source: code/trunk/downstream.go@ 91

Last change on this file since 91 was 91, checked in by contact, 5 years ago

Auto-save IRC networks

File size: 13.5 KB
Line 
1package jounce
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strings"
9 "time"
10
11 "golang.org/x/crypto/bcrypt"
12 "gopkg.in/irc.v3"
13)
14
15type ircError struct {
16 Message *irc.Message
17}
18
19func (err ircError) Error() string {
20 return err.Message.String()
21}
22
23func newUnknownCommandError(cmd string) ircError {
24 return ircError{&irc.Message{
25 Command: irc.ERR_UNKNOWNCOMMAND,
26 Params: []string{
27 "*",
28 cmd,
29 "Unknown command",
30 },
31 }}
32}
33
34func newNeedMoreParamsError(cmd string) ircError {
35 return ircError{&irc.Message{
36 Command: irc.ERR_NEEDMOREPARAMS,
37 Params: []string{
38 "*",
39 cmd,
40 "Not enough parameters",
41 },
42 }}
43}
44
45var errAuthFailed = ircError{&irc.Message{
46 Command: irc.ERR_PASSWDMISMATCH,
47 Params: []string{"*", "Invalid username or password"},
48}}
49
50type consumption struct {
51 consumer *RingConsumer
52 upstreamConn *upstreamConn
53}
54
55type downstreamConn struct {
56 net net.Conn
57 irc *irc.Conn
58 srv *Server
59 logger Logger
60 messages chan *irc.Message
61 consumptions chan consumption
62 closed chan struct{}
63
64 registered bool
65 user *user
66 nick string
67 username string
68 realname string
69 password string // empty after authentication
70 network *network // can be nil
71}
72
73func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
74 dc := &downstreamConn{
75 net: netConn,
76 irc: irc.NewConn(netConn),
77 srv: srv,
78 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
79 messages: make(chan *irc.Message, 64),
80 consumptions: make(chan consumption),
81 closed: make(chan struct{}),
82 }
83
84 go func() {
85 if err := dc.writeMessages(); err != nil {
86 dc.logger.Printf("failed to write message: %v", err)
87 }
88 if err := dc.net.Close(); err != nil {
89 dc.logger.Printf("failed to close connection: %v", err)
90 } else {
91 dc.logger.Printf("connection closed")
92 }
93 }()
94
95 return dc
96}
97
98func (dc *downstreamConn) prefix() *irc.Prefix {
99 return &irc.Prefix{
100 Name: dc.nick,
101 User: dc.username,
102 // TODO: fill the host?
103 }
104}
105
106func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
107 return name
108}
109
110func (dc *downstreamConn) forEachNetwork(f func(*network)) {
111 if dc.network != nil {
112 f(dc.network)
113 } else {
114 dc.user.forEachNetwork(f)
115 }
116}
117
118func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
119 dc.user.forEachUpstream(func(uc *upstreamConn) {
120 if dc.network != nil && uc.network != dc.network {
121 return
122 }
123 f(uc)
124 })
125}
126
127// upstream returns the upstream connection, if any. If there are zero or if
128// there are multiple upstream connections, it returns nil.
129func (dc *downstreamConn) upstream() *upstreamConn {
130 if dc.network == nil {
131 return nil
132 }
133
134 var upstream *upstreamConn
135 dc.forEachUpstream(func(uc *upstreamConn) {
136 upstream = uc
137 })
138 return upstream
139}
140
141func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
142 if uc := dc.upstream(); uc != nil {
143 return uc, name, nil
144 }
145
146 // TODO: extract network name from channel name if dc.upstream == nil
147 var channel *upstreamChannel
148 var err error
149 dc.forEachUpstream(func(uc *upstreamConn) {
150 if err != nil {
151 return
152 }
153 if ch, ok := uc.channels[name]; ok {
154 if channel != nil {
155 err = fmt.Errorf("ambiguous channel name %q", name)
156 } else {
157 channel = ch
158 }
159 }
160 })
161 if channel == nil {
162 return nil, "", ircError{&irc.Message{
163 Command: irc.ERR_NOSUCHCHANNEL,
164 Params: []string{name, "No such channel"},
165 }}
166 }
167 return channel.conn, channel.Name, nil
168}
169
170func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
171 if nick == uc.nick {
172 return dc.nick
173 }
174 return nick
175}
176
177func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
178 if prefix.Name == uc.nick {
179 return dc.prefix()
180 }
181 return prefix
182}
183
184func (dc *downstreamConn) isClosed() bool {
185 select {
186 case <-dc.closed:
187 return true
188 default:
189 return false
190 }
191}
192
193func (dc *downstreamConn) readMessages() error {
194 dc.logger.Printf("new connection")
195
196 for {
197 msg, err := dc.irc.ReadMessage()
198 if err == io.EOF {
199 break
200 } else if err != nil {
201 return fmt.Errorf("failed to read IRC command: %v", err)
202 }
203
204 if dc.srv.Debug {
205 dc.logger.Printf("received: %v", msg)
206 }
207
208 err = dc.handleMessage(msg)
209 if ircErr, ok := err.(ircError); ok {
210 ircErr.Message.Prefix = dc.srv.prefix()
211 dc.SendMessage(ircErr.Message)
212 } else if err != nil {
213 return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
214 }
215
216 if dc.isClosed() {
217 return nil
218 }
219 }
220
221 return nil
222}
223
224func (dc *downstreamConn) writeMessages() error {
225 for {
226 var err error
227 var closed bool
228 select {
229 case msg := <-dc.messages:
230 if dc.srv.Debug {
231 dc.logger.Printf("sent: %v", msg)
232 }
233 err = dc.irc.WriteMessage(msg)
234 case consumption := <-dc.consumptions:
235 consumer, uc := consumption.consumer, consumption.upstreamConn
236 for {
237 msg := consumer.Peek()
238 if msg == nil {
239 break
240 }
241 msg = msg.Copy()
242 switch msg.Command {
243 case "PRIVMSG":
244 // TODO: detect whether it's a user or a channel
245 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
246 default:
247 panic("expected to consume a PRIVMSG message")
248 }
249 if dc.srv.Debug {
250 dc.logger.Printf("sent: %v", msg)
251 }
252 err = dc.irc.WriteMessage(msg)
253 if err != nil {
254 break
255 }
256 consumer.Consume()
257 }
258 case <-dc.closed:
259 closed = true
260 }
261 if err != nil {
262 return err
263 }
264 if closed {
265 break
266 }
267 }
268 return nil
269}
270
271func (dc *downstreamConn) Close() error {
272 if dc.isClosed() {
273 return fmt.Errorf("downstream connection already closed")
274 }
275
276 if u := dc.user; u != nil {
277 u.lock.Lock()
278 for i := range u.downstreamConns {
279 if u.downstreamConns[i] == dc {
280 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
281 break
282 }
283 }
284 u.lock.Unlock()
285 }
286
287 close(dc.closed)
288 return nil
289}
290
291func (dc *downstreamConn) SendMessage(msg *irc.Message) {
292 dc.messages <- msg
293}
294
295func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
296 switch msg.Command {
297 case "QUIT":
298 return dc.Close()
299 case "PING":
300 dc.SendMessage(&irc.Message{
301 Prefix: dc.srv.prefix(),
302 Command: "PONG",
303 Params: msg.Params,
304 })
305 return nil
306 default:
307 if dc.registered {
308 return dc.handleMessageRegistered(msg)
309 } else {
310 return dc.handleMessageUnregistered(msg)
311 }
312 }
313}
314
315func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
316 switch msg.Command {
317 case "NICK":
318 if err := parseMessageParams(msg, &dc.nick); err != nil {
319 return err
320 }
321 case "USER":
322 var username string
323 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
324 return err
325 }
326 dc.username = "~" + username
327 case "PASS":
328 if err := parseMessageParams(msg, &dc.password); err != nil {
329 return err
330 }
331 default:
332 dc.logger.Printf("unhandled message: %v", msg)
333 return newUnknownCommandError(msg.Command)
334 }
335 if dc.username != "" && dc.nick != "" {
336 return dc.register()
337 }
338 return nil
339}
340
341func sanityCheckServer(addr string) error {
342 dialer := net.Dialer{Timeout: 30 * time.Second}
343 conn, err := tls.DialWithDialer(&dialer, "tcp", addr, nil)
344 if err != nil {
345 return err
346 }
347 return conn.Close()
348}
349
350func (dc *downstreamConn) register() error {
351 username := strings.TrimPrefix(dc.username, "~")
352 var networkName string
353 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
354 networkName = username[i+1:]
355 }
356 if i := strings.IndexAny(username, "/@"); i >= 0 {
357 username = username[:i]
358 }
359
360 password := dc.password
361 dc.password = ""
362
363 u := dc.srv.getUser(username)
364 if u == nil {
365 dc.logger.Printf("failed authentication for %q: unknown username", username)
366 return errAuthFailed
367 }
368
369 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
370 if err != nil {
371 dc.logger.Printf("failed authentication for %q: %v", username, err)
372 return errAuthFailed
373 }
374
375 var network *network
376 if networkName != "" {
377 network = u.getNetwork(networkName)
378 if network == nil {
379 addr := networkName
380 if !strings.ContainsRune(addr, ':') {
381 addr = addr + ":6697"
382 }
383
384 dc.logger.Printf("trying to connect to new upstream server %q", addr)
385 if err := sanityCheckServer(addr); err != nil {
386 dc.logger.Printf("failed to connect to %q: %v", addr, err)
387 return ircError{&irc.Message{
388 Command: irc.ERR_PASSWDMISMATCH,
389 Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
390 }}
391 }
392
393 dc.logger.Printf("auto-adding network %q", networkName)
394 network, err = u.createNetwork(networkName, dc.nick)
395 if err != nil {
396 return err
397 }
398 }
399 }
400
401 dc.registered = true
402 dc.user = u
403 dc.network = network
404
405 u.lock.Lock()
406 firstDownstream := len(u.downstreamConns) == 0
407 u.downstreamConns = append(u.downstreamConns, dc)
408 u.lock.Unlock()
409
410 dc.SendMessage(&irc.Message{
411 Prefix: dc.srv.prefix(),
412 Command: irc.RPL_WELCOME,
413 Params: []string{dc.nick, "Welcome to jounce, " + dc.nick},
414 })
415 dc.SendMessage(&irc.Message{
416 Prefix: dc.srv.prefix(),
417 Command: irc.RPL_YOURHOST,
418 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
419 })
420 dc.SendMessage(&irc.Message{
421 Prefix: dc.srv.prefix(),
422 Command: irc.RPL_CREATED,
423 Params: []string{dc.nick, "Who cares when the server was created?"},
424 })
425 dc.SendMessage(&irc.Message{
426 Prefix: dc.srv.prefix(),
427 Command: irc.RPL_MYINFO,
428 Params: []string{dc.nick, dc.srv.Hostname, "jounce", "aiwroO", "OovaimnqpsrtklbeI"},
429 })
430 dc.SendMessage(&irc.Message{
431 Prefix: dc.srv.prefix(),
432 Command: irc.ERR_NOMOTD,
433 Params: []string{dc.nick, "No MOTD"},
434 })
435
436 dc.forEachUpstream(func(uc *upstreamConn) {
437 // TODO: fix races accessing upstream connection data
438 for _, ch := range uc.channels {
439 if ch.complete {
440 forwardChannel(dc, ch)
441 }
442 }
443
444 historyName := dc.username
445
446 var seqPtr *uint64
447 if firstDownstream {
448 seq, ok := uc.history[historyName]
449 if ok {
450 seqPtr = &seq
451 }
452 }
453
454 consumer, ch := uc.ring.NewConsumer(seqPtr)
455 go func() {
456 for {
457 var closed bool
458 select {
459 case <-ch:
460 dc.consumptions <- consumption{consumer, uc}
461 case <-dc.closed:
462 closed = true
463 }
464 if closed {
465 break
466 }
467 }
468
469 seq := consumer.Close()
470
471 dc.user.lock.Lock()
472 lastDownstream := len(dc.user.downstreamConns) == 0
473 dc.user.lock.Unlock()
474
475 if lastDownstream {
476 uc.history[historyName] = seq
477 }
478 }()
479 })
480
481 return nil
482}
483
484func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
485 switch msg.Command {
486 case "USER":
487 return ircError{&irc.Message{
488 Command: irc.ERR_ALREADYREGISTERED,
489 Params: []string{dc.nick, "You may not reregister"},
490 }}
491 case "NICK":
492 var nick string
493 if err := parseMessageParams(msg, &nick); err != nil {
494 return err
495 }
496
497 var err error
498 dc.forEachNetwork(func(n *network) {
499 if err != nil {
500 return
501 }
502 n.Nick = nick
503 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
504 })
505 if err != nil {
506 return err
507 }
508
509 dc.forEachUpstream(func(uc *upstreamConn) {
510 uc.SendMessage(msg)
511 })
512 case "JOIN", "PART":
513 var name string
514 if err := parseMessageParams(msg, &name); err != nil {
515 return err
516 }
517
518 uc, upstreamName, err := dc.unmarshalChannel(name)
519 if err != nil {
520 return ircError{&irc.Message{
521 Command: irc.ERR_NOSUCHCHANNEL,
522 Params: []string{name, err.Error()},
523 }}
524 }
525
526 uc.SendMessage(&irc.Message{
527 Command: msg.Command,
528 Params: []string{upstreamName},
529 })
530
531 switch msg.Command {
532 case "JOIN":
533 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
534 Name: upstreamName,
535 })
536 if err != nil {
537 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
538 }
539 case "PART":
540 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
541 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
542 }
543 }
544 case "MODE":
545 if msg.Prefix == nil {
546 return fmt.Errorf("missing prefix")
547 }
548
549 var name string
550 if err := parseMessageParams(msg, &name); err != nil {
551 return err
552 }
553
554 var modeStr string
555 if len(msg.Params) > 1 {
556 modeStr = msg.Params[1]
557 }
558
559 if msg.Prefix.Name != name {
560 uc, upstreamName, err := dc.unmarshalChannel(name)
561 if err != nil {
562 return err
563 }
564
565 if modeStr != "" {
566 uc.SendMessage(&irc.Message{
567 Command: "MODE",
568 Params: []string{upstreamName, modeStr},
569 })
570 } else {
571 ch, ok := uc.channels[upstreamName]
572 if !ok {
573 return ircError{&irc.Message{
574 Command: irc.ERR_NOSUCHCHANNEL,
575 Params: []string{name, "No such channel"},
576 }}
577 }
578
579 dc.SendMessage(&irc.Message{
580 Prefix: dc.srv.prefix(),
581 Command: irc.RPL_CHANNELMODEIS,
582 Params: []string{name, string(ch.modes)},
583 })
584 }
585 } else {
586 if name != dc.nick {
587 return ircError{&irc.Message{
588 Command: irc.ERR_USERSDONTMATCH,
589 Params: []string{dc.nick, "Cannot change mode for other users"},
590 }}
591 }
592
593 if modeStr != "" {
594 dc.forEachUpstream(func(uc *upstreamConn) {
595 uc.SendMessage(&irc.Message{
596 Command: "MODE",
597 Params: []string{uc.nick, modeStr},
598 })
599 })
600 } else {
601 dc.SendMessage(&irc.Message{
602 Prefix: dc.srv.prefix(),
603 Command: irc.RPL_UMODEIS,
604 Params: []string{""}, // TODO
605 })
606 }
607 }
608 case "PRIVMSG":
609 var targetsStr, text string
610 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
611 return err
612 }
613
614 for _, name := range strings.Split(targetsStr, ",") {
615 uc, upstreamName, err := dc.unmarshalChannel(name)
616 if err != nil {
617 return err
618 }
619
620 uc.SendMessage(&irc.Message{
621 Command: "PRIVMSG",
622 Params: []string{upstreamName, text},
623 })
624 }
625 default:
626 dc.logger.Printf("unhandled message: %v", msg)
627 return newUnknownCommandError(msg.Command)
628 }
629 return nil
630}
Note: See TracBrowser for help on using the repository browser.