source: code/trunk/user.go@ 499

Last change on this file since 499 was 499, checked in by contact, 4 years ago

Relay detached channel backlog as BouncerServ NOTICE if necessary

Instead of ignoring detached channels wehn replaying backlog,
process them as usual and relay messages as BouncerServ NOTICEs
if necessary. Advance the delivery receipts as if the channel was
attached.

Closes: https://todo.sr.ht/~emersion/soju/98

File size: 17.6 KB
Line 
1package soju
2
3import (
4 "crypto/sha256"
5 "encoding/binary"
6 "encoding/hex"
7 "fmt"
8 "time"
9
10 "gopkg.in/irc.v3"
11)
12
13type event interface{}
14
15type eventUpstreamMessage struct {
16 msg *irc.Message
17 uc *upstreamConn
18}
19
20type eventUpstreamConnectionError struct {
21 net *network
22 err error
23}
24
25type eventUpstreamConnected struct {
26 uc *upstreamConn
27}
28
29type eventUpstreamDisconnected struct {
30 uc *upstreamConn
31}
32
33type eventUpstreamError struct {
34 uc *upstreamConn
35 err error
36}
37
38type eventDownstreamMessage struct {
39 msg *irc.Message
40 dc *downstreamConn
41}
42
43type eventDownstreamConnected struct {
44 dc *downstreamConn
45}
46
47type eventDownstreamDisconnected struct {
48 dc *downstreamConn
49}
50
51type eventChannelDetach struct {
52 uc *upstreamConn
53 name string
54}
55
56type eventStop struct{}
57
58type deliveredClientMap map[string]string // client name -> msg ID
59
60type deliveredStore struct {
61 m deliveredCasemapMap
62}
63
64func newDeliveredStore() deliveredStore {
65 return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
66}
67
68func (ds deliveredStore) HasTarget(target string) bool {
69 return ds.m.Value(target) != nil
70}
71
72func (ds deliveredStore) LoadID(target, clientName string) string {
73 clients := ds.m.Value(target)
74 if clients == nil {
75 return ""
76 }
77 return clients[clientName]
78}
79
80func (ds deliveredStore) StoreID(target, clientName, msgID string) {
81 clients := ds.m.Value(target)
82 if clients == nil {
83 clients = make(deliveredClientMap)
84 ds.m.SetValue(target, clients)
85 }
86 clients[clientName] = msgID
87}
88
89func (ds deliveredStore) ForEachTarget(f func(target string)) {
90 for _, entry := range ds.m.innerMap {
91 f(entry.originalKey)
92 }
93}
94
95func (ds deliveredStore) ForEachClient(f func(clientName string)) {
96 clients := make(map[string]struct{})
97 for _, entry := range ds.m.innerMap {
98 delivered := entry.value.(deliveredClientMap)
99 for clientName := range delivered {
100 clients[clientName] = struct{}{}
101 }
102 }
103
104 for clientName := range clients {
105 f(clientName)
106 }
107}
108
109type network struct {
110 Network
111 user *user
112 stopped chan struct{}
113
114 conn *upstreamConn
115 channels channelCasemapMap
116 delivered deliveredStore
117 lastError error
118 casemap casemapping
119}
120
121func newNetwork(user *user, record *Network, channels []Channel) *network {
122 m := channelCasemapMap{newCasemapMap(0)}
123 for _, ch := range channels {
124 ch := ch
125 m.SetValue(ch.Name, &ch)
126 }
127
128 return &network{
129 Network: *record,
130 user: user,
131 stopped: make(chan struct{}),
132 channels: m,
133 delivered: newDeliveredStore(),
134 casemap: casemapRFC1459,
135 }
136}
137
138func (net *network) forEachDownstream(f func(*downstreamConn)) {
139 net.user.forEachDownstream(func(dc *downstreamConn) {
140 if dc.network != nil && dc.network != net {
141 return
142 }
143 f(dc)
144 })
145}
146
147func (net *network) isStopped() bool {
148 select {
149 case <-net.stopped:
150 return true
151 default:
152 return false
153 }
154}
155
156func userIdent(u *User) string {
157 // The ident is a string we will send to upstream servers in clear-text.
158 // For privacy reasons, make sure it doesn't expose any meaningful user
159 // metadata. We just use the base64-encoded hashed ID, so that people don't
160 // start relying on the string being an integer or following a pattern.
161 var b [64]byte
162 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
163 h := sha256.Sum256(b[:])
164 return hex.EncodeToString(h[:16])
165}
166
167func (net *network) run() {
168 var lastTry time.Time
169 for {
170 if net.isStopped() {
171 return
172 }
173
174 if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
175 delay := retryConnectDelay - dur
176 net.user.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
177 time.Sleep(delay)
178 }
179 lastTry = time.Now()
180
181 uc, err := connectToUpstream(net)
182 if err != nil {
183 net.user.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
184 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
185 continue
186 }
187
188 if net.user.srv.Identd != nil {
189 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
190 }
191
192 uc.register()
193 if err := uc.runUntilRegistered(); err != nil {
194 text := err.Error()
195 if regErr, ok := err.(registrationError); ok {
196 text = string(regErr)
197 }
198 uc.logger.Printf("failed to register: %v", text)
199 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
200 uc.Close()
201 continue
202 }
203
204 // TODO: this is racy with net.stopped. If the network is stopped
205 // before the user goroutine receives eventUpstreamConnected, the
206 // connection won't be closed.
207 net.user.events <- eventUpstreamConnected{uc}
208 if err := uc.readMessages(net.user.events); err != nil {
209 uc.logger.Printf("failed to handle messages: %v", err)
210 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
211 }
212 uc.Close()
213 net.user.events <- eventUpstreamDisconnected{uc}
214
215 if net.user.srv.Identd != nil {
216 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
217 }
218 }
219}
220
221func (net *network) stop() {
222 if !net.isStopped() {
223 close(net.stopped)
224 }
225
226 if net.conn != nil {
227 net.conn.Close()
228 }
229}
230
231func (net *network) detach(ch *Channel) {
232 if ch.Detached {
233 return
234 }
235
236 net.user.logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name)
237
238 ch.Detached = true
239
240 if net.user.msgStore != nil {
241 nameCM := net.casemap(ch.Name)
242 lastID, err := net.user.msgStore.LastMsgID(net, nameCM, time.Now())
243 if err != nil {
244 net.user.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
245 }
246 ch.DetachedInternalMsgID = lastID
247 }
248
249 if net.conn != nil {
250 uch := net.conn.channels.Value(ch.Name)
251 if uch != nil {
252 uch.updateAutoDetach(0)
253 }
254 }
255
256 net.forEachDownstream(func(dc *downstreamConn) {
257 dc.SendMessage(&irc.Message{
258 Prefix: dc.prefix(),
259 Command: "PART",
260 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
261 })
262 })
263}
264
265func (net *network) attach(ch *Channel) {
266 if !ch.Detached {
267 return
268 }
269
270 net.user.logger.Printf("network %q: attaching channel %q", net.GetName(), ch.Name)
271
272 detachedMsgID := ch.DetachedInternalMsgID
273 ch.Detached = false
274 ch.DetachedInternalMsgID = ""
275
276 var uch *upstreamChannel
277 if net.conn != nil {
278 uch = net.conn.channels.Value(ch.Name)
279
280 net.conn.updateChannelAutoDetach(ch.Name)
281 }
282
283 net.forEachDownstream(func(dc *downstreamConn) {
284 dc.SendMessage(&irc.Message{
285 Prefix: dc.prefix(),
286 Command: "JOIN",
287 Params: []string{dc.marshalEntity(net, ch.Name)},
288 })
289
290 if uch != nil {
291 forwardChannel(dc, uch)
292 }
293
294 if detachedMsgID != "" {
295 dc.sendTargetBacklog(net, ch.Name, detachedMsgID)
296 }
297 })
298}
299
300func (net *network) deleteChannel(name string) error {
301 ch := net.channels.Value(name)
302 if ch == nil {
303 return fmt.Errorf("unknown channel %q", name)
304 }
305 if net.conn != nil {
306 uch := net.conn.channels.Value(ch.Name)
307 if uch != nil {
308 uch.updateAutoDetach(0)
309 }
310 }
311
312 if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil {
313 return err
314 }
315 net.channels.Delete(name)
316 return nil
317}
318
319func (net *network) updateCasemapping(newCasemap casemapping) {
320 net.casemap = newCasemap
321 net.channels.SetCasemapping(newCasemap)
322 net.delivered.m.SetCasemapping(newCasemap)
323 if net.conn != nil {
324 net.conn.channels.SetCasemapping(newCasemap)
325 for _, entry := range net.conn.channels.innerMap {
326 uch := entry.value.(*upstreamChannel)
327 uch.Members.SetCasemapping(newCasemap)
328 }
329 }
330}
331
332func (net *network) storeClientDeliveryReceipts(clientName string) {
333 if !net.user.hasPersistentMsgStore() {
334 return
335 }
336
337 var receipts []DeliveryReceipt
338 net.delivered.ForEachTarget(func(target string) {
339 msgID := net.delivered.LoadID(target, clientName)
340 if msgID == "" {
341 return
342 }
343 receipts = append(receipts, DeliveryReceipt{
344 Target: target,
345 InternalMsgID: msgID,
346 })
347 })
348
349 if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil {
350 net.user.logger.Printf("failed to store delivery receipts for user %q, client %q, network %q: %v", net.user.Username, clientName, net.GetName(), err)
351 }
352}
353
354func (net *network) isHighlight(msg *irc.Message) bool {
355 if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
356 return false
357 }
358
359 text := msg.Params[1]
360
361 nick := net.Nick
362 if net.conn != nil {
363 nick = net.conn.nick
364 }
365
366 // TODO: use case-mapping aware comparison here
367 return msg.Prefix.Name != nick && isHighlight(text, nick)
368}
369
370func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
371 highlight := net.isHighlight(msg)
372 return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
373}
374
375type user struct {
376 User
377 srv *Server
378 logger Logger
379
380 events chan event
381 done chan struct{}
382
383 networks []*network
384 downstreamConns []*downstreamConn
385 msgStore messageStore
386
387 // LIST commands in progress
388 pendingLISTs []pendingLIST
389}
390
391type pendingLIST struct {
392 downstreamID uint64
393 // list of per-upstream LIST commands not yet sent or completed
394 pendingCommands map[int64]*irc.Message
395}
396
397func newUser(srv *Server, record *User) *user {
398 logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
399
400 var msgStore messageStore
401 if srv.LogPath != "" {
402 msgStore = newFSMessageStore(srv.LogPath, record.Username)
403 } else {
404 msgStore = newMemoryMessageStore()
405 }
406
407 return &user{
408 User: *record,
409 srv: srv,
410 logger: logger,
411 events: make(chan event, 64),
412 done: make(chan struct{}),
413 msgStore: msgStore,
414 }
415}
416
417func (u *user) forEachNetwork(f func(*network)) {
418 for _, network := range u.networks {
419 f(network)
420 }
421}
422
423func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
424 for _, network := range u.networks {
425 if network.conn == nil {
426 continue
427 }
428 f(network.conn)
429 }
430}
431
432func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
433 for _, dc := range u.downstreamConns {
434 f(dc)
435 }
436}
437
438func (u *user) getNetwork(name string) *network {
439 for _, network := range u.networks {
440 if network.Addr == name {
441 return network
442 }
443 if network.Name != "" && network.Name == name {
444 return network
445 }
446 }
447 return nil
448}
449
450func (u *user) getNetworkByID(id int64) *network {
451 for _, net := range u.networks {
452 if net.ID == id {
453 return net
454 }
455 }
456 return nil
457}
458
459func (u *user) run() {
460 defer func() {
461 if u.msgStore != nil {
462 if err := u.msgStore.Close(); err != nil {
463 u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
464 }
465 }
466 close(u.done)
467 }()
468
469 networks, err := u.srv.db.ListNetworks(u.ID)
470 if err != nil {
471 u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
472 return
473 }
474
475 for _, record := range networks {
476 record := record
477 channels, err := u.srv.db.ListChannels(record.ID)
478 if err != nil {
479 u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
480 continue
481 }
482
483 network := newNetwork(u, &record, channels)
484 u.networks = append(u.networks, network)
485
486 if u.hasPersistentMsgStore() {
487 receipts, err := u.srv.db.ListDeliveryReceipts(record.ID)
488 if err != nil {
489 u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
490 return
491 }
492
493 for _, rcpt := range receipts {
494 network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
495 }
496 }
497
498 go network.run()
499 }
500
501 for e := range u.events {
502 switch e := e.(type) {
503 case eventUpstreamConnected:
504 uc := e.uc
505
506 uc.network.conn = uc
507
508 uc.updateAway()
509
510 uc.forEachDownstream(func(dc *downstreamConn) {
511 dc.updateSupportedCaps()
512 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
513
514 dc.updateNick()
515 })
516 uc.network.lastError = nil
517 case eventUpstreamDisconnected:
518 u.handleUpstreamDisconnected(e.uc)
519 case eventUpstreamConnectionError:
520 net := e.net
521
522 stopped := false
523 select {
524 case <-net.stopped:
525 stopped = true
526 default:
527 }
528
529 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
530 net.forEachDownstream(func(dc *downstreamConn) {
531 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
532 })
533 }
534 net.lastError = e.err
535 case eventUpstreamError:
536 uc := e.uc
537
538 uc.forEachDownstream(func(dc *downstreamConn) {
539 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
540 })
541 uc.network.lastError = e.err
542 case eventUpstreamMessage:
543 msg, uc := e.msg, e.uc
544 if uc.isClosed() {
545 uc.logger.Printf("ignoring message on closed connection: %v", msg)
546 break
547 }
548 if err := uc.handleMessage(msg); err != nil {
549 uc.logger.Printf("failed to handle message %q: %v", msg, err)
550 }
551 case eventChannelDetach:
552 uc, name := e.uc, e.name
553 c := uc.network.channels.Value(name)
554 if c == nil || c.Detached {
555 continue
556 }
557 uc.network.detach(c)
558 if err := uc.srv.db.StoreChannel(uc.network.ID, c); err != nil {
559 u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
560 }
561 case eventDownstreamConnected:
562 dc := e.dc
563
564 if err := dc.welcome(); err != nil {
565 dc.logger.Printf("failed to handle new registered connection: %v", err)
566 break
567 }
568
569 u.downstreamConns = append(u.downstreamConns, dc)
570
571 dc.forEachNetwork(func(network *network) {
572 if network.lastError != nil {
573 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
574 }
575 })
576
577 u.forEachUpstream(func(uc *upstreamConn) {
578 uc.updateAway()
579 })
580 case eventDownstreamDisconnected:
581 dc := e.dc
582
583 for i := range u.downstreamConns {
584 if u.downstreamConns[i] == dc {
585 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
586 break
587 }
588 }
589
590 dc.forEachNetwork(func(net *network) {
591 net.storeClientDeliveryReceipts(dc.clientName)
592 })
593
594 u.forEachUpstream(func(uc *upstreamConn) {
595 uc.updateAway()
596 })
597 case eventDownstreamMessage:
598 msg, dc := e.msg, e.dc
599 if dc.isClosed() {
600 dc.logger.Printf("ignoring message on closed connection: %v", msg)
601 break
602 }
603 err := dc.handleMessage(msg)
604 if ircErr, ok := err.(ircError); ok {
605 ircErr.Message.Prefix = dc.srv.prefix()
606 dc.SendMessage(ircErr.Message)
607 } else if err != nil {
608 dc.logger.Printf("failed to handle message %q: %v", msg, err)
609 dc.Close()
610 }
611 case eventStop:
612 u.forEachDownstream(func(dc *downstreamConn) {
613 dc.Close()
614 })
615 for _, n := range u.networks {
616 n.stop()
617
618 n.delivered.ForEachClient(func(clientName string) {
619 n.storeClientDeliveryReceipts(clientName)
620 })
621 }
622 return
623 default:
624 panic(fmt.Sprintf("received unknown event type: %T", e))
625 }
626 }
627}
628
629func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
630 uc.network.conn = nil
631
632 uc.endPendingLISTs(true)
633
634 for _, entry := range uc.channels.innerMap {
635 uch := entry.value.(*upstreamChannel)
636 uch.updateAutoDetach(0)
637 }
638
639 uc.forEachDownstream(func(dc *downstreamConn) {
640 dc.updateSupportedCaps()
641 })
642
643 if uc.network.lastError == nil {
644 uc.forEachDownstream(func(dc *downstreamConn) {
645 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
646 })
647 }
648}
649
650func (u *user) addNetwork(network *network) {
651 u.networks = append(u.networks, network)
652 go network.run()
653}
654
655func (u *user) removeNetwork(network *network) {
656 network.stop()
657
658 u.forEachDownstream(func(dc *downstreamConn) {
659 if dc.network != nil && dc.network == network {
660 dc.Close()
661 }
662 })
663
664 for i, net := range u.networks {
665 if net == network {
666 u.networks = append(u.networks[:i], u.networks[i+1:]...)
667 return
668 }
669 }
670
671 panic("tried to remove a non-existing network")
672}
673
674func (u *user) createNetwork(record *Network) (*network, error) {
675 if record.ID != 0 {
676 panic("tried creating an already-existing network")
677 }
678
679 network := newNetwork(u, record, nil)
680 err := u.srv.db.StoreNetwork(u.ID, &network.Network)
681 if err != nil {
682 return nil, err
683 }
684
685 u.addNetwork(network)
686
687 return network, nil
688}
689
690func (u *user) updateNetwork(record *Network) (*network, error) {
691 if record.ID == 0 {
692 panic("tried updating a new network")
693 }
694
695 network := u.getNetworkByID(record.ID)
696 if network == nil {
697 panic("tried updating a non-existing network")
698 }
699
700 if err := u.srv.db.StoreNetwork(u.ID, record); err != nil {
701 return nil, err
702 }
703
704 // Most network changes require us to re-connect to the upstream server
705
706 channels := make([]Channel, 0, network.channels.Len())
707 for _, entry := range network.channels.innerMap {
708 ch := entry.value.(*Channel)
709 channels = append(channels, *ch)
710 }
711
712 updatedNetwork := newNetwork(u, record, channels)
713
714 // If we're currently connected, disconnect and perform the necessary
715 // bookkeeping
716 if network.conn != nil {
717 network.stop()
718 // Note: this will set network.conn to nil
719 u.handleUpstreamDisconnected(network.conn)
720 }
721
722 // Patch downstream connections to use our fresh updated network
723 u.forEachDownstream(func(dc *downstreamConn) {
724 if dc.network != nil && dc.network == network {
725 dc.network = updatedNetwork
726 }
727 })
728
729 // We need to remove the network after patching downstream connections,
730 // otherwise they'll get closed
731 u.removeNetwork(network)
732
733 // This will re-connect to the upstream server
734 u.addNetwork(updatedNetwork)
735
736 return updatedNetwork, nil
737}
738
739func (u *user) deleteNetwork(id int64) error {
740 network := u.getNetworkByID(id)
741 if network == nil {
742 panic("tried deleting a non-existing network")
743 }
744
745 if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
746 return err
747 }
748
749 u.removeNetwork(network)
750 return nil
751}
752
753func (u *user) updatePassword(hashed string) error {
754 u.User.Password = hashed
755 return u.srv.db.StoreUser(&u.User)
756}
757
758func (u *user) stop() {
759 u.events <- eventStop{}
760 <-u.done
761}
762
763func (u *user) hasPersistentMsgStore() bool {
764 if u.msgStore == nil {
765 return false
766 }
767 _, isMem := u.msgStore.(*memoryMessageStore)
768 return !isMem
769}
Note: See TracBrowser for help on using the repository browser.