source: code/trunk/downstream.go@ 106

Last change on this file since 106 was 106, checked in by contact, 5 years ago

Make downstreamConn.runUntilegistered exit with an error on EOF

File size: 15.2 KB
RevLine 
[98]1package soju
[13]2
3import (
[91]4 "crypto/tls"
[13]5 "fmt"
6 "io"
7 "net"
[39]8 "strings"
[105]9 "sync"
[91]10 "time"
[13]11
[85]12 "golang.org/x/crypto/bcrypt"
[13]13 "gopkg.in/irc.v3"
14)
15
16type ircError struct {
17 Message *irc.Message
18}
19
[85]20func (err ircError) Error() string {
21 return err.Message.String()
22}
23
[13]24func newUnknownCommandError(cmd string) ircError {
25 return ircError{&irc.Message{
26 Command: irc.ERR_UNKNOWNCOMMAND,
27 Params: []string{
28 "*",
29 cmd,
30 "Unknown command",
31 },
32 }}
33}
34
35func newNeedMoreParamsError(cmd string) ircError {
36 return ircError{&irc.Message{
37 Command: irc.ERR_NEEDMOREPARAMS,
38 Params: []string{
39 "*",
40 cmd,
41 "Not enough parameters",
42 },
43 }}
44}
45
[85]46var errAuthFailed = ircError{&irc.Message{
47 Command: irc.ERR_PASSWDMISMATCH,
48 Params: []string{"*", "Invalid username or password"},
49}}
[13]50
[104]51type ringMessage struct {
[69]52 consumer *RingConsumer
53 upstreamConn *upstreamConn
54}
55
[13]56type downstreamConn struct {
[69]57 net net.Conn
58 irc *irc.Conn
59 srv *Server
60 logger Logger
[102]61 outgoing chan *irc.Message
[104]62 ringMessages chan ringMessage
[69]63 closed chan struct{}
[22]64
[100]65 registered bool
66 user *user
67 nick string
68 username string
69 rawUsername string
70 realname string
71 password string // empty after authentication
72 network *network // can be nil
[105]73
74 lock sync.Mutex
75 ourMessages map[*irc.Message]struct{}
[13]76}
77
[22]78func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
[55]79 dc := &downstreamConn{
[69]80 net: netConn,
81 irc: irc.NewConn(netConn),
82 srv: srv,
83 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
[102]84 outgoing: make(chan *irc.Message, 64),
[104]85 ringMessages: make(chan ringMessage),
[69]86 closed: make(chan struct{}),
[105]87 ourMessages: make(map[*irc.Message]struct{}),
[22]88 }
[26]89
90 go func() {
[56]91 if err := dc.writeMessages(); err != nil {
92 dc.logger.Printf("failed to write message: %v", err)
[26]93 }
[55]94 if err := dc.net.Close(); err != nil {
95 dc.logger.Printf("failed to close connection: %v", err)
[45]96 } else {
[55]97 dc.logger.Printf("connection closed")
[45]98 }
[26]99 }()
100
[55]101 return dc
[22]102}
103
[55]104func (dc *downstreamConn) prefix() *irc.Prefix {
[27]105 return &irc.Prefix{
[55]106 Name: dc.nick,
107 User: dc.username,
[27]108 // TODO: fill the host?
109 }
110}
111
[69]112func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
113 return name
114}
115
[90]116func (dc *downstreamConn) forEachNetwork(f func(*network)) {
117 if dc.network != nil {
118 f(dc.network)
119 } else {
120 dc.user.forEachNetwork(f)
121 }
122}
123
[73]124func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
125 dc.user.forEachUpstream(func(uc *upstreamConn) {
[77]126 if dc.network != nil && uc.network != dc.network {
[73]127 return
128 }
129 f(uc)
130 })
131}
132
[89]133// upstream returns the upstream connection, if any. If there are zero or if
134// there are multiple upstream connections, it returns nil.
135func (dc *downstreamConn) upstream() *upstreamConn {
136 if dc.network == nil {
137 return nil
138 }
139
140 var upstream *upstreamConn
141 dc.forEachUpstream(func(uc *upstreamConn) {
142 upstream = uc
143 })
144 return upstream
145}
146
[69]147func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
[89]148 if uc := dc.upstream(); uc != nil {
149 return uc, name, nil
150 }
151
[73]152 // TODO: extract network name from channel name if dc.upstream == nil
153 var channel *upstreamChannel
154 var err error
155 dc.forEachUpstream(func(uc *upstreamConn) {
156 if err != nil {
157 return
158 }
159 if ch, ok := uc.channels[name]; ok {
160 if channel != nil {
161 err = fmt.Errorf("ambiguous channel name %q", name)
162 } else {
163 channel = ch
164 }
165 }
166 })
167 if channel == nil {
168 return nil, "", ircError{&irc.Message{
169 Command: irc.ERR_NOSUCHCHANNEL,
170 Params: []string{name, "No such channel"},
171 }}
[69]172 }
[73]173 return channel.conn, channel.Name, nil
[69]174}
175
176func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
177 if nick == uc.nick {
178 return dc.nick
179 }
180 return nick
181}
182
183func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
184 if prefix.Name == uc.nick {
185 return dc.prefix()
186 }
187 return prefix
188}
189
[57]190func (dc *downstreamConn) isClosed() bool {
191 select {
192 case <-dc.closed:
193 return true
194 default:
195 return false
196 }
197}
198
[103]199func (dc *downstreamConn) readMessages(ch chan<- downstreamIncomingMessage) error {
[55]200 dc.logger.Printf("new connection")
[22]201
202 for {
[55]203 msg, err := dc.irc.ReadMessage()
[22]204 if err == io.EOF {
205 break
206 } else if err != nil {
207 return fmt.Errorf("failed to read IRC command: %v", err)
208 }
209
[64]210 if dc.srv.Debug {
211 dc.logger.Printf("received: %v", msg)
212 }
213
[103]214 ch <- downstreamIncomingMessage{msg, dc}
[22]215 }
216
[45]217 return nil
[22]218}
219
[56]220func (dc *downstreamConn) writeMessages() error {
[57]221 for {
222 var err error
223 var closed bool
224 select {
[102]225 case msg := <-dc.outgoing:
[64]226 if dc.srv.Debug {
227 dc.logger.Printf("sent: %v", msg)
228 }
[57]229 err = dc.irc.WriteMessage(msg)
[104]230 case ringMessage := <-dc.ringMessages:
231 consumer, uc := ringMessage.consumer, ringMessage.upstreamConn
[57]232 for {
233 msg := consumer.Peek()
234 if msg == nil {
235 break
236 }
[105]237
238 dc.lock.Lock()
239 _, ours := dc.ourMessages[msg]
240 delete(dc.ourMessages, msg)
241 dc.lock.Unlock()
242 if ours {
243 // The message comes from our connection, don't echo it
244 // back
245 continue
246 }
247
[69]248 msg = msg.Copy()
249 switch msg.Command {
250 case "PRIVMSG":
251 // TODO: detect whether it's a user or a channel
252 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
253 default:
254 panic("expected to consume a PRIVMSG message")
255 }
[64]256 if dc.srv.Debug {
257 dc.logger.Printf("sent: %v", msg)
258 }
[57]259 err = dc.irc.WriteMessage(msg)
260 if err != nil {
261 break
262 }
263 consumer.Consume()
264 }
265 case <-dc.closed:
266 closed = true
267 }
268 if err != nil {
[56]269 return err
270 }
[57]271 if closed {
272 break
273 }
[56]274 }
275 return nil
276}
277
[55]278func (dc *downstreamConn) Close() error {
[57]279 if dc.isClosed() {
[26]280 return fmt.Errorf("downstream connection already closed")
281 }
[40]282
[55]283 if u := dc.user; u != nil {
[40]284 u.lock.Lock()
285 for i := range u.downstreamConns {
[55]286 if u.downstreamConns[i] == dc {
[40]287 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
[63]288 break
[40]289 }
290 }
291 u.lock.Unlock()
[13]292 }
[40]293
[57]294 close(dc.closed)
[45]295 return nil
[13]296}
297
[55]298func (dc *downstreamConn) SendMessage(msg *irc.Message) {
[102]299 dc.outgoing <- msg
[54]300}
301
[55]302func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
[13]303 switch msg.Command {
[28]304 case "QUIT":
[55]305 return dc.Close()
[13]306 case "PING":
[55]307 dc.SendMessage(&irc.Message{
308 Prefix: dc.srv.prefix(),
[13]309 Command: "PONG",
[68]310 Params: msg.Params,
[54]311 })
[26]312 return nil
[13]313 default:
[55]314 if dc.registered {
315 return dc.handleMessageRegistered(msg)
[13]316 } else {
[55]317 return dc.handleMessageUnregistered(msg)
[13]318 }
319 }
320}
321
[55]322func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
[13]323 switch msg.Command {
324 case "NICK":
[55]325 if err := parseMessageParams(msg, &dc.nick); err != nil {
[43]326 return err
[13]327 }
328 case "USER":
[43]329 var username string
[55]330 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
[43]331 return err
[13]332 }
[100]333 dc.rawUsername = username
[85]334 case "PASS":
335 if err := parseMessageParams(msg, &dc.password); err != nil {
336 return err
337 }
[13]338 default:
[55]339 dc.logger.Printf("unhandled message: %v", msg)
[13]340 return newUnknownCommandError(msg.Command)
341 }
[100]342 if dc.rawUsername != "" && dc.nick != "" {
[55]343 return dc.register()
[13]344 }
345 return nil
346}
347
[91]348func sanityCheckServer(addr string) error {
349 dialer := net.Dialer{Timeout: 30 * time.Second}
350 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
351 if err != nil {
352 return err
353 }
354 return conn.Close()
355}
356
[55]357func (dc *downstreamConn) register() error {
[100]358 username := dc.rawUsername
[77]359 var networkName string
[73]360 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
[77]361 networkName = username[i+1:]
[73]362 }
363 if i := strings.IndexAny(username, "/@"); i >= 0 {
364 username = username[:i]
365 }
[100]366 dc.username = "~" + username
[73]367
[85]368 password := dc.password
369 dc.password = ""
370
[73]371 u := dc.srv.getUser(username)
[38]372 if u == nil {
[85]373 dc.logger.Printf("failed authentication for %q: unknown username", username)
374 return errAuthFailed
[37]375 }
376
[85]377 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
378 if err != nil {
379 dc.logger.Printf("failed authentication for %q: %v", username, err)
380 return errAuthFailed
381 }
382
[88]383 var network *network
[77]384 if networkName != "" {
[88]385 network = u.getNetwork(networkName)
386 if network == nil {
[91]387 addr := networkName
388 if !strings.ContainsRune(addr, ':') {
389 addr = addr + ":6697"
390 }
391
[95]392 dc.logger.Printf("trying to connect to new network %q", addr)
[91]393 if err := sanityCheckServer(addr); err != nil {
394 dc.logger.Printf("failed to connect to %q: %v", addr, err)
395 return ircError{&irc.Message{
396 Command: irc.ERR_PASSWDMISMATCH,
397 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
398 }}
399 }
400
[95]401 dc.logger.Printf("auto-saving network %q", networkName)
[91]402 network, err = u.createNetwork(networkName, dc.nick)
403 if err != nil {
404 return err
405 }
[73]406 }
407 }
408
[55]409 dc.registered = true
410 dc.user = u
[88]411 dc.network = network
[13]412
[40]413 u.lock.Lock()
[57]414 firstDownstream := len(u.downstreamConns) == 0
[55]415 u.downstreamConns = append(u.downstreamConns, dc)
[40]416 u.lock.Unlock()
417
[55]418 dc.SendMessage(&irc.Message{
419 Prefix: dc.srv.prefix(),
[13]420 Command: irc.RPL_WELCOME,
[98]421 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
[54]422 })
[55]423 dc.SendMessage(&irc.Message{
424 Prefix: dc.srv.prefix(),
[13]425 Command: irc.RPL_YOURHOST,
[55]426 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
[54]427 })
[55]428 dc.SendMessage(&irc.Message{
429 Prefix: dc.srv.prefix(),
[13]430 Command: irc.RPL_CREATED,
[55]431 Params: []string{dc.nick, "Who cares when the server was created?"},
[54]432 })
[55]433 dc.SendMessage(&irc.Message{
434 Prefix: dc.srv.prefix(),
[13]435 Command: irc.RPL_MYINFO,
[98]436 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
[54]437 })
[93]438 // TODO: RPL_ISUPPORT
[55]439 dc.SendMessage(&irc.Message{
440 Prefix: dc.srv.prefix(),
[13]441 Command: irc.ERR_NOMOTD,
[55]442 Params: []string{dc.nick, "No MOTD"},
[54]443 })
[13]444
[73]445 dc.forEachUpstream(func(uc *upstreamConn) {
[30]446 for _, ch := range uc.channels {
447 if ch.complete {
[55]448 forwardChannel(dc, ch)
[30]449 }
450 }
[50]451
[73]452 historyName := dc.username
[57]453
454 var seqPtr *uint64
455 if firstDownstream {
456 seq, ok := uc.history[historyName]
457 if ok {
458 seqPtr = &seq
[50]459 }
460 }
[57]461
[59]462 consumer, ch := uc.ring.NewConsumer(seqPtr)
[57]463 go func() {
464 for {
465 var closed bool
466 select {
467 case <-ch:
[104]468 dc.ringMessages <- ringMessage{consumer, uc}
[57]469 case <-dc.closed:
470 closed = true
471 }
472 if closed {
473 break
474 }
475 }
476
477 seq := consumer.Close()
478
479 dc.user.lock.Lock()
480 lastDownstream := len(dc.user.downstreamConns) == 0
481 dc.user.lock.Unlock()
482
483 if lastDownstream {
484 uc.history[historyName] = seq
485 }
486 }()
[39]487 })
[50]488
[13]489 return nil
490}
491
[103]492func (dc *downstreamConn) runUntilRegistered() error {
493 for !dc.registered {
494 msg, err := dc.irc.ReadMessage()
[106]495 if err != nil {
[103]496 return fmt.Errorf("failed to read IRC command: %v", err)
497 }
498
499 err = dc.handleMessage(msg)
500 if ircErr, ok := err.(ircError); ok {
501 ircErr.Message.Prefix = dc.srv.prefix()
502 dc.SendMessage(ircErr.Message)
503 } else if err != nil {
504 return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
505 }
506 }
507
508 return nil
509}
510
[55]511func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
[13]512 switch msg.Command {
[42]513 case "USER":
[13]514 return ircError{&irc.Message{
515 Command: irc.ERR_ALREADYREGISTERED,
[55]516 Params: []string{dc.nick, "You may not reregister"},
[13]517 }}
[42]518 case "NICK":
[90]519 var nick string
520 if err := parseMessageParams(msg, &nick); err != nil {
521 return err
522 }
523
524 var err error
525 dc.forEachNetwork(func(n *network) {
526 if err != nil {
527 return
528 }
529 n.Nick = nick
530 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
531 })
532 if err != nil {
533 return err
534 }
535
[73]536 dc.forEachUpstream(func(uc *upstreamConn) {
[60]537 uc.SendMessage(msg)
[42]538 })
[69]539 case "JOIN", "PART":
[48]540 var name string
541 if err := parseMessageParams(msg, &name); err != nil {
542 return err
543 }
544
[69]545 uc, upstreamName, err := dc.unmarshalChannel(name)
546 if err != nil {
547 return ircError{&irc.Message{
548 Command: irc.ERR_NOSUCHCHANNEL,
549 Params: []string{name, err.Error()},
550 }}
[48]551 }
552
[69]553 uc.SendMessage(&irc.Message{
554 Command: msg.Command,
555 Params: []string{upstreamName},
556 })
[89]557
558 switch msg.Command {
559 case "JOIN":
560 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
561 Name: upstreamName,
562 })
563 if err != nil {
564 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
565 }
566 case "PART":
567 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
568 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
569 }
570 }
[69]571 case "MODE":
572 if msg.Prefix == nil {
573 return fmt.Errorf("missing prefix")
[49]574 }
575
[46]576 var name string
577 if err := parseMessageParams(msg, &name); err != nil {
578 return err
579 }
580
581 var modeStr string
582 if len(msg.Params) > 1 {
583 modeStr = msg.Params[1]
584 }
585
586 if msg.Prefix.Name != name {
[69]587 uc, upstreamName, err := dc.unmarshalChannel(name)
[46]588 if err != nil {
589 return err
590 }
591
592 if modeStr != "" {
[69]593 uc.SendMessage(&irc.Message{
594 Command: "MODE",
595 Params: []string{upstreamName, modeStr},
596 })
[46]597 } else {
[69]598 ch, ok := uc.channels[upstreamName]
599 if !ok {
600 return ircError{&irc.Message{
601 Command: irc.ERR_NOSUCHCHANNEL,
602 Params: []string{name, "No such channel"},
603 }}
604 }
605
[55]606 dc.SendMessage(&irc.Message{
607 Prefix: dc.srv.prefix(),
[46]608 Command: irc.RPL_CHANNELMODEIS,
[69]609 Params: []string{name, string(ch.modes)},
[54]610 })
[46]611 }
612 } else {
[55]613 if name != dc.nick {
[46]614 return ircError{&irc.Message{
615 Command: irc.ERR_USERSDONTMATCH,
[55]616 Params: []string{dc.nick, "Cannot change mode for other users"},
[46]617 }}
618 }
619
620 if modeStr != "" {
[73]621 dc.forEachUpstream(func(uc *upstreamConn) {
[69]622 uc.SendMessage(&irc.Message{
623 Command: "MODE",
624 Params: []string{uc.nick, modeStr},
625 })
[46]626 })
627 } else {
[55]628 dc.SendMessage(&irc.Message{
629 Prefix: dc.srv.prefix(),
[46]630 Command: irc.RPL_UMODEIS,
631 Params: []string{""}, // TODO
[54]632 })
[46]633 }
634 }
[58]635 case "PRIVMSG":
636 var targetsStr, text string
637 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
638 return err
639 }
640
641 for _, name := range strings.Split(targetsStr, ",") {
[69]642 uc, upstreamName, err := dc.unmarshalChannel(name)
[58]643 if err != nil {
644 return err
645 }
646
[95]647 if upstreamName == "NickServ" {
648 dc.handleNickServPRIVMSG(uc, text)
649 }
650
[69]651 uc.SendMessage(&irc.Message{
[58]652 Command: "PRIVMSG",
[69]653 Params: []string{upstreamName, text},
[60]654 })
[105]655
656 dc.lock.Lock()
657 dc.ourMessages[msg] = struct{}{}
658 dc.lock.Unlock()
659
660 uc.ring.Produce(msg)
[58]661 }
[13]662 default:
[55]663 dc.logger.Printf("unhandled message: %v", msg)
[13]664 return newUnknownCommandError(msg.Command)
665 }
[42]666 return nil
[13]667}
[95]668
669func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
670 username, password, ok := parseNickServCredentials(text, uc.nick)
671 if !ok {
672 return
673 }
674
675 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
676 n := uc.network
677 n.SASL.Mechanism = "PLAIN"
678 n.SASL.Plain.Username = username
679 n.SASL.Plain.Password = password
680 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
681 dc.logger.Printf("failed to save NickServ credentials: %v", err)
682 }
683}
684
685func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
686 fields := strings.Fields(text)
687 if len(fields) < 2 {
688 return "", "", false
689 }
690 cmd := strings.ToUpper(fields[0])
691 params := fields[1:]
692 switch cmd {
693 case "REGISTER":
694 username = nick
695 password = params[0]
696 case "IDENTIFY":
697 if len(params) == 1 {
698 username = nick
699 } else {
700 username = params[0]
701 }
702 password = params[1]
703 }
704 return username, password, true
705}
Note: See TracBrowser for help on using the repository browser.