source: code/trunk/downstream.go@ 105

Last change on this file since 105 was 105, checked in by contact, 5 years ago

Echo downstream PRIVMSGs to other downstream connections

File size: 15.3 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strings"
9 "sync"
10 "time"
11
12 "golang.org/x/crypto/bcrypt"
13 "gopkg.in/irc.v3"
14)
15
16type ircError struct {
17 Message *irc.Message
18}
19
20func (err ircError) Error() string {
21 return err.Message.String()
22}
23
24func newUnknownCommandError(cmd string) ircError {
25 return ircError{&irc.Message{
26 Command: irc.ERR_UNKNOWNCOMMAND,
27 Params: []string{
28 "*",
29 cmd,
30 "Unknown command",
31 },
32 }}
33}
34
35func newNeedMoreParamsError(cmd string) ircError {
36 return ircError{&irc.Message{
37 Command: irc.ERR_NEEDMOREPARAMS,
38 Params: []string{
39 "*",
40 cmd,
41 "Not enough parameters",
42 },
43 }}
44}
45
46var errAuthFailed = ircError{&irc.Message{
47 Command: irc.ERR_PASSWDMISMATCH,
48 Params: []string{"*", "Invalid username or password"},
49}}
50
51type ringMessage struct {
52 consumer *RingConsumer
53 upstreamConn *upstreamConn
54}
55
56type downstreamConn struct {
57 net net.Conn
58 irc *irc.Conn
59 srv *Server
60 logger Logger
61 outgoing chan *irc.Message
62 ringMessages chan ringMessage
63 closed chan struct{}
64
65 registered bool
66 user *user
67 nick string
68 username string
69 rawUsername string
70 realname string
71 password string // empty after authentication
72 network *network // can be nil
73
74 lock sync.Mutex
75 ourMessages map[*irc.Message]struct{}
76}
77
78func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
79 dc := &downstreamConn{
80 net: netConn,
81 irc: irc.NewConn(netConn),
82 srv: srv,
83 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
84 outgoing: make(chan *irc.Message, 64),
85 ringMessages: make(chan ringMessage),
86 closed: make(chan struct{}),
87 ourMessages: make(map[*irc.Message]struct{}),
88 }
89
90 go func() {
91 if err := dc.writeMessages(); err != nil {
92 dc.logger.Printf("failed to write message: %v", err)
93 }
94 if err := dc.net.Close(); err != nil {
95 dc.logger.Printf("failed to close connection: %v", err)
96 } else {
97 dc.logger.Printf("connection closed")
98 }
99 }()
100
101 return dc
102}
103
104func (dc *downstreamConn) prefix() *irc.Prefix {
105 return &irc.Prefix{
106 Name: dc.nick,
107 User: dc.username,
108 // TODO: fill the host?
109 }
110}
111
112func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
113 return name
114}
115
116func (dc *downstreamConn) forEachNetwork(f func(*network)) {
117 if dc.network != nil {
118 f(dc.network)
119 } else {
120 dc.user.forEachNetwork(f)
121 }
122}
123
124func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
125 dc.user.forEachUpstream(func(uc *upstreamConn) {
126 if dc.network != nil && uc.network != dc.network {
127 return
128 }
129 f(uc)
130 })
131}
132
133// upstream returns the upstream connection, if any. If there are zero or if
134// there are multiple upstream connections, it returns nil.
135func (dc *downstreamConn) upstream() *upstreamConn {
136 if dc.network == nil {
137 return nil
138 }
139
140 var upstream *upstreamConn
141 dc.forEachUpstream(func(uc *upstreamConn) {
142 upstream = uc
143 })
144 return upstream
145}
146
147func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
148 if uc := dc.upstream(); uc != nil {
149 return uc, name, nil
150 }
151
152 // TODO: extract network name from channel name if dc.upstream == nil
153 var channel *upstreamChannel
154 var err error
155 dc.forEachUpstream(func(uc *upstreamConn) {
156 if err != nil {
157 return
158 }
159 if ch, ok := uc.channels[name]; ok {
160 if channel != nil {
161 err = fmt.Errorf("ambiguous channel name %q", name)
162 } else {
163 channel = ch
164 }
165 }
166 })
167 if channel == nil {
168 return nil, "", ircError{&irc.Message{
169 Command: irc.ERR_NOSUCHCHANNEL,
170 Params: []string{name, "No such channel"},
171 }}
172 }
173 return channel.conn, channel.Name, nil
174}
175
176func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
177 if nick == uc.nick {
178 return dc.nick
179 }
180 return nick
181}
182
183func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
184 if prefix.Name == uc.nick {
185 return dc.prefix()
186 }
187 return prefix
188}
189
190func (dc *downstreamConn) isClosed() bool {
191 select {
192 case <-dc.closed:
193 return true
194 default:
195 return false
196 }
197}
198
199func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
200 dc.logger.Printf("new connection")
201
202 for {
203 msg, err := dc.irc.ReadMessage()
204 if err == io.EOF {
205 break
206 } else if err != nil {
207 return fmt.Errorf("failed to read IRC command: %v", err)
208 }
209
210 if dc.srv.Debug {
211 dc.logger.Printf("received: %v", msg)
212 }
213
214 ch <- downstreamIncomingMessage{msg, dc}
215 }
216
217 return nil
218}
219
220func (dc *downstreamConn) writeMessages() error {
221 for {
222 var err error
223 var closed bool
224 select {
225 case msg := <-dc.outgoing:
226 if dc.srv.Debug {
227 dc.logger.Printf("sent: %v", msg)
228 }
229 err = dc.irc.WriteMessage(msg)
230 case ringMessage := <-dc.ringMessages:
231 consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
232 for {
233 msg := consumer.Peek()
234 if msg == nil {
235 break
236 }
237
238 dc.lock.Lock()
239 _, ours := dc.ourMessages[msg]
240 delete(dc.ourMessages, msg)
241 dc.lock.Unlock()
242 if ours {
243 // The message comes from our connection, don't echo it
244 // back
245 continue
246 }
247
248 msg = msg.Copy()
249 switch msg.Command {
250 case "PRIVMSG":
251 // TODO: detect whether it's a user or a channel
252 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
253 default:
254 panic("expected to consume a PRIVMSG message")
255 }
256 if dc.srv.Debug {
257 dc.logger.Printf("sent: %v", msg)
258 }
259 err = dc.irc.WriteMessage(msg)
260 if err != nil {
261 break
262 }
263 consumer.Consume()
264 }
265 case <-dc.closed:
266 closed = true
267 }
268 if err != nil {
269 return err
270 }
271 if closed {
272 break
273 }
274 }
275 return nil
276}
277
278func (dc *downstreamConn) Close() error {
279 if dc.isClosed() {
280 return fmt.Errorf("downstream connection already closed")
281 }
282
283 if u := dc.user; u != nil {
284 u.lock.Lock()
285 for i := range u.downstreamConns {
286 if u.downstreamConns[i] == dc {
287 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
288 break
289 }
290 }
291 u.lock.Unlock()
292 }
293
294 close(dc.closed)
295 return nil
296}
297
298func (dc *downstreamConn) SendMessage(msg *irc.Message) {
299 dc.outgoing <- msg
300}
301
302func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
303 switch msg.Command {
304 case "QUIT":
305 return dc.Close()
306 case "PING":
307 dc.SendMessage(&irc.Message{
308 Prefix: dc.srv.prefix(),
309 Command: "PONG",
310 Params: msg.Params,
311 })
312 return nil
313 default:
314 if dc.registered {
315 return dc.handleMessageRegistered(msg)
316 } else {
317 return dc.handleMessageUnregistered(msg)
318 }
319 }
320}
321
322func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
323 switch msg.Command {
324 case "NICK":
325 if err := parseMessageParams(msg, &dc.nick); err != nil {
326 return err
327 }
328 case "USER":
329 var username string
330 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
331 return err
332 }
333 dc.rawUsername = username
334 case "PASS":
335 if err := parseMessageParams(msg, &dc.password); err != nil {
336 return err
337 }
338 default:
339 dc.logger.Printf("unhandled message: %v", msg)
340 return newUnknownCommandError(msg.Command)
341 }
342 if dc.rawUsername != "" && dc.nick != "" {
343 return dc.register()
344 }
345 return nil
346}
347
348func sanityCheckServer(addr string) error {
349 dialer := net.Dialer{Timeout: 30 * time.Second}
350 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
351 if err != nil {
352 return err
353 }
354 return conn.Close()
355}
356
357func (dc *downstreamConn) register() error {
358 username := dc.rawUsername
359 var networkName string
360 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
361 networkName = username[i+1:]
362 }
363 if i := strings.IndexAny(username, "/@"); i >= 0 {
364 username = username[:i]
365 }
366 dc.username = "~" + username
367
368 password := dc.password
369 dc.password = ""
370
371 u := dc.srv.getUser(username)
372 if u == nil {
373 dc.logger.Printf("failed authentication for %q: unknown username", username)
374 return errAuthFailed
375 }
376
377 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
378 if err != nil {
379 dc.logger.Printf("failed authentication for %q: %v", username, err)
380 return errAuthFailed
381 }
382
383 var network *network
384 if networkName != "" {
385 network = u.getNetwork(networkName)
386 if network == nil {
387 addr := networkName
388 if !strings.ContainsRune(addr, ':') {
389 addr = addr + ":6697"
390 }
391
392 dc.logger.Printf("trying to connect to new network %q", addr)
393 if err := sanityCheckServer(addr); err != nil {
394 dc.logger.Printf("failed to connect to %q: %v", addr, err)
395 return ircError{&irc.Message{
396 Command: irc.ERR_PASSWDMISMATCH,
397 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
398 }}
399 }
400
401 dc.logger.Printf("auto-saving network %q", networkName)
402 network, err = u.createNetwork(networkName, dc.nick)
403 if err != nil {
404 return err
405 }
406 }
407 }
408
409 dc.registered = true
410 dc.user = u
411 dc.network = network
412
413 u.lock.Lock()
414 firstDownstream := len(u.downstreamConns) == 0
415 u.downstreamConns = append(u.downstreamConns, dc)
416 u.lock.Unlock()
417
418 dc.SendMessage(&irc.Message{
419 Prefix: dc.srv.prefix(),
420 Command: irc.RPL_WELCOME,
421 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
422 })
423 dc.SendMessage(&irc.Message{
424 Prefix: dc.srv.prefix(),
425 Command: irc.RPL_YOURHOST,
426 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
427 })
428 dc.SendMessage(&irc.Message{
429 Prefix: dc.srv.prefix(),
430 Command: irc.RPL_CREATED,
431 Params: []string{dc.nick, "Who cares when the server was created?"},
432 })
433 dc.SendMessage(&irc.Message{
434 Prefix: dc.srv.prefix(),
435 Command: irc.RPL_MYINFO,
436 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
437 })
438 // TODO: RPL_ISUPPORT
439 dc.SendMessage(&irc.Message{
440 Prefix: dc.srv.prefix(),
441 Command: irc.ERR_NOMOTD,
442 Params: []string{dc.nick, "No MOTD"},
443 })
444
445 dc.forEachUpstream(func(uc *upstreamConn) {
446 for _, ch := range uc.channels {
447 if ch.complete {
448 forwardChannel(dc, ch)
449 }
450 }
451
452 historyName := dc.username
453
454 var seqPtr *uint64
455 if firstDownstream {
456 seq, ok := uc.history[historyName]
457 if ok {
458 seqPtr = &seq
459 }
460 }
461
462 consumer, ch := uc.ring.NewConsumer(seqPtr)
463 go func() {
464 for {
465 var closed bool
466 select {
467 case <-ch:
468 dc.ringMessages <- ringMessage{consumer, uc}
469 case <-dc.closed:
470 closed = true
471 }
472 if closed {
473 break
474 }
475 }
476
477 seq := consumer.Close()
478
479 dc.user.lock.Lock()
480 lastDownstream := len(dc.user.downstreamConns) == 0
481 dc.user.lock.Unlock()
482
483 if lastDownstream {
484 uc.history[historyName] = seq
485 }
486 }()
487 })
488
489 return nil
490}
491
492func (dc *downstreamConn) runUntilRegistered() error {
493 for !dc.registered {
494 msg, err := dc.irc.ReadMessage()
495 if err == io.EOF {
496 break
497 } else if err != nil {
498 return fmt.Errorf("failed to read IRC command: %v", err)
499 }
500
501 err = dc.handleMessage(msg)
502 if ircErr, ok := err.(ircError); ok {
503 ircErr.Message.Prefix = dc.srv.prefix()
504 dc.SendMessage(ircErr.Message)
505 } else if err != nil {
506 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
507 }
508 }
509
510 return nil
511}
512
513func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
514 switch msg.Command {
515 case "USER":
516 return ircError{&irc.Message{
517 Command: irc.ERR_ALREADYREGISTERED,
518 Params: []string{dc.nick, "You may not reregister"},
519 }}
520 case "NICK":
521 var nick string
522 if err := parseMessageParams(msg, &nick); err != nil {
523 return err
524 }
525
526 var err error
527 dc.forEachNetwork(func(n *network) {
528 if err != nil {
529 return
530 }
531 n.Nick = nick
532 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
533 })
534 if err != nil {
535 return err
536 }
537
538 dc.forEachUpstream(func(uc *upstreamConn) {
539 uc.SendMessage(msg)
540 })
541 case "JOIN", "PART":
542 var name string
543 if err := parseMessageParams(msg, &name); err != nil {
544 return err
545 }
546
547 uc, upstreamName, err := dc.unmarshalChannel(name)
548 if err != nil {
549 return ircError{&irc.Message{
550 Command: irc.ERR_NOSUCHCHANNEL,
551 Params: []string{name, err.Error()},
552 }}
553 }
554
555 uc.SendMessage(&irc.Message{
556 Command: msg.Command,
557 Params: []string{upstreamName},
558 })
559
560 switch msg.Command {
561 case "JOIN":
562 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
563 Name: upstreamName,
564 })
565 if err != nil {
566 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
567 }
568 case "PART":
569 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
570 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
571 }
572 }
573 case "MODE":
574 if msg.Prefix == nil {
575 return fmt.Errorf("missing prefix")
576 }
577
578 var name string
579 if err := parseMessageParams(msg, &name); err != nil {
580 return err
581 }
582
583 var modeStr string
584 if len(msg.Params) > 1 {
585 modeStr = msg.Params[1]
586 }
587
588 if msg.Prefix.Name != name {
589 uc, upstreamName, err := dc.unmarshalChannel(name)
590 if err != nil {
591 return err
592 }
593
594 if modeStr != "" {
595 uc.SendMessage(&irc.Message{
596 Command: "MODE",
597 Params: []string{upstreamName, modeStr},
598 })
599 } else {
600 ch, ok := uc.channels[upstreamName]
601 if !ok {
602 return ircError{&irc.Message{
603 Command: irc.ERR_NOSUCHCHANNEL,
604 Params: []string{name, "No such channel"},
605 }}
606 }
607
608 dc.SendMessage(&irc.Message{
609 Prefix: dc.srv.prefix(),
610 Command: irc.RPL_CHANNELMODEIS,
611 Params: []string{name, string(ch.modes)},
612 })
613 }
614 } else {
615 if name != dc.nick {
616 return ircError{&irc.Message{
617 Command: irc.ERR_USERSDONTMATCH,
618 Params: []string{dc.nick, "Cannot change mode for other users"},
619 }}
620 }
621
622 if modeStr != "" {
623 dc.forEachUpstream(func(uc *upstreamConn) {
624 uc.SendMessage(&irc.Message{
625 Command: "MODE",
626 Params: []string{uc.nick, modeStr},
627 })
628 })
629 } else {
630 dc.SendMessage(&irc.Message{
631 Prefix: dc.srv.prefix(),
632 Command: irc.RPL_UMODEIS,
633 Params: []string{""}, // TODO
634 })
635 }
636 }
637 case "PRIVMSG":
638 var targetsStr, text string
639 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
640 return err
641 }
642
643 for _, name := range strings.Split(targetsStr, ",") {
644 uc, upstreamName, err := dc.unmarshalChannel(name)
645 if err != nil {
646 return err
647 }
648
649 if upstreamName == "NickServ" {
650 dc.handleNickServPRIVMSG(uc, text)
651 }
652
653 uc.SendMessage(&irc.Message{
654 Command: "PRIVMSG",
655 Params: []string{upstreamName, text},
656 })
657
658 dc.lock.Lock()
659 dc.ourMessages[msg] = struct{}{}
660 dc.lock.Unlock()
661
662 uc.ring.Produce(msg)
663 }
664 default:
665 dc.logger.Printf("unhandled message: %v", msg)
666 return newUnknownCommandError(msg.Command)
667 }
668 return nil
669}
670
671func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
672 username, password, ok := parseNickServCredentials(text, uc.nick)
673 if !ok {
674 return
675 }
676
677 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
678 n := uc.network
679 n.SASL.Mechanism = "PLAIN"
680 n.SASL.Plain.Username = username
681 n.SASL.Plain.Password = password
682 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
683 dc.logger.Printf("failed to save NickServ credentials: %v", err)
684 }
685}
686
687func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
688 fields := strings.Fields(text)
689 if len(fields) < 2 {
690 return "", "", false
691 }
692 cmd := strings.ToUpper(fields[0])
693 params := fields[1:]
694 switch cmd {
695 case "REGISTER":
696 username = nick
697 password = params[0]
698 case "IDENTIFY":
699 if len(params) == 1 {
700 username = nick
701 } else {
702 username = params[0]
703 }
704 password = params[1]
705 }
706 return username, password, true
707}
Note: See TracBrowser for help on using the repository browser.