source: code/trunk/downstream.go@ 98

Last change on this file since 98 was 98, checked in by contact, 5 years ago

Rename project to soju

File size: 14.6 KB
Line 
1package soju
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strings"
9 "time"
10
11 "golang.org/x/crypto/bcrypt"
12 "gopkg.in/irc.v3"
13)
14
15type ircError struct {
16 Message *irc.Message
17}
18
19func (err ircError) Error() string {
20 return err.Message.String()
21}
22
23func newUnknownCommandError(cmd string) ircError {
24 return ircError{&irc.Message{
25 Command: irc.ERR_UNKNOWNCOMMAND,
26 Params: []string{
27 "*",
28 cmd,
29 "Unknown command",
30 },
31 }}
32}
33
34func newNeedMoreParamsError(cmd string) ircError {
35 return ircError{&irc.Message{
36 Command: irc.ERR_NEEDMOREPARAMS,
37 Params: []string{
38 "*",
39 cmd,
40 "Not enough parameters",
41 },
42 }}
43}
44
45var errAuthFailed = ircError{&irc.Message{
46 Command: irc.ERR_PASSWDMISMATCH,
47 Params: []string{"*", "Invalid username or password"},
48}}
49
50type consumption struct {
51 consumer *RingConsumer
52 upstreamConn *upstreamConn
53}
54
55type downstreamConn struct {
56 net net.Conn
57 irc *irc.Conn
58 srv *Server
59 logger Logger
60 messages chan *irc.Message
61 consumptions chan consumption
62 closed chan struct{}
63
64 registered bool
65 user *user
66 nick string
67 username string
68 realname string
69 password string // empty after authentication
70 network *network // can be nil
71}
72
73func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
74 dc := &downstreamConn{
75 net: netConn,
76 irc: irc.NewConn(netConn),
77 srv: srv,
78 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
79 messages: make(chan *irc.Message, 64),
80 consumptions: make(chan consumption),
81 closed: make(chan struct{}),
82 }
83
84 go func() {
85 if err := dc.writeMessages(); err != nil {
86 dc.logger.Printf("failed to write message: %v", err)
87 }
88 if err := dc.net.Close(); err != nil {
89 dc.logger.Printf("failed to close connection: %v", err)
90 } else {
91 dc.logger.Printf("connection closed")
92 }
93 }()
94
95 return dc
96}
97
98func (dc *downstreamConn) prefix() *irc.Prefix {
99 return &irc.Prefix{
100 Name: dc.nick,
101 User: dc.username,
102 // TODO: fill the host?
103 }
104}
105
106func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
107 return name
108}
109
110func (dc *downstreamConn) forEachNetwork(f func(*network)) {
111 if dc.network != nil {
112 f(dc.network)
113 } else {
114 dc.user.forEachNetwork(f)
115 }
116}
117
118func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
119 dc.user.forEachUpstream(func(uc *upstreamConn) {
120 if dc.network != nil && uc.network != dc.network {
121 return
122 }
123 f(uc)
124 })
125}
126
127// upstream returns the upstream connection, if any. If there are zero or if
128// there are multiple upstream connections, it returns nil.
129func (dc *downstreamConn) upstream() *upstreamConn {
130 if dc.network == nil {
131 return nil
132 }
133
134 var upstream *upstreamConn
135 dc.forEachUpstream(func(uc *upstreamConn) {
136 upstream = uc
137 })
138 return upstream
139}
140
141func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
142 if uc := dc.upstream(); uc != nil {
143 return uc, name, nil
144 }
145
146 // TODO: extract network name from channel name if dc.upstream == nil
147 var channel *upstreamChannel
148 var err error
149 dc.forEachUpstream(func(uc *upstreamConn) {
150 if err != nil {
151 return
152 }
153 if ch, ok := uc.channels[name]; ok {
154 if channel != nil {
155 err = fmt.Errorf("ambiguous channel name %q", name)
156 } else {
157 channel = ch
158 }
159 }
160 })
161 if channel == nil {
162 return nil, "", ircError{&irc.Message{
163 Command: irc.ERR_NOSUCHCHANNEL,
164 Params: []string{name, "No such channel"},
165 }}
166 }
167 return channel.conn, channel.Name, nil
168}
169
170func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
171 if nick == uc.nick {
172 return dc.nick
173 }
174 return nick
175}
176
177func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
178 if prefix.Name == uc.nick {
179 return dc.prefix()
180 }
181 return prefix
182}
183
184func (dc *downstreamConn) isClosed() bool {
185 select {
186 case <-dc.closed:
187 return true
188 default:
189 return false
190 }
191}
192
193func (dc *downstreamConn) readMessages() error {
194 dc.logger.Printf("new connection")
195
196 for {
197 msg, err := dc.irc.ReadMessage()
198 if err == io.EOF {
199 break
200 } else if err != nil {
201 return fmt.Errorf("failed to read IRC command: %v", err)
202 }
203
204 if dc.srv.Debug {
205 dc.logger.Printf("received: %v", msg)
206 }
207
208 err = dc.handleMessage(msg)
209 if ircErr, ok := err.(ircError); ok {
210 ircErr.Message.Prefix = dc.srv.prefix()
211 dc.SendMessage(ircErr.Message)
212 } else if err != nil {
213 return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
214 }
215
216 if dc.isClosed() {
217 return nil
218 }
219 }
220
221 return nil
222}
223
224func (dc *downstreamConn) writeMessages() error {
225 for {
226 var err error
227 var closed bool
228 select {
229 case msg := <-dc.messages:
230 if dc.srv.Debug {
231 dc.logger.Printf("sent: %v", msg)
232 }
233 err = dc.irc.WriteMessage(msg)
234 case consumption := <-dc.consumptions:
235 consumer, uc := consumption.consumer, consumption.upstreamConn
236 for {
237 msg := consumer.Peek()
238 if msg == nil {
239 break
240 }
241 msg = msg.Copy()
242 switch msg.Command {
243 case "PRIVMSG":
244 // TODO: detect whether it's a user or a channel
245 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
246 default:
247 panic("expected to consume a PRIVMSG message")
248 }
249 if dc.srv.Debug {
250 dc.logger.Printf("sent: %v", msg)
251 }
252 err = dc.irc.WriteMessage(msg)
253 if err != nil {
254 break
255 }
256 consumer.Consume()
257 }
258 case <-dc.closed:
259 closed = true
260 }
261 if err != nil {
262 return err
263 }
264 if closed {
265 break
266 }
267 }
268 return nil
269}
270
271func (dc *downstreamConn) Close() error {
272 if dc.isClosed() {
273 return fmt.Errorf("downstream connection already closed")
274 }
275
276 if u := dc.user; u != nil {
277 u.lock.Lock()
278 for i := range u.downstreamConns {
279 if u.downstreamConns[i] == dc {
280 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
281 break
282 }
283 }
284 u.lock.Unlock()
285 }
286
287 close(dc.closed)
288 return nil
289}
290
291func (dc *downstreamConn) SendMessage(msg *irc.Message) {
292 dc.messages <- msg
293}
294
295func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
296 switch msg.Command {
297 case "QUIT":
298 return dc.Close()
299 case "PING":
300 dc.SendMessage(&irc.Message{
301 Prefix: dc.srv.prefix(),
302 Command: "PONG",
303 Params: msg.Params,
304 })
305 return nil
306 default:
307 if dc.registered {
308 return dc.handleMessageRegistered(msg)
309 } else {
310 return dc.handleMessageUnregistered(msg)
311 }
312 }
313}
314
315func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
316 switch msg.Command {
317 case "NICK":
318 if err := parseMessageParams(msg, &dc.nick); err != nil {
319 return err
320 }
321 case "USER":
322 var username string
323 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
324 return err
325 }
326 dc.username = "~" + username
327 case "PASS":
328 if err := parseMessageParams(msg, &dc.password); err != nil {
329 return err
330 }
331 default:
332 dc.logger.Printf("unhandled message: %v", msg)
333 return newUnknownCommandError(msg.Command)
334 }
335 if dc.username != "" && dc.nick != "" {
336 return dc.register()
337 }
338 return nil
339}
340
341func sanityCheckServer(addr string) error {
342 dialer := net.Dialer{Timeout: 30 * time.Second}
343 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
344 if err != nil {
345 return err
346 }
347 return conn.Close()
348}
349
350func (dc *downstreamConn) register() error {
351 username := strings.TrimPrefix(dc.username, "~")
352 var networkName string
353 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
354 networkName = username[i+1:]
355 }
356 if i := strings.IndexAny(username, "/@"); i >= 0 {
357 username = username[:i]
358 }
359
360 password := dc.password
361 dc.password = ""
362
363 u := dc.srv.getUser(username)
364 if u == nil {
365 dc.logger.Printf("failed authentication for %q: unknown username", username)
366 return errAuthFailed
367 }
368
369 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
370 if err != nil {
371 dc.logger.Printf("failed authentication for %q: %v", username, err)
372 return errAuthFailed
373 }
374
375 var network *network
376 if networkName != "" {
377 network = u.getNetwork(networkName)
378 if network == nil {
379 addr := networkName
380 if !strings.ContainsRune(addr, ':') {
381 addr = addr + ":6697"
382 }
383
384 dc.logger.Printf("trying to connect to new network %q", addr)
385 if err := sanityCheckServer(addr); err != nil {
386 dc.logger.Printf("failed to connect to %q: %v", addr, err)
387 return ircError{&irc.Message{
388 Command: irc.ERR_PASSWDMISMATCH,
389 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
390 }}
391 }
392
393 dc.logger.Printf("auto-saving network %q", networkName)
394 network, err = u.createNetwork(networkName, dc.nick)
395 if err != nil {
396 return err
397 }
398 }
399 }
400
401 dc.registered = true
402 dc.user = u
403 dc.network = network
404
405 u.lock.Lock()
406 firstDownstream := len(u.downstreamConns) == 0
407 u.downstreamConns = append(u.downstreamConns, dc)
408 u.lock.Unlock()
409
410 dc.SendMessage(&irc.Message{
411 Prefix: dc.srv.prefix(),
412 Command: irc.RPL_WELCOME,
413 Params: []string{dc.nick, "Welcome to soju, " + dc.nick},
414 })
415 dc.SendMessage(&irc.Message{
416 Prefix: dc.srv.prefix(),
417 Command: irc.RPL_YOURHOST,
418 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
419 })
420 dc.SendMessage(&irc.Message{
421 Prefix: dc.srv.prefix(),
422 Command: irc.RPL_CREATED,
423 Params: []string{dc.nick, "Who cares when the server was created?"},
424 })
425 dc.SendMessage(&irc.Message{
426 Prefix: dc.srv.prefix(),
427 Command: irc.RPL_MYINFO,
428 Params: []string{dc.nick, dc.srv.Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"},
429 })
430 // TODO: RPL_ISUPPORT
431 dc.SendMessage(&irc.Message{
432 Prefix: dc.srv.prefix(),
433 Command: irc.ERR_NOMOTD,
434 Params: []string{dc.nick, "No MOTD"},
435 })
436
437 dc.forEachUpstream(func(uc *upstreamConn) {
438 // TODO: fix races accessing upstream connection data
439 for _, ch := range uc.channels {
440 if ch.complete {
441 forwardChannel(dc, ch)
442 }
443 }
444
445 historyName := dc.username
446
447 var seqPtr *uint64
448 if firstDownstream {
449 seq, ok := uc.history[historyName]
450 if ok {
451 seqPtr = &seq
452 }
453 }
454
455 consumer, ch := uc.ring.NewConsumer(seqPtr)
456 go func() {
457 for {
458 var closed bool
459 select {
460 case <-ch:
461 dc.consumptions <- consumption{consumer, uc}
462 case <-dc.closed:
463 closed = true
464 }
465 if closed {
466 break
467 }
468 }
469
470 seq := consumer.Close()
471
472 dc.user.lock.Lock()
473 lastDownstream := len(dc.user.downstreamConns) == 0
474 dc.user.lock.Unlock()
475
476 if lastDownstream {
477 uc.history[historyName] = seq
478 }
479 }()
480 })
481
482 return nil
483}
484
485func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
486 switch msg.Command {
487 case "USER":
488 return ircError{&irc.Message{
489 Command: irc.ERR_ALREADYREGISTERED,
490 Params: []string{dc.nick, "You may not reregister"},
491 }}
492 case "NICK":
493 var nick string
494 if err := parseMessageParams(msg, &nick); err != nil {
495 return err
496 }
497
498 var err error
499 dc.forEachNetwork(func(n *network) {
500 if err != nil {
501 return
502 }
503 n.Nick = nick
504 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
505 })
506 if err != nil {
507 return err
508 }
509
510 dc.forEachUpstream(func(uc *upstreamConn) {
511 uc.SendMessage(msg)
512 })
513 case "JOIN", "PART":
514 var name string
515 if err := parseMessageParams(msg, &name); err != nil {
516 return err
517 }
518
519 uc, upstreamName, err := dc.unmarshalChannel(name)
520 if err != nil {
521 return ircError{&irc.Message{
522 Command: irc.ERR_NOSUCHCHANNEL,
523 Params: []string{name, err.Error()},
524 }}
525 }
526
527 uc.SendMessage(&irc.Message{
528 Command: msg.Command,
529 Params: []string{upstreamName},
530 })
531
532 switch msg.Command {
533 case "JOIN":
534 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
535 Name: upstreamName,
536 })
537 if err != nil {
538 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
539 }
540 case "PART":
541 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
542 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
543 }
544 }
545 case "MODE":
546 if msg.Prefix == nil {
547 return fmt.Errorf("missing prefix")
548 }
549
550 var name string
551 if err := parseMessageParams(msg, &name); err != nil {
552 return err
553 }
554
555 var modeStr string
556 if len(msg.Params) > 1 {
557 modeStr = msg.Params[1]
558 }
559
560 if msg.Prefix.Name != name {
561 uc, upstreamName, err := dc.unmarshalChannel(name)
562 if err != nil {
563 return err
564 }
565
566 if modeStr != "" {
567 uc.SendMessage(&irc.Message{
568 Command: "MODE",
569 Params: []string{upstreamName, modeStr},
570 })
571 } else {
572 ch, ok := uc.channels[upstreamName]
573 if !ok {
574 return ircError{&irc.Message{
575 Command: irc.ERR_NOSUCHCHANNEL,
576 Params: []string{name, "No such channel"},
577 }}
578 }
579
580 dc.SendMessage(&irc.Message{
581 Prefix: dc.srv.prefix(),
582 Command: irc.RPL_CHANNELMODEIS,
583 Params: []string{name, string(ch.modes)},
584 })
585 }
586 } else {
587 if name != dc.nick {
588 return ircError{&irc.Message{
589 Command: irc.ERR_USERSDONTMATCH,
590 Params: []string{dc.nick, "Cannot change mode for other users"},
591 }}
592 }
593
594 if modeStr != "" {
595 dc.forEachUpstream(func(uc *upstreamConn) {
596 uc.SendMessage(&irc.Message{
597 Command: "MODE",
598 Params: []string{uc.nick, modeStr},
599 })
600 })
601 } else {
602 dc.SendMessage(&irc.Message{
603 Prefix: dc.srv.prefix(),
604 Command: irc.RPL_UMODEIS,
605 Params: []string{""}, // TODO
606 })
607 }
608 }
609 case "PRIVMSG":
610 var targetsStr, text string
611 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
612 return err
613 }
614
615 for _, name := range strings.Split(targetsStr, ",") {
616 uc, upstreamName, err := dc.unmarshalChannel(name)
617 if err != nil {
618 return err
619 }
620
621 if upstreamName == "NickServ" {
622 dc.handleNickServPRIVMSG(uc, text)
623 }
624
625 uc.SendMessage(&irc.Message{
626 Command: "PRIVMSG",
627 Params: []string{upstreamName, text},
628 })
629 }
630 default:
631 dc.logger.Printf("unhandled message: %v", msg)
632 return newUnknownCommandError(msg.Command)
633 }
634 return nil
635}
636
637func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
638 username, password, ok := parseNickServCredentials(text, uc.nick)
639 if !ok {
640 return
641 }
642
643 dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
644 n := uc.network
645 n.SASL.Mechanism = "PLAIN"
646 n.SASL.Plain.Username = username
647 n.SASL.Plain.Password = password
648 if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
649 dc.logger.Printf("failed to save NickServ credentials: %v", err)
650 }
651}
652
653func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
654 fields := strings.Fields(text)
655 if len(fields) < 2 {
656 return "", "", false
657 }
658 cmd := strings.ToUpper(fields[0])
659 params := fields[1:]
660 switch cmd {
661 case "REGISTER":
662 username = nick
663 password = params[0]
664 case "IDENTIFY":
665 if len(params) == 1 {
666 username = nick
667 } else {
668 username = params[0]
669 }
670 password = params[1]
671 }
672 return username, password, true
673}
Note: See TracBrowser for help on using the repository browser.