source: code/trunk/downstream.go@ 107

Last change on this file since 107 was 107, checked in by contact, 5 years ago

Don't reply to PING when client is unregistered

File size: 15.2 KB
RevLine 
[98]1package soju
[13]2
3import (
[91]4 "crypto/tls"
[13]5 "fmt"
6 "io"
7 "net"
[39]8 "strings"
[105]9 "sync"
[91]10 "time"
[13]11
[85]12 "golang.org/x/crypto/bcrypt"
[13]13 "gopkg.in/irc.v3"
14)
15
16type ircError struct {
17 Message *irc.Message
18}
19
[85]20func (err ircError) Error() string {
21 return err.Message.String()
22}
23
[13]24func newUnknownCommandError(cmd string) ircError {
25 return ircError{&irc.Message{
26 Command: irc.ERR_UNKNOWNCOMMAND,
27 Params: []string{
28 "*",
29 cmd,
30 "Unknown command",
31 },
32 }}
33}
34
35func newNeedMoreParamsError(cmd string) ircError {
36 return ircError{&irc.Message{
37 Command: irc.ERR_NEEDMOREPARAMS,
38 Params: []string{
39 "*",
40 cmd,
41 "Not enough parameters",
42 },
43 }}
44}
45
[85]46var errAuthFailed = ircError{&irc.Message{
47 Command: irc.ERR_PASSWDMISMATCH,
48 Params: []string{"*", "Invalid username or password"},
49}}
[13]50
[104]51type ringMessage struct {
[69]52 consumer *RingConsumer
53 upstreamConn *upstreamConn
54}
55
[13]56type downstreamConn struct {
[69]57 net net.Conn
58 irc *irc.Conn
59 srv *Server
60 logger Logger
[102]61 outgoing chan *irc.Message
[104]62 ringMessages chan ringMessage
[69]63 closed chan struct{}
[22]64
[100]65 registered bool
66 user *user
67 nick string
68 username string
69 rawUsername string
70 realname string
71 password string // empty after authentication
72 network *network // can be nil
[105]73
74 lock sync.Mutex
75 ourMessages map[*irc.Message]struct{}
[13]76}
77
[22]78func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
[55]79 dc := &downstreamConn{
[69]80 net: netConn,
81 irc: irc.NewConn(netConn),
82 srv: srv,
83 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
[102]84 outgoing: make(chan *irc.Message, 64),
[104]85 ringMessages: make(chan ringMessage),
[69]86 closed: make(chan struct{}),
[105]87 ourMessages: make(map[*irc.Message]struct{}),
[22]88 }
[26]89
90 go func() {
[56]91 if err := dc.writeMessages(); err != nil {
92 dc.logger.Printf("failed to write message: %v", err)
[26]93 }
[55]94 if err := dc.net.Close(); err != nil {
95 dc.logger.Printf("failed to close connection: %v", err)
[45]96 } else {
[55]97 dc.logger.Printf("connection closed")
[45]98 }
[26]99 }()
100
[55]101 return dc
[22]102}
103
[55]104func (dc *downstreamConn) prefix() *irc.Prefix {
[27]105 return &irc.Prefix{
[55]106 Name: dc.nick,
107 User: dc.username,
[27]108 // TODO: fill the host?
109 }
110}
111
[69]112func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
113 return name
114}
115
[90]116func (dc *downstreamConn) forEachNetwork(f func(*network)) {
117 if dc.network != nil {
118 f(dc.network)
119 } else {
120 dc.user.forEachNetwork(f)
121 }
122}
123
[73]124func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
125 dc.user.forEachUpstream(func(uc *upstreamConn) {
[77]126 if dc.network != nil && uc.network != dc.network {
[73]127 return
128 }
129 f(uc)
130 })
131}
132
[89]133// upstream returns the upstream connection, if any. If there are zero or if
134// there are multiple upstream connections, it returns nil.
135func (dc *downstreamConn) upstream() *upstreamConn {
136 if dc.network == nil {
137 return nil
138 }
139
140 var upstream *upstreamConn
141 dc.forEachUpstream(func(uc *upstreamConn) {
142 upstream = uc
143 })
144 return upstream
145}
146
[69]147func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
[89]148 if uc := dc.upstream(); uc != nil {
149 return uc, name, nil
150 }
151
[73]152 // TODO: extract network name from channel name if dc.upstream == nil
153 var channel *upstreamChannel
154 var err error
155 dc.forEachUpstream(func(uc *upstreamConn) {
156 if err != nil {
157 return
158 }
159 if ch, ok := uc.channels[name]; ok {
160 if channel != nil {
161 err = fmt.Errorf("ambiguous channel name %q", name)
162 } else {
163 channel = ch
164 }
165 }
166 })
167 if channel == nil {
168 return nil, "", ircError{&irc.Message{
169 Command: irc.ERR_NOSUCHCHANNEL,
170 Params: []string{name, "No such channel"},
171 }}
[69]172 }
[73]173 return channel.conn, channel.Name, nil
[69]174}
175
176func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
177 if nick == uc.nick {
178 return dc.nick
179 }
180 return nick
181}
182
183func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
184 if prefix.Name == uc.nick {
185 return dc.prefix()
186 }
187 return prefix
188}
189
[57]190func (dc *downstreamConn) isClosed() bool {
191 select {
192 case <-dc.closed:
193 return true
194 default:
195 return false
196 }
197}
198
[103]199func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
[55]200 dc.logger.Printf("new connection")
[22]201
202 for {
[55]203 msg, err := dc.irc.ReadMessage()
[22]204 if err == io.EOF {
205 break
206 } else if err != nil {
207 return fmt.Errorf("failed to read IRC command: %v", err)
208 }
209
[64]210 if dc.srv.Debug {
211 dc.logger.Printf("received: %v", msg)
212 }
213
[103]214 ch <- downstreamIncomingMessage{msg, dc}
[22]215 }
216
[45]217 return nil
[22]218}
219
[56]220func (dc *downstreamConn) writeMessages() error {
[57]221 for {
222 var err error
223 var closed bool
224 select {
[102]225 case msg := <-dc.outgoing:
[64]226 if dc.srv.Debug {
227 dc.logger.Printf("sent: %v", msg)
228 }
[57]229 err = dc.irc.WriteMessage(msg)
[104]230 case ringMessage := <-dc.ringMessages:
231 consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
[57]232 for {
233 msg := consumer.Peek()
234 if msg == nil {
235 break
236 }
[105]237
238 dc.lock.Lock()
239 _, ours := dc.ourMessages[msg]
240 delete(dc.ourMessages, msg)
241 dc.lock.Unlock()
242 if ours {
243 // The message comes from our connection, don't echo it
244 // back
245 continue
246 }
247
[69]248 msg = msg.Copy()
249 switch msg.Command {
250 case "PRIVMSG":
251 // TODO: detect whether it's a user or a channel
252 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
253 default:
254 panic("expected to consume a PRIVMSG message")
255 }
[64]256 if dc.srv.Debug {
257 dc.logger.Printf("sent: %v", msg)
258 }
[57]259 err = dc.irc.WriteMessage(msg)
260 if err != nil {
261 break
262 }
263 consumer.Consume()
264 }
265 case <-dc.closed:
266 closed = true
267 }
268 if err != nil {
[56]269 return err
270 }
[57]271 if closed {
272 break
273 }
[56]274 }
275 return nil
276}
277
[55]278func (dc *downstreamConn) Close() error {
[57]279 if dc.isClosed() {
[26]280 return fmt.Errorf("downstream connection already closed")
281 }
[40]282
[55]283 if u := dc.user; u != nil {
[40]284 u.lock.Lock()
285 for i := range u.downstreamConns {
[55]286 if u.downstreamConns[i] == dc {
[40]287 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
[63]288 break
[40]289 }
290 }
291 u.lock.Unlock()
[13]292 }
[40]293
[57]294 close(dc.closed)
[45]295 return nil
[13]296}
297
[55]298func (dc *downstreamConn) SendMessage(msg *irc.Message) {
[102]299 dc.outgoing <- msg
[54]300}
301
[55]302func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
[13]303 switch msg.Command {
[28]304 case "QUIT":
[55]305 return dc.Close()
[13]306 default:
[55]307 if dc.registered {
308 return dc.handleMessageRegistered(msg)
[13]309 } else {
[55]310 return dc.handleMessageUnregistered(msg)
[13]311 }
312 }
313}
314
[55]315func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
[13]316 switch msg.Command {
317 case "NICK":
[55]318 if err := parseMessageParams(msg, &dc.nick); err != nil {
[43]319 return err
[13]320 }
321 case "USER":
[43]322 var username string
[55]323 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
[43]324 return err
[13]325 }
[100]326 dc.rawUsername = username
[85]327 case "PASS":
328 if err := parseMessageParams(msg, &dc.password); err != nil {
329 return err
330 }
[13]331 default:
[55]332 dc.logger.Printf("unhandled message: %v", msg)
[13]333 return newUnknownCommandError(msg.Command)
334 }
[100]335 if dc.rawUsername != "" && dc.nick != "" {
[55]336 return dc.register()
[13]337 }
338 return nil
339}
340
[91]341func sanityCheckServer(addr string) error {
342 dialer := net.Dialer{Timeout: 30 * time.Second}
343 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
344 if err != nil {
345 return err
346 }
347 return conn.Close()
348}
349
[55]350func (dc *downstreamConn) register() error {
[100]351 username := dc.rawUsername
[77]352 var networkName string
[73]353 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
[77]354 networkName = username[i+1:]
[73]355 }
356 if i := strings.IndexAny(username, "/@"); i >= 0 {
357 username = username[:i]
358 }
[100]359 dc.username = "~" + username
[73]360
[85]361 password := dc.password
362 dc.password = ""
363
[73]364 u := dc.srv.getUser(username)
[38]365 if u == nil {
[85]366 dc.logger.Printf("failed authentication for %q: unknown username", username)
367 return errAuthFailed
[37]368 }
369
[85]370 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
371 if err != nil {
372 dc.logger.Printf("failed authentication for %q: %v", username, err)
373 return errAuthFailed
374 }
375
[88]376 var network *network
[77]377 if networkName != "" {
[88]378 network = u.getNetwork(networkName)
379 if network == nil {
[91]380 addr := networkName
381 if !strings.ContainsRune(addr, ':') {
382 addr = addr + ":6697"
383 }
384
[95]385 dc.logger.Printf("trying to connect to new network %q", addr)
[91]386 if err := sanityCheckServer(addr); err != nil {
387 dc.logger.Printf("failed to connect to %q: %v", addr, err)
388 return ircError{&irc.Message{
389 Command: irc.ERR_PASSWDMISMATCH,
390 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
391 }}
392 }
393
[95]394 dc.logger.Printf("auto-saving network %q", networkName)
[91]395 network, err = u.createNetwork(networkName, dc.nick)
396 if err != nil {
397 return err
398 }
[73]399 }
400 }
401
[55]402 dc.registered = true
403 dc.user = u
[88]404 dc.network = network
[13]405
[40]406 u.lock.Lock()
[57]407 firstDownstream := len(u.downstreamConns) == 0
[55]408 u.downstreamConns = append(u.downstreamConns, dc)
[40]409 u.lock.Unlock()
410
[55]411 dc.SendMessage(&irc.Message{
412 Prefix: dc.srv.prefix(),
[13]413 Command: irc.RPL_WELCOME,
[98]414 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
[54]415 })
[55]416 dc.SendMessage(&irc.Message{
417 Prefix: dc.srv.prefix(),
[13]418 Command: irc.RPL_YOURHOST,
[55]419 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
[54]420 })
[55]421 dc.SendMessage(&irc.Message{
422 Prefix: dc.srv.prefix(),
[13]423 Command: irc.RPL_CREATED,
[55]424 Params: []string{dc.nick, "Who cares when the server was created?"},
[54]425 })
[55]426 dc.SendMessage(&irc.Message{
427 Prefix: dc.srv.prefix(),
[13]428 Command: irc.RPL_MYINFO,
[98]429 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
[54]430 })
[93]431 // TODO: RPL_ISUPPORT
[55]432 dc.SendMessage(&irc.Message{
433 Prefix: dc.srv.prefix(),
[13]434 Command: irc.ERR_NOMOTD,
[55]435 Params: []string{dc.nick, "No MOTD"},
[54]436 })
[13]437
[73]438 dc.forEachUpstream(func(uc *upstreamConn) {
[30]439 for _, ch := range uc.channels {
440 if ch.complete {
[55]441 forwardChannel(dc, ch)
[30]442 }
443 }
[50]444
[73]445 historyName := dc.username
[57]446
447 var seqPtr *uint64
448 if firstDownstream {
449 seq, ok := uc.history[historyName]
450 if ok {
451 seqPtr = &seq
[50]452 }
453 }
[57]454
[59]455 consumer, ch := uc.ring.NewConsumer(seqPtr)
[57]456 go func() {
457 for {
458 var closed bool
459 select {
460 case <-ch:
[104]461 dc.ringMessages <- ringMessage{consumer, uc}
[57]462 case <-dc.closed:
463 closed = true
464 }
465 if closed {
466 break
467 }
468 }
469
470 seq := consumer.Close()
471
472 dc.user.lock.Lock()
473 lastDownstream := len(dc.user.downstreamConns) == 0
474 dc.user.lock.Unlock()
475
476 if lastDownstream {
477 uc.history[historyName] = seq
478 }
479 }()
[39]480 })
[50]481
[13]482 return nil
483}
484
[103]485func (dc *downstreamConn) runUntilRegistered() error {
486 for !dc.registered {
487 msg, err := dc.irc.ReadMessage()
[106]488 if err != nil {
[103]489 return fmt.Errorf("failed to read IRC command: %v", err)
490 }
491
492 err = dc.handleMessage(msg)
493 if ircErr, ok := err.(ircError); ok {
494 ircErr.Message.Prefix = dc.srv.prefix()
495 dc.SendMessage(ircErr.Message)
496 } else if err != nil {
497 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
498 }
499 }
500
501 return nil
502}
503
[55]504func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
[13]505 switch msg.Command {
[107]506 case "PING":
507 dc.SendMessage(&irc.Message{
508 Prefix: dc.srv.prefix(),
509 Command: "PONG",
510 Params: msg.Params,
511 })
512 return nil
[42]513 case "USER":
[13]514 return ircError{&irc.Message{
515 Command: irc.ERR_ALREADYREGISTERED,
[55]516 Params: []string{dc.nick, "You may not reregister"},
[13]517 }}
[42]518 case "NICK":
[90]519 var nick string
520 if err := parseMessageParams(msg, &nick); err != nil {
521 return err
522 }
523
524 var err error
525 dc.forEachNetwork(func(n *network) {
526 if err != nil {
527 return
528 }
529 n.Nick = nick
530 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
531 })
532 if err != nil {
533 return err
534 }
535
[73]536 dc.forEachUpstream(func(uc *upstreamConn) {
[60]537 uc.SendMessage(msg)
[42]538 })
[69]539 case "JOIN", "PART":
[48]540 var name string
541 if err := parseMessageParams(msg, &name); err != nil {
542 return err
543 }
544
[69]545 uc, upstreamName, err := dc.unmarshalChannel(name)
546 if err != nil {
547 return ircError{&irc.Message{
548 Command: irc.ERR_NOSUCHCHANNEL,
549 Params: []string{name, err.Error()},
550 }}
[48]551 }
552
[69]553 uc.SendMessage(&irc.Message{
554 Command: msg.Command,
555 Params: []string{upstreamName},
556 })
[89]557
558 switch msg.Command {
559 case "JOIN":
560 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
561 Name: upstreamName,
562 })
563 if err != nil {
564 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
565 }
566 case "PART":
567 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
568 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
569 }
570 }
[69]571 case "MODE":
572 if msg.Prefix == nil {
573 return fmt.Errorf("missing prefix")
[49]574 }
575
[46]576 var name string
577 if err := parseMessageParams(msg, &name); err != nil {
578 return err
579 }
580
581 var modeStr string
582 if len(msg.Params) > 1 {
583 modeStr = msg.Params[1]
584 }
585
586 if msg.Prefix.Name != name {
[69]587 uc, upstreamName, err := dc.unmarshalChannel(name)
[46]588 if err != nil {
589 return err
590 }
591
592 if modeStr != "" {
[69]593 uc.SendMessage(&irc.Message{
594 Command: "MODE",
595 Params: []string{upstreamName, modeStr},
596 })
[46]597 } else {
[69]598 ch, ok := uc.channels[upstreamName]
599 if !ok {
600 return ircError{&irc.Message{
601 Command: irc.ERR_NOSUCHCHANNEL,
602 Params: []string{name, "No such channel"},
603 }}
604 }
605
[55]606 dc.SendMessage(&irc.Message{
607 Prefix: dc.srv.prefix(),
[46]608 Command: irc.RPL_CHANNELMODEIS,
[69]609 Params: []string{name, string(ch.modes)},
[54]610 })
[46]611 }
612 } else {
[55]613 if name != dc.nick {
[46]614 return ircError{&irc.Message{
615 Command: irc.ERR_USERSDONTMATCH,
[55]616 Params: []string{dc.nick, "Cannot change mode for other users"},
[46]617 }}
618 }
619
620 if modeStr != "" {
[73]621 dc.forEachUpstream(func(uc *upstreamConn) {
[69]622 uc.SendMessage(&irc.Message{
623 Command: "MODE",
624 Params: []string{uc.nick, modeStr},
625 })
[46]626 })
627 } else {
[55]628 dc.SendMessage(&irc.Message{
629 Prefix: dc.srv.prefix(),
[46]630 Command: irc.RPL_UMODEIS,
631 Params: []string{""}, // TODO
[54]632 })
[46]633 }
634 }
[58]635 case "PRIVMSG":
636 var targetsStr, text string
637 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
638 return err
639 }
640
641 for _, name := range strings.Split(targetsStr, ",") {
[69]642 uc, upstreamName, err := dc.unmarshalChannel(name)
[58]643 if err != nil {
644 return err
645 }
646
[95]647 if upstreamName == "NickServ" {
648 dc.handleNickServPRIVMSG(uc, text)
649 }
650
[69]651 uc.SendMessage(&irc.Message{
[58]652 Command: "PRIVMSG",
[69]653 Params: []string{upstreamName, text},
[60]654 })
[105]655
656 dc.lock.Lock()
657 dc.ourMessages[msg] = struct{}{}
658 dc.lock.Unlock()
659
660 uc.ring.Produce(msg)
[58]661 }
[13]662 default:
[55]663 dc.logger.Printf("unhandled message: %v", msg)
[13]664 return newUnknownCommandError(msg.Command)
665 }
[42]666 return nil
[13]667}
[95]668
669func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
670 username, password, ok := parseNickServCredentials(text, uc.nick)
671 if !ok {
672 return
673 }
674
675 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
676 n := uc.network
677 n.SASL.Mechanism = "PLAIN"
678 n.SASL.Plain.Username = username
679 n.SASL.Plain.Password = password
680 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
681 dc.logger.Printf("failed to save NickServ credentials: %v", err)
682 }
683}
684
685func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
686 fields := strings.Fields(text)
687 if len(fields) < 2 {
688 return "", "", false
689 }
690 cmd := strings.ToUpper(fields[0])
691 params := fields[1:]
692 switch cmd {
693 case "REGISTER":
694 username = nick
695 password = params[0]
696 case "IDENTIFY":
697 if len(params) == 1 {
698 username = nick
699 } else {
700 username = params[0]
701 }
702 password = params[1]
703 }
704 return username, password, true
705}
Note: See TracBrowser for help on using the repository browser.