source: code/trunk/upstream.go@ 121

Last change on this file since 121 was 117, checked in by contact, 5 years ago

Add basic infrastructure for bouncer service

File size: 16.0 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "fmt"
7 "io"
8 "net"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/emersion/go-sasl"
15 "gopkg.in/irc.v3"
16)
17
18type upstreamChannel struct {
19 Name string
20 conn *upstreamConn
21 Topic string
22 TopicWho string
23 TopicTime time.Time
24 Status channelStatus
25 modes modeSet
26 Members map[string]membership
27 complete bool
28}
29
30type upstreamConn struct {
31 network *network
32 logger Logger
33 net net.Conn
34 irc *irc.Conn
35 srv *Server
36 user *user
37 outgoing chan<- *irc.Message
38 ring *Ring
39
40 serverName string
41 availableUserModes string
42 availableChannelModes string
43 channelModesWithParam string
44
45 registered bool
46 nick string
47 username string
48 realname string
49 closed bool
50 modes modeSet
51 channels map[string]*upstreamChannel
52 caps map[string]string
53
54 saslClient sasl.Client
55 saslStarted bool
56
57 lock sync.Mutex
58 history map[string]uint64 // TODO: move to network
59}
60
61func connectToUpstream(network *network) (*upstreamConn, error) {
62 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
63
64 addr := network.Addr
65 if !strings.ContainsRune(addr, ':') {
66 addr = addr + ":6697"
67 }
68
69 logger.Printf("connecting to TLS server at address %q", addr)
70 netConn, err := tls.Dial("tcp", addr, nil)
71 if err != nil {
72 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
73 }
74
75 setKeepAlive(netConn)
76
77 outgoing := make(chan *irc.Message, 64)
78 uc := &upstreamConn{
79 network: network,
80 logger: logger,
81 net: netConn,
82 irc: irc.NewConn(netConn),
83 srv: network.user.srv,
84 user: network.user,
85 outgoing: outgoing,
86 ring: NewRing(network.user.srv.RingCap),
87 channels: make(map[string]*upstreamChannel),
88 history: make(map[string]uint64),
89 caps: make(map[string]string),
90 }
91
92 go func() {
93 for msg := range outgoing {
94 if uc.srv.Debug {
95 uc.logger.Printf("sent: %v", msg)
96 }
97 if err := uc.irc.WriteMessage(msg); err != nil {
98 uc.logger.Printf("failed to write message: %v", err)
99 }
100 }
101 if err := uc.net.Close(); err != nil {
102 uc.logger.Printf("failed to close connection: %v", err)
103 } else {
104 uc.logger.Printf("connection closed")
105 }
106 }()
107
108 return uc, nil
109}
110
111func (uc *upstreamConn) Close() error {
112 if uc.closed {
113 return fmt.Errorf("upstream connection already closed")
114 }
115 close(uc.outgoing)
116 uc.closed = true
117 return nil
118}
119
120func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
121 uc.user.forEachDownstream(func(dc *downstreamConn) {
122 if dc.network != nil && dc.network != uc.network {
123 return
124 }
125 f(dc)
126 })
127}
128
129func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
130 ch, ok := uc.channels[name]
131 if !ok {
132 return nil, fmt.Errorf("unknown channel %q", name)
133 }
134 return ch, nil
135}
136
137func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
138 switch msg.Command {
139 case "PING":
140 uc.SendMessage(&irc.Message{
141 Command: "PONG",
142 Params: msg.Params,
143 })
144 return nil
145 case "MODE":
146 if msg.Prefix == nil {
147 return fmt.Errorf("missing prefix")
148 }
149
150 var name, modeStr string
151 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
152 return err
153 }
154
155 if name == msg.Prefix.Name { // user mode change
156 if name != uc.nick {
157 return fmt.Errorf("received MODE message for unknow nick %q", name)
158 }
159 return uc.modes.Apply(modeStr)
160 } else { // channel mode change
161 ch, err := uc.getChannel(name)
162 if err != nil {
163 return err
164 }
165 if err := ch.modes.Apply(modeStr); err != nil {
166 return err
167 }
168
169 uc.forEachDownstream(func(dc *downstreamConn) {
170 dc.SendMessage(&irc.Message{
171 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
172 Command: "MODE",
173 Params: []string{dc.marshalChannel(uc, name), modeStr},
174 })
175 })
176 }
177 case "NOTICE":
178 uc.logger.Print(msg)
179
180 uc.forEachDownstream(func(dc *downstreamConn) {
181 dc.SendMessage(msg)
182 })
183 case "CAP":
184 var subCmd string
185 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
186 return err
187 }
188 subCmd = strings.ToUpper(subCmd)
189 subParams := msg.Params[2:]
190 switch subCmd {
191 case "LS":
192 if len(subParams) < 1 {
193 return newNeedMoreParamsError(msg.Command)
194 }
195 caps := strings.Fields(subParams[len(subParams)-1])
196 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
197
198 for _, s := range caps {
199 kv := strings.SplitN(s, "=", 2)
200 k := strings.ToLower(kv[0])
201 var v string
202 if len(kv) == 2 {
203 v = kv[1]
204 }
205 uc.caps[k] = v
206 }
207
208 if more {
209 break // wait to receive all capabilities
210 }
211
212 if uc.requestSASL() {
213 uc.SendMessage(&irc.Message{
214 Command: "CAP",
215 Params: []string{"REQ", "sasl"},
216 })
217 break // we'll send CAP END after authentication is completed
218 }
219
220 uc.SendMessage(&irc.Message{
221 Command: "CAP",
222 Params: []string{"END"},
223 })
224 case "ACK", "NAK":
225 if len(subParams) < 1 {
226 return newNeedMoreParamsError(msg.Command)
227 }
228 caps := strings.Fields(subParams[0])
229
230 for _, name := range caps {
231 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
232 return err
233 }
234 }
235
236 if uc.saslClient == nil {
237 uc.SendMessage(&irc.Message{
238 Command: "CAP",
239 Params: []string{"END"},
240 })
241 }
242 default:
243 uc.logger.Printf("unhandled message: %v", msg)
244 }
245 case "AUTHENTICATE":
246 if uc.saslClient == nil {
247 return fmt.Errorf("received unexpected AUTHENTICATE message")
248 }
249
250 // TODO: if a challenge is 400 bytes long, buffer it
251 var challengeStr string
252 if err := parseMessageParams(msg, &challengeStr); err != nil {
253 uc.SendMessage(&irc.Message{
254 Command: "AUTHENTICATE",
255 Params: []string{"*"},
256 })
257 return err
258 }
259
260 var challenge []byte
261 if challengeStr != "+" {
262 var err error
263 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
264 if err != nil {
265 uc.SendMessage(&irc.Message{
266 Command: "AUTHENTICATE",
267 Params: []string{"*"},
268 })
269 return err
270 }
271 }
272
273 var resp []byte
274 var err error
275 if !uc.saslStarted {
276 _, resp, err = uc.saslClient.Start()
277 uc.saslStarted = true
278 } else {
279 resp, err = uc.saslClient.Next(challenge)
280 }
281 if err != nil {
282 uc.SendMessage(&irc.Message{
283 Command: "AUTHENTICATE",
284 Params: []string{"*"},
285 })
286 return err
287 }
288
289 // TODO: send response in multiple chunks if >= 400 bytes
290 var respStr = "+"
291 if resp != nil {
292 respStr = base64.StdEncoding.EncodeToString(resp)
293 }
294
295 uc.SendMessage(&irc.Message{
296 Command: "AUTHENTICATE",
297 Params: []string{respStr},
298 })
299 case rpl_loggedin:
300 var account string
301 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
302 return err
303 }
304 uc.logger.Printf("logged in with account %q", account)
305 case rpl_loggedout:
306 uc.logger.Printf("logged out")
307 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
308 var info string
309 if err := parseMessageParams(msg, nil, &info); err != nil {
310 return err
311 }
312 switch msg.Command {
313 case err_nicklocked:
314 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
315 case err_saslfail:
316 uc.logger.Printf("SASL authentication failed: %v", info)
317 case err_sasltoolong:
318 uc.logger.Printf("SASL message too long: %v", info)
319 }
320
321 uc.saslClient = nil
322 uc.saslStarted = false
323
324 uc.SendMessage(&irc.Message{
325 Command: "CAP",
326 Params: []string{"END"},
327 })
328 case irc.RPL_WELCOME:
329 uc.registered = true
330 uc.logger.Printf("connection registered")
331
332 channels, err := uc.srv.db.ListChannels(uc.network.ID)
333 if err != nil {
334 uc.logger.Printf("failed to list channels from database: %v", err)
335 break
336 }
337
338 for _, ch := range channels {
339 uc.SendMessage(&irc.Message{
340 Command: "JOIN",
341 Params: []string{ch.Name},
342 })
343 }
344 case irc.RPL_MYINFO:
345 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
346 return err
347 }
348 if len(msg.Params) > 5 {
349 uc.channelModesWithParam = msg.Params[5]
350 }
351 case "NICK":
352 if msg.Prefix == nil {
353 return fmt.Errorf("expected a prefix")
354 }
355
356 var newNick string
357 if err := parseMessageParams(msg, &newNick); err != nil {
358 return err
359 }
360
361 if msg.Prefix.Name == uc.nick {
362 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
363 uc.nick = newNick
364 }
365
366 for _, ch := range uc.channels {
367 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
368 delete(ch.Members, msg.Prefix.Name)
369 ch.Members[newNick] = membership
370 }
371 }
372
373 if msg.Prefix.Name != uc.nick {
374 uc.forEachDownstream(func(dc *downstreamConn) {
375 dc.SendMessage(&irc.Message{
376 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
377 Command: "NICK",
378 Params: []string{newNick},
379 })
380 })
381 }
382 case "JOIN":
383 if msg.Prefix == nil {
384 return fmt.Errorf("expected a prefix")
385 }
386
387 var channels string
388 if err := parseMessageParams(msg, &channels); err != nil {
389 return err
390 }
391
392 for _, ch := range strings.Split(channels, ",") {
393 if msg.Prefix.Name == uc.nick {
394 uc.logger.Printf("joined channel %q", ch)
395 uc.channels[ch] = &upstreamChannel{
396 Name: ch,
397 conn: uc,
398 Members: make(map[string]membership),
399 }
400 } else {
401 ch, err := uc.getChannel(ch)
402 if err != nil {
403 return err
404 }
405 ch.Members[msg.Prefix.Name] = 0
406 }
407
408 uc.forEachDownstream(func(dc *downstreamConn) {
409 dc.SendMessage(&irc.Message{
410 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
411 Command: "JOIN",
412 Params: []string{dc.marshalChannel(uc, ch)},
413 })
414 })
415 }
416 case "PART":
417 if msg.Prefix == nil {
418 return fmt.Errorf("expected a prefix")
419 }
420
421 var channels string
422 if err := parseMessageParams(msg, &channels); err != nil {
423 return err
424 }
425
426 for _, ch := range strings.Split(channels, ",") {
427 if msg.Prefix.Name == uc.nick {
428 uc.logger.Printf("parted channel %q", ch)
429 delete(uc.channels, ch)
430 } else {
431 ch, err := uc.getChannel(ch)
432 if err != nil {
433 return err
434 }
435 delete(ch.Members, msg.Prefix.Name)
436 }
437
438 uc.forEachDownstream(func(dc *downstreamConn) {
439 dc.SendMessage(&irc.Message{
440 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
441 Command: "PART",
442 Params: []string{dc.marshalChannel(uc, ch)},
443 })
444 })
445 }
446 case "QUIT":
447 if msg.Prefix == nil {
448 return fmt.Errorf("expected a prefix")
449 }
450
451 if msg.Prefix.Name == uc.nick {
452 uc.logger.Printf("quit")
453 }
454
455 for _, ch := range uc.channels {
456 delete(ch.Members, msg.Prefix.Name)
457 }
458
459 if msg.Prefix.Name != uc.nick {
460 uc.forEachDownstream(func(dc *downstreamConn) {
461 dc.SendMessage(&irc.Message{
462 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
463 Command: "QUIT",
464 Params: msg.Params,
465 })
466 })
467 }
468 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
469 var name, topic string
470 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
471 return err
472 }
473 ch, err := uc.getChannel(name)
474 if err != nil {
475 return err
476 }
477 if msg.Command == irc.RPL_TOPIC {
478 ch.Topic = topic
479 } else {
480 ch.Topic = ""
481 }
482 case "TOPIC":
483 var name string
484 if err := parseMessageParams(msg, &name); err != nil {
485 return err
486 }
487 ch, err := uc.getChannel(name)
488 if err != nil {
489 return err
490 }
491 if len(msg.Params) > 1 {
492 ch.Topic = msg.Params[1]
493 } else {
494 ch.Topic = ""
495 }
496 uc.forEachDownstream(func(dc *downstreamConn) {
497 params := []string{dc.marshalChannel(uc, name)}
498 if ch.Topic != "" {
499 params = append(params, ch.Topic)
500 }
501 dc.SendMessage(&irc.Message{
502 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
503 Command: "TOPIC",
504 Params: params,
505 })
506 })
507 case rpl_topicwhotime:
508 var name, who, timeStr string
509 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
510 return err
511 }
512 ch, err := uc.getChannel(name)
513 if err != nil {
514 return err
515 }
516 ch.TopicWho = who
517 sec, err := strconv.ParseInt(timeStr, 10, 64)
518 if err != nil {
519 return fmt.Errorf("failed to parse topic time: %v", err)
520 }
521 ch.TopicTime = time.Unix(sec, 0)
522 case irc.RPL_NAMREPLY:
523 var name, statusStr, members string
524 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
525 return err
526 }
527 ch, err := uc.getChannel(name)
528 if err != nil {
529 return err
530 }
531
532 status, err := parseChannelStatus(statusStr)
533 if err != nil {
534 return err
535 }
536 ch.Status = status
537
538 for _, s := range strings.Split(members, " ") {
539 membership, nick := parseMembershipPrefix(s)
540 ch.Members[nick] = membership
541 }
542 case irc.RPL_ENDOFNAMES:
543 var name string
544 if err := parseMessageParams(msg, nil, &name); err != nil {
545 return err
546 }
547 ch, err := uc.getChannel(name)
548 if err != nil {
549 return err
550 }
551
552 if ch.complete {
553 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
554 }
555 ch.complete = true
556
557 uc.forEachDownstream(func(dc *downstreamConn) {
558 forwardChannel(dc, ch)
559 })
560 case "PRIVMSG":
561 if msg.Prefix == nil {
562 return fmt.Errorf("expected a prefix")
563 }
564
565 var nick string
566 if err := parseMessageParams(msg, &nick, nil); err != nil {
567 return err
568 }
569
570 if msg.Prefix.Name == serviceNick {
571 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
572 break
573 }
574 if nick == serviceNick {
575 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
576 break
577 }
578
579 uc.ring.Produce(msg)
580 case "INVITE":
581 var nick string
582 var channel string
583 if err := parseMessageParams(msg, &nick, &channel); err != nil {
584 return err
585 }
586
587 uc.forEachDownstream(func(dc *downstreamConn) {
588 dc.SendMessage(&irc.Message{
589 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
590 Command: "INVITE",
591 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
592 })
593 })
594 case irc.RPL_YOURHOST, irc.RPL_CREATED:
595 // Ignore
596 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
597 // Ignore
598 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
599 // Ignore
600 case rpl_localusers, rpl_globalusers:
601 // Ignore
602 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
603 // Ignore
604 default:
605 uc.logger.Printf("unhandled message: %v", msg)
606 }
607 return nil
608}
609
610func (uc *upstreamConn) register() {
611 uc.nick = uc.network.Nick
612 uc.username = uc.network.Username
613 if uc.username == "" {
614 uc.username = uc.nick
615 }
616 uc.realname = uc.network.Realname
617 if uc.realname == "" {
618 uc.realname = uc.nick
619 }
620
621 uc.SendMessage(&irc.Message{
622 Command: "CAP",
623 Params: []string{"LS", "302"},
624 })
625
626 if uc.network.Pass != "" {
627 uc.SendMessage(&irc.Message{
628 Command: "PASS",
629 Params: []string{uc.network.Pass},
630 })
631 }
632
633 uc.SendMessage(&irc.Message{
634 Command: "NICK",
635 Params: []string{uc.nick},
636 })
637 uc.SendMessage(&irc.Message{
638 Command: "USER",
639 Params: []string{uc.username, "0", "*", uc.realname},
640 })
641}
642
643func (uc *upstreamConn) requestSASL() bool {
644 if uc.network.SASL.Mechanism == "" {
645 return false
646 }
647
648 v, ok := uc.caps["sasl"]
649 if !ok {
650 return false
651 }
652 if v != "" {
653 mechanisms := strings.Split(v, ",")
654 found := false
655 for _, mech := range mechanisms {
656 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
657 found = true
658 break
659 }
660 }
661 if !found {
662 return false
663 }
664 }
665
666 return true
667}
668
669func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
670 auth := &uc.network.SASL
671 switch name {
672 case "sasl":
673 if !ok {
674 uc.logger.Printf("server refused to acknowledge the SASL capability")
675 return nil
676 }
677
678 switch auth.Mechanism {
679 case "PLAIN":
680 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
681 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
682 default:
683 return fmt.Errorf("unsupported SASL mechanism %q", name)
684 }
685
686 uc.SendMessage(&irc.Message{
687 Command: "AUTHENTICATE",
688 Params: []string{auth.Mechanism},
689 })
690 }
691 return nil
692}
693
694func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
695 for {
696 msg, err := uc.irc.ReadMessage()
697 if err == io.EOF {
698 break
699 } else if err != nil {
700 return fmt.Errorf("failed to read IRC command: %v", err)
701 }
702
703 if uc.srv.Debug {
704 uc.logger.Printf("received: %v", msg)
705 }
706
707 ch <- upstreamIncomingMessage{msg, uc}
708 }
709
710 return nil
711}
712
713func (uc *upstreamConn) SendMessage(msg *irc.Message) {
714 uc.outgoing <- msg
715}
Note: See TracBrowser for help on using the repository browser.