source: code/trunk/upstream.go@ 159

Last change on this file since 159 was 159, checked in by delthas, 5 years ago

Add KICK support

Downstream and upstream message handling are slightly different because
downstreams can send KICK messages with multiple channels or users,
while upstreams can only send KICK messages with one channel and one
user (according to the RFC).

File size: 28.2 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "errors"
7 "fmt"
8 "io"
9 "net"
10 "strconv"
11 "strings"
12 "time"
13
14 "github.com/emersion/go-sasl"
15 "gopkg.in/irc.v3"
16)
17
18type upstreamChannel struct {
19 Name string
20 conn *upstreamConn
21 Topic string
22 TopicWho string
23 TopicTime time.Time
24 Status channelStatus
25 modes channelModes
26 Members map[string]*membership
27 complete bool
28}
29
30type upstreamConn struct {
31 network *network
32 logger Logger
33 net net.Conn
34 irc *irc.Conn
35 srv *Server
36 user *user
37 outgoing chan<- *irc.Message
38
39 serverName string
40 availableUserModes string
41 availableChannelModes map[byte]channelModeType
42 availableChannelTypes string
43 availableMemberships []membership
44
45 registered bool
46 nick string
47 username string
48 realname string
49 closed bool
50 modes userModes
51 channels map[string]*upstreamChannel
52 caps map[string]string
53 batches map[string]batch
54
55 tagsSupported bool
56 labelsSupported bool
57 nextLabelId uint64
58
59 saslClient sasl.Client
60 saslStarted bool
61}
62
63func connectToUpstream(network *network) (*upstreamConn, error) {
64 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
65
66 addr := network.Addr
67 if !strings.ContainsRune(addr, ':') {
68 addr = addr + ":6697"
69 }
70
71 logger.Printf("connecting to TLS server at address %q", addr)
72 netConn, err := tls.Dial("tcp", addr, nil)
73 if err != nil {
74 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
75 }
76
77 setKeepAlive(netConn)
78
79 outgoing := make(chan *irc.Message, 64)
80 uc := &upstreamConn{
81 network: network,
82 logger: logger,
83 net: netConn,
84 irc: irc.NewConn(netConn),
85 srv: network.user.srv,
86 user: network.user,
87 outgoing: outgoing,
88 channels: make(map[string]*upstreamChannel),
89 caps: make(map[string]string),
90 batches: make(map[string]batch),
91 availableChannelTypes: stdChannelTypes,
92 availableChannelModes: stdChannelModes,
93 availableMemberships: stdMemberships,
94 }
95
96 go func() {
97 for msg := range outgoing {
98 if uc.srv.Debug {
99 uc.logger.Printf("sent: %v", msg)
100 }
101 if err := uc.irc.WriteMessage(msg); err != nil {
102 uc.logger.Printf("failed to write message: %v", err)
103 }
104 }
105 if err := uc.net.Close(); err != nil {
106 uc.logger.Printf("failed to close connection: %v", err)
107 } else {
108 uc.logger.Printf("connection closed")
109 }
110 }()
111
112 return uc, nil
113}
114
115func (uc *upstreamConn) Close() error {
116 if uc.closed {
117 return fmt.Errorf("upstream connection already closed")
118 }
119 close(uc.outgoing)
120 uc.closed = true
121 return nil
122}
123
124func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
125 uc.user.forEachDownstream(func(dc *downstreamConn) {
126 if dc.network != nil && dc.network != uc.network {
127 return
128 }
129 f(dc)
130 })
131}
132
133func (uc *upstreamConn) forEachDownstreamById(id uint64, f func(*downstreamConn)) {
134 uc.forEachDownstream(func(dc *downstreamConn) {
135 if id != 0 && id != dc.id {
136 return
137 }
138 f(dc)
139 })
140}
141
142func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
143 ch, ok := uc.channels[name]
144 if !ok {
145 return nil, fmt.Errorf("unknown channel %q", name)
146 }
147 return ch, nil
148}
149
150func (uc *upstreamConn) isChannel(entity string) bool {
151 if i := strings.IndexByte(uc.availableChannelTypes, entity[0]); i >= 0 {
152 return true
153 }
154 return false
155}
156
157func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership, nick string) {
158 for _, m := range uc.availableMemberships {
159 if m.Prefix == s[0] {
160 return &m, s[1:]
161 }
162 }
163 return nil, s
164}
165
166func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
167 var label string
168 if l, ok := msg.GetTag("label"); ok {
169 label = l
170 }
171
172 var msgBatch *batch
173 if batchName, ok := msg.GetTag("batch"); ok {
174 b, ok := uc.batches[batchName]
175 if !ok {
176 return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
177 }
178 msgBatch = &b
179 if label == "" {
180 label = msgBatch.Label
181 }
182 }
183
184 var downstreamId uint64 = 0
185 if label != "" {
186 var labelOffset uint64
187 n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamId, &labelOffset)
188 if err == nil && n < 2 {
189 err = errors.New("not enough arguments")
190 }
191 if err != nil {
192 return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
193 }
194 }
195
196 switch msg.Command {
197 case "PING":
198 uc.SendMessage(&irc.Message{
199 Command: "PONG",
200 Params: msg.Params,
201 })
202 return nil
203 case "NOTICE":
204 uc.logger.Print(msg)
205
206 uc.forEachDownstream(func(dc *downstreamConn) {
207 dc.SendMessage(msg)
208 })
209 case "CAP":
210 var subCmd string
211 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
212 return err
213 }
214 subCmd = strings.ToUpper(subCmd)
215 subParams := msg.Params[2:]
216 switch subCmd {
217 case "LS":
218 if len(subParams) < 1 {
219 return newNeedMoreParamsError(msg.Command)
220 }
221 caps := strings.Fields(subParams[len(subParams)-1])
222 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
223
224 for _, s := range caps {
225 kv := strings.SplitN(s, "=", 2)
226 k := strings.ToLower(kv[0])
227 var v string
228 if len(kv) == 2 {
229 v = kv[1]
230 }
231 uc.caps[k] = v
232 }
233
234 if more {
235 break // wait to receive all capabilities
236 }
237
238 requestCaps := make([]string, 0, 16)
239 for _, c := range []string{"message-tags", "batch", "labeled-response"} {
240 if _, ok := uc.caps[c]; ok {
241 requestCaps = append(requestCaps, c)
242 }
243 }
244
245 if uc.requestSASL() {
246 requestCaps = append(requestCaps, "sasl")
247 }
248
249 if len(requestCaps) > 0 {
250 uc.SendMessage(&irc.Message{
251 Command: "CAP",
252 Params: []string{"REQ", strings.Join(requestCaps, " ")},
253 })
254 }
255
256 if uc.requestSASL() {
257 break // we'll send CAP END after authentication is completed
258 }
259
260 uc.SendMessage(&irc.Message{
261 Command: "CAP",
262 Params: []string{"END"},
263 })
264 case "ACK", "NAK":
265 if len(subParams) < 1 {
266 return newNeedMoreParamsError(msg.Command)
267 }
268 caps := strings.Fields(subParams[0])
269
270 for _, name := range caps {
271 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
272 return err
273 }
274 }
275
276 if uc.saslClient == nil {
277 uc.SendMessage(&irc.Message{
278 Command: "CAP",
279 Params: []string{"END"},
280 })
281 }
282 default:
283 uc.logger.Printf("unhandled message: %v", msg)
284 }
285 case "AUTHENTICATE":
286 if uc.saslClient == nil {
287 return fmt.Errorf("received unexpected AUTHENTICATE message")
288 }
289
290 // TODO: if a challenge is 400 bytes long, buffer it
291 var challengeStr string
292 if err := parseMessageParams(msg, &challengeStr); err != nil {
293 uc.SendMessage(&irc.Message{
294 Command: "AUTHENTICATE",
295 Params: []string{"*"},
296 })
297 return err
298 }
299
300 var challenge []byte
301 if challengeStr != "+" {
302 var err error
303 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
304 if err != nil {
305 uc.SendMessage(&irc.Message{
306 Command: "AUTHENTICATE",
307 Params: []string{"*"},
308 })
309 return err
310 }
311 }
312
313 var resp []byte
314 var err error
315 if !uc.saslStarted {
316 _, resp, err = uc.saslClient.Start()
317 uc.saslStarted = true
318 } else {
319 resp, err = uc.saslClient.Next(challenge)
320 }
321 if err != nil {
322 uc.SendMessage(&irc.Message{
323 Command: "AUTHENTICATE",
324 Params: []string{"*"},
325 })
326 return err
327 }
328
329 // TODO: send response in multiple chunks if >= 400 bytes
330 var respStr = "+"
331 if resp != nil {
332 respStr = base64.StdEncoding.EncodeToString(resp)
333 }
334
335 uc.SendMessage(&irc.Message{
336 Command: "AUTHENTICATE",
337 Params: []string{respStr},
338 })
339 case irc.RPL_LOGGEDIN:
340 var account string
341 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
342 return err
343 }
344 uc.logger.Printf("logged in with account %q", account)
345 case irc.RPL_LOGGEDOUT:
346 uc.logger.Printf("logged out")
347 case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
348 var info string
349 if err := parseMessageParams(msg, nil, &info); err != nil {
350 return err
351 }
352 switch msg.Command {
353 case irc.ERR_NICKLOCKED:
354 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
355 case irc.ERR_SASLFAIL:
356 uc.logger.Printf("SASL authentication failed: %v", info)
357 case irc.ERR_SASLTOOLONG:
358 uc.logger.Printf("SASL message too long: %v", info)
359 }
360
361 uc.saslClient = nil
362 uc.saslStarted = false
363
364 uc.SendMessage(&irc.Message{
365 Command: "CAP",
366 Params: []string{"END"},
367 })
368 case irc.RPL_WELCOME:
369 uc.registered = true
370 uc.logger.Printf("connection registered")
371
372 channels, err := uc.srv.db.ListChannels(uc.network.ID)
373 if err != nil {
374 uc.logger.Printf("failed to list channels from database: %v", err)
375 break
376 }
377
378 for _, ch := range channels {
379 params := []string{ch.Name}
380 if ch.Key != "" {
381 params = append(params, ch.Key)
382 }
383 uc.SendMessage(&irc.Message{
384 Command: "JOIN",
385 Params: params,
386 })
387 }
388 case irc.RPL_MYINFO:
389 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
390 return err
391 }
392 case irc.RPL_ISUPPORT:
393 if err := parseMessageParams(msg, nil, nil); err != nil {
394 return err
395 }
396 for _, token := range msg.Params[1 : len(msg.Params)-1] {
397 negate := false
398 parameter := token
399 value := ""
400 if strings.HasPrefix(token, "-") {
401 negate = true
402 token = token[1:]
403 } else {
404 if i := strings.IndexByte(token, '='); i >= 0 {
405 parameter = token[:i]
406 value = token[i+1:]
407 }
408 }
409 if !negate {
410 switch parameter {
411 case "CHANMODES":
412 parts := strings.SplitN(value, ",", 5)
413 if len(parts) < 4 {
414 return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", value)
415 }
416 modes := make(map[byte]channelModeType)
417 for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
418 for j := 0; j < len(parts[i]); j++ {
419 mode := parts[i][j]
420 modes[mode] = mt
421 }
422 }
423 uc.availableChannelModes = modes
424 case "CHANTYPES":
425 uc.availableChannelTypes = value
426 case "PREFIX":
427 if value == "" {
428 uc.availableMemberships = nil
429 } else {
430 if value[0] != '(' {
431 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
432 }
433 sep := strings.IndexByte(value, ')')
434 if sep < 0 || len(value) != sep*2 {
435 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
436 }
437 memberships := make([]membership, len(value)/2-1)
438 for i := range memberships {
439 memberships[i] = membership{
440 Mode: value[i+1],
441 Prefix: value[sep+i+1],
442 }
443 }
444 uc.availableMemberships = memberships
445 }
446 }
447 } else {
448 // TODO: handle ISUPPORT negations
449 }
450 }
451 case "BATCH":
452 var tag string
453 if err := parseMessageParams(msg, &tag); err != nil {
454 return err
455 }
456
457 if strings.HasPrefix(tag, "+") {
458 tag = tag[1:]
459 if _, ok := uc.batches[tag]; ok {
460 return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
461 }
462 var batchType string
463 if err := parseMessageParams(msg, nil, &batchType); err != nil {
464 return err
465 }
466 label := label
467 if label == "" && msgBatch != nil {
468 label = msgBatch.Label
469 }
470 uc.batches[tag] = batch{
471 Type: batchType,
472 Params: msg.Params[2:],
473 Outer: msgBatch,
474 Label: label,
475 }
476 } else if strings.HasPrefix(tag, "-") {
477 tag = tag[1:]
478 if _, ok := uc.batches[tag]; !ok {
479 return fmt.Errorf("unknown BATCH reference tag: %q", tag)
480 }
481 delete(uc.batches, tag)
482 } else {
483 return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
484 }
485 case "NICK":
486 if msg.Prefix == nil {
487 return fmt.Errorf("expected a prefix")
488 }
489
490 var newNick string
491 if err := parseMessageParams(msg, &newNick); err != nil {
492 return err
493 }
494
495 if msg.Prefix.Name == uc.nick {
496 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
497 uc.nick = newNick
498 }
499
500 for _, ch := range uc.channels {
501 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
502 delete(ch.Members, msg.Prefix.Name)
503 ch.Members[newNick] = membership
504 }
505 }
506
507 if msg.Prefix.Name != uc.nick {
508 uc.forEachDownstream(func(dc *downstreamConn) {
509 dc.SendMessage(&irc.Message{
510 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
511 Command: "NICK",
512 Params: []string{newNick},
513 })
514 })
515 }
516 case "JOIN":
517 if msg.Prefix == nil {
518 return fmt.Errorf("expected a prefix")
519 }
520
521 var channels string
522 if err := parseMessageParams(msg, &channels); err != nil {
523 return err
524 }
525
526 for _, ch := range strings.Split(channels, ",") {
527 if msg.Prefix.Name == uc.nick {
528 uc.logger.Printf("joined channel %q", ch)
529 uc.channels[ch] = &upstreamChannel{
530 Name: ch,
531 conn: uc,
532 Members: make(map[string]*membership),
533 }
534
535 uc.SendMessage(&irc.Message{
536 Command: "MODE",
537 Params: []string{ch},
538 })
539 } else {
540 ch, err := uc.getChannel(ch)
541 if err != nil {
542 return err
543 }
544 ch.Members[msg.Prefix.Name] = nil
545 }
546
547 uc.forEachDownstream(func(dc *downstreamConn) {
548 dc.SendMessage(&irc.Message{
549 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
550 Command: "JOIN",
551 Params: []string{dc.marshalChannel(uc, ch)},
552 })
553 })
554 }
555 case "PART":
556 if msg.Prefix == nil {
557 return fmt.Errorf("expected a prefix")
558 }
559
560 var channels string
561 if err := parseMessageParams(msg, &channels); err != nil {
562 return err
563 }
564
565 for _, ch := range strings.Split(channels, ",") {
566 if msg.Prefix.Name == uc.nick {
567 uc.logger.Printf("parted channel %q", ch)
568 delete(uc.channels, ch)
569 } else {
570 ch, err := uc.getChannel(ch)
571 if err != nil {
572 return err
573 }
574 delete(ch.Members, msg.Prefix.Name)
575 }
576
577 uc.forEachDownstream(func(dc *downstreamConn) {
578 dc.SendMessage(&irc.Message{
579 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
580 Command: "PART",
581 Params: []string{dc.marshalChannel(uc, ch)},
582 })
583 })
584 }
585 case "KICK":
586 if msg.Prefix == nil {
587 return fmt.Errorf("expected a prefix")
588 }
589
590 var channel, user string
591 if err := parseMessageParams(msg, &channel, &user); err != nil {
592 return err
593 }
594
595 var reason string
596 if len(msg.Params) > 2 {
597 reason = msg.Params[1]
598 }
599
600 if user == uc.nick {
601 uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
602 delete(uc.channels, channel)
603 } else {
604 ch, err := uc.getChannel(channel)
605 if err != nil {
606 return err
607 }
608 delete(ch.Members, user)
609 }
610
611 uc.forEachDownstream(func(dc *downstreamConn) {
612 params := []string{dc.marshalChannel(uc, channel), dc.marshalNick(uc, user)}
613 if reason != "" {
614 params = append(params, reason)
615 }
616 dc.SendMessage(&irc.Message{
617 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
618 Command: "KICK",
619 Params: params,
620 })
621 })
622 case "QUIT":
623 if msg.Prefix == nil {
624 return fmt.Errorf("expected a prefix")
625 }
626
627 if msg.Prefix.Name == uc.nick {
628 uc.logger.Printf("quit")
629 }
630
631 for _, ch := range uc.channels {
632 delete(ch.Members, msg.Prefix.Name)
633 }
634
635 if msg.Prefix.Name != uc.nick {
636 uc.forEachDownstream(func(dc *downstreamConn) {
637 dc.SendMessage(&irc.Message{
638 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
639 Command: "QUIT",
640 Params: msg.Params,
641 })
642 })
643 }
644 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
645 var name, topic string
646 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
647 return err
648 }
649 ch, err := uc.getChannel(name)
650 if err != nil {
651 return err
652 }
653 if msg.Command == irc.RPL_TOPIC {
654 ch.Topic = topic
655 } else {
656 ch.Topic = ""
657 }
658 case "TOPIC":
659 var name string
660 if err := parseMessageParams(msg, &name); err != nil {
661 return err
662 }
663 ch, err := uc.getChannel(name)
664 if err != nil {
665 return err
666 }
667 if len(msg.Params) > 1 {
668 ch.Topic = msg.Params[1]
669 } else {
670 ch.Topic = ""
671 }
672 uc.forEachDownstream(func(dc *downstreamConn) {
673 params := []string{dc.marshalChannel(uc, name)}
674 if ch.Topic != "" {
675 params = append(params, ch.Topic)
676 }
677 dc.SendMessage(&irc.Message{
678 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
679 Command: "TOPIC",
680 Params: params,
681 })
682 })
683 case "MODE":
684 var name, modeStr string
685 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
686 return err
687 }
688
689 if !uc.isChannel(name) { // user mode change
690 if name != uc.nick {
691 return fmt.Errorf("received MODE message for unknown nick %q", name)
692 }
693 return uc.modes.Apply(modeStr)
694 // TODO: notify downstreams about user mode change?
695 } else { // channel mode change
696 ch, err := uc.getChannel(name)
697 if err != nil {
698 return err
699 }
700
701 if ch.modes != nil {
702 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[2:]...); err != nil {
703 return err
704 }
705 }
706
707 uc.forEachDownstream(func(dc *downstreamConn) {
708 params := []string{dc.marshalChannel(uc, name), modeStr}
709 params = append(params, msg.Params[2:]...)
710
711 dc.SendMessage(&irc.Message{
712 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
713 Command: "MODE",
714 Params: params,
715 })
716 })
717 }
718 case irc.RPL_UMODEIS:
719 if err := parseMessageParams(msg, nil); err != nil {
720 return err
721 }
722 modeStr := ""
723 if len(msg.Params) > 1 {
724 modeStr = msg.Params[1]
725 }
726
727 uc.modes = ""
728 if err := uc.modes.Apply(modeStr); err != nil {
729 return err
730 }
731 // TODO: send RPL_UMODEIS to downstream connections when applicable
732 case irc.RPL_CHANNELMODEIS:
733 var channel string
734 if err := parseMessageParams(msg, nil, &channel); err != nil {
735 return err
736 }
737 modeStr := ""
738 if len(msg.Params) > 2 {
739 modeStr = msg.Params[2]
740 }
741
742 ch, err := uc.getChannel(channel)
743 if err != nil {
744 return err
745 }
746
747 firstMode := ch.modes == nil
748 ch.modes = make(map[byte]string)
749 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[3:]...); err != nil {
750 return err
751 }
752 if firstMode {
753 modeStr, modeParams := ch.modes.Format()
754
755 uc.forEachDownstream(func(dc *downstreamConn) {
756 params := []string{dc.nick, dc.marshalChannel(uc, channel), modeStr}
757 params = append(params, modeParams...)
758
759 dc.SendMessage(&irc.Message{
760 Prefix: dc.srv.prefix(),
761 Command: irc.RPL_CHANNELMODEIS,
762 Params: params,
763 })
764 })
765 }
766 case rpl_topicwhotime:
767 var name, who, timeStr string
768 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
769 return err
770 }
771 ch, err := uc.getChannel(name)
772 if err != nil {
773 return err
774 }
775 ch.TopicWho = who
776 sec, err := strconv.ParseInt(timeStr, 10, 64)
777 if err != nil {
778 return fmt.Errorf("failed to parse topic time: %v", err)
779 }
780 ch.TopicTime = time.Unix(sec, 0)
781 case irc.RPL_NAMREPLY:
782 var name, statusStr, members string
783 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
784 return err
785 }
786 members = strings.TrimRight(members, " ")
787
788 ch, ok := uc.channels[name]
789 if !ok {
790 // NAMES on a channel we have not joined, forward to downstream
791 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
792 channel := dc.marshalChannel(uc, name)
793 members := strings.Split(members, " ")
794 for i, member := range members {
795 membership, nick := uc.parseMembershipPrefix(member)
796 members[i] = membership.String() + dc.marshalNick(uc, nick)
797 }
798 memberStr := strings.Join(members, " ")
799
800 dc.SendMessage(&irc.Message{
801 Prefix: dc.srv.prefix(),
802 Command: irc.RPL_NAMREPLY,
803 Params: []string{dc.nick, statusStr, channel, memberStr},
804 })
805 })
806 return nil
807 }
808
809 status, err := parseChannelStatus(statusStr)
810 if err != nil {
811 return err
812 }
813 ch.Status = status
814
815 for _, s := range strings.Split(members, " ") {
816 membership, nick := uc.parseMembershipPrefix(s)
817 ch.Members[nick] = membership
818 }
819 case irc.RPL_ENDOFNAMES:
820 var name string
821 if err := parseMessageParams(msg, nil, &name); err != nil {
822 return err
823 }
824
825 ch, ok := uc.channels[name]
826 if !ok {
827 // NAMES on a channel we have not joined, forward to downstream
828 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
829 channel := dc.marshalChannel(uc, name)
830
831 dc.SendMessage(&irc.Message{
832 Prefix: dc.srv.prefix(),
833 Command: irc.RPL_ENDOFNAMES,
834 Params: []string{dc.nick, channel, "End of /NAMES list"},
835 })
836 })
837 return nil
838 }
839
840 if ch.complete {
841 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
842 }
843 ch.complete = true
844
845 uc.forEachDownstream(func(dc *downstreamConn) {
846 forwardChannel(dc, ch)
847 })
848 case irc.RPL_WHOREPLY:
849 var channel, username, host, server, nick, mode, trailing string
850 if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &mode, &trailing); err != nil {
851 return err
852 }
853
854 parts := strings.SplitN(trailing, " ", 2)
855 if len(parts) != 2 {
856 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing)
857 }
858 realname := parts[1]
859 hops, err := strconv.Atoi(parts[0])
860 if err != nil {
861 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong hop count: %s", parts[0])
862 }
863 hops++
864
865 trailing = strconv.Itoa(hops) + " " + realname
866
867 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
868 channel := channel
869 if channel != "*" {
870 channel = dc.marshalChannel(uc, channel)
871 }
872 nick := dc.marshalNick(uc, nick)
873 dc.SendMessage(&irc.Message{
874 Prefix: dc.srv.prefix(),
875 Command: irc.RPL_WHOREPLY,
876 Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
877 })
878 })
879 case irc.RPL_ENDOFWHO:
880 var name string
881 if err := parseMessageParams(msg, nil, &name); err != nil {
882 return err
883 }
884
885 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
886 name := name
887 if name != "*" {
888 // TODO: support WHO masks
889 name = dc.marshalEntity(uc, name)
890 }
891 dc.SendMessage(&irc.Message{
892 Prefix: dc.srv.prefix(),
893 Command: irc.RPL_ENDOFWHO,
894 Params: []string{dc.nick, name, "End of /WHO list"},
895 })
896 })
897 case irc.RPL_WHOISUSER:
898 var nick, username, host, realname string
899 if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
900 return err
901 }
902
903 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
904 nick := dc.marshalNick(uc, nick)
905 dc.SendMessage(&irc.Message{
906 Prefix: dc.srv.prefix(),
907 Command: irc.RPL_WHOISUSER,
908 Params: []string{dc.nick, nick, username, host, "*", realname},
909 })
910 })
911 case irc.RPL_WHOISSERVER:
912 var nick, server, serverInfo string
913 if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
914 return err
915 }
916
917 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
918 nick := dc.marshalNick(uc, nick)
919 dc.SendMessage(&irc.Message{
920 Prefix: dc.srv.prefix(),
921 Command: irc.RPL_WHOISSERVER,
922 Params: []string{dc.nick, nick, server, serverInfo},
923 })
924 })
925 case irc.RPL_WHOISOPERATOR:
926 var nick string
927 if err := parseMessageParams(msg, nil, &nick); err != nil {
928 return err
929 }
930
931 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
932 nick := dc.marshalNick(uc, nick)
933 dc.SendMessage(&irc.Message{
934 Prefix: dc.srv.prefix(),
935 Command: irc.RPL_WHOISOPERATOR,
936 Params: []string{dc.nick, nick, "is an IRC operator"},
937 })
938 })
939 case irc.RPL_WHOISIDLE:
940 var nick string
941 if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
942 return err
943 }
944
945 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
946 nick := dc.marshalNick(uc, nick)
947 params := []string{dc.nick, nick}
948 params = append(params, msg.Params[2:]...)
949 dc.SendMessage(&irc.Message{
950 Prefix: dc.srv.prefix(),
951 Command: irc.RPL_WHOISIDLE,
952 Params: params,
953 })
954 })
955 case irc.RPL_WHOISCHANNELS:
956 var nick, channelList string
957 if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
958 return err
959 }
960 channels := strings.Split(channelList, " ")
961
962 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
963 nick := dc.marshalNick(uc, nick)
964 channelList := make([]string, len(channels))
965 for i, channel := range channels {
966 prefix, channel := uc.parseMembershipPrefix(channel)
967 channel = dc.marshalChannel(uc, channel)
968 channelList[i] = prefix.String() + channel
969 }
970 channels := strings.Join(channelList, " ")
971 dc.SendMessage(&irc.Message{
972 Prefix: dc.srv.prefix(),
973 Command: irc.RPL_WHOISCHANNELS,
974 Params: []string{dc.nick, nick, channels},
975 })
976 })
977 case irc.RPL_ENDOFWHOIS:
978 var nick string
979 if err := parseMessageParams(msg, nil, &nick); err != nil {
980 return err
981 }
982
983 uc.forEachDownstreamById(downstreamId, func(dc *downstreamConn) {
984 nick := dc.marshalNick(uc, nick)
985 dc.SendMessage(&irc.Message{
986 Prefix: dc.srv.prefix(),
987 Command: irc.RPL_ENDOFWHOIS,
988 Params: []string{dc.nick, nick, "End of /WHOIS list"},
989 })
990 })
991 case "PRIVMSG":
992 if msg.Prefix == nil {
993 return fmt.Errorf("expected a prefix")
994 }
995
996 var nick string
997 if err := parseMessageParams(msg, &nick, nil); err != nil {
998 return err
999 }
1000
1001 if msg.Prefix.Name == serviceNick {
1002 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
1003 break
1004 }
1005 if nick == serviceNick {
1006 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
1007 break
1008 }
1009
1010 uc.network.ring.Produce(msg)
1011 case "INVITE":
1012 var nick string
1013 var channel string
1014 if err := parseMessageParams(msg, &nick, &channel); err != nil {
1015 return err
1016 }
1017
1018 uc.forEachDownstream(func(dc *downstreamConn) {
1019 dc.SendMessage(&irc.Message{
1020 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
1021 Command: "INVITE",
1022 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
1023 })
1024 })
1025 case "TAGMSG":
1026 // TODO: relay to downstream connections that accept message-tags
1027 case "ACK":
1028 // Ignore
1029 case irc.RPL_YOURHOST, irc.RPL_CREATED:
1030 // Ignore
1031 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
1032 // Ignore
1033 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
1034 // Ignore
1035 case rpl_localusers, rpl_globalusers:
1036 // Ignore
1037 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
1038 // Ignore
1039 default:
1040 uc.logger.Printf("unhandled message: %v", msg)
1041 }
1042 return nil
1043}
1044
1045func (uc *upstreamConn) register() {
1046 uc.nick = uc.network.Nick
1047 uc.username = uc.network.Username
1048 if uc.username == "" {
1049 uc.username = uc.nick
1050 }
1051 uc.realname = uc.network.Realname
1052 if uc.realname == "" {
1053 uc.realname = uc.nick
1054 }
1055
1056 uc.SendMessage(&irc.Message{
1057 Command: "CAP",
1058 Params: []string{"LS", "302"},
1059 })
1060
1061 if uc.network.Pass != "" {
1062 uc.SendMessage(&irc.Message{
1063 Command: "PASS",
1064 Params: []string{uc.network.Pass},
1065 })
1066 }
1067
1068 uc.SendMessage(&irc.Message{
1069 Command: "NICK",
1070 Params: []string{uc.nick},
1071 })
1072 uc.SendMessage(&irc.Message{
1073 Command: "USER",
1074 Params: []string{uc.username, "0", "*", uc.realname},
1075 })
1076}
1077
1078func (uc *upstreamConn) requestSASL() bool {
1079 if uc.network.SASL.Mechanism == "" {
1080 return false
1081 }
1082
1083 v, ok := uc.caps["sasl"]
1084 if !ok {
1085 return false
1086 }
1087 if v != "" {
1088 mechanisms := strings.Split(v, ",")
1089 found := false
1090 for _, mech := range mechanisms {
1091 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
1092 found = true
1093 break
1094 }
1095 }
1096 if !found {
1097 return false
1098 }
1099 }
1100
1101 return true
1102}
1103
1104func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
1105 auth := &uc.network.SASL
1106 switch name {
1107 case "sasl":
1108 if !ok {
1109 uc.logger.Printf("server refused to acknowledge the SASL capability")
1110 return nil
1111 }
1112
1113 switch auth.Mechanism {
1114 case "PLAIN":
1115 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
1116 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
1117 default:
1118 return fmt.Errorf("unsupported SASL mechanism %q", name)
1119 }
1120
1121 uc.SendMessage(&irc.Message{
1122 Command: "AUTHENTICATE",
1123 Params: []string{auth.Mechanism},
1124 })
1125 case "message-tags":
1126 uc.tagsSupported = ok
1127 case "labeled-response":
1128 uc.labelsSupported = ok
1129 }
1130 return nil
1131}
1132
1133func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
1134 for {
1135 msg, err := uc.irc.ReadMessage()
1136 if err == io.EOF {
1137 break
1138 } else if err != nil {
1139 return fmt.Errorf("failed to read IRC command: %v", err)
1140 }
1141
1142 if uc.srv.Debug {
1143 uc.logger.Printf("received: %v", msg)
1144 }
1145
1146 ch <- upstreamIncomingMessage{msg, uc}
1147 }
1148
1149 return nil
1150}
1151
1152func (uc *upstreamConn) SendMessage(msg *irc.Message) {
1153 uc.outgoing <- msg
1154}
1155
1156func (uc *upstreamConn) SendMessageLabeled(dc *downstreamConn, msg *irc.Message) {
1157 if uc.labelsSupported {
1158 if msg.Tags == nil {
1159 msg.Tags = make(map[string]irc.TagValue)
1160 }
1161 msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", dc.id, uc.nextLabelId))
1162 uc.nextLabelId++
1163 }
1164 uc.SendMessage(msg)
1165}
Note: See TracBrowser for help on using the repository browser.