source: code/trunk/upstream.go@ 116

Last change on this file since 116 was 115, checked in by delthas, 5 years ago

Add upstream INVITE support

File size: 15.6 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "fmt"
7 "io"
8 "net"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/emersion/go-sasl"
15 "gopkg.in/irc.v3"
16)
17
18type upstreamChannel struct {
19 Name string
20 conn *upstreamConn
21 Topic string
22 TopicWho string
23 TopicTime time.Time
24 Status channelStatus
25 modes modeSet
26 Members map[string]membership
27 complete bool
28}
29
30type upstreamConn struct {
31 network *network
32 logger Logger
33 net net.Conn
34 irc *irc.Conn
35 srv *Server
36 user *user
37 outgoing chan<- *irc.Message
38 ring *Ring
39
40 serverName string
41 availableUserModes string
42 availableChannelModes string
43 channelModesWithParam string
44
45 registered bool
46 nick string
47 username string
48 realname string
49 closed bool
50 modes modeSet
51 channels map[string]*upstreamChannel
52 caps map[string]string
53
54 saslClient sasl.Client
55 saslStarted bool
56
57 lock sync.Mutex
58 history map[string]uint64 // TODO: move to network
59}
60
61func connectToUpstream(network *network) (*upstreamConn, error) {
62 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
63
64 addr := network.Addr
65 if !strings.ContainsRune(addr, ':') {
66 addr = addr + ":6697"
67 }
68
69 logger.Printf("connecting to TLS server at address %q", addr)
70 netConn, err := tls.Dial("tcp", addr, nil)
71 if err != nil {
72 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
73 }
74
75 setKeepAlive(netConn)
76
77 outgoing := make(chan *irc.Message, 64)
78 uc := &upstreamConn{
79 network: network,
80 logger: logger,
81 net: netConn,
82 irc: irc.NewConn(netConn),
83 srv: network.user.srv,
84 user: network.user,
85 outgoing: outgoing,
86 ring: NewRing(network.user.srv.RingCap),
87 channels: make(map[string]*upstreamChannel),
88 history: make(map[string]uint64),
89 caps: make(map[string]string),
90 }
91
92 go func() {
93 for msg := range outgoing {
94 if uc.srv.Debug {
95 uc.logger.Printf("sent: %v", msg)
96 }
97 if err := uc.irc.WriteMessage(msg); err != nil {
98 uc.logger.Printf("failed to write message: %v", err)
99 }
100 }
101 if err := uc.net.Close(); err != nil {
102 uc.logger.Printf("failed to close connection: %v", err)
103 } else {
104 uc.logger.Printf("connection closed")
105 }
106 }()
107
108 return uc, nil
109}
110
111func (uc *upstreamConn) Close() error {
112 if uc.closed {
113 return fmt.Errorf("upstream connection already closed")
114 }
115 close(uc.outgoing)
116 uc.closed = true
117 return nil
118}
119
120func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
121 uc.user.forEachDownstream(func(dc *downstreamConn) {
122 if dc.network != nil && dc.network != uc.network {
123 return
124 }
125 f(dc)
126 })
127}
128
129func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
130 ch, ok := uc.channels[name]
131 if !ok {
132 return nil, fmt.Errorf("unknown channel %q", name)
133 }
134 return ch, nil
135}
136
137func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
138 switch msg.Command {
139 case "PING":
140 uc.SendMessage(&irc.Message{
141 Command: "PONG",
142 Params: msg.Params,
143 })
144 return nil
145 case "MODE":
146 if msg.Prefix == nil {
147 return fmt.Errorf("missing prefix")
148 }
149
150 var name, modeStr string
151 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
152 return err
153 }
154
155 if name == msg.Prefix.Name { // user mode change
156 if name != uc.nick {
157 return fmt.Errorf("received MODE message for unknow nick %q", name)
158 }
159 return uc.modes.Apply(modeStr)
160 } else { // channel mode change
161 ch, err := uc.getChannel(name)
162 if err != nil {
163 return err
164 }
165 if err := ch.modes.Apply(modeStr); err != nil {
166 return err
167 }
168
169 uc.forEachDownstream(func(dc *downstreamConn) {
170 dc.SendMessage(&irc.Message{
171 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
172 Command: "MODE",
173 Params: []string{dc.marshalChannel(uc, name), modeStr},
174 })
175 })
176 }
177 case "NOTICE":
178 uc.logger.Print(msg)
179
180 uc.forEachDownstream(func(dc *downstreamConn) {
181 dc.SendMessage(msg)
182 })
183 case "CAP":
184 var subCmd string
185 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
186 return err
187 }
188 subCmd = strings.ToUpper(subCmd)
189 subParams := msg.Params[2:]
190 switch subCmd {
191 case "LS":
192 if len(subParams) < 1 {
193 return newNeedMoreParamsError(msg.Command)
194 }
195 caps := strings.Fields(subParams[len(subParams)-1])
196 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
197
198 for _, s := range caps {
199 kv := strings.SplitN(s, "=", 2)
200 k := strings.ToLower(kv[0])
201 var v string
202 if len(kv) == 2 {
203 v = kv[1]
204 }
205 uc.caps[k] = v
206 }
207
208 if more {
209 break // wait to receive all capabilities
210 }
211
212 if uc.requestSASL() {
213 uc.SendMessage(&irc.Message{
214 Command: "CAP",
215 Params: []string{"REQ", "sasl"},
216 })
217 break // we'll send CAP END after authentication is completed
218 }
219
220 uc.SendMessage(&irc.Message{
221 Command: "CAP",
222 Params: []string{"END"},
223 })
224 case "ACK", "NAK":
225 if len(subParams) < 1 {
226 return newNeedMoreParamsError(msg.Command)
227 }
228 caps := strings.Fields(subParams[0])
229
230 for _, name := range caps {
231 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
232 return err
233 }
234 }
235
236 if uc.saslClient == nil {
237 uc.SendMessage(&irc.Message{
238 Command: "CAP",
239 Params: []string{"END"},
240 })
241 }
242 default:
243 uc.logger.Printf("unhandled message: %v", msg)
244 }
245 case "AUTHENTICATE":
246 if uc.saslClient == nil {
247 return fmt.Errorf("received unexpected AUTHENTICATE message")
248 }
249
250 // TODO: if a challenge is 400 bytes long, buffer it
251 var challengeStr string
252 if err := parseMessageParams(msg, &challengeStr); err != nil {
253 uc.SendMessage(&irc.Message{
254 Command: "AUTHENTICATE",
255 Params: []string{"*"},
256 })
257 return err
258 }
259
260 var challenge []byte
261 if challengeStr != "+" {
262 var err error
263 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
264 if err != nil {
265 uc.SendMessage(&irc.Message{
266 Command: "AUTHENTICATE",
267 Params: []string{"*"},
268 })
269 return err
270 }
271 }
272
273 var resp []byte
274 var err error
275 if !uc.saslStarted {
276 _, resp, err = uc.saslClient.Start()
277 uc.saslStarted = true
278 } else {
279 resp, err = uc.saslClient.Next(challenge)
280 }
281 if err != nil {
282 uc.SendMessage(&irc.Message{
283 Command: "AUTHENTICATE",
284 Params: []string{"*"},
285 })
286 return err
287 }
288
289 // TODO: send response in multiple chunks if >= 400 bytes
290 var respStr = "+"
291 if resp != nil {
292 respStr = base64.StdEncoding.EncodeToString(resp)
293 }
294
295 uc.SendMessage(&irc.Message{
296 Command: "AUTHENTICATE",
297 Params: []string{respStr},
298 })
299 case rpl_loggedin:
300 var account string
301 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
302 return err
303 }
304 uc.logger.Printf("logged in with account %q", account)
305 case rpl_loggedout:
306 uc.logger.Printf("logged out")
307 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
308 var info string
309 if err := parseMessageParams(msg, nil, &info); err != nil {
310 return err
311 }
312 switch msg.Command {
313 case err_nicklocked:
314 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
315 case err_saslfail:
316 uc.logger.Printf("SASL authentication failed: %v", info)
317 case err_sasltoolong:
318 uc.logger.Printf("SASL message too long: %v", info)
319 }
320
321 uc.saslClient = nil
322 uc.saslStarted = false
323
324 uc.SendMessage(&irc.Message{
325 Command: "CAP",
326 Params: []string{"END"},
327 })
328 case irc.RPL_WELCOME:
329 uc.registered = true
330 uc.logger.Printf("connection registered")
331
332 channels, err := uc.srv.db.ListChannels(uc.network.ID)
333 if err != nil {
334 uc.logger.Printf("failed to list channels from database: %v", err)
335 break
336 }
337
338 for _, ch := range channels {
339 uc.SendMessage(&irc.Message{
340 Command: "JOIN",
341 Params: []string{ch.Name},
342 })
343 }
344 case irc.RPL_MYINFO:
345 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
346 return err
347 }
348 if len(msg.Params) > 5 {
349 uc.channelModesWithParam = msg.Params[5]
350 }
351 case "NICK":
352 if msg.Prefix == nil {
353 return fmt.Errorf("expected a prefix")
354 }
355
356 var newNick string
357 if err := parseMessageParams(msg, &newNick); err != nil {
358 return err
359 }
360
361 if msg.Prefix.Name == uc.nick {
362 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
363 uc.nick = newNick
364 }
365
366 for _, ch := range uc.channels {
367 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
368 delete(ch.Members, msg.Prefix.Name)
369 ch.Members[newNick] = membership
370 }
371 }
372
373 if msg.Prefix.Name != uc.nick {
374 uc.forEachDownstream(func(dc *downstreamConn) {
375 dc.SendMessage(&irc.Message{
376 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
377 Command: "NICK",
378 Params: []string{newNick},
379 })
380 })
381 }
382 case "JOIN":
383 if msg.Prefix == nil {
384 return fmt.Errorf("expected a prefix")
385 }
386
387 var channels string
388 if err := parseMessageParams(msg, &channels); err != nil {
389 return err
390 }
391
392 for _, ch := range strings.Split(channels, ",") {
393 if msg.Prefix.Name == uc.nick {
394 uc.logger.Printf("joined channel %q", ch)
395 uc.channels[ch] = &upstreamChannel{
396 Name: ch,
397 conn: uc,
398 Members: make(map[string]membership),
399 }
400 } else {
401 ch, err := uc.getChannel(ch)
402 if err != nil {
403 return err
404 }
405 ch.Members[msg.Prefix.Name] = 0
406 }
407
408 uc.forEachDownstream(func(dc *downstreamConn) {
409 dc.SendMessage(&irc.Message{
410 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
411 Command: "JOIN",
412 Params: []string{dc.marshalChannel(uc, ch)},
413 })
414 })
415 }
416 case "PART":
417 if msg.Prefix == nil {
418 return fmt.Errorf("expected a prefix")
419 }
420
421 var channels string
422 if err := parseMessageParams(msg, &channels); err != nil {
423 return err
424 }
425
426 for _, ch := range strings.Split(channels, ",") {
427 if msg.Prefix.Name == uc.nick {
428 uc.logger.Printf("parted channel %q", ch)
429 delete(uc.channels, ch)
430 } else {
431 ch, err := uc.getChannel(ch)
432 if err != nil {
433 return err
434 }
435 delete(ch.Members, msg.Prefix.Name)
436 }
437
438 uc.forEachDownstream(func(dc *downstreamConn) {
439 dc.SendMessage(&irc.Message{
440 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
441 Command: "PART",
442 Params: []string{dc.marshalChannel(uc, ch)},
443 })
444 })
445 }
446 case "QUIT":
447 if msg.Prefix == nil {
448 return fmt.Errorf("expected a prefix")
449 }
450
451 if msg.Prefix.Name == uc.nick {
452 uc.logger.Printf("quit")
453 }
454
455 for _, ch := range uc.channels {
456 delete(ch.Members, msg.Prefix.Name)
457 }
458
459 if msg.Prefix.Name != uc.nick {
460 uc.forEachDownstream(func(dc *downstreamConn) {
461 dc.SendMessage(&irc.Message{
462 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
463 Command: "QUIT",
464 Params: msg.Params,
465 })
466 })
467 }
468 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
469 var name, topic string
470 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
471 return err
472 }
473 ch, err := uc.getChannel(name)
474 if err != nil {
475 return err
476 }
477 if msg.Command == irc.RPL_TOPIC {
478 ch.Topic = topic
479 } else {
480 ch.Topic = ""
481 }
482 case "TOPIC":
483 var name string
484 if err := parseMessageParams(msg, &name); err != nil {
485 return err
486 }
487 ch, err := uc.getChannel(name)
488 if err != nil {
489 return err
490 }
491 if len(msg.Params) > 1 {
492 ch.Topic = msg.Params[1]
493 } else {
494 ch.Topic = ""
495 }
496 uc.forEachDownstream(func(dc *downstreamConn) {
497 params := []string{dc.marshalChannel(uc, name)}
498 if ch.Topic != "" {
499 params = append(params, ch.Topic)
500 }
501 dc.SendMessage(&irc.Message{
502 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
503 Command: "TOPIC",
504 Params: params,
505 })
506 })
507 case rpl_topicwhotime:
508 var name, who, timeStr string
509 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
510 return err
511 }
512 ch, err := uc.getChannel(name)
513 if err != nil {
514 return err
515 }
516 ch.TopicWho = who
517 sec, err := strconv.ParseInt(timeStr, 10, 64)
518 if err != nil {
519 return fmt.Errorf("failed to parse topic time: %v", err)
520 }
521 ch.TopicTime = time.Unix(sec, 0)
522 case irc.RPL_NAMREPLY:
523 var name, statusStr, members string
524 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
525 return err
526 }
527 ch, err := uc.getChannel(name)
528 if err != nil {
529 return err
530 }
531
532 status, err := parseChannelStatus(statusStr)
533 if err != nil {
534 return err
535 }
536 ch.Status = status
537
538 for _, s := range strings.Split(members, " ") {
539 membership, nick := parseMembershipPrefix(s)
540 ch.Members[nick] = membership
541 }
542 case irc.RPL_ENDOFNAMES:
543 var name string
544 if err := parseMessageParams(msg, nil, &name); err != nil {
545 return err
546 }
547 ch, err := uc.getChannel(name)
548 if err != nil {
549 return err
550 }
551
552 if ch.complete {
553 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
554 }
555 ch.complete = true
556
557 uc.forEachDownstream(func(dc *downstreamConn) {
558 forwardChannel(dc, ch)
559 })
560 case "PRIVMSG":
561 if err := parseMessageParams(msg, nil, nil); err != nil {
562 return err
563 }
564 uc.ring.Produce(msg)
565 case "INVITE":
566 var nick string
567 var channel string
568 if err := parseMessageParams(msg, &nick, &channel); err != nil {
569 return err
570 }
571
572 uc.forEachDownstream(func(dc *downstreamConn) {
573 dc.SendMessage(&irc.Message{
574 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
575 Command: "INVITE",
576 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
577 })
578 })
579 case irc.RPL_YOURHOST, irc.RPL_CREATED:
580 // Ignore
581 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
582 // Ignore
583 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
584 // Ignore
585 case rpl_localusers, rpl_globalusers:
586 // Ignore
587 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
588 // Ignore
589 default:
590 uc.logger.Printf("unhandled message: %v", msg)
591 }
592 return nil
593}
594
595func (uc *upstreamConn) register() {
596 uc.nick = uc.network.Nick
597 uc.username = uc.network.Username
598 if uc.username == "" {
599 uc.username = uc.nick
600 }
601 uc.realname = uc.network.Realname
602 if uc.realname == "" {
603 uc.realname = uc.nick
604 }
605
606 uc.SendMessage(&irc.Message{
607 Command: "CAP",
608 Params: []string{"LS", "302"},
609 })
610
611 if uc.network.Pass != "" {
612 uc.SendMessage(&irc.Message{
613 Command: "PASS",
614 Params: []string{uc.network.Pass},
615 })
616 }
617
618 uc.SendMessage(&irc.Message{
619 Command: "NICK",
620 Params: []string{uc.nick},
621 })
622 uc.SendMessage(&irc.Message{
623 Command: "USER",
624 Params: []string{uc.username, "0", "*", uc.realname},
625 })
626}
627
628func (uc *upstreamConn) requestSASL() bool {
629 if uc.network.SASL.Mechanism == "" {
630 return false
631 }
632
633 v, ok := uc.caps["sasl"]
634 if !ok {
635 return false
636 }
637 if v != "" {
638 mechanisms := strings.Split(v, ",")
639 found := false
640 for _, mech := range mechanisms {
641 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
642 found = true
643 break
644 }
645 }
646 if !found {
647 return false
648 }
649 }
650
651 return true
652}
653
654func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
655 auth := &uc.network.SASL
656 switch name {
657 case "sasl":
658 if !ok {
659 uc.logger.Printf("server refused to acknowledge the SASL capability")
660 return nil
661 }
662
663 switch auth.Mechanism {
664 case "PLAIN":
665 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
666 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
667 default:
668 return fmt.Errorf("unsupported SASL mechanism %q", name)
669 }
670
671 uc.SendMessage(&irc.Message{
672 Command: "AUTHENTICATE",
673 Params: []string{auth.Mechanism},
674 })
675 }
676 return nil
677}
678
679func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
680 for {
681 msg, err := uc.irc.ReadMessage()
682 if err == io.EOF {
683 break
684 } else if err != nil {
685 return fmt.Errorf("failed to read IRC command: %v", err)
686 }
687
688 if uc.srv.Debug {
689 uc.logger.Printf("received: %v", msg)
690 }
691
692 ch <- upstreamIncomingMessage{msg, uc}
693 }
694
695 return nil
696}
697
698func (uc *upstreamConn) SendMessage(msg *irc.Message) {
699 uc.outgoing <- msg
700}
Note: See TracBrowser for help on using the repository browser.