source: code/trunk/upstream.go@ 98

Last change on this file since 98 was 98, checked in by contact, 5 years ago

Rename project to soju

File size: 15.2 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "fmt"
7 "io"
8 "net"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/emersion/go-sasl"
14 "gopkg.in/irc.v3"
15)
16
17type upstreamChannel struct {
18 Name string
19 conn *upstreamConn
20 Topic string
21 TopicWho string
22 TopicTime time.Time
23 Status channelStatus
24 modes modeSet
25 Members map[string]membership
26 complete bool
27}
28
29type upstreamConn struct {
30 network *network
31 logger Logger
32 net net.Conn
33 irc *irc.Conn
34 srv *Server
35 user *user
36 messages chan<- *irc.Message
37 ring *Ring
38
39 serverName string
40 availableUserModes string
41 availableChannelModes string
42 channelModesWithParam string
43
44 registered bool
45 nick string
46 username string
47 realname string
48 closed bool
49 modes modeSet
50 channels map[string]*upstreamChannel
51 history map[string]uint64
52 caps map[string]string
53
54 saslClient sasl.Client
55 saslStarted bool
56}
57
58func connectToUpstream(network *network) (*upstreamConn, error) {
59 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
60
61 addr := network.Addr
62 if !strings.ContainsRune(addr, ':') {
63 addr = addr + ":6697"
64 }
65
66 logger.Printf("connecting to TLS server at address %q", addr)
67 netConn, err := tls.Dial("tcp", addr, nil)
68 if err != nil {
69 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
70 }
71
72 setKeepAlive(netConn)
73
74 msgs := make(chan *irc.Message, 64)
75 uc := &upstreamConn{
76 network: network,
77 logger: logger,
78 net: netConn,
79 irc: irc.NewConn(netConn),
80 srv: network.user.srv,
81 user: network.user,
82 messages: msgs,
83 ring: NewRing(network.user.srv.RingCap),
84 channels: make(map[string]*upstreamChannel),
85 history: make(map[string]uint64),
86 caps: make(map[string]string),
87 }
88
89 go func() {
90 for msg := range msgs {
91 if uc.srv.Debug {
92 uc.logger.Printf("sent: %v", msg)
93 }
94 if err := uc.irc.WriteMessage(msg); err != nil {
95 uc.logger.Printf("failed to write message: %v", err)
96 }
97 }
98 if err := uc.net.Close(); err != nil {
99 uc.logger.Printf("failed to close connection: %v", err)
100 } else {
101 uc.logger.Printf("connection closed")
102 }
103 }()
104
105 return uc, nil
106}
107
108func (uc *upstreamConn) Close() error {
109 if uc.closed {
110 return fmt.Errorf("upstream connection already closed")
111 }
112 close(uc.messages)
113 uc.closed = true
114 return nil
115}
116
117func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
118 uc.user.forEachDownstream(func(dc *downstreamConn) {
119 if dc.network != nil && dc.network != uc.network {
120 return
121 }
122 f(dc)
123 })
124}
125
126func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
127 ch, ok := uc.channels[name]
128 if !ok {
129 return nil, fmt.Errorf("unknown channel %q", name)
130 }
131 return ch, nil
132}
133
134func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
135 switch msg.Command {
136 case "PING":
137 uc.SendMessage(&irc.Message{
138 Command: "PONG",
139 Params: msg.Params,
140 })
141 return nil
142 case "MODE":
143 if msg.Prefix == nil {
144 return fmt.Errorf("missing prefix")
145 }
146
147 var name, modeStr string
148 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
149 return err
150 }
151
152 if name == msg.Prefix.Name { // user mode change
153 if name != uc.nick {
154 return fmt.Errorf("received MODE message for unknow nick %q", name)
155 }
156 return uc.modes.Apply(modeStr)
157 } else { // channel mode change
158 ch, err := uc.getChannel(name)
159 if err != nil {
160 return err
161 }
162 if err := ch.modes.Apply(modeStr); err != nil {
163 return err
164 }
165
166 uc.forEachDownstream(func(dc *downstreamConn) {
167 dc.SendMessage(&irc.Message{
168 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
169 Command: "MODE",
170 Params: []string{dc.marshalChannel(uc, name), modeStr},
171 })
172 })
173 }
174 case "NOTICE":
175 uc.logger.Print(msg)
176
177 uc.forEachDownstream(func(dc *downstreamConn) {
178 dc.SendMessage(msg)
179 })
180 case "CAP":
181 var subCmd string
182 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
183 return err
184 }
185 subCmd = strings.ToUpper(subCmd)
186 subParams := msg.Params[2:]
187 switch subCmd {
188 case "LS":
189 if len(subParams) < 1 {
190 return newNeedMoreParamsError(msg.Command)
191 }
192 caps := strings.Fields(subParams[len(subParams)-1])
193 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
194
195 for _, s := range caps {
196 kv := strings.SplitN(s, "=", 2)
197 k := strings.ToLower(kv[0])
198 var v string
199 if len(kv) == 2 {
200 v = kv[1]
201 }
202 uc.caps[k] = v
203 }
204
205 if more {
206 break // wait to receive all capabilities
207 }
208
209 if uc.requestSASL() {
210 uc.SendMessage(&irc.Message{
211 Command: "CAP",
212 Params: []string{"REQ", "sasl"},
213 })
214 break // we'll send CAP END after authentication is completed
215 }
216
217 uc.SendMessage(&irc.Message{
218 Command: "CAP",
219 Params: []string{"END"},
220 })
221 case "ACK", "NAK":
222 if len(subParams) < 1 {
223 return newNeedMoreParamsError(msg.Command)
224 }
225 caps := strings.Fields(subParams[0])
226
227 for _, name := range caps {
228 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
229 return err
230 }
231 }
232
233 if uc.saslClient == nil {
234 uc.SendMessage(&irc.Message{
235 Command: "CAP",
236 Params: []string{"END"},
237 })
238 }
239 default:
240 uc.logger.Printf("unhandled message: %v", msg)
241 }
242 case "AUTHENTICATE":
243 if uc.saslClient == nil {
244 return fmt.Errorf("received unexpected AUTHENTICATE message")
245 }
246
247 // TODO: if a challenge is 400 bytes long, buffer it
248 var challengeStr string
249 if err := parseMessageParams(msg, &challengeStr); err != nil {
250 uc.SendMessage(&irc.Message{
251 Command: "AUTHENTICATE",
252 Params: []string{"*"},
253 })
254 return err
255 }
256
257 var challenge []byte
258 if challengeStr != "+" {
259 var err error
260 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
261 if err != nil {
262 uc.SendMessage(&irc.Message{
263 Command: "AUTHENTICATE",
264 Params: []string{"*"},
265 })
266 return err
267 }
268 }
269
270 var resp []byte
271 var err error
272 if !uc.saslStarted {
273 _, resp, err = uc.saslClient.Start()
274 uc.saslStarted = true
275 } else {
276 resp, err = uc.saslClient.Next(challenge)
277 }
278 if err != nil {
279 uc.SendMessage(&irc.Message{
280 Command: "AUTHENTICATE",
281 Params: []string{"*"},
282 })
283 return err
284 }
285
286 // TODO: send response in multiple chunks if >= 400 bytes
287 var respStr = "+"
288 if resp != nil {
289 respStr = base64.StdEncoding.EncodeToString(resp)
290 }
291
292 uc.SendMessage(&irc.Message{
293 Command: "AUTHENTICATE",
294 Params: []string{respStr},
295 })
296 case rpl_loggedin:
297 var account string
298 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
299 return err
300 }
301 uc.logger.Printf("logged in with account %q", account)
302 case rpl_loggedout:
303 uc.logger.Printf("logged out")
304 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
305 var info string
306 if err := parseMessageParams(msg, nil, &info); err != nil {
307 return err
308 }
309 switch msg.Command {
310 case err_nicklocked:
311 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
312 case err_saslfail:
313 uc.logger.Printf("SASL authentication failed: %v", info)
314 case err_sasltoolong:
315 uc.logger.Printf("SASL message too long: %v", info)
316 }
317
318 uc.saslClient = nil
319 uc.saslStarted = false
320
321 uc.SendMessage(&irc.Message{
322 Command: "CAP",
323 Params: []string{"END"},
324 })
325 case irc.RPL_WELCOME:
326 uc.registered = true
327 uc.logger.Printf("connection registered")
328
329 channels, err := uc.srv.db.ListChannels(uc.network.ID)
330 if err != nil {
331 uc.logger.Printf("failed to list channels from database: %v", err)
332 break
333 }
334
335 for _, ch := range channels {
336 uc.SendMessage(&irc.Message{
337 Command: "JOIN",
338 Params: []string{ch.Name},
339 })
340 }
341 case irc.RPL_MYINFO:
342 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
343 return err
344 }
345 if len(msg.Params) > 5 {
346 uc.channelModesWithParam = msg.Params[5]
347 }
348 case "NICK":
349 if msg.Prefix == nil {
350 return fmt.Errorf("expected a prefix")
351 }
352
353 var newNick string
354 if err := parseMessageParams(msg, &newNick); err != nil {
355 return err
356 }
357
358 if msg.Prefix.Name == uc.nick {
359 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
360 uc.nick = newNick
361 }
362
363 for _, ch := range uc.channels {
364 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
365 delete(ch.Members, msg.Prefix.Name)
366 ch.Members[newNick] = membership
367 }
368 }
369
370 if msg.Prefix.Name != uc.nick {
371 uc.forEachDownstream(func(dc *downstreamConn) {
372 dc.SendMessage(&irc.Message{
373 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
374 Command: "NICK",
375 Params: []string{newNick},
376 })
377 })
378 }
379 case "JOIN":
380 if msg.Prefix == nil {
381 return fmt.Errorf("expected a prefix")
382 }
383
384 var channels string
385 if err := parseMessageParams(msg, &channels); err != nil {
386 return err
387 }
388
389 for _, ch := range strings.Split(channels, ",") {
390 if msg.Prefix.Name == uc.nick {
391 uc.logger.Printf("joined channel %q", ch)
392 uc.channels[ch] = &upstreamChannel{
393 Name: ch,
394 conn: uc,
395 Members: make(map[string]membership),
396 }
397 } else {
398 ch, err := uc.getChannel(ch)
399 if err != nil {
400 return err
401 }
402 ch.Members[msg.Prefix.Name] = 0
403 }
404
405 uc.forEachDownstream(func(dc *downstreamConn) {
406 dc.SendMessage(&irc.Message{
407 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
408 Command: "JOIN",
409 Params: []string{dc.marshalChannel(uc, ch)},
410 })
411 })
412 }
413 case "PART":
414 if msg.Prefix == nil {
415 return fmt.Errorf("expected a prefix")
416 }
417
418 var channels string
419 if err := parseMessageParams(msg, &channels); err != nil {
420 return err
421 }
422
423 for _, ch := range strings.Split(channels, ",") {
424 if msg.Prefix.Name == uc.nick {
425 uc.logger.Printf("parted channel %q", ch)
426 delete(uc.channels, ch)
427 } else {
428 ch, err := uc.getChannel(ch)
429 if err != nil {
430 return err
431 }
432 delete(ch.Members, msg.Prefix.Name)
433 }
434
435 uc.forEachDownstream(func(dc *downstreamConn) {
436 dc.SendMessage(&irc.Message{
437 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
438 Command: "PART",
439 Params: []string{dc.marshalChannel(uc, ch)},
440 })
441 })
442 }
443 case "QUIT":
444 if msg.Prefix == nil {
445 return fmt.Errorf("expected a prefix")
446 }
447
448 if msg.Prefix.Name == uc.nick {
449 uc.logger.Printf("quit")
450 }
451
452 for _, ch := range uc.channels {
453 delete(ch.Members, msg.Prefix.Name)
454 }
455
456 if msg.Prefix.Name != uc.nick {
457 uc.forEachDownstream(func(dc *downstreamConn) {
458 dc.SendMessage(&irc.Message{
459 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
460 Command: "QUIT",
461 Params: msg.Params,
462 })
463 })
464 }
465 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
466 var name, topic string
467 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
468 return err
469 }
470 ch, err := uc.getChannel(name)
471 if err != nil {
472 return err
473 }
474 if msg.Command == irc.RPL_TOPIC {
475 ch.Topic = topic
476 } else {
477 ch.Topic = ""
478 }
479 case "TOPIC":
480 var name string
481 if err := parseMessageParams(msg, &name); err != nil {
482 return err
483 }
484 ch, err := uc.getChannel(name)
485 if err != nil {
486 return err
487 }
488 if len(msg.Params) > 1 {
489 ch.Topic = msg.Params[1]
490 } else {
491 ch.Topic = ""
492 }
493 uc.forEachDownstream(func(dc *downstreamConn) {
494 params := []string{dc.marshalChannel(uc, name)}
495 if ch.Topic != "" {
496 params = append(params, ch.Topic)
497 }
498 dc.SendMessage(&irc.Message{
499 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
500 Command: "TOPIC",
501 Params: params,
502 })
503 })
504 case rpl_topicwhotime:
505 var name, who, timeStr string
506 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
507 return err
508 }
509 ch, err := uc.getChannel(name)
510 if err != nil {
511 return err
512 }
513 ch.TopicWho = who
514 sec, err := strconv.ParseInt(timeStr, 10, 64)
515 if err != nil {
516 return fmt.Errorf("failed to parse topic time: %v", err)
517 }
518 ch.TopicTime = time.Unix(sec, 0)
519 case irc.RPL_NAMREPLY:
520 var name, statusStr, members string
521 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
522 return err
523 }
524 ch, err := uc.getChannel(name)
525 if err != nil {
526 return err
527 }
528
529 status, err := parseChannelStatus(statusStr)
530 if err != nil {
531 return err
532 }
533 ch.Status = status
534
535 for _, s := range strings.Split(members, " ") {
536 membership, nick := parseMembershipPrefix(s)
537 ch.Members[nick] = membership
538 }
539 case irc.RPL_ENDOFNAMES:
540 var name string
541 if err := parseMessageParams(msg, nil, &name); err != nil {
542 return err
543 }
544 ch, err := uc.getChannel(name)
545 if err != nil {
546 return err
547 }
548
549 if ch.complete {
550 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
551 }
552 ch.complete = true
553
554 uc.forEachDownstream(func(dc *downstreamConn) {
555 forwardChannel(dc, ch)
556 })
557 case "PRIVMSG":
558 if err := parseMessageParams(msg, nil, nil); err != nil {
559 return err
560 }
561 uc.ring.Produce(msg)
562 case irc.RPL_YOURHOST, irc.RPL_CREATED:
563 // Ignore
564 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
565 // Ignore
566 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
567 // Ignore
568 case rpl_localusers, rpl_globalusers:
569 // Ignore
570 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
571 // Ignore
572 default:
573 uc.logger.Printf("unhandled message: %v", msg)
574 }
575 return nil
576}
577
578func (uc *upstreamConn) register() {
579 uc.nick = uc.network.Nick
580 uc.username = uc.network.Username
581 if uc.username == "" {
582 uc.username = uc.nick
583 }
584 uc.realname = uc.network.Realname
585 if uc.realname == "" {
586 uc.realname = uc.nick
587 }
588
589 uc.SendMessage(&irc.Message{
590 Command: "CAP",
591 Params: []string{"LS", "302"},
592 })
593
594 if uc.network.Pass != "" {
595 uc.SendMessage(&irc.Message{
596 Command: "PASS",
597 Params: []string{uc.network.Pass},
598 })
599 }
600
601 uc.SendMessage(&irc.Message{
602 Command: "NICK",
603 Params: []string{uc.nick},
604 })
605 uc.SendMessage(&irc.Message{
606 Command: "USER",
607 Params: []string{uc.username, "0", "*", uc.realname},
608 })
609}
610
611func (uc *upstreamConn) requestSASL() bool {
612 if uc.network.SASL.Mechanism == "" {
613 return false
614 }
615
616 v, ok := uc.caps["sasl"]
617 if !ok {
618 return false
619 }
620 if v != "" {
621 mechanisms := strings.Split(v, ",")
622 found := false
623 for _, mech := range mechanisms {
624 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
625 found = true
626 break
627 }
628 }
629 if !found {
630 return false
631 }
632 }
633
634 return true
635}
636
637func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
638 auth := &uc.network.SASL
639 switch name {
640 case "sasl":
641 if !ok {
642 uc.logger.Printf("server refused to acknowledge the SASL capability")
643 return nil
644 }
645
646 switch auth.Mechanism {
647 case "PLAIN":
648 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
649 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
650 default:
651 return fmt.Errorf("unsupported SASL mechanism %q", name)
652 }
653
654 uc.SendMessage(&irc.Message{
655 Command: "AUTHENTICATE",
656 Params: []string{auth.Mechanism},
657 })
658 }
659 return nil
660}
661
662func (uc *upstreamConn) readMessages() error {
663 for {
664 msg, err := uc.irc.ReadMessage()
665 if err == io.EOF {
666 break
667 } else if err != nil {
668 return fmt.Errorf("failed to read IRC command: %v", err)
669 }
670
671 if uc.srv.Debug {
672 uc.logger.Printf("received: %v", msg)
673 }
674
675 if err := uc.handleMessage(msg); err != nil {
676 uc.logger.Printf("failed to handle message %q: %v", msg, err)
677 }
678 }
679
680 return nil
681}
682
683func (uc *upstreamConn) SendMessage(msg *irc.Message) {
684 uc.messages <- msg
685}
Note: See TracBrowser for help on using the repository browser.