source: code/trunk/upstream.go@ 250

Last change on this file since 250 was 248, checked in by contact, 5 years ago

Make newMessageLogger take a *network instead of an *upstreamConn

There's no reason why messgeLogger needs access to the whole connection,
the network is enough.

File size: 32.7 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "errors"
7 "fmt"
8 "io"
9 "net"
10 "strconv"
11 "strings"
12 "time"
13
14 "github.com/emersion/go-sasl"
15 "gopkg.in/irc.v3"
16)
17
18type upstreamChannel struct {
19 Name string
20 conn *upstreamConn
21 Topic string
22 TopicWho string
23 TopicTime time.Time
24 Status channelStatus
25 modes channelModes
26 creationTime string
27 Members map[string]*membership
28 complete bool
29}
30
31type upstreamConn struct {
32 conn
33
34 network *network
35 user *user
36
37 serverName string
38 availableUserModes string
39 availableChannelModes map[byte]channelModeType
40 availableChannelTypes string
41 availableMemberships []membership
42
43 registered bool
44 nick string
45 username string
46 realname string
47 modes userModes
48 channels map[string]*upstreamChannel
49 caps map[string]string
50 batches map[string]batch
51 away bool
52
53 tagsSupported bool
54 labelsSupported bool
55 nextLabelID uint64
56
57 saslClient sasl.Client
58 saslStarted bool
59
60 // set of LIST commands in progress, per downstream
61 pendingLISTDownstreamSet map[uint64]struct{}
62
63 messageLoggers map[string]*messageLogger
64}
65
66func connectToUpstream(network *network) (*upstreamConn, error) {
67 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
68
69 addr := network.Addr
70 if !strings.ContainsRune(addr, ':') {
71 addr = addr + ":6697"
72 }
73
74 dialer := net.Dialer{Timeout: connectTimeout}
75
76 logger.Printf("connecting to TLS server at address %q", addr)
77 netConn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
78 if err != nil {
79 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
80 }
81
82 uc := &upstreamConn{
83 conn: *newConn(network.user.srv, netConn, logger),
84 network: network,
85 user: network.user,
86 channels: make(map[string]*upstreamChannel),
87 caps: make(map[string]string),
88 batches: make(map[string]batch),
89 availableChannelTypes: stdChannelTypes,
90 availableChannelModes: stdChannelModes,
91 availableMemberships: stdMemberships,
92 pendingLISTDownstreamSet: make(map[uint64]struct{}),
93 messageLoggers: make(map[string]*messageLogger),
94 }
95
96 return uc, nil
97}
98
99func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
100 uc.network.forEachDownstream(f)
101}
102
103func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) {
104 uc.forEachDownstream(func(dc *downstreamConn) {
105 if id != 0 && id != dc.id {
106 return
107 }
108 f(dc)
109 })
110}
111
112func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
113 ch, ok := uc.channels[name]
114 if !ok {
115 return nil, fmt.Errorf("unknown channel %q", name)
116 }
117 return ch, nil
118}
119
120func (uc *upstreamConn) isChannel(entity string) bool {
121 if i := strings.IndexByte(uc.availableChannelTypes, entity[0]); i >= 0 {
122 return true
123 }
124 return false
125}
126
127func (uc *upstreamConn) getPendingLIST() *pendingLIST {
128 for _, pl := range uc.user.pendingLISTs {
129 if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
130 continue
131 }
132 return &pl
133 }
134 return nil
135}
136
137func (uc *upstreamConn) endPendingLISTs(all bool) (found bool) {
138 found = false
139 for i := 0; i < len(uc.user.pendingLISTs); i++ {
140 pl := uc.user.pendingLISTs[i]
141 if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
142 continue
143 }
144 delete(pl.pendingCommands, uc.network.ID)
145 if len(pl.pendingCommands) == 0 {
146 uc.user.pendingLISTs = append(uc.user.pendingLISTs[:i], uc.user.pendingLISTs[i+1:]...)
147 i--
148 uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) {
149 dc.SendMessage(&irc.Message{
150 Prefix: dc.srv.prefix(),
151 Command: irc.RPL_LISTEND,
152 Params: []string{dc.nick, "End of /LIST"},
153 })
154 })
155 }
156 found = true
157 if !all {
158 delete(uc.pendingLISTDownstreamSet, pl.downstreamID)
159 uc.user.forEachUpstream(func(uc *upstreamConn) {
160 uc.trySendLIST(pl.downstreamID)
161 })
162 return
163 }
164 }
165 return
166}
167
168func (uc *upstreamConn) trySendLIST(downstreamID uint64) {
169 if _, ok := uc.pendingLISTDownstreamSet[downstreamID]; ok {
170 // a LIST command is already pending
171 // we will try again when that command is completed
172 return
173 }
174
175 for _, pl := range uc.user.pendingLISTs {
176 if pl.downstreamID != downstreamID {
177 continue
178 }
179 // this is the first pending LIST command list of the downstream
180 listCommand, ok := pl.pendingCommands[uc.network.ID]
181 if !ok {
182 // there is no command for this upstream in these LIST commands
183 // do not send anything
184 continue
185 }
186 // there is a command for this upstream in these LIST commands
187 // send it now
188
189 uc.SendMessageLabeled(downstreamID, listCommand)
190
191 uc.pendingLISTDownstreamSet[downstreamID] = struct{}{}
192 return
193 }
194}
195
196func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership, nick string) {
197 for _, m := range uc.availableMemberships {
198 if m.Prefix == s[0] {
199 return &m, s[1:]
200 }
201 }
202 return nil, s
203}
204
205func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
206 var label string
207 if l, ok := msg.GetTag("label"); ok {
208 label = l
209 }
210
211 var msgBatch *batch
212 if batchName, ok := msg.GetTag("batch"); ok {
213 b, ok := uc.batches[batchName]
214 if !ok {
215 return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
216 }
217 msgBatch = &b
218 if label == "" {
219 label = msgBatch.Label
220 }
221 }
222
223 var downstreamID uint64 = 0
224 if label != "" {
225 var labelOffset uint64
226 n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset)
227 if err == nil && n < 2 {
228 err = errors.New("not enough arguments")
229 }
230 if err != nil {
231 return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
232 }
233 }
234
235 if _, ok := msg.Tags["time"]; !ok {
236 msg.Tags["time"] = irc.TagValue(time.Now().UTC().Format(serverTimeLayout))
237 }
238
239 switch msg.Command {
240 case "PING":
241 uc.SendMessage(&irc.Message{
242 Command: "PONG",
243 Params: msg.Params,
244 })
245 return nil
246 case "NOTICE":
247 if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
248 uc.produce("", msg, nil)
249 } else { // regular user NOTICE
250 var entity, text string
251 if err := parseMessageParams(msg, &entity, &text); err != nil {
252 return err
253 }
254
255 target := entity
256 if target == uc.nick {
257 target = msg.Prefix.Name
258 }
259 uc.produce(target, msg, nil)
260 }
261 case "CAP":
262 var subCmd string
263 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
264 return err
265 }
266 subCmd = strings.ToUpper(subCmd)
267 subParams := msg.Params[2:]
268 switch subCmd {
269 case "LS":
270 if len(subParams) < 1 {
271 return newNeedMoreParamsError(msg.Command)
272 }
273 caps := strings.Fields(subParams[len(subParams)-1])
274 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
275
276 for _, s := range caps {
277 kv := strings.SplitN(s, "=", 2)
278 k := strings.ToLower(kv[0])
279 var v string
280 if len(kv) == 2 {
281 v = kv[1]
282 }
283 uc.caps[k] = v
284 }
285
286 if more {
287 break // wait to receive all capabilities
288 }
289
290 requestCaps := make([]string, 0, 16)
291 for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time"} {
292 if _, ok := uc.caps[c]; ok {
293 requestCaps = append(requestCaps, c)
294 }
295 }
296
297 if uc.requestSASL() {
298 requestCaps = append(requestCaps, "sasl")
299 }
300
301 if len(requestCaps) > 0 {
302 uc.SendMessage(&irc.Message{
303 Command: "CAP",
304 Params: []string{"REQ", strings.Join(requestCaps, " ")},
305 })
306 }
307
308 if uc.requestSASL() {
309 break // we'll send CAP END after authentication is completed
310 }
311
312 uc.SendMessage(&irc.Message{
313 Command: "CAP",
314 Params: []string{"END"},
315 })
316 case "ACK", "NAK":
317 if len(subParams) < 1 {
318 return newNeedMoreParamsError(msg.Command)
319 }
320 caps := strings.Fields(subParams[0])
321
322 for _, name := range caps {
323 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
324 return err
325 }
326 }
327
328 if uc.saslClient == nil {
329 uc.SendMessage(&irc.Message{
330 Command: "CAP",
331 Params: []string{"END"},
332 })
333 }
334 default:
335 uc.logger.Printf("unhandled message: %v", msg)
336 }
337 case "AUTHENTICATE":
338 if uc.saslClient == nil {
339 return fmt.Errorf("received unexpected AUTHENTICATE message")
340 }
341
342 // TODO: if a challenge is 400 bytes long, buffer it
343 var challengeStr string
344 if err := parseMessageParams(msg, &challengeStr); err != nil {
345 uc.SendMessage(&irc.Message{
346 Command: "AUTHENTICATE",
347 Params: []string{"*"},
348 })
349 return err
350 }
351
352 var challenge []byte
353 if challengeStr != "+" {
354 var err error
355 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
356 if err != nil {
357 uc.SendMessage(&irc.Message{
358 Command: "AUTHENTICATE",
359 Params: []string{"*"},
360 })
361 return err
362 }
363 }
364
365 var resp []byte
366 var err error
367 if !uc.saslStarted {
368 _, resp, err = uc.saslClient.Start()
369 uc.saslStarted = true
370 } else {
371 resp, err = uc.saslClient.Next(challenge)
372 }
373 if err != nil {
374 uc.SendMessage(&irc.Message{
375 Command: "AUTHENTICATE",
376 Params: []string{"*"},
377 })
378 return err
379 }
380
381 // TODO: send response in multiple chunks if >= 400 bytes
382 var respStr = "+"
383 if resp != nil {
384 respStr = base64.StdEncoding.EncodeToString(resp)
385 }
386
387 uc.SendMessage(&irc.Message{
388 Command: "AUTHENTICATE",
389 Params: []string{respStr},
390 })
391 case irc.RPL_LOGGEDIN:
392 var account string
393 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
394 return err
395 }
396 uc.logger.Printf("logged in with account %q", account)
397 case irc.RPL_LOGGEDOUT:
398 uc.logger.Printf("logged out")
399 case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
400 var info string
401 if err := parseMessageParams(msg, nil, &info); err != nil {
402 return err
403 }
404 switch msg.Command {
405 case irc.ERR_NICKLOCKED:
406 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
407 case irc.ERR_SASLFAIL:
408 uc.logger.Printf("SASL authentication failed: %v", info)
409 case irc.ERR_SASLTOOLONG:
410 uc.logger.Printf("SASL message too long: %v", info)
411 }
412
413 uc.saslClient = nil
414 uc.saslStarted = false
415
416 uc.SendMessage(&irc.Message{
417 Command: "CAP",
418 Params: []string{"END"},
419 })
420 case irc.RPL_WELCOME:
421 uc.registered = true
422 uc.logger.Printf("connection registered")
423
424 channels, err := uc.srv.db.ListChannels(uc.network.ID)
425 if err != nil {
426 uc.logger.Printf("failed to list channels from database: %v", err)
427 break
428 }
429
430 for _, ch := range channels {
431 params := []string{ch.Name}
432 if ch.Key != "" {
433 params = append(params, ch.Key)
434 }
435 uc.SendMessage(&irc.Message{
436 Command: "JOIN",
437 Params: params,
438 })
439 }
440 case irc.RPL_MYINFO:
441 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
442 return err
443 }
444 case irc.RPL_ISUPPORT:
445 if err := parseMessageParams(msg, nil, nil); err != nil {
446 return err
447 }
448 for _, token := range msg.Params[1 : len(msg.Params)-1] {
449 negate := false
450 parameter := token
451 value := ""
452 if strings.HasPrefix(token, "-") {
453 negate = true
454 token = token[1:]
455 } else {
456 if i := strings.IndexByte(token, '='); i >= 0 {
457 parameter = token[:i]
458 value = token[i+1:]
459 }
460 }
461 if !negate {
462 switch parameter {
463 case "CHANMODES":
464 parts := strings.SplitN(value, ",", 5)
465 if len(parts) < 4 {
466 return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", value)
467 }
468 modes := make(map[byte]channelModeType)
469 for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
470 for j := 0; j < len(parts[i]); j++ {
471 mode := parts[i][j]
472 modes[mode] = mt
473 }
474 }
475 uc.availableChannelModes = modes
476 case "CHANTYPES":
477 uc.availableChannelTypes = value
478 case "PREFIX":
479 if value == "" {
480 uc.availableMemberships = nil
481 } else {
482 if value[0] != '(' {
483 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
484 }
485 sep := strings.IndexByte(value, ')')
486 if sep < 0 || len(value) != sep*2 {
487 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
488 }
489 memberships := make([]membership, len(value)/2-1)
490 for i := range memberships {
491 memberships[i] = membership{
492 Mode: value[i+1],
493 Prefix: value[sep+i+1],
494 }
495 }
496 uc.availableMemberships = memberships
497 }
498 }
499 } else {
500 // TODO: handle ISUPPORT negations
501 }
502 }
503 case "BATCH":
504 var tag string
505 if err := parseMessageParams(msg, &tag); err != nil {
506 return err
507 }
508
509 if strings.HasPrefix(tag, "+") {
510 tag = tag[1:]
511 if _, ok := uc.batches[tag]; ok {
512 return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
513 }
514 var batchType string
515 if err := parseMessageParams(msg, nil, &batchType); err != nil {
516 return err
517 }
518 label := label
519 if label == "" && msgBatch != nil {
520 label = msgBatch.Label
521 }
522 uc.batches[tag] = batch{
523 Type: batchType,
524 Params: msg.Params[2:],
525 Outer: msgBatch,
526 Label: label,
527 }
528 } else if strings.HasPrefix(tag, "-") {
529 tag = tag[1:]
530 if _, ok := uc.batches[tag]; !ok {
531 return fmt.Errorf("unknown BATCH reference tag: %q", tag)
532 }
533 delete(uc.batches, tag)
534 } else {
535 return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
536 }
537 case "NICK":
538 if msg.Prefix == nil {
539 return fmt.Errorf("expected a prefix")
540 }
541
542 var newNick string
543 if err := parseMessageParams(msg, &newNick); err != nil {
544 return err
545 }
546
547 me := false
548 if msg.Prefix.Name == uc.nick {
549 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
550 me = true
551 uc.nick = newNick
552 }
553
554 for _, ch := range uc.channels {
555 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
556 delete(ch.Members, msg.Prefix.Name)
557 ch.Members[newNick] = membership
558 uc.appendLog(ch.Name, msg)
559 }
560 }
561
562 if !me {
563 uc.network.ring.Produce(msg)
564 uc.forEachDownstream(func(dc *downstreamConn) {
565 dc.SendMessage(dc.marshalMessage(msg, uc))
566 })
567 }
568 case "JOIN":
569 if msg.Prefix == nil {
570 return fmt.Errorf("expected a prefix")
571 }
572
573 var channels string
574 if err := parseMessageParams(msg, &channels); err != nil {
575 return err
576 }
577
578 for _, ch := range strings.Split(channels, ",") {
579 if msg.Prefix.Name == uc.nick {
580 uc.logger.Printf("joined channel %q", ch)
581 uc.channels[ch] = &upstreamChannel{
582 Name: ch,
583 conn: uc,
584 Members: make(map[string]*membership),
585 }
586
587 uc.SendMessage(&irc.Message{
588 Command: "MODE",
589 Params: []string{ch},
590 })
591 } else {
592 ch, err := uc.getChannel(ch)
593 if err != nil {
594 return err
595 }
596 ch.Members[msg.Prefix.Name] = nil
597 }
598
599 chMsg := msg.Copy()
600 chMsg.Params[0] = ch
601 uc.produce(ch, chMsg, nil)
602 }
603 case "PART":
604 if msg.Prefix == nil {
605 return fmt.Errorf("expected a prefix")
606 }
607
608 var channels string
609 if err := parseMessageParams(msg, &channels); err != nil {
610 return err
611 }
612
613 for _, ch := range strings.Split(channels, ",") {
614 if msg.Prefix.Name == uc.nick {
615 uc.logger.Printf("parted channel %q", ch)
616 delete(uc.channels, ch)
617 } else {
618 ch, err := uc.getChannel(ch)
619 if err != nil {
620 return err
621 }
622 delete(ch.Members, msg.Prefix.Name)
623 }
624
625 chMsg := msg.Copy()
626 chMsg.Params[0] = ch
627 uc.produce(ch, chMsg, nil)
628 }
629 case "KICK":
630 if msg.Prefix == nil {
631 return fmt.Errorf("expected a prefix")
632 }
633
634 var channel, user string
635 if err := parseMessageParams(msg, &channel, &user); err != nil {
636 return err
637 }
638
639 if user == uc.nick {
640 uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
641 delete(uc.channels, channel)
642 } else {
643 ch, err := uc.getChannel(channel)
644 if err != nil {
645 return err
646 }
647 delete(ch.Members, user)
648 }
649
650 uc.produce(channel, msg, nil)
651 case "QUIT":
652 if msg.Prefix == nil {
653 return fmt.Errorf("expected a prefix")
654 }
655
656 if msg.Prefix.Name == uc.nick {
657 uc.logger.Printf("quit")
658 }
659
660 for _, ch := range uc.channels {
661 if _, ok := ch.Members[msg.Prefix.Name]; ok {
662 delete(ch.Members, msg.Prefix.Name)
663
664 uc.appendLog(ch.Name, msg)
665 }
666 }
667
668 if msg.Prefix.Name != uc.nick {
669 uc.network.ring.Produce(msg)
670 uc.forEachDownstream(func(dc *downstreamConn) {
671 dc.SendMessage(dc.marshalMessage(msg, uc))
672 })
673 }
674 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
675 var name, topic string
676 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
677 return err
678 }
679 ch, err := uc.getChannel(name)
680 if err != nil {
681 return err
682 }
683 if msg.Command == irc.RPL_TOPIC {
684 ch.Topic = topic
685 } else {
686 ch.Topic = ""
687 }
688 case "TOPIC":
689 var name string
690 if err := parseMessageParams(msg, &name); err != nil {
691 return err
692 }
693 ch, err := uc.getChannel(name)
694 if err != nil {
695 return err
696 }
697 if len(msg.Params) > 1 {
698 ch.Topic = msg.Params[1]
699 } else {
700 ch.Topic = ""
701 }
702 uc.produce(ch.Name, msg, nil)
703 case "MODE":
704 var name, modeStr string
705 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
706 return err
707 }
708
709 if !uc.isChannel(name) { // user mode change
710 if name != uc.nick {
711 return fmt.Errorf("received MODE message for unknown nick %q", name)
712 }
713 return uc.modes.Apply(modeStr)
714 // TODO: notify downstreams about user mode change?
715 } else { // channel mode change
716 ch, err := uc.getChannel(name)
717 if err != nil {
718 return err
719 }
720
721 if ch.modes != nil {
722 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[2:]...); err != nil {
723 return err
724 }
725 }
726
727 uc.produce(ch.Name, msg, nil)
728 }
729 case irc.RPL_UMODEIS:
730 if err := parseMessageParams(msg, nil); err != nil {
731 return err
732 }
733 modeStr := ""
734 if len(msg.Params) > 1 {
735 modeStr = msg.Params[1]
736 }
737
738 uc.modes = ""
739 if err := uc.modes.Apply(modeStr); err != nil {
740 return err
741 }
742 // TODO: send RPL_UMODEIS to downstream connections when applicable
743 case irc.RPL_CHANNELMODEIS:
744 var channel string
745 if err := parseMessageParams(msg, nil, &channel); err != nil {
746 return err
747 }
748 modeStr := ""
749 if len(msg.Params) > 2 {
750 modeStr = msg.Params[2]
751 }
752
753 ch, err := uc.getChannel(channel)
754 if err != nil {
755 return err
756 }
757
758 firstMode := ch.modes == nil
759 ch.modes = make(map[byte]string)
760 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[3:]...); err != nil {
761 return err
762 }
763 if firstMode {
764 modeStr, modeParams := ch.modes.Format()
765
766 uc.forEachDownstream(func(dc *downstreamConn) {
767 params := []string{dc.nick, dc.marshalChannel(uc, channel), modeStr}
768 params = append(params, modeParams...)
769
770 dc.SendMessage(&irc.Message{
771 Prefix: dc.srv.prefix(),
772 Command: irc.RPL_CHANNELMODEIS,
773 Params: params,
774 })
775 })
776 }
777 case rpl_creationtime:
778 var channel, creationTime string
779 if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil {
780 return err
781 }
782
783 ch, err := uc.getChannel(channel)
784 if err != nil {
785 return err
786 }
787
788 firstCreationTime := ch.creationTime == ""
789 ch.creationTime = creationTime
790 if firstCreationTime {
791 uc.forEachDownstream(func(dc *downstreamConn) {
792 dc.SendMessage(&irc.Message{
793 Prefix: dc.srv.prefix(),
794 Command: rpl_creationtime,
795 Params: []string{dc.nick, channel, creationTime},
796 })
797 })
798 }
799 case rpl_topicwhotime:
800 var name, who, timeStr string
801 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
802 return err
803 }
804 ch, err := uc.getChannel(name)
805 if err != nil {
806 return err
807 }
808 ch.TopicWho = who
809 sec, err := strconv.ParseInt(timeStr, 10, 64)
810 if err != nil {
811 return fmt.Errorf("failed to parse topic time: %v", err)
812 }
813 ch.TopicTime = time.Unix(sec, 0)
814 case irc.RPL_LIST:
815 var channel, clients, topic string
816 if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil {
817 return err
818 }
819
820 pl := uc.getPendingLIST()
821 if pl == nil {
822 return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
823 }
824
825 uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) {
826 dc.SendMessage(&irc.Message{
827 Prefix: dc.srv.prefix(),
828 Command: irc.RPL_LIST,
829 Params: []string{dc.nick, dc.marshalChannel(uc, channel), clients, topic},
830 })
831 })
832 case irc.RPL_LISTEND:
833 ok := uc.endPendingLISTs(false)
834 if !ok {
835 return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
836 }
837 case irc.RPL_NAMREPLY:
838 var name, statusStr, members string
839 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
840 return err
841 }
842
843 ch, ok := uc.channels[name]
844 if !ok {
845 // NAMES on a channel we have not joined, forward to downstream
846 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
847 channel := dc.marshalChannel(uc, name)
848 members := splitSpace(members)
849 for i, member := range members {
850 membership, nick := uc.parseMembershipPrefix(member)
851 members[i] = membership.String() + dc.marshalNick(uc, nick)
852 }
853 memberStr := strings.Join(members, " ")
854
855 dc.SendMessage(&irc.Message{
856 Prefix: dc.srv.prefix(),
857 Command: irc.RPL_NAMREPLY,
858 Params: []string{dc.nick, statusStr, channel, memberStr},
859 })
860 })
861 return nil
862 }
863
864 status, err := parseChannelStatus(statusStr)
865 if err != nil {
866 return err
867 }
868 ch.Status = status
869
870 for _, s := range splitSpace(members) {
871 membership, nick := uc.parseMembershipPrefix(s)
872 ch.Members[nick] = membership
873 }
874 case irc.RPL_ENDOFNAMES:
875 var name string
876 if err := parseMessageParams(msg, nil, &name); err != nil {
877 return err
878 }
879
880 ch, ok := uc.channels[name]
881 if !ok {
882 // NAMES on a channel we have not joined, forward to downstream
883 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
884 channel := dc.marshalChannel(uc, name)
885
886 dc.SendMessage(&irc.Message{
887 Prefix: dc.srv.prefix(),
888 Command: irc.RPL_ENDOFNAMES,
889 Params: []string{dc.nick, channel, "End of /NAMES list"},
890 })
891 })
892 return nil
893 }
894
895 if ch.complete {
896 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
897 }
898 ch.complete = true
899
900 uc.forEachDownstream(func(dc *downstreamConn) {
901 forwardChannel(dc, ch)
902 })
903 case irc.RPL_WHOREPLY:
904 var channel, username, host, server, nick, mode, trailing string
905 if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &mode, &trailing); err != nil {
906 return err
907 }
908
909 parts := strings.SplitN(trailing, " ", 2)
910 if len(parts) != 2 {
911 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing)
912 }
913 realname := parts[1]
914 hops, err := strconv.Atoi(parts[0])
915 if err != nil {
916 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong hop count: %s", parts[0])
917 }
918 hops++
919
920 trailing = strconv.Itoa(hops) + " " + realname
921
922 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
923 channel := channel
924 if channel != "*" {
925 channel = dc.marshalChannel(uc, channel)
926 }
927 nick := dc.marshalNick(uc, nick)
928 dc.SendMessage(&irc.Message{
929 Prefix: dc.srv.prefix(),
930 Command: irc.RPL_WHOREPLY,
931 Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
932 })
933 })
934 case irc.RPL_ENDOFWHO:
935 var name string
936 if err := parseMessageParams(msg, nil, &name); err != nil {
937 return err
938 }
939
940 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
941 name := name
942 if name != "*" {
943 // TODO: support WHO masks
944 name = dc.marshalEntity(uc, name)
945 }
946 dc.SendMessage(&irc.Message{
947 Prefix: dc.srv.prefix(),
948 Command: irc.RPL_ENDOFWHO,
949 Params: []string{dc.nick, name, "End of /WHO list"},
950 })
951 })
952 case irc.RPL_WHOISUSER:
953 var nick, username, host, realname string
954 if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
955 return err
956 }
957
958 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
959 nick := dc.marshalNick(uc, nick)
960 dc.SendMessage(&irc.Message{
961 Prefix: dc.srv.prefix(),
962 Command: irc.RPL_WHOISUSER,
963 Params: []string{dc.nick, nick, username, host, "*", realname},
964 })
965 })
966 case irc.RPL_WHOISSERVER:
967 var nick, server, serverInfo string
968 if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
969 return err
970 }
971
972 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
973 nick := dc.marshalNick(uc, nick)
974 dc.SendMessage(&irc.Message{
975 Prefix: dc.srv.prefix(),
976 Command: irc.RPL_WHOISSERVER,
977 Params: []string{dc.nick, nick, server, serverInfo},
978 })
979 })
980 case irc.RPL_WHOISOPERATOR:
981 var nick string
982 if err := parseMessageParams(msg, nil, &nick); err != nil {
983 return err
984 }
985
986 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
987 nick := dc.marshalNick(uc, nick)
988 dc.SendMessage(&irc.Message{
989 Prefix: dc.srv.prefix(),
990 Command: irc.RPL_WHOISOPERATOR,
991 Params: []string{dc.nick, nick, "is an IRC operator"},
992 })
993 })
994 case irc.RPL_WHOISIDLE:
995 var nick string
996 if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
997 return err
998 }
999
1000 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1001 nick := dc.marshalNick(uc, nick)
1002 params := []string{dc.nick, nick}
1003 params = append(params, msg.Params[2:]...)
1004 dc.SendMessage(&irc.Message{
1005 Prefix: dc.srv.prefix(),
1006 Command: irc.RPL_WHOISIDLE,
1007 Params: params,
1008 })
1009 })
1010 case irc.RPL_WHOISCHANNELS:
1011 var nick, channelList string
1012 if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
1013 return err
1014 }
1015 channels := splitSpace(channelList)
1016
1017 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1018 nick := dc.marshalNick(uc, nick)
1019 channelList := make([]string, len(channels))
1020 for i, channel := range channels {
1021 prefix, channel := uc.parseMembershipPrefix(channel)
1022 channel = dc.marshalChannel(uc, channel)
1023 channelList[i] = prefix.String() + channel
1024 }
1025 channels := strings.Join(channelList, " ")
1026 dc.SendMessage(&irc.Message{
1027 Prefix: dc.srv.prefix(),
1028 Command: irc.RPL_WHOISCHANNELS,
1029 Params: []string{dc.nick, nick, channels},
1030 })
1031 })
1032 case irc.RPL_ENDOFWHOIS:
1033 var nick string
1034 if err := parseMessageParams(msg, nil, &nick); err != nil {
1035 return err
1036 }
1037
1038 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1039 nick := dc.marshalNick(uc, nick)
1040 dc.SendMessage(&irc.Message{
1041 Prefix: dc.srv.prefix(),
1042 Command: irc.RPL_ENDOFWHOIS,
1043 Params: []string{dc.nick, nick, "End of /WHOIS list"},
1044 })
1045 })
1046 case "PRIVMSG":
1047 if msg.Prefix == nil {
1048 return fmt.Errorf("expected a prefix")
1049 }
1050
1051 var entity, text string
1052 if err := parseMessageParams(msg, &entity, &text); err != nil {
1053 return err
1054 }
1055
1056 if msg.Prefix.Name == serviceNick {
1057 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
1058 break
1059 }
1060 if entity == serviceNick {
1061 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
1062 break
1063 }
1064
1065 target := entity
1066 if target == uc.nick {
1067 target = msg.Prefix.Name
1068 }
1069 uc.produce(target, msg, nil)
1070 case "INVITE":
1071 var nick string
1072 var channel string
1073 if err := parseMessageParams(msg, &nick, &channel); err != nil {
1074 return err
1075 }
1076
1077 uc.forEachDownstream(func(dc *downstreamConn) {
1078 dc.SendMessage(&irc.Message{
1079 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
1080 Command: "INVITE",
1081 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
1082 })
1083 })
1084 case irc.RPL_INVITING:
1085 var nick string
1086 var channel string
1087 if err := parseMessageParams(msg, &nick, &channel); err != nil {
1088 return err
1089 }
1090
1091 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1092 dc.SendMessage(&irc.Message{
1093 Prefix: dc.srv.prefix(),
1094 Command: irc.RPL_INVITING,
1095 Params: []string{dc.nick, dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
1096 })
1097 })
1098 case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN:
1099 var command, reason string
1100 if err := parseMessageParams(msg, nil, &command, &reason); err != nil {
1101 return err
1102 }
1103
1104 if command == "LIST" {
1105 ok := uc.endPendingLISTs(false)
1106 if !ok {
1107 return fmt.Errorf("unexpected response for LIST: %q: no matching pending LIST", msg.Command)
1108 }
1109 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1110 dc.SendMessage(&irc.Message{
1111 Prefix: uc.srv.prefix(),
1112 Command: msg.Command,
1113 Params: []string{dc.nick, "LIST", reason},
1114 })
1115 })
1116 }
1117 case "TAGMSG":
1118 // TODO: relay to downstream connections that accept message-tags
1119 case "ACK":
1120 // Ignore
1121 case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
1122 // Ignore
1123 case irc.RPL_YOURHOST, irc.RPL_CREATED:
1124 // Ignore
1125 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
1126 // Ignore
1127 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
1128 // Ignore
1129 case irc.RPL_LISTSTART:
1130 // Ignore
1131 case rpl_localusers, rpl_globalusers:
1132 // Ignore
1133 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
1134 // Ignore
1135 default:
1136 uc.logger.Printf("unhandled message: %v", msg)
1137 }
1138 return nil
1139}
1140
1141func splitSpace(s string) []string {
1142 return strings.FieldsFunc(s, func(r rune) bool {
1143 return r == ' '
1144 })
1145}
1146
1147func (uc *upstreamConn) register() {
1148 uc.nick = uc.network.Nick
1149 uc.username = uc.network.Username
1150 if uc.username == "" {
1151 uc.username = uc.nick
1152 }
1153 uc.realname = uc.network.Realname
1154 if uc.realname == "" {
1155 uc.realname = uc.nick
1156 }
1157
1158 uc.SendMessage(&irc.Message{
1159 Command: "CAP",
1160 Params: []string{"LS", "302"},
1161 })
1162
1163 if uc.network.Pass != "" {
1164 uc.SendMessage(&irc.Message{
1165 Command: "PASS",
1166 Params: []string{uc.network.Pass},
1167 })
1168 }
1169
1170 uc.SendMessage(&irc.Message{
1171 Command: "NICK",
1172 Params: []string{uc.nick},
1173 })
1174 uc.SendMessage(&irc.Message{
1175 Command: "USER",
1176 Params: []string{uc.username, "0", "*", uc.realname},
1177 })
1178}
1179
1180func (uc *upstreamConn) runUntilRegistered() error {
1181 for !uc.registered {
1182 msg, err := uc.ReadMessage()
1183 if err != nil {
1184 return fmt.Errorf("failed to read message: %v", err)
1185 }
1186
1187 if err := uc.handleMessage(msg); err != nil {
1188 return fmt.Errorf("failed to handle message %q: %v", msg, err)
1189 }
1190 }
1191
1192 return nil
1193}
1194
1195func (uc *upstreamConn) requestSASL() bool {
1196 if uc.network.SASL.Mechanism == "" {
1197 return false
1198 }
1199
1200 v, ok := uc.caps["sasl"]
1201 if !ok {
1202 return false
1203 }
1204 if v != "" {
1205 mechanisms := strings.Split(v, ",")
1206 found := false
1207 for _, mech := range mechanisms {
1208 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
1209 found = true
1210 break
1211 }
1212 }
1213 if !found {
1214 return false
1215 }
1216 }
1217
1218 return true
1219}
1220
1221func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
1222 auth := &uc.network.SASL
1223 switch name {
1224 case "sasl":
1225 if !ok {
1226 uc.logger.Printf("server refused to acknowledge the SASL capability")
1227 return nil
1228 }
1229
1230 switch auth.Mechanism {
1231 case "PLAIN":
1232 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
1233 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
1234 default:
1235 return fmt.Errorf("unsupported SASL mechanism %q", name)
1236 }
1237
1238 uc.SendMessage(&irc.Message{
1239 Command: "AUTHENTICATE",
1240 Params: []string{auth.Mechanism},
1241 })
1242 case "message-tags":
1243 uc.tagsSupported = ok
1244 case "labeled-response":
1245 uc.labelsSupported = ok
1246 case "batch", "server-time":
1247 // Nothing to do
1248 default:
1249 uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
1250 }
1251 return nil
1252}
1253
1254func (uc *upstreamConn) readMessages(ch chan<- event) error {
1255 for {
1256 msg, err := uc.ReadMessage()
1257 if err == io.EOF {
1258 break
1259 } else if err != nil {
1260 return fmt.Errorf("failed to read IRC command: %v", err)
1261 }
1262
1263 ch <- eventUpstreamMessage{msg, uc}
1264 }
1265
1266 return nil
1267}
1268
1269func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
1270 if uc.labelsSupported {
1271 if msg.Tags == nil {
1272 msg.Tags = make(map[string]irc.TagValue)
1273 }
1274 msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
1275 uc.nextLabelID++
1276 }
1277 uc.SendMessage(msg)
1278}
1279
1280// TODO: handle moving logs when a network name changes, when support for this is added
1281func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
1282 if uc.srv.LogPath == "" {
1283 return
1284 }
1285
1286 ml, ok := uc.messageLoggers[entity]
1287 if !ok {
1288 ml = newMessageLogger(uc.network, entity)
1289 uc.messageLoggers[entity] = ml
1290 }
1291
1292 if err := ml.Append(msg); err != nil {
1293 uc.logger.Printf("failed to log message: %v", err)
1294 }
1295}
1296
1297// produce appends a message to the logs, adds it to the history and forwards
1298// it to connected downstream connections.
1299//
1300// If origin is not nil and origin doesn't support echo-message, the message is
1301// forwarded to all connections except origin.
1302func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) {
1303 if target != "" {
1304 uc.appendLog(target, msg)
1305 }
1306
1307 uc.network.ring.Produce(msg)
1308
1309 uc.forEachDownstream(func(dc *downstreamConn) {
1310 if dc != origin || dc.caps["echo-message"] {
1311 dc.SendMessage(dc.marshalMessage(msg, uc))
1312 }
1313 })
1314}
1315
1316func (uc *upstreamConn) updateAway() {
1317 away := true
1318 uc.forEachDownstream(func(*downstreamConn) {
1319 away = false
1320 })
1321 if away == uc.away {
1322 return
1323 }
1324 if away {
1325 uc.SendMessage(&irc.Message{
1326 Command: "AWAY",
1327 Params: []string{"Auto away"},
1328 })
1329 } else {
1330 uc.SendMessage(&irc.Message{
1331 Command: "AWAY",
1332 })
1333 }
1334 uc.away = away
1335}
Note: See TracBrowser for help on using the repository browser.