source: code/trunk/downstream.go@ 106

Last change on this file since 106 was 106, checked in by contact, 5 years ago

Make downstreamConn.runUntilegistered exit with an error on EOF

File size: 15.2 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strings"
9 "sync"
10 "time"
11
12 "golang.org/x/crypto/bcrypt"
13 "gopkg.in/irc.v3"
14)
15
16type ircError struct {
17 Message *irc.Message
18}
19
20func (err ircError) Error() string {
21 return err.Message.String()
22}
23
24func newUnknownCommandError(cmd string) ircError {
25 return ircError{&irc.Message{
26 Command: irc.ERR_UNKNOWNCOMMAND,
27 Params: []string{
28 "*",
29 cmd,
30 "Unknown command",
31 },
32 }}
33}
34
35func newNeedMoreParamsError(cmd string) ircError {
36 return ircError{&irc.Message{
37 Command: irc.ERR_NEEDMOREPARAMS,
38 Params: []string{
39 "*",
40 cmd,
41 "Not enough parameters",
42 },
43 }}
44}
45
46var errAuthFailed = ircError{&irc.Message{
47 Command: irc.ERR_PASSWDMISMATCH,
48 Params: []string{"*", "Invalid username or password"},
49}}
50
51type ringMessage struct {
52 consumer *RingConsumer
53 upstreamConn *upstreamConn
54}
55
56type downstreamConn struct {
57 net net.Conn
58 irc *irc.Conn
59 srv *Server
60 logger Logger
61 outgoing chan *irc.Message
62 ringMessages chan ringMessage
63 closed chan struct{}
64
65 registered bool
66 user *user
67 nick string
68 username string
69 rawUsername string
70 realname string
71 password string // empty after authentication
72 network *network // can be nil
73
74 lock sync.Mutex
75 ourMessages map[*irc.Message]struct{}
76}
77
78func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
79 dc := &downstreamConn{
80 net: netConn,
81 irc: irc.NewConn(netConn),
82 srv: srv,
83 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
84 outgoing: make(chan *irc.Message, 64),
85 ringMessages: make(chan ringMessage),
86 closed: make(chan struct{}),
87 ourMessages: make(map[*irc.Message]struct{}),
88 }
89
90 go func() {
91 if err := dc.writeMessages(); err != nil {
92 dc.logger.Printf("failed to write message: %v", err)
93 }
94 if err := dc.net.Close(); err != nil {
95 dc.logger.Printf("failed to close connection: %v", err)
96 } else {
97 dc.logger.Printf("connection closed")
98 }
99 }()
100
101 return dc
102}
103
104func (dc *downstreamConn) prefix() *irc.Prefix {
105 return &irc.Prefix{
106 Name: dc.nick,
107 User: dc.username,
108 // TODO: fill the host?
109 }
110}
111
112func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
113 return name
114}
115
116func (dc *downstreamConn) forEachNetwork(f func(*network)) {
117 if dc.network != nil {
118 f(dc.network)
119 } else {
120 dc.user.forEachNetwork(f)
121 }
122}
123
124func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
125 dc.user.forEachUpstream(func(uc *upstreamConn) {
126 if dc.network != nil && uc.network != dc.network {
127 return
128 }
129 f(uc)
130 })
131}
132
133// upstream returns the upstream connection, if any. If there are zero or if
134// there are multiple upstream connections, it returns nil.
135func (dc *downstreamConn) upstream() *upstreamConn {
136 if dc.network == nil {
137 return nil
138 }
139
140 var upstream *upstreamConn
141 dc.forEachUpstream(func(uc *upstreamConn) {
142 upstream = uc
143 })
144 return upstream
145}
146
147func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
148 if uc := dc.upstream(); uc != nil {
149 return uc, name, nil
150 }
151
152 // TODO: extract network name from channel name if dc.upstream == nil
153 var channel *upstreamChannel
154 var err error
155 dc.forEachUpstream(func(uc *upstreamConn) {
156 if err != nil {
157 return
158 }
159 if ch, ok := uc.channels[name]; ok {
160 if channel != nil {
161 err = fmt.Errorf("ambiguous channel name %q", name)
162 } else {
163 channel = ch
164 }
165 }
166 })
167 if channel == nil {
168 return nil, "", ircError{&irc.Message{
169 Command: irc.ERR_NOSUCHCHANNEL,
170 Params: []string{name, "No such channel"},
171 }}
172 }
173 return channel.conn, channel.Name, nil
174}
175
176func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
177 if nick == uc.nick {
178 return dc.nick
179 }
180 return nick
181}
182
183func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
184 if prefix.Name == uc.nick {
185 return dc.prefix()
186 }
187 return prefix
188}
189
190func (dc *downstreamConn) isClosed() bool {
191 select {
192 case <-dc.closed:
193 return true
194 default:
195 return false
196 }
197}
198
199func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
200 dc.logger.Printf("new connection")
201
202 for {
203 msg, err := dc.irc.ReadMessage()
204 if err == io.EOF {
205 break
206 } else if err != nil {
207 return fmt.Errorf("failed to read IRC command: %v", err)
208 }
209
210 if dc.srv.Debug {
211 dc.logger.Printf("received: %v", msg)
212 }
213
214 ch <- downstreamIncomingMessage{msg, dc}
215 }
216
217 return nil
218}
219
220func (dc *downstreamConn) writeMessages() error {
221 for {
222 var err error
223 var closed bool
224 select {
225 case msg := <-dc.outgoing:
226 if dc.srv.Debug {
227 dc.logger.Printf("sent: %v", msg)
228 }
229 err = dc.irc.WriteMessage(msg)
230 case ringMessage := <-dc.ringMessages:
231 consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
232 for {
233 msg := consumer.Peek()
234 if msg == nil {
235 break
236 }
237
238 dc.lock.Lock()
239 _, ours := dc.ourMessages[msg]
240 delete(dc.ourMessages, msg)
241 dc.lock.Unlock()
242 if ours {
243 // The message comes from our connection, don't echo it
244 // back
245 continue
246 }
247
248 msg = msg.Copy()
249 switch msg.Command {
250 case "PRIVMSG":
251 // TODO: detect whether it's a user or a channel
252 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
253 default:
254 panic("expected to consume a PRIVMSG message")
255 }
256 if dc.srv.Debug {
257 dc.logger.Printf("sent: %v", msg)
258 }
259 err = dc.irc.WriteMessage(msg)
260 if err != nil {
261 break
262 }
263 consumer.Consume()
264 }
265 case <-dc.closed:
266 closed = true
267 }
268 if err != nil {
269 return err
270 }
271 if closed {
272 break
273 }
274 }
275 return nil
276}
277
278func (dc *downstreamConn) Close() error {
279 if dc.isClosed() {
280 return fmt.Errorf("downstream connection already closed")
281 }
282
283 if u := dc.user; u != nil {
284 u.lock.Lock()
285 for i := range u.downstreamConns {
286 if u.downstreamConns[i] == dc {
287 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
288 break
289 }
290 }
291 u.lock.Unlock()
292 }
293
294 close(dc.closed)
295 return nil
296}
297
298func (dc *downstreamConn) SendMessage(msg *irc.Message) {
299 dc.outgoing <- msg
300}
301
302func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
303 switch msg.Command {
304 case "QUIT":
305 return dc.Close()
306 case "PING":
307 dc.SendMessage(&irc.Message{
308 Prefix: dc.srv.prefix(),
309 Command: "PONG",
310 Params: msg.Params,
311 })
312 return nil
313 default:
314 if dc.registered {
315 return dc.handleMessageRegistered(msg)
316 } else {
317 return dc.handleMessageUnregistered(msg)
318 }
319 }
320}
321
322func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
323 switch msg.Command {
324 case "NICK":
325 if err := parseMessageParams(msg, &dc.nick); err != nil {
326 return err
327 }
328 case "USER":
329 var username string
330 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
331 return err
332 }
333 dc.rawUsername = username
334 case "PASS":
335 if err := parseMessageParams(msg, &dc.password); err != nil {
336 return err
337 }
338 default:
339 dc.logger.Printf("unhandled message: %v", msg)
340 return newUnknownCommandError(msg.Command)
341 }
342 if dc.rawUsername != "" && dc.nick != "" {
343 return dc.register()
344 }
345 return nil
346}
347
348func sanityCheckServer(addr string) error {
349 dialer := net.Dialer{Timeout: 30 * time.Second}
350 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
351 if err != nil {
352 return err
353 }
354 return conn.Close()
355}
356
357func (dc *downstreamConn) register() error {
358 username := dc.rawUsername
359 var networkName string
360 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
361 networkName = username[i+1:]
362 }
363 if i := strings.IndexAny(username, "/@"); i >= 0 {
364 username = username[:i]
365 }
366 dc.username = "~" + username
367
368 password := dc.password
369 dc.password = ""
370
371 u := dc.srv.getUser(username)
372 if u == nil {
373 dc.logger.Printf("failed authentication for %q: unknown username", username)
374 return errAuthFailed
375 }
376
377 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
378 if err != nil {
379 dc.logger.Printf("failed authentication for %q: %v", username, err)
380 return errAuthFailed
381 }
382
383 var network *network
384 if networkName != "" {
385 network = u.getNetwork(networkName)
386 if network == nil {
387 addr := networkName
388 if !strings.ContainsRune(addr, ':') {
389 addr = addr + ":6697"
390 }
391
392 dc.logger.Printf("trying to connect to new network %q", addr)
393 if err := sanityCheckServer(addr); err != nil {
394 dc.logger.Printf("failed to connect to %q: %v", addr, err)
395 return ircError{&irc.Message{
396 Command: irc.ERR_PASSWDMISMATCH,
397 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
398 }}
399 }
400
401 dc.logger.Printf("auto-saving network %q", networkName)
402 network, err = u.createNetwork(networkName, dc.nick)
403 if err != nil {
404 return err
405 }
406 }
407 }
408
409 dc.registered = true
410 dc.user = u
411 dc.network = network
412
413 u.lock.Lock()
414 firstDownstream := len(u.downstreamConns) == 0
415 u.downstreamConns = append(u.downstreamConns, dc)
416 u.lock.Unlock()
417
418 dc.SendMessage(&irc.Message{
419 Prefix: dc.srv.prefix(),
420 Command: irc.RPL_WELCOME,
421 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
422 })
423 dc.SendMessage(&irc.Message{
424 Prefix: dc.srv.prefix(),
425 Command: irc.RPL_YOURHOST,
426 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
427 })
428 dc.SendMessage(&irc.Message{
429 Prefix: dc.srv.prefix(),
430 Command: irc.RPL_CREATED,
431 Params: []string{dc.nick, "Who cares when the server was created?"},
432 })
433 dc.SendMessage(&irc.Message{
434 Prefix: dc.srv.prefix(),
435 Command: irc.RPL_MYINFO,
436 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
437 })
438 // TODO: RPL_ISUPPORT
439 dc.SendMessage(&irc.Message{
440 Prefix: dc.srv.prefix(),
441 Command: irc.ERR_NOMOTD,
442 Params: []string{dc.nick, "No MOTD"},
443 })
444
445 dc.forEachUpstream(func(uc *upstreamConn) {
446 for _, ch := range uc.channels {
447 if ch.complete {
448 forwardChannel(dc, ch)
449 }
450 }
451
452 historyName := dc.username
453
454 var seqPtr *uint64
455 if firstDownstream {
456 seq, ok := uc.history[historyName]
457 if ok {
458 seqPtr = &seq
459 }
460 }
461
462 consumer, ch := uc.ring.NewConsumer(seqPtr)
463 go func() {
464 for {
465 var closed bool
466 select {
467 case <-ch:
468 dc.ringMessages <- ringMessage{consumer, uc}
469 case <-dc.closed:
470 closed = true
471 }
472 if closed {
473 break
474 }
475 }
476
477 seq := consumer.Close()
478
479 dc.user.lock.Lock()
480 lastDownstream := len(dc.user.downstreamConns) == 0
481 dc.user.lock.Unlock()
482
483 if lastDownstream {
484 uc.history[historyName] = seq
485 }
486 }()
487 })
488
489 return nil
490}
491
492func (dc *downstreamConn) runUntilRegistered() error {
493 for !dc.registered {
494 msg, err := dc.irc.ReadMessage()
495 if err != nil {
496 return fmt.Errorf("failed to read IRC command: %v", err)
497 }
498
499 err = dc.handleMessage(msg)
500 if ircErr, ok := err.(ircError); ok {
501 ircErr.Message.Prefix = dc.srv.prefix()
502 dc.SendMessage(ircErr.Message)
503 } else if err != nil {
504 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
505 }
506 }
507
508 return nil
509}
510
511func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
512 switch msg.Command {
513 case "USER":
514 return ircError{&irc.Message{
515 Command: irc.ERR_ALREADYREGISTERED,
516 Params: []string{dc.nick, "You may not reregister"},
517 }}
518 case "NICK":
519 var nick string
520 if err := parseMessageParams(msg, &nick); err != nil {
521 return err
522 }
523
524 var err error
525 dc.forEachNetwork(func(n *network) {
526 if err != nil {
527 return
528 }
529 n.Nick = nick
530 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
531 })
532 if err != nil {
533 return err
534 }
535
536 dc.forEachUpstream(func(uc *upstreamConn) {
537 uc.SendMessage(msg)
538 })
539 case "JOIN", "PART":
540 var name string
541 if err := parseMessageParams(msg, &name); err != nil {
542 return err
543 }
544
545 uc, upstreamName, err := dc.unmarshalChannel(name)
546 if err != nil {
547 return ircError{&irc.Message{
548 Command: irc.ERR_NOSUCHCHANNEL,
549 Params: []string{name, err.Error()},
550 }}
551 }
552
553 uc.SendMessage(&irc.Message{
554 Command: msg.Command,
555 Params: []string{upstreamName},
556 })
557
558 switch msg.Command {
559 case "JOIN":
560 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
561 Name: upstreamName,
562 })
563 if err != nil {
564 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
565 }
566 case "PART":
567 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
568 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
569 }
570 }
571 case "MODE":
572 if msg.Prefix == nil {
573 return fmt.Errorf("missing prefix")
574 }
575
576 var name string
577 if err := parseMessageParams(msg, &name); err != nil {
578 return err
579 }
580
581 var modeStr string
582 if len(msg.Params) > 1 {
583 modeStr = msg.Params[1]
584 }
585
586 if msg.Prefix.Name != name {
587 uc, upstreamName, err := dc.unmarshalChannel(name)
588 if err != nil {
589 return err
590 }
591
592 if modeStr != "" {
593 uc.SendMessage(&irc.Message{
594 Command: "MODE",
595 Params: []string{upstreamName, modeStr},
596 })
597 } else {
598 ch, ok := uc.channels[upstreamName]
599 if !ok {
600 return ircError{&irc.Message{
601 Command: irc.ERR_NOSUCHCHANNEL,
602 Params: []string{name, "No such channel"},
603 }}
604 }
605
606 dc.SendMessage(&irc.Message{
607 Prefix: dc.srv.prefix(),
608 Command: irc.RPL_CHANNELMODEIS,
609 Params: []string{name, string(ch.modes)},
610 })
611 }
612 } else {
613 if name != dc.nick {
614 return ircError{&irc.Message{
615 Command: irc.ERR_USERSDONTMATCH,
616 Params: []string{dc.nick, "Cannot change mode for other users"},
617 }}
618 }
619
620 if modeStr != "" {
621 dc.forEachUpstream(func(uc *upstreamConn) {
622 uc.SendMessage(&irc.Message{
623 Command: "MODE",
624 Params: []string{uc.nick, modeStr},
625 })
626 })
627 } else {
628 dc.SendMessage(&irc.Message{
629 Prefix: dc.srv.prefix(),
630 Command: irc.RPL_UMODEIS,
631 Params: []string{""}, // TODO
632 })
633 }
634 }
635 case "PRIVMSG":
636 var targetsStr, text string
637 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
638 return err
639 }
640
641 for _, name := range strings.Split(targetsStr, ",") {
642 uc, upstreamName, err := dc.unmarshalChannel(name)
643 if err != nil {
644 return err
645 }
646
647 if upstreamName == "NickServ" {
648 dc.handleNickServPRIVMSG(uc, text)
649 }
650
651 uc.SendMessage(&irc.Message{
652 Command: "PRIVMSG",
653 Params: []string{upstreamName, text},
654 })
655
656 dc.lock.Lock()
657 dc.ourMessages[msg] = struct{}{}
658 dc.lock.Unlock()
659
660 uc.ring.Produce(msg)
661 }
662 default:
663 dc.logger.Printf("unhandled message: %v", msg)
664 return newUnknownCommandError(msg.Command)
665 }
666 return nil
667}
668
669func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
670 username, password, ok := parseNickServCredentials(text, uc.nick)
671 if !ok {
672 return
673 }
674
675 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
676 n := uc.network
677 n.SASL.Mechanism = "PLAIN"
678 n.SASL.Plain.Username = username
679 n.SASL.Plain.Password = password
680 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
681 dc.logger.Printf("failed to save NickServ credentials: %v", err)
682 }
683}
684
685func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
686 fields := strings.Fields(text)
687 if len(fields) < 2 {
688 return "", "", false
689 }
690 cmd := strings.ToUpper(fields[0])
691 params := fields[1:]
692 switch cmd {
693 case "REGISTER":
694 username = nick
695 password = params[0]
696 case "IDENTIFY":
697 if len(params) == 1 {
698 username = nick
699 } else {
700 username = params[0]
701 }
702 password = params[1]
703 }
704 return username, password, true
705}
Note: See TracBrowser for help on using the repository browser.