source: code/trunk/upstream.go@ 96

Last change on this file since 96 was 96, checked in by contact, 5 years ago

Update dependencies

go-irc v3.1.1 contains a breaking change.

References: https://github.com/go-irc/irc/issues/76

File size: 15.2 KB
Line 
1package jounce
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "fmt"
7 "io"
8 "net"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/emersion/go-sasl"
14 "gopkg.in/irc.v3"
15)
16
17type upstreamChannel struct {
18 Name string
19 conn *upstreamConn
20 Topic string
21 TopicWho string
22 TopicTime time.Time
23 Status channelStatus
24 modes modeSet
25 Members map[string]membership
26 complete bool
27}
28
29type upstreamConn struct {
30 network *network
31 logger Logger
32 net net.Conn
33 irc *irc.Conn
34 srv *Server
35 user *user
36 messages chan<- *irc.Message
37 ring *Ring
38
39 serverName string
40 availableUserModes string
41 availableChannelModes string
42 channelModesWithParam string
43
44 registered bool
45 nick string
46 username string
47 realname string
48 closed bool
49 modes modeSet
50 channels map[string]*upstreamChannel
51 history map[string]uint64
52 caps map[string]string
53
54 saslClient sasl.Client
55 saslStarted bool
56}
57
58func connectToUpstream(network *network) (*upstreamConn, error) {
59 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
60
61 addr := network.Addr
62 if !strings.ContainsRune(addr, ':') {
63 addr = addr + ":6697"
64 }
65
66 logger.Printf("connecting to TLS server at address %q", addr)
67 netConn, err := tls.Dial("tcp", addr, nil)
68 if err != nil {
69 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
70 }
71
72 setKeepAlive(netConn)
73
74 msgs := make(chan *irc.Message, 64)
75 uc := &upstreamConn{
76 network: network,
77 logger: logger,
78 net: netConn,
79 irc: irc.NewConn(netConn),
80 srv: network.user.srv,
81 user: network.user,
82 messages: msgs,
83 ring: NewRing(network.user.srv.RingCap),
84 channels: make(map[string]*upstreamChannel),
85 history: make(map[string]uint64),
86 caps: make(map[string]string),
87 }
88
89 go func() {
90 for msg := range msgs {
91 if uc.srv.Debug {
92 uc.logger.Printf("sent: %v", msg)
93 }
94 if err := uc.irc.WriteMessage(msg); err != nil {
95 uc.logger.Printf("failed to write message: %v", err)
96 }
97 }
98 if err := uc.net.Close(); err != nil {
99 uc.logger.Printf("failed to close connection: %v", err)
100 } else {
101 uc.logger.Printf("connection closed")
102 }
103 }()
104
105 return uc, nil
106}
107
108func (uc *upstreamConn) Close() error {
109 if uc.closed {
110 return fmt.Errorf("upstream connection already closed")
111 }
112 close(uc.messages)
113 uc.closed = true
114 return nil
115}
116
117func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
118 uc.user.forEachDownstream(func(dc *downstreamConn) {
119 if dc.network != nil && dc.network != uc.network {
120 return
121 }
122 f(dc)
123 })
124}
125
126func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
127 ch, ok := uc.channels[name]
128 if !ok {
129 return nil, fmt.Errorf("unknown channel %q", name)
130 }
131 return ch, nil
132}
133
134func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
135 switch msg.Command {
136 case "PING":
137 uc.SendMessage(&irc.Message{
138 Command: "PONG",
139 Params: msg.Params,
140 })
141 return nil
142 case "MODE":
143 if msg.Prefix == nil {
144 return fmt.Errorf("missing prefix")
145 }
146
147 var name, modeStr string
148 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
149 return err
150 }
151
152 if name == msg.Prefix.Name { // user mode change
153 if name != uc.nick {
154 return fmt.Errorf("received MODE message for unknow nick %q", name)
155 }
156 return uc.modes.Apply(modeStr)
157 } else { // channel mode change
158 ch, err := uc.getChannel(name)
159 if err != nil {
160 return err
161 }
162 if err := ch.modes.Apply(modeStr); err != nil {
163 return err
164 }
165
166 uc.forEachDownstream(func(dc *downstreamConn) {
167 dc.SendMessage(&irc.Message{
168 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
169 Command: "MODE",
170 Params: []string{dc.marshalChannel(uc, name), modeStr},
171 })
172 })
173 }
174 case "NOTICE":
175 uc.logger.Print(msg)
176 case "CAP":
177 var subCmd string
178 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
179 return err
180 }
181 subCmd = strings.ToUpper(subCmd)
182 subParams := msg.Params[2:]
183 switch subCmd {
184 case "LS":
185 if len(subParams) < 1 {
186 return newNeedMoreParamsError(msg.Command)
187 }
188 caps := strings.Fields(subParams[len(subParams)-1])
189 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
190
191 for _, s := range caps {
192 kv := strings.SplitN(s, "=", 2)
193 k := strings.ToLower(kv[0])
194 var v string
195 if len(kv) == 2 {
196 v = kv[1]
197 }
198 uc.caps[k] = v
199 }
200
201 if more {
202 break // wait to receive all capabilities
203 }
204
205 if uc.requestSASL() {
206 uc.SendMessage(&irc.Message{
207 Command: "CAP",
208 Params: []string{"REQ", "sasl"},
209 })
210 break // we'll send CAP END after authentication is completed
211 }
212
213 uc.SendMessage(&irc.Message{
214 Command: "CAP",
215 Params: []string{"END"},
216 })
217 case "ACK", "NAK":
218 if len(subParams) < 1 {
219 return newNeedMoreParamsError(msg.Command)
220 }
221 caps := strings.Fields(subParams[0])
222
223 for _, name := range caps {
224 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
225 return err
226 }
227 }
228
229 if uc.saslClient == nil {
230 uc.SendMessage(&irc.Message{
231 Command: "CAP",
232 Params: []string{"END"},
233 })
234 }
235 default:
236 uc.logger.Printf("unhandled message: %v", msg)
237 }
238 case "AUTHENTICATE":
239 if uc.saslClient == nil {
240 return fmt.Errorf("received unexpected AUTHENTICATE message")
241 }
242
243 // TODO: if a challenge is 400 bytes long, buffer it
244 var challengeStr string
245 if err := parseMessageParams(msg, &challengeStr); err != nil {
246 uc.SendMessage(&irc.Message{
247 Command: "AUTHENTICATE",
248 Params: []string{"*"},
249 })
250 return err
251 }
252
253 var challenge []byte
254 if challengeStr != "+" {
255 var err error
256 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
257 if err != nil {
258 uc.SendMessage(&irc.Message{
259 Command: "AUTHENTICATE",
260 Params: []string{"*"},
261 })
262 return err
263 }
264 }
265
266 var resp []byte
267 var err error
268 if !uc.saslStarted {
269 _, resp, err = uc.saslClient.Start()
270 uc.saslStarted = true
271 } else {
272 resp, err = uc.saslClient.Next(challenge)
273 }
274 if err != nil {
275 uc.SendMessage(&irc.Message{
276 Command: "AUTHENTICATE",
277 Params: []string{"*"},
278 })
279 return err
280 }
281
282 // TODO: send response in multiple chunks if >= 400 bytes
283 var respStr = "+"
284 if resp != nil {
285 respStr = base64.StdEncoding.EncodeToString(resp)
286 }
287
288 uc.SendMessage(&irc.Message{
289 Command: "AUTHENTICATE",
290 Params: []string{respStr},
291 })
292 case rpl_loggedin:
293 var account string
294 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
295 return err
296 }
297 uc.logger.Printf("logged in with account %q", account)
298 case rpl_loggedout:
299 uc.logger.Printf("logged out")
300 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
301 var info string
302 if err := parseMessageParams(msg, nil, &info); err != nil {
303 return err
304 }
305 switch msg.Command {
306 case err_nicklocked:
307 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
308 case err_saslfail:
309 uc.logger.Printf("SASL authentication failed: %v", info)
310 case err_sasltoolong:
311 uc.logger.Printf("SASL message too long: %v", info)
312 }
313
314 uc.saslClient = nil
315 uc.saslStarted = false
316
317 uc.SendMessage(&irc.Message{
318 Command: "CAP",
319 Params: []string{"END"},
320 })
321 case irc.RPL_WELCOME:
322 uc.registered = true
323 uc.logger.Printf("connection registered")
324
325 channels, err := uc.srv.db.ListChannels(uc.network.ID)
326 if err != nil {
327 uc.logger.Printf("failed to list channels from database: %v", err)
328 break
329 }
330
331 for _, ch := range channels {
332 uc.SendMessage(&irc.Message{
333 Command: "JOIN",
334 Params: []string{ch.Name},
335 })
336 }
337 case irc.RPL_MYINFO:
338 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
339 return err
340 }
341 if len(msg.Params) > 5 {
342 uc.channelModesWithParam = msg.Params[5]
343 }
344 case "NICK":
345 if msg.Prefix == nil {
346 return fmt.Errorf("expected a prefix")
347 }
348
349 var newNick string
350 if err := parseMessageParams(msg, &newNick); err != nil {
351 return err
352 }
353
354 if msg.Prefix.Name == uc.nick {
355 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
356 uc.nick = newNick
357 }
358
359 for _, ch := range uc.channels {
360 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
361 delete(ch.Members, msg.Prefix.Name)
362 ch.Members[newNick] = membership
363 }
364 }
365
366 if msg.Prefix.Name != uc.nick {
367 uc.forEachDownstream(func(dc *downstreamConn) {
368 dc.SendMessage(&irc.Message{
369 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
370 Command: "NICK",
371 Params: []string{newNick},
372 })
373 })
374 }
375 case "JOIN":
376 if msg.Prefix == nil {
377 return fmt.Errorf("expected a prefix")
378 }
379
380 var channels string
381 if err := parseMessageParams(msg, &channels); err != nil {
382 return err
383 }
384
385 for _, ch := range strings.Split(channels, ",") {
386 if msg.Prefix.Name == uc.nick {
387 uc.logger.Printf("joined channel %q", ch)
388 uc.channels[ch] = &upstreamChannel{
389 Name: ch,
390 conn: uc,
391 Members: make(map[string]membership),
392 }
393 } else {
394 ch, err := uc.getChannel(ch)
395 if err != nil {
396 return err
397 }
398 ch.Members[msg.Prefix.Name] = 0
399 }
400
401 uc.forEachDownstream(func(dc *downstreamConn) {
402 dc.SendMessage(&irc.Message{
403 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
404 Command: "JOIN",
405 Params: []string{dc.marshalChannel(uc, ch)},
406 })
407 })
408 }
409 case "PART":
410 if msg.Prefix == nil {
411 return fmt.Errorf("expected a prefix")
412 }
413
414 var channels string
415 if err := parseMessageParams(msg, &channels); err != nil {
416 return err
417 }
418
419 for _, ch := range strings.Split(channels, ",") {
420 if msg.Prefix.Name == uc.nick {
421 uc.logger.Printf("parted channel %q", ch)
422 delete(uc.channels, ch)
423 } else {
424 ch, err := uc.getChannel(ch)
425 if err != nil {
426 return err
427 }
428 delete(ch.Members, msg.Prefix.Name)
429 }
430
431 uc.forEachDownstream(func(dc *downstreamConn) {
432 dc.SendMessage(&irc.Message{
433 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
434 Command: "PART",
435 Params: []string{dc.marshalChannel(uc, ch)},
436 })
437 })
438 }
439 case "QUIT":
440 if msg.Prefix == nil {
441 return fmt.Errorf("expected a prefix")
442 }
443
444 if msg.Prefix.Name == uc.nick {
445 uc.logger.Printf("quit")
446 }
447
448 for _, ch := range uc.channels {
449 delete(ch.Members, msg.Prefix.Name)
450 }
451
452 if msg.Prefix.Name != uc.nick {
453 uc.forEachDownstream(func(dc *downstreamConn) {
454 dc.SendMessage(&irc.Message{
455 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
456 Command: "QUIT",
457 Params: msg.Params,
458 })
459 })
460 }
461 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
462 var name, topic string
463 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
464 return err
465 }
466 ch, err := uc.getChannel(name)
467 if err != nil {
468 return err
469 }
470 if msg.Command == irc.RPL_TOPIC {
471 ch.Topic = topic
472 } else {
473 ch.Topic = ""
474 }
475 case "TOPIC":
476 var name string
477 if err := parseMessageParams(msg, &name); err != nil {
478 return err
479 }
480 ch, err := uc.getChannel(name)
481 if err != nil {
482 return err
483 }
484 if len(msg.Params) > 1 {
485 ch.Topic = msg.Params[1]
486 } else {
487 ch.Topic = ""
488 }
489 uc.forEachDownstream(func(dc *downstreamConn) {
490 params := []string{dc.marshalChannel(uc, name)}
491 if ch.Topic != "" {
492 params = append(params, ch.Topic)
493 }
494 dc.SendMessage(&irc.Message{
495 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
496 Command: "TOPIC",
497 Params: params,
498 })
499 })
500 case rpl_topicwhotime:
501 var name, who, timeStr string
502 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
503 return err
504 }
505 ch, err := uc.getChannel(name)
506 if err != nil {
507 return err
508 }
509 ch.TopicWho = who
510 sec, err := strconv.ParseInt(timeStr, 10, 64)
511 if err != nil {
512 return fmt.Errorf("failed to parse topic time: %v", err)
513 }
514 ch.TopicTime = time.Unix(sec, 0)
515 case irc.RPL_NAMREPLY:
516 var name, statusStr, members string
517 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
518 return err
519 }
520 ch, err := uc.getChannel(name)
521 if err != nil {
522 return err
523 }
524
525 status, err := parseChannelStatus(statusStr)
526 if err != nil {
527 return err
528 }
529 ch.Status = status
530
531 for _, s := range strings.Split(members, " ") {
532 membership, nick := parseMembershipPrefix(s)
533 ch.Members[nick] = membership
534 }
535 case irc.RPL_ENDOFNAMES:
536 var name string
537 if err := parseMessageParams(msg, nil, &name); err != nil {
538 return err
539 }
540 ch, err := uc.getChannel(name)
541 if err != nil {
542 return err
543 }
544
545 if ch.complete {
546 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
547 }
548 ch.complete = true
549
550 uc.forEachDownstream(func(dc *downstreamConn) {
551 forwardChannel(dc, ch)
552 })
553 case "PRIVMSG":
554 if err := parseMessageParams(msg, nil, nil); err != nil {
555 return err
556 }
557 uc.ring.Produce(msg)
558 case irc.RPL_YOURHOST, irc.RPL_CREATED:
559 // Ignore
560 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
561 // Ignore
562 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
563 // Ignore
564 case rpl_localusers, rpl_globalusers:
565 // Ignore
566 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
567 // Ignore
568 default:
569 uc.logger.Printf("unhandled message: %v", msg)
570 }
571 return nil
572}
573
574func (uc *upstreamConn) register() {
575 uc.nick = uc.network.Nick
576 uc.username = uc.network.Username
577 if uc.username == "" {
578 uc.username = uc.nick
579 }
580 uc.realname = uc.network.Realname
581 if uc.realname == "" {
582 uc.realname = uc.nick
583 }
584
585 uc.SendMessage(&irc.Message{
586 Command: "CAP",
587 Params: []string{"LS", "302"},
588 })
589
590 if uc.network.Pass != "" {
591 uc.SendMessage(&irc.Message{
592 Command: "PASS",
593 Params: []string{uc.network.Pass},
594 })
595 }
596
597 uc.SendMessage(&irc.Message{
598 Command: "NICK",
599 Params: []string{uc.nick},
600 })
601 uc.SendMessage(&irc.Message{
602 Command: "USER",
603 Params: []string{uc.username, "0", "*", uc.realname},
604 })
605}
606
607func (uc *upstreamConn) requestSASL() bool {
608 if uc.network.SASL.Mechanism == "" {
609 return false
610 }
611
612 v, ok := uc.caps["sasl"]
613 if !ok {
614 return false
615 }
616 if v != "" {
617 mechanisms := strings.Split(v, ",")
618 found := false
619 for _, mech := range mechanisms {
620 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
621 found = true
622 break
623 }
624 }
625 if !found {
626 return false
627 }
628 }
629
630 return true
631}
632
633func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
634 auth := &uc.network.SASL
635 switch name {
636 case "sasl":
637 if !ok {
638 uc.logger.Printf("server refused to acknowledge the SASL capability")
639 return nil
640 }
641
642 switch auth.Mechanism {
643 case "PLAIN":
644 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
645 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
646 default:
647 return fmt.Errorf("unsupported SASL mechanism %q", name)
648 }
649
650 uc.SendMessage(&irc.Message{
651 Command: "AUTHENTICATE",
652 Params: []string{auth.Mechanism},
653 })
654 }
655 return nil
656}
657
658func (uc *upstreamConn) readMessages() error {
659 for {
660 msg, err := uc.irc.ReadMessage()
661 if err == io.EOF {
662 break
663 } else if err != nil {
664 return fmt.Errorf("failed to read IRC command: %v", err)
665 }
666
667 if uc.srv.Debug {
668 uc.logger.Printf("received: %v", msg)
669 }
670
671 if err := uc.handleMessage(msg); err != nil {
672 uc.logger.Printf("failed to handle message %q: %v", msg, err)
673 }
674 }
675
676 return nil
677}
678
679func (uc *upstreamConn) SendMessage(msg *irc.Message) {
680 uc.messages <- msg
681}
Note: See TracBrowser for help on using the repository browser.