source: code/trunk/upstream.go@ 109

Last change on this file since 109 was 109, checked in by contact, 5 years ago

Protect upstreamConn.history with a lock

File size: 15.3 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "encoding/base64"
6 "fmt"
7 "io"
8 "net"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/emersion/go-sasl"
15 "gopkg.in/irc.v3"
16)
17
18type upstreamChannel struct {
19 Name string
20 conn *upstreamConn
21 Topic string
22 TopicWho string
23 TopicTime time.Time
24 Status channelStatus
25 modes modeSet
26 Members map[string]membership
27 complete bool
28}
29
30type upstreamConn struct {
31 network *network
32 logger Logger
33 net net.Conn
34 irc *irc.Conn
35 srv *Server
36 user *user
37 outgoing chan<- *irc.Message
38 ring *Ring
39
40 serverName string
41 availableUserModes string
42 availableChannelModes string
43 channelModesWithParam string
44
45 registered bool
46 nick string
47 username string
48 realname string
49 closed bool
50 modes modeSet
51 channels map[string]*upstreamChannel
52 caps map[string]string
53
54 saslClient sasl.Client
55 saslStarted bool
56
57 lock sync.Mutex
58 history map[string]uint64 // TODO: move to network
59}
60
61func connectToUpstream(network *network) (*upstreamConn, error) {
62 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
63
64 addr := network.Addr
65 if !strings.ContainsRune(addr, ':') {
66 addr = addr + ":6697"
67 }
68
69 logger.Printf("connecting to TLS server at address %q", addr)
70 netConn, err := tls.Dial("tcp", addr, nil)
71 if err != nil {
72 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
73 }
74
75 setKeepAlive(netConn)
76
77 outgoing := make(chan *irc.Message, 64)
78 uc := &upstreamConn{
79 network: network,
80 logger: logger,
81 net: netConn,
82 irc: irc.NewConn(netConn),
83 srv: network.user.srv,
84 user: network.user,
85 outgoing: outgoing,
86 ring: NewRing(network.user.srv.RingCap),
87 channels: make(map[string]*upstreamChannel),
88 history: make(map[string]uint64),
89 caps: make(map[string]string),
90 }
91
92 go func() {
93 for msg := range outgoing {
94 if uc.srv.Debug {
95 uc.logger.Printf("sent: %v", msg)
96 }
97 if err := uc.irc.WriteMessage(msg); err != nil {
98 uc.logger.Printf("failed to write message: %v", err)
99 }
100 }
101 if err := uc.net.Close(); err != nil {
102 uc.logger.Printf("failed to close connection: %v", err)
103 } else {
104 uc.logger.Printf("connection closed")
105 }
106 }()
107
108 return uc, nil
109}
110
111func (uc *upstreamConn) Close() error {
112 if uc.closed {
113 return fmt.Errorf("upstream connection already closed")
114 }
115 close(uc.outgoing)
116 uc.closed = true
117 return nil
118}
119
120func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
121 uc.user.forEachDownstream(func(dc *downstreamConn) {
122 if dc.network != nil && dc.network != uc.network {
123 return
124 }
125 f(dc)
126 })
127}
128
129func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
130 ch, ok := uc.channels[name]
131 if !ok {
132 return nil, fmt.Errorf("unknown channel %q", name)
133 }
134 return ch, nil
135}
136
137func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
138 switch msg.Command {
139 case "PING":
140 uc.SendMessage(&irc.Message{
141 Command: "PONG",
142 Params: msg.Params,
143 })
144 return nil
145 case "MODE":
146 if msg.Prefix == nil {
147 return fmt.Errorf("missing prefix")
148 }
149
150 var name, modeStr string
151 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
152 return err
153 }
154
155 if name == msg.Prefix.Name { // user mode change
156 if name != uc.nick {
157 return fmt.Errorf("received MODE message for unknow nick %q", name)
158 }
159 return uc.modes.Apply(modeStr)
160 } else { // channel mode change
161 ch, err := uc.getChannel(name)
162 if err != nil {
163 return err
164 }
165 if err := ch.modes.Apply(modeStr); err != nil {
166 return err
167 }
168
169 uc.forEachDownstream(func(dc *downstreamConn) {
170 dc.SendMessage(&irc.Message{
171 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
172 Command: "MODE",
173 Params: []string{dc.marshalChannel(uc, name), modeStr},
174 })
175 })
176 }
177 case "NOTICE":
178 uc.logger.Print(msg)
179
180 uc.forEachDownstream(func(dc *downstreamConn) {
181 dc.SendMessage(msg)
182 })
183 case "CAP":
184 var subCmd string
185 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
186 return err
187 }
188 subCmd = strings.ToUpper(subCmd)
189 subParams := msg.Params[2:]
190 switch subCmd {
191 case "LS":
192 if len(subParams) < 1 {
193 return newNeedMoreParamsError(msg.Command)
194 }
195 caps := strings.Fields(subParams[len(subParams)-1])
196 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
197
198 for _, s := range caps {
199 kv := strings.SplitN(s, "=", 2)
200 k := strings.ToLower(kv[0])
201 var v string
202 if len(kv) == 2 {
203 v = kv[1]
204 }
205 uc.caps[k] = v
206 }
207
208 if more {
209 break // wait to receive all capabilities
210 }
211
212 if uc.requestSASL() {
213 uc.SendMessage(&irc.Message{
214 Command: "CAP",
215 Params: []string{"REQ", "sasl"},
216 })
217 break // we'll send CAP END after authentication is completed
218 }
219
220 uc.SendMessage(&irc.Message{
221 Command: "CAP",
222 Params: []string{"END"},
223 })
224 case "ACK", "NAK":
225 if len(subParams) < 1 {
226 return newNeedMoreParamsError(msg.Command)
227 }
228 caps := strings.Fields(subParams[0])
229
230 for _, name := range caps {
231 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
232 return err
233 }
234 }
235
236 if uc.saslClient == nil {
237 uc.SendMessage(&irc.Message{
238 Command: "CAP",
239 Params: []string{"END"},
240 })
241 }
242 default:
243 uc.logger.Printf("unhandled message: %v", msg)
244 }
245 case "AUTHENTICATE":
246 if uc.saslClient == nil {
247 return fmt.Errorf("received unexpected AUTHENTICATE message")
248 }
249
250 // TODO: if a challenge is 400 bytes long, buffer it
251 var challengeStr string
252 if err := parseMessageParams(msg, &challengeStr); err != nil {
253 uc.SendMessage(&irc.Message{
254 Command: "AUTHENTICATE",
255 Params: []string{"*"},
256 })
257 return err
258 }
259
260 var challenge []byte
261 if challengeStr != "+" {
262 var err error
263 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
264 if err != nil {
265 uc.SendMessage(&irc.Message{
266 Command: "AUTHENTICATE",
267 Params: []string{"*"},
268 })
269 return err
270 }
271 }
272
273 var resp []byte
274 var err error
275 if !uc.saslStarted {
276 _, resp, err = uc.saslClient.Start()
277 uc.saslStarted = true
278 } else {
279 resp, err = uc.saslClient.Next(challenge)
280 }
281 if err != nil {
282 uc.SendMessage(&irc.Message{
283 Command: "AUTHENTICATE",
284 Params: []string{"*"},
285 })
286 return err
287 }
288
289 // TODO: send response in multiple chunks if >= 400 bytes
290 var respStr = "+"
291 if resp != nil {
292 respStr = base64.StdEncoding.EncodeToString(resp)
293 }
294
295 uc.SendMessage(&irc.Message{
296 Command: "AUTHENTICATE",
297 Params: []string{respStr},
298 })
299 case rpl_loggedin:
300 var account string
301 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
302 return err
303 }
304 uc.logger.Printf("logged in with account %q", account)
305 case rpl_loggedout:
306 uc.logger.Printf("logged out")
307 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
308 var info string
309 if err := parseMessageParams(msg, nil, &info); err != nil {
310 return err
311 }
312 switch msg.Command {
313 case err_nicklocked:
314 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
315 case err_saslfail:
316 uc.logger.Printf("SASL authentication failed: %v", info)
317 case err_sasltoolong:
318 uc.logger.Printf("SASL message too long: %v", info)
319 }
320
321 uc.saslClient = nil
322 uc.saslStarted = false
323
324 uc.SendMessage(&irc.Message{
325 Command: "CAP",
326 Params: []string{"END"},
327 })
328 case irc.RPL_WELCOME:
329 uc.registered = true
330 uc.logger.Printf("connection registered")
331
332 channels, err := uc.srv.db.ListChannels(uc.network.ID)
333 if err != nil {
334 uc.logger.Printf("failed to list channels from database: %v", err)
335 break
336 }
337
338 for _, ch := range channels {
339 uc.SendMessage(&irc.Message{
340 Command: "JOIN",
341 Params: []string{ch.Name},
342 })
343 }
344 case irc.RPL_MYINFO:
345 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
346 return err
347 }
348 if len(msg.Params) > 5 {
349 uc.channelModesWithParam = msg.Params[5]
350 }
351 case "NICK":
352 if msg.Prefix == nil {
353 return fmt.Errorf("expected a prefix")
354 }
355
356 var newNick string
357 if err := parseMessageParams(msg, &newNick); err != nil {
358 return err
359 }
360
361 if msg.Prefix.Name == uc.nick {
362 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
363 uc.nick = newNick
364 }
365
366 for _, ch := range uc.channels {
367 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
368 delete(ch.Members, msg.Prefix.Name)
369 ch.Members[newNick] = membership
370 }
371 }
372
373 if msg.Prefix.Name != uc.nick {
374 uc.forEachDownstream(func(dc *downstreamConn) {
375 dc.SendMessage(&irc.Message{
376 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
377 Command: "NICK",
378 Params: []string{newNick},
379 })
380 })
381 }
382 case "JOIN":
383 if msg.Prefix == nil {
384 return fmt.Errorf("expected a prefix")
385 }
386
387 var channels string
388 if err := parseMessageParams(msg, &channels); err != nil {
389 return err
390 }
391
392 for _, ch := range strings.Split(channels, ",") {
393 if msg.Prefix.Name == uc.nick {
394 uc.logger.Printf("joined channel %q", ch)
395 uc.channels[ch] = &upstreamChannel{
396 Name: ch,
397 conn: uc,
398 Members: make(map[string]membership),
399 }
400 } else {
401 ch, err := uc.getChannel(ch)
402 if err != nil {
403 return err
404 }
405 ch.Members[msg.Prefix.Name] = 0
406 }
407
408 uc.forEachDownstream(func(dc *downstreamConn) {
409 dc.SendMessage(&irc.Message{
410 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
411 Command: "JOIN",
412 Params: []string{dc.marshalChannel(uc, ch)},
413 })
414 })
415 }
416 case "PART":
417 if msg.Prefix == nil {
418 return fmt.Errorf("expected a prefix")
419 }
420
421 var channels string
422 if err := parseMessageParams(msg, &channels); err != nil {
423 return err
424 }
425
426 for _, ch := range strings.Split(channels, ",") {
427 if msg.Prefix.Name == uc.nick {
428 uc.logger.Printf("parted channel %q", ch)
429 delete(uc.channels, ch)
430 } else {
431 ch, err := uc.getChannel(ch)
432 if err != nil {
433 return err
434 }
435 delete(ch.Members, msg.Prefix.Name)
436 }
437
438 uc.forEachDownstream(func(dc *downstreamConn) {
439 dc.SendMessage(&irc.Message{
440 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
441 Command: "PART",
442 Params: []string{dc.marshalChannel(uc, ch)},
443 })
444 })
445 }
446 case "QUIT":
447 if msg.Prefix == nil {
448 return fmt.Errorf("expected a prefix")
449 }
450
451 if msg.Prefix.Name == uc.nick {
452 uc.logger.Printf("quit")
453 }
454
455 for _, ch := range uc.channels {
456 delete(ch.Members, msg.Prefix.Name)
457 }
458
459 if msg.Prefix.Name != uc.nick {
460 uc.forEachDownstream(func(dc *downstreamConn) {
461 dc.SendMessage(&irc.Message{
462 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
463 Command: "QUIT",
464 Params: msg.Params,
465 })
466 })
467 }
468 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
469 var name, topic string
470 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
471 return err
472 }
473 ch, err := uc.getChannel(name)
474 if err != nil {
475 return err
476 }
477 if msg.Command == irc.RPL_TOPIC {
478 ch.Topic = topic
479 } else {
480 ch.Topic = ""
481 }
482 case "TOPIC":
483 var name string
484 if err := parseMessageParams(msg, &name); err != nil {
485 return err
486 }
487 ch, err := uc.getChannel(name)
488 if err != nil {
489 return err
490 }
491 if len(msg.Params) > 1 {
492 ch.Topic = msg.Params[1]
493 } else {
494 ch.Topic = ""
495 }
496 uc.forEachDownstream(func(dc *downstreamConn) {
497 params := []string{dc.marshalChannel(uc, name)}
498 if ch.Topic != "" {
499 params = append(params, ch.Topic)
500 }
501 dc.SendMessage(&irc.Message{
502 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
503 Command: "TOPIC",
504 Params: params,
505 })
506 })
507 case rpl_topicwhotime:
508 var name, who, timeStr string
509 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
510 return err
511 }
512 ch, err := uc.getChannel(name)
513 if err != nil {
514 return err
515 }
516 ch.TopicWho = who
517 sec, err := strconv.ParseInt(timeStr, 10, 64)
518 if err != nil {
519 return fmt.Errorf("failed to parse topic time: %v", err)
520 }
521 ch.TopicTime = time.Unix(sec, 0)
522 case irc.RPL_NAMREPLY:
523 var name, statusStr, members string
524 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
525 return err
526 }
527 ch, err := uc.getChannel(name)
528 if err != nil {
529 return err
530 }
531
532 status, err := parseChannelStatus(statusStr)
533 if err != nil {
534 return err
535 }
536 ch.Status = status
537
538 for _, s := range strings.Split(members, " ") {
539 membership, nick := parseMembershipPrefix(s)
540 ch.Members[nick] = membership
541 }
542 case irc.RPL_ENDOFNAMES:
543 var name string
544 if err := parseMessageParams(msg, nil, &name); err != nil {
545 return err
546 }
547 ch, err := uc.getChannel(name)
548 if err != nil {
549 return err
550 }
551
552 if ch.complete {
553 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
554 }
555 ch.complete = true
556
557 uc.forEachDownstream(func(dc *downstreamConn) {
558 forwardChannel(dc, ch)
559 })
560 case "PRIVMSG":
561 if err := parseMessageParams(msg, nil, nil); err != nil {
562 return err
563 }
564 uc.ring.Produce(msg)
565 case irc.RPL_YOURHOST, irc.RPL_CREATED:
566 // Ignore
567 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
568 // Ignore
569 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
570 // Ignore
571 case rpl_localusers, rpl_globalusers:
572 // Ignore
573 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
574 // Ignore
575 default:
576 uc.logger.Printf("unhandled message: %v", msg)
577 }
578 return nil
579}
580
581func (uc *upstreamConn) register() {
582 uc.nick = uc.network.Nick
583 uc.username = uc.network.Username
584 if uc.username == "" {
585 uc.username = uc.nick
586 }
587 uc.realname = uc.network.Realname
588 if uc.realname == "" {
589 uc.realname = uc.nick
590 }
591
592 uc.SendMessage(&irc.Message{
593 Command: "CAP",
594 Params: []string{"LS", "302"},
595 })
596
597 if uc.network.Pass != "" {
598 uc.SendMessage(&irc.Message{
599 Command: "PASS",
600 Params: []string{uc.network.Pass},
601 })
602 }
603
604 uc.SendMessage(&irc.Message{
605 Command: "NICK",
606 Params: []string{uc.nick},
607 })
608 uc.SendMessage(&irc.Message{
609 Command: "USER",
610 Params: []string{uc.username, "0", "*", uc.realname},
611 })
612}
613
614func (uc *upstreamConn) requestSASL() bool {
615 if uc.network.SASL.Mechanism == "" {
616 return false
617 }
618
619 v, ok := uc.caps["sasl"]
620 if !ok {
621 return false
622 }
623 if v != "" {
624 mechanisms := strings.Split(v, ",")
625 found := false
626 for _, mech := range mechanisms {
627 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
628 found = true
629 break
630 }
631 }
632 if !found {
633 return false
634 }
635 }
636
637 return true
638}
639
640func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
641 auth := &uc.network.SASL
642 switch name {
643 case "sasl":
644 if !ok {
645 uc.logger.Printf("server refused to acknowledge the SASL capability")
646 return nil
647 }
648
649 switch auth.Mechanism {
650 case "PLAIN":
651 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
652 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
653 default:
654 return fmt.Errorf("unsupported SASL mechanism %q", name)
655 }
656
657 uc.SendMessage(&irc.Message{
658 Command: "AUTHENTICATE",
659 Params: []string{auth.Mechanism},
660 })
661 }
662 return nil
663}
664
665func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
666 for {
667 msg, err := uc.irc.ReadMessage()
668 if err == io.EOF {
669 break
670 } else if err != nil {
671 return fmt.Errorf("failed to read IRC command: %v", err)
672 }
673
674 if uc.srv.Debug {
675 uc.logger.Printf("received: %v", msg)
676 }
677
678 ch <- upstreamIncomingMessage{msg, uc}
679 }
680
681 return nil
682}
683
684func (uc *upstreamConn) SendMessage(msg *irc.Message) {
685 uc.outgoing <- msg
686}
Note: See TracBrowser for help on using the repository browser.