source: code/trunk/downstream.go@ 209

Last change on this file since 209 was 209, checked in by contact, 5 years ago

Fix writer goroutine races

Any SendMessage call after Close could potentially block forever if the
outgoing channel was filled up. Now the channel is drained before the
writer goroutine exits.

File size: 32.0 KB
RevLine 
[98]1package soju
[13]2
3import (
[91]4 "crypto/tls"
[112]5 "encoding/base64"
[13]6 "fmt"
7 "io"
8 "net"
[108]9 "strconv"
[39]10 "strings"
[105]11 "sync"
[91]12 "time"
[13]13
[112]14 "github.com/emersion/go-sasl"
[85]15 "golang.org/x/crypto/bcrypt"
[13]16 "gopkg.in/irc.v3"
17)
18
19type ircError struct {
20 Message *irc.Message
21}
22
[85]23func (err ircError) Error() string {
24 return err.Message.String()
25}
26
[13]27func newUnknownCommandError(cmd string) ircError {
28 return ircError{&irc.Message{
29 Command: irc.ERR_UNKNOWNCOMMAND,
30 Params: []string{
31 "*",
32 cmd,
33 "Unknown command",
34 },
35 }}
36}
37
38func newNeedMoreParamsError(cmd string) ircError {
39 return ircError{&irc.Message{
40 Command: irc.ERR_NEEDMOREPARAMS,
41 Params: []string{
42 "*",
43 cmd,
44 "Not enough parameters",
45 },
46 }}
47}
48
[85]49var errAuthFailed = ircError{&irc.Message{
50 Command: irc.ERR_PASSWDMISMATCH,
51 Params: []string{"*", "Invalid username or password"},
52}}
[13]53
54type downstreamConn struct {
[188]55 id uint64
56 net net.Conn
57 irc *irc.Conn
58 srv *Server
59 logger Logger
[209]60 outgoing chan<- *irc.Message
[188]61 closed chan struct{}
[22]62
[100]63 registered bool
64 user *user
65 nick string
66 rawUsername string
[168]67 networkName string
[183]68 clientName string
[100]69 realname string
[141]70 hostname string
[100]71 password string // empty after authentication
72 network *network // can be nil
[105]73
[204]74 ringConsumers map[*network]*RingConsumer
75
[108]76 negociatingCaps bool
77 capVersion int
78 caps map[string]bool
79
[112]80 saslServer sasl.Server
81
[105]82 lock sync.Mutex
83 ourMessages map[*irc.Message]struct{}
[13]84}
85
[154]86func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn {
[209]87 outgoing := make(chan *irc.Message, 64)
[55]88 dc := &downstreamConn{
[204]89 id: id,
90 net: netConn,
91 irc: irc.NewConn(netConn),
92 srv: srv,
93 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
[209]94 outgoing: outgoing,
[204]95 closed: make(chan struct{}),
96 ringConsumers: make(map[*network]*RingConsumer),
97 caps: make(map[string]bool),
98 ourMessages: make(map[*irc.Message]struct{}),
[22]99 }
[141]100 dc.hostname = netConn.RemoteAddr().String()
101 if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
102 dc.hostname = host
103 }
[26]104
105 go func() {
[209]106 for msg := range outgoing {
107 if dc.srv.Debug {
108 dc.logger.Printf("sent: %v", msg)
109 }
110 dc.net.SetWriteDeadline(time.Now().Add(writeTimeout))
111 if err := dc.irc.WriteMessage(msg); err != nil {
112 dc.logger.Printf("failed to write message: %v", err)
113 break
114 }
[26]115 }
[55]116 if err := dc.net.Close(); err != nil {
117 dc.logger.Printf("failed to close connection: %v", err)
[45]118 } else {
[55]119 dc.logger.Printf("connection closed")
[45]120 }
[209]121 // Drain the outgoing channel to prevent SendMessage from blocking
122 for range outgoing {
123 // This space is intentionally left blank
124 }
[26]125 }()
126
[130]127 dc.logger.Printf("new connection")
[55]128 return dc
[22]129}
130
[55]131func (dc *downstreamConn) prefix() *irc.Prefix {
[27]132 return &irc.Prefix{
[55]133 Name: dc.nick,
[184]134 User: dc.user.Username,
[141]135 Host: dc.hostname,
[27]136 }
137}
138
[90]139func (dc *downstreamConn) forEachNetwork(f func(*network)) {
140 if dc.network != nil {
141 f(dc.network)
142 } else {
143 dc.user.forEachNetwork(f)
144 }
145}
146
[73]147func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
148 dc.user.forEachUpstream(func(uc *upstreamConn) {
[77]149 if dc.network != nil && uc.network != dc.network {
[73]150 return
151 }
152 f(uc)
153 })
154}
155
[89]156// upstream returns the upstream connection, if any. If there are zero or if
157// there are multiple upstream connections, it returns nil.
158func (dc *downstreamConn) upstream() *upstreamConn {
159 if dc.network == nil {
160 return nil
161 }
[136]162 return dc.network.upstream()
[89]163}
164
[129]165func (dc *downstreamConn) marshalEntity(uc *upstreamConn, entity string) string {
166 if uc.isChannel(entity) {
167 return dc.marshalChannel(uc, entity)
[119]168 }
[129]169 return dc.marshalNick(uc, entity)
[119]170}
171
172func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
[130]173 if dc.network != nil {
[119]174 return name
175 }
176 return name + "/" + uc.network.GetName()
177}
178
[127]179func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) {
[89]180 if uc := dc.upstream(); uc != nil {
181 return uc, name, nil
182 }
183
[127]184 var conn *upstreamConn
[119]185 if i := strings.LastIndexByte(name, '/'); i >= 0 {
[127]186 network := name[i+1:]
[119]187 name = name[:i]
188
189 dc.forEachUpstream(func(uc *upstreamConn) {
190 if network != uc.network.GetName() {
191 return
192 }
193 conn = uc
194 })
195 }
196
[127]197 if conn == nil {
[73]198 return nil, "", ircError{&irc.Message{
199 Command: irc.ERR_NOSUCHCHANNEL,
200 Params: []string{name, "No such channel"},
201 }}
[69]202 }
[127]203 return conn, name, nil
[69]204}
205
206func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
207 if nick == uc.nick {
208 return dc.nick
209 }
[130]210 if dc.network != nil {
[119]211 return nick
212 }
213 return nick + "/" + uc.network.GetName()
[69]214}
215
216func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
217 if prefix.Name == uc.nick {
218 return dc.prefix()
219 }
[130]220 if dc.network != nil {
[119]221 return prefix
222 }
223 return &irc.Prefix{
224 Name: prefix.Name + "/" + uc.network.GetName(),
225 User: prefix.User,
226 Host: prefix.Host,
227 }
[69]228}
229
[57]230func (dc *downstreamConn) isClosed() bool {
231 select {
232 case <-dc.closed:
233 return true
234 default:
235 return false
236 }
237}
238
[165]239func (dc *downstreamConn) readMessages(ch chan<- event) error {
[22]240 for {
[55]241 msg, err := dc.irc.ReadMessage()
[22]242 if err == io.EOF {
243 break
244 } else if err != nil {
245 return fmt.Errorf("failed to read IRC command: %v", err)
246 }
247
[64]248 if dc.srv.Debug {
249 dc.logger.Printf("received: %v", msg)
250 }
251
[165]252 ch <- eventDownstreamMessage{msg, dc}
[22]253 }
254
[45]255 return nil
[22]256}
257
[56]258func (dc *downstreamConn) writeMessages() error {
259 return nil
260}
261
[180]262// Close closes the connection. It is safe to call from any goroutine.
[55]263func (dc *downstreamConn) Close() error {
[57]264 if dc.isClosed() {
[26]265 return fmt.Errorf("downstream connection already closed")
266 }
[57]267 close(dc.closed)
[209]268 close(dc.outgoing)
[45]269 return nil
[13]270}
271
[180]272// SendMessage queues a new outgoing message. It is safe to call from any
273// goroutine.
[55]274func (dc *downstreamConn) SendMessage(msg *irc.Message) {
[209]275 if dc.isClosed() {
276 return
277 }
[191]278 // TODO: strip tags if the client doesn't support them (see runNetwork)
[102]279 dc.outgoing <- msg
[54]280}
281
[55]282func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
[13]283 switch msg.Command {
[28]284 case "QUIT":
[55]285 return dc.Close()
[13]286 default:
[55]287 if dc.registered {
288 return dc.handleMessageRegistered(msg)
[13]289 } else {
[55]290 return dc.handleMessageUnregistered(msg)
[13]291 }
292 }
293}
294
[55]295func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
[13]296 switch msg.Command {
297 case "NICK":
[117]298 var nick string
299 if err := parseMessageParams(msg, &nick); err != nil {
[43]300 return err
[13]301 }
[117]302 if nick == serviceNick {
303 return ircError{&irc.Message{
304 Command: irc.ERR_NICKNAMEINUSE,
305 Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
306 }}
307 }
308 dc.nick = nick
[13]309 case "USER":
[117]310 if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil {
[43]311 return err
[13]312 }
[85]313 case "PASS":
314 if err := parseMessageParams(msg, &dc.password); err != nil {
315 return err
316 }
[108]317 case "CAP":
318 var subCmd string
319 if err := parseMessageParams(msg, &subCmd); err != nil {
320 return err
321 }
322 if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
323 return err
324 }
[112]325 case "AUTHENTICATE":
326 if !dc.caps["sasl"] {
327 return ircError{&irc.Message{
[125]328 Command: irc.ERR_SASLFAIL,
[112]329 Params: []string{"*", "AUTHENTICATE requires the \"sasl\" capability to be enabled"},
330 }}
331 }
332 if len(msg.Params) == 0 {
333 return ircError{&irc.Message{
[125]334 Command: irc.ERR_SASLFAIL,
[112]335 Params: []string{"*", "Missing AUTHENTICATE argument"},
336 }}
337 }
338 if dc.nick == "" {
339 return ircError{&irc.Message{
[125]340 Command: irc.ERR_SASLFAIL,
[112]341 Params: []string{"*", "Expected NICK command before AUTHENTICATE"},
342 }}
343 }
344
345 var resp []byte
346 if dc.saslServer == nil {
347 mech := strings.ToUpper(msg.Params[0])
348 switch mech {
349 case "PLAIN":
350 dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
351 return dc.authenticate(username, password)
352 }))
353 default:
354 return ircError{&irc.Message{
[125]355 Command: irc.ERR_SASLFAIL,
[112]356 Params: []string{"*", fmt.Sprintf("Unsupported SASL mechanism %q", mech)},
357 }}
358 }
359 } else if msg.Params[0] == "*" {
360 dc.saslServer = nil
361 return ircError{&irc.Message{
[125]362 Command: irc.ERR_SASLABORTED,
[112]363 Params: []string{"*", "SASL authentication aborted"},
364 }}
365 } else if msg.Params[0] == "+" {
366 resp = nil
367 } else {
368 // TODO: multi-line messages
369 var err error
370 resp, err = base64.StdEncoding.DecodeString(msg.Params[0])
371 if err != nil {
372 dc.saslServer = nil
373 return ircError{&irc.Message{
[125]374 Command: irc.ERR_SASLFAIL,
[112]375 Params: []string{"*", "Invalid base64-encoded response"},
376 }}
377 }
378 }
379
380 challenge, done, err := dc.saslServer.Next(resp)
381 if err != nil {
382 dc.saslServer = nil
383 if ircErr, ok := err.(ircError); ok && ircErr.Message.Command == irc.ERR_PASSWDMISMATCH {
384 return ircError{&irc.Message{
[125]385 Command: irc.ERR_SASLFAIL,
[112]386 Params: []string{"*", ircErr.Message.Params[1]},
387 }}
388 }
389 dc.SendMessage(&irc.Message{
390 Prefix: dc.srv.prefix(),
[125]391 Command: irc.ERR_SASLFAIL,
[112]392 Params: []string{"*", "SASL error"},
393 })
394 return fmt.Errorf("SASL authentication failed: %v", err)
395 } else if done {
396 dc.saslServer = nil
397 dc.SendMessage(&irc.Message{
398 Prefix: dc.srv.prefix(),
[125]399 Command: irc.RPL_LOGGEDIN,
[112]400 Params: []string{dc.nick, dc.nick, dc.user.Username, "You are now logged in"},
401 })
402 dc.SendMessage(&irc.Message{
403 Prefix: dc.srv.prefix(),
[125]404 Command: irc.RPL_SASLSUCCESS,
[112]405 Params: []string{dc.nick, "SASL authentication successful"},
406 })
407 } else {
408 challengeStr := "+"
[135]409 if len(challenge) > 0 {
[112]410 challengeStr = base64.StdEncoding.EncodeToString(challenge)
411 }
412
413 // TODO: multi-line messages
414 dc.SendMessage(&irc.Message{
415 Prefix: dc.srv.prefix(),
416 Command: "AUTHENTICATE",
417 Params: []string{challengeStr},
418 })
419 }
[13]420 default:
[55]421 dc.logger.Printf("unhandled message: %v", msg)
[13]422 return newUnknownCommandError(msg.Command)
423 }
[108]424 if dc.rawUsername != "" && dc.nick != "" && !dc.negociatingCaps {
[55]425 return dc.register()
[13]426 }
427 return nil
428}
429
[108]430func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
[111]431 cmd = strings.ToUpper(cmd)
432
[108]433 replyTo := dc.nick
434 if !dc.registered {
435 replyTo = "*"
436 }
437
438 switch cmd {
439 case "LS":
440 if len(args) > 0 {
441 var err error
442 if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
443 return err
444 }
445 }
446
[194]447 caps := []string{"message-tags", "server-time"}
448
[112]449 if dc.capVersion >= 302 {
[108]450 caps = append(caps, "sasl=PLAIN")
451 } else {
452 caps = append(caps, "sasl")
[112]453 }
[108]454
455 // TODO: multi-line replies
456 dc.SendMessage(&irc.Message{
457 Prefix: dc.srv.prefix(),
458 Command: "CAP",
459 Params: []string{replyTo, "LS", strings.Join(caps, " ")},
460 })
461
462 if !dc.registered {
463 dc.negociatingCaps = true
464 }
465 case "LIST":
466 var caps []string
467 for name := range dc.caps {
468 caps = append(caps, name)
469 }
470
471 // TODO: multi-line replies
472 dc.SendMessage(&irc.Message{
473 Prefix: dc.srv.prefix(),
474 Command: "CAP",
475 Params: []string{replyTo, "LIST", strings.Join(caps, " ")},
476 })
477 case "REQ":
478 if len(args) == 0 {
479 return ircError{&irc.Message{
480 Command: err_invalidcapcmd,
481 Params: []string{replyTo, cmd, "Missing argument in CAP REQ command"},
482 }}
483 }
484
485 caps := strings.Fields(args[0])
486 ack := true
487 for _, name := range caps {
488 name = strings.ToLower(name)
489 enable := !strings.HasPrefix(name, "-")
490 if !enable {
491 name = strings.TrimPrefix(name, "-")
492 }
493
494 enabled := dc.caps[name]
495 if enable == enabled {
496 continue
497 }
498
499 switch name {
[194]500 case "sasl", "message-tags", "server-time":
[112]501 dc.caps[name] = enable
[108]502 default:
503 ack = false
504 }
505 }
506
507 reply := "NAK"
508 if ack {
509 reply = "ACK"
510 }
511 dc.SendMessage(&irc.Message{
512 Prefix: dc.srv.prefix(),
513 Command: "CAP",
514 Params: []string{replyTo, reply, args[0]},
515 })
516 case "END":
517 dc.negociatingCaps = false
518 default:
519 return ircError{&irc.Message{
520 Command: err_invalidcapcmd,
521 Params: []string{replyTo, cmd, "Unknown CAP command"},
522 }}
523 }
524 return nil
525}
526
[91]527func sanityCheckServer(addr string) error {
528 dialer := net.Dialer{Timeout: 30 * time.Second}
529 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
530 if err != nil {
531 return err
532 }
533 return conn.Close()
534}
535
[183]536func unmarshalUsername(rawUsername string) (username, client, network string) {
[112]537 username = rawUsername
[183]538
539 i := strings.IndexAny(username, "/@")
540 j := strings.LastIndexAny(username, "/@")
541 if i >= 0 {
542 username = rawUsername[:i]
[73]543 }
[183]544 if j >= 0 {
[190]545 if rawUsername[j] == '@' {
546 client = rawUsername[j+1:]
547 } else {
548 network = rawUsername[j+1:]
549 }
[73]550 }
[183]551 if i >= 0 && j >= 0 && i < j {
[190]552 if rawUsername[i] == '@' {
553 client = rawUsername[i+1 : j]
554 } else {
555 network = rawUsername[i+1 : j]
556 }
[183]557 }
558
559 return username, client, network
[112]560}
[73]561
[168]562func (dc *downstreamConn) authenticate(username, password string) error {
[183]563 username, clientName, networkName := unmarshalUsername(username)
[168]564
[173]565 u, err := dc.srv.db.GetUser(username)
566 if err != nil {
567 dc.logger.Printf("failed authentication for %q: %v", username, err)
[168]568 return errAuthFailed
569 }
570
[173]571 err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
[168]572 if err != nil {
573 dc.logger.Printf("failed authentication for %q: %v", username, err)
574 return errAuthFailed
575 }
576
[173]577 dc.user = dc.srv.getUser(username)
578 if dc.user == nil {
579 dc.logger.Printf("failed authentication for %q: user not active", username)
580 return errAuthFailed
581 }
[183]582 dc.clientName = clientName
[168]583 dc.networkName = networkName
584 return nil
585}
586
587func (dc *downstreamConn) register() error {
588 if dc.registered {
589 return fmt.Errorf("tried to register twice")
590 }
591
592 password := dc.password
593 dc.password = ""
594 if dc.user == nil {
595 if err := dc.authenticate(dc.rawUsername, password); err != nil {
596 return err
597 }
598 }
599
[183]600 if dc.clientName == "" && dc.networkName == "" {
601 _, dc.clientName, dc.networkName = unmarshalUsername(dc.rawUsername)
[168]602 }
603
604 dc.registered = true
[184]605 dc.logger.Printf("registration complete for user %q", dc.user.Username)
[168]606 return nil
607}
608
609func (dc *downstreamConn) loadNetwork() error {
610 if dc.networkName == "" {
[112]611 return nil
612 }
[85]613
[168]614 network := dc.user.getNetwork(dc.networkName)
[112]615 if network == nil {
[168]616 addr := dc.networkName
[112]617 if !strings.ContainsRune(addr, ':') {
618 addr = addr + ":6697"
619 }
620
621 dc.logger.Printf("trying to connect to new network %q", addr)
622 if err := sanityCheckServer(addr); err != nil {
623 dc.logger.Printf("failed to connect to %q: %v", addr, err)
624 return ircError{&irc.Message{
625 Command: irc.ERR_PASSWDMISMATCH,
[168]626 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", dc.networkName)},
[112]627 }}
628 }
629
[168]630 dc.logger.Printf("auto-saving network %q", dc.networkName)
[112]631 var err error
[120]632 network, err = dc.user.createNetwork(&Network{
[168]633 Addr: dc.networkName,
[120]634 Nick: dc.nick,
635 })
[112]636 if err != nil {
637 return err
638 }
639 }
640
641 dc.network = network
642 return nil
643}
644
[168]645func (dc *downstreamConn) welcome() error {
646 if dc.user == nil || !dc.registered {
647 panic("tried to welcome an unregistered connection")
[37]648 }
649
[168]650 // TODO: doing this might take some time. We should do it in dc.register
651 // instead, but we'll potentially be adding a new network and this must be
652 // done in the user goroutine.
653 if err := dc.loadNetwork(); err != nil {
654 return err
[85]655 }
656
[185]657 // Only send history if we're the first connected client with that name and
658 // network
659 sendHistory := true
660 dc.user.forEachDownstream(func(conn *downstreamConn) {
661 if dc.clientName == conn.clientName && dc.network == conn.network {
662 sendHistory = false
663 }
664 })
[40]665
[55]666 dc.SendMessage(&irc.Message{
667 Prefix: dc.srv.prefix(),
[13]668 Command: irc.RPL_WELCOME,
[98]669 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
[54]670 })
[55]671 dc.SendMessage(&irc.Message{
672 Prefix: dc.srv.prefix(),
[13]673 Command: irc.RPL_YOURHOST,
[55]674 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
[54]675 })
[55]676 dc.SendMessage(&irc.Message{
677 Prefix: dc.srv.prefix(),
[13]678 Command: irc.RPL_CREATED,
[55]679 Params: []string{dc.nick, "Who cares when the server was created?"},
[54]680 })
[55]681 dc.SendMessage(&irc.Message{
682 Prefix: dc.srv.prefix(),
[13]683 Command: irc.RPL_MYINFO,
[98]684 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
[54]685 })
[93]686 // TODO: RPL_ISUPPORT
[55]687 dc.SendMessage(&irc.Message{
688 Prefix: dc.srv.prefix(),
[13]689 Command: irc.ERR_NOMOTD,
[55]690 Params: []string{dc.nick, "No MOTD"},
[54]691 })
[13]692
[73]693 dc.forEachUpstream(func(uc *upstreamConn) {
[30]694 for _, ch := range uc.channels {
695 if ch.complete {
[132]696 dc.SendMessage(&irc.Message{
697 Prefix: dc.prefix(),
698 Command: "JOIN",
699 Params: []string{dc.marshalChannel(ch.conn, ch.Name)},
700 })
701
[55]702 forwardChannel(dc, ch)
[30]703 }
704 }
[143]705 })
[50]706
[143]707 dc.forEachNetwork(func(net *network) {
[185]708 dc.runNetwork(net, sendHistory)
[144]709 })
[57]710
[144]711 return nil
712}
713
714// runNetwork starts listening for messages coming from the network's ring
715// buffer.
716//
717// It panics if the network is not suitable for the downstream connection.
718func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) {
719 if dc.network != nil && net != dc.network {
720 panic("network not suitable for downstream connection")
721 }
722
723 var seqPtr *uint64
724 if loadHistory {
[185]725 seq, ok := net.history[dc.clientName]
[144]726 if ok {
727 seqPtr = &seq
[50]728 }
[144]729 }
[57]730
[191]731 // TODO: can't be enabled/disabled on-the-fly
732 msgTagsEnabled := dc.caps["message-tags"]
[194]733 serverTimeEnabled := dc.caps["server-time"]
[191]734
[144]735 consumer, ch := net.ring.NewConsumer(seqPtr)
[204]736
737 if _, ok := dc.ringConsumers[net]; ok {
738 panic("network has been added twice")
739 }
740 dc.ringConsumers[net] = consumer
741
[144]742 go func() {
[204]743 for range ch {
744 uc := net.upstream()
745 if uc == nil {
746 dc.logger.Printf("ignoring messages for upstream %q: upstream is disconnected", net.Addr)
747 continue
748 }
749
750 for {
751 msg := consumer.Peek()
752 if msg == nil {
[203]753 break
754 }
755
[204]756 dc.lock.Lock()
757 _, ours := dc.ourMessages[msg]
758 delete(dc.ourMessages, msg)
759 dc.lock.Unlock()
760 if ours {
761 // The message comes from our connection, don't echo it
762 // back
763 consumer.Consume()
764 continue
[57]765 }
[188]766
[204]767 msg = msg.Copy()
768 switch msg.Command {
769 case "PRIVMSG":
770 msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix)
771 msg.Params[0] = dc.marshalEntity(uc, msg.Params[0])
772 default:
773 panic("expected to consume a PRIVMSG message")
774 }
[188]775
[204]776 if !msgTagsEnabled {
777 for name := range msg.Tags {
778 supported := false
779 switch name {
780 case "time":
781 supported = serverTimeEnabled
[194]782 }
[204]783 if !supported {
784 delete(msg.Tags, name)
785 }
[191]786 }
[204]787 }
[191]788
[204]789 dc.SendMessage(msg)
790 consumer.Consume()
[57]791 }
[144]792 }
793 }()
[13]794}
795
[103]796func (dc *downstreamConn) runUntilRegistered() error {
797 for !dc.registered {
798 msg, err := dc.irc.ReadMessage()
[106]799 if err != nil {
[103]800 return fmt.Errorf("failed to read IRC command: %v", err)
801 }
802
[110]803 if dc.srv.Debug {
804 dc.logger.Printf("received: %v", msg)
805 }
806
[103]807 err = dc.handleMessage(msg)
808 if ircErr, ok := err.(ircError); ok {
809 ircErr.Message.Prefix = dc.srv.prefix()
810 dc.SendMessage(ircErr.Message)
811 } else if err != nil {
812 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
813 }
814 }
815
816 return nil
817}
818
[55]819func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
[13]820 switch msg.Command {
[111]821 case "CAP":
822 var subCmd string
823 if err := parseMessageParams(msg, &subCmd); err != nil {
824 return err
825 }
826 if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
827 return err
828 }
[107]829 case "PING":
830 dc.SendMessage(&irc.Message{
831 Prefix: dc.srv.prefix(),
832 Command: "PONG",
833 Params: msg.Params,
834 })
835 return nil
[42]836 case "USER":
[13]837 return ircError{&irc.Message{
838 Command: irc.ERR_ALREADYREGISTERED,
[55]839 Params: []string{dc.nick, "You may not reregister"},
[13]840 }}
[42]841 case "NICK":
[90]842 var nick string
843 if err := parseMessageParams(msg, &nick); err != nil {
844 return err
845 }
846
847 var err error
848 dc.forEachNetwork(func(n *network) {
849 if err != nil {
850 return
851 }
852 n.Nick = nick
853 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
854 })
855 if err != nil {
856 return err
857 }
858
[73]859 dc.forEachUpstream(func(uc *upstreamConn) {
[60]860 uc.SendMessage(msg)
[42]861 })
[146]862 case "JOIN":
863 var namesStr string
864 if err := parseMessageParams(msg, &namesStr); err != nil {
[48]865 return err
866 }
867
[146]868 var keys []string
869 if len(msg.Params) > 1 {
870 keys = strings.Split(msg.Params[1], ",")
871 }
872
873 for i, name := range strings.Split(namesStr, ",") {
[145]874 uc, upstreamName, err := dc.unmarshalEntity(name)
875 if err != nil {
[158]876 return err
[145]877 }
[48]878
[146]879 var key string
880 if len(keys) > i {
881 key = keys[i]
882 }
883
884 params := []string{upstreamName}
885 if key != "" {
886 params = append(params, key)
887 }
[145]888 uc.SendMessage(&irc.Message{
[146]889 Command: "JOIN",
890 Params: params,
[145]891 })
[89]892
[207]893 ch, err := dc.srv.db.GetChannel(uc.network.ID, upstreamName)
894 if err == ErrNoSuchChannel {
895 ch = &Channel{Name: upstreamName}
896 } else if err != nil {
897 return err
[89]898 }
[207]899
900 ch.Key = key
901
902 if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
903 return err
904 }
[89]905 }
[146]906 case "PART":
907 var namesStr string
908 if err := parseMessageParams(msg, &namesStr); err != nil {
909 return err
910 }
911
912 var reason string
913 if len(msg.Params) > 1 {
914 reason = msg.Params[1]
915 }
916
917 for _, name := range strings.Split(namesStr, ",") {
918 uc, upstreamName, err := dc.unmarshalEntity(name)
919 if err != nil {
[158]920 return err
[146]921 }
922
923 params := []string{upstreamName}
924 if reason != "" {
925 params = append(params, reason)
926 }
927 uc.SendMessage(&irc.Message{
928 Command: "PART",
929 Params: params,
930 })
931
932 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
933 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
934 }
935 }
[159]936 case "KICK":
937 var channelStr, userStr string
938 if err := parseMessageParams(msg, &channelStr, &userStr); err != nil {
939 return err
940 }
941
942 channels := strings.Split(channelStr, ",")
943 users := strings.Split(userStr, ",")
944
945 var reason string
946 if len(msg.Params) > 2 {
947 reason = msg.Params[2]
948 }
949
950 if len(channels) != 1 && len(channels) != len(users) {
951 return ircError{&irc.Message{
952 Command: irc.ERR_BADCHANMASK,
953 Params: []string{dc.nick, channelStr, "Bad channel mask"},
954 }}
955 }
956
957 for i, user := range users {
958 var channel string
959 if len(channels) == 1 {
960 channel = channels[0]
961 } else {
962 channel = channels[i]
963 }
964
965 ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
966 if err != nil {
967 return err
968 }
969
970 ucUser, upstreamUser, err := dc.unmarshalEntity(user)
971 if err != nil {
972 return err
973 }
974
975 if ucChannel != ucUser {
976 return ircError{&irc.Message{
977 Command: irc.ERR_USERNOTINCHANNEL,
978 Params: []string{dc.nick, user, channel, "They aren't on that channel"},
979 }}
980 }
981 uc := ucChannel
982
983 params := []string{upstreamChannel, upstreamUser}
984 if reason != "" {
985 params = append(params, reason)
986 }
987 uc.SendMessage(&irc.Message{
988 Command: "KICK",
989 Params: params,
990 })
991 }
[69]992 case "MODE":
[46]993 var name string
994 if err := parseMessageParams(msg, &name); err != nil {
995 return err
996 }
997
998 var modeStr string
999 if len(msg.Params) > 1 {
1000 modeStr = msg.Params[1]
1001 }
1002
[139]1003 if name == dc.nick {
[46]1004 if modeStr != "" {
[73]1005 dc.forEachUpstream(func(uc *upstreamConn) {
[69]1006 uc.SendMessage(&irc.Message{
1007 Command: "MODE",
1008 Params: []string{uc.nick, modeStr},
1009 })
[46]1010 })
1011 } else {
[55]1012 dc.SendMessage(&irc.Message{
1013 Prefix: dc.srv.prefix(),
[46]1014 Command: irc.RPL_UMODEIS,
[129]1015 Params: []string{dc.nick, ""}, // TODO
[54]1016 })
[46]1017 }
[139]1018 return nil
[46]1019 }
[139]1020
1021 uc, upstreamName, err := dc.unmarshalEntity(name)
1022 if err != nil {
1023 return err
1024 }
1025
1026 if !uc.isChannel(upstreamName) {
1027 return ircError{&irc.Message{
1028 Command: irc.ERR_USERSDONTMATCH,
1029 Params: []string{dc.nick, "Cannot change mode for other users"},
1030 }}
1031 }
1032
1033 if modeStr != "" {
1034 params := []string{upstreamName, modeStr}
1035 params = append(params, msg.Params[2:]...)
1036 uc.SendMessage(&irc.Message{
1037 Command: "MODE",
1038 Params: params,
1039 })
1040 } else {
1041 ch, ok := uc.channels[upstreamName]
1042 if !ok {
1043 return ircError{&irc.Message{
1044 Command: irc.ERR_NOSUCHCHANNEL,
1045 Params: []string{dc.nick, name, "No such channel"},
1046 }}
1047 }
1048
1049 if ch.modes == nil {
1050 // we haven't received the initial RPL_CHANNELMODEIS yet
1051 // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS
1052 return nil
1053 }
1054
1055 modeStr, modeParams := ch.modes.Format()
1056 params := []string{dc.nick, name, modeStr}
1057 params = append(params, modeParams...)
1058
1059 dc.SendMessage(&irc.Message{
1060 Prefix: dc.srv.prefix(),
1061 Command: irc.RPL_CHANNELMODEIS,
1062 Params: params,
1063 })
[162]1064 if ch.creationTime != "" {
1065 dc.SendMessage(&irc.Message{
1066 Prefix: dc.srv.prefix(),
1067 Command: rpl_creationtime,
1068 Params: []string{dc.nick, name, ch.creationTime},
1069 })
1070 }
[139]1071 }
[160]1072 case "TOPIC":
1073 var channel string
1074 if err := parseMessageParams(msg, &channel); err != nil {
1075 return err
1076 }
1077
1078 uc, upstreamChannel, err := dc.unmarshalEntity(channel)
1079 if err != nil {
1080 return err
1081 }
1082
1083 if len(msg.Params) > 1 { // setting topic
1084 topic := msg.Params[1]
1085 uc.SendMessage(&irc.Message{
1086 Command: "TOPIC",
1087 Params: []string{upstreamChannel, topic},
1088 })
1089 } else { // getting topic
1090 ch, ok := uc.channels[upstreamChannel]
1091 if !ok {
1092 return ircError{&irc.Message{
1093 Command: irc.ERR_NOSUCHCHANNEL,
1094 Params: []string{dc.nick, upstreamChannel, "No such channel"},
1095 }}
1096 }
1097 sendTopic(dc, ch)
1098 }
[177]1099 case "LIST":
1100 // TODO: support ELIST when supported by all upstreams
1101
1102 pl := pendingLIST{
1103 downstreamID: dc.id,
1104 pendingCommands: make(map[int64]*irc.Message),
1105 }
1106 var upstreamChannels map[int64][]string
1107 if len(msg.Params) > 0 {
1108 upstreamChannels = make(map[int64][]string)
1109 channels := strings.Split(msg.Params[0], ",")
1110 for _, channel := range channels {
1111 uc, upstreamChannel, err := dc.unmarshalEntity(channel)
1112 if err != nil {
1113 return err
1114 }
1115 upstreamChannels[uc.network.ID] = append(upstreamChannels[uc.network.ID], upstreamChannel)
1116 }
1117 }
1118
1119 dc.user.pendingLISTs = append(dc.user.pendingLISTs, pl)
1120 dc.forEachUpstream(func(uc *upstreamConn) {
1121 var params []string
1122 if upstreamChannels != nil {
1123 if channels, ok := upstreamChannels[uc.network.ID]; ok {
1124 params = []string{strings.Join(channels, ",")}
1125 } else {
1126 return
1127 }
1128 }
1129 pl.pendingCommands[uc.network.ID] = &irc.Message{
1130 Command: "LIST",
1131 Params: params,
1132 }
[181]1133 uc.trySendLIST(dc.id)
[177]1134 })
[140]1135 case "NAMES":
1136 if len(msg.Params) == 0 {
1137 dc.SendMessage(&irc.Message{
1138 Prefix: dc.srv.prefix(),
1139 Command: irc.RPL_ENDOFNAMES,
1140 Params: []string{dc.nick, "*", "End of /NAMES list"},
1141 })
1142 return nil
1143 }
1144
1145 channels := strings.Split(msg.Params[0], ",")
1146 for _, channel := range channels {
1147 uc, upstreamChannel, err := dc.unmarshalEntity(channel)
1148 if err != nil {
1149 return err
1150 }
1151
1152 ch, ok := uc.channels[upstreamChannel]
1153 if ok {
1154 sendNames(dc, ch)
1155 } else {
1156 // NAMES on a channel we have not joined, ask upstream
[176]1157 uc.SendMessageLabeled(dc.id, &irc.Message{
[140]1158 Command: "NAMES",
1159 Params: []string{upstreamChannel},
1160 })
1161 }
1162 }
[127]1163 case "WHO":
1164 if len(msg.Params) == 0 {
1165 // TODO: support WHO without parameters
1166 dc.SendMessage(&irc.Message{
1167 Prefix: dc.srv.prefix(),
1168 Command: irc.RPL_ENDOFWHO,
[140]1169 Params: []string{dc.nick, "*", "End of /WHO list"},
[127]1170 })
1171 return nil
1172 }
1173
1174 // TODO: support WHO masks
1175 entity := msg.Params[0]
1176
[142]1177 if entity == dc.nick {
1178 // TODO: support AWAY (H/G) in self WHO reply
1179 dc.SendMessage(&irc.Message{
1180 Prefix: dc.srv.prefix(),
1181 Command: irc.RPL_WHOREPLY,
[184]1182 Params: []string{dc.nick, "*", dc.user.Username, dc.hostname, dc.srv.Hostname, dc.nick, "H", "0 " + dc.realname},
[142]1183 })
1184 dc.SendMessage(&irc.Message{
1185 Prefix: dc.srv.prefix(),
1186 Command: irc.RPL_ENDOFWHO,
1187 Params: []string{dc.nick, dc.nick, "End of /WHO list"},
1188 })
1189 return nil
1190 }
1191
[127]1192 uc, upstreamName, err := dc.unmarshalEntity(entity)
1193 if err != nil {
1194 return err
1195 }
1196
1197 var params []string
1198 if len(msg.Params) == 2 {
1199 params = []string{upstreamName, msg.Params[1]}
1200 } else {
1201 params = []string{upstreamName}
1202 }
1203
[176]1204 uc.SendMessageLabeled(dc.id, &irc.Message{
[127]1205 Command: "WHO",
1206 Params: params,
1207 })
[128]1208 case "WHOIS":
1209 if len(msg.Params) == 0 {
1210 return ircError{&irc.Message{
1211 Command: irc.ERR_NONICKNAMEGIVEN,
1212 Params: []string{dc.nick, "No nickname given"},
1213 }}
1214 }
1215
1216 var target, mask string
1217 if len(msg.Params) == 1 {
1218 target = ""
1219 mask = msg.Params[0]
1220 } else {
1221 target = msg.Params[0]
1222 mask = msg.Params[1]
1223 }
1224 // TODO: support multiple WHOIS users
1225 if i := strings.IndexByte(mask, ','); i >= 0 {
1226 mask = mask[:i]
1227 }
1228
[142]1229 if mask == dc.nick {
1230 dc.SendMessage(&irc.Message{
1231 Prefix: dc.srv.prefix(),
1232 Command: irc.RPL_WHOISUSER,
[184]1233 Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname},
[142]1234 })
1235 dc.SendMessage(&irc.Message{
1236 Prefix: dc.srv.prefix(),
1237 Command: irc.RPL_WHOISSERVER,
1238 Params: []string{dc.nick, dc.nick, dc.srv.Hostname, "soju"},
1239 })
1240 dc.SendMessage(&irc.Message{
1241 Prefix: dc.srv.prefix(),
1242 Command: irc.RPL_ENDOFWHOIS,
1243 Params: []string{dc.nick, dc.nick, "End of /WHOIS list"},
1244 })
1245 return nil
1246 }
1247
[128]1248 // TODO: support WHOIS masks
1249 uc, upstreamNick, err := dc.unmarshalEntity(mask)
1250 if err != nil {
1251 return err
1252 }
1253
1254 var params []string
1255 if target != "" {
1256 params = []string{target, upstreamNick}
1257 } else {
1258 params = []string{upstreamNick}
1259 }
1260
[176]1261 uc.SendMessageLabeled(dc.id, &irc.Message{
[128]1262 Command: "WHOIS",
1263 Params: params,
1264 })
[58]1265 case "PRIVMSG":
1266 var targetsStr, text string
1267 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
1268 return err
1269 }
1270
1271 for _, name := range strings.Split(targetsStr, ",") {
[117]1272 if name == serviceNick {
1273 handleServicePRIVMSG(dc, text)
1274 continue
1275 }
1276
[127]1277 uc, upstreamName, err := dc.unmarshalEntity(name)
[58]1278 if err != nil {
1279 return err
1280 }
1281
[95]1282 if upstreamName == "NickServ" {
1283 dc.handleNickServPRIVMSG(uc, text)
1284 }
1285
[69]1286 uc.SendMessage(&irc.Message{
[58]1287 Command: "PRIVMSG",
[69]1288 Params: []string{upstreamName, text},
[60]1289 })
[105]1290
[113]1291 echoMsg := &irc.Message{
1292 Prefix: &irc.Prefix{
1293 Name: uc.nick,
1294 User: uc.username,
1295 },
[114]1296 Command: "PRIVMSG",
[113]1297 Params: []string{upstreamName, text},
1298 }
[105]1299 dc.lock.Lock()
[113]1300 dc.ourMessages[echoMsg] = struct{}{}
[105]1301 dc.lock.Unlock()
1302
[143]1303 uc.network.ring.Produce(echoMsg)
[58]1304 }
[164]1305 case "NOTICE":
1306 var targetsStr, text string
1307 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
1308 return err
1309 }
1310
1311 for _, name := range strings.Split(targetsStr, ",") {
1312 uc, upstreamName, err := dc.unmarshalEntity(name)
1313 if err != nil {
1314 return err
1315 }
1316
1317 uc.SendMessage(&irc.Message{
1318 Command: "NOTICE",
1319 Params: []string{upstreamName, text},
1320 })
1321 }
[163]1322 case "INVITE":
1323 var user, channel string
1324 if err := parseMessageParams(msg, &user, &channel); err != nil {
1325 return err
1326 }
1327
1328 ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
1329 if err != nil {
1330 return err
1331 }
1332
1333 ucUser, upstreamUser, err := dc.unmarshalEntity(user)
1334 if err != nil {
1335 return err
1336 }
1337
1338 if ucChannel != ucUser {
1339 return ircError{&irc.Message{
1340 Command: irc.ERR_USERNOTINCHANNEL,
1341 Params: []string{dc.nick, user, channel, "They aren't on that channel"},
1342 }}
1343 }
1344 uc := ucChannel
1345
[176]1346 uc.SendMessageLabeled(dc.id, &irc.Message{
[163]1347 Command: "INVITE",
1348 Params: []string{upstreamUser, upstreamChannel},
1349 })
[13]1350 default:
[55]1351 dc.logger.Printf("unhandled message: %v", msg)
[13]1352 return newUnknownCommandError(msg.Command)
1353 }
[42]1354 return nil
[13]1355}
[95]1356
1357func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
1358 username, password, ok := parseNickServCredentials(text, uc.nick)
1359 if !ok {
1360 return
1361 }
1362
1363 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
1364 n := uc.network
1365 n.SASL.Mechanism = "PLAIN"
1366 n.SASL.Plain.Username = username
1367 n.SASL.Plain.Password = password
1368 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
1369 dc.logger.Printf("failed to save NickServ credentials: %v", err)
1370 }
1371}
1372
1373func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
1374 fields := strings.Fields(text)
1375 if len(fields) < 2 {
1376 return "", "", false
1377 }
1378 cmd := strings.ToUpper(fields[0])
1379 params := fields[1:]
1380 switch cmd {
1381 case "REGISTER":
1382 username = nick
1383 password = params[0]
1384 case "IDENTIFY":
1385 if len(params) == 1 {
1386 username = nick
[182]1387 password = params[0]
[95]1388 } else {
1389 username = params[0]
[182]1390 password = params[1]
[95]1391 }
[182]1392 case "SET":
1393 if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") {
1394 username = nick
1395 password = params[1]
1396 }
[95]1397 }
1398 return username, password, true
1399}
Note: See TracBrowser for help on using the repository browser.