source: code/trunk/downstream.go@ 109

Last change on this file since 109 was 109, checked in by contact, 5 years ago

Protect upstreamConn.history with a lock

File size: 17.5 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "strings"
10 "sync"
11 "time"
12
13 "golang.org/x/crypto/bcrypt"
14 "gopkg.in/irc.v3"
15)
16
17type ircError struct {
18 Message *irc.Message
19}
20
21func (err ircError) Error() string {
22 return err.Message.String()
23}
24
25func newUnknownCommandError(cmd string) ircError {
26 return ircError{&irc.Message{
27 Command: irc.ERR_UNKNOWNCOMMAND,
28 Params: []string{
29 "*",
30 cmd,
31 "Unknown command",
32 },
33 }}
34}
35
36func newNeedMoreParamsError(cmd string) ircError {
37 return ircError{&irc.Message{
38 Command: irc.ERR_NEEDMOREPARAMS,
39 Params: []string{
40 "*",
41 cmd,
42 "Not enough parameters",
43 },
44 }}
45}
46
47var errAuthFailed = ircError{&irc.Message{
48 Command: irc.ERR_PASSWDMISMATCH,
49 Params: []string{"*", "Invalid username or password"},
50}}
51
52type ringMessage struct {
53 consumer *RingConsumer
54 upstreamConn *upstreamConn
55}
56
57type downstreamConn struct {
58 net net.Conn
59 irc *irc.Conn
60 srv *Server
61 logger Logger
62 outgoing chan *irc.Message
63 ringMessages chan ringMessage
64 closed chan struct{}
65
66 registered bool
67 user *user
68 nick string
69 username string
70 rawUsername string
71 realname string
72 password string // empty after authentication
73 network *network // can be nil
74
75 negociatingCaps bool
76 capVersion int
77 caps map[string]bool
78
79 lock sync.Mutex
80 ourMessages map[*irc.Message]struct{}
81}
82
83func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
84 dc := &downstreamConn{
85 net: netConn,
86 irc: irc.NewConn(netConn),
87 srv: srv,
88 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
89 outgoing: make(chan *irc.Message, 64),
90 ringMessages: make(chan ringMessage),
91 closed: make(chan struct{}),
92 caps: make(map[string]bool),
93 ourMessages: make(map[*irc.Message]struct{}),
94 }
95
96 go func() {
97 if err := dc.writeMessages(); err != nil {
98 dc.logger.Printf("failed to write message: %v", err)
99 }
100 if err := dc.net.Close(); err != nil {
101 dc.logger.Printf("failed to close connection: %v", err)
102 } else {
103 dc.logger.Printf("connection closed")
104 }
105 }()
106
107 return dc
108}
109
110func (dc *downstreamConn) prefix() *irc.Prefix {
111 return &irc.Prefix{
112 Name: dc.nick,
113 User: dc.username,
114 // TODO: fill the host?
115 }
116}
117
118func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
119 return name
120}
121
122func (dc *downstreamConn) forEachNetwork(f func(*network)) {
123 if dc.network != nil {
124 f(dc.network)
125 } else {
126 dc.user.forEachNetwork(f)
127 }
128}
129
130func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
131 dc.user.forEachUpstream(func(uc *upstreamConn) {
132 if dc.network != nil && uc.network != dc.network {
133 return
134 }
135 f(uc)
136 })
137}
138
139// upstream returns the upstream connection, if any. If there are zero or if
140// there are multiple upstream connections, it returns nil.
141func (dc *downstreamConn) upstream() *upstreamConn {
142 if dc.network == nil {
143 return nil
144 }
145
146 var upstream *upstreamConn
147 dc.forEachUpstream(func(uc *upstreamConn) {
148 upstream = uc
149 })
150 return upstream
151}
152
153func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
154 if uc := dc.upstream(); uc != nil {
155 return uc, name, nil
156 }
157
158 // TODO: extract network name from channel name if dc.upstream == nil
159 var channel *upstreamChannel
160 var err error
161 dc.forEachUpstream(func(uc *upstreamConn) {
162 if err != nil {
163 return
164 }
165 if ch, ok := uc.channels[name]; ok {
166 if channel != nil {
167 err = fmt.Errorf("ambiguous channel name %q", name)
168 } else {
169 channel = ch
170 }
171 }
172 })
173 if channel == nil {
174 return nil, "", ircError{&irc.Message{
175 Command: irc.ERR_NOSUCHCHANNEL,
176 Params: []string{name, "No such channel"},
177 }}
178 }
179 return channel.conn, channel.Name, nil
180}
181
182func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
183 if nick == uc.nick {
184 return dc.nick
185 }
186 return nick
187}
188
189func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
190 if prefix.Name == uc.nick {
191 return dc.prefix()
192 }
193 return prefix
194}
195
196func (dc *downstreamConn) isClosed() bool {
197 select {
198 case <-dc.closed:
199 return true
200 default:
201 return false
202 }
203}
204
205func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
206 dc.logger.Printf("new connection")
207
208 for {
209 msg, err := dc.irc.ReadMessage()
210 if err == io.EOF {
211 break
212 } else if err != nil {
213 return fmt.Errorf("failed to read IRC command: %v", err)
214 }
215
216 if dc.srv.Debug {
217 dc.logger.Printf("received: %v", msg)
218 }
219
220 ch <- downstreamIncomingMessage{msg, dc}
221 }
222
223 return nil
224}
225
226func (dc *downstreamConn) writeMessages() error {
227 for {
228 var err error
229 var closed bool
230 select {
231 case msg := <-dc.outgoing:
232 if dc.srv.Debug {
233 dc.logger.Printf("sent: %v", msg)
234 }
235 err = dc.irc.WriteMessage(msg)
236 case ringMessage := <-dc.ringMessages:
237 consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
238 for {
239 msg := consumer.Peek()
240 if msg == nil {
241 break
242 }
243
244 dc.lock.Lock()
245 _, ours := dc.ourMessages[msg]
246 delete(dc.ourMessages, msg)
247 dc.lock.Unlock()
248 if ours {
249 // The message comes from our connection, don't echo it
250 // back
251 continue
252 }
253
254 msg = msg.Copy()
255 switch msg.Command {
256 case "PRIVMSG":
257 // TODO: detect whether it's a user or a channel
258 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
259 default:
260 panic("expected to consume a PRIVMSG message")
261 }
262 if dc.srv.Debug {
263 dc.logger.Printf("sent: %v", msg)
264 }
265 err = dc.irc.WriteMessage(msg)
266 if err != nil {
267 break
268 }
269 consumer.Consume()
270 }
271 case <-dc.closed:
272 closed = true
273 }
274 if err != nil {
275 return err
276 }
277 if closed {
278 break
279 }
280 }
281 return nil
282}
283
284func (dc *downstreamConn) Close() error {
285 if dc.isClosed() {
286 return fmt.Errorf("downstream connection already closed")
287 }
288
289 if u := dc.user; u != nil {
290 u.lock.Lock()
291 for i := range u.downstreamConns {
292 if u.downstreamConns[i] == dc {
293 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
294 break
295 }
296 }
297 u.lock.Unlock()
298 }
299
300 close(dc.closed)
301 return nil
302}
303
304func (dc *downstreamConn) SendMessage(msg *irc.Message) {
305 dc.outgoing <- msg
306}
307
308func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
309 switch msg.Command {
310 case "QUIT":
311 return dc.Close()
312 default:
313 if dc.registered {
314 return dc.handleMessageRegistered(msg)
315 } else {
316 return dc.handleMessageUnregistered(msg)
317 }
318 }
319}
320
321func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
322 switch msg.Command {
323 case "NICK":
324 if err := parseMessageParams(msg, &dc.nick); err != nil {
325 return err
326 }
327 case "USER":
328 var username string
329 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
330 return err
331 }
332 dc.rawUsername = username
333 case "PASS":
334 if err := parseMessageParams(msg, &dc.password); err != nil {
335 return err
336 }
337 case "CAP":
338 var subCmd string
339 if err := parseMessageParams(msg, &subCmd); err != nil {
340 return err
341 }
342 subCmd = strings.ToUpper(subCmd)
343 if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
344 return err
345 }
346 default:
347 dc.logger.Printf("unhandled message: %v", msg)
348 return newUnknownCommandError(msg.Command)
349 }
350 if dc.rawUsername != "" && dc.nick != "" && !dc.negociatingCaps {
351 return dc.register()
352 }
353 return nil
354}
355
356func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
357 replyTo := dc.nick
358 if !dc.registered {
359 replyTo = "*"
360 }
361
362 switch cmd {
363 case "LS":
364 if len(args) > 0 {
365 var err error
366 if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
367 return err
368 }
369 }
370
371 var caps []string
372 /*if dc.capVersion >= 302 {
373 caps = append(caps, "sasl=PLAIN")
374 } else {
375 caps = append(caps, "sasl")
376 }*/
377
378 // TODO: multi-line replies
379 dc.SendMessage(&irc.Message{
380 Prefix: dc.srv.prefix(),
381 Command: "CAP",
382 Params: []string{replyTo, "LS", strings.Join(caps, " ")},
383 })
384
385 if !dc.registered {
386 dc.negociatingCaps = true
387 }
388 case "LIST":
389 var caps []string
390 for name := range dc.caps {
391 caps = append(caps, name)
392 }
393
394 // TODO: multi-line replies
395 dc.SendMessage(&irc.Message{
396 Prefix: dc.srv.prefix(),
397 Command: "CAP",
398 Params: []string{replyTo, "LIST", strings.Join(caps, " ")},
399 })
400 case "REQ":
401 if len(args) == 0 {
402 return ircError{&irc.Message{
403 Command: err_invalidcapcmd,
404 Params: []string{replyTo, cmd, "Missing argument in CAP REQ command"},
405 }}
406 }
407
408 caps := strings.Fields(args[0])
409 ack := true
410 for _, name := range caps {
411 name = strings.ToLower(name)
412 enable := !strings.HasPrefix(name, "-")
413 if !enable {
414 name = strings.TrimPrefix(name, "-")
415 }
416
417 enabled := dc.caps[name]
418 if enable == enabled {
419 continue
420 }
421
422 switch name {
423 /*case "sasl":
424 dc.caps[name] = enable*/
425 default:
426 ack = false
427 }
428 }
429
430 reply := "NAK"
431 if ack {
432 reply = "ACK"
433 }
434 dc.SendMessage(&irc.Message{
435 Prefix: dc.srv.prefix(),
436 Command: "CAP",
437 Params: []string{replyTo, reply, args[0]},
438 })
439 case "END":
440 dc.negociatingCaps = false
441 default:
442 return ircError{&irc.Message{
443 Command: err_invalidcapcmd,
444 Params: []string{replyTo, cmd, "Unknown CAP command"},
445 }}
446 }
447 return nil
448}
449
450func sanityCheckServer(addr string) error {
451 dialer := net.Dialer{Timeout: 30 * time.Second}
452 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
453 if err != nil {
454 return err
455 }
456 return conn.Close()
457}
458
459func (dc *downstreamConn) register() error {
460 username := dc.rawUsername
461 var networkName string
462 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
463 networkName = username[i+1:]
464 }
465 if i := strings.IndexAny(username, "/@"); i >= 0 {
466 username = username[:i]
467 }
468 dc.username = "~" + username
469
470 password := dc.password
471 dc.password = ""
472
473 u := dc.srv.getUser(username)
474 if u == nil {
475 dc.logger.Printf("failed authentication for %q: unknown username", username)
476 return errAuthFailed
477 }
478
479 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
480 if err != nil {
481 dc.logger.Printf("failed authentication for %q: %v", username, err)
482 return errAuthFailed
483 }
484
485 var network *network
486 if networkName != "" {
487 network = u.getNetwork(networkName)
488 if network == nil {
489 addr := networkName
490 if !strings.ContainsRune(addr, ':') {
491 addr = addr + ":6697"
492 }
493
494 dc.logger.Printf("trying to connect to new network %q", addr)
495 if err := sanityCheckServer(addr); err != nil {
496 dc.logger.Printf("failed to connect to %q: %v", addr, err)
497 return ircError{&irc.Message{
498 Command: irc.ERR_PASSWDMISMATCH,
499 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
500 }}
501 }
502
503 dc.logger.Printf("auto-saving network %q", networkName)
504 network, err = u.createNetwork(networkName, dc.nick)
505 if err != nil {
506 return err
507 }
508 }
509 }
510
511 dc.registered = true
512 dc.user = u
513 dc.network = network
514
515 u.lock.Lock()
516 firstDownstream := len(u.downstreamConns) == 0
517 u.downstreamConns = append(u.downstreamConns, dc)
518 u.lock.Unlock()
519
520 dc.SendMessage(&irc.Message{
521 Prefix: dc.srv.prefix(),
522 Command: irc.RPL_WELCOME,
523 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
524 })
525 dc.SendMessage(&irc.Message{
526 Prefix: dc.srv.prefix(),
527 Command: irc.RPL_YOURHOST,
528 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
529 })
530 dc.SendMessage(&irc.Message{
531 Prefix: dc.srv.prefix(),
532 Command: irc.RPL_CREATED,
533 Params: []string{dc.nick, "Who cares when the server was created?"},
534 })
535 dc.SendMessage(&irc.Message{
536 Prefix: dc.srv.prefix(),
537 Command: irc.RPL_MYINFO,
538 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
539 })
540 // TODO: RPL_ISUPPORT
541 dc.SendMessage(&irc.Message{
542 Prefix: dc.srv.prefix(),
543 Command: irc.ERR_NOMOTD,
544 Params: []string{dc.nick, "No MOTD"},
545 })
546
547 dc.forEachUpstream(func(uc *upstreamConn) {
548 for _, ch := range uc.channels {
549 if ch.complete {
550 forwardChannel(dc, ch)
551 }
552 }
553
554 historyName := dc.username
555
556 var seqPtr *uint64
557 if firstDownstream {
558 uc.lock.Lock()
559 seq, ok := uc.history[historyName]
560 uc.lock.Unlock()
561 if ok {
562 seqPtr = &seq
563 }
564 }
565
566 consumer, ch := uc.ring.NewConsumer(seqPtr)
567 go func() {
568 for {
569 var closed bool
570 select {
571 case <-ch:
572 dc.ringMessages <- ringMessage{consumer, uc}
573 case <-dc.closed:
574 closed = true
575 }
576 if closed {
577 break
578 }
579 }
580
581 seq := consumer.Close()
582
583 dc.user.lock.Lock()
584 lastDownstream := len(dc.user.downstreamConns) == 0
585 dc.user.lock.Unlock()
586
587 if lastDownstream {
588 uc.lock.Lock()
589 uc.history[historyName] = seq
590 uc.lock.Unlock()
591 }
592 }()
593 })
594
595 return nil
596}
597
598func (dc *downstreamConn) runUntilRegistered() error {
599 for !dc.registered {
600 msg, err := dc.irc.ReadMessage()
601 if err != nil {
602 return fmt.Errorf("failed to read IRC command: %v", err)
603 }
604
605 err = dc.handleMessage(msg)
606 if ircErr, ok := err.(ircError); ok {
607 ircErr.Message.Prefix = dc.srv.prefix()
608 dc.SendMessage(ircErr.Message)
609 } else if err != nil {
610 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
611 }
612 }
613
614 return nil
615}
616
617func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
618 switch msg.Command {
619 case "PING":
620 dc.SendMessage(&irc.Message{
621 Prefix: dc.srv.prefix(),
622 Command: "PONG",
623 Params: msg.Params,
624 })
625 return nil
626 case "USER":
627 return ircError{&irc.Message{
628 Command: irc.ERR_ALREADYREGISTERED,
629 Params: []string{dc.nick, "You may not reregister"},
630 }}
631 case "NICK":
632 var nick string
633 if err := parseMessageParams(msg, &nick); err != nil {
634 return err
635 }
636
637 var err error
638 dc.forEachNetwork(func(n *network) {
639 if err != nil {
640 return
641 }
642 n.Nick = nick
643 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
644 })
645 if err != nil {
646 return err
647 }
648
649 dc.forEachUpstream(func(uc *upstreamConn) {
650 uc.SendMessage(msg)
651 })
652 case "JOIN", "PART":
653 var name string
654 if err := parseMessageParams(msg, &name); err != nil {
655 return err
656 }
657
658 uc, upstreamName, err := dc.unmarshalChannel(name)
659 if err != nil {
660 return ircError{&irc.Message{
661 Command: irc.ERR_NOSUCHCHANNEL,
662 Params: []string{name, err.Error()},
663 }}
664 }
665
666 uc.SendMessage(&irc.Message{
667 Command: msg.Command,
668 Params: []string{upstreamName},
669 })
670
671 switch msg.Command {
672 case "JOIN":
673 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
674 Name: upstreamName,
675 })
676 if err != nil {
677 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
678 }
679 case "PART":
680 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
681 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
682 }
683 }
684 case "MODE":
685 if msg.Prefix == nil {
686 return fmt.Errorf("missing prefix")
687 }
688
689 var name string
690 if err := parseMessageParams(msg, &name); err != nil {
691 return err
692 }
693
694 var modeStr string
695 if len(msg.Params) > 1 {
696 modeStr = msg.Params[1]
697 }
698
699 if msg.Prefix.Name != name {
700 uc, upstreamName, err := dc.unmarshalChannel(name)
701 if err != nil {
702 return err
703 }
704
705 if modeStr != "" {
706 uc.SendMessage(&irc.Message{
707 Command: "MODE",
708 Params: []string{upstreamName, modeStr},
709 })
710 } else {
711 ch, ok := uc.channels[upstreamName]
712 if !ok {
713 return ircError{&irc.Message{
714 Command: irc.ERR_NOSUCHCHANNEL,
715 Params: []string{name, "No such channel"},
716 }}
717 }
718
719 dc.SendMessage(&irc.Message{
720 Prefix: dc.srv.prefix(),
721 Command: irc.RPL_CHANNELMODEIS,
722 Params: []string{name, string(ch.modes)},
723 })
724 }
725 } else {
726 if name != dc.nick {
727 return ircError{&irc.Message{
728 Command: irc.ERR_USERSDONTMATCH,
729 Params: []string{dc.nick, "Cannot change mode for other users"},
730 }}
731 }
732
733 if modeStr != "" {
734 dc.forEachUpstream(func(uc *upstreamConn) {
735 uc.SendMessage(&irc.Message{
736 Command: "MODE",
737 Params: []string{uc.nick, modeStr},
738 })
739 })
740 } else {
741 dc.SendMessage(&irc.Message{
742 Prefix: dc.srv.prefix(),
743 Command: irc.RPL_UMODEIS,
744 Params: []string{""}, // TODO
745 })
746 }
747 }
748 case "PRIVMSG":
749 var targetsStr, text string
750 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
751 return err
752 }
753
754 for _, name := range strings.Split(targetsStr, ",") {
755 uc, upstreamName, err := dc.unmarshalChannel(name)
756 if err != nil {
757 return err
758 }
759
760 if upstreamName == "NickServ" {
761 dc.handleNickServPRIVMSG(uc, text)
762 }
763
764 uc.SendMessage(&irc.Message{
765 Command: "PRIVMSG",
766 Params: []string{upstreamName, text},
767 })
768
769 dc.lock.Lock()
770 dc.ourMessages[msg] = struct{}{}
771 dc.lock.Unlock()
772
773 uc.ring.Produce(msg)
774 }
775 default:
776 dc.logger.Printf("unhandled message: %v", msg)
777 return newUnknownCommandError(msg.Command)
778 }
779 return nil
780}
781
782func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
783 username, password, ok := parseNickServCredentials(text, uc.nick)
784 if !ok {
785 return
786 }
787
788 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
789 n := uc.network
790 n.SASL.Mechanism = "PLAIN"
791 n.SASL.Plain.Username = username
792 n.SASL.Plain.Password = password
793 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
794 dc.logger.Printf("failed to save NickServ credentials: %v", err)
795 }
796}
797
798func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
799 fields := strings.Fields(text)
800 if len(fields) < 2 {
801 return "", "", false
802 }
803 cmd := strings.ToUpper(fields[0])
804 params := fields[1:]
805 switch cmd {
806 case "REGISTER":
807 username = nick
808 password = params[0]
809 case "IDENTIFY":
810 if len(params) == 1 {
811 username = nick
812 } else {
813 username = params[0]
814 }
815 password = params[1]
816 }
817 return username, password, true
818}
Note: See TracBrowser for help on using the repository browser.