source: code/trunk/upstream.go@ 127

Last change on this file since 127 was 127, checked in by delthas, 5 years ago

Add WHO support

File size: 17.4 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "fmt"
7 "io"
8 "net"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/emersion/go-sasl"
15 "gopkg.in/irc.v3"
16)
17
18type upstreamChannel struct {
19 Name string
20 conn *upstreamConn
21 Topic string
22 TopicWho string
23 TopicTime time.Time
24 Status channelStatus
25 modes modeSet
26 Members map[string]membership
27 complete bool
28}
29
30type upstreamConn struct {
31 network *network
32 logger Logger
33 net net.Conn
34 irc *irc.Conn
35 srv *Server
36 user *user
37 outgoing chan<- *irc.Message
38 ring *Ring
39
40 serverName string
41 availableUserModes string
42 availableChannelModes string
43 channelModesWithParam string
44
45 registered bool
46 nick string
47 username string
48 realname string
49 closed bool
50 modes modeSet
51 channels map[string]*upstreamChannel
52 caps map[string]string
53
54 saslClient sasl.Client
55 saslStarted bool
56
57 lock sync.Mutex
58 history map[string]uint64 // TODO: move to network
59}
60
61func connectToUpstream(network *network) (*upstreamConn, error) {
62 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
63
64 addr := network.Addr
65 if !strings.ContainsRune(addr, ':') {
66 addr = addr + ":6697"
67 }
68
69 logger.Printf("connecting to TLS server at address %q", addr)
70 netConn, err := tls.Dial("tcp", addr, nil)
71 if err != nil {
72 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
73 }
74
75 setKeepAlive(netConn)
76
77 outgoing := make(chan *irc.Message, 64)
78 uc := &upstreamConn{
79 network: network,
80 logger: logger,
81 net: netConn,
82 irc: irc.NewConn(netConn),
83 srv: network.user.srv,
84 user: network.user,
85 outgoing: outgoing,
86 ring: NewRing(network.user.srv.RingCap),
87 channels: make(map[string]*upstreamChannel),
88 history: make(map[string]uint64),
89 caps: make(map[string]string),
90 }
91
92 go func() {
93 for msg := range outgoing {
94 if uc.srv.Debug {
95 uc.logger.Printf("sent: %v", msg)
96 }
97 if err := uc.irc.WriteMessage(msg); err != nil {
98 uc.logger.Printf("failed to write message: %v", err)
99 }
100 }
101 if err := uc.net.Close(); err != nil {
102 uc.logger.Printf("failed to close connection: %v", err)
103 } else {
104 uc.logger.Printf("connection closed")
105 }
106 }()
107
108 return uc, nil
109}
110
111func (uc *upstreamConn) Close() error {
112 if uc.closed {
113 return fmt.Errorf("upstream connection already closed")
114 }
115 close(uc.outgoing)
116 uc.closed = true
117 return nil
118}
119
120func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
121 uc.user.forEachDownstream(func(dc *downstreamConn) {
122 if dc.network != nil && dc.network != uc.network {
123 return
124 }
125 f(dc)
126 })
127}
128
129func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
130 ch, ok := uc.channels[name]
131 if !ok {
132 return nil, fmt.Errorf("unknown channel %q", name)
133 }
134 return ch, nil
135}
136
137func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
138 switch msg.Command {
139 case "PING":
140 uc.SendMessage(&irc.Message{
141 Command: "PONG",
142 Params: msg.Params,
143 })
144 return nil
145 case "MODE":
146 if msg.Prefix == nil {
147 return fmt.Errorf("missing prefix")
148 }
149
150 var name, modeStr string
151 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
152 return err
153 }
154
155 if name == msg.Prefix.Name { // user mode change
156 if name != uc.nick {
157 return fmt.Errorf("received MODE message for unknow nick %q", name)
158 }
159 return uc.modes.Apply(modeStr)
160 } else { // channel mode change
161 ch, err := uc.getChannel(name)
162 if err != nil {
163 return err
164 }
165 if err := ch.modes.Apply(modeStr); err != nil {
166 return err
167 }
168
169 uc.forEachDownstream(func(dc *downstreamConn) {
170 dc.SendMessage(&irc.Message{
171 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
172 Command: "MODE",
173 Params: []string{dc.marshalChannel(uc, name), modeStr},
174 })
175 })
176 }
177 case "NOTICE":
178 uc.logger.Print(msg)
179
180 uc.forEachDownstream(func(dc *downstreamConn) {
181 dc.SendMessage(msg)
182 })
183 case "CAP":
184 var subCmd string
185 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
186 return err
187 }
188 subCmd = strings.ToUpper(subCmd)
189 subParams := msg.Params[2:]
190 switch subCmd {
191 case "LS":
192 if len(subParams) < 1 {
193 return newNeedMoreParamsError(msg.Command)
194 }
195 caps := strings.Fields(subParams[len(subParams)-1])
196 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
197
198 for _, s := range caps {
199 kv := strings.SplitN(s, "=", 2)
200 k := strings.ToLower(kv[0])
201 var v string
202 if len(kv) == 2 {
203 v = kv[1]
204 }
205 uc.caps[k] = v
206 }
207
208 if more {
209 break // wait to receive all capabilities
210 }
211
212 if uc.requestSASL() {
213 uc.SendMessage(&irc.Message{
214 Command: "CAP",
215 Params: []string{"REQ", "sasl"},
216 })
217 break // we'll send CAP END after authentication is completed
218 }
219
220 uc.SendMessage(&irc.Message{
221 Command: "CAP",
222 Params: []string{"END"},
223 })
224 case "ACK", "NAK":
225 if len(subParams) < 1 {
226 return newNeedMoreParamsError(msg.Command)
227 }
228 caps := strings.Fields(subParams[0])
229
230 for _, name := range caps {
231 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
232 return err
233 }
234 }
235
236 if uc.saslClient == nil {
237 uc.SendMessage(&irc.Message{
238 Command: "CAP",
239 Params: []string{"END"},
240 })
241 }
242 default:
243 uc.logger.Printf("unhandled message: %v", msg)
244 }
245 case "AUTHENTICATE":
246 if uc.saslClient == nil {
247 return fmt.Errorf("received unexpected AUTHENTICATE message")
248 }
249
250 // TODO: if a challenge is 400 bytes long, buffer it
251 var challengeStr string
252 if err := parseMessageParams(msg, &challengeStr); err != nil {
253 uc.SendMessage(&irc.Message{
254 Command: "AUTHENTICATE",
255 Params: []string{"*"},
256 })
257 return err
258 }
259
260 var challenge []byte
261 if challengeStr != "+" {
262 var err error
263 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
264 if err != nil {
265 uc.SendMessage(&irc.Message{
266 Command: "AUTHENTICATE",
267 Params: []string{"*"},
268 })
269 return err
270 }
271 }
272
273 var resp []byte
274 var err error
275 if !uc.saslStarted {
276 _, resp, err = uc.saslClient.Start()
277 uc.saslStarted = true
278 } else {
279 resp, err = uc.saslClient.Next(challenge)
280 }
281 if err != nil {
282 uc.SendMessage(&irc.Message{
283 Command: "AUTHENTICATE",
284 Params: []string{"*"},
285 })
286 return err
287 }
288
289 // TODO: send response in multiple chunks if >= 400 bytes
290 var respStr = "+"
291 if resp != nil {
292 respStr = base64.StdEncoding.EncodeToString(resp)
293 }
294
295 uc.SendMessage(&irc.Message{
296 Command: "AUTHENTICATE",
297 Params: []string{respStr},
298 })
299 case irc.RPL_LOGGEDIN:
300 var account string
301 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
302 return err
303 }
304 uc.logger.Printf("logged in with account %q", account)
305 case irc.RPL_LOGGEDOUT:
306 uc.logger.Printf("logged out")
307 case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
308 var info string
309 if err := parseMessageParams(msg, nil, &info); err != nil {
310 return err
311 }
312 switch msg.Command {
313 case irc.ERR_NICKLOCKED:
314 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
315 case irc.ERR_SASLFAIL:
316 uc.logger.Printf("SASL authentication failed: %v", info)
317 case irc.ERR_SASLTOOLONG:
318 uc.logger.Printf("SASL message too long: %v", info)
319 }
320
321 uc.saslClient = nil
322 uc.saslStarted = false
323
324 uc.SendMessage(&irc.Message{
325 Command: "CAP",
326 Params: []string{"END"},
327 })
328 case irc.RPL_WELCOME:
329 uc.registered = true
330 uc.logger.Printf("connection registered")
331
332 channels, err := uc.srv.db.ListChannels(uc.network.ID)
333 if err != nil {
334 uc.logger.Printf("failed to list channels from database: %v", err)
335 break
336 }
337
338 for _, ch := range channels {
339 uc.SendMessage(&irc.Message{
340 Command: "JOIN",
341 Params: []string{ch.Name},
342 })
343 }
344 case irc.RPL_MYINFO:
345 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
346 return err
347 }
348 if len(msg.Params) > 5 {
349 uc.channelModesWithParam = msg.Params[5]
350 }
351 case "NICK":
352 if msg.Prefix == nil {
353 return fmt.Errorf("expected a prefix")
354 }
355
356 var newNick string
357 if err := parseMessageParams(msg, &newNick); err != nil {
358 return err
359 }
360
361 if msg.Prefix.Name == uc.nick {
362 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
363 uc.nick = newNick
364 }
365
366 for _, ch := range uc.channels {
367 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
368 delete(ch.Members, msg.Prefix.Name)
369 ch.Members[newNick] = membership
370 }
371 }
372
373 if msg.Prefix.Name != uc.nick {
374 uc.forEachDownstream(func(dc *downstreamConn) {
375 dc.SendMessage(&irc.Message{
376 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
377 Command: "NICK",
378 Params: []string{newNick},
379 })
380 })
381 }
382 case "JOIN":
383 if msg.Prefix == nil {
384 return fmt.Errorf("expected a prefix")
385 }
386
387 var channels string
388 if err := parseMessageParams(msg, &channels); err != nil {
389 return err
390 }
391
392 for _, ch := range strings.Split(channels, ",") {
393 if msg.Prefix.Name == uc.nick {
394 uc.logger.Printf("joined channel %q", ch)
395 uc.channels[ch] = &upstreamChannel{
396 Name: ch,
397 conn: uc,
398 Members: make(map[string]membership),
399 }
400 } else {
401 ch, err := uc.getChannel(ch)
402 if err != nil {
403 return err
404 }
405 ch.Members[msg.Prefix.Name] = 0
406 }
407
408 uc.forEachDownstream(func(dc *downstreamConn) {
409 dc.SendMessage(&irc.Message{
410 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
411 Command: "JOIN",
412 Params: []string{dc.marshalChannel(uc, ch)},
413 })
414 })
415 }
416 case "PART":
417 if msg.Prefix == nil {
418 return fmt.Errorf("expected a prefix")
419 }
420
421 var channels string
422 if err := parseMessageParams(msg, &channels); err != nil {
423 return err
424 }
425
426 for _, ch := range strings.Split(channels, ",") {
427 if msg.Prefix.Name == uc.nick {
428 uc.logger.Printf("parted channel %q", ch)
429 delete(uc.channels, ch)
430 } else {
431 ch, err := uc.getChannel(ch)
432 if err != nil {
433 return err
434 }
435 delete(ch.Members, msg.Prefix.Name)
436 }
437
438 uc.forEachDownstream(func(dc *downstreamConn) {
439 dc.SendMessage(&irc.Message{
440 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
441 Command: "PART",
442 Params: []string{dc.marshalChannel(uc, ch)},
443 })
444 })
445 }
446 case "QUIT":
447 if msg.Prefix == nil {
448 return fmt.Errorf("expected a prefix")
449 }
450
451 if msg.Prefix.Name == uc.nick {
452 uc.logger.Printf("quit")
453 }
454
455 for _, ch := range uc.channels {
456 delete(ch.Members, msg.Prefix.Name)
457 }
458
459 if msg.Prefix.Name != uc.nick {
460 uc.forEachDownstream(func(dc *downstreamConn) {
461 dc.SendMessage(&irc.Message{
462 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
463 Command: "QUIT",
464 Params: msg.Params,
465 })
466 })
467 }
468 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
469 var name, topic string
470 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
471 return err
472 }
473 ch, err := uc.getChannel(name)
474 if err != nil {
475 return err
476 }
477 if msg.Command == irc.RPL_TOPIC {
478 ch.Topic = topic
479 } else {
480 ch.Topic = ""
481 }
482 case "TOPIC":
483 var name string
484 if err := parseMessageParams(msg, &name); err != nil {
485 return err
486 }
487 ch, err := uc.getChannel(name)
488 if err != nil {
489 return err
490 }
491 if len(msg.Params) > 1 {
492 ch.Topic = msg.Params[1]
493 } else {
494 ch.Topic = ""
495 }
496 uc.forEachDownstream(func(dc *downstreamConn) {
497 params := []string{dc.marshalChannel(uc, name)}
498 if ch.Topic != "" {
499 params = append(params, ch.Topic)
500 }
501 dc.SendMessage(&irc.Message{
502 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
503 Command: "TOPIC",
504 Params: params,
505 })
506 })
507 case rpl_topicwhotime:
508 var name, who, timeStr string
509 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
510 return err
511 }
512 ch, err := uc.getChannel(name)
513 if err != nil {
514 return err
515 }
516 ch.TopicWho = who
517 sec, err := strconv.ParseInt(timeStr, 10, 64)
518 if err != nil {
519 return fmt.Errorf("failed to parse topic time: %v", err)
520 }
521 ch.TopicTime = time.Unix(sec, 0)
522 case irc.RPL_NAMREPLY:
523 var name, statusStr, members string
524 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
525 return err
526 }
527 ch, err := uc.getChannel(name)
528 if err != nil {
529 return err
530 }
531
532 status, err := parseChannelStatus(statusStr)
533 if err != nil {
534 return err
535 }
536 ch.Status = status
537
538 for _, s := range strings.Split(members, " ") {
539 membership, nick := parseMembershipPrefix(s)
540 ch.Members[nick] = membership
541 }
542 case irc.RPL_ENDOFNAMES:
543 var name string
544 if err := parseMessageParams(msg, nil, &name); err != nil {
545 return err
546 }
547 ch, err := uc.getChannel(name)
548 if err != nil {
549 return err
550 }
551
552 if ch.complete {
553 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
554 }
555 ch.complete = true
556
557 uc.forEachDownstream(func(dc *downstreamConn) {
558 forwardChannel(dc, ch)
559 })
560 case irc.RPL_WHOREPLY:
561 var channel, username, host, server, nick, mode, trailing string
562 if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &mode, &trailing); err != nil {
563 return err
564 }
565
566 parts := strings.SplitN(trailing, " ", 2)
567 if len(parts) != 2 {
568 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing)
569 }
570 realname := parts[1]
571 hops, err := strconv.Atoi(parts[0])
572 if err != nil {
573 return fmt.Errorf("received malformed RPL_WHOREPLY: wrong hop count: %s", parts[0])
574 }
575 hops++
576
577 trailing = strconv.Itoa(hops) + " " + realname
578
579 uc.forEachDownstream(func(dc *downstreamConn) {
580 channel := channel
581 if channel != "*" {
582 channel = dc.marshalChannel(uc, channel)
583 }
584 nick := dc.marshalNick(uc, nick)
585 dc.SendMessage(&irc.Message{
586 Prefix: dc.srv.prefix(),
587 Command: irc.RPL_WHOREPLY,
588 Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing},
589 })
590 })
591 case irc.RPL_ENDOFWHO:
592 var name string
593 if err := parseMessageParams(msg, nil, &name); err != nil {
594 return err
595 }
596
597 uc.forEachDownstream(func(dc *downstreamConn) {
598 name := name
599 if name != "*" {
600 // TODO: support WHO masks
601 name = dc.marshalEntity(uc, name)
602 }
603 dc.SendMessage(&irc.Message{
604 Prefix: dc.srv.prefix(),
605 Command: irc.RPL_ENDOFWHO,
606 Params: []string{dc.nick, name, "End of /WHO list."},
607 })
608 })
609 case "PRIVMSG":
610 if msg.Prefix == nil {
611 return fmt.Errorf("expected a prefix")
612 }
613
614 var nick string
615 if err := parseMessageParams(msg, &nick, nil); err != nil {
616 return err
617 }
618
619 if msg.Prefix.Name == serviceNick {
620 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
621 break
622 }
623 if nick == serviceNick {
624 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
625 break
626 }
627
628 uc.ring.Produce(msg)
629 case "INVITE":
630 var nick string
631 var channel string
632 if err := parseMessageParams(msg, &nick, &channel); err != nil {
633 return err
634 }
635
636 uc.forEachDownstream(func(dc *downstreamConn) {
637 dc.SendMessage(&irc.Message{
638 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
639 Command: "INVITE",
640 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
641 })
642 })
643 case irc.RPL_YOURHOST, irc.RPL_CREATED:
644 // Ignore
645 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
646 // Ignore
647 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
648 // Ignore
649 case rpl_localusers, rpl_globalusers:
650 // Ignore
651 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
652 // Ignore
653 default:
654 uc.logger.Printf("unhandled message: %v", msg)
655 }
656 return nil
657}
658
659func (uc *upstreamConn) register() {
660 uc.nick = uc.network.Nick
661 uc.username = uc.network.Username
662 if uc.username == "" {
663 uc.username = uc.nick
664 }
665 uc.realname = uc.network.Realname
666 if uc.realname == "" {
667 uc.realname = uc.nick
668 }
669
670 uc.SendMessage(&irc.Message{
671 Command: "CAP",
672 Params: []string{"LS", "302"},
673 })
674
675 if uc.network.Pass != "" {
676 uc.SendMessage(&irc.Message{
677 Command: "PASS",
678 Params: []string{uc.network.Pass},
679 })
680 }
681
682 uc.SendMessage(&irc.Message{
683 Command: "NICK",
684 Params: []string{uc.nick},
685 })
686 uc.SendMessage(&irc.Message{
687 Command: "USER",
688 Params: []string{uc.username, "0", "*", uc.realname},
689 })
690}
691
692func (uc *upstreamConn) requestSASL() bool {
693 if uc.network.SASL.Mechanism == "" {
694 return false
695 }
696
697 v, ok := uc.caps["sasl"]
698 if !ok {
699 return false
700 }
701 if v != "" {
702 mechanisms := strings.Split(v, ",")
703 found := false
704 for _, mech := range mechanisms {
705 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
706 found = true
707 break
708 }
709 }
710 if !found {
711 return false
712 }
713 }
714
715 return true
716}
717
718func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
719 auth := &uc.network.SASL
720 switch name {
721 case "sasl":
722 if !ok {
723 uc.logger.Printf("server refused to acknowledge the SASL capability")
724 return nil
725 }
726
727 switch auth.Mechanism {
728 case "PLAIN":
729 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
730 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
731 default:
732 return fmt.Errorf("unsupported SASL mechanism %q", name)
733 }
734
735 uc.SendMessage(&irc.Message{
736 Command: "AUTHENTICATE",
737 Params: []string{auth.Mechanism},
738 })
739 }
740 return nil
741}
742
743func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
744 for {
745 msg, err := uc.irc.ReadMessage()
746 if err == io.EOF {
747 break
748 } else if err != nil {
749 return fmt.Errorf("failed to read IRC command: %v", err)
750 }
751
752 if uc.srv.Debug {
753 uc.logger.Printf("received: %v", msg)
754 }
755
756 ch <- upstreamIncomingMessage{msg, uc}
757 }
758
759 return nil
760}
761
762func (uc *upstreamConn) SendMessage(msg *irc.Message) {
763 uc.outgoing <- msg
764}
Note: See TracBrowser for help on using the repository browser.