source: code/trunk/downstream.go@ 89

Last change on this file since 89 was 89, checked in by contact, 5 years ago

Update DB on JOIN and PART

File size: 12.5 KB
Line 
1package jounce
2
3import (
4 "fmt"
5 "io"
6 "net"
7 "strings"
8
9 "golang.org/x/crypto/bcrypt"
10 "gopkg.in/irc.v3"
11)
12
13type ircError struct {
14 Message *irc.Message
15}
16
17func (err ircError) Error() string {
18 return err.Message.String()
19}
20
21func newUnknownCommandError(cmd string) ircError {
22 return ircError{&irc.Message{
23 Command: irc.ERR_UNKNOWNCOMMAND,
24 Params: []string{
25 "*",
26 cmd,
27 "Unknown command",
28 },
29 }}
30}
31
32func newNeedMoreParamsError(cmd string) ircError {
33 return ircError{&irc.Message{
34 Command: irc.ERR_NEEDMOREPARAMS,
35 Params: []string{
36 "*",
37 cmd,
38 "Not enough parameters",
39 },
40 }}
41}
42
43var errAuthFailed = ircError{&irc.Message{
44 Command: irc.ERR_PASSWDMISMATCH,
45 Params: []string{"*", "Invalid username or password"},
46}}
47
48type consumption struct {
49 consumer *RingConsumer
50 upstreamConn *upstreamConn
51}
52
53type downstreamConn struct {
54 net net.Conn
55 irc *irc.Conn
56 srv *Server
57 logger Logger
58 messages chan *irc.Message
59 consumptions chan consumption
60 closed chan struct{}
61
62 registered bool
63 user *user
64 nick string
65 username string
66 realname string
67 password string // empty after authentication
68 network *network // can be nil
69}
70
71func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
72 dc := &downstreamConn{
73 net: netConn,
74 irc: irc.NewConn(netConn),
75 srv: srv,
76 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
77 messages: make(chan *irc.Message, 64),
78 consumptions: make(chan consumption),
79 closed: make(chan struct{}),
80 }
81
82 go func() {
83 if err := dc.writeMessages(); err != nil {
84 dc.logger.Printf("failed to write message: %v", err)
85 }
86 if err := dc.net.Close(); err != nil {
87 dc.logger.Printf("failed to close connection: %v", err)
88 } else {
89 dc.logger.Printf("connection closed")
90 }
91 }()
92
93 return dc
94}
95
96func (dc *downstreamConn) prefix() *irc.Prefix {
97 return &irc.Prefix{
98 Name: dc.nick,
99 User: dc.username,
100 // TODO: fill the host?
101 }
102}
103
104func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
105 return name
106}
107
108func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
109 dc.user.forEachUpstream(func(uc *upstreamConn) {
110 if dc.network != nil && uc.network != dc.network {
111 return
112 }
113 f(uc)
114 })
115}
116
117// upstream returns the upstream connection, if any. If there are zero or if
118// there are multiple upstream connections, it returns nil.
119func (dc *downstreamConn) upstream() *upstreamConn {
120 if dc.network == nil {
121 return nil
122 }
123
124 var upstream *upstreamConn
125 dc.forEachUpstream(func(uc *upstreamConn) {
126 upstream = uc
127 })
128 return upstream
129}
130
131func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
132 if uc := dc.upstream(); uc != nil {
133 return uc, name, nil
134 }
135
136 // TODO: extract network name from channel name if dc.upstream == nil
137 var channel *upstreamChannel
138 var err error
139 dc.forEachUpstream(func(uc *upstreamConn) {
140 if err != nil {
141 return
142 }
143 if ch, ok := uc.channels[name]; ok {
144 if channel != nil {
145 err = fmt.Errorf("ambiguous channel name %q", name)
146 } else {
147 channel = ch
148 }
149 }
150 })
151 if channel == nil {
152 return nil, "", ircError{&irc.Message{
153 Command: irc.ERR_NOSUCHCHANNEL,
154 Params: []string{name, "No such channel"},
155 }}
156 }
157 return channel.conn, channel.Name, nil
158}
159
160func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
161 if nick == uc.nick {
162 return dc.nick
163 }
164 return nick
165}
166
167func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
168 if prefix.Name == uc.nick {
169 return dc.prefix()
170 }
171 return prefix
172}
173
174func (dc *downstreamConn) isClosed() bool {
175 select {
176 case <-dc.closed:
177 return true
178 default:
179 return false
180 }
181}
182
183func (dc *downstreamConn) readMessages() error {
184 dc.logger.Printf("new connection")
185
186 for {
187 msg, err := dc.irc.ReadMessage()
188 if err == io.EOF {
189 break
190 } else if err != nil {
191 return fmt.Errorf("failed to read IRC command: %v", err)
192 }
193
194 if dc.srv.Debug {
195 dc.logger.Printf("received: %v", msg)
196 }
197
198 err = dc.handleMessage(msg)
199 if ircErr, ok := err.(ircError); ok {
200 ircErr.Message.Prefix = dc.srv.prefix()
201 dc.SendMessage(ircErr.Message)
202 } else if err != nil {
203 return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
204 }
205
206 if dc.isClosed() {
207 return nil
208 }
209 }
210
211 return nil
212}
213
214func (dc *downstreamConn) writeMessages() error {
215 for {
216 var err error
217 var closed bool
218 select {
219 case msg := <-dc.messages:
220 if dc.srv.Debug {
221 dc.logger.Printf("sent: %v", msg)
222 }
223 err = dc.irc.WriteMessage(msg)
224 case consumption := <-dc.consumptions:
225 consumer, uc := consumption.consumer, consumption.upstreamConn
226 for {
227 msg := consumer.Peek()
228 if msg == nil {
229 break
230 }
231 msg = msg.Copy()
232 switch msg.Command {
233 case "PRIVMSG":
234 // TODO: detect whether it's a user or a channel
235 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
236 default:
237 panic("expected to consume a PRIVMSG message")
238 }
239 if dc.srv.Debug {
240 dc.logger.Printf("sent: %v", msg)
241 }
242 err = dc.irc.WriteMessage(msg)
243 if err != nil {
244 break
245 }
246 consumer.Consume()
247 }
248 case <-dc.closed:
249 closed = true
250 }
251 if err != nil {
252 return err
253 }
254 if closed {
255 break
256 }
257 }
258 return nil
259}
260
261func (dc *downstreamConn) Close() error {
262 if dc.isClosed() {
263 return fmt.Errorf("downstream connection already closed")
264 }
265
266 if u := dc.user; u != nil {
267 u.lock.Lock()
268 for i := range u.downstreamConns {
269 if u.downstreamConns[i] == dc {
270 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
271 break
272 }
273 }
274 u.lock.Unlock()
275 }
276
277 close(dc.closed)
278 return nil
279}
280
281func (dc *downstreamConn) SendMessage(msg *irc.Message) {
282 dc.messages <- msg
283}
284
285func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
286 switch msg.Command {
287 case "QUIT":
288 return dc.Close()
289 case "PING":
290 dc.SendMessage(&irc.Message{
291 Prefix: dc.srv.prefix(),
292 Command: "PONG",
293 Params: msg.Params,
294 })
295 return nil
296 default:
297 if dc.registered {
298 return dc.handleMessageRegistered(msg)
299 } else {
300 return dc.handleMessageUnregistered(msg)
301 }
302 }
303}
304
305func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
306 switch msg.Command {
307 case "NICK":
308 if err := parseMessageParams(msg, &dc.nick); err != nil {
309 return err
310 }
311 case "USER":
312 var username string
313 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
314 return err
315 }
316 dc.username = "~" + username
317 case "PASS":
318 if err := parseMessageParams(msg, &dc.password); err != nil {
319 return err
320 }
321 default:
322 dc.logger.Printf("unhandled message: %v", msg)
323 return newUnknownCommandError(msg.Command)
324 }
325 if dc.username != "" && dc.nick != "" {
326 return dc.register()
327 }
328 return nil
329}
330
331func (dc *downstreamConn) register() error {
332 username := strings.TrimPrefix(dc.username, "~")
333 var networkName string
334 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
335 networkName = username[i+1:]
336 }
337 if i := strings.IndexAny(username, "/@"); i >= 0 {
338 username = username[:i]
339 }
340
341 password := dc.password
342 dc.password = ""
343
344 u := dc.srv.getUser(username)
345 if u == nil {
346 dc.logger.Printf("failed authentication for %q: unknown username", username)
347 return errAuthFailed
348 }
349
350 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
351 if err != nil {
352 dc.logger.Printf("failed authentication for %q: %v", username, err)
353 return errAuthFailed
354 }
355
356 var network *network
357 if networkName != "" {
358 network = u.getNetwork(networkName)
359 if network == nil {
360 dc.logger.Printf("failed registration: unknown network %q", networkName)
361 dc.SendMessage(&irc.Message{
362 Prefix: dc.srv.prefix(),
363 Command: irc.ERR_PASSWDMISMATCH,
364 Params: []string{"*", fmt.Sprintf("Unknown network %q", networkName)},
365 })
366 return nil
367 }
368 }
369
370 dc.registered = true
371 dc.user = u
372 dc.network = network
373
374 u.lock.Lock()
375 firstDownstream := len(u.downstreamConns) == 0
376 u.downstreamConns = append(u.downstreamConns, dc)
377 u.lock.Unlock()
378
379 dc.SendMessage(&irc.Message{
380 Prefix: dc.srv.prefix(),
381 Command: irc.RPL_WELCOME,
382 Params: []string{dc.nick, "Welcome to jounce, " + dc.nick},
383 })
384 dc.SendMessage(&irc.Message{
385 Prefix: dc.srv.prefix(),
386 Command: irc.RPL_YOURHOST,
387 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
388 })
389 dc.SendMessage(&irc.Message{
390 Prefix: dc.srv.prefix(),
391 Command: irc.RPL_CREATED,
392 Params: []string{dc.nick, "Who cares when the server was created?"},
393 })
394 dc.SendMessage(&irc.Message{
395 Prefix: dc.srv.prefix(),
396 Command: irc.RPL_MYINFO,
397 Params: []string{dc.nick, dc.srv.Hostname, "jounce", "aiwroO", "OovaimnqpsrtklbeI"},
398 })
399 dc.SendMessage(&irc.Message{
400 Prefix: dc.srv.prefix(),
401 Command: irc.ERR_NOMOTD,
402 Params: []string{dc.nick, "No MOTD"},
403 })
404
405 dc.forEachUpstream(func(uc *upstreamConn) {
406 // TODO: fix races accessing upstream connection data
407 for _, ch := range uc.channels {
408 if ch.complete {
409 forwardChannel(dc, ch)
410 }
411 }
412
413 historyName := dc.username
414
415 var seqPtr *uint64
416 if firstDownstream {
417 seq, ok := uc.history[historyName]
418 if ok {
419 seqPtr = &seq
420 }
421 }
422
423 consumer, ch := uc.ring.NewConsumer(seqPtr)
424 go func() {
425 for {
426 var closed bool
427 select {
428 case <-ch:
429 dc.consumptions <- consumption{consumer, uc}
430 case <-dc.closed:
431 closed = true
432 }
433 if closed {
434 break
435 }
436 }
437
438 seq := consumer.Close()
439
440 dc.user.lock.Lock()
441 lastDownstream := len(dc.user.downstreamConns) == 0
442 dc.user.lock.Unlock()
443
444 if lastDownstream {
445 uc.history[historyName] = seq
446 }
447 }()
448 })
449
450 return nil
451}
452
453func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
454 switch msg.Command {
455 case "USER":
456 return ircError{&irc.Message{
457 Command: irc.ERR_ALREADYREGISTERED,
458 Params: []string{dc.nick, "You may not reregister"},
459 }}
460 case "NICK":
461 dc.forEachUpstream(func(uc *upstreamConn) {
462 uc.SendMessage(msg)
463 })
464 case "JOIN", "PART":
465 var name string
466 if err := parseMessageParams(msg, &name); err != nil {
467 return err
468 }
469
470 uc, upstreamName, err := dc.unmarshalChannel(name)
471 if err != nil {
472 return ircError{&irc.Message{
473 Command: irc.ERR_NOSUCHCHANNEL,
474 Params: []string{name, err.Error()},
475 }}
476 }
477
478 uc.SendMessage(&irc.Message{
479 Command: msg.Command,
480 Params: []string{upstreamName},
481 })
482
483 switch msg.Command {
484 case "JOIN":
485 err := dc.srv.db.StoreChannel(uc.network.ID, &Channel{
486 Name: upstreamName,
487 })
488 if err != nil {
489 dc.logger.Printf("failed to create channel %q in DB: %v", upstreamName, err)
490 }
491 case "PART":
492 if err := dc.srv.db.DeleteChannel(uc.network.ID, upstreamName); err != nil {
493 dc.logger.Printf("failed to delete channel %q in DB: %v", upstreamName, err)
494 }
495 }
496 case "MODE":
497 if msg.Prefix == nil {
498 return fmt.Errorf("missing prefix")
499 }
500
501 var name string
502 if err := parseMessageParams(msg, &name); err != nil {
503 return err
504 }
505
506 var modeStr string
507 if len(msg.Params) > 1 {
508 modeStr = msg.Params[1]
509 }
510
511 if msg.Prefix.Name != name {
512 uc, upstreamName, err := dc.unmarshalChannel(name)
513 if err != nil {
514 return err
515 }
516
517 if modeStr != "" {
518 uc.SendMessage(&irc.Message{
519 Command: "MODE",
520 Params: []string{upstreamName, modeStr},
521 })
522 } else {
523 ch, ok := uc.channels[upstreamName]
524 if !ok {
525 return ircError{&irc.Message{
526 Command: irc.ERR_NOSUCHCHANNEL,
527 Params: []string{name, "No such channel"},
528 }}
529 }
530
531 dc.SendMessage(&irc.Message{
532 Prefix: dc.srv.prefix(),
533 Command: irc.RPL_CHANNELMODEIS,
534 Params: []string{name, string(ch.modes)},
535 })
536 }
537 } else {
538 if name != dc.nick {
539 return ircError{&irc.Message{
540 Command: irc.ERR_USERSDONTMATCH,
541 Params: []string{dc.nick, "Cannot change mode for other users"},
542 }}
543 }
544
545 if modeStr != "" {
546 dc.forEachUpstream(func(uc *upstreamConn) {
547 uc.SendMessage(&irc.Message{
548 Command: "MODE",
549 Params: []string{uc.nick, modeStr},
550 })
551 })
552 } else {
553 dc.SendMessage(&irc.Message{
554 Prefix: dc.srv.prefix(),
555 Command: irc.RPL_UMODEIS,
556 Params: []string{""}, // TODO
557 })
558 }
559 }
560 case "PRIVMSG":
561 var targetsStr, text string
562 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
563 return err
564 }
565
566 for _, name := range strings.Split(targetsStr, ",") {
567 uc, upstreamName, err := dc.unmarshalChannel(name)
568 if err != nil {
569 return err
570 }
571
572 uc.SendMessage(&irc.Message{
573 Command: "PRIVMSG",
574 Params: []string{upstreamName, text},
575 })
576 }
577 default:
578 dc.logger.Printf("unhandled message: %v", msg)
579 return newUnknownCommandError(msg.Command)
580 }
581 return nil
582}
Note: See TracBrowser for help on using the repository browser.