source: code/trunk/downstream.go@ 102

Last change on this file since 102 was 102, checked in by contact, 5 years ago

Rename messages channels to outgoing

File size: 14.6 KB
RevLine 
[98]1package soju
[13]2
3import (
[91]4 "crypto/tls"
[13]5 "fmt"
6 "io"
7 "net"
[39]8 "strings"
[91]9 "time"
[13]10
[85]11 "golang.org/x/crypto/bcrypt"
[13]12 "gopkg.in/irc.v3"
13)
14
15type ircError struct {
16 Message *irc.Message
17}
18
[85]19func (err ircError) Error() string {
20 return err.Message.String()
21}
22
[13]23func newUnknownCommandError(cmd string) ircError {
24 return ircError{&irc.Message{
25 Command: irc.ERR_UNKNOWNCOMMAND,
26 Params: []string{
27 "*",
28 cmd,
29 "Unknown command",
30 },
31 }}
32}
33
34func newNeedMoreParamsError(cmd string) ircError {
35 return ircError{&irc.Message{
36 Command: irc.ERR_NEEDMOREPARAMS,
37 Params: []string{
38 "*",
39 cmd,
40 "Not enough parameters",
41 },
42 }}
43}
44
[85]45var errAuthFailed = ircError{&irc.Message{
46 Command: irc.ERR_PASSWDMISMATCH,
47 Params: []string{"*", "Invalid username or password"},
48}}
[13]49
[69]50type consumption struct {
51 consumer *RingConsumer
52 upstreamConn *upstreamConn
53}
54
[13]55type downstreamConn struct {
[69]56 net net.Conn
57 irc *irc.Conn
58 srv *Server
59 logger Logger
[102]60 outgoing chan *irc.Message
[69]61 consumptions chan consumption
62 closed chan struct{}
[22]63
[100]64 registered bool
65 user *user
66 nick string
67 username string
68 rawUsername string
69 realname string
70 password string // empty after authentication
71 network *network // can be nil
[13]72}
73
[22]74func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
[55]75 dc := &downstreamConn{
[69]76 net: netConn,
77 irc: irc.NewConn(netConn),
78 srv: srv,
79 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
[102]80 outgoing: make(chan *irc.Message, 64),
[69]81 consumptions: make(chan consumption),
82 closed: make(chan struct{}),
[22]83 }
[26]84
85 go func() {
[56]86 if err := dc.writeMessages(); err != nil {
87 dc.logger.Printf("failed to write message: %v", err)
[26]88 }
[55]89 if err := dc.net.Close(); err != nil {
90 dc.logger.Printf("failed to close connection: %v", err)
[45]91 } else {
[55]92 dc.logger.Printf("connection closed")
[45]93 }
[26]94 }()
95
[55]96 return dc
[22]97}
98
[55]99func (dc *downstreamConn) prefix() *irc.Prefix {
[27]100 return &irc.Prefix{
[55]101 Name: dc.nick,
102 User: dc.username,
[27]103 // TODO: fill the host?
104 }
105}
106
[69]107func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
108 return name
109}
110
[90]111func (dc *downstreamConn) forEachNetwork(f func(*network)) {
112 if dc.network != nil {
113 f(dc.network)
114 } else {
115 dc.user.forEachNetwork(f)
116 }
117}
118
[73]119func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
120 dc.user.forEachUpstream(func(uc *upstreamConn) {
[77]121 if dc.network != nil && uc.network != dc.network {
[73]122 return
123 }
124 f(uc)
125 })
126}
127
[89]128// upstream returns the upstream connection, if any. If there are zero or if
129// there are multiple upstream connections, it returns nil.
130func (dc *downstreamConn) upstream() *upstreamConn {
131 if dc.network == nil {
132 return nil
133 }
134
135 var upstream *upstreamConn
136 dc.forEachUpstream(func(uc *upstreamConn) {
137 upstream = uc
138 })
139 return upstream
140}
141
[69]142func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
[89]143 if uc := dc.upstream(); uc != nil {
144 return uc, name, nil
145 }
146
[73]147 // TODO: extract network name from channel name if dc.upstream == nil
148 var channel *upstreamChannel
149 var err error
150 dc.forEachUpstream(func(uc *upstreamConn) {
151 if err != nil {
152 return
153 }
154 if ch, ok := uc.channels[name]; ok {
155 if channel != nil {
156 err = fmt.Errorf("ambiguous channel name %q", name)
157 } else {
158 channel = ch
159 }
160 }
161 })
162 if channel == nil {
163 return nil, "", ircError{&irc.Message{
164 Command: irc.ERR_NOSUCHCHANNEL,
165 Params: []string{name, "No such channel"},
166 }}
[69]167 }
[73]168 return channel.conn, channel.Name, nil
[69]169}
170
171func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
172 if nick == uc.nick {
173 return dc.nick
174 }
175 return nick
176}
177
178func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
179 if prefix.Name == uc.nick {
180 return dc.prefix()
181 }
182 return prefix
183}
184
[57]185func (dc *downstreamConn) isClosed() bool {
186 select {
187 case <-dc.closed:
188 return true
189 default:
190 return false
191 }
192}
193
[55]194func (dc *downstreamConn) readMessages() error {
195 dc.logger.Printf("new connection")
[22]196
197 for {
[55]198 msg, err := dc.irc.ReadMessage()
[22]199 if err == io.EOF {
200 break
201 } else if err != nil {
202 return fmt.Errorf("failed to read IRC command: %v", err)
203 }
204
[64]205 if dc.srv.Debug {
206 dc.logger.Printf("received: %v", msg)
207 }
208
[55]209 err = dc.handleMessage(msg)
[22]210 if ircErr, ok := err.(ircError); ok {
[55]211 ircErr.Message.Prefix = dc.srv.prefix()
212 dc.SendMessage(ircErr.Message)
[22]213 } else if err != nil {
214 return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
215 }
216
[57]217 if dc.isClosed() {
[22]218 return nil
219 }
220 }
221
[45]222 return nil
[22]223}
224
[56]225func (dc *downstreamConn) writeMessages() error {
[57]226 for {
227 var err error
228 var closed bool
229 select {
[102]230 case msg := <-dc.outgoing:
[64]231 if dc.srv.Debug {
232 dc.logger.Printf("sent: %v", msg)
233 }
[57]234 err = dc.irc.WriteMessage(msg)
[69]235 case consumption := <-dc.consumptions:
236 consumer, uc := consumption.consumer, consumption.upstreamConn
[57]237 for {
238 msg := consumer.Peek()
239 if msg == nil {
240 break
241 }
[69]242 msg = msg.Copy()
243 switch msg.Command {
244 case "PRIVMSG":
245 // TODO: detect whether it's a user or a channel
246 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
247 default:
248 panic("expected to consume a PRIVMSG message")
249 }
[64]250 if dc.srv.Debug {
251 dc.logger.Printf("sent: %v", msg)
252 }
[57]253 err = dc.irc.WriteMessage(msg)
254 if err != nil {
255 break
256 }
257 consumer.Consume()
258 }
259 case <-dc.closed:
260 closed = true
261 }
262 if err != nil {
[56]263 return err
264 }
[57]265 if closed {
266 break
267 }
[56]268 }
269 return nil
270}
271
[55]272func (dc *downstreamConn) Close() error {
[57]273 if dc.isClosed() {
[26]274 return fmt.Errorf("downstream connection already closed")
275 }
[40]276
[55]277 if u := dc.user; u != nil {
[40]278 u.lock.Lock()
279 for i := range u.downstreamConns {
[55]280 if u.downstreamConns[i] == dc {
[40]281 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
[63]282 break
[40]283 }
284 }
285 u.lock.Unlock()
[13]286 }
[40]287
[57]288 close(dc.closed)
[45]289 return nil
[13]290}
291
[55]292func (dc *downstreamConn) SendMessage(msg *irc.Message) {
[102]293 dc.outgoing <- msg
[54]294}
295
[55]296func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
[13]297 switch msg.Command {
[28]298 case "QUIT":
[55]299 return dc.Close()
[13]300 case "PING":
[55]301 dc.SendMessage(&irc.Message{
302 Prefix: dc.srv.prefix(),
[13]303 Command: "PONG",
[68]304 Params: msg.Params,
[54]305 })
[26]306 return nil
[13]307 default:
[55]308 if dc.registered {
309 return dc.handleMessageRegistered(msg)
[13]310 } else {
[55]311 return dc.handleMessageUnregistered(msg)
[13]312 }
313 }
314}
315
[55]316func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
[13]317 switch msg.Command {
318 case "NICK":
[55]319 if err := parseMessageParams(msg, &dc.nick); err != nil {
[43]320 return err
[13]321 }
322 case "USER":
[43]323 var username string
[55]324 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
[43]325 return err
[13]326 }
[100]327 dc.rawUsername = username
[85]328 case "PASS":
329 if err := parseMessageParams(msg, &dc.password); err != nil {
330 return err
331 }
[13]332 default:
[55]333 dc.logger.Printf("unhandled message: %v", msg)
[13]334 return newUnknownCommandError(msg.Command)
335 }
[100]336 if dc.rawUsername != "" && dc.nick != "" {
[55]337 return dc.register()
[13]338 }
339 return nil
340}
341
[91]342func sanityCheckServer(addr string) error {
343 dialer := net.Dialer{Timeout: 30 * time.Second}
344 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
345 if err != nil {
346 return err
347 }
348 return conn.Close()
349}
350
[55]351func (dc *downstreamConn) register() error {
[100]352 username := dc.rawUsername
[77]353 var networkName string
[73]354 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
[77]355 networkName = username[i+1:]
[73]356 }
357 if i := strings.IndexAny(username, "/@"); i >= 0 {
358 username = username[:i]
359 }
[100]360 dc.username = "~" + username
[73]361
[85]362 password := dc.password
363 dc.password = ""
364
[73]365 u := dc.srv.getUser(username)
[38]366 if u == nil {
[85]367 dc.logger.Printf("failed authentication for %q: unknown username", username)
368 return errAuthFailed
[37]369 }
370
[85]371 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
372 if err != nil {
373 dc.logger.Printf("failed authentication for %q: %v", username, err)
374 return errAuthFailed
375 }
376
[88]377 var network *network
[77]378 if networkName != "" {
[88]379 network = u.getNetwork(networkName)
380 if network == nil {
[91]381 addr := networkName
382 if !strings.ContainsRune(addr, ':') {
383 addr = addr + ":6697"
384 }
385
[95]386 dc.logger.Printf("trying to connect to new network %q", addr)
[91]387 if err := sanityCheckServer(addr); err != nil {
388 dc.logger.Printf("failed to connect to %q: %v", addr, err)
389 return ircError{&irc.Message{
390 Command: irc.ERR_PASSWDMISMATCH,
391 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
392 }}
393 }
394
[95]395 dc.logger.Printf("auto-saving network %q", networkName)
[91]396 network, err = u.createNetwork(networkName, dc.nick)
397 if err != nil {
398 return err
399 }
[73]400 }
401 }
402
[55]403 dc.registered = true
404 dc.user = u
[88]405 dc.network = network
[13]406
[40]407 u.lock.Lock()
[57]408 firstDownstream := len(u.downstreamConns) == 0
[55]409 u.downstreamConns = append(u.downstreamConns, dc)
[40]410 u.lock.Unlock()
411
[55]412 dc.SendMessage(&irc.Message{
413 Prefix: dc.srv.prefix(),
[13]414 Command: irc.RPL_WELCOME,
[98]415 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
[54]416 })
[55]417 dc.SendMessage(&irc.Message{
418 Prefix: dc.srv.prefix(),
[13]419 Command: irc.RPL_YOURHOST,
[55]420 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
[54]421 })
[55]422 dc.SendMessage(&irc.Message{
423 Prefix: dc.srv.prefix(),
[13]424 Command: irc.RPL_CREATED,
[55]425 Params: []string{dc.nick, "Who cares when the server was created?"},
[54]426 })
[55]427 dc.SendMessage(&irc.Message{
428 Prefix: dc.srv.prefix(),
[13]429 Command: irc.RPL_MYINFO,
[98]430 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
[54]431 })
[93]432 // TODO: RPL_ISUPPORT
[55]433 dc.SendMessage(&irc.Message{
434 Prefix: dc.srv.prefix(),
[13]435 Command: irc.ERR_NOMOTD,
[55]436 Params: []string{dc.nick, "No MOTD"},
[54]437 })
[13]438
[73]439 dc.forEachUpstream(func(uc *upstreamConn) {
[30]440 // TODO: fix races accessing upstream connection data
441 for _, ch := range uc.channels {
442 if ch.complete {
[55]443 forwardChannel(dc, ch)
[30]444 }
445 }
[50]446
[73]447 historyName := dc.username
[57]448
449 var seqPtr *uint64
450 if firstDownstream {
451 seq, ok := uc.history[historyName]
452 if ok {
453 seqPtr = &seq
[50]454 }
455 }
[57]456
[59]457 consumer, ch := uc.ring.NewConsumer(seqPtr)
[57]458 go func() {
459 for {
460 var closed bool
461 select {
462 case <-ch:
[69]463 dc.consumptions <- consumption{consumer, uc}
[57]464 case <-dc.closed:
465 closed = true
466 }
467 if closed {
468 break
469 }
470 }
471
472 seq := consumer.Close()
473
474 dc.user.lock.Lock()
475 lastDownstream := len(dc.user.downstreamConns) == 0
476 dc.user.lock.Unlock()
477
478 if lastDownstream {
479 uc.history[historyName] = seq
480 }
481 }()
[39]482 })
[50]483
[13]484 return nil
485}
486
[55]487func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
[13]488 switch msg.Command {
[42]489 case "USER":
[13]490 return ircError{&irc.Message{
491 Command: irc.ERR_ALREADYREGISTERED,
[55]492 Params: []string{dc.nick, "You may not reregister"},
[13]493 }}
[42]494 case "NICK":
[90]495 var nick string
496 if err := parseMessageParams(msg, &nick); err != nil {
497 return err
498 }
499
500 var err error
501 dc.forEachNetwork(func(n *network) {
502 if err != nil {
503 return
504 }
505 n.Nick = nick
506 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
507 })
508 if err != nil {
509 return err
510 }
511
[73]512 dc.forEachUpstream(func(uc *upstreamConn) {
[60]513 uc.SendMessage(msg)
[42]514 })
[69]515 case "JOIN", "PART":
[48]516 var name string
517 if err := parseMessageParams(msg, &name); err != nil {
518 return err
519 }
520
[69]521 uc, upstreamName, err := dc.unmarshalChannel(name)
522 if err != nil {
523 return ircError{&irc.Message{
524 Command: irc.ERR_NOSUCHCHANNEL,
525 Params: []string{name, err.Error()},
526 }}
[48]527 }
528
[69]529 uc.SendMessage(&irc.Message{
530 Command: msg.Command,
531 Params: []string{upstreamName},
532 })
[89]533
534 switch msg.Command {
535 case "JOIN":
536 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
537 Name: upstreamName,
538 })
539 if err != nil {
540 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
541 }
542 case "PART":
543 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
544 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
545 }
546 }
[69]547 case "MODE":
548 if msg.Prefix == nil {
549 return fmt.Errorf("missing prefix")
[49]550 }
551
[46]552 var name string
553 if err := parseMessageParams(msg, &name); err != nil {
554 return err
555 }
556
557 var modeStr string
558 if len(msg.Params) > 1 {
559 modeStr = msg.Params[1]
560 }
561
562 if msg.Prefix.Name != name {
[69]563 uc, upstreamName, err := dc.unmarshalChannel(name)
[46]564 if err != nil {
565 return err
566 }
567
568 if modeStr != "" {
[69]569 uc.SendMessage(&irc.Message{
570 Command: "MODE",
571 Params: []string{upstreamName, modeStr},
572 })
[46]573 } else {
[69]574 ch, ok := uc.channels[upstreamName]
575 if !ok {
576 return ircError{&irc.Message{
577 Command: irc.ERR_NOSUCHCHANNEL,
578 Params: []string{name, "No such channel"},
579 }}
580 }
581
[55]582 dc.SendMessage(&irc.Message{
583 Prefix: dc.srv.prefix(),
[46]584 Command: irc.RPL_CHANNELMODEIS,
[69]585 Params: []string{name, string(ch.modes)},
[54]586 })
[46]587 }
588 } else {
[55]589 if name != dc.nick {
[46]590 return ircError{&irc.Message{
591 Command: irc.ERR_USERSDONTMATCH,
[55]592 Params: []string{dc.nick, "Cannot change mode for other users"},
[46]593 }}
594 }
595
596 if modeStr != "" {
[73]597 dc.forEachUpstream(func(uc *upstreamConn) {
[69]598 uc.SendMessage(&irc.Message{
599 Command: "MODE",
600 Params: []string{uc.nick, modeStr},
601 })
[46]602 })
603 } else {
[55]604 dc.SendMessage(&irc.Message{
605 Prefix: dc.srv.prefix(),
[46]606 Command: irc.RPL_UMODEIS,
607 Params: []string{""}, // TODO
[54]608 })
[46]609 }
610 }
[58]611 case "PRIVMSG":
612 var targetsStr, text string
613 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
614 return err
615 }
616
617 for _, name := range strings.Split(targetsStr, ",") {
[69]618 uc, upstreamName, err := dc.unmarshalChannel(name)
[58]619 if err != nil {
620 return err
621 }
622
[95]623 if upstreamName == "NickServ" {
624 dc.handleNickServPRIVMSG(uc, text)
625 }
626
[69]627 uc.SendMessage(&irc.Message{
[58]628 Command: "PRIVMSG",
[69]629 Params: []string{upstreamName, text},
[60]630 })
[58]631 }
[13]632 default:
[55]633 dc.logger.Printf("unhandled message: %v", msg)
[13]634 return newUnknownCommandError(msg.Command)
635 }
[42]636 return nil
[13]637}
[95]638
639func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
640 username, password, ok := parseNickServCredentials(text, uc.nick)
641 if !ok {
642 return
643 }
644
645 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
646 n := uc.network
647 n.SASL.Mechanism = "PLAIN"
648 n.SASL.Plain.Username = username
649 n.SASL.Plain.Password = password
650 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
651 dc.logger.Printf("failed to save NickServ credentials: %v", err)
652 }
653}
654
655func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
656 fields := strings.Fields(text)
657 if len(fields) < 2 {
658 return "", "", false
659 }
660 cmd := strings.ToUpper(fields[0])
661 params := fields[1:]
662 switch cmd {
663 case "REGISTER":
664 username = nick
665 password = params[0]
666 case "IDENTIFY":
667 if len(params) == 1 {
668 username = nick
669 } else {
670 username = params[0]
671 }
672 password = params[1]
673 }
674 return username, password, true
675}
Note: See TracBrowser for help on using the repository browser.