source: code/trunk/upstream.go@ 216

Last change on this file since 216 was 216, checked in by contact, 5 years ago

Add time tag to all messages

File size: 34.3 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "errors"
7 "fmt"
8 "io"
9 "net"
10 "os"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/emersion/go-sasl"
16 "gopkg.in/irc.v3"
17)
18
19type upstreamChannel struct {
20 Name string
21 conn *upstreamConn
22 Topic string
23 TopicWho string
24 TopicTime time.Time
25 Status channelStatus
26 modes channelModes
27 creationTime string
28 Members map[string]*membership
29 complete bool
30}
31
32type upstreamConn struct {
33 conn
34
35 network *network
36 user *user
37
38 serverName string
39 availableUserModes string
40 availableChannelModes map[byte]channelModeType
41 availableChannelTypes string
42 availableMemberships []membership
43
44 registered bool
45 nick string
46 username string
47 realname string
48 modes userModes
49 channels map[string]*upstreamChannel
50 caps map[string]string
51 batches map[string]batch
52 away bool
53
54 tagsSupported bool
55 labelsSupported bool
56 nextLabelID uint64
57
58 saslClient sasl.Client
59 saslStarted bool
60
61 // set of LIST commands in progress, per downstream
62 pendingLISTDownstreamSet map[uint64]struct{}
63
64 messageLoggers map[string]*messageLogger
65}
66
67type entityLog struct {
68 name string
69 file *os.File
70}
71
72func connectToUpstream(network *network) (*upstreamConn, error) {
73 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
74
75 addr := network.Addr
76 if !strings.ContainsRune(addr, ':') {
77 addr = addr + ":6697"
78 }
79
80 dialer := net.Dialer{Timeout: connectTimeout}
81
82 logger.Printf("connecting to TLS server at address %q", addr)
83 netConn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
84 if err != nil {
85 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
86 }
87
88 uc := &upstreamConn{
89 conn: *newConn(network.user.srv, netConn, logger),
90 network: network,
91 user: network.user,
92 channels: make(map[string]*upstreamChannel),
93 caps: make(map[string]string),
94 batches: make(map[string]batch),
95 availableChannelTypes: stdChannelTypes,
96 availableChannelModes: stdChannelModes,
97 availableMemberships: stdMemberships,
98 pendingLISTDownstreamSet: make(map[uint64]struct{}),
99 messageLoggers: make(map[string]*messageLogger),
100 }
101
102 return uc, nil
103}
104
105func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
106 uc.user.forEachDownstream(func(dc *downstreamConn) {
107 if dc.network != nil && dc.network != uc.network {
108 return
109 }
110 f(dc)
111 })
112}
113
114func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) {
115 uc.forEachDownstream(func(dc *downstreamConn) {
116 if id != 0 && id != dc.id {
117 return
118 }
119 f(dc)
120 })
121}
122
123func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
124 ch, ok := uc.channels[name]
125 if !ok {
126 return nil, fmt.Errorf("unknown channel %q", name)
127 }
128 return ch, nil
129}
130
131func (uc *upstreamConn) isChannel(entity string) bool {
132 if i := strings.IndexByte(uc.availableChannelTypes, entity[0]); i >= 0 {
133 return true
134 }
135 return false
136}
137
138func (uc *upstreamConn) getPendingLIST() *pendingLIST {
139 for _, pl := range uc.user.pendingLISTs {
140 if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
141 continue
142 }
143 return &pl
144 }
145 return nil
146}
147
148func (uc *upstreamConn) endPendingLISTs(all bool) (found bool) {
149 found = false
150 for i := 0; i < len(uc.user.pendingLISTs); i++ {
151 pl := uc.user.pendingLISTs[i]
152 if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
153 continue
154 }
155 delete(pl.pendingCommands, uc.network.ID)
156 if len(pl.pendingCommands) == 0 {
157 uc.user.pendingLISTs = append(uc.user.pendingLISTs[:i], uc.user.pendingLISTs[i+1:]...)
158 i--
159 uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) {
160 dc.SendMessage(&irc.Message{
161 Prefix: dc.srv.prefix(),
162 Command: irc.RPL_LISTEND,
163 Params: []string{dc.nick, "End of /LIST"},
164 })
165 })
166 }
167 found = true
168 if !all {
169 delete(uc.pendingLISTDownstreamSet, pl.downstreamID)
170 uc.user.forEachUpstream(func(uc *upstreamConn) {
171 uc.trySendLIST(pl.downstreamID)
172 })
173 return
174 }
175 }
176 return
177}
178
179func (uc *upstreamConn) trySendLIST(downstreamID uint64) {
180 // must be called with a lock in uc.user.pendingLISTsLock
181
182 if _, ok := uc.pendingLISTDownstreamSet[downstreamID]; ok {
183 // a LIST command is already pending
184 // we will try again when that command is completed
185 return
186 }
187
188 for _, pl := range uc.user.pendingLISTs {
189 if pl.downstreamID != downstreamID {
190 continue
191 }
192 // this is the first pending LIST command list of the downstream
193 listCommand, ok := pl.pendingCommands[uc.network.ID]
194 if !ok {
195 // there is no command for this upstream in these LIST commands
196 // do not send anything
197 continue
198 }
199 // there is a command for this upstream in these LIST commands
200 // send it now
201
202 uc.SendMessageLabeled(downstreamID, listCommand)
203
204 uc.pendingLISTDownstreamSet[downstreamID] = struct{}{}
205 return
206 }
207}
208
209func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership, nick string) {
210 for _, m := range uc.availableMemberships {
211 if m.Prefix == s[0] {
212 return &m, s[1:]
213 }
214 }
215 return nil, s
216}
217
218func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
219 var label string
220 if l, ok := msg.GetTag("label"); ok {
221 label = l
222 }
223
224 var msgBatch *batch
225 if batchName, ok := msg.GetTag("batch"); ok {
226 b, ok := uc.batches[batchName]
227 if !ok {
228 return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
229 }
230 msgBatch = &b
231 if label == "" {
232 label = msgBatch.Label
233 }
234 }
235
236 var downstreamID uint64 = 0
237 if label != "" {
238 var labelOffset uint64
239 n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset)
240 if err == nil && n < 2 {
241 err = errors.New("not enough arguments")
242 }
243 if err != nil {
244 return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
245 }
246 }
247
248 if _, ok := msg.Tags["time"]; !ok {
249 msg.Tags["time"] = irc.TagValue(time.Now().Format(serverTimeLayout))
250 }
251
252 switch msg.Command {
253 case "PING":
254 uc.SendMessage(&irc.Message{
255 Command: "PONG",
256 Params: msg.Params,
257 })
258 return nil
259 case "NOTICE":
260 uc.logger.Print(msg)
261
262 if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
263 uc.forEachDownstream(func(dc *downstreamConn) {
264 dc.SendMessage(&irc.Message{
265 Command: "NOTICE",
266 Params: msg.Params,
267 })
268 })
269 } else { // regular user NOTICE
270 var nick, text string
271 if err := parseMessageParams(msg, &nick, &text); err != nil {
272 return err
273 }
274
275 target := nick
276 if nick == uc.nick {
277 target = msg.Prefix.Name
278 }
279 uc.appendLog(target, msg)
280
281 uc.forEachDownstream(func(dc *downstreamConn) {
282 dc.SendMessage(&irc.Message{
283 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
284 Command: "NOTICE",
285 Params: []string{dc.marshalEntity(uc, nick), text},
286 })
287 })
288 }
289 case "CAP":
290 var subCmd string
291 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
292 return err
293 }
294 subCmd = strings.ToUpper(subCmd)
295 subParams := msg.Params[2:]
296 switch subCmd {
297 case "LS":
298 if len(subParams) < 1 {
299 return newNeedMoreParamsError(msg.Command)
300 }
301 caps := strings.Fields(subParams[len(subParams)-1])
302 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
303
304 for _, s := range caps {
305 kv := strings.SplitN(s, "=", 2)
306 k := strings.ToLower(kv[0])
307 var v string
308 if len(kv) == 2 {
309 v = kv[1]
310 }
311 uc.caps[k] = v
312 }
313
314 if more {
315 break // wait to receive all capabilities
316 }
317
318 requestCaps := make([]string, 0, 16)
319 for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time"} {
320 if _, ok := uc.caps[c]; ok {
321 requestCaps = append(requestCaps, c)
322 }
323 }
324
325 if uc.requestSASL() {
326 requestCaps = append(requestCaps, "sasl")
327 }
328
329 if len(requestCaps) > 0 {
330 uc.SendMessage(&irc.Message{
331 Command: "CAP",
332 Params: []string{"REQ", strings.Join(requestCaps, " ")},
333 })
334 }
335
336 if uc.requestSASL() {
337 break // we'll send CAP END after authentication is completed
338 }
339
340 uc.SendMessage(&irc.Message{
341 Command: "CAP",
342 Params: []string{"END"},
343 })
344 case "ACK", "NAK":
345 if len(subParams) < 1 {
346 return newNeedMoreParamsError(msg.Command)
347 }
348 caps := strings.Fields(subParams[0])
349
350 for _, name := range caps {
351 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
352 return err
353 }
354 }
355
356 if uc.saslClient == nil {
357 uc.SendMessage(&irc.Message{
358 Command: "CAP",
359 Params: []string{"END"},
360 })
361 }
362 default:
363 uc.logger.Printf("unhandled message: %v", msg)
364 }
365 case "AUTHENTICATE":
366 if uc.saslClient == nil {
367 return fmt.Errorf("received unexpected AUTHENTICATE message")
368 }
369
370 // TODO: if a challenge is 400 bytes long, buffer it
371 var challengeStr string
372 if err := parseMessageParams(msg, &challengeStr); err != nil {
373 uc.SendMessage(&irc.Message{
374 Command: "AUTHENTICATE",
375 Params: []string{"*"},
376 })
377 return err
378 }
379
380 var challenge []byte
381 if challengeStr != "+" {
382 var err error
383 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
384 if err != nil {
385 uc.SendMessage(&irc.Message{
386 Command: "AUTHENTICATE",
387 Params: []string{"*"},
388 })
389 return err
390 }
391 }
392
393 var resp []byte
394 var err error
395 if !uc.saslStarted {
396 _, resp, err = uc.saslClient.Start()
397 uc.saslStarted = true
398 } else {
399 resp, err = uc.saslClient.Next(challenge)
400 }
401 if err != nil {
402 uc.SendMessage(&irc.Message{
403 Command: "AUTHENTICATE",
404 Params: []string{"*"},
405 })
406 return err
407 }
408
409 // TODO: send response in multiple chunks if >= 400 bytes
410 var respStr = "+"
411 if resp != nil {
412 respStr = base64.StdEncoding.EncodeToString(resp)
413 }
414
415 uc.SendMessage(&irc.Message{
416 Command: "AUTHENTICATE",
417 Params: []string{respStr},
418 })
419 case irc.RPL_LOGGEDIN:
420 var account string
421 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
422 return err
423 }
424 uc.logger.Printf("logged in with account %q", account)
425 case irc.RPL_LOGGEDOUT:
426 uc.logger.Printf("logged out")
427 case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
428 var info string
429 if err := parseMessageParams(msg, nil, &info); err != nil {
430 return err
431 }
432 switch msg.Command {
433 case irc.ERR_NICKLOCKED:
434 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
435 case irc.ERR_SASLFAIL:
436 uc.logger.Printf("SASL authentication failed: %v", info)
437 case irc.ERR_SASLTOOLONG:
438 uc.logger.Printf("SASL message too long: %v", info)
439 }
440
441 uc.saslClient = nil
442 uc.saslStarted = false
443
444 uc.SendMessage(&irc.Message{
445 Command: "CAP",
446 Params: []string{"END"},
447 })
448 case irc.RPL_WELCOME:
449 uc.registered = true
450 uc.logger.Printf("connection registered")
451
452 channels, err := uc.srv.db.ListChannels(uc.network.ID)
453 if err != nil {
454 uc.logger.Printf("failed to list channels from database: %v", err)
455 break
456 }
457
458 for _, ch := range channels {
459 params := []string{ch.Name}
460 if ch.Key != "" {
461 params = append(params, ch.Key)
462 }
463 uc.SendMessage(&irc.Message{
464 Command: "JOIN",
465 Params: params,
466 })
467 }
468 case irc.RPL_MYINFO:
469 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
470 return err
471 }
472 case irc.RPL_ISUPPORT:
473 if err := parseMessageParams(msg, nil, nil); err != nil {
474 return err
475 }
476 for _, token := range msg.Params[1 : len(msg.Params)-1] {
477 negate := false
478 parameter := token
479 value := ""
480 if strings.HasPrefix(token, "-") {
481 negate = true
482 token = token[1:]
483 } else {
484 if i := strings.IndexByte(token, '='); i >= 0 {
485 parameter = token[:i]
486 value = token[i+1:]
487 }
488 }
489 if !negate {
490 switch parameter {
491 case "CHANMODES":
492 parts := strings.SplitN(value, ",", 5)
493 if len(parts) < 4 {
494 return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", value)
495 }
496 modes := make(map[byte]channelModeType)
497 for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
498 for j := 0; j < len(parts[i]); j++ {
499 mode := parts[i][j]
500 modes[mode] = mt
501 }
502 }
503 uc.availableChannelModes = modes
504 case "CHANTYPES":
505 uc.availableChannelTypes = value
506 case "PREFIX":
507 if value == "" {
508 uc.availableMemberships = nil
509 } else {
510 if value[0] != '(' {
511 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
512 }
513 sep := strings.IndexByte(value, ')')
514 if sep < 0 || len(value) != sep*2 {
515 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
516 }
517 memberships := make([]membership, len(value)/2-1)
518 for i := range memberships {
519 memberships[i] = membership{
520 Mode: value[i+1],
521 Prefix: value[sep+i+1],
522 }
523 }
524 uc.availableMemberships = memberships
525 }
526 }
527 } else {
528 // TODO: handle ISUPPORT negations
529 }
530 }
531 case "BATCH":
532 var tag string
533 if err := parseMessageParams(msg, &tag); err != nil {
534 return err
535 }
536
537 if strings.HasPrefix(tag, "+") {
538 tag = tag[1:]
539 if _, ok := uc.batches[tag]; ok {
540 return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
541 }
542 var batchType string
543 if err := parseMessageParams(msg, nil, &batchType); err != nil {
544 return err
545 }
546 label := label
547 if label == "" && msgBatch != nil {
548 label = msgBatch.Label
549 }
550 uc.batches[tag] = batch{
551 Type: batchType,
552 Params: msg.Params[2:],
553 Outer: msgBatch,
554 Label: label,
555 }
556 } else if strings.HasPrefix(tag, "-") {
557 tag = tag[1:]
558 if _, ok := uc.batches[tag]; !ok {
559 return fmt.Errorf("unknown BATCH reference tag: %q", tag)
560 }
561 delete(uc.batches, tag)
562 } else {
563 return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
564 }
565 case "NICK":
566 if msg.Prefix == nil {
567 return fmt.Errorf("expected a prefix")
568 }
569
570 var newNick string
571 if err := parseMessageParams(msg, &newNick); err != nil {
572 return err
573 }
574
575 if msg.Prefix.Name == uc.nick {
576 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
577 uc.nick = newNick
578 }
579
580 for _, ch := range uc.channels {
581 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
582 delete(ch.Members, msg.Prefix.Name)
583 ch.Members[newNick] = membership
584 uc.appendLog(ch.Name, msg)
585 }
586 }
587
588 if msg.Prefix.Name != uc.nick {
589 uc.forEachDownstream(func(dc *downstreamConn) {
590 dc.SendMessage(&irc.Message{
591 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
592 Command: "NICK",
593 Params: []string{newNick},
594 })
595 })
596 }
597 case "JOIN":
598 if msg.Prefix == nil {
599 return fmt.Errorf("expected a prefix")
600 }
601
602 var channels string
603 if err := parseMessageParams(msg, &channels); err != nil {
604 return err
605 }
606
607 for _, ch := range strings.Split(channels, ",") {
608 if msg.Prefix.Name == uc.nick {
609 uc.logger.Printf("joined channel %q", ch)
610 uc.channels[ch] = &upstreamChannel{
611 Name: ch,
612 conn: uc,
613 Members: make(map[string]*membership),
614 }
615
616 uc.SendMessage(&irc.Message{
617 Command: "MODE",
618 Params: []string{ch},
619 })
620 } else {
621 ch, err := uc.getChannel(ch)
622 if err != nil {
623 return err
624 }
625 ch.Members[msg.Prefix.Name] = nil
626 }
627
628 uc.appendLog(ch, msg)
629
630 uc.forEachDownstream(func(dc *downstreamConn) {
631 dc.SendMessage(&irc.Message{
632 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
633 Command: "JOIN",
634 Params: []string{dc.marshalChannel(uc, ch)},
635 })
636 })
637 }
638 case "PART":
639 if msg.Prefix == nil {
640 return fmt.Errorf("expected a prefix")
641 }
642
643 var channels string
644 if err := parseMessageParams(msg, &channels); err != nil {
645 return err
646 }
647
648 var reason string
649 if len(msg.Params) > 1 {
650 reason = msg.Params[1]
651 }
652
653 for _, ch := range strings.Split(channels, ",") {
654 if msg.Prefix.Name == uc.nick {
655 uc.logger.Printf("parted channel %q", ch)
656 delete(uc.channels, ch)
657 } else {
658 ch, err := uc.getChannel(ch)
659 if err != nil {
660 return err
661 }
662 delete(ch.Members, msg.Prefix.Name)
663 }
664
665 uc.appendLog(ch, msg)
666
667 uc.forEachDownstream(func(dc *downstreamConn) {
668 params := []string{dc.marshalChannel(uc, ch)}
669 if reason != "" {
670 params = append(params, reason)
671 }
672 dc.SendMessage(&irc.Message{
673 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
674 Command: "PART",
675 Params: params,
676 })
677 })
678 }
679 case "KICK":
680 if msg.Prefix == nil {
681 return fmt.Errorf("expected a prefix")
682 }
683
684 var channel, user string
685 if err := parseMessageParams(msg, &channel, &user); err != nil {
686 return err
687 }
688
689 var reason string
690 if len(msg.Params) > 2 {
691 reason = msg.Params[2]
692 }
693
694 if user == uc.nick {
695 uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
696 delete(uc.channels, channel)
697 } else {
698 ch, err := uc.getChannel(channel)
699 if err != nil {
700 return err
701 }
702 delete(ch.Members, user)
703 }
704
705 uc.appendLog(channel, msg)
706
707 uc.forEachDownstream(func(dc *downstreamConn) {
708 params := []string{dc.marshalChannel(uc, channel), dc.marshalNick(uc, user)}
709 if reason != "" {
710 params = append(params, reason)
711 }
712 dc.SendMessage(&irc.Message{
713 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
714 Command: "KICK",
715 Params: params,
716 })
717 })
718 case "QUIT":
719 if msg.Prefix == nil {
720 return fmt.Errorf("expected a prefix")
721 }
722
723 if msg.Prefix.Name == uc.nick {
724 uc.logger.Printf("quit")
725 }
726
727 for _, ch := range uc.channels {
728 if _, ok := ch.Members[msg.Prefix.Name]; ok {
729 delete(ch.Members, msg.Prefix.Name)
730
731 uc.appendLog(ch.Name, msg)
732 }
733 }
734
735 if msg.Prefix.Name != uc.nick {
736 uc.forEachDownstream(func(dc *downstreamConn) {
737 dc.SendMessage(&irc.Message{
738 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
739 Command: "QUIT",
740 Params: msg.Params,
741 })
742 })
743 }
744 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
745 var name, topic string
746 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
747 return err
748 }
749 ch, err := uc.getChannel(name)
750 if err != nil {
751 return err
752 }
753 if msg.Command == irc.RPL_TOPIC {
754 ch.Topic = topic
755 } else {
756 ch.Topic = ""
757 }
758 case "TOPIC":
759 var name string
760 if err := parseMessageParams(msg, &name); err != nil {
761 return err
762 }
763 ch, err := uc.getChannel(name)
764 if err != nil {
765 return err
766 }
767 if len(msg.Params) > 1 {
768 ch.Topic = msg.Params[1]
769 } else {
770 ch.Topic = ""
771 }
772 uc.forEachDownstream(func(dc *downstreamConn) {
773 params := []string{dc.marshalChannel(uc, name)}
774 if ch.Topic != "" {
775 params = append(params, ch.Topic)
776 }
777 dc.SendMessage(&irc.Message{
778 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
779 Command: "TOPIC",
780 Params: params,
781 })
782 })
783 case "MODE":
784 var name, modeStr string
785 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
786 return err
787 }
788
789 if !uc.isChannel(name) { // user mode change
790 if name != uc.nick {
791 return fmt.Errorf("received MODE message for unknown nick %q", name)
792 }
793 return uc.modes.Apply(modeStr)
794 // TODO: notify downstreams about user mode change?
795 } else { // channel mode change
796 ch, err := uc.getChannel(name)
797 if err != nil {
798 return err
799 }
800
801 if ch.modes != nil {
802 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[2:]...); err != nil {
803 return err
804 }
805 }
806
807 uc.appendLog(ch.Name, msg)
808
809 uc.forEachDownstream(func(dc *downstreamConn) {
810 params := []string{dc.marshalChannel(uc, name), modeStr}
811 params = append(params, msg.Params[2:]...)
812
813 dc.SendMessage(&irc.Message{
814 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
815 Command: "MODE",
816 Params: params,
817 })
818 })
819 }
820 case irc.RPL_UMODEIS:
821 if err := parseMessageParams(msg, nil); err != nil {
822 return err
823 }
824 modeStr := ""
825 if len(msg.Params) > 1 {
826 modeStr = msg.Params[1]
827 }
828
829 uc.modes = ""
830 if err := uc.modes.Apply(modeStr); err != nil {
831 return err
832 }
833 // TODO: send RPL_UMODEIS to downstream connections when applicable
834 case irc.RPL_CHANNELMODEIS:
835 var channel string
836 if err := parseMessageParams(msg, nil, &channel); err != nil {
837 return err
838 }
839 modeStr := ""
840 if len(msg.Params) > 2 {
841 modeStr = msg.Params[2]
842 }
843
844 ch, err := uc.getChannel(channel)
845 if err != nil {
846 return err
847 }
848
849 firstMode := ch.modes == nil
850 ch.modes = make(map[byte]string)
851 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[3:]...); err != nil {
852 return err
853 }
854 if firstMode {
855 modeStr, modeParams := ch.modes.Format()
856
857 uc.forEachDownstream(func(dc *downstreamConn) {
858 params := []string{dc.nick, dc.marshalChannel(uc, channel), modeStr}
859 params = append(params, modeParams...)
860
861 dc.SendMessage(&irc.Message{
862 Prefix: dc.srv.prefix(),
863 Command: irc.RPL_CHANNELMODEIS,
864 Params: params,
865 })
866 })
867 }
868 case rpl_creationtime:
869 var channel, creationTime string
870 if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil {
871 return err
872 }
873
874 ch, err := uc.getChannel(channel)
875 if err != nil {
876 return err
877 }
878
879 firstCreationTime := ch.creationTime == ""
880 ch.creationTime = creationTime
881 if firstCreationTime {
882 uc.forEachDownstream(func(dc *downstreamConn) {
883 dc.SendMessage(&irc.Message{
884 Prefix: dc.srv.prefix(),
885 Command: rpl_creationtime,
886 Params: []string{dc.nick, channel, creationTime},
887 })
888 })
889 }
890 case rpl_topicwhotime:
891 var name, who, timeStr string
892 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
893 return err
894 }
895 ch, err := uc.getChannel(name)
896 if err != nil {
897 return err
898 }
899 ch.TopicWho = who
900 sec, err := strconv.ParseInt(timeStr, 10, 64)
901 if err != nil {
902 return fmt.Errorf("failed to parse topic time: %v", err)
903 }
904 ch.TopicTime = time.Unix(sec, 0)
905 case irc.RPL_LIST:
906 var channel, clients, topic string
907 if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil {
908 return err
909 }
910
911 pl := uc.getPendingLIST()
912 if pl == nil {
913 return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
914 }
915
916 uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) {
917 dc.SendMessage(&irc.Message{
918 Prefix: dc.srv.prefix(),
919 Command: irc.RPL_LIST,
920 Params: []string{dc.nick, dc.marshalChannel(uc, channel), clients, topic},
921 })
922 })
923 case irc.RPL_LISTEND:
924 ok := uc.endPendingLISTs(false)
925 if !ok {
926 return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
927 }
928 case irc.RPL_NAMREPLY:
929 var name, statusStr, members string
930 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
931 return err
932 }
933
934 ch, ok := uc.channels[name]
935 if !ok {
936 // NAMES on a channel we have not joined, forward to downstream
937 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
938 channel := dc.marshalChannel(uc, name)
939 members := splitSpace(members)
940 for i, member := range members {
941 membership, nick := uc.parseMembershipPrefix(member)
942 members[i] = membership.String() + dc.marshalNick(uc, nick)
943 }
944 memberStr := strings.Join(members, " ")
945
946 dc.SendMessage(&irc.Message{
947 Prefix: dc.srv.prefix(),
948 Command: irc.RPL_NAMREPLY,
949 Params: []string{dc.nick, statusStr, channel, memberStr},
950 })
951 })
952 return nil
953 }
954
955 status, err := parseChannelStatus(statusStr)
956 if err != nil {
957 return err
958 }
959 ch.Status = status
960
961 for _, s := range splitSpace(members) {
962 membership, nick := uc.parseMembershipPrefix(s)
963 ch.Members[nick] = membership
964 }
965 case irc.RPL_ENDOFNAMES:
966 var name string
967 if err := parseMessageParams(msg, nil, &name); err != nil {
968 return err
969 }
970
971 ch, ok := uc.channels[name]
972 if !ok {
973 // NAMES on a channel we have not joined, forward to downstream
974 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
975 channel := dc.marshalChannel(uc, name)
976
977 dc.SendMessage(&irc.Message{
978 Prefix: dc.srv.prefix(),
979 Command: irc.RPL_ENDOFNAMES,
980 Params: []string{dc.nick, channel, "End of /NAMES list"},
981 })
982 })
983 return nil
984 }
985
986 if ch.complete {
987 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
988 }
989 ch.complete = true
990
991 uc.forEachDownstream(func(dc *downstreamConn) {
992 forwardChannel(dc, ch)
993 })
994 case irc.RPL_WHOREPLY:
995 var channel, username, host, server, nick, mode, trailing string
996 if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &mode, &trailing); err != nil {
997 return err
998 }
999
1000 parts := strings.SplitN(trailing, " ", 2)
1001 if len(parts) != 2 {
1002 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing)
1003 }
1004 realname := parts[1]
1005 hops, err := strconv.Atoi(parts[0])
1006 if err != nil {
1007 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong hop count: %s", parts[0])
1008 }
1009 hops++
1010
1011 trailing = strconv.Itoa(hops) + " " + realname
1012
1013 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1014 channel := channel
1015 if channel != "*" {
1016 channel = dc.marshalChannel(uc, channel)
1017 }
1018 nick := dc.marshalNick(uc, nick)
1019 dc.SendMessage(&irc.Message{
1020 Prefix: dc.srv.prefix(),
1021 Command: irc.RPL_WHOREPLY,
1022 Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
1023 })
1024 })
1025 case irc.RPL_ENDOFWHO:
1026 var name string
1027 if err := parseMessageParams(msg, nil, &name); err != nil {
1028 return err
1029 }
1030
1031 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1032 name := name
1033 if name != "*" {
1034 // TODO: support WHO masks
1035 name = dc.marshalEntity(uc, name)
1036 }
1037 dc.SendMessage(&irc.Message{
1038 Prefix: dc.srv.prefix(),
1039 Command: irc.RPL_ENDOFWHO,
1040 Params: []string{dc.nick, name, "End of /WHO list"},
1041 })
1042 })
1043 case irc.RPL_WHOISUSER:
1044 var nick, username, host, realname string
1045 if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
1046 return err
1047 }
1048
1049 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1050 nick := dc.marshalNick(uc, nick)
1051 dc.SendMessage(&irc.Message{
1052 Prefix: dc.srv.prefix(),
1053 Command: irc.RPL_WHOISUSER,
1054 Params: []string{dc.nick, nick, username, host, "*", realname},
1055 })
1056 })
1057 case irc.RPL_WHOISSERVER:
1058 var nick, server, serverInfo string
1059 if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
1060 return err
1061 }
1062
1063 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1064 nick := dc.marshalNick(uc, nick)
1065 dc.SendMessage(&irc.Message{
1066 Prefix: dc.srv.prefix(),
1067 Command: irc.RPL_WHOISSERVER,
1068 Params: []string{dc.nick, nick, server, serverInfo},
1069 })
1070 })
1071 case irc.RPL_WHOISOPERATOR:
1072 var nick string
1073 if err := parseMessageParams(msg, nil, &nick); err != nil {
1074 return err
1075 }
1076
1077 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1078 nick := dc.marshalNick(uc, nick)
1079 dc.SendMessage(&irc.Message{
1080 Prefix: dc.srv.prefix(),
1081 Command: irc.RPL_WHOISOPERATOR,
1082 Params: []string{dc.nick, nick, "is an IRC operator"},
1083 })
1084 })
1085 case irc.RPL_WHOISIDLE:
1086 var nick string
1087 if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
1088 return err
1089 }
1090
1091 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1092 nick := dc.marshalNick(uc, nick)
1093 params := []string{dc.nick, nick}
1094 params = append(params, msg.Params[2:]...)
1095 dc.SendMessage(&irc.Message{
1096 Prefix: dc.srv.prefix(),
1097 Command: irc.RPL_WHOISIDLE,
1098 Params: params,
1099 })
1100 })
1101 case irc.RPL_WHOISCHANNELS:
1102 var nick, channelList string
1103 if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
1104 return err
1105 }
1106 channels := splitSpace(channelList)
1107
1108 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1109 nick := dc.marshalNick(uc, nick)
1110 channelList := make([]string, len(channels))
1111 for i, channel := range channels {
1112 prefix, channel := uc.parseMembershipPrefix(channel)
1113 channel = dc.marshalChannel(uc, channel)
1114 channelList[i] = prefix.String() + channel
1115 }
1116 channels := strings.Join(channelList, " ")
1117 dc.SendMessage(&irc.Message{
1118 Prefix: dc.srv.prefix(),
1119 Command: irc.RPL_WHOISCHANNELS,
1120 Params: []string{dc.nick, nick, channels},
1121 })
1122 })
1123 case irc.RPL_ENDOFWHOIS:
1124 var nick string
1125 if err := parseMessageParams(msg, nil, &nick); err != nil {
1126 return err
1127 }
1128
1129 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1130 nick := dc.marshalNick(uc, nick)
1131 dc.SendMessage(&irc.Message{
1132 Prefix: dc.srv.prefix(),
1133 Command: irc.RPL_ENDOFWHOIS,
1134 Params: []string{dc.nick, nick, "End of /WHOIS list"},
1135 })
1136 })
1137 case "PRIVMSG":
1138 if msg.Prefix == nil {
1139 return fmt.Errorf("expected a prefix")
1140 }
1141
1142 var nick, text string
1143 if err := parseMessageParams(msg, &nick, &text); err != nil {
1144 return err
1145 }
1146
1147 if msg.Prefix.Name == serviceNick {
1148 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
1149 break
1150 }
1151 if nick == serviceNick {
1152 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
1153 break
1154 }
1155
1156 target := nick
1157 if nick == uc.nick {
1158 target = msg.Prefix.Name
1159 }
1160 uc.appendLog(target, msg)
1161
1162 uc.network.ring.Produce(msg)
1163 case "INVITE":
1164 var nick string
1165 var channel string
1166 if err := parseMessageParams(msg, &nick, &channel); err != nil {
1167 return err
1168 }
1169
1170 uc.forEachDownstream(func(dc *downstreamConn) {
1171 dc.SendMessage(&irc.Message{
1172 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
1173 Command: "INVITE",
1174 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
1175 })
1176 })
1177 case irc.RPL_INVITING:
1178 var nick string
1179 var channel string
1180 if err := parseMessageParams(msg, &nick, &channel); err != nil {
1181 return err
1182 }
1183
1184 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1185 dc.SendMessage(&irc.Message{
1186 Prefix: dc.srv.prefix(),
1187 Command: irc.RPL_INVITING,
1188 Params: []string{dc.nick, dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
1189 })
1190 })
1191 case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN:
1192 var command, reason string
1193 if err := parseMessageParams(msg, nil, &command, &reason); err != nil {
1194 return err
1195 }
1196
1197 if command == "LIST" {
1198 ok := uc.endPendingLISTs(false)
1199 if !ok {
1200 return fmt.Errorf("unexpected response for LIST: %q: no matching pending LIST", msg.Command)
1201 }
1202 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1203 dc.SendMessage(&irc.Message{
1204 Prefix: uc.srv.prefix(),
1205 Command: msg.Command,
1206 Params: []string{dc.nick, "LIST", reason},
1207 })
1208 })
1209 }
1210 case "TAGMSG":
1211 // TODO: relay to downstream connections that accept message-tags
1212 case "ACK":
1213 // Ignore
1214 case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
1215 // Ignore
1216 case irc.RPL_YOURHOST, irc.RPL_CREATED:
1217 // Ignore
1218 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
1219 // Ignore
1220 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
1221 // Ignore
1222 case irc.RPL_LISTSTART:
1223 // Ignore
1224 case rpl_localusers, rpl_globalusers:
1225 // Ignore
1226 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
1227 // Ignore
1228 default:
1229 uc.logger.Printf("unhandled message: %v", msg)
1230 }
1231 return nil
1232}
1233
1234func splitSpace(s string) []string {
1235 return strings.FieldsFunc(s, func(r rune) bool {
1236 return r == ' '
1237 })
1238}
1239
1240func (uc *upstreamConn) register() {
1241 uc.nick = uc.network.Nick
1242 uc.username = uc.network.Username
1243 if uc.username == "" {
1244 uc.username = uc.nick
1245 }
1246 uc.realname = uc.network.Realname
1247 if uc.realname == "" {
1248 uc.realname = uc.nick
1249 }
1250
1251 uc.SendMessage(&irc.Message{
1252 Command: "CAP",
1253 Params: []string{"LS", "302"},
1254 })
1255
1256 if uc.network.Pass != "" {
1257 uc.SendMessage(&irc.Message{
1258 Command: "PASS",
1259 Params: []string{uc.network.Pass},
1260 })
1261 }
1262
1263 uc.SendMessage(&irc.Message{
1264 Command: "NICK",
1265 Params: []string{uc.nick},
1266 })
1267 uc.SendMessage(&irc.Message{
1268 Command: "USER",
1269 Params: []string{uc.username, "0", "*", uc.realname},
1270 })
1271}
1272
1273func (uc *upstreamConn) runUntilRegistered() error {
1274 for !uc.registered {
1275 msg, err := uc.ReadMessage()
1276 if err != nil {
1277 return fmt.Errorf("failed to read message: %v", err)
1278 }
1279
1280 if err := uc.handleMessage(msg); err != nil {
1281 return fmt.Errorf("failed to handle message %q: %v", msg, err)
1282 }
1283 }
1284
1285 return nil
1286}
1287
1288func (uc *upstreamConn) requestSASL() bool {
1289 if uc.network.SASL.Mechanism == "" {
1290 return false
1291 }
1292
1293 v, ok := uc.caps["sasl"]
1294 if !ok {
1295 return false
1296 }
1297 if v != "" {
1298 mechanisms := strings.Split(v, ",")
1299 found := false
1300 for _, mech := range mechanisms {
1301 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
1302 found = true
1303 break
1304 }
1305 }
1306 if !found {
1307 return false
1308 }
1309 }
1310
1311 return true
1312}
1313
1314func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
1315 auth := &uc.network.SASL
1316 switch name {
1317 case "sasl":
1318 if !ok {
1319 uc.logger.Printf("server refused to acknowledge the SASL capability")
1320 return nil
1321 }
1322
1323 switch auth.Mechanism {
1324 case "PLAIN":
1325 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
1326 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
1327 default:
1328 return fmt.Errorf("unsupported SASL mechanism %q", name)
1329 }
1330
1331 uc.SendMessage(&irc.Message{
1332 Command: "AUTHENTICATE",
1333 Params: []string{auth.Mechanism},
1334 })
1335 case "message-tags":
1336 uc.tagsSupported = ok
1337 case "labeled-response":
1338 uc.labelsSupported = ok
1339 case "batch", "server-time":
1340 // Nothing to do
1341 default:
1342 uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
1343 }
1344 return nil
1345}
1346
1347func (uc *upstreamConn) readMessages(ch chan<- event) error {
1348 for {
1349 msg, err := uc.ReadMessage()
1350 if err == io.EOF {
1351 break
1352 } else if err != nil {
1353 return fmt.Errorf("failed to read IRC command: %v", err)
1354 }
1355
1356 ch <- eventUpstreamMessage{msg, uc}
1357 }
1358
1359 return nil
1360}
1361
1362func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
1363 if uc.labelsSupported {
1364 if msg.Tags == nil {
1365 msg.Tags = make(map[string]irc.TagValue)
1366 }
1367 msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
1368 uc.nextLabelID++
1369 }
1370 uc.SendMessage(msg)
1371}
1372
1373// TODO: handle moving logs when a network name changes, when support for this is added
1374func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) {
1375 if uc.srv.LogPath == "" {
1376 return
1377 }
1378
1379 ml, ok := uc.messageLoggers[entity]
1380 if !ok {
1381 ml = newMessageLogger(uc, entity)
1382 uc.messageLoggers[entity] = ml
1383 }
1384
1385 if err := ml.Append(msg); err != nil {
1386 uc.logger.Printf("failed to log message: %v", err)
1387 }
1388}
1389
1390func (uc *upstreamConn) updateAway() {
1391 away := true
1392 uc.forEachDownstream(func(*downstreamConn) {
1393 away = false
1394 })
1395 if away == uc.away {
1396 return
1397 }
1398 if away {
1399 uc.SendMessage(&irc.Message{
1400 Command: "AWAY",
1401 Params: []string{"Auto away"},
1402 })
1403 } else {
1404 uc.SendMessage(&irc.Message{
1405 Command: "AWAY",
1406 })
1407 }
1408 uc.away = away
1409}
Note: See TracBrowser for help on using the repository browser.