source: code/trunk/downstream.go@ 103

Last change on this file since 103 was 103, checked in by contact, 5 years ago

Per-user dispatcher goroutine

This allows message handlers to read upstream/downstream connection
information without causing any race condition.

References: https://todo.sr.ht/~emersion/soju/1

File size: 14.9 KB
RevLine 
[98]1package soju
[13]2
3import (
[91]4 "crypto/tls"
[13]5 "fmt"
6 "io"
7 "net"
[39]8 "strings"
[91]9 "time"
[13]10
[85]11 "golang.org/x/crypto/bcrypt"
[13]12 "gopkg.in/irc.v3"
13)
14
15type ircError struct {
16 Message *irc.Message
17}
18
[85]19func (err ircError) Error() string {
20 return err.Message.String()
21}
22
[13]23func newUnknownCommandError(cmd string) ircError {
24 return ircError{&irc.Message{
25 Command: irc.ERR_UNKNOWNCOMMAND,
26 Params: []string{
27 "*",
28 cmd,
29 "Unknown command",
30 },
31 }}
32}
33
34func newNeedMoreParamsError(cmd string) ircError {
35 return ircError{&irc.Message{
36 Command: irc.ERR_NEEDMOREPARAMS,
37 Params: []string{
38 "*",
39 cmd,
40 "Not enough parameters",
41 },
42 }}
43}
44
[85]45var errAuthFailed = ircError{&irc.Message{
46 Command: irc.ERR_PASSWDMISMATCH,
47 Params: []string{"*", "Invalid username or password"},
48}}
[13]49
[69]50type consumption struct {
51 consumer *RingConsumer
52 upstreamConn *upstreamConn
53}
54
[13]55type downstreamConn struct {
[69]56 net net.Conn
57 irc *irc.Conn
58 srv *Server
59 logger Logger
[102]60 outgoing chan *irc.Message
[69]61 consumptions chan consumption
62 closed chan struct{}
[22]63
[100]64 registered bool
65 user *user
66 nick string
67 username string
68 rawUsername string
69 realname string
70 password string // empty after authentication
71 network *network // can be nil
[13]72}
73
[22]74func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
[55]75 dc := &downstreamConn{
[69]76 net: netConn,
77 irc: irc.NewConn(netConn),
78 srv: srv,
79 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
[102]80 outgoing: make(chan *irc.Message, 64),
[69]81 consumptions: make(chan consumption),
82 closed: make(chan struct{}),
[22]83 }
[26]84
85 go func() {
[56]86 if err := dc.writeMessages(); err != nil {
87 dc.logger.Printf("failed to write message: %v", err)
[26]88 }
[55]89 if err := dc.net.Close(); err != nil {
90 dc.logger.Printf("failed to close connection: %v", err)
[45]91 } else {
[55]92 dc.logger.Printf("connection closed")
[45]93 }
[26]94 }()
95
[55]96 return dc
[22]97}
98
[55]99func (dc *downstreamConn) prefix() *irc.Prefix {
[27]100 return &irc.Prefix{
[55]101 Name: dc.nick,
102 User: dc.username,
[27]103 // TODO: fill the host?
104 }
105}
106
[69]107func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
108 return name
109}
110
[90]111func (dc *downstreamConn) forEachNetwork(f func(*network)) {
112 if dc.network != nil {
113 f(dc.network)
114 } else {
115 dc.user.forEachNetwork(f)
116 }
117}
118
[73]119func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
120 dc.user.forEachUpstream(func(uc *upstreamConn) {
[77]121 if dc.network != nil && uc.network != dc.network {
[73]122 return
123 }
124 f(uc)
125 })
126}
127
[89]128// upstream returns the upstream connection, if any. If there are zero or if
129// there are multiple upstream connections, it returns nil.
130func (dc *downstreamConn) upstream() *upstreamConn {
131 if dc.network == nil {
132 return nil
133 }
134
135 var upstream *upstreamConn
136 dc.forEachUpstream(func(uc *upstreamConn) {
137 upstream = uc
138 })
139 return upstream
140}
141
[69]142func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
[89]143 if uc := dc.upstream(); uc != nil {
144 return uc, name, nil
145 }
146
[73]147 // TODO: extract network name from channel name if dc.upstream == nil
148 var channel *upstreamChannel
149 var err error
150 dc.forEachUpstream(func(uc *upstreamConn) {
151 if err != nil {
152 return
153 }
154 if ch, ok := uc.channels[name]; ok {
155 if channel != nil {
156 err = fmt.Errorf("ambiguous channel name %q", name)
157 } else {
158 channel = ch
159 }
160 }
161 })
162 if channel == nil {
163 return nil, "", ircError{&irc.Message{
164 Command: irc.ERR_NOSUCHCHANNEL,
165 Params: []string{name, "No such channel"},
166 }}
[69]167 }
[73]168 return channel.conn, channel.Name, nil
[69]169}
170
171func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
172 if nick == uc.nick {
173 return dc.nick
174 }
175 return nick
176}
177
178func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
179 if prefix.Name == uc.nick {
180 return dc.prefix()
181 }
182 return prefix
183}
184
[57]185func (dc *downstreamConn) isClosed() bool {
186 select {
187 case <-dc.closed:
188 return true
189 default:
190 return false
191 }
192}
193
[103]194func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
[55]195 dc.logger.Printf("new connection")
[22]196
197 for {
[55]198 msg, err := dc.irc.ReadMessage()
[22]199 if err == io.EOF {
200 break
201 } else if err != nil {
202 return fmt.Errorf("failed to read IRC command: %v", err)
203 }
204
[64]205 if dc.srv.Debug {
206 dc.logger.Printf("received: %v", msg)
207 }
208
[103]209 ch <- downstreamIncomingMessage{msg, dc}
[22]210 }
211
[45]212 return nil
[22]213}
214
[56]215func (dc *downstreamConn) writeMessages() error {
[57]216 for {
217 var err error
218 var closed bool
219 select {
[102]220 case msg := <-dc.outgoing:
[64]221 if dc.srv.Debug {
222 dc.logger.Printf("sent: %v", msg)
223 }
[57]224 err = dc.irc.WriteMessage(msg)
[69]225 case consumption := <-dc.consumptions:
226 consumer, uc := consumption.consumer, consumption.upstreamConn
[57]227 for {
228 msg := consumer.Peek()
229 if msg == nil {
230 break
231 }
[69]232 msg = msg.Copy()
233 switch msg.Command {
234 case "PRIVMSG":
235 // TODO: detect whether it's a user or a channel
236 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
237 default:
238 panic("expected to consume a PRIVMSG message")
239 }
[64]240 if dc.srv.Debug {
241 dc.logger.Printf("sent: %v", msg)
242 }
[57]243 err = dc.irc.WriteMessage(msg)
244 if err != nil {
245 break
246 }
247 consumer.Consume()
248 }
249 case <-dc.closed:
250 closed = true
251 }
252 if err != nil {
[56]253 return err
254 }
[57]255 if closed {
256 break
257 }
[56]258 }
259 return nil
260}
261
[55]262func (dc *downstreamConn) Close() error {
[57]263 if dc.isClosed() {
[26]264 return fmt.Errorf("downstream connection already closed")
265 }
[40]266
[55]267 if u := dc.user; u != nil {
[40]268 u.lock.Lock()
269 for i := range u.downstreamConns {
[55]270 if u.downstreamConns[i] == dc {
[40]271 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
[63]272 break
[40]273 }
274 }
275 u.lock.Unlock()
[13]276 }
[40]277
[57]278 close(dc.closed)
[45]279 return nil
[13]280}
281
[55]282func (dc *downstreamConn) SendMessage(msg *irc.Message) {
[102]283 dc.outgoing <- msg
[54]284}
285
[55]286func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
[13]287 switch msg.Command {
[28]288 case "QUIT":
[55]289 return dc.Close()
[13]290 case "PING":
[55]291 dc.SendMessage(&irc.Message{
292 Prefix: dc.srv.prefix(),
[13]293 Command: "PONG",
[68]294 Params: msg.Params,
[54]295 })
[26]296 return nil
[13]297 default:
[55]298 if dc.registered {
299 return dc.handleMessageRegistered(msg)
[13]300 } else {
[55]301 return dc.handleMessageUnregistered(msg)
[13]302 }
303 }
304}
305
[55]306func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
[13]307 switch msg.Command {
308 case "NICK":
[55]309 if err := parseMessageParams(msg, &dc.nick); err != nil {
[43]310 return err
[13]311 }
312 case "USER":
[43]313 var username string
[55]314 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
[43]315 return err
[13]316 }
[100]317 dc.rawUsername = username
[85]318 case "PASS":
319 if err := parseMessageParams(msg, &dc.password); err != nil {
320 return err
321 }
[13]322 default:
[55]323 dc.logger.Printf("unhandled message: %v", msg)
[13]324 return newUnknownCommandError(msg.Command)
325 }
[100]326 if dc.rawUsername != "" && dc.nick != "" {
[55]327 return dc.register()
[13]328 }
329 return nil
330}
331
[91]332func sanityCheckServer(addr string) error {
333 dialer := net.Dialer{Timeout: 30 * time.Second}
334 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
335 if err != nil {
336 return err
337 }
338 return conn.Close()
339}
340
[55]341func (dc *downstreamConn) register() error {
[100]342 username := dc.rawUsername
[77]343 var networkName string
[73]344 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
[77]345 networkName = username[i+1:]
[73]346 }
347 if i := strings.IndexAny(username, "/@"); i >= 0 {
348 username = username[:i]
349 }
[100]350 dc.username = "~" + username
[73]351
[85]352 password := dc.password
353 dc.password = ""
354
[73]355 u := dc.srv.getUser(username)
[38]356 if u == nil {
[85]357 dc.logger.Printf("failed authentication for %q: unknown username", username)
358 return errAuthFailed
[37]359 }
360
[85]361 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
362 if err != nil {
363 dc.logger.Printf("failed authentication for %q: %v", username, err)
364 return errAuthFailed
365 }
366
[88]367 var network *network
[77]368 if networkName != "" {
[88]369 network = u.getNetwork(networkName)
370 if network == nil {
[91]371 addr := networkName
372 if !strings.ContainsRune(addr, ':') {
373 addr = addr + ":6697"
374 }
375
[95]376 dc.logger.Printf("trying to connect to new network %q", addr)
[91]377 if err := sanityCheckServer(addr); err != nil {
378 dc.logger.Printf("failed to connect to %q: %v", addr, err)
379 return ircError{&irc.Message{
380 Command: irc.ERR_PASSWDMISMATCH,
381 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
382 }}
383 }
384
[95]385 dc.logger.Printf("auto-saving network %q", networkName)
[91]386 network, err = u.createNetwork(networkName, dc.nick)
387 if err != nil {
388 return err
389 }
[73]390 }
391 }
392
[55]393 dc.registered = true
394 dc.user = u
[88]395 dc.network = network
[13]396
[40]397 u.lock.Lock()
[57]398 firstDownstream := len(u.downstreamConns) == 0
[55]399 u.downstreamConns = append(u.downstreamConns, dc)
[40]400 u.lock.Unlock()
401
[55]402 dc.SendMessage(&irc.Message{
403 Prefix: dc.srv.prefix(),
[13]404 Command: irc.RPL_WELCOME,
[98]405 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
[54]406 })
[55]407 dc.SendMessage(&irc.Message{
408 Prefix: dc.srv.prefix(),
[13]409 Command: irc.RPL_YOURHOST,
[55]410 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
[54]411 })
[55]412 dc.SendMessage(&irc.Message{
413 Prefix: dc.srv.prefix(),
[13]414 Command: irc.RPL_CREATED,
[55]415 Params: []string{dc.nick, "Who cares when the server was created?"},
[54]416 })
[55]417 dc.SendMessage(&irc.Message{
418 Prefix: dc.srv.prefix(),
[13]419 Command: irc.RPL_MYINFO,
[98]420 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
[54]421 })
[93]422 // TODO: RPL_ISUPPORT
[55]423 dc.SendMessage(&irc.Message{
424 Prefix: dc.srv.prefix(),
[13]425 Command: irc.ERR_NOMOTD,
[55]426 Params: []string{dc.nick, "No MOTD"},
[54]427 })
[13]428
[73]429 dc.forEachUpstream(func(uc *upstreamConn) {
[30]430 // TODO: fix races accessing upstream connection data
431 for _, ch := range uc.channels {
432 if ch.complete {
[55]433 forwardChannel(dc, ch)
[30]434 }
435 }
[50]436
[73]437 historyName := dc.username
[57]438
439 var seqPtr *uint64
440 if firstDownstream {
441 seq, ok := uc.history[historyName]
442 if ok {
443 seqPtr = &seq
[50]444 }
445 }
[57]446
[59]447 consumer, ch := uc.ring.NewConsumer(seqPtr)
[57]448 go func() {
449 for {
450 var closed bool
451 select {
452 case <-ch:
[69]453 dc.consumptions <- consumption{consumer, uc}
[57]454 case <-dc.closed:
455 closed = true
456 }
457 if closed {
458 break
459 }
460 }
461
462 seq := consumer.Close()
463
464 dc.user.lock.Lock()
465 lastDownstream := len(dc.user.downstreamConns) == 0
466 dc.user.lock.Unlock()
467
468 if lastDownstream {
469 uc.history[historyName] = seq
470 }
471 }()
[39]472 })
[50]473
[13]474 return nil
475}
476
[103]477func (dc *downstreamConn) runUntilRegistered() error {
478 for !dc.registered {
479 msg, err := dc.irc.ReadMessage()
480 if err == io.EOF {
481 break
482 } else if err != nil {
483 return fmt.Errorf("failed to read IRC command: %v", err)
484 }
485
486 err = dc.handleMessage(msg)
487 if ircErr, ok := err.(ircError); ok {
488 ircErr.Message.Prefix = dc.srv.prefix()
489 dc.SendMessage(ircErr.Message)
490 } else if err != nil {
491 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
492 }
493 }
494
495 return nil
496}
497
[55]498func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
[13]499 switch msg.Command {
[42]500 case "USER":
[13]501 return ircError{&irc.Message{
502 Command: irc.ERR_ALREADYREGISTERED,
[55]503 Params: []string{dc.nick, "You may not reregister"},
[13]504 }}
[42]505 case "NICK":
[90]506 var nick string
507 if err := parseMessageParams(msg, &nick); err != nil {
508 return err
509 }
510
511 var err error
512 dc.forEachNetwork(func(n *network) {
513 if err != nil {
514 return
515 }
516 n.Nick = nick
517 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
518 })
519 if err != nil {
520 return err
521 }
522
[73]523 dc.forEachUpstream(func(uc *upstreamConn) {
[60]524 uc.SendMessage(msg)
[42]525 })
[69]526 case "JOIN", "PART":
[48]527 var name string
528 if err := parseMessageParams(msg, &name); err != nil {
529 return err
530 }
531
[69]532 uc, upstreamName, err := dc.unmarshalChannel(name)
533 if err != nil {
534 return ircError{&irc.Message{
535 Command: irc.ERR_NOSUCHCHANNEL,
536 Params: []string{name, err.Error()},
537 }}
[48]538 }
539
[69]540 uc.SendMessage(&irc.Message{
541 Command: msg.Command,
542 Params: []string{upstreamName},
543 })
[89]544
545 switch msg.Command {
546 case "JOIN":
547 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
548 Name: upstreamName,
549 })
550 if err != nil {
551 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
552 }
553 case "PART":
554 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
555 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
556 }
557 }
[69]558 case "MODE":
559 if msg.Prefix == nil {
560 return fmt.Errorf("missing prefix")
[49]561 }
562
[46]563 var name string
564 if err := parseMessageParams(msg, &name); err != nil {
565 return err
566 }
567
568 var modeStr string
569 if len(msg.Params) > 1 {
570 modeStr = msg.Params[1]
571 }
572
573 if msg.Prefix.Name != name {
[69]574 uc, upstreamName, err := dc.unmarshalChannel(name)
[46]575 if err != nil {
576 return err
577 }
578
579 if modeStr != "" {
[69]580 uc.SendMessage(&irc.Message{
581 Command: "MODE",
582 Params: []string{upstreamName, modeStr},
583 })
[46]584 } else {
[69]585 ch, ok := uc.channels[upstreamName]
586 if !ok {
587 return ircError{&irc.Message{
588 Command: irc.ERR_NOSUCHCHANNEL,
589 Params: []string{name, "No such channel"},
590 }}
591 }
592
[55]593 dc.SendMessage(&irc.Message{
594 Prefix: dc.srv.prefix(),
[46]595 Command: irc.RPL_CHANNELMODEIS,
[69]596 Params: []string{name, string(ch.modes)},
[54]597 })
[46]598 }
599 } else {
[55]600 if name != dc.nick {
[46]601 return ircError{&irc.Message{
602 Command: irc.ERR_USERSDONTMATCH,
[55]603 Params: []string{dc.nick, "Cannot change mode for other users"},
[46]604 }}
605 }
606
607 if modeStr != "" {
[73]608 dc.forEachUpstream(func(uc *upstreamConn) {
[69]609 uc.SendMessage(&irc.Message{
610 Command: "MODE",
611 Params: []string{uc.nick, modeStr},
612 })
[46]613 })
614 } else {
[55]615 dc.SendMessage(&irc.Message{
616 Prefix: dc.srv.prefix(),
[46]617 Command: irc.RPL_UMODEIS,
618 Params: []string{""}, // TODO
[54]619 })
[46]620 }
621 }
[58]622 case "PRIVMSG":
623 var targetsStr, text string
624 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
625 return err
626 }
627
628 for _, name := range strings.Split(targetsStr, ",") {
[69]629 uc, upstreamName, err := dc.unmarshalChannel(name)
[58]630 if err != nil {
631 return err
632 }
633
[95]634 if upstreamName == "NickServ" {
635 dc.handleNickServPRIVMSG(uc, text)
636 }
637
[69]638 uc.SendMessage(&irc.Message{
[58]639 Command: "PRIVMSG",
[69]640 Params: []string{upstreamName, text},
[60]641 })
[58]642 }
[13]643 default:
[55]644 dc.logger.Printf("unhandled message: %v", msg)
[13]645 return newUnknownCommandError(msg.Command)
646 }
[42]647 return nil
[13]648}
[95]649
650func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
651 username, password, ok := parseNickServCredentials(text, uc.nick)
652 if !ok {
653 return
654 }
655
656 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
657 n := uc.network
658 n.SASL.Mechanism = "PLAIN"
659 n.SASL.Plain.Username = username
660 n.SASL.Plain.Password = password
661 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
662 dc.logger.Printf("failed to save NickServ credentials: %v", err)
663 }
664}
665
666func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
667 fields := strings.Fields(text)
668 if len(fields) < 2 {
669 return "", "", false
670 }
671 cmd := strings.ToUpper(fields[0])
672 params := fields[1:]
673 switch cmd {
674 case "REGISTER":
675 username = nick
676 password = params[0]
677 case "IDENTIFY":
678 if len(params) == 1 {
679 username = nick
680 } else {
681 username = params[0]
682 }
683 password = params[1]
684 }
685 return username, password, true
686}
Note: See TracBrowser for help on using the repository browser.