source: code/trunk/downstream.go@ 103

Last change on this file since 103 was 103, checked in by contact, 5 years ago

Per-user dispatcher goroutine

This allows message handlers to read upstream/downstream connection
information without causing any race condition.

References: https://todo.sr.ht/~emersion/soju/1

File size: 14.9 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strings"
9 "time"
10
11 "golang.org/x/crypto/bcrypt"
12 "gopkg.in/irc.v3"
13)
14
15type ircError struct {
16 Message *irc.Message
17}
18
19func (err ircError) Error() string {
20 return err.Message.String()
21}
22
23func newUnknownCommandError(cmd string) ircError {
24 return ircError{&irc.Message{
25 Command: irc.ERR_UNKNOWNCOMMAND,
26 Params: []string{
27 "*",
28 cmd,
29 "Unknown command",
30 },
31 }}
32}
33
34func newNeedMoreParamsError(cmd string) ircError {
35 return ircError{&irc.Message{
36 Command: irc.ERR_NEEDMOREPARAMS,
37 Params: []string{
38 "*",
39 cmd,
40 "Not enough parameters",
41 },
42 }}
43}
44
45var errAuthFailed = ircError{&irc.Message{
46 Command: irc.ERR_PASSWDMISMATCH,
47 Params: []string{"*", "Invalid username or password"},
48}}
49
50type consumption struct {
51 consumer *RingConsumer
52 upstreamConn *upstreamConn
53}
54
55type downstreamConn struct {
56 net net.Conn
57 irc *irc.Conn
58 srv *Server
59 logger Logger
60 outgoing chan *irc.Message
61 consumptions chan consumption
62 closed chan struct{}
63
64 registered bool
65 user *user
66 nick string
67 username string
68 rawUsername string
69 realname string
70 password string // empty after authentication
71 network *network // can be nil
72}
73
74func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
75 dc := &downstreamConn{
76 net: netConn,
77 irc: irc.NewConn(netConn),
78 srv: srv,
79 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
80 outgoing: make(chan *irc.Message, 64),
81 consumptions: make(chan consumption),
82 closed: make(chan struct{}),
83 }
84
85 go func() {
86 if err := dc.writeMessages(); err != nil {
87 dc.logger.Printf("failed to write message: %v", err)
88 }
89 if err := dc.net.Close(); err != nil {
90 dc.logger.Printf("failed to close connection: %v", err)
91 } else {
92 dc.logger.Printf("connection closed")
93 }
94 }()
95
96 return dc
97}
98
99func (dc *downstreamConn) prefix() *irc.Prefix {
100 return &irc.Prefix{
101 Name: dc.nick,
102 User: dc.username,
103 // TODO: fill the host?
104 }
105}
106
107func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
108 return name
109}
110
111func (dc *downstreamConn) forEachNetwork(f func(*network)) {
112 if dc.network != nil {
113 f(dc.network)
114 } else {
115 dc.user.forEachNetwork(f)
116 }
117}
118
119func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
120 dc.user.forEachUpstream(func(uc *upstreamConn) {
121 if dc.network != nil && uc.network != dc.network {
122 return
123 }
124 f(uc)
125 })
126}
127
128// upstream returns the upstream connection, if any. If there are zero or if
129// there are multiple upstream connections, it returns nil.
130func (dc *downstreamConn) upstream() *upstreamConn {
131 if dc.network == nil {
132 return nil
133 }
134
135 var upstream *upstreamConn
136 dc.forEachUpstream(func(uc *upstreamConn) {
137 upstream = uc
138 })
139 return upstream
140}
141
142func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
143 if uc := dc.upstream(); uc != nil {
144 return uc, name, nil
145 }
146
147 // TODO: extract network name from channel name if dc.upstream == nil
148 var channel *upstreamChannel
149 var err error
150 dc.forEachUpstream(func(uc *upstreamConn) {
151 if err != nil {
152 return
153 }
154 if ch, ok := uc.channels[name]; ok {
155 if channel != nil {
156 err = fmt.Errorf("ambiguous channel name %q", name)
157 } else {
158 channel = ch
159 }
160 }
161 })
162 if channel == nil {
163 return nil, "", ircError{&irc.Message{
164 Command: irc.ERR_NOSUCHCHANNEL,
165 Params: []string{name, "No such channel"},
166 }}
167 }
168 return channel.conn, channel.Name, nil
169}
170
171func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
172 if nick == uc.nick {
173 return dc.nick
174 }
175 return nick
176}
177
178func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
179 if prefix.Name == uc.nick {
180 return dc.prefix()
181 }
182 return prefix
183}
184
185func (dc *downstreamConn) isClosed() bool {
186 select {
187 case <-dc.closed:
188 return true
189 default:
190 return false
191 }
192}
193
194func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
195 dc.logger.Printf("new connection")
196
197 for {
198 msg, err := dc.irc.ReadMessage()
199 if err == io.EOF {
200 break
201 } else if err != nil {
202 return fmt.Errorf("failed to read IRC command: %v", err)
203 }
204
205 if dc.srv.Debug {
206 dc.logger.Printf("received: %v", msg)
207 }
208
209 ch <- downstreamIncomingMessage{msg, dc}
210 }
211
212 return nil
213}
214
215func (dc *downstreamConn) writeMessages() error {
216 for {
217 var err error
218 var closed bool
219 select {
220 case msg := <-dc.outgoing:
221 if dc.srv.Debug {
222 dc.logger.Printf("sent: %v", msg)
223 }
224 err = dc.irc.WriteMessage(msg)
225 case consumption := <-dc.consumptions:
226 consumer, uc := consumption.consumer, consumption.upstreamConn
227 for {
228 msg := consumer.Peek()
229 if msg == nil {
230 break
231 }
232 msg = msg.Copy()
233 switch msg.Command {
234 case "PRIVMSG":
235 // TODO: detect whether it's a user or a channel
236 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
237 default:
238 panic("expected to consume a PRIVMSG message")
239 }
240 if dc.srv.Debug {
241 dc.logger.Printf("sent: %v", msg)
242 }
243 err = dc.irc.WriteMessage(msg)
244 if err != nil {
245 break
246 }
247 consumer.Consume()
248 }
249 case <-dc.closed:
250 closed = true
251 }
252 if err != nil {
253 return err
254 }
255 if closed {
256 break
257 }
258 }
259 return nil
260}
261
262func (dc *downstreamConn) Close() error {
263 if dc.isClosed() {
264 return fmt.Errorf("downstream connection already closed")
265 }
266
267 if u := dc.user; u != nil {
268 u.lock.Lock()
269 for i := range u.downstreamConns {
270 if u.downstreamConns[i] == dc {
271 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
272 break
273 }
274 }
275 u.lock.Unlock()
276 }
277
278 close(dc.closed)
279 return nil
280}
281
282func (dc *downstreamConn) SendMessage(msg *irc.Message) {
283 dc.outgoing <- msg
284}
285
286func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
287 switch msg.Command {
288 case "QUIT":
289 return dc.Close()
290 case "PING":
291 dc.SendMessage(&irc.Message{
292 Prefix: dc.srv.prefix(),
293 Command: "PONG",
294 Params: msg.Params,
295 })
296 return nil
297 default:
298 if dc.registered {
299 return dc.handleMessageRegistered(msg)
300 } else {
301 return dc.handleMessageUnregistered(msg)
302 }
303 }
304}
305
306func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
307 switch msg.Command {
308 case "NICK":
309 if err := parseMessageParams(msg, &dc.nick); err != nil {
310 return err
311 }
312 case "USER":
313 var username string
314 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
315 return err
316 }
317 dc.rawUsername = username
318 case "PASS":
319 if err := parseMessageParams(msg, &dc.password); err != nil {
320 return err
321 }
322 default:
323 dc.logger.Printf("unhandled message: %v", msg)
324 return newUnknownCommandError(msg.Command)
325 }
326 if dc.rawUsername != "" && dc.nick != "" {
327 return dc.register()
328 }
329 return nil
330}
331
332func sanityCheckServer(addr string) error {
333 dialer := net.Dialer{Timeout: 30 * time.Second}
334 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
335 if err != nil {
336 return err
337 }
338 return conn.Close()
339}
340
341func (dc *downstreamConn) register() error {
342 username := dc.rawUsername
343 var networkName string
344 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
345 networkName = username[i+1:]
346 }
347 if i := strings.IndexAny(username, "/@"); i >= 0 {
348 username = username[:i]
349 }
350 dc.username = "~" + username
351
352 password := dc.password
353 dc.password = ""
354
355 u := dc.srv.getUser(username)
356 if u == nil {
357 dc.logger.Printf("failed authentication for %q: unknown username", username)
358 return errAuthFailed
359 }
360
361 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
362 if err != nil {
363 dc.logger.Printf("failed authentication for %q: %v", username, err)
364 return errAuthFailed
365 }
366
367 var network *network
368 if networkName != "" {
369 network = u.getNetwork(networkName)
370 if network == nil {
371 addr := networkName
372 if !strings.ContainsRune(addr, ':') {
373 addr = addr + ":6697"
374 }
375
376 dc.logger.Printf("trying to connect to new network %q", addr)
377 if err := sanityCheckServer(addr); err != nil {
378 dc.logger.Printf("failed to connect to %q: %v", addr, err)
379 return ircError{&irc.Message{
380 Command: irc.ERR_PASSWDMISMATCH,
381 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
382 }}
383 }
384
385 dc.logger.Printf("auto-saving network %q", networkName)
386 network, err = u.createNetwork(networkName, dc.nick)
387 if err != nil {
388 return err
389 }
390 }
391 }
392
393 dc.registered = true
394 dc.user = u
395 dc.network = network
396
397 u.lock.Lock()
398 firstDownstream := len(u.downstreamConns) == 0
399 u.downstreamConns = append(u.downstreamConns, dc)
400 u.lock.Unlock()
401
402 dc.SendMessage(&irc.Message{
403 Prefix: dc.srv.prefix(),
404 Command: irc.RPL_WELCOME,
405 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
406 })
407 dc.SendMessage(&irc.Message{
408 Prefix: dc.srv.prefix(),
409 Command: irc.RPL_YOURHOST,
410 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
411 })
412 dc.SendMessage(&irc.Message{
413 Prefix: dc.srv.prefix(),
414 Command: irc.RPL_CREATED,
415 Params: []string{dc.nick, "Who cares when the server was created?"},
416 })
417 dc.SendMessage(&irc.Message{
418 Prefix: dc.srv.prefix(),
419 Command: irc.RPL_MYINFO,
420 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
421 })
422 // TODO: RPL_ISUPPORT
423 dc.SendMessage(&irc.Message{
424 Prefix: dc.srv.prefix(),
425 Command: irc.ERR_NOMOTD,
426 Params: []string{dc.nick, "No MOTD"},
427 })
428
429 dc.forEachUpstream(func(uc *upstreamConn) {
430 // TODO: fix races accessing upstream connection data
431 for _, ch := range uc.channels {
432 if ch.complete {
433 forwardChannel(dc, ch)
434 }
435 }
436
437 historyName := dc.username
438
439 var seqPtr *uint64
440 if firstDownstream {
441 seq, ok := uc.history[historyName]
442 if ok {
443 seqPtr = &seq
444 }
445 }
446
447 consumer, ch := uc.ring.NewConsumer(seqPtr)
448 go func() {
449 for {
450 var closed bool
451 select {
452 case <-ch:
453 dc.consumptions <- consumption{consumer, uc}
454 case <-dc.closed:
455 closed = true
456 }
457 if closed {
458 break
459 }
460 }
461
462 seq := consumer.Close()
463
464 dc.user.lock.Lock()
465 lastDownstream := len(dc.user.downstreamConns) == 0
466 dc.user.lock.Unlock()
467
468 if lastDownstream {
469 uc.history[historyName] = seq
470 }
471 }()
472 })
473
474 return nil
475}
476
477func (dc *downstreamConn) runUntilRegistered() error {
478 for !dc.registered {
479 msg, err := dc.irc.ReadMessage()
480 if err == io.EOF {
481 break
482 } else if err != nil {
483 return fmt.Errorf("failed to read IRC command: %v", err)
484 }
485
486 err = dc.handleMessage(msg)
487 if ircErr, ok := err.(ircError); ok {
488 ircErr.Message.Prefix = dc.srv.prefix()
489 dc.SendMessage(ircErr.Message)
490 } else if err != nil {
491 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
492 }
493 }
494
495 return nil
496}
497
498func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
499 switch msg.Command {
500 case "USER":
501 return ircError{&irc.Message{
502 Command: irc.ERR_ALREADYREGISTERED,
503 Params: []string{dc.nick, "You may not reregister"},
504 }}
505 case "NICK":
506 var nick string
507 if err := parseMessageParams(msg, &nick); err != nil {
508 return err
509 }
510
511 var err error
512 dc.forEachNetwork(func(n *network) {
513 if err != nil {
514 return
515 }
516 n.Nick = nick
517 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
518 })
519 if err != nil {
520 return err
521 }
522
523 dc.forEachUpstream(func(uc *upstreamConn) {
524 uc.SendMessage(msg)
525 })
526 case "JOIN", "PART":
527 var name string
528 if err := parseMessageParams(msg, &name); err != nil {
529 return err
530 }
531
532 uc, upstreamName, err := dc.unmarshalChannel(name)
533 if err != nil {
534 return ircError{&irc.Message{
535 Command: irc.ERR_NOSUCHCHANNEL,
536 Params: []string{name, err.Error()},
537 }}
538 }
539
540 uc.SendMessage(&irc.Message{
541 Command: msg.Command,
542 Params: []string{upstreamName},
543 })
544
545 switch msg.Command {
546 case "JOIN":
547 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
548 Name: upstreamName,
549 })
550 if err != nil {
551 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
552 }
553 case "PART":
554 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
555 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
556 }
557 }
558 case "MODE":
559 if msg.Prefix == nil {
560 return fmt.Errorf("missing prefix")
561 }
562
563 var name string
564 if err := parseMessageParams(msg, &name); err != nil {
565 return err
566 }
567
568 var modeStr string
569 if len(msg.Params) > 1 {
570 modeStr = msg.Params[1]
571 }
572
573 if msg.Prefix.Name != name {
574 uc, upstreamName, err := dc.unmarshalChannel(name)
575 if err != nil {
576 return err
577 }
578
579 if modeStr != "" {
580 uc.SendMessage(&irc.Message{
581 Command: "MODE",
582 Params: []string{upstreamName, modeStr},
583 })
584 } else {
585 ch, ok := uc.channels[upstreamName]
586 if !ok {
587 return ircError{&irc.Message{
588 Command: irc.ERR_NOSUCHCHANNEL,
589 Params: []string{name, "No such channel"},
590 }}
591 }
592
593 dc.SendMessage(&irc.Message{
594 Prefix: dc.srv.prefix(),
595 Command: irc.RPL_CHANNELMODEIS,
596 Params: []string{name, string(ch.modes)},
597 })
598 }
599 } else {
600 if name != dc.nick {
601 return ircError{&irc.Message{
602 Command: irc.ERR_USERSDONTMATCH,
603 Params: []string{dc.nick, "Cannot change mode for other users"},
604 }}
605 }
606
607 if modeStr != "" {
608 dc.forEachUpstream(func(uc *upstreamConn) {
609 uc.SendMessage(&irc.Message{
610 Command: "MODE",
611 Params: []string{uc.nick, modeStr},
612 })
613 })
614 } else {
615 dc.SendMessage(&irc.Message{
616 Prefix: dc.srv.prefix(),
617 Command: irc.RPL_UMODEIS,
618 Params: []string{""}, // TODO
619 })
620 }
621 }
622 case "PRIVMSG":
623 var targetsStr, text string
624 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
625 return err
626 }
627
628 for _, name := range strings.Split(targetsStr, ",") {
629 uc, upstreamName, err := dc.unmarshalChannel(name)
630 if err != nil {
631 return err
632 }
633
634 if upstreamName == "NickServ" {
635 dc.handleNickServPRIVMSG(uc, text)
636 }
637
638 uc.SendMessage(&irc.Message{
639 Command: "PRIVMSG",
640 Params: []string{upstreamName, text},
641 })
642 }
643 default:
644 dc.logger.Printf("unhandled message: %v", msg)
645 return newUnknownCommandError(msg.Command)
646 }
647 return nil
648}
649
650func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
651 username, password, ok := parseNickServCredentials(text, uc.nick)
652 if !ok {
653 return
654 }
655
656 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
657 n := uc.network
658 n.SASL.Mechanism = "PLAIN"
659 n.SASL.Plain.Username = username
660 n.SASL.Plain.Password = password
661 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
662 dc.logger.Printf("failed to save NickServ credentials: %v", err)
663 }
664}
665
666func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
667 fields := strings.Fields(text)
668 if len(fields) < 2 {
669 return "", "", false
670 }
671 cmd := strings.ToUpper(fields[0])
672 params := fields[1:]
673 switch cmd {
674 case "REGISTER":
675 username = nick
676 password = params[0]
677 case "IDENTIFY":
678 if len(params) == 1 {
679 username = nick
680 } else {
681 username = params[0]
682 }
683 password = params[1]
684 }
685 return username, password, true
686}
Note: See TracBrowser for help on using the repository browser.