source: code/trunk/downstream.go@ 110

Last change on this file since 110 was 110, checked in by contact, 5 years ago

Log downstream messages before registration

File size: 17.6 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "strings"
10 "sync"
11 "time"
12
13 "golang.org/x/crypto/bcrypt"
14 "gopkg.in/irc.v3"
15)
16
17type ircError struct {
18 Message *irc.Message
19}
20
21func (err ircError) Error() string {
22 return err.Message.String()
23}
24
25func newUnknownCommandError(cmd string) ircError {
26 return ircError{&irc.Message{
27 Command: irc.ERR_UNKNOWNCOMMAND,
28 Params: []string{
29 "*",
30 cmd,
31 "Unknown command",
32 },
33 }}
34}
35
36func newNeedMoreParamsError(cmd string) ircError {
37 return ircError{&irc.Message{
38 Command: irc.ERR_NEEDMOREPARAMS,
39 Params: []string{
40 "*",
41 cmd,
42 "Not enough parameters",
43 },
44 }}
45}
46
47var errAuthFailed = ircError{&irc.Message{
48 Command: irc.ERR_PASSWDMISMATCH,
49 Params: []string{"*", "Invalid username or password"},
50}}
51
52type ringMessage struct {
53 consumer *RingConsumer
54 upstreamConn *upstreamConn
55}
56
57type downstreamConn struct {
58 net net.Conn
59 irc *irc.Conn
60 srv *Server
61 logger Logger
62 outgoing chan *irc.Message
63 ringMessages chan ringMessage
64 closed chan struct{}
65
66 registered bool
67 user *user
68 nick string
69 username string
70 rawUsername string
71 realname string
72 password string // empty after authentication
73 network *network // can be nil
74
75 negociatingCaps bool
76 capVersion int
77 caps map[string]bool
78
79 lock sync.Mutex
80 ourMessages map[*irc.Message]struct{}
81}
82
83func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
84 dc := &downstreamConn{
85 net: netConn,
86 irc: irc.NewConn(netConn),
87 srv: srv,
88 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
89 outgoing: make(chan *irc.Message, 64),
90 ringMessages: make(chan ringMessage),
91 closed: make(chan struct{}),
92 caps: make(map[string]bool),
93 ourMessages: make(map[*irc.Message]struct{}),
94 }
95
96 go func() {
97 if err := dc.writeMessages(); err != nil {
98 dc.logger.Printf("failed to write message: %v", err)
99 }
100 if err := dc.net.Close(); err != nil {
101 dc.logger.Printf("failed to close connection: %v", err)
102 } else {
103 dc.logger.Printf("connection closed")
104 }
105 }()
106
107 return dc
108}
109
110func (dc *downstreamConn) prefix() *irc.Prefix {
111 return &irc.Prefix{
112 Name: dc.nick,
113 User: dc.username,
114 // TODO: fill the host?
115 }
116}
117
118func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
119 return name
120}
121
122func (dc *downstreamConn) forEachNetwork(f func(*network)) {
123 if dc.network != nil {
124 f(dc.network)
125 } else {
126 dc.user.forEachNetwork(f)
127 }
128}
129
130func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
131 dc.user.forEachUpstream(func(uc *upstreamConn) {
132 if dc.network != nil && uc.network != dc.network {
133 return
134 }
135 f(uc)
136 })
137}
138
139// upstream returns the upstream connection, if any. If there are zero or if
140// there are multiple upstream connections, it returns nil.
141func (dc *downstreamConn) upstream() *upstreamConn {
142 if dc.network == nil {
143 return nil
144 }
145
146 var upstream *upstreamConn
147 dc.forEachUpstream(func(uc *upstreamConn) {
148 upstream = uc
149 })
150 return upstream
151}
152
153func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
154 if uc := dc.upstream(); uc != nil {
155 return uc, name, nil
156 }
157
158 // TODO: extract network name from channel name if dc.upstream == nil
159 var channel *upstreamChannel
160 var err error
161 dc.forEachUpstream(func(uc *upstreamConn) {
162 if err != nil {
163 return
164 }
165 if ch, ok := uc.channels[name]; ok {
166 if channel != nil {
167 err = fmt.Errorf("ambiguous channel name %q", name)
168 } else {
169 channel = ch
170 }
171 }
172 })
173 if channel == nil {
174 return nil, "", ircError{&irc.Message{
175 Command: irc.ERR_NOSUCHCHANNEL,
176 Params: []string{name, "No such channel"},
177 }}
178 }
179 return channel.conn, channel.Name, nil
180}
181
182func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
183 if nick == uc.nick {
184 return dc.nick
185 }
186 return nick
187}
188
189func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
190 if prefix.Name == uc.nick {
191 return dc.prefix()
192 }
193 return prefix
194}
195
196func (dc *downstreamConn) isClosed() bool {
197 select {
198 case <-dc.closed:
199 return true
200 default:
201 return false
202 }
203}
204
205func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
206 dc.logger.Printf("new connection")
207
208 for {
209 msg, err := dc.irc.ReadMessage()
210 if err == io.EOF {
211 break
212 } else if err != nil {
213 return fmt.Errorf("failed to read IRC command: %v", err)
214 }
215
216 if dc.srv.Debug {
217 dc.logger.Printf("received: %v", msg)
218 }
219
220 ch <- downstreamIncomingMessage{msg, dc}
221 }
222
223 return nil
224}
225
226func (dc *downstreamConn) writeMessages() error {
227 for {
228 var err error
229 var closed bool
230 select {
231 case msg := <-dc.outgoing:
232 if dc.srv.Debug {
233 dc.logger.Printf("sent: %v", msg)
234 }
235 err = dc.irc.WriteMessage(msg)
236 case ringMessage := <-dc.ringMessages:
237 consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
238 for {
239 msg := consumer.Peek()
240 if msg == nil {
241 break
242 }
243
244 dc.lock.Lock()
245 _, ours := dc.ourMessages[msg]
246 delete(dc.ourMessages, msg)
247 dc.lock.Unlock()
248 if ours {
249 // The message comes from our connection, don't echo it
250 // back
251 continue
252 }
253
254 msg = msg.Copy()
255 switch msg.Command {
256 case "PRIVMSG":
257 // TODO: detect whether it's a user or a channel
258 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
259 default:
260 panic("expected to consume a PRIVMSG message")
261 }
262 if dc.srv.Debug {
263 dc.logger.Printf("sent: %v", msg)
264 }
265 err = dc.irc.WriteMessage(msg)
266 if err != nil {
267 break
268 }
269 consumer.Consume()
270 }
271 case <-dc.closed:
272 closed = true
273 }
274 if err != nil {
275 return err
276 }
277 if closed {
278 break
279 }
280 }
281 return nil
282}
283
284func (dc *downstreamConn) Close() error {
285 if dc.isClosed() {
286 return fmt.Errorf("downstream connection already closed")
287 }
288
289 if u := dc.user; u != nil {
290 u.lock.Lock()
291 for i := range u.downstreamConns {
292 if u.downstreamConns[i] == dc {
293 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
294 break
295 }
296 }
297 u.lock.Unlock()
298 }
299
300 close(dc.closed)
301 return nil
302}
303
304func (dc *downstreamConn) SendMessage(msg *irc.Message) {
305 dc.outgoing <- msg
306}
307
308func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
309 switch msg.Command {
310 case "QUIT":
311 return dc.Close()
312 default:
313 if dc.registered {
314 return dc.handleMessageRegistered(msg)
315 } else {
316 return dc.handleMessageUnregistered(msg)
317 }
318 }
319}
320
321func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
322 switch msg.Command {
323 case "NICK":
324 if err := parseMessageParams(msg, &dc.nick); err != nil {
325 return err
326 }
327 case "USER":
328 var username string
329 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
330 return err
331 }
332 dc.rawUsername = username
333 case "PASS":
334 if err := parseMessageParams(msg, &dc.password); err != nil {
335 return err
336 }
337 case "CAP":
338 var subCmd string
339 if err := parseMessageParams(msg, &subCmd); err != nil {
340 return err
341 }
342 subCmd = strings.ToUpper(subCmd)
343 if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
344 return err
345 }
346 default:
347 dc.logger.Printf("unhandled message: %v", msg)
348 return newUnknownCommandError(msg.Command)
349 }
350 if dc.rawUsername != "" && dc.nick != "" && !dc.negociatingCaps {
351 return dc.register()
352 }
353 return nil
354}
355
356func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
357 replyTo := dc.nick
358 if !dc.registered {
359 replyTo = "*"
360 }
361
362 switch cmd {
363 case "LS":
364 if len(args) > 0 {
365 var err error
366 if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
367 return err
368 }
369 }
370
371 var caps []string
372 /*if dc.capVersion >= 302 {
373 caps = append(caps, "sasl=PLAIN")
374 } else {
375 caps = append(caps, "sasl")
376 }*/
377
378 // TODO: multi-line replies
379 dc.SendMessage(&irc.Message{
380 Prefix: dc.srv.prefix(),
381 Command: "CAP",
382 Params: []string{replyTo, "LS", strings.Join(caps, " ")},
383 })
384
385 if !dc.registered {
386 dc.negociatingCaps = true
387 }
388 case "LIST":
389 var caps []string
390 for name := range dc.caps {
391 caps = append(caps, name)
392 }
393
394 // TODO: multi-line replies
395 dc.SendMessage(&irc.Message{
396 Prefix: dc.srv.prefix(),
397 Command: "CAP",
398 Params: []string{replyTo, "LIST", strings.Join(caps, " ")},
399 })
400 case "REQ":
401 if len(args) == 0 {
402 return ircError{&irc.Message{
403 Command: err_invalidcapcmd,
404 Params: []string{replyTo, cmd, "Missing argument in CAP REQ command"},
405 }}
406 }
407
408 caps := strings.Fields(args[0])
409 ack := true
410 for _, name := range caps {
411 name = strings.ToLower(name)
412 enable := !strings.HasPrefix(name, "-")
413 if !enable {
414 name = strings.TrimPrefix(name, "-")
415 }
416
417 enabled := dc.caps[name]
418 if enable == enabled {
419 continue
420 }
421
422 switch name {
423 /*case "sasl":
424 dc.caps[name] = enable*/
425 default:
426 ack = false
427 }
428 }
429
430 reply := "NAK"
431 if ack {
432 reply = "ACK"
433 }
434 dc.SendMessage(&irc.Message{
435 Prefix: dc.srv.prefix(),
436 Command: "CAP",
437 Params: []string{replyTo, reply, args[0]},
438 })
439 case "END":
440 dc.negociatingCaps = false
441 default:
442 return ircError{&irc.Message{
443 Command: err_invalidcapcmd,
444 Params: []string{replyTo, cmd, "Unknown CAP command"},
445 }}
446 }
447 return nil
448}
449
450func sanityCheckServer(addr string) error {
451 dialer := net.Dialer{Timeout: 30 * time.Second}
452 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
453 if err != nil {
454 return err
455 }
456 return conn.Close()
457}
458
459func (dc *downstreamConn) register() error {
460 username := dc.rawUsername
461 var networkName string
462 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
463 networkName = username[i+1:]
464 }
465 if i := strings.IndexAny(username, "/@"); i >= 0 {
466 username = username[:i]
467 }
468 dc.username = "~" + username
469
470 password := dc.password
471 dc.password = ""
472
473 u := dc.srv.getUser(username)
474 if u == nil {
475 dc.logger.Printf("failed authentication for %q: unknown username", username)
476 return errAuthFailed
477 }
478
479 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
480 if err != nil {
481 dc.logger.Printf("failed authentication for %q: %v", username, err)
482 return errAuthFailed
483 }
484
485 var network *network
486 if networkName != "" {
487 network = u.getNetwork(networkName)
488 if network == nil {
489 addr := networkName
490 if !strings.ContainsRune(addr, ':') {
491 addr = addr + ":6697"
492 }
493
494 dc.logger.Printf("trying to connect to new network %q", addr)
495 if err := sanityCheckServer(addr); err != nil {
496 dc.logger.Printf("failed to connect to %q: %v", addr, err)
497 return ircError{&irc.Message{
498 Command: irc.ERR_PASSWDMISMATCH,
499 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
500 }}
501 }
502
503 dc.logger.Printf("auto-saving network %q", networkName)
504 network, err = u.createNetwork(networkName, dc.nick)
505 if err != nil {
506 return err
507 }
508 }
509 }
510
511 dc.registered = true
512 dc.user = u
513 dc.network = network
514
515 u.lock.Lock()
516 firstDownstream := len(u.downstreamConns) == 0
517 u.downstreamConns = append(u.downstreamConns, dc)
518 u.lock.Unlock()
519
520 dc.SendMessage(&irc.Message{
521 Prefix: dc.srv.prefix(),
522 Command: irc.RPL_WELCOME,
523 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
524 })
525 dc.SendMessage(&irc.Message{
526 Prefix: dc.srv.prefix(),
527 Command: irc.RPL_YOURHOST,
528 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
529 })
530 dc.SendMessage(&irc.Message{
531 Prefix: dc.srv.prefix(),
532 Command: irc.RPL_CREATED,
533 Params: []string{dc.nick, "Who cares when the server was created?"},
534 })
535 dc.SendMessage(&irc.Message{
536 Prefix: dc.srv.prefix(),
537 Command: irc.RPL_MYINFO,
538 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
539 })
540 // TODO: RPL_ISUPPORT
541 dc.SendMessage(&irc.Message{
542 Prefix: dc.srv.prefix(),
543 Command: irc.ERR_NOMOTD,
544 Params: []string{dc.nick, "No MOTD"},
545 })
546
547 dc.forEachUpstream(func(uc *upstreamConn) {
548 for _, ch := range uc.channels {
549 if ch.complete {
550 forwardChannel(dc, ch)
551 }
552 }
553
554 historyName := dc.username
555
556 var seqPtr *uint64
557 if firstDownstream {
558 uc.lock.Lock()
559 seq, ok := uc.history[historyName]
560 uc.lock.Unlock()
561 if ok {
562 seqPtr = &seq
563 }
564 }
565
566 consumer, ch := uc.ring.NewConsumer(seqPtr)
567 go func() {
568 for {
569 var closed bool
570 select {
571 case <-ch:
572 dc.ringMessages <- ringMessage{consumer, uc}
573 case <-dc.closed:
574 closed = true
575 }
576 if closed {
577 break
578 }
579 }
580
581 seq := consumer.Close()
582
583 dc.user.lock.Lock()
584 lastDownstream := len(dc.user.downstreamConns) == 0
585 dc.user.lock.Unlock()
586
587 if lastDownstream {
588 uc.lock.Lock()
589 uc.history[historyName] = seq
590 uc.lock.Unlock()
591 }
592 }()
593 })
594
595 return nil
596}
597
598func (dc *downstreamConn) runUntilRegistered() error {
599 for !dc.registered {
600 msg, err := dc.irc.ReadMessage()
601 if err != nil {
602 return fmt.Errorf("failed to read IRC command: %v", err)
603 }
604
605 if dc.srv.Debug {
606 dc.logger.Printf("received: %v", msg)
607 }
608
609 err = dc.handleMessage(msg)
610 if ircErr, ok := err.(ircError); ok {
611 ircErr.Message.Prefix = dc.srv.prefix()
612 dc.SendMessage(ircErr.Message)
613 } else if err != nil {
614 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
615 }
616 }
617
618 return nil
619}
620
621func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
622 switch msg.Command {
623 case "PING":
624 dc.SendMessage(&irc.Message{
625 Prefix: dc.srv.prefix(),
626 Command: "PONG",
627 Params: msg.Params,
628 })
629 return nil
630 case "USER":
631 return ircError{&irc.Message{
632 Command: irc.ERR_ALREADYREGISTERED,
633 Params: []string{dc.nick, "You may not reregister"},
634 }}
635 case "NICK":
636 var nick string
637 if err := parseMessageParams(msg, &nick); err != nil {
638 return err
639 }
640
641 var err error
642 dc.forEachNetwork(func(n *network) {
643 if err != nil {
644 return
645 }
646 n.Nick = nick
647 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
648 })
649 if err != nil {
650 return err
651 }
652
653 dc.forEachUpstream(func(uc *upstreamConn) {
654 uc.SendMessage(msg)
655 })
656 case "JOIN", "PART":
657 var name string
658 if err := parseMessageParams(msg, &name); err != nil {
659 return err
660 }
661
662 uc, upstreamName, err := dc.unmarshalChannel(name)
663 if err != nil {
664 return ircError{&irc.Message{
665 Command: irc.ERR_NOSUCHCHANNEL,
666 Params: []string{name, err.Error()},
667 }}
668 }
669
670 uc.SendMessage(&irc.Message{
671 Command: msg.Command,
672 Params: []string{upstreamName},
673 })
674
675 switch msg.Command {
676 case "JOIN":
677 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
678 Name: upstreamName,
679 })
680 if err != nil {
681 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
682 }
683 case "PART":
684 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
685 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
686 }
687 }
688 case "MODE":
689 if msg.Prefix == nil {
690 return fmt.Errorf("missing prefix")
691 }
692
693 var name string
694 if err := parseMessageParams(msg, &name); err != nil {
695 return err
696 }
697
698 var modeStr string
699 if len(msg.Params) > 1 {
700 modeStr = msg.Params[1]
701 }
702
703 if msg.Prefix.Name != name {
704 uc, upstreamName, err := dc.unmarshalChannel(name)
705 if err != nil {
706 return err
707 }
708
709 if modeStr != "" {
710 uc.SendMessage(&irc.Message{
711 Command: "MODE",
712 Params: []string{upstreamName, modeStr},
713 })
714 } else {
715 ch, ok := uc.channels[upstreamName]
716 if !ok {
717 return ircError{&irc.Message{
718 Command: irc.ERR_NOSUCHCHANNEL,
719 Params: []string{name, "No such channel"},
720 }}
721 }
722
723 dc.SendMessage(&irc.Message{
724 Prefix: dc.srv.prefix(),
725 Command: irc.RPL_CHANNELMODEIS,
726 Params: []string{name, string(ch.modes)},
727 })
728 }
729 } else {
730 if name != dc.nick {
731 return ircError{&irc.Message{
732 Command: irc.ERR_USERSDONTMATCH,
733 Params: []string{dc.nick, "Cannot change mode for other users"},
734 }}
735 }
736
737 if modeStr != "" {
738 dc.forEachUpstream(func(uc *upstreamConn) {
739 uc.SendMessage(&irc.Message{
740 Command: "MODE",
741 Params: []string{uc.nick, modeStr},
742 })
743 })
744 } else {
745 dc.SendMessage(&irc.Message{
746 Prefix: dc.srv.prefix(),
747 Command: irc.RPL_UMODEIS,
748 Params: []string{""}, // TODO
749 })
750 }
751 }
752 case "PRIVMSG":
753 var targetsStr, text string
754 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
755 return err
756 }
757
758 for _, name := range strings.Split(targetsStr, ",") {
759 uc, upstreamName, err := dc.unmarshalChannel(name)
760 if err != nil {
761 return err
762 }
763
764 if upstreamName == "NickServ" {
765 dc.handleNickServPRIVMSG(uc, text)
766 }
767
768 uc.SendMessage(&irc.Message{
769 Command: "PRIVMSG",
770 Params: []string{upstreamName, text},
771 })
772
773 dc.lock.Lock()
774 dc.ourMessages[msg] = struct{}{}
775 dc.lock.Unlock()
776
777 uc.ring.Produce(msg)
778 }
779 default:
780 dc.logger.Printf("unhandled message: %v", msg)
781 return newUnknownCommandError(msg.Command)
782 }
783 return nil
784}
785
786func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
787 username, password, ok := parseNickServCredentials(text, uc.nick)
788 if !ok {
789 return
790 }
791
792 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
793 n := uc.network
794 n.SASL.Mechanism = "PLAIN"
795 n.SASL.Plain.Username = username
796 n.SASL.Plain.Password = password
797 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
798 dc.logger.Printf("failed to save NickServ credentials: %v", err)
799 }
800}
801
802func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
803 fields := strings.Fields(text)
804 if len(fields) < 2 {
805 return "", "", false
806 }
807 cmd := strings.ToUpper(fields[0])
808 params := fields[1:]
809 switch cmd {
810 case "REGISTER":
811 username = nick
812 password = params[0]
813 case "IDENTIFY":
814 if len(params) == 1 {
815 username = nick
816 } else {
817 username = params[0]
818 }
819 password = params[1]
820 }
821 return username, password, true
822}
Note: See TracBrowser for help on using the repository browser.