source: code/trunk/downstream.go@ 111

Last change on this file since 111 was 111, checked in by contact, 5 years ago

Allow CAP command when registered

File size: 17.8 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "strings"
10 "sync"
11 "time"
12
13 "golang.org/x/crypto/bcrypt"
14 "gopkg.in/irc.v3"
15)
16
17type ircError struct {
18 Message *irc.Message
19}
20
21func (err ircError) Error() string {
22 return err.Message.String()
23}
24
25func newUnknownCommandError(cmd string) ircError {
26 return ircError{&irc.Message{
27 Command: irc.ERR_UNKNOWNCOMMAND,
28 Params: []string{
29 "*",
30 cmd,
31 "Unknown command",
32 },
33 }}
34}
35
36func newNeedMoreParamsError(cmd string) ircError {
37 return ircError{&irc.Message{
38 Command: irc.ERR_NEEDMOREPARAMS,
39 Params: []string{
40 "*",
41 cmd,
42 "Not enough parameters",
43 },
44 }}
45}
46
47var errAuthFailed = ircError{&irc.Message{
48 Command: irc.ERR_PASSWDMISMATCH,
49 Params: []string{"*", "Invalid username or password"},
50}}
51
52type ringMessage struct {
53 consumer *RingConsumer
54 upstreamConn *upstreamConn
55}
56
57type downstreamConn struct {
58 net net.Conn
59 irc *irc.Conn
60 srv *Server
61 logger Logger
62 outgoing chan *irc.Message
63 ringMessages chan ringMessage
64 closed chan struct{}
65
66 registered bool
67 user *user
68 nick string
69 username string
70 rawUsername string
71 realname string
72 password string // empty after authentication
73 network *network // can be nil
74
75 negociatingCaps bool
76 capVersion int
77 caps map[string]bool
78
79 lock sync.Mutex
80 ourMessages map[*irc.Message]struct{}
81}
82
83func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
84 dc := &downstreamConn{
85 net: netConn,
86 irc: irc.NewConn(netConn),
87 srv: srv,
88 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
89 outgoing: make(chan *irc.Message, 64),
90 ringMessages: make(chan ringMessage),
91 closed: make(chan struct{}),
92 caps: make(map[string]bool),
93 ourMessages: make(map[*irc.Message]struct{}),
94 }
95
96 go func() {
97 if err := dc.writeMessages(); err != nil {
98 dc.logger.Printf("failed to write message: %v", err)
99 }
100 if err := dc.net.Close(); err != nil {
101 dc.logger.Printf("failed to close connection: %v", err)
102 } else {
103 dc.logger.Printf("connection closed")
104 }
105 }()
106
107 return dc
108}
109
110func (dc *downstreamConn) prefix() *irc.Prefix {
111 return &irc.Prefix{
112 Name: dc.nick,
113 User: dc.username,
114 // TODO: fill the host?
115 }
116}
117
118func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
119 return name
120}
121
122func (dc *downstreamConn) forEachNetwork(f func(*network)) {
123 if dc.network != nil {
124 f(dc.network)
125 } else {
126 dc.user.forEachNetwork(f)
127 }
128}
129
130func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
131 dc.user.forEachUpstream(func(uc *upstreamConn) {
132 if dc.network != nil && uc.network != dc.network {
133 return
134 }
135 f(uc)
136 })
137}
138
139// upstream returns the upstream connection, if any. If there are zero or if
140// there are multiple upstream connections, it returns nil.
141func (dc *downstreamConn) upstream() *upstreamConn {
142 if dc.network == nil {
143 return nil
144 }
145
146 var upstream *upstreamConn
147 dc.forEachUpstream(func(uc *upstreamConn) {
148 upstream = uc
149 })
150 return upstream
151}
152
153func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
154 if uc := dc.upstream(); uc != nil {
155 return uc, name, nil
156 }
157
158 // TODO: extract network name from channel name if dc.upstream == nil
159 var channel *upstreamChannel
160 var err error
161 dc.forEachUpstream(func(uc *upstreamConn) {
162 if err != nil {
163 return
164 }
165 if ch, ok := uc.channels[name]; ok {
166 if channel != nil {
167 err = fmt.Errorf("ambiguous channel name %q", name)
168 } else {
169 channel = ch
170 }
171 }
172 })
173 if channel == nil {
174 return nil, "", ircError{&irc.Message{
175 Command: irc.ERR_NOSUCHCHANNEL,
176 Params: []string{name, "No such channel"},
177 }}
178 }
179 return channel.conn, channel.Name, nil
180}
181
182func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
183 if nick == uc.nick {
184 return dc.nick
185 }
186 return nick
187}
188
189func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
190 if prefix.Name == uc.nick {
191 return dc.prefix()
192 }
193 return prefix
194}
195
196func (dc *downstreamConn) isClosed() bool {
197 select {
198 case <-dc.closed:
199 return true
200 default:
201 return false
202 }
203}
204
205func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
206 dc.logger.Printf("new connection")
207
208 for {
209 msg, err := dc.irc.ReadMessage()
210 if err == io.EOF {
211 break
212 } else if err != nil {
213 return fmt.Errorf("failed to read IRC command: %v", err)
214 }
215
216 if dc.srv.Debug {
217 dc.logger.Printf("received: %v", msg)
218 }
219
220 ch <- downstreamIncomingMessage{msg, dc}
221 }
222
223 return nil
224}
225
226func (dc *downstreamConn) writeMessages() error {
227 for {
228 var err error
229 var closed bool
230 select {
231 case msg := <-dc.outgoing:
232 if dc.srv.Debug {
233 dc.logger.Printf("sent: %v", msg)
234 }
235 err = dc.irc.WriteMessage(msg)
236 case ringMessage := <-dc.ringMessages:
237 consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
238 for {
239 msg := consumer.Peek()
240 if msg == nil {
241 break
242 }
243
244 dc.lock.Lock()
245 _, ours := dc.ourMessages[msg]
246 delete(dc.ourMessages, msg)
247 dc.lock.Unlock()
248 if ours {
249 // The message comes from our connection, don't echo it
250 // back
251 continue
252 }
253
254 msg = msg.Copy()
255 switch msg.Command {
256 case "PRIVMSG":
257 // TODO: detect whether it's a user or a channel
258 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
259 default:
260 panic("expected to consume a PRIVMSG message")
261 }
262 if dc.srv.Debug {
263 dc.logger.Printf("sent: %v", msg)
264 }
265 err = dc.irc.WriteMessage(msg)
266 if err != nil {
267 break
268 }
269 consumer.Consume()
270 }
271 case <-dc.closed:
272 closed = true
273 }
274 if err != nil {
275 return err
276 }
277 if closed {
278 break
279 }
280 }
281 return nil
282}
283
284func (dc *downstreamConn) Close() error {
285 if dc.isClosed() {
286 return fmt.Errorf("downstream connection already closed")
287 }
288
289 if u := dc.user; u != nil {
290 u.lock.Lock()
291 for i := range u.downstreamConns {
292 if u.downstreamConns[i] == dc {
293 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
294 break
295 }
296 }
297 u.lock.Unlock()
298 }
299
300 close(dc.closed)
301 return nil
302}
303
304func (dc *downstreamConn) SendMessage(msg *irc.Message) {
305 dc.outgoing <- msg
306}
307
308func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
309 switch msg.Command {
310 case "QUIT":
311 return dc.Close()
312 default:
313 if dc.registered {
314 return dc.handleMessageRegistered(msg)
315 } else {
316 return dc.handleMessageUnregistered(msg)
317 }
318 }
319}
320
321func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
322 switch msg.Command {
323 case "NICK":
324 if err := parseMessageParams(msg, &dc.nick); err != nil {
325 return err
326 }
327 case "USER":
328 var username string
329 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
330 return err
331 }
332 dc.rawUsername = username
333 case "PASS":
334 if err := parseMessageParams(msg, &dc.password); err != nil {
335 return err
336 }
337 case "CAP":
338 var subCmd string
339 if err := parseMessageParams(msg, &subCmd); err != nil {
340 return err
341 }
342 if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
343 return err
344 }
345 default:
346 dc.logger.Printf("unhandled message: %v", msg)
347 return newUnknownCommandError(msg.Command)
348 }
349 if dc.rawUsername != "" && dc.nick != "" && !dc.negociatingCaps {
350 return dc.register()
351 }
352 return nil
353}
354
355func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
356 cmd = strings.ToUpper(cmd)
357
358 replyTo := dc.nick
359 if !dc.registered {
360 replyTo = "*"
361 }
362
363 switch cmd {
364 case "LS":
365 if len(args) > 0 {
366 var err error
367 if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
368 return err
369 }
370 }
371
372 var caps []string
373 /*if dc.capVersion >= 302 {
374 caps = append(caps, "sasl=PLAIN")
375 } else {
376 caps = append(caps, "sasl")
377 }*/
378
379 // TODO: multi-line replies
380 dc.SendMessage(&irc.Message{
381 Prefix: dc.srv.prefix(),
382 Command: "CAP",
383 Params: []string{replyTo, "LS", strings.Join(caps, " ")},
384 })
385
386 if !dc.registered {
387 dc.negociatingCaps = true
388 }
389 case "LIST":
390 var caps []string
391 for name := range dc.caps {
392 caps = append(caps, name)
393 }
394
395 // TODO: multi-line replies
396 dc.SendMessage(&irc.Message{
397 Prefix: dc.srv.prefix(),
398 Command: "CAP",
399 Params: []string{replyTo, "LIST", strings.Join(caps, " ")},
400 })
401 case "REQ":
402 if len(args) == 0 {
403 return ircError{&irc.Message{
404 Command: err_invalidcapcmd,
405 Params: []string{replyTo, cmd, "Missing argument in CAP REQ command"},
406 }}
407 }
408
409 caps := strings.Fields(args[0])
410 ack := true
411 for _, name := range caps {
412 name = strings.ToLower(name)
413 enable := !strings.HasPrefix(name, "-")
414 if !enable {
415 name = strings.TrimPrefix(name, "-")
416 }
417
418 enabled := dc.caps[name]
419 if enable == enabled {
420 continue
421 }
422
423 switch name {
424 /*case "sasl":
425 dc.caps[name] = enable*/
426 default:
427 ack = false
428 }
429 }
430
431 reply := "NAK"
432 if ack {
433 reply = "ACK"
434 }
435 dc.SendMessage(&irc.Message{
436 Prefix: dc.srv.prefix(),
437 Command: "CAP",
438 Params: []string{replyTo, reply, args[0]},
439 })
440 case "END":
441 dc.negociatingCaps = false
442 default:
443 return ircError{&irc.Message{
444 Command: err_invalidcapcmd,
445 Params: []string{replyTo, cmd, "Unknown CAP command"},
446 }}
447 }
448 return nil
449}
450
451func sanityCheckServer(addr string) error {
452 dialer := net.Dialer{Timeout: 30 * time.Second}
453 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
454 if err != nil {
455 return err
456 }
457 return conn.Close()
458}
459
460func (dc *downstreamConn) register() error {
461 username := dc.rawUsername
462 var networkName string
463 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
464 networkName = username[i+1:]
465 }
466 if i := strings.IndexAny(username, "/@"); i >= 0 {
467 username = username[:i]
468 }
469 dc.username = "~" + username
470
471 password := dc.password
472 dc.password = ""
473
474 u := dc.srv.getUser(username)
475 if u == nil {
476 dc.logger.Printf("failed authentication for %q: unknown username", username)
477 return errAuthFailed
478 }
479
480 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
481 if err != nil {
482 dc.logger.Printf("failed authentication for %q: %v", username, err)
483 return errAuthFailed
484 }
485
486 var network *network
487 if networkName != "" {
488 network = u.getNetwork(networkName)
489 if network == nil {
490 addr := networkName
491 if !strings.ContainsRune(addr, ':') {
492 addr = addr + ":6697"
493 }
494
495 dc.logger.Printf("trying to connect to new network %q", addr)
496 if err := sanityCheckServer(addr); err != nil {
497 dc.logger.Printf("failed to connect to %q: %v", addr, err)
498 return ircError{&irc.Message{
499 Command: irc.ERR_PASSWDMISMATCH,
500 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
501 }}
502 }
503
504 dc.logger.Printf("auto-saving network %q", networkName)
505 network, err = u.createNetwork(networkName, dc.nick)
506 if err != nil {
507 return err
508 }
509 }
510 }
511
512 dc.registered = true
513 dc.user = u
514 dc.network = network
515
516 u.lock.Lock()
517 firstDownstream := len(u.downstreamConns) == 0
518 u.downstreamConns = append(u.downstreamConns, dc)
519 u.lock.Unlock()
520
521 dc.SendMessage(&irc.Message{
522 Prefix: dc.srv.prefix(),
523 Command: irc.RPL_WELCOME,
524 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
525 })
526 dc.SendMessage(&irc.Message{
527 Prefix: dc.srv.prefix(),
528 Command: irc.RPL_YOURHOST,
529 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
530 })
531 dc.SendMessage(&irc.Message{
532 Prefix: dc.srv.prefix(),
533 Command: irc.RPL_CREATED,
534 Params: []string{dc.nick, "Who cares when the server was created?"},
535 })
536 dc.SendMessage(&irc.Message{
537 Prefix: dc.srv.prefix(),
538 Command: irc.RPL_MYINFO,
539 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
540 })
541 // TODO: RPL_ISUPPORT
542 dc.SendMessage(&irc.Message{
543 Prefix: dc.srv.prefix(),
544 Command: irc.ERR_NOMOTD,
545 Params: []string{dc.nick, "No MOTD"},
546 })
547
548 dc.forEachUpstream(func(uc *upstreamConn) {
549 for _, ch := range uc.channels {
550 if ch.complete {
551 forwardChannel(dc, ch)
552 }
553 }
554
555 historyName := dc.username
556
557 var seqPtr *uint64
558 if firstDownstream {
559 uc.lock.Lock()
560 seq, ok := uc.history[historyName]
561 uc.lock.Unlock()
562 if ok {
563 seqPtr = &seq
564 }
565 }
566
567 consumer, ch := uc.ring.NewConsumer(seqPtr)
568 go func() {
569 for {
570 var closed bool
571 select {
572 case <-ch:
573 dc.ringMessages <- ringMessage{consumer, uc}
574 case <-dc.closed:
575 closed = true
576 }
577 if closed {
578 break
579 }
580 }
581
582 seq := consumer.Close()
583
584 dc.user.lock.Lock()
585 lastDownstream := len(dc.user.downstreamConns) == 0
586 dc.user.lock.Unlock()
587
588 if lastDownstream {
589 uc.lock.Lock()
590 uc.history[historyName] = seq
591 uc.lock.Unlock()
592 }
593 }()
594 })
595
596 return nil
597}
598
599func (dc *downstreamConn) runUntilRegistered() error {
600 for !dc.registered {
601 msg, err := dc.irc.ReadMessage()
602 if err != nil {
603 return fmt.Errorf("failed to read IRC command: %v", err)
604 }
605
606 if dc.srv.Debug {
607 dc.logger.Printf("received: %v", msg)
608 }
609
610 err = dc.handleMessage(msg)
611 if ircErr, ok := err.(ircError); ok {
612 ircErr.Message.Prefix = dc.srv.prefix()
613 dc.SendMessage(ircErr.Message)
614 } else if err != nil {
615 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
616 }
617 }
618
619 return nil
620}
621
622func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
623 switch msg.Command {
624 case "CAP":
625 var subCmd string
626 if err := parseMessageParams(msg, &subCmd); err != nil {
627 return err
628 }
629 if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
630 return err
631 }
632 case "PING":
633 dc.SendMessage(&irc.Message{
634 Prefix: dc.srv.prefix(),
635 Command: "PONG",
636 Params: msg.Params,
637 })
638 return nil
639 case "USER":
640 return ircError{&irc.Message{
641 Command: irc.ERR_ALREADYREGISTERED,
642 Params: []string{dc.nick, "You may not reregister"},
643 }}
644 case "NICK":
645 var nick string
646 if err := parseMessageParams(msg, &nick); err != nil {
647 return err
648 }
649
650 var err error
651 dc.forEachNetwork(func(n *network) {
652 if err != nil {
653 return
654 }
655 n.Nick = nick
656 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
657 })
658 if err != nil {
659 return err
660 }
661
662 dc.forEachUpstream(func(uc *upstreamConn) {
663 uc.SendMessage(msg)
664 })
665 case "JOIN", "PART":
666 var name string
667 if err := parseMessageParams(msg, &name); err != nil {
668 return err
669 }
670
671 uc, upstreamName, err := dc.unmarshalChannel(name)
672 if err != nil {
673 return ircError{&irc.Message{
674 Command: irc.ERR_NOSUCHCHANNEL,
675 Params: []string{name, err.Error()},
676 }}
677 }
678
679 uc.SendMessage(&irc.Message{
680 Command: msg.Command,
681 Params: []string{upstreamName},
682 })
683
684 switch msg.Command {
685 case "JOIN":
686 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
687 Name: upstreamName,
688 })
689 if err != nil {
690 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
691 }
692 case "PART":
693 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
694 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
695 }
696 }
697 case "MODE":
698 if msg.Prefix == nil {
699 return fmt.Errorf("missing prefix")
700 }
701
702 var name string
703 if err := parseMessageParams(msg, &name); err != nil {
704 return err
705 }
706
707 var modeStr string
708 if len(msg.Params) > 1 {
709 modeStr = msg.Params[1]
710 }
711
712 if msg.Prefix.Name != name {
713 uc, upstreamName, err := dc.unmarshalChannel(name)
714 if err != nil {
715 return err
716 }
717
718 if modeStr != "" {
719 uc.SendMessage(&irc.Message{
720 Command: "MODE",
721 Params: []string{upstreamName, modeStr},
722 })
723 } else {
724 ch, ok := uc.channels[upstreamName]
725 if !ok {
726 return ircError{&irc.Message{
727 Command: irc.ERR_NOSUCHCHANNEL,
728 Params: []string{name, "No such channel"},
729 }}
730 }
731
732 dc.SendMessage(&irc.Message{
733 Prefix: dc.srv.prefix(),
734 Command: irc.RPL_CHANNELMODEIS,
735 Params: []string{name, string(ch.modes)},
736 })
737 }
738 } else {
739 if name != dc.nick {
740 return ircError{&irc.Message{
741 Command: irc.ERR_USERSDONTMATCH,
742 Params: []string{dc.nick, "Cannot change mode for other users"},
743 }}
744 }
745
746 if modeStr != "" {
747 dc.forEachUpstream(func(uc *upstreamConn) {
748 uc.SendMessage(&irc.Message{
749 Command: "MODE",
750 Params: []string{uc.nick, modeStr},
751 })
752 })
753 } else {
754 dc.SendMessage(&irc.Message{
755 Prefix: dc.srv.prefix(),
756 Command: irc.RPL_UMODEIS,
757 Params: []string{""}, // TODO
758 })
759 }
760 }
761 case "PRIVMSG":
762 var targetsStr, text string
763 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
764 return err
765 }
766
767 for _, name := range strings.Split(targetsStr, ",") {
768 uc, upstreamName, err := dc.unmarshalChannel(name)
769 if err != nil {
770 return err
771 }
772
773 if upstreamName == "NickServ" {
774 dc.handleNickServPRIVMSG(uc, text)
775 }
776
777 uc.SendMessage(&irc.Message{
778 Command: "PRIVMSG",
779 Params: []string{upstreamName, text},
780 })
781
782 dc.lock.Lock()
783 dc.ourMessages[msg] = struct{}{}
784 dc.lock.Unlock()
785
786 uc.ring.Produce(msg)
787 }
788 default:
789 dc.logger.Printf("unhandled message: %v", msg)
790 return newUnknownCommandError(msg.Command)
791 }
792 return nil
793}
794
795func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
796 username, password, ok := parseNickServCredentials(text, uc.nick)
797 if !ok {
798 return
799 }
800
801 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
802 n := uc.network
803 n.SASL.Mechanism = "PLAIN"
804 n.SASL.Plain.Username = username
805 n.SASL.Plain.Password = password
806 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
807 dc.logger.Printf("failed to save NickServ credentials: %v", err)
808 }
809}
810
811func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
812 fields := strings.Fields(text)
813 if len(fields) < 2 {
814 return "", "", false
815 }
816 cmd := strings.ToUpper(fields[0])
817 params := fields[1:]
818 switch cmd {
819 case "REGISTER":
820 username = nick
821 password = params[0]
822 case "IDENTIFY":
823 if len(params) == 1 {
824 username = nick
825 } else {
826 username = params[0]
827 }
828 password = params[1]
829 }
830 return username, password, true
831}
Note: See TracBrowser for help on using the repository browser.