source: code/trunk/downstream.go@ 90

Last change on this file since 90 was 90, checked in by contact, 5 years ago

Store NICK changes in the DB

File size: 13.0 KB
Line 
1package jounce
2
3import (
4 "fmt"
5 "io"
6 "net"
7 "strings"
8
9 "golang.org/x/crypto/bcrypt"
10 "gopkg.in/irc.v3"
11)
12
13type ircError struct {
14 Message *irc.Message
15}
16
17func (err ircError) Error() string {
18 return err.Message.String()
19}
20
21func newUnknownCommandError(cmd string) ircError {
22 return ircError{&irc.Message{
23 Command: irc.ERR_UNKNOWNCOMMAND,
24 Params: []string{
25 "*",
26 cmd,
27 "Unknown command",
28 },
29 }}
30}
31
32func newNeedMoreParamsError(cmd string) ircError {
33 return ircError{&irc.Message{
34 Command: irc.ERR_NEEDMOREPARAMS,
35 Params: []string{
36 "*",
37 cmd,
38 "Not enough parameters",
39 },
40 }}
41}
42
43var errAuthFailed = ircError{&irc.Message{
44 Command: irc.ERR_PASSWDMISMATCH,
45 Params: []string{"*", "Invalid username or password"},
46}}
47
48type consumption struct {
49 consumer *RingConsumer
50 upstreamConn *upstreamConn
51}
52
53type downstreamConn struct {
54 net net.Conn
55 irc *irc.Conn
56 srv *Server
57 logger Logger
58 messages chan *irc.Message
59 consumptions chan consumption
60 closed chan struct{}
61
62 registered bool
63 user *user
64 nick string
65 username string
66 realname string
67 password string // empty after authentication
68 network *network // can be nil
69}
70
71func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
72 dc := &downstreamConn{
73 net: netConn,
74 irc: irc.NewConn(netConn),
75 srv: srv,
76 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
77 messages: make(chan *irc.Message, 64),
78 consumptions: make(chan consumption),
79 closed: make(chan struct{}),
80 }
81
82 go func() {
83 if err := dc.writeMessages(); err != nil {
84 dc.logger.Printf("failed to write message: %v", err)
85 }
86 if err := dc.net.Close(); err != nil {
87 dc.logger.Printf("failed to close connection: %v", err)
88 } else {
89 dc.logger.Printf("connection closed")
90 }
91 }()
92
93 return dc
94}
95
96func (dc *downstreamConn) prefix() *irc.Prefix {
97 return &irc.Prefix{
98 Name: dc.nick,
99 User: dc.username,
100 // TODO: fill the host?
101 }
102}
103
104func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
105 return name
106}
107
108func (dc *downstreamConn) forEachNetwork(f func(*network)) {
109 if dc.network != nil {
110 f(dc.network)
111 } else {
112 dc.user.forEachNetwork(f)
113 }
114}
115
116func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
117 dc.user.forEachUpstream(func(uc *upstreamConn) {
118 if dc.network != nil && uc.network != dc.network {
119 return
120 }
121 f(uc)
122 })
123}
124
125// upstream returns the upstream connection, if any. If there are zero or if
126// there are multiple upstream connections, it returns nil.
127func (dc *downstreamConn) upstream() *upstreamConn {
128 if dc.network == nil {
129 return nil
130 }
131
132 var upstream *upstreamConn
133 dc.forEachUpstream(func(uc *upstreamConn) {
134 upstream = uc
135 })
136 return upstream
137}
138
139func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
140 if uc := dc.upstream(); uc != nil {
141 return uc, name, nil
142 }
143
144 // TODO: extract network name from channel name if dc.upstream == nil
145 var channel *upstreamChannel
146 var err error
147 dc.forEachUpstream(func(uc *upstreamConn) {
148 if err != nil {
149 return
150 }
151 if ch, ok := uc.channels[name]; ok {
152 if channel != nil {
153 err = fmt.Errorf("ambiguous channel name %q", name)
154 } else {
155 channel = ch
156 }
157 }
158 })
159 if channel == nil {
160 return nil, "", ircError{&irc.Message{
161 Command: irc.ERR_NOSUCHCHANNEL,
162 Params: []string{name, "No such channel"},
163 }}
164 }
165 return channel.conn, channel.Name, nil
166}
167
168func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
169 if nick == uc.nick {
170 return dc.nick
171 }
172 return nick
173}
174
175func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
176 if prefix.Name == uc.nick {
177 return dc.prefix()
178 }
179 return prefix
180}
181
182func (dc *downstreamConn) isClosed() bool {
183 select {
184 case <-dc.closed:
185 return true
186 default:
187 return false
188 }
189}
190
191func (dc *downstreamConn) readMessages() error {
192 dc.logger.Printf("new connection")
193
194 for {
195 msg, err := dc.irc.ReadMessage()
196 if err == io.EOF {
197 break
198 } else if err != nil {
199 return fmt.Errorf("failed to read IRC command: %v", err)
200 }
201
202 if dc.srv.Debug {
203 dc.logger.Printf("received: %v", msg)
204 }
205
206 err = dc.handleMessage(msg)
207 if ircErr, ok := err.(ircError); ok {
208 ircErr.Message.Prefix = dc.srv.prefix()
209 dc.SendMessage(ircErr.Message)
210 } else if err != nil {
211 return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
212 }
213
214 if dc.isClosed() {
215 return nil
216 }
217 }
218
219 return nil
220}
221
222func (dc *downstreamConn) writeMessages() error {
223 for {
224 var err error
225 var closed bool
226 select {
227 case msg := <-dc.messages:
228 if dc.srv.Debug {
229 dc.logger.Printf("sent: %v", msg)
230 }
231 err = dc.irc.WriteMessage(msg)
232 case consumption := <-dc.consumptions:
233 consumer, uc := consumption.consumer, consumption.upstreamConn
234 for {
235 msg := consumer.Peek()
236 if msg == nil {
237 break
238 }
239 msg = msg.Copy()
240 switch msg.Command {
241 case "PRIVMSG":
242 // TODO: detect whether it's a user or a channel
243 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
244 default:
245 panic("expected to consume a PRIVMSG message")
246 }
247 if dc.srv.Debug {
248 dc.logger.Printf("sent: %v", msg)
249 }
250 err = dc.irc.WriteMessage(msg)
251 if err != nil {
252 break
253 }
254 consumer.Consume()
255 }
256 case <-dc.closed:
257 closed = true
258 }
259 if err != nil {
260 return err
261 }
262 if closed {
263 break
264 }
265 }
266 return nil
267}
268
269func (dc *downstreamConn) Close() error {
270 if dc.isClosed() {
271 return fmt.Errorf("downstream connection already closed")
272 }
273
274 if u := dc.user; u != nil {
275 u.lock.Lock()
276 for i := range u.downstreamConns {
277 if u.downstreamConns[i] == dc {
278 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
279 break
280 }
281 }
282 u.lock.Unlock()
283 }
284
285 close(dc.closed)
286 return nil
287}
288
289func (dc *downstreamConn) SendMessage(msg *irc.Message) {
290 dc.messages <- msg
291}
292
293func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
294 switch msg.Command {
295 case "QUIT":
296 return dc.Close()
297 case "PING":
298 dc.SendMessage(&irc.Message{
299 Prefix: dc.srv.prefix(),
300 Command: "PONG",
301 Params: msg.Params,
302 })
303 return nil
304 default:
305 if dc.registered {
306 return dc.handleMessageRegistered(msg)
307 } else {
308 return dc.handleMessageUnregistered(msg)
309 }
310 }
311}
312
313func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
314 switch msg.Command {
315 case "NICK":
316 if err := parseMessageParams(msg, &dc.nick); err != nil {
317 return err
318 }
319 case "USER":
320 var username string
321 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
322 return err
323 }
324 dc.username = "~" + username
325 case "PASS":
326 if err := parseMessageParams(msg, &dc.password); err != nil {
327 return err
328 }
329 default:
330 dc.logger.Printf("unhandled message: %v", msg)
331 return newUnknownCommandError(msg.Command)
332 }
333 if dc.username != "" && dc.nick != "" {
334 return dc.register()
335 }
336 return nil
337}
338
339func (dc *downstreamConn) register() error {
340 username := strings.TrimPrefix(dc.username, "~")
341 var networkName string
342 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
343 networkName = username[i+1:]
344 }
345 if i := strings.IndexAny(username, "/@"); i >= 0 {
346 username = username[:i]
347 }
348
349 password := dc.password
350 dc.password = ""
351
352 u := dc.srv.getUser(username)
353 if u == nil {
354 dc.logger.Printf("failed authentication for %q: unknown username", username)
355 return errAuthFailed
356 }
357
358 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
359 if err != nil {
360 dc.logger.Printf("failed authentication for %q: %v", username, err)
361 return errAuthFailed
362 }
363
364 var network *network
365 if networkName != "" {
366 network = u.getNetwork(networkName)
367 if network == nil {
368 dc.logger.Printf("failed registration: unknown network %q", networkName)
369 dc.SendMessage(&irc.Message{
370 Prefix: dc.srv.prefix(),
371 Command: irc.ERR_PASSWDMISMATCH,
372 Params: []string{"*", fmt.Sprintf("Unknown network %q", networkName)},
373 })
374 return nil
375 }
376 }
377
378 dc.registered = true
379 dc.user = u
380 dc.network = network
381
382 u.lock.Lock()
383 firstDownstream := len(u.downstreamConns) == 0
384 u.downstreamConns = append(u.downstreamConns, dc)
385 u.lock.Unlock()
386
387 dc.SendMessage(&irc.Message{
388 Prefix: dc.srv.prefix(),
389 Command: irc.RPL_WELCOME,
390 Params: []string{dc.nick, "Welcome to jounce, " + dc.nick},
391 })
392 dc.SendMessage(&irc.Message{
393 Prefix: dc.srv.prefix(),
394 Command: irc.RPL_YOURHOST,
395 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
396 })
397 dc.SendMessage(&irc.Message{
398 Prefix: dc.srv.prefix(),
399 Command: irc.RPL_CREATED,
400 Params: []string{dc.nick, "Who cares when the server was created?"},
401 })
402 dc.SendMessage(&irc.Message{
403 Prefix: dc.srv.prefix(),
404 Command: irc.RPL_MYINFO,
405 Params: []string{dc.nick, dc.srv.Hostname, "jounce", "aiwroO", "OovaimnqpsrtklbeI"},
406 })
407 dc.SendMessage(&irc.Message{
408 Prefix: dc.srv.prefix(),
409 Command: irc.ERR_NOMOTD,
410 Params: []string{dc.nick, "No MOTD"},
411 })
412
413 dc.forEachUpstream(func(uc *upstreamConn) {
414 // TODO: fix races accessing upstream connection data
415 for _, ch := range uc.channels {
416 if ch.complete {
417 forwardChannel(dc, ch)
418 }
419 }
420
421 historyName := dc.username
422
423 var seqPtr *uint64
424 if firstDownstream {
425 seq, ok := uc.history[historyName]
426 if ok {
427 seqPtr = &seq
428 }
429 }
430
431 consumer, ch := uc.ring.NewConsumer(seqPtr)
432 go func() {
433 for {
434 var closed bool
435 select {
436 case <-ch:
437 dc.consumptions <- consumption{consumer, uc}
438 case <-dc.closed:
439 closed = true
440 }
441 if closed {
442 break
443 }
444 }
445
446 seq := consumer.Close()
447
448 dc.user.lock.Lock()
449 lastDownstream := len(dc.user.downstreamConns) == 0
450 dc.user.lock.Unlock()
451
452 if lastDownstream {
453 uc.history[historyName] = seq
454 }
455 }()
456 })
457
458 return nil
459}
460
461func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
462 switch msg.Command {
463 case "USER":
464 return ircError{&irc.Message{
465 Command: irc.ERR_ALREADYREGISTERED,
466 Params: []string{dc.nick, "You may not reregister"},
467 }}
468 case "NICK":
469 var nick string
470 if err := parseMessageParams(msg, &nick); err != nil {
471 return err
472 }
473
474 var err error
475 dc.forEachNetwork(func(n *network) {
476 if err != nil {
477 return
478 }
479 n.Nick = nick
480 err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network)
481 })
482 if err != nil {
483 return err
484 }
485
486 dc.forEachUpstream(func(uc *upstreamConn) {
487 uc.SendMessage(msg)
488 })
489 case "JOIN", "PART":
490 var name string
491 if err := parseMessageParams(msg, &name); err != nil {
492 return err
493 }
494
495 uc, upstreamName, err := dc.unmarshalChannel(name)
496 if err != nil {
497 return ircError{&irc.Message{
498 Command: irc.ERR_NOSUCHCHANNEL,
499 Params: []string{name, err.Error()},
500 }}
501 }
502
503 uc.SendMessage(&irc.Message{
504 Command: msg.Command,
505 Params: []string{upstreamName},
506 })
507
508 switch msg.Command {
509 case "JOIN":
510 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
511 Name: upstreamName,
512 })
513 if err != nil {
514 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
515 }
516 case "PART":
517 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
518 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
519 }
520 }
521 case "MODE":
522 if msg.Prefix == nil {
523 return fmt.Errorf("missing prefix")
524 }
525
526 var name string
527 if err := parseMessageParams(msg, &name); err != nil {
528 return err
529 }
530
531 var modeStr string
532 if len(msg.Params) > 1 {
533 modeStr = msg.Params[1]
534 }
535
536 if msg.Prefix.Name != name {
537 uc, upstreamName, err := dc.unmarshalChannel(name)
538 if err != nil {
539 return err
540 }
541
542 if modeStr != "" {
543 uc.SendMessage(&irc.Message{
544 Command: "MODE",
545 Params: []string{upstreamName, modeStr},
546 })
547 } else {
548 ch, ok := uc.channels[upstreamName]
549 if !ok {
550 return ircError{&irc.Message{
551 Command: irc.ERR_NOSUCHCHANNEL,
552 Params: []string{name, "No such channel"},
553 }}
554 }
555
556 dc.SendMessage(&irc.Message{
557 Prefix: dc.srv.prefix(),
558 Command: irc.RPL_CHANNELMODEIS,
559 Params: []string{name, string(ch.modes)},
560 })
561 }
562 } else {
563 if name != dc.nick {
564 return ircError{&irc.Message{
565 Command: irc.ERR_USERSDONTMATCH,
566 Params: []string{dc.nick, "Cannot change mode for other users"},
567 }}
568 }
569
570 if modeStr != "" {
571 dc.forEachUpstream(func(uc *upstreamConn) {
572 uc.SendMessage(&irc.Message{
573 Command: "MODE",
574 Params: []string{uc.nick, modeStr},
575 })
576 })
577 } else {
578 dc.SendMessage(&irc.Message{
579 Prefix: dc.srv.prefix(),
580 Command: irc.RPL_UMODEIS,
581 Params: []string{""}, // TODO
582 })
583 }
584 }
585 case "PRIVMSG":
586 var targetsStr, text string
587 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
588 return err
589 }
590
591 for _, name := range strings.Split(targetsStr, ",") {
592 uc, upstreamName, err := dc.unmarshalChannel(name)
593 if err != nil {
594 return err
595 }
596
597 uc.SendMessage(&irc.Message{
598 Command: "PRIVMSG",
599 Params: []string{upstreamName, text},
600 })
601 }
602 default:
603 dc.logger.Printf("unhandled message: %v", msg)
604 return newUnknownCommandError(msg.Command)
605 }
606 return nil
607}
Note: See TracBrowser for help on using the repository browser.