source: code/trunk/downstream.go@ 101

Last change on this file since 101 was 100, checked in by contact, 5 years ago

Strip client & network name from username

File size: 14.6 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strings"
9 "time"
10
11 "golang.org/x/crypto/bcrypt"
12 "gopkg.in/irc.v3"
13)
14
15type ircError struct {
16 Message *irc.Message
17}
18
19func (err ircError) Error() string {
20 return err.Message.String()
21}
22
23func newUnknownCommandError(cmd string) ircError {
24 return ircError{&irc.Message{
25 Command: irc.ERR_UNKNOWNCOMMAND,
26 Params: []string{
27 "*",
28 cmd,
29 "Unknown command",
30 },
31 }}
32}
33
34func newNeedMoreParamsError(cmd string) ircError {
35 return ircError{&irc.Message{
36 Command: irc.ERR_NEEDMOREPARAMS,
37 Params: []string{
38 "*",
39 cmd,
40 "Not enough parameters",
41 },
42 }}
43}
44
45var errAuthFailed = ircError{&irc.Message{
46 Command: irc.ERR_PASSWDMISMATCH,
47 Params: []string{"*", "Invalid username or password"},
48}}
49
50type consumption struct {
51 consumer *RingConsumer
52 upstreamConn *upstreamConn
53}
54
55type downstreamConn struct {
56 net net.Conn
57 irc *irc.Conn
58 srv *Server
59 logger Logger
60 messages chan *irc.Message
61 consumptions chan consumption
62 closed chan struct{}
63
64 registered bool
65 user *user
66 nick string
67 username string
68 rawUsername string
69 realname string
70 password string // empty after authentication
71 network *network // can be nil
72}
73
74func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
75 dc := &downstreamConn{
76 net: netConn,
77 irc: irc.NewConn(netConn),
78 srv: srv,
79 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
80 messages: make(chan *irc.Message, 64),
81 consumptions: make(chan consumption),
82 closed: make(chan struct{}),
83 }
84
85 go func() {
86 if err := dc.writeMessages(); err != nil {
87 dc.logger.Printf("failed to write message: %v", err)
88 }
89 if err := dc.net.Close(); err != nil {
90 dc.logger.Printf("failed to close connection: %v", err)
91 } else {
92 dc.logger.Printf("connection closed")
93 }
94 }()
95
96 return dc
97}
98
99func (dc *downstreamConn) prefix() *irc.Prefix {
100 return &irc.Prefix{
101 Name: dc.nick,
102 User: dc.username,
103 // TODO: fill the host?
104 }
105}
106
107func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
108 return name
109}
110
111func (dc *downstreamConn) forEachNetwork(f func(*network)) {
112 if dc.network != nil {
113 f(dc.network)
114 } else {
115 dc.user.forEachNetwork(f)
116 }
117}
118
119func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
120 dc.user.forEachUpstream(func(uc *upstreamConn) {
121 if dc.network != nil && uc.network != dc.network {
122 return
123 }
124 f(uc)
125 })
126}
127
128// upstream returns the upstream connection, if any. If there are zero or if
129// there are multiple upstream connections, it returns nil.
130func (dc *downstreamConn) upstream() *upstreamConn {
131 if dc.network == nil {
132 return nil
133 }
134
135 var upstream *upstreamConn
136 dc.forEachUpstream(func(uc *upstreamConn) {
137 upstream = uc
138 })
139 return upstream
140}
141
142func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
143 if uc := dc.upstream(); uc != nil {
144 return uc, name, nil
145 }
146
147 // TODO: extract network name from channel name if dc.upstream == nil
148 var channel *upstreamChannel
149 var err error
150 dc.forEachUpstream(func(uc *upstreamConn) {
151 if err != nil {
152 return
153 }
154 if ch, ok := uc.channels[name]; ok {
155 if channel != nil {
156 err = fmt.Errorf("ambiguous channel name %q", name)
157 } else {
158 channel = ch
159 }
160 }
161 })
162 if channel == nil {
163 return nil, "", ircError{&irc.Message{
164 Command: irc.ERR_NOSUCHCHANNEL,
165 Params: []string{name, "No such channel"},
166 }}
167 }
168 return channel.conn, channel.Name, nil
169}
170
171func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
172 if nick == uc.nick {
173 return dc.nick
174 }
175 return nick
176}
177
178func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
179 if prefix.Name == uc.nick {
180 return dc.prefix()
181 }
182 return prefix
183}
184
185func (dc *downstreamConn) isClosed() bool {
186 select {
187 case <-dc.closed:
188 return true
189 default:
190 return false
191 }
192}
193
194func (dc *downstreamConn) readMessages() error {
195 dc.logger.Printf("new connection")
196
197 for {
198 msg, err := dc.irc.ReadMessage()
199 if err == io.EOF {
200 break
201 } else if err != nil {
202 return fmt.Errorf("failed to read IRC command: %v", err)
203 }
204
205 if dc.srv.Debug {
206 dc.logger.Printf("received: %v", msg)
207 }
208
209 err = dc.handleMessage(msg)
210 if ircErr, ok := err.(ircError); ok {
211 ircErr.Message.Prefix = dc.srv.prefix()
212 dc.SendMessage(ircErr.Message)
213 } else if err != nil {
214 return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
215 }
216
217 if dc.isClosed() {
218 return nil
219 }
220 }
221
222 return nil
223}
224
225func (dc *downstreamConn) writeMessages() error {
226 for {
227 var err error
228 var closed bool
229 select {
230 case msg := <-dc.messages:
231 if dc.srv.Debug {
232 dc.logger.Printf("sent: %v", msg)
233 }
234 err = dc.irc.WriteMessage(msg)
235 case consumption := <-dc.consumptions:
236 consumer, uc := consumption.consumer, consumption.upstreamConn
237 for {
238 msg := consumer.Peek()
239 if msg == nil {
240 break
241 }
242 msg = msg.Copy()
243 switch msg.Command {
244 case "PRIVMSG":
245 // TODO: detect whether it's a user or a channel
246 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
247 default:
248 panic("expected to consume a PRIVMSG message")
249 }
250 if dc.srv.Debug {
251 dc.logger.Printf("sent: %v", msg)
252 }
253 err = dc.irc.WriteMessage(msg)
254 if err != nil {
255 break
256 }
257 consumer.Consume()
258 }
259 case <-dc.closed:
260 closed = true
261 }
262 if err != nil {
263 return err
264 }
265 if closed {
266 break
267 }
268 }
269 return nil
270}
271
272func (dc *downstreamConn) Close() error {
273 if dc.isClosed() {
274 return fmt.Errorf("downstream connection already closed")
275 }
276
277 if u := dc.user; u != nil {
278 u.lock.Lock()
279 for i := range u.downstreamConns {
280 if u.downstreamConns[i] == dc {
281 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
282 break
283 }
284 }
285 u.lock.Unlock()
286 }
287
288 close(dc.closed)
289 return nil
290}
291
292func (dc *downstreamConn) SendMessage(msg *irc.Message) {
293 dc.messages <- msg
294}
295
296func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
297 switch msg.Command {
298 case "QUIT":
299 return dc.Close()
300 case "PING":
301 dc.SendMessage(&irc.Message{
302 Prefix: dc.srv.prefix(),
303 Command: "PONG",
304 Params: msg.Params,
305 })
306 return nil
307 default:
308 if dc.registered {
309 return dc.handleMessageRegistered(msg)
310 } else {
311 return dc.handleMessageUnregistered(msg)
312 }
313 }
314}
315
316func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
317 switch msg.Command {
318 case "NICK":
319 if err := parseMessageParams(msg, &dc.nick); err != nil {
320 return err
321 }
322 case "USER":
323 var username string
324 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
325 return err
326 }
327 dc.rawUsername = username
328 case "PASS":
329 if err := parseMessageParams(msg, &dc.password); err != nil {
330 return err
331 }
332 default:
333 dc.logger.Printf("unhandled message: %v", msg)
334 return newUnknownCommandError(msg.Command)
335 }
336 if dc.rawUsername != "" && dc.nick != "" {
337 return dc.register()
338 }
339 return nil
340}
341
342func sanityCheckServer(addr string) error {
343 dialer := net.Dialer{Timeout: 30 * time.Second}
344 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
345 if err != nil {
346 return err
347 }
348 return conn.Close()
349}
350
351func (dc *downstreamConn) register() error {
352 username := dc.rawUsername
353 var networkName string
354 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
355 networkName = username[i+1:]
356 }
357 if i := strings.IndexAny(username, "/@"); i >= 0 {
358 username = username[:i]
359 }
360 dc.username = "~" + username
361
362 password := dc.password
363 dc.password = ""
364
365 u := dc.srv.getUser(username)
366 if u == nil {
367 dc.logger.Printf("failed authentication for %q: unknown username", username)
368 return errAuthFailed
369 }
370
371 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
372 if err != nil {
373 dc.logger.Printf("failed authentication for %q: %v", username, err)
374 return errAuthFailed
375 }
376
377 var network *network
378 if networkName != "" {
379 network = u.getNetwork(networkName)
380 if network == nil {
381 addr := networkName
382 if !strings.ContainsRune(addr, ':') {
383 addr = addr + ":6697"
384 }
385
386 dc.logger.Printf("trying to connect to new network %q", addr)
387 if err := sanityCheckServer(addr); err != nil {
388 dc.logger.Printf("failed to connect to %q: %v", addr, err)
389 return ircError{&irc.Message{
390 Command: irc.ERR_PASSWDMISMATCH,
391 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
392 }}
393 }
394
395 dc.logger.Printf("auto-saving network %q", networkName)
396 network, err = u.createNetwork(networkName, dc.nick)
397 if err != nil {
398 return err
399 }
400 }
401 }
402
403 dc.registered = true
404 dc.user = u
405 dc.network = network
406
407 u.lock.Lock()
408 firstDownstream := len(u.downstreamConns) == 0
409 u.downstreamConns = append(u.downstreamConns, dc)
410 u.lock.Unlock()
411
412 dc.SendMessage(&irc.Message{
413 Prefix: dc.srv.prefix(),
414 Command: irc.RPL_WELCOME,
415 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
416 })
417 dc.SendMessage(&irc.Message{
418 Prefix: dc.srv.prefix(),
419 Command: irc.RPL_YOURHOST,
420 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
421 })
422 dc.SendMessage(&irc.Message{
423 Prefix: dc.srv.prefix(),
424 Command: irc.RPL_CREATED,
425 Params: []string{dc.nick, "Who cares when the server was created?"},
426 })
427 dc.SendMessage(&irc.Message{
428 Prefix: dc.srv.prefix(),
429 Command: irc.RPL_MYINFO,
430 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
431 })
432 // TODO: RPL_ISUPPORT
433 dc.SendMessage(&irc.Message{
434 Prefix: dc.srv.prefix(),
435 Command: irc.ERR_NOMOTD,
436 Params: []string{dc.nick, "No MOTD"},
437 })
438
439 dc.forEachUpstream(func(uc *upstreamConn) {
440 // TODO: fix races accessing upstream connection data
441 for _, ch := range uc.channels {
442 if ch.complete {
443 forwardChannel(dc, ch)
444 }
445 }
446
447 historyName := dc.username
448
449 var seqPtr *uint64
450 if firstDownstream {
451 seq, ok := uc.history[historyName]
452 if ok {
453 seqPtr = &seq
454 }
455 }
456
457 consumer, ch := uc.ring.NewConsumer(seqPtr)
458 go func() {
459 for {
460 var closed bool
461 select {
462 case <-ch:
463 dc.consumptions <- consumption{consumer, uc}
464 case <-dc.closed:
465 closed = true
466 }
467 if closed {
468 break
469 }
470 }
471
472 seq := consumer.Close()
473
474 dc.user.lock.Lock()
475 lastDownstream := len(dc.user.downstreamConns) == 0
476 dc.user.lock.Unlock()
477
478 if lastDownstream {
479 uc.history[historyName] = seq
480 }
481 }()
482 })
483
484 return nil
485}
486
487func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
488 switch msg.Command {
489 case "USER":
490 return ircError{&irc.Message{
491 Command: irc.ERR_ALREADYREGISTERED,
492 Params: []string{dc.nick, "You may not reregister"},
493 }}
494 case "NICK":
495 var nick string
496 if err := parseMessageParams(msg, &nick); err != nil {
497 return err
498 }
499
500 var err error
501 dc.forEachNetwork(func(n *network) {
502 if err != nil {
503 return
504 }
505 n.Nick = nick
506 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
507 })
508 if err != nil {
509 return err
510 }
511
512 dc.forEachUpstream(func(uc *upstreamConn) {
513 uc.SendMessage(msg)
514 })
515 case "JOIN", "PART":
516 var name string
517 if err := parseMessageParams(msg, &name); err != nil {
518 return err
519 }
520
521 uc, upstreamName, err := dc.unmarshalChannel(name)
522 if err != nil {
523 return ircError{&irc.Message{
524 Command: irc.ERR_NOSUCHCHANNEL,
525 Params: []string{name, err.Error()},
526 }}
527 }
528
529 uc.SendMessage(&irc.Message{
530 Command: msg.Command,
531 Params: []string{upstreamName},
532 })
533
534 switch msg.Command {
535 case "JOIN":
536 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
537 Name: upstreamName,
538 })
539 if err != nil {
540 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
541 }
542 case "PART":
543 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
544 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
545 }
546 }
547 case "MODE":
548 if msg.Prefix == nil {
549 return fmt.Errorf("missing prefix")
550 }
551
552 var name string
553 if err := parseMessageParams(msg, &name); err != nil {
554 return err
555 }
556
557 var modeStr string
558 if len(msg.Params) > 1 {
559 modeStr = msg.Params[1]
560 }
561
562 if msg.Prefix.Name != name {
563 uc, upstreamName, err := dc.unmarshalChannel(name)
564 if err != nil {
565 return err
566 }
567
568 if modeStr != "" {
569 uc.SendMessage(&irc.Message{
570 Command: "MODE",
571 Params: []string{upstreamName, modeStr},
572 })
573 } else {
574 ch, ok := uc.channels[upstreamName]
575 if !ok {
576 return ircError{&irc.Message{
577 Command: irc.ERR_NOSUCHCHANNEL,
578 Params: []string{name, "No such channel"},
579 }}
580 }
581
582 dc.SendMessage(&irc.Message{
583 Prefix: dc.srv.prefix(),
584 Command: irc.RPL_CHANNELMODEIS,
585 Params: []string{name, string(ch.modes)},
586 })
587 }
588 } else {
589 if name != dc.nick {
590 return ircError{&irc.Message{
591 Command: irc.ERR_USERSDONTMATCH,
592 Params: []string{dc.nick, "Cannot change mode for other users"},
593 }}
594 }
595
596 if modeStr != "" {
597 dc.forEachUpstream(func(uc *upstreamConn) {
598 uc.SendMessage(&irc.Message{
599 Command: "MODE",
600 Params: []string{uc.nick, modeStr},
601 })
602 })
603 } else {
604 dc.SendMessage(&irc.Message{
605 Prefix: dc.srv.prefix(),
606 Command: irc.RPL_UMODEIS,
607 Params: []string{""}, // TODO
608 })
609 }
610 }
611 case "PRIVMSG":
612 var targetsStr, text string
613 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
614 return err
615 }
616
617 for _, name := range strings.Split(targetsStr, ",") {
618 uc, upstreamName, err := dc.unmarshalChannel(name)
619 if err != nil {
620 return err
621 }
622
623 if upstreamName == "NickServ" {
624 dc.handleNickServPRIVMSG(uc, text)
625 }
626
627 uc.SendMessage(&irc.Message{
628 Command: "PRIVMSG",
629 Params: []string{upstreamName, text},
630 })
631 }
632 default:
633 dc.logger.Printf("unhandled message: %v", msg)
634 return newUnknownCommandError(msg.Command)
635 }
636 return nil
637}
638
639func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
640 username, password, ok := parseNickServCredentials(text, uc.nick)
641 if !ok {
642 return
643 }
644
645 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
646 n := uc.network
647 n.SASL.Mechanism = "PLAIN"
648 n.SASL.Plain.Username = username
649 n.SASL.Plain.Password = password
650 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
651 dc.logger.Printf("failed to save NickServ credentials: %v", err)
652 }
653}
654
655func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
656 fields := strings.Fields(text)
657 if len(fields) < 2 {
658 return "", "", false
659 }
660 cmd := strings.ToUpper(fields[0])
661 params := fields[1:]
662 switch cmd {
663 case "REGISTER":
664 username = nick
665 password = params[0]
666 case "IDENTIFY":
667 if len(params) == 1 {
668 username = nick
669 } else {
670 username = params[0]
671 }
672 password = params[1]
673 }
674 return username, password, true
675}
Note: See TracBrowser for help on using the repository browser.