source: code/trunk/upstream.go@ 210

Last change on this file since 210 was 210, checked in by contact, 5 years ago

Introduce conn for common connection logic

This centralizes the common upstream & downstream bits.

File size: 35.9 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "errors"
7 "fmt"
8 "io"
9 "net"
10 "os"
11 "path/filepath"
12 "strconv"
13 "strings"
14 "time"
15
16 "github.com/emersion/go-sasl"
17 "gopkg.in/irc.v3"
18)
19
20type upstreamChannel struct {
21 Name string
22 conn *upstreamConn
23 Topic string
24 TopicWho string
25 TopicTime time.Time
26 Status channelStatus
27 modes channelModes
28 creationTime string
29 Members map[string]*membership
30 complete bool
31}
32
33type upstreamConn struct {
34 conn
35
36 network *network
37 user *user
38
39 serverName string
40 availableUserModes string
41 availableChannelModes map[byte]channelModeType
42 availableChannelTypes string
43 availableMemberships []membership
44
45 registered bool
46 nick string
47 username string
48 realname string
49 modes userModes
50 channels map[string]*upstreamChannel
51 caps map[string]string
52 batches map[string]batch
53 away bool
54
55 tagsSupported bool
56 labelsSupported bool
57 nextLabelID uint64
58
59 saslClient sasl.Client
60 saslStarted bool
61
62 // set of LIST commands in progress, per downstream
63 pendingLISTDownstreamSet map[uint64]struct{}
64
65 logs map[string]entityLog
66}
67
68type entityLog struct {
69 name string
70 file *os.File
71}
72
73func connectToUpstream(network *network) (*upstreamConn, error) {
74 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
75
76 addr := network.Addr
77 if !strings.ContainsRune(addr, ':') {
78 addr = addr + ":6697"
79 }
80
81 dialer := net.Dialer{Timeout: connectTimeout}
82
83 logger.Printf("connecting to TLS server at address %q", addr)
84 netConn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
85 if err != nil {
86 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
87 }
88
89 uc := &upstreamConn{
90 conn: *newConn(network.user.srv, netConn, logger),
91 network: network,
92 user: network.user,
93 channels: make(map[string]*upstreamChannel),
94 caps: make(map[string]string),
95 batches: make(map[string]batch),
96 availableChannelTypes: stdChannelTypes,
97 availableChannelModes: stdChannelModes,
98 availableMemberships: stdMemberships,
99 pendingLISTDownstreamSet: make(map[uint64]struct{}),
100 logs: make(map[string]entityLog),
101 }
102
103 return uc, nil
104}
105
106func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
107 uc.user.forEachDownstream(func(dc *downstreamConn) {
108 if dc.network != nil && dc.network != uc.network {
109 return
110 }
111 f(dc)
112 })
113}
114
115func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) {
116 uc.forEachDownstream(func(dc *downstreamConn) {
117 if id != 0 && id != dc.id {
118 return
119 }
120 f(dc)
121 })
122}
123
124func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
125 ch, ok := uc.channels[name]
126 if !ok {
127 return nil, fmt.Errorf("unknown channel %q", name)
128 }
129 return ch, nil
130}
131
132func (uc *upstreamConn) isChannel(entity string) bool {
133 if i := strings.IndexByte(uc.availableChannelTypes, entity[0]); i >= 0 {
134 return true
135 }
136 return false
137}
138
139func (uc *upstreamConn) getPendingLIST() *pendingLIST {
140 for _, pl := range uc.user.pendingLISTs {
141 if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
142 continue
143 }
144 return &pl
145 }
146 return nil
147}
148
149func (uc *upstreamConn) endPendingLISTs(all bool) (found bool) {
150 found = false
151 for i := 0; i < len(uc.user.pendingLISTs); i++ {
152 pl := uc.user.pendingLISTs[i]
153 if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
154 continue
155 }
156 delete(pl.pendingCommands, uc.network.ID)
157 if len(pl.pendingCommands) == 0 {
158 uc.user.pendingLISTs = append(uc.user.pendingLISTs[:i], uc.user.pendingLISTs[i+1:]...)
159 i--
160 uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) {
161 dc.SendMessage(&irc.Message{
162 Prefix: dc.srv.prefix(),
163 Command: irc.RPL_LISTEND,
164 Params: []string{dc.nick, "End of /LIST"},
165 })
166 })
167 }
168 found = true
169 if !all {
170 delete(uc.pendingLISTDownstreamSet, pl.downstreamID)
171 uc.user.forEachUpstream(func(uc *upstreamConn) {
172 uc.trySendLIST(pl.downstreamID)
173 })
174 return
175 }
176 }
177 return
178}
179
180func (uc *upstreamConn) trySendLIST(downstreamID uint64) {
181 // must be called with a lock in uc.user.pendingLISTsLock
182
183 if _, ok := uc.pendingLISTDownstreamSet[downstreamID]; ok {
184 // a LIST command is already pending
185 // we will try again when that command is completed
186 return
187 }
188
189 for _, pl := range uc.user.pendingLISTs {
190 if pl.downstreamID != downstreamID {
191 continue
192 }
193 // this is the first pending LIST command list of the downstream
194 listCommand, ok := pl.pendingCommands[uc.network.ID]
195 if !ok {
196 // there is no command for this upstream in these LIST commands
197 // do not send anything
198 continue
199 }
200 // there is a command for this upstream in these LIST commands
201 // send it now
202
203 uc.SendMessageLabeled(downstreamID, listCommand)
204
205 uc.pendingLISTDownstreamSet[downstreamID] = struct{}{}
206 return
207 }
208}
209
210func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership, nick string) {
211 for _, m := range uc.availableMemberships {
212 if m.Prefix == s[0] {
213 return &m, s[1:]
214 }
215 }
216 return nil, s
217}
218
219func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
220 var label string
221 if l, ok := msg.GetTag("label"); ok {
222 label = l
223 }
224
225 var msgBatch *batch
226 if batchName, ok := msg.GetTag("batch"); ok {
227 b, ok := uc.batches[batchName]
228 if !ok {
229 return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
230 }
231 msgBatch = &b
232 if label == "" {
233 label = msgBatch.Label
234 }
235 }
236
237 var downstreamID uint64 = 0
238 if label != "" {
239 var labelOffset uint64
240 n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset)
241 if err == nil && n < 2 {
242 err = errors.New("not enough arguments")
243 }
244 if err != nil {
245 return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
246 }
247 }
248
249 switch msg.Command {
250 case "PING":
251 uc.SendMessage(&irc.Message{
252 Command: "PONG",
253 Params: msg.Params,
254 })
255 return nil
256 case "NOTICE":
257 uc.logger.Print(msg)
258
259 if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
260 uc.forEachDownstream(func(dc *downstreamConn) {
261 dc.SendMessage(&irc.Message{
262 Command: "NOTICE",
263 Params: msg.Params,
264 })
265 })
266 } else { // regular user NOTICE
267 var nick, text string
268 if err := parseMessageParams(msg, &nick, &text); err != nil {
269 return err
270 }
271
272 target := nick
273 if nick == uc.nick {
274 target = msg.Prefix.Name
275 }
276 uc.appendLog(target, "<%s> %s", msg.Prefix.Name, text)
277
278 uc.forEachDownstream(func(dc *downstreamConn) {
279 dc.SendMessage(&irc.Message{
280 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
281 Command: "NOTICE",
282 Params: []string{dc.marshalEntity(uc, nick), text},
283 })
284 })
285 }
286 case "CAP":
287 var subCmd string
288 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
289 return err
290 }
291 subCmd = strings.ToUpper(subCmd)
292 subParams := msg.Params[2:]
293 switch subCmd {
294 case "LS":
295 if len(subParams) < 1 {
296 return newNeedMoreParamsError(msg.Command)
297 }
298 caps := strings.Fields(subParams[len(subParams)-1])
299 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
300
301 for _, s := range caps {
302 kv := strings.SplitN(s, "=", 2)
303 k := strings.ToLower(kv[0])
304 var v string
305 if len(kv) == 2 {
306 v = kv[1]
307 }
308 uc.caps[k] = v
309 }
310
311 if more {
312 break // wait to receive all capabilities
313 }
314
315 requestCaps := make([]string, 0, 16)
316 for _, c := range []string{"message-tags", "batch", "labeled-response", "server-time"} {
317 if _, ok := uc.caps[c]; ok {
318 requestCaps = append(requestCaps, c)
319 }
320 }
321
322 if uc.requestSASL() {
323 requestCaps = append(requestCaps, "sasl")
324 }
325
326 if len(requestCaps) > 0 {
327 uc.SendMessage(&irc.Message{
328 Command: "CAP",
329 Params: []string{"REQ", strings.Join(requestCaps, " ")},
330 })
331 }
332
333 if uc.requestSASL() {
334 break // we'll send CAP END after authentication is completed
335 }
336
337 uc.SendMessage(&irc.Message{
338 Command: "CAP",
339 Params: []string{"END"},
340 })
341 case "ACK", "NAK":
342 if len(subParams) < 1 {
343 return newNeedMoreParamsError(msg.Command)
344 }
345 caps := strings.Fields(subParams[0])
346
347 for _, name := range caps {
348 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
349 return err
350 }
351 }
352
353 if uc.saslClient == nil {
354 uc.SendMessage(&irc.Message{
355 Command: "CAP",
356 Params: []string{"END"},
357 })
358 }
359 default:
360 uc.logger.Printf("unhandled message: %v", msg)
361 }
362 case "AUTHENTICATE":
363 if uc.saslClient == nil {
364 return fmt.Errorf("received unexpected AUTHENTICATE message")
365 }
366
367 // TODO: if a challenge is 400 bytes long, buffer it
368 var challengeStr string
369 if err := parseMessageParams(msg, &challengeStr); err != nil {
370 uc.SendMessage(&irc.Message{
371 Command: "AUTHENTICATE",
372 Params: []string{"*"},
373 })
374 return err
375 }
376
377 var challenge []byte
378 if challengeStr != "+" {
379 var err error
380 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
381 if err != nil {
382 uc.SendMessage(&irc.Message{
383 Command: "AUTHENTICATE",
384 Params: []string{"*"},
385 })
386 return err
387 }
388 }
389
390 var resp []byte
391 var err error
392 if !uc.saslStarted {
393 _, resp, err = uc.saslClient.Start()
394 uc.saslStarted = true
395 } else {
396 resp, err = uc.saslClient.Next(challenge)
397 }
398 if err != nil {
399 uc.SendMessage(&irc.Message{
400 Command: "AUTHENTICATE",
401 Params: []string{"*"},
402 })
403 return err
404 }
405
406 // TODO: send response in multiple chunks if >= 400 bytes
407 var respStr = "+"
408 if resp != nil {
409 respStr = base64.StdEncoding.EncodeToString(resp)
410 }
411
412 uc.SendMessage(&irc.Message{
413 Command: "AUTHENTICATE",
414 Params: []string{respStr},
415 })
416 case irc.RPL_LOGGEDIN:
417 var account string
418 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
419 return err
420 }
421 uc.logger.Printf("logged in with account %q", account)
422 case irc.RPL_LOGGEDOUT:
423 uc.logger.Printf("logged out")
424 case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
425 var info string
426 if err := parseMessageParams(msg, nil, &info); err != nil {
427 return err
428 }
429 switch msg.Command {
430 case irc.ERR_NICKLOCKED:
431 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
432 case irc.ERR_SASLFAIL:
433 uc.logger.Printf("SASL authentication failed: %v", info)
434 case irc.ERR_SASLTOOLONG:
435 uc.logger.Printf("SASL message too long: %v", info)
436 }
437
438 uc.saslClient = nil
439 uc.saslStarted = false
440
441 uc.SendMessage(&irc.Message{
442 Command: "CAP",
443 Params: []string{"END"},
444 })
445 case irc.RPL_WELCOME:
446 uc.registered = true
447 uc.logger.Printf("connection registered")
448
449 channels, err := uc.srv.db.ListChannels(uc.network.ID)
450 if err != nil {
451 uc.logger.Printf("failed to list channels from database: %v", err)
452 break
453 }
454
455 for _, ch := range channels {
456 params := []string{ch.Name}
457 if ch.Key != "" {
458 params = append(params, ch.Key)
459 }
460 uc.SendMessage(&irc.Message{
461 Command: "JOIN",
462 Params: params,
463 })
464 }
465 case irc.RPL_MYINFO:
466 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
467 return err
468 }
469 case irc.RPL_ISUPPORT:
470 if err := parseMessageParams(msg, nil, nil); err != nil {
471 return err
472 }
473 for _, token := range msg.Params[1 : len(msg.Params)-1] {
474 negate := false
475 parameter := token
476 value := ""
477 if strings.HasPrefix(token, "-") {
478 negate = true
479 token = token[1:]
480 } else {
481 if i := strings.IndexByte(token, '='); i >= 0 {
482 parameter = token[:i]
483 value = token[i+1:]
484 }
485 }
486 if !negate {
487 switch parameter {
488 case "CHANMODES":
489 parts := strings.SplitN(value, ",", 5)
490 if len(parts) < 4 {
491 return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", value)
492 }
493 modes := make(map[byte]channelModeType)
494 for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
495 for j := 0; j < len(parts[i]); j++ {
496 mode := parts[i][j]
497 modes[mode] = mt
498 }
499 }
500 uc.availableChannelModes = modes
501 case "CHANTYPES":
502 uc.availableChannelTypes = value
503 case "PREFIX":
504 if value == "" {
505 uc.availableMemberships = nil
506 } else {
507 if value[0] != '(' {
508 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
509 }
510 sep := strings.IndexByte(value, ')')
511 if sep < 0 || len(value) != sep*2 {
512 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
513 }
514 memberships := make([]membership, len(value)/2-1)
515 for i := range memberships {
516 memberships[i] = membership{
517 Mode: value[i+1],
518 Prefix: value[sep+i+1],
519 }
520 }
521 uc.availableMemberships = memberships
522 }
523 }
524 } else {
525 // TODO: handle ISUPPORT negations
526 }
527 }
528 case "BATCH":
529 var tag string
530 if err := parseMessageParams(msg, &tag); err != nil {
531 return err
532 }
533
534 if strings.HasPrefix(tag, "+") {
535 tag = tag[1:]
536 if _, ok := uc.batches[tag]; ok {
537 return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
538 }
539 var batchType string
540 if err := parseMessageParams(msg, nil, &batchType); err != nil {
541 return err
542 }
543 label := label
544 if label == "" && msgBatch != nil {
545 label = msgBatch.Label
546 }
547 uc.batches[tag] = batch{
548 Type: batchType,
549 Params: msg.Params[2:],
550 Outer: msgBatch,
551 Label: label,
552 }
553 } else if strings.HasPrefix(tag, "-") {
554 tag = tag[1:]
555 if _, ok := uc.batches[tag]; !ok {
556 return fmt.Errorf("unknown BATCH reference tag: %q", tag)
557 }
558 delete(uc.batches, tag)
559 } else {
560 return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
561 }
562 case "NICK":
563 if msg.Prefix == nil {
564 return fmt.Errorf("expected a prefix")
565 }
566
567 var newNick string
568 if err := parseMessageParams(msg, &newNick); err != nil {
569 return err
570 }
571
572 if msg.Prefix.Name == uc.nick {
573 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
574 uc.nick = newNick
575 }
576
577 for _, ch := range uc.channels {
578 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
579 delete(ch.Members, msg.Prefix.Name)
580 ch.Members[newNick] = membership
581 uc.appendLog(ch.Name, "*** %s is now known as %s", msg.Prefix.Name, newNick)
582 }
583 }
584
585 if msg.Prefix.Name != uc.nick {
586 uc.forEachDownstream(func(dc *downstreamConn) {
587 dc.SendMessage(&irc.Message{
588 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
589 Command: "NICK",
590 Params: []string{newNick},
591 })
592 })
593 }
594 case "JOIN":
595 if msg.Prefix == nil {
596 return fmt.Errorf("expected a prefix")
597 }
598
599 var channels string
600 if err := parseMessageParams(msg, &channels); err != nil {
601 return err
602 }
603
604 for _, ch := range strings.Split(channels, ",") {
605 if msg.Prefix.Name == uc.nick {
606 uc.logger.Printf("joined channel %q", ch)
607 uc.channels[ch] = &upstreamChannel{
608 Name: ch,
609 conn: uc,
610 Members: make(map[string]*membership),
611 }
612
613 uc.SendMessage(&irc.Message{
614 Command: "MODE",
615 Params: []string{ch},
616 })
617 } else {
618 ch, err := uc.getChannel(ch)
619 if err != nil {
620 return err
621 }
622 ch.Members[msg.Prefix.Name] = nil
623 }
624
625 uc.appendLog(ch, "*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host)
626
627 uc.forEachDownstream(func(dc *downstreamConn) {
628 dc.SendMessage(&irc.Message{
629 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
630 Command: "JOIN",
631 Params: []string{dc.marshalChannel(uc, ch)},
632 })
633 })
634 }
635 case "PART":
636 if msg.Prefix == nil {
637 return fmt.Errorf("expected a prefix")
638 }
639
640 var channels string
641 if err := parseMessageParams(msg, &channels); err != nil {
642 return err
643 }
644
645 var reason string
646 if len(msg.Params) > 1 {
647 reason = msg.Params[1]
648 }
649
650 for _, ch := range strings.Split(channels, ",") {
651 if msg.Prefix.Name == uc.nick {
652 uc.logger.Printf("parted channel %q", ch)
653 delete(uc.channels, ch)
654 } else {
655 ch, err := uc.getChannel(ch)
656 if err != nil {
657 return err
658 }
659 delete(ch.Members, msg.Prefix.Name)
660 }
661
662 uc.appendLog(ch, "*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
663
664 uc.forEachDownstream(func(dc *downstreamConn) {
665 dc.SendMessage(&irc.Message{
666 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
667 Command: "PART",
668 Params: []string{dc.marshalChannel(uc, ch)},
669 })
670 })
671 }
672 case "KICK":
673 if msg.Prefix == nil {
674 return fmt.Errorf("expected a prefix")
675 }
676
677 var channel, user string
678 if err := parseMessageParams(msg, &channel, &user); err != nil {
679 return err
680 }
681
682 var reason string
683 if len(msg.Params) > 2 {
684 reason = msg.Params[1]
685 }
686
687 if user == uc.nick {
688 uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
689 delete(uc.channels, channel)
690 } else {
691 ch, err := uc.getChannel(channel)
692 if err != nil {
693 return err
694 }
695 delete(ch.Members, user)
696 }
697
698 uc.appendLog(channel, "*** %s was kicked by %s (%s)", user, msg.Prefix.Name, reason)
699
700 uc.forEachDownstream(func(dc *downstreamConn) {
701 params := []string{dc.marshalChannel(uc, channel), dc.marshalNick(uc, user)}
702 if reason != "" {
703 params = append(params, reason)
704 }
705 dc.SendMessage(&irc.Message{
706 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
707 Command: "KICK",
708 Params: params,
709 })
710 })
711 case "QUIT":
712 if msg.Prefix == nil {
713 return fmt.Errorf("expected a prefix")
714 }
715
716 var reason string
717 if len(msg.Params) > 0 {
718 reason = msg.Params[0]
719 }
720
721 if msg.Prefix.Name == uc.nick {
722 uc.logger.Printf("quit")
723 }
724
725 for _, ch := range uc.channels {
726 if _, ok := ch.Members[msg.Prefix.Name]; ok {
727 delete(ch.Members, msg.Prefix.Name)
728
729 uc.appendLog(ch.Name, "*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
730 }
731 }
732
733 if msg.Prefix.Name != uc.nick {
734 uc.forEachDownstream(func(dc *downstreamConn) {
735 dc.SendMessage(&irc.Message{
736 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
737 Command: "QUIT",
738 Params: msg.Params,
739 })
740 })
741 }
742 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
743 var name, topic string
744 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
745 return err
746 }
747 ch, err := uc.getChannel(name)
748 if err != nil {
749 return err
750 }
751 if msg.Command == irc.RPL_TOPIC {
752 ch.Topic = topic
753 } else {
754 ch.Topic = ""
755 }
756 case "TOPIC":
757 var name string
758 if err := parseMessageParams(msg, &name); err != nil {
759 return err
760 }
761 ch, err := uc.getChannel(name)
762 if err != nil {
763 return err
764 }
765 if len(msg.Params) > 1 {
766 ch.Topic = msg.Params[1]
767 } else {
768 ch.Topic = ""
769 }
770 uc.forEachDownstream(func(dc *downstreamConn) {
771 params := []string{dc.marshalChannel(uc, name)}
772 if ch.Topic != "" {
773 params = append(params, ch.Topic)
774 }
775 dc.SendMessage(&irc.Message{
776 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
777 Command: "TOPIC",
778 Params: params,
779 })
780 })
781 case "MODE":
782 var name, modeStr string
783 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
784 return err
785 }
786
787 if !uc.isChannel(name) { // user mode change
788 if name != uc.nick {
789 return fmt.Errorf("received MODE message for unknown nick %q", name)
790 }
791 return uc.modes.Apply(modeStr)
792 // TODO: notify downstreams about user mode change?
793 } else { // channel mode change
794 ch, err := uc.getChannel(name)
795 if err != nil {
796 return err
797 }
798
799 if ch.modes != nil {
800 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[2:]...); err != nil {
801 return err
802 }
803 }
804
805 modeMsg := modeStr
806 for _, v := range msg.Params[2:] {
807 modeMsg += " " + v
808 }
809 uc.appendLog(ch.Name, "*** %s sets mode: %s", msg.Prefix.Name, modeMsg)
810
811 uc.forEachDownstream(func(dc *downstreamConn) {
812 params := []string{dc.marshalChannel(uc, name), modeStr}
813 params = append(params, msg.Params[2:]...)
814
815 dc.SendMessage(&irc.Message{
816 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
817 Command: "MODE",
818 Params: params,
819 })
820 })
821 }
822 case irc.RPL_UMODEIS:
823 if err := parseMessageParams(msg, nil); err != nil {
824 return err
825 }
826 modeStr := ""
827 if len(msg.Params) > 1 {
828 modeStr = msg.Params[1]
829 }
830
831 uc.modes = ""
832 if err := uc.modes.Apply(modeStr); err != nil {
833 return err
834 }
835 // TODO: send RPL_UMODEIS to downstream connections when applicable
836 case irc.RPL_CHANNELMODEIS:
837 var channel string
838 if err := parseMessageParams(msg, nil, &channel); err != nil {
839 return err
840 }
841 modeStr := ""
842 if len(msg.Params) > 2 {
843 modeStr = msg.Params[2]
844 }
845
846 ch, err := uc.getChannel(channel)
847 if err != nil {
848 return err
849 }
850
851 firstMode := ch.modes == nil
852 ch.modes = make(map[byte]string)
853 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[3:]...); err != nil {
854 return err
855 }
856 if firstMode {
857 modeStr, modeParams := ch.modes.Format()
858
859 uc.forEachDownstream(func(dc *downstreamConn) {
860 params := []string{dc.nick, dc.marshalChannel(uc, channel), modeStr}
861 params = append(params, modeParams...)
862
863 dc.SendMessage(&irc.Message{
864 Prefix: dc.srv.prefix(),
865 Command: irc.RPL_CHANNELMODEIS,
866 Params: params,
867 })
868 })
869 }
870 case rpl_creationtime:
871 var channel, creationTime string
872 if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil {
873 return err
874 }
875
876 ch, err := uc.getChannel(channel)
877 if err != nil {
878 return err
879 }
880
881 firstCreationTime := ch.creationTime == ""
882 ch.creationTime = creationTime
883 if firstCreationTime {
884 uc.forEachDownstream(func(dc *downstreamConn) {
885 dc.SendMessage(&irc.Message{
886 Prefix: dc.srv.prefix(),
887 Command: rpl_creationtime,
888 Params: []string{dc.nick, channel, creationTime},
889 })
890 })
891 }
892 case rpl_topicwhotime:
893 var name, who, timeStr string
894 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
895 return err
896 }
897 ch, err := uc.getChannel(name)
898 if err != nil {
899 return err
900 }
901 ch.TopicWho = who
902 sec, err := strconv.ParseInt(timeStr, 10, 64)
903 if err != nil {
904 return fmt.Errorf("failed to parse topic time: %v", err)
905 }
906 ch.TopicTime = time.Unix(sec, 0)
907 case irc.RPL_LIST:
908 var channel, clients, topic string
909 if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil {
910 return err
911 }
912
913 pl := uc.getPendingLIST()
914 if pl == nil {
915 return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
916 }
917
918 uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) {
919 dc.SendMessage(&irc.Message{
920 Prefix: dc.srv.prefix(),
921 Command: irc.RPL_LIST,
922 Params: []string{dc.nick, dc.marshalChannel(uc, channel), clients, topic},
923 })
924 })
925 case irc.RPL_LISTEND:
926 ok := uc.endPendingLISTs(false)
927 if !ok {
928 return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
929 }
930 case irc.RPL_NAMREPLY:
931 var name, statusStr, members string
932 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
933 return err
934 }
935
936 ch, ok := uc.channels[name]
937 if !ok {
938 // NAMES on a channel we have not joined, forward to downstream
939 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
940 channel := dc.marshalChannel(uc, name)
941 members := splitSpace(members)
942 for i, member := range members {
943 membership, nick := uc.parseMembershipPrefix(member)
944 members[i] = membership.String() + dc.marshalNick(uc, nick)
945 }
946 memberStr := strings.Join(members, " ")
947
948 dc.SendMessage(&irc.Message{
949 Prefix: dc.srv.prefix(),
950 Command: irc.RPL_NAMREPLY,
951 Params: []string{dc.nick, statusStr, channel, memberStr},
952 })
953 })
954 return nil
955 }
956
957 status, err := parseChannelStatus(statusStr)
958 if err != nil {
959 return err
960 }
961 ch.Status = status
962
963 for _, s := range splitSpace(members) {
964 membership, nick := uc.parseMembershipPrefix(s)
965 ch.Members[nick] = membership
966 }
967 case irc.RPL_ENDOFNAMES:
968 var name string
969 if err := parseMessageParams(msg, nil, &name); err != nil {
970 return err
971 }
972
973 ch, ok := uc.channels[name]
974 if !ok {
975 // NAMES on a channel we have not joined, forward to downstream
976 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
977 channel := dc.marshalChannel(uc, name)
978
979 dc.SendMessage(&irc.Message{
980 Prefix: dc.srv.prefix(),
981 Command: irc.RPL_ENDOFNAMES,
982 Params: []string{dc.nick, channel, "End of /NAMES list"},
983 })
984 })
985 return nil
986 }
987
988 if ch.complete {
989 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
990 }
991 ch.complete = true
992
993 uc.forEachDownstream(func(dc *downstreamConn) {
994 forwardChannel(dc, ch)
995 })
996 case irc.RPL_WHOREPLY:
997 var channel, username, host, server, nick, mode, trailing string
998 if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &mode, &trailing); err != nil {
999 return err
1000 }
1001
1002 parts := strings.SplitN(trailing, " ", 2)
1003 if len(parts) != 2 {
1004 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing)
1005 }
1006 realname := parts[1]
1007 hops, err := strconv.Atoi(parts[0])
1008 if err != nil {
1009 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong hop count: %s", parts[0])
1010 }
1011 hops++
1012
1013 trailing = strconv.Itoa(hops) + " " + realname
1014
1015 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1016 channel := channel
1017 if channel != "*" {
1018 channel = dc.marshalChannel(uc, channel)
1019 }
1020 nick := dc.marshalNick(uc, nick)
1021 dc.SendMessage(&irc.Message{
1022 Prefix: dc.srv.prefix(),
1023 Command: irc.RPL_WHOREPLY,
1024 Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
1025 })
1026 })
1027 case irc.RPL_ENDOFWHO:
1028 var name string
1029 if err := parseMessageParams(msg, nil, &name); err != nil {
1030 return err
1031 }
1032
1033 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1034 name := name
1035 if name != "*" {
1036 // TODO: support WHO masks
1037 name = dc.marshalEntity(uc, name)
1038 }
1039 dc.SendMessage(&irc.Message{
1040 Prefix: dc.srv.prefix(),
1041 Command: irc.RPL_ENDOFWHO,
1042 Params: []string{dc.nick, name, "End of /WHO list"},
1043 })
1044 })
1045 case irc.RPL_WHOISUSER:
1046 var nick, username, host, realname string
1047 if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
1048 return err
1049 }
1050
1051 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1052 nick := dc.marshalNick(uc, nick)
1053 dc.SendMessage(&irc.Message{
1054 Prefix: dc.srv.prefix(),
1055 Command: irc.RPL_WHOISUSER,
1056 Params: []string{dc.nick, nick, username, host, "*", realname},
1057 })
1058 })
1059 case irc.RPL_WHOISSERVER:
1060 var nick, server, serverInfo string
1061 if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
1062 return err
1063 }
1064
1065 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1066 nick := dc.marshalNick(uc, nick)
1067 dc.SendMessage(&irc.Message{
1068 Prefix: dc.srv.prefix(),
1069 Command: irc.RPL_WHOISSERVER,
1070 Params: []string{dc.nick, nick, server, serverInfo},
1071 })
1072 })
1073 case irc.RPL_WHOISOPERATOR:
1074 var nick string
1075 if err := parseMessageParams(msg, nil, &nick); err != nil {
1076 return err
1077 }
1078
1079 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1080 nick := dc.marshalNick(uc, nick)
1081 dc.SendMessage(&irc.Message{
1082 Prefix: dc.srv.prefix(),
1083 Command: irc.RPL_WHOISOPERATOR,
1084 Params: []string{dc.nick, nick, "is an IRC operator"},
1085 })
1086 })
1087 case irc.RPL_WHOISIDLE:
1088 var nick string
1089 if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
1090 return err
1091 }
1092
1093 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1094 nick := dc.marshalNick(uc, nick)
1095 params := []string{dc.nick, nick}
1096 params = append(params, msg.Params[2:]...)
1097 dc.SendMessage(&irc.Message{
1098 Prefix: dc.srv.prefix(),
1099 Command: irc.RPL_WHOISIDLE,
1100 Params: params,
1101 })
1102 })
1103 case irc.RPL_WHOISCHANNELS:
1104 var nick, channelList string
1105 if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
1106 return err
1107 }
1108 channels := splitSpace(channelList)
1109
1110 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1111 nick := dc.marshalNick(uc, nick)
1112 channelList := make([]string, len(channels))
1113 for i, channel := range channels {
1114 prefix, channel := uc.parseMembershipPrefix(channel)
1115 channel = dc.marshalChannel(uc, channel)
1116 channelList[i] = prefix.String() + channel
1117 }
1118 channels := strings.Join(channelList, " ")
1119 dc.SendMessage(&irc.Message{
1120 Prefix: dc.srv.prefix(),
1121 Command: irc.RPL_WHOISCHANNELS,
1122 Params: []string{dc.nick, nick, channels},
1123 })
1124 })
1125 case irc.RPL_ENDOFWHOIS:
1126 var nick string
1127 if err := parseMessageParams(msg, nil, &nick); err != nil {
1128 return err
1129 }
1130
1131 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1132 nick := dc.marshalNick(uc, nick)
1133 dc.SendMessage(&irc.Message{
1134 Prefix: dc.srv.prefix(),
1135 Command: irc.RPL_ENDOFWHOIS,
1136 Params: []string{dc.nick, nick, "End of /WHOIS list"},
1137 })
1138 })
1139 case "PRIVMSG":
1140 if msg.Prefix == nil {
1141 return fmt.Errorf("expected a prefix")
1142 }
1143
1144 var nick, text string
1145 if err := parseMessageParams(msg, &nick, &text); err != nil {
1146 return err
1147 }
1148
1149 if msg.Prefix.Name == serviceNick {
1150 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
1151 break
1152 }
1153 if nick == serviceNick {
1154 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
1155 break
1156 }
1157
1158 if _, ok := msg.Tags["time"]; !ok {
1159 msg.Tags["time"] = irc.TagValue(time.Now().Format(serverTimeLayout))
1160 }
1161
1162 target := nick
1163 if nick == uc.nick {
1164 target = msg.Prefix.Name
1165 }
1166 uc.appendLog(target, "<%s> %s", msg.Prefix.Name, text)
1167
1168 uc.network.ring.Produce(msg)
1169 case "INVITE":
1170 var nick string
1171 var channel string
1172 if err := parseMessageParams(msg, &nick, &channel); err != nil {
1173 return err
1174 }
1175
1176 uc.forEachDownstream(func(dc *downstreamConn) {
1177 dc.SendMessage(&irc.Message{
1178 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
1179 Command: "INVITE",
1180 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
1181 })
1182 })
1183 case irc.RPL_INVITING:
1184 var nick string
1185 var channel string
1186 if err := parseMessageParams(msg, &nick, &channel); err != nil {
1187 return err
1188 }
1189
1190 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1191 dc.SendMessage(&irc.Message{
1192 Prefix: dc.srv.prefix(),
1193 Command: irc.RPL_INVITING,
1194 Params: []string{dc.nick, dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
1195 })
1196 })
1197 case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN:
1198 var command, reason string
1199 if err := parseMessageParams(msg, nil, &command, &reason); err != nil {
1200 return err
1201 }
1202
1203 if command == "LIST" {
1204 ok := uc.endPendingLISTs(false)
1205 if !ok {
1206 return fmt.Errorf("unexpected response for LIST: %q: no matching pending LIST", msg.Command)
1207 }
1208 uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
1209 dc.SendMessage(&irc.Message{
1210 Prefix: uc.srv.prefix(),
1211 Command: msg.Command,
1212 Params: []string{dc.nick, "LIST", reason},
1213 })
1214 })
1215 }
1216 case "TAGMSG":
1217 // TODO: relay to downstream connections that accept message-tags
1218 case "ACK":
1219 // Ignore
1220 case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
1221 // Ignore
1222 case irc.RPL_YOURHOST, irc.RPL_CREATED:
1223 // Ignore
1224 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
1225 // Ignore
1226 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
1227 // Ignore
1228 case irc.RPL_LISTSTART:
1229 // Ignore
1230 case rpl_localusers, rpl_globalusers:
1231 // Ignore
1232 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
1233 // Ignore
1234 default:
1235 uc.logger.Printf("unhandled message: %v", msg)
1236 }
1237 return nil
1238}
1239
1240func splitSpace(s string) []string {
1241 return strings.FieldsFunc(s, func(r rune) bool {
1242 return r == ' '
1243 })
1244}
1245
1246func (uc *upstreamConn) register() {
1247 uc.nick = uc.network.Nick
1248 uc.username = uc.network.Username
1249 if uc.username == "" {
1250 uc.username = uc.nick
1251 }
1252 uc.realname = uc.network.Realname
1253 if uc.realname == "" {
1254 uc.realname = uc.nick
1255 }
1256
1257 uc.SendMessage(&irc.Message{
1258 Command: "CAP",
1259 Params: []string{"LS", "302"},
1260 })
1261
1262 if uc.network.Pass != "" {
1263 uc.SendMessage(&irc.Message{
1264 Command: "PASS",
1265 Params: []string{uc.network.Pass},
1266 })
1267 }
1268
1269 uc.SendMessage(&irc.Message{
1270 Command: "NICK",
1271 Params: []string{uc.nick},
1272 })
1273 uc.SendMessage(&irc.Message{
1274 Command: "USER",
1275 Params: []string{uc.username, "0", "*", uc.realname},
1276 })
1277}
1278
1279func (uc *upstreamConn) runUntilRegistered() error {
1280 for !uc.registered {
1281 msg, err := uc.irc.ReadMessage()
1282 if err != nil {
1283 return fmt.Errorf("failed to read message: %v", err)
1284 }
1285
1286 if uc.srv.Debug {
1287 uc.logger.Printf("received: %v", msg)
1288 }
1289
1290 if err := uc.handleMessage(msg); err != nil {
1291 return fmt.Errorf("failed to handle message %q: %v", msg, err)
1292 }
1293 }
1294
1295 return nil
1296}
1297
1298func (uc *upstreamConn) requestSASL() bool {
1299 if uc.network.SASL.Mechanism == "" {
1300 return false
1301 }
1302
1303 v, ok := uc.caps["sasl"]
1304 if !ok {
1305 return false
1306 }
1307 if v != "" {
1308 mechanisms := strings.Split(v, ",")
1309 found := false
1310 for _, mech := range mechanisms {
1311 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
1312 found = true
1313 break
1314 }
1315 }
1316 if !found {
1317 return false
1318 }
1319 }
1320
1321 return true
1322}
1323
1324func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
1325 auth := &uc.network.SASL
1326 switch name {
1327 case "sasl":
1328 if !ok {
1329 uc.logger.Printf("server refused to acknowledge the SASL capability")
1330 return nil
1331 }
1332
1333 switch auth.Mechanism {
1334 case "PLAIN":
1335 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
1336 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
1337 default:
1338 return fmt.Errorf("unsupported SASL mechanism %q", name)
1339 }
1340
1341 uc.SendMessage(&irc.Message{
1342 Command: "AUTHENTICATE",
1343 Params: []string{auth.Mechanism},
1344 })
1345 case "message-tags":
1346 uc.tagsSupported = ok
1347 case "labeled-response":
1348 uc.labelsSupported = ok
1349 case "batch", "server-time":
1350 // Nothing to do
1351 default:
1352 uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
1353 }
1354 return nil
1355}
1356
1357func (uc *upstreamConn) readMessages(ch chan<- event) error {
1358 for {
1359 msg, err := uc.ReadMessage()
1360 if err == io.EOF {
1361 break
1362 } else if err != nil {
1363 return fmt.Errorf("failed to read IRC command: %v", err)
1364 }
1365
1366 ch <- eventUpstreamMessage{msg, uc}
1367 }
1368
1369 return nil
1370}
1371
1372func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
1373 if uc.labelsSupported {
1374 if msg.Tags == nil {
1375 msg.Tags = make(map[string]irc.TagValue)
1376 }
1377 msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
1378 uc.nextLabelID++
1379 }
1380 uc.SendMessage(msg)
1381}
1382
1383// TODO: handle moving logs when a network name changes, when support for this is added
1384func (uc *upstreamConn) appendLog(entity string, format string, a ...interface{}) {
1385 if uc.srv.LogPath == "" {
1386 return
1387 }
1388 // TODO: enforce maximum open file handles (LRU cache of file handles)
1389 // TODO: handle non-monotonic clock behaviour
1390 now := time.Now()
1391 year, month, day := now.Date()
1392 name := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
1393 log, ok := uc.logs[entity]
1394 if !ok || log.name != name {
1395 if ok {
1396 log.file.Close()
1397 delete(uc.logs, entity)
1398 }
1399 // TODO: handle/forbid network/entity names with illegal path characters
1400 dir := filepath.Join(uc.srv.LogPath, uc.user.Username, uc.network.Name, entity)
1401 if err := os.MkdirAll(dir, 0700); err != nil {
1402 uc.logger.Printf("failed to log message: could not create logs directory %q: %v", dir, err)
1403 return
1404 }
1405 path := filepath.Join(dir, name)
1406 f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600)
1407 if err != nil {
1408 uc.logger.Printf("failed to log message: could not open or create log file %q: %v", path, err)
1409 return
1410 }
1411 log = entityLog{
1412 name: name,
1413 file: f,
1414 }
1415 uc.logs[entity] = log
1416 }
1417
1418 format = "[%02d:%02d:%02d] " + format + "\n"
1419 args := []interface{}{now.Hour(), now.Minute(), now.Second()}
1420 args = append(args, a...)
1421
1422 if _, err := fmt.Fprintf(log.file, format, args...); err != nil {
1423 uc.logger.Printf("failed to log message to %q: %v", log.name, err)
1424 }
1425}
1426
1427func (uc *upstreamConn) updateAway() {
1428 away := true
1429 uc.forEachDownstream(func(*downstreamConn) {
1430 away = false
1431 })
1432 if away == uc.away {
1433 return
1434 }
1435 if away {
1436 uc.SendMessage(&irc.Message{
1437 Command: "AWAY",
1438 Params: []string{"Auto away"},
1439 })
1440 } else {
1441 uc.SendMessage(&irc.Message{
1442 Command: "AWAY",
1443 })
1444 }
1445 uc.away = away
1446}
Note: See TracBrowser for help on using the repository browser.