source: code/trunk/upstream.go@ 152

Last change on this file since 152 was 152, checked in by delthas, 5 years ago

Add upstream message-tags capability support

File size: 24.9 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "fmt"
7 "io"
8 "net"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/emersion/go-sasl"
14 "gopkg.in/irc.v3"
15)
16
17type upstreamChannel struct {
18 Name string
19 conn *upstreamConn
20 Topic string
21 TopicWho string
22 TopicTime time.Time
23 Status channelStatus
24 modes channelModes
25 Members map[string]*membership
26 complete bool
27}
28
29type upstreamConn struct {
30 network *network
31 logger Logger
32 net net.Conn
33 irc *irc.Conn
34 srv *Server
35 user *user
36 outgoing chan<- *irc.Message
37
38 serverName string
39 availableUserModes string
40 availableChannelModes map[byte]channelModeType
41 availableChannelTypes string
42 availableMemberships []membership
43
44 registered bool
45 nick string
46 username string
47 realname string
48 closed bool
49 modes userModes
50 channels map[string]*upstreamChannel
51 caps map[string]string
52
53 tagsSupported bool
54
55 saslClient sasl.Client
56 saslStarted bool
57}
58
59func connectToUpstream(network *network) (*upstreamConn, error) {
60 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
61
62 addr := network.Addr
63 if !strings.ContainsRune(addr, ':') {
64 addr = addr + ":6697"
65 }
66
67 logger.Printf("connecting to TLS server at address %q", addr)
68 netConn, err := tls.Dial("tcp", addr, nil)
69 if err != nil {
70 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
71 }
72
73 setKeepAlive(netConn)
74
75 outgoing := make(chan *irc.Message, 64)
76 uc := &upstreamConn{
77 network: network,
78 logger: logger,
79 net: netConn,
80 irc: irc.NewConn(netConn),
81 srv: network.user.srv,
82 user: network.user,
83 outgoing: outgoing,
84 channels: make(map[string]*upstreamChannel),
85 caps: make(map[string]string),
86 availableChannelTypes: stdChannelTypes,
87 availableChannelModes: stdChannelModes,
88 availableMemberships: stdMemberships,
89 }
90
91 go func() {
92 for msg := range outgoing {
93 if uc.srv.Debug {
94 uc.logger.Printf("sent: %v", msg)
95 }
96 if err := uc.irc.WriteMessage(msg); err != nil {
97 uc.logger.Printf("failed to write message: %v", err)
98 }
99 }
100 if err := uc.net.Close(); err != nil {
101 uc.logger.Printf("failed to close connection: %v", err)
102 } else {
103 uc.logger.Printf("connection closed")
104 }
105 }()
106
107 return uc, nil
108}
109
110func (uc *upstreamConn) Close() error {
111 if uc.closed {
112 return fmt.Errorf("upstream connection already closed")
113 }
114 close(uc.outgoing)
115 uc.closed = true
116 return nil
117}
118
119func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
120 uc.user.forEachDownstream(func(dc *downstreamConn) {
121 if dc.network != nil && dc.network != uc.network {
122 return
123 }
124 f(dc)
125 })
126}
127
128func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
129 ch, ok := uc.channels[name]
130 if !ok {
131 return nil, fmt.Errorf("unknown channel %q", name)
132 }
133 return ch, nil
134}
135
136func (uc *upstreamConn) isChannel(entity string) bool {
137 if i := strings.IndexByte(uc.availableChannelTypes, entity[0]); i >= 0 {
138 return true
139 }
140 return false
141}
142
143func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership, nick string) {
144 for _, m := range uc.availableMemberships {
145 if m.Prefix == s[0] {
146 return &m, s[1:]
147 }
148 }
149 return nil, s
150}
151
152func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
153 switch msg.Command {
154 case "PING":
155 uc.SendMessage(&irc.Message{
156 Command: "PONG",
157 Params: msg.Params,
158 })
159 return nil
160 case "NOTICE":
161 uc.logger.Print(msg)
162
163 uc.forEachDownstream(func(dc *downstreamConn) {
164 dc.SendMessage(msg)
165 })
166 case "CAP":
167 var subCmd string
168 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
169 return err
170 }
171 subCmd = strings.ToUpper(subCmd)
172 subParams := msg.Params[2:]
173 switch subCmd {
174 case "LS":
175 if len(subParams) < 1 {
176 return newNeedMoreParamsError(msg.Command)
177 }
178 caps := strings.Fields(subParams[len(subParams)-1])
179 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
180
181 for _, s := range caps {
182 kv := strings.SplitN(s, "=", 2)
183 k := strings.ToLower(kv[0])
184 var v string
185 if len(kv) == 2 {
186 v = kv[1]
187 }
188 uc.caps[k] = v
189 }
190
191 if more {
192 break // wait to receive all capabilities
193 }
194
195 requestCaps := make([]string, 0, 16)
196 for _, c := range []string{"message-tags"} {
197 if _, ok := uc.caps[c]; ok {
198 requestCaps = append(requestCaps, c)
199 }
200 }
201
202 if uc.requestSASL() {
203 requestCaps = append(requestCaps, "sasl")
204 }
205
206 if len(requestCaps) > 0 {
207 uc.SendMessage(&irc.Message{
208 Command: "CAP",
209 Params: []string{"REQ", strings.Join(requestCaps, " ")},
210 })
211 }
212
213 if uc.requestSASL() {
214 break // we'll send CAP END after authentication is completed
215 }
216
217 uc.SendMessage(&irc.Message{
218 Command: "CAP",
219 Params: []string{"END"},
220 })
221 case "ACK", "NAK":
222 if len(subParams) < 1 {
223 return newNeedMoreParamsError(msg.Command)
224 }
225 caps := strings.Fields(subParams[0])
226
227 for _, name := range caps {
228 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
229 return err
230 }
231 }
232
233 if uc.saslClient == nil {
234 uc.SendMessage(&irc.Message{
235 Command: "CAP",
236 Params: []string{"END"},
237 })
238 }
239 default:
240 uc.logger.Printf("unhandled message: %v", msg)
241 }
242 case "AUTHENTICATE":
243 if uc.saslClient == nil {
244 return fmt.Errorf("received unexpected AUTHENTICATE message")
245 }
246
247 // TODO: if a challenge is 400 bytes long, buffer it
248 var challengeStr string
249 if err := parseMessageParams(msg, &challengeStr); err != nil {
250 uc.SendMessage(&irc.Message{
251 Command: "AUTHENTICATE",
252 Params: []string{"*"},
253 })
254 return err
255 }
256
257 var challenge []byte
258 if challengeStr != "+" {
259 var err error
260 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
261 if err != nil {
262 uc.SendMessage(&irc.Message{
263 Command: "AUTHENTICATE",
264 Params: []string{"*"},
265 })
266 return err
267 }
268 }
269
270 var resp []byte
271 var err error
272 if !uc.saslStarted {
273 _, resp, err = uc.saslClient.Start()
274 uc.saslStarted = true
275 } else {
276 resp, err = uc.saslClient.Next(challenge)
277 }
278 if err != nil {
279 uc.SendMessage(&irc.Message{
280 Command: "AUTHENTICATE",
281 Params: []string{"*"},
282 })
283 return err
284 }
285
286 // TODO: send response in multiple chunks if >= 400 bytes
287 var respStr = "+"
288 if resp != nil {
289 respStr = base64.StdEncoding.EncodeToString(resp)
290 }
291
292 uc.SendMessage(&irc.Message{
293 Command: "AUTHENTICATE",
294 Params: []string{respStr},
295 })
296 case irc.RPL_LOGGEDIN:
297 var account string
298 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
299 return err
300 }
301 uc.logger.Printf("logged in with account %q", account)
302 case irc.RPL_LOGGEDOUT:
303 uc.logger.Printf("logged out")
304 case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
305 var info string
306 if err := parseMessageParams(msg, nil, &info); err != nil {
307 return err
308 }
309 switch msg.Command {
310 case irc.ERR_NICKLOCKED:
311 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
312 case irc.ERR_SASLFAIL:
313 uc.logger.Printf("SASL authentication failed: %v", info)
314 case irc.ERR_SASLTOOLONG:
315 uc.logger.Printf("SASL message too long: %v", info)
316 }
317
318 uc.saslClient = nil
319 uc.saslStarted = false
320
321 uc.SendMessage(&irc.Message{
322 Command: "CAP",
323 Params: []string{"END"},
324 })
325 case irc.RPL_WELCOME:
326 uc.registered = true
327 uc.logger.Printf("connection registered")
328
329 channels, err := uc.srv.db.ListChannels(uc.network.ID)
330 if err != nil {
331 uc.logger.Printf("failed to list channels from database: %v", err)
332 break
333 }
334
335 for _, ch := range channels {
336 params := []string{ch.Name}
337 if ch.Key != "" {
338 params = append(params, ch.Key)
339 }
340 uc.SendMessage(&irc.Message{
341 Command: "JOIN",
342 Params: params,
343 })
344 }
345 case irc.RPL_MYINFO:
346 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
347 return err
348 }
349 case irc.RPL_ISUPPORT:
350 if err := parseMessageParams(msg, nil, nil); err != nil {
351 return err
352 }
353 for _, token := range msg.Params[1 : len(msg.Params)-1] {
354 negate := false
355 parameter := token
356 value := ""
357 if strings.HasPrefix(token, "-") {
358 negate = true
359 token = token[1:]
360 } else {
361 if i := strings.IndexByte(token, '='); i >= 0 {
362 parameter = token[:i]
363 value = token[i+1:]
364 }
365 }
366 if !negate {
367 switch parameter {
368 case "CHANMODES":
369 parts := strings.SplitN(value, ",", 5)
370 if len(parts) < 4 {
371 return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", value)
372 }
373 modes := make(map[byte]channelModeType)
374 for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
375 for j := 0; j < len(parts[i]); j++ {
376 mode := parts[i][j]
377 modes[mode] = mt
378 }
379 }
380 uc.availableChannelModes = modes
381 case "CHANTYPES":
382 uc.availableChannelTypes = value
383 case "PREFIX":
384 if value == "" {
385 uc.availableMemberships = nil
386 } else {
387 if value[0] != '(' {
388 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
389 }
390 sep := strings.IndexByte(value, ')')
391 if sep < 0 || len(value) != sep*2 {
392 return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", value)
393 }
394 memberships := make([]membership, len(value)/2-1)
395 for i := range memberships {
396 memberships[i] = membership{
397 Mode: value[i+1],
398 Prefix: value[sep+i+1],
399 }
400 }
401 uc.availableMemberships = memberships
402 }
403 }
404 } else {
405 // TODO: handle ISUPPORT negations
406 }
407 }
408 case "NICK":
409 if msg.Prefix == nil {
410 return fmt.Errorf("expected a prefix")
411 }
412
413 var newNick string
414 if err := parseMessageParams(msg, &newNick); err != nil {
415 return err
416 }
417
418 if msg.Prefix.Name == uc.nick {
419 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
420 uc.nick = newNick
421 }
422
423 for _, ch := range uc.channels {
424 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
425 delete(ch.Members, msg.Prefix.Name)
426 ch.Members[newNick] = membership
427 }
428 }
429
430 if msg.Prefix.Name != uc.nick {
431 uc.forEachDownstream(func(dc *downstreamConn) {
432 dc.SendMessage(&irc.Message{
433 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
434 Command: "NICK",
435 Params: []string{newNick},
436 })
437 })
438 }
439 case "JOIN":
440 if msg.Prefix == nil {
441 return fmt.Errorf("expected a prefix")
442 }
443
444 var channels string
445 if err := parseMessageParams(msg, &channels); err != nil {
446 return err
447 }
448
449 for _, ch := range strings.Split(channels, ",") {
450 if msg.Prefix.Name == uc.nick {
451 uc.logger.Printf("joined channel %q", ch)
452 uc.channels[ch] = &upstreamChannel{
453 Name: ch,
454 conn: uc,
455 Members: make(map[string]*membership),
456 }
457
458 uc.SendMessage(&irc.Message{
459 Command: "MODE",
460 Params: []string{ch},
461 })
462 } else {
463 ch, err := uc.getChannel(ch)
464 if err != nil {
465 return err
466 }
467 ch.Members[msg.Prefix.Name] = nil
468 }
469
470 uc.forEachDownstream(func(dc *downstreamConn) {
471 dc.SendMessage(&irc.Message{
472 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
473 Command: "JOIN",
474 Params: []string{dc.marshalChannel(uc, ch)},
475 })
476 })
477 }
478 case "PART":
479 if msg.Prefix == nil {
480 return fmt.Errorf("expected a prefix")
481 }
482
483 var channels string
484 if err := parseMessageParams(msg, &channels); err != nil {
485 return err
486 }
487
488 for _, ch := range strings.Split(channels, ",") {
489 if msg.Prefix.Name == uc.nick {
490 uc.logger.Printf("parted channel %q", ch)
491 delete(uc.channels, ch)
492 } else {
493 ch, err := uc.getChannel(ch)
494 if err != nil {
495 return err
496 }
497 delete(ch.Members, msg.Prefix.Name)
498 }
499
500 uc.forEachDownstream(func(dc *downstreamConn) {
501 dc.SendMessage(&irc.Message{
502 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
503 Command: "PART",
504 Params: []string{dc.marshalChannel(uc, ch)},
505 })
506 })
507 }
508 case "QUIT":
509 if msg.Prefix == nil {
510 return fmt.Errorf("expected a prefix")
511 }
512
513 if msg.Prefix.Name == uc.nick {
514 uc.logger.Printf("quit")
515 }
516
517 for _, ch := range uc.channels {
518 delete(ch.Members, msg.Prefix.Name)
519 }
520
521 if msg.Prefix.Name != uc.nick {
522 uc.forEachDownstream(func(dc *downstreamConn) {
523 dc.SendMessage(&irc.Message{
524 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
525 Command: "QUIT",
526 Params: msg.Params,
527 })
528 })
529 }
530 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
531 var name, topic string
532 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
533 return err
534 }
535 ch, err := uc.getChannel(name)
536 if err != nil {
537 return err
538 }
539 if msg.Command == irc.RPL_TOPIC {
540 ch.Topic = topic
541 } else {
542 ch.Topic = ""
543 }
544 case "TOPIC":
545 var name string
546 if err := parseMessageParams(msg, &name); err != nil {
547 return err
548 }
549 ch, err := uc.getChannel(name)
550 if err != nil {
551 return err
552 }
553 if len(msg.Params) > 1 {
554 ch.Topic = msg.Params[1]
555 } else {
556 ch.Topic = ""
557 }
558 uc.forEachDownstream(func(dc *downstreamConn) {
559 params := []string{dc.marshalChannel(uc, name)}
560 if ch.Topic != "" {
561 params = append(params, ch.Topic)
562 }
563 dc.SendMessage(&irc.Message{
564 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
565 Command: "TOPIC",
566 Params: params,
567 })
568 })
569 case "MODE":
570 var name, modeStr string
571 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
572 return err
573 }
574
575 if !uc.isChannel(name) { // user mode change
576 if name != uc.nick {
577 return fmt.Errorf("received MODE message for unknown nick %q", name)
578 }
579 return uc.modes.Apply(modeStr)
580 // TODO: notify downstreams about user mode change?
581 } else { // channel mode change
582 ch, err := uc.getChannel(name)
583 if err != nil {
584 return err
585 }
586
587 if ch.modes != nil {
588 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[2:]...); err != nil {
589 return err
590 }
591 }
592
593 uc.forEachDownstream(func(dc *downstreamConn) {
594 params := []string{dc.marshalChannel(uc, name), modeStr}
595 params = append(params, msg.Params[2:]...)
596
597 dc.SendMessage(&irc.Message{
598 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
599 Command: "MODE",
600 Params: params,
601 })
602 })
603 }
604 case irc.RPL_UMODEIS:
605 if err := parseMessageParams(msg, nil); err != nil {
606 return err
607 }
608 modeStr := ""
609 if len(msg.Params) > 1 {
610 modeStr = msg.Params[1]
611 }
612
613 uc.modes = ""
614 if err := uc.modes.Apply(modeStr); err != nil {
615 return err
616 }
617 // TODO: send RPL_UMODEIS to downstream connections when applicable
618 case irc.RPL_CHANNELMODEIS:
619 var channel string
620 if err := parseMessageParams(msg, nil, &channel); err != nil {
621 return err
622 }
623 modeStr := ""
624 if len(msg.Params) > 2 {
625 modeStr = msg.Params[2]
626 }
627
628 ch, err := uc.getChannel(channel)
629 if err != nil {
630 return err
631 }
632
633 firstMode := ch.modes == nil
634 ch.modes = make(map[byte]string)
635 if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[3:]...); err != nil {
636 return err
637 }
638 if firstMode {
639 modeStr, modeParams := ch.modes.Format()
640
641 uc.forEachDownstream(func(dc *downstreamConn) {
642 params := []string{dc.nick, dc.marshalChannel(uc, channel), modeStr}
643 params = append(params, modeParams...)
644
645 dc.SendMessage(&irc.Message{
646 Prefix: dc.srv.prefix(),
647 Command: irc.RPL_CHANNELMODEIS,
648 Params: params,
649 })
650 })
651 }
652 case rpl_topicwhotime:
653 var name, who, timeStr string
654 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
655 return err
656 }
657 ch, err := uc.getChannel(name)
658 if err != nil {
659 return err
660 }
661 ch.TopicWho = who
662 sec, err := strconv.ParseInt(timeStr, 10, 64)
663 if err != nil {
664 return fmt.Errorf("failed to parse topic time: %v", err)
665 }
666 ch.TopicTime = time.Unix(sec, 0)
667 case irc.RPL_NAMREPLY:
668 var name, statusStr, members string
669 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
670 return err
671 }
672
673 ch, ok := uc.channels[name]
674 if !ok {
675 // NAMES on a channel we have not joined, forward to downstream
676 uc.forEachDownstream(func(dc *downstreamConn) {
677 channel := dc.marshalChannel(uc, name)
678 members := strings.Split(members, " ")
679 for i, member := range members {
680 membership, nick := uc.parseMembershipPrefix(member)
681 members[i] = membership.String() + dc.marshalNick(uc, nick)
682 }
683 memberStr := strings.Join(members, " ")
684
685 dc.SendMessage(&irc.Message{
686 Prefix: dc.srv.prefix(),
687 Command: irc.RPL_NAMREPLY,
688 Params: []string{dc.nick, statusStr, channel, memberStr},
689 })
690 })
691 return nil
692 }
693
694 status, err := parseChannelStatus(statusStr)
695 if err != nil {
696 return err
697 }
698 ch.Status = status
699
700 for _, s := range strings.Split(members, " ") {
701 membership, nick := uc.parseMembershipPrefix(s)
702 ch.Members[nick] = membership
703 }
704 case irc.RPL_ENDOFNAMES:
705 var name string
706 if err := parseMessageParams(msg, nil, &name); err != nil {
707 return err
708 }
709
710 ch, ok := uc.channels[name]
711 if !ok {
712 // NAMES on a channel we have not joined, forward to downstream
713 uc.forEachDownstream(func(dc *downstreamConn) {
714 channel := dc.marshalChannel(uc, name)
715
716 dc.SendMessage(&irc.Message{
717 Prefix: dc.srv.prefix(),
718 Command: irc.RPL_ENDOFNAMES,
719 Params: []string{dc.nick, channel, "End of /NAMES list"},
720 })
721 })
722 return nil
723 }
724
725 if ch.complete {
726 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
727 }
728 ch.complete = true
729
730 uc.forEachDownstream(func(dc *downstreamConn) {
731 forwardChannel(dc, ch)
732 })
733 case irc.RPL_WHOREPLY:
734 var channel, username, host, server, nick, mode, trailing string
735 if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &mode, &trailing); err != nil {
736 return err
737 }
738
739 parts := strings.SplitN(trailing, " ", 2)
740 if len(parts) != 2 {
741 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing)
742 }
743 realname := parts[1]
744 hops, err := strconv.Atoi(parts[0])
745 if err != nil {
746 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong hop count: %s", parts[0])
747 }
748 hops++
749
750 trailing = strconv.Itoa(hops) + " " + realname
751
752 uc.forEachDownstream(func(dc *downstreamConn) {
753 channel := channel
754 if channel != "*" {
755 channel = dc.marshalChannel(uc, channel)
756 }
757 nick := dc.marshalNick(uc, nick)
758 dc.SendMessage(&irc.Message{
759 Prefix: dc.srv.prefix(),
760 Command: irc.RPL_WHOREPLY,
761 Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
762 })
763 })
764 case irc.RPL_ENDOFWHO:
765 var name string
766 if err := parseMessageParams(msg, nil, &name); err != nil {
767 return err
768 }
769
770 uc.forEachDownstream(func(dc *downstreamConn) {
771 name := name
772 if name != "*" {
773 // TODO: support WHO masks
774 name = dc.marshalEntity(uc, name)
775 }
776 dc.SendMessage(&irc.Message{
777 Prefix: dc.srv.prefix(),
778 Command: irc.RPL_ENDOFWHO,
779 Params: []string{dc.nick, name, "End of /WHO list"},
780 })
781 })
782 case irc.RPL_WHOISUSER:
783 var nick, username, host, realname string
784 if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
785 return err
786 }
787
788 uc.forEachDownstream(func(dc *downstreamConn) {
789 nick := dc.marshalNick(uc, nick)
790 dc.SendMessage(&irc.Message{
791 Prefix: dc.srv.prefix(),
792 Command: irc.RPL_WHOISUSER,
793 Params: []string{dc.nick, nick, username, host, "*", realname},
794 })
795 })
796 case irc.RPL_WHOISSERVER:
797 var nick, server, serverInfo string
798 if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
799 return err
800 }
801
802 uc.forEachDownstream(func(dc *downstreamConn) {
803 nick := dc.marshalNick(uc, nick)
804 dc.SendMessage(&irc.Message{
805 Prefix: dc.srv.prefix(),
806 Command: irc.RPL_WHOISSERVER,
807 Params: []string{dc.nick, nick, server, serverInfo},
808 })
809 })
810 case irc.RPL_WHOISOPERATOR:
811 var nick string
812 if err := parseMessageParams(msg, nil, &nick); err != nil {
813 return err
814 }
815
816 uc.forEachDownstream(func(dc *downstreamConn) {
817 nick := dc.marshalNick(uc, nick)
818 dc.SendMessage(&irc.Message{
819 Prefix: dc.srv.prefix(),
820 Command: irc.RPL_WHOISOPERATOR,
821 Params: []string{dc.nick, nick, "is an IRC operator"},
822 })
823 })
824 case irc.RPL_WHOISIDLE:
825 var nick string
826 if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
827 return err
828 }
829
830 uc.forEachDownstream(func(dc *downstreamConn) {
831 nick := dc.marshalNick(uc, nick)
832 params := []string{dc.nick, nick}
833 params = append(params, msg.Params[2:]...)
834 dc.SendMessage(&irc.Message{
835 Prefix: dc.srv.prefix(),
836 Command: irc.RPL_WHOISIDLE,
837 Params: params,
838 })
839 })
840 case irc.RPL_WHOISCHANNELS:
841 var nick, channelList string
842 if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
843 return err
844 }
845 channels := strings.Split(channelList, " ")
846
847 uc.forEachDownstream(func(dc *downstreamConn) {
848 nick := dc.marshalNick(uc, nick)
849 channelList := make([]string, len(channels))
850 for i, channel := range channels {
851 prefix, channel := uc.parseMembershipPrefix(channel)
852 channel = dc.marshalChannel(uc, channel)
853 channelList[i] = prefix.String() + channel
854 }
855 channels := strings.Join(channelList, " ")
856 dc.SendMessage(&irc.Message{
857 Prefix: dc.srv.prefix(),
858 Command: irc.RPL_WHOISCHANNELS,
859 Params: []string{dc.nick, nick, channels},
860 })
861 })
862 case irc.RPL_ENDOFWHOIS:
863 var nick string
864 if err := parseMessageParams(msg, nil, &nick); err != nil {
865 return err
866 }
867
868 uc.forEachDownstream(func(dc *downstreamConn) {
869 nick := dc.marshalNick(uc, nick)
870 dc.SendMessage(&irc.Message{
871 Prefix: dc.srv.prefix(),
872 Command: irc.RPL_ENDOFWHOIS,
873 Params: []string{dc.nick, nick, "End of /WHOIS list"},
874 })
875 })
876 case "PRIVMSG":
877 if msg.Prefix == nil {
878 return fmt.Errorf("expected a prefix")
879 }
880
881 var nick string
882 if err := parseMessageParams(msg, &nick, nil); err != nil {
883 return err
884 }
885
886 if msg.Prefix.Name == serviceNick {
887 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
888 break
889 }
890 if nick == serviceNick {
891 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
892 break
893 }
894
895 uc.network.ring.Produce(msg)
896 case "INVITE":
897 var nick string
898 var channel string
899 if err := parseMessageParams(msg, &nick, &channel); err != nil {
900 return err
901 }
902
903 uc.forEachDownstream(func(dc *downstreamConn) {
904 dc.SendMessage(&irc.Message{
905 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
906 Command: "INVITE",
907 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
908 })
909 })
910 case "TAGMSG":
911 // TODO: relay to downstream connections that accept message-tags
912 case irc.RPL_YOURHOST, irc.RPL_CREATED:
913 // Ignore
914 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
915 // Ignore
916 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
917 // Ignore
918 case rpl_localusers, rpl_globalusers:
919 // Ignore
920 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
921 // Ignore
922 default:
923 uc.logger.Printf("unhandled message: %v", msg)
924 }
925 return nil
926}
927
928func (uc *upstreamConn) register() {
929 uc.nick = uc.network.Nick
930 uc.username = uc.network.Username
931 if uc.username == "" {
932 uc.username = uc.nick
933 }
934 uc.realname = uc.network.Realname
935 if uc.realname == "" {
936 uc.realname = uc.nick
937 }
938
939 uc.SendMessage(&irc.Message{
940 Command: "CAP",
941 Params: []string{"LS", "302"},
942 })
943
944 if uc.network.Pass != "" {
945 uc.SendMessage(&irc.Message{
946 Command: "PASS",
947 Params: []string{uc.network.Pass},
948 })
949 }
950
951 uc.SendMessage(&irc.Message{
952 Command: "NICK",
953 Params: []string{uc.nick},
954 })
955 uc.SendMessage(&irc.Message{
956 Command: "USER",
957 Params: []string{uc.username, "0", "*", uc.realname},
958 })
959}
960
961func (uc *upstreamConn) requestSASL() bool {
962 if uc.network.SASL.Mechanism == "" {
963 return false
964 }
965
966 v, ok := uc.caps["sasl"]
967 if !ok {
968 return false
969 }
970 if v != "" {
971 mechanisms := strings.Split(v, ",")
972 found := false
973 for _, mech := range mechanisms {
974 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
975 found = true
976 break
977 }
978 }
979 if !found {
980 return false
981 }
982 }
983
984 return true
985}
986
987func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
988 auth := &uc.network.SASL
989 switch name {
990 case "sasl":
991 if !ok {
992 uc.logger.Printf("server refused to acknowledge the SASL capability")
993 return nil
994 }
995
996 switch auth.Mechanism {
997 case "PLAIN":
998 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
999 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
1000 default:
1001 return fmt.Errorf("unsupported SASL mechanism %q", name)
1002 }
1003
1004 uc.SendMessage(&irc.Message{
1005 Command: "AUTHENTICATE",
1006 Params: []string{auth.Mechanism},
1007 })
1008 case "message-tags":
1009 uc.tagsSupported = ok
1010 }
1011 return nil
1012}
1013
1014func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
1015 for {
1016 msg, err := uc.irc.ReadMessage()
1017 if err == io.EOF {
1018 break
1019 } else if err != nil {
1020 return fmt.Errorf("failed to read IRC command: %v", err)
1021 }
1022
1023 if uc.srv.Debug {
1024 uc.logger.Printf("received: %v", msg)
1025 }
1026
1027 ch <- upstreamIncomingMessage{msg, uc}
1028 }
1029
1030 return nil
1031}
1032
1033func (uc *upstreamConn) SendMessage(msg *irc.Message) {
1034 uc.outgoing <- msg
1035}
Note: See TracBrowser for help on using the repository browser.