source: code/trunk/user.go@ 670

Last change on this file since 670 was 666, checked in by contact, 4 years ago

msgstore: take Network as arg instead of network

The message stores don't need to access the internal network
struct, they just need network metadata such as ID and name.

This can ease moving message stores into a separate package in the
future.

File size: 22.2 KB
Line 
1package soju
2
3import (
4 "context"
5 "crypto/sha256"
6 "encoding/binary"
7 "encoding/hex"
8 "fmt"
9 "time"
10
11 "gopkg.in/irc.v3"
12)
13
14type event interface{}
15
16type eventUpstreamMessage struct {
17 msg *irc.Message
18 uc *upstreamConn
19}
20
21type eventUpstreamConnectionError struct {
22 net *network
23 err error
24}
25
26type eventUpstreamConnected struct {
27 uc *upstreamConn
28}
29
30type eventUpstreamDisconnected struct {
31 uc *upstreamConn
32}
33
34type eventUpstreamError struct {
35 uc *upstreamConn
36 err error
37}
38
39type eventDownstreamMessage struct {
40 msg *irc.Message
41 dc *downstreamConn
42}
43
44type eventDownstreamConnected struct {
45 dc *downstreamConn
46}
47
48type eventDownstreamDisconnected struct {
49 dc *downstreamConn
50}
51
52type eventChannelDetach struct {
53 uc *upstreamConn
54 name string
55}
56
57type eventBroadcast struct {
58 msg *irc.Message
59}
60
61type eventStop struct{}
62
63type eventUserUpdate struct {
64 password *string
65 admin *bool
66 done chan error
67}
68
69type deliveredClientMap map[string]string // client name -> msg ID
70
71type deliveredStore struct {
72 m deliveredCasemapMap
73}
74
75func newDeliveredStore() deliveredStore {
76 return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
77}
78
79func (ds deliveredStore) HasTarget(target string) bool {
80 return ds.m.Value(target) != nil
81}
82
83func (ds deliveredStore) LoadID(target, clientName string) string {
84 clients := ds.m.Value(target)
85 if clients == nil {
86 return ""
87 }
88 return clients[clientName]
89}
90
91func (ds deliveredStore) StoreID(target, clientName, msgID string) {
92 clients := ds.m.Value(target)
93 if clients == nil {
94 clients = make(deliveredClientMap)
95 ds.m.SetValue(target, clients)
96 }
97 clients[clientName] = msgID
98}
99
100func (ds deliveredStore) ForEachTarget(f func(target string)) {
101 for _, entry := range ds.m.innerMap {
102 f(entry.originalKey)
103 }
104}
105
106func (ds deliveredStore) ForEachClient(f func(clientName string)) {
107 clients := make(map[string]struct{})
108 for _, entry := range ds.m.innerMap {
109 delivered := entry.value.(deliveredClientMap)
110 for clientName := range delivered {
111 clients[clientName] = struct{}{}
112 }
113 }
114
115 for clientName := range clients {
116 f(clientName)
117 }
118}
119
120type network struct {
121 Network
122 user *user
123 logger Logger
124 stopped chan struct{}
125
126 conn *upstreamConn
127 channels channelCasemapMap
128 delivered deliveredStore
129 lastError error
130 casemap casemapping
131}
132
133func newNetwork(user *user, record *Network, channels []Channel) *network {
134 logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
135
136 m := channelCasemapMap{newCasemapMap(0)}
137 for _, ch := range channels {
138 ch := ch
139 m.SetValue(ch.Name, &ch)
140 }
141
142 return &network{
143 Network: *record,
144 user: user,
145 logger: logger,
146 stopped: make(chan struct{}),
147 channels: m,
148 delivered: newDeliveredStore(),
149 casemap: casemapRFC1459,
150 }
151}
152
153func (net *network) forEachDownstream(f func(*downstreamConn)) {
154 net.user.forEachDownstream(func(dc *downstreamConn) {
155 if dc.network == nil && dc.caps["soju.im/bouncer-networks"] {
156 return
157 }
158 if dc.network != nil && dc.network != net {
159 return
160 }
161 f(dc)
162 })
163}
164
165func (net *network) isStopped() bool {
166 select {
167 case <-net.stopped:
168 return true
169 default:
170 return false
171 }
172}
173
174func userIdent(u *User) string {
175 // The ident is a string we will send to upstream servers in clear-text.
176 // For privacy reasons, make sure it doesn't expose any meaningful user
177 // metadata. We just use the base64-encoded hashed ID, so that people don't
178 // start relying on the string being an integer or following a pattern.
179 var b [64]byte
180 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
181 h := sha256.Sum256(b[:])
182 return hex.EncodeToString(h[:16])
183}
184
185func (net *network) run() {
186 if !net.Enabled {
187 return
188 }
189
190 var lastTry time.Time
191 for {
192 if net.isStopped() {
193 return
194 }
195
196 if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
197 delay := retryConnectDelay - dur
198 net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
199 time.Sleep(delay)
200 }
201 lastTry = time.Now()
202
203 uc, err := connectToUpstream(net)
204 if err != nil {
205 net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
206 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
207 continue
208 }
209
210 if net.user.srv.Identd != nil {
211 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
212 }
213
214 uc.register()
215 if err := uc.runUntilRegistered(); err != nil {
216 text := err.Error()
217 if regErr, ok := err.(registrationError); ok {
218 text = string(regErr)
219 }
220 uc.logger.Printf("failed to register: %v", text)
221 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
222 uc.Close()
223 continue
224 }
225
226 // TODO: this is racy with net.stopped. If the network is stopped
227 // before the user goroutine receives eventUpstreamConnected, the
228 // connection won't be closed.
229 net.user.events <- eventUpstreamConnected{uc}
230 if err := uc.readMessages(net.user.events); err != nil {
231 uc.logger.Printf("failed to handle messages: %v", err)
232 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
233 }
234 uc.Close()
235 net.user.events <- eventUpstreamDisconnected{uc}
236
237 if net.user.srv.Identd != nil {
238 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
239 }
240 }
241}
242
243func (net *network) stop() {
244 if !net.isStopped() {
245 close(net.stopped)
246 }
247
248 if net.conn != nil {
249 net.conn.Close()
250 }
251}
252
253func (net *network) detach(ch *Channel) {
254 if ch.Detached {
255 return
256 }
257
258 net.logger.Printf("detaching channel %q", ch.Name)
259
260 ch.Detached = true
261
262 if net.user.msgStore != nil {
263 nameCM := net.casemap(ch.Name)
264 lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
265 if err != nil {
266 net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
267 }
268 ch.DetachedInternalMsgID = lastID
269 }
270
271 if net.conn != nil {
272 uch := net.conn.channels.Value(ch.Name)
273 if uch != nil {
274 uch.updateAutoDetach(0)
275 }
276 }
277
278 net.forEachDownstream(func(dc *downstreamConn) {
279 dc.SendMessage(&irc.Message{
280 Prefix: dc.prefix(),
281 Command: "PART",
282 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
283 })
284 })
285}
286
287func (net *network) attach(ch *Channel) {
288 if !ch.Detached {
289 return
290 }
291
292 net.logger.Printf("attaching channel %q", ch.Name)
293
294 detachedMsgID := ch.DetachedInternalMsgID
295 ch.Detached = false
296 ch.DetachedInternalMsgID = ""
297
298 var uch *upstreamChannel
299 if net.conn != nil {
300 uch = net.conn.channels.Value(ch.Name)
301
302 net.conn.updateChannelAutoDetach(ch.Name)
303 }
304
305 net.forEachDownstream(func(dc *downstreamConn) {
306 dc.SendMessage(&irc.Message{
307 Prefix: dc.prefix(),
308 Command: "JOIN",
309 Params: []string{dc.marshalEntity(net, ch.Name)},
310 })
311
312 if uch != nil {
313 forwardChannel(dc, uch)
314 }
315
316 if detachedMsgID != "" {
317 dc.sendTargetBacklog(net, ch.Name, detachedMsgID)
318 }
319 })
320}
321
322func (net *network) deleteChannel(name string) error {
323 ch := net.channels.Value(name)
324 if ch == nil {
325 return fmt.Errorf("unknown channel %q", name)
326 }
327 if net.conn != nil {
328 uch := net.conn.channels.Value(ch.Name)
329 if uch != nil {
330 uch.updateAutoDetach(0)
331 }
332 }
333
334 if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
335 return err
336 }
337 net.channels.Delete(name)
338 return nil
339}
340
341func (net *network) updateCasemapping(newCasemap casemapping) {
342 net.casemap = newCasemap
343 net.channels.SetCasemapping(newCasemap)
344 net.delivered.m.SetCasemapping(newCasemap)
345 if net.conn != nil {
346 net.conn.channels.SetCasemapping(newCasemap)
347 for _, entry := range net.conn.channels.innerMap {
348 uch := entry.value.(*upstreamChannel)
349 uch.Members.SetCasemapping(newCasemap)
350 }
351 }
352}
353
354func (net *network) storeClientDeliveryReceipts(clientName string) {
355 if !net.user.hasPersistentMsgStore() {
356 return
357 }
358
359 var receipts []DeliveryReceipt
360 net.delivered.ForEachTarget(func(target string) {
361 msgID := net.delivered.LoadID(target, clientName)
362 if msgID == "" {
363 return
364 }
365 receipts = append(receipts, DeliveryReceipt{
366 Target: target,
367 InternalMsgID: msgID,
368 })
369 })
370
371 if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
372 net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
373 }
374}
375
376func (net *network) isHighlight(msg *irc.Message) bool {
377 if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
378 return false
379 }
380
381 text := msg.Params[1]
382
383 nick := net.Nick
384 if net.conn != nil {
385 nick = net.conn.nick
386 }
387
388 // TODO: use case-mapping aware comparison here
389 return msg.Prefix.Name != nick && isHighlight(text, nick)
390}
391
392func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
393 highlight := net.isHighlight(msg)
394 return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
395}
396
397type user struct {
398 User
399 srv *Server
400 logger Logger
401
402 events chan event
403 done chan struct{}
404
405 networks []*network
406 downstreamConns []*downstreamConn
407 msgStore messageStore
408
409 // LIST commands in progress
410 pendingLISTs []pendingLIST
411}
412
413type pendingLIST struct {
414 downstreamID uint64
415 // list of per-upstream LIST commands not yet sent or completed
416 pendingCommands map[int64]*irc.Message
417}
418
419func newUser(srv *Server, record *User) *user {
420 logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
421
422 var msgStore messageStore
423 if srv.LogPath != "" {
424 msgStore = newFSMessageStore(srv.LogPath, record.Username)
425 } else {
426 msgStore = newMemoryMessageStore()
427 }
428
429 return &user{
430 User: *record,
431 srv: srv,
432 logger: logger,
433 events: make(chan event, 64),
434 done: make(chan struct{}),
435 msgStore: msgStore,
436 }
437}
438
439func (u *user) forEachNetwork(f func(*network)) {
440 for _, network := range u.networks {
441 f(network)
442 }
443}
444
445func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
446 for _, network := range u.networks {
447 if network.conn == nil {
448 continue
449 }
450 f(network.conn)
451 }
452}
453
454func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
455 for _, dc := range u.downstreamConns {
456 f(dc)
457 }
458}
459
460func (u *user) getNetwork(name string) *network {
461 for _, network := range u.networks {
462 if network.Addr == name {
463 return network
464 }
465 if network.Name != "" && network.Name == name {
466 return network
467 }
468 }
469 return nil
470}
471
472func (u *user) getNetworkByID(id int64) *network {
473 for _, net := range u.networks {
474 if net.ID == id {
475 return net
476 }
477 }
478 return nil
479}
480
481func (u *user) run() {
482 defer func() {
483 if u.msgStore != nil {
484 if err := u.msgStore.Close(); err != nil {
485 u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
486 }
487 }
488 close(u.done)
489 }()
490
491 networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
492 if err != nil {
493 u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
494 return
495 }
496
497 for _, record := range networks {
498 record := record
499 channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
500 if err != nil {
501 u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
502 continue
503 }
504
505 network := newNetwork(u, &record, channels)
506 u.networks = append(u.networks, network)
507
508 if u.hasPersistentMsgStore() {
509 receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
510 if err != nil {
511 u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
512 return
513 }
514
515 for _, rcpt := range receipts {
516 network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
517 }
518 }
519
520 go network.run()
521 }
522
523 for e := range u.events {
524 switch e := e.(type) {
525 case eventUpstreamConnected:
526 uc := e.uc
527
528 uc.network.conn = uc
529
530 uc.updateAway()
531
532 netIDStr := fmt.Sprintf("%v", uc.network.ID)
533 uc.forEachDownstream(func(dc *downstreamConn) {
534 dc.updateSupportedCaps()
535
536 if !dc.caps["soju.im/bouncer-networks"] {
537 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
538 }
539
540 dc.updateNick()
541 dc.updateRealname()
542 })
543 u.forEachDownstream(func(dc *downstreamConn) {
544 if dc.caps["soju.im/bouncer-networks-notify"] {
545 dc.SendMessage(&irc.Message{
546 Prefix: dc.srv.prefix(),
547 Command: "BOUNCER",
548 Params: []string{"NETWORK", netIDStr, "state=connected"},
549 })
550 }
551 })
552 uc.network.lastError = nil
553 case eventUpstreamDisconnected:
554 u.handleUpstreamDisconnected(e.uc)
555 case eventUpstreamConnectionError:
556 net := e.net
557
558 stopped := false
559 select {
560 case <-net.stopped:
561 stopped = true
562 default:
563 }
564
565 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
566 net.forEachDownstream(func(dc *downstreamConn) {
567 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
568 })
569 }
570 net.lastError = e.err
571 case eventUpstreamError:
572 uc := e.uc
573
574 uc.forEachDownstream(func(dc *downstreamConn) {
575 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
576 })
577 uc.network.lastError = e.err
578 case eventUpstreamMessage:
579 msg, uc := e.msg, e.uc
580 if uc.isClosed() {
581 uc.logger.Printf("ignoring message on closed connection: %v", msg)
582 break
583 }
584 if err := uc.handleMessage(msg); err != nil {
585 uc.logger.Printf("failed to handle message %q: %v", msg, err)
586 }
587 case eventChannelDetach:
588 uc, name := e.uc, e.name
589 c := uc.network.channels.Value(name)
590 if c == nil || c.Detached {
591 continue
592 }
593 uc.network.detach(c)
594 if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
595 u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
596 }
597 case eventDownstreamConnected:
598 dc := e.dc
599
600 if err := dc.welcome(); err != nil {
601 dc.logger.Printf("failed to handle new registered connection: %v", err)
602 break
603 }
604
605 u.downstreamConns = append(u.downstreamConns, dc)
606
607 dc.forEachNetwork(func(network *network) {
608 if network.lastError != nil {
609 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
610 }
611 })
612
613 u.forEachUpstream(func(uc *upstreamConn) {
614 uc.updateAway()
615 })
616 case eventDownstreamDisconnected:
617 dc := e.dc
618
619 for i := range u.downstreamConns {
620 if u.downstreamConns[i] == dc {
621 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
622 break
623 }
624 }
625
626 dc.forEachNetwork(func(net *network) {
627 net.storeClientDeliveryReceipts(dc.clientName)
628 })
629
630 u.forEachUpstream(func(uc *upstreamConn) {
631 uc.updateAway()
632 })
633 case eventDownstreamMessage:
634 msg, dc := e.msg, e.dc
635 if dc.isClosed() {
636 dc.logger.Printf("ignoring message on closed connection: %v", msg)
637 break
638 }
639 err := dc.handleMessage(msg)
640 if ircErr, ok := err.(ircError); ok {
641 ircErr.Message.Prefix = dc.srv.prefix()
642 dc.SendMessage(ircErr.Message)
643 } else if err != nil {
644 dc.logger.Printf("failed to handle message %q: %v", msg, err)
645 dc.Close()
646 }
647 case eventBroadcast:
648 msg := e.msg
649 u.forEachDownstream(func(dc *downstreamConn) {
650 dc.SendMessage(msg)
651 })
652 case eventUserUpdate:
653 // copy the user record because we'll mutate it
654 record := u.User
655
656 if e.password != nil {
657 record.Password = *e.password
658 }
659 if e.admin != nil {
660 record.Admin = *e.admin
661 }
662
663 e.done <- u.updateUser(&record)
664
665 // If the password was updated, kill all downstream connections to
666 // force them to re-authenticate with the new credentials.
667 if e.password != nil {
668 u.forEachDownstream(func(dc *downstreamConn) {
669 dc.Close()
670 })
671 }
672 case eventStop:
673 u.forEachDownstream(func(dc *downstreamConn) {
674 dc.Close()
675 })
676 for _, n := range u.networks {
677 n.stop()
678
679 n.delivered.ForEachClient(func(clientName string) {
680 n.storeClientDeliveryReceipts(clientName)
681 })
682 }
683 return
684 default:
685 panic(fmt.Sprintf("received unknown event type: %T", e))
686 }
687 }
688}
689
690func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
691 uc.network.conn = nil
692
693 uc.endPendingLISTs(true)
694
695 for _, entry := range uc.channels.innerMap {
696 uch := entry.value.(*upstreamChannel)
697 uch.updateAutoDetach(0)
698 }
699
700 netIDStr := fmt.Sprintf("%v", uc.network.ID)
701 uc.forEachDownstream(func(dc *downstreamConn) {
702 dc.updateSupportedCaps()
703 })
704
705 // If the network has been removed, don't send a state change notification
706 found := false
707 for _, net := range u.networks {
708 if net == uc.network {
709 found = true
710 break
711 }
712 }
713 if !found {
714 return
715 }
716
717 u.forEachDownstream(func(dc *downstreamConn) {
718 if dc.caps["soju.im/bouncer-networks-notify"] {
719 dc.SendMessage(&irc.Message{
720 Prefix: dc.srv.prefix(),
721 Command: "BOUNCER",
722 Params: []string{"NETWORK", netIDStr, "state=disconnected"},
723 })
724 }
725 })
726
727 if uc.network.lastError == nil {
728 uc.forEachDownstream(func(dc *downstreamConn) {
729 if !dc.caps["soju.im/bouncer-networks"] {
730 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
731 }
732 })
733 }
734}
735
736func (u *user) addNetwork(network *network) {
737 u.networks = append(u.networks, network)
738 go network.run()
739}
740
741func (u *user) removeNetwork(network *network) {
742 network.stop()
743
744 u.forEachDownstream(func(dc *downstreamConn) {
745 if dc.network != nil && dc.network == network {
746 dc.Close()
747 }
748 })
749
750 for i, net := range u.networks {
751 if net == network {
752 u.networks = append(u.networks[:i], u.networks[i+1:]...)
753 return
754 }
755 }
756
757 panic("tried to remove a non-existing network")
758}
759
760func (u *user) checkNetwork(record *Network) error {
761 for _, net := range u.networks {
762 if net.GetName() == record.GetName() && net.ID != record.ID {
763 return fmt.Errorf("a network with the name %q already exists", record.GetName())
764 }
765 }
766 return nil
767}
768
769func (u *user) createNetwork(record *Network) (*network, error) {
770 if record.ID != 0 {
771 panic("tried creating an already-existing network")
772 }
773
774 if err := u.checkNetwork(record); err != nil {
775 return nil, err
776 }
777
778 if u.srv.MaxUserNetworks >= 0 && len(u.networks) >= u.srv.MaxUserNetworks {
779 return nil, fmt.Errorf("maximum number of networks reached")
780 }
781
782 network := newNetwork(u, record, nil)
783 err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
784 if err != nil {
785 return nil, err
786 }
787
788 u.addNetwork(network)
789
790 idStr := fmt.Sprintf("%v", network.ID)
791 attrs := getNetworkAttrs(network)
792 u.forEachDownstream(func(dc *downstreamConn) {
793 if dc.caps["soju.im/bouncer-networks-notify"] {
794 dc.SendMessage(&irc.Message{
795 Prefix: dc.srv.prefix(),
796 Command: "BOUNCER",
797 Params: []string{"NETWORK", idStr, attrs.String()},
798 })
799 }
800 })
801
802 return network, nil
803}
804
805func (u *user) updateNetwork(record *Network) (*network, error) {
806 if record.ID == 0 {
807 panic("tried updating a new network")
808 }
809
810 // If the realname is reset to the default, just wipe the per-network
811 // setting
812 if record.Realname == u.Realname {
813 record.Realname = ""
814 }
815
816 if err := u.checkNetwork(record); err != nil {
817 return nil, err
818 }
819
820 network := u.getNetworkByID(record.ID)
821 if network == nil {
822 panic("tried updating a non-existing network")
823 }
824
825 if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
826 return nil, err
827 }
828
829 // Most network changes require us to re-connect to the upstream server
830
831 channels := make([]Channel, 0, network.channels.Len())
832 for _, entry := range network.channels.innerMap {
833 ch := entry.value.(*Channel)
834 channels = append(channels, *ch)
835 }
836
837 updatedNetwork := newNetwork(u, record, channels)
838
839 // If we're currently connected, disconnect and perform the necessary
840 // bookkeeping
841 if network.conn != nil {
842 network.stop()
843 // Note: this will set network.conn to nil
844 u.handleUpstreamDisconnected(network.conn)
845 }
846
847 // Patch downstream connections to use our fresh updated network
848 u.forEachDownstream(func(dc *downstreamConn) {
849 if dc.network != nil && dc.network == network {
850 dc.network = updatedNetwork
851 }
852 })
853
854 // We need to remove the network after patching downstream connections,
855 // otherwise they'll get closed
856 u.removeNetwork(network)
857
858 // The filesystem message store needs to be notified whenever the network
859 // is renamed
860 fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
861 if isFS && updatedNetwork.GetName() != network.GetName() {
862 if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
863 network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
864 }
865 }
866
867 // This will re-connect to the upstream server
868 u.addNetwork(updatedNetwork)
869
870 // TODO: only broadcast attributes that have changed
871 idStr := fmt.Sprintf("%v", updatedNetwork.ID)
872 attrs := getNetworkAttrs(updatedNetwork)
873 u.forEachDownstream(func(dc *downstreamConn) {
874 if dc.caps["soju.im/bouncer-networks-notify"] {
875 dc.SendMessage(&irc.Message{
876 Prefix: dc.srv.prefix(),
877 Command: "BOUNCER",
878 Params: []string{"NETWORK", idStr, attrs.String()},
879 })
880 }
881 })
882
883 return updatedNetwork, nil
884}
885
886func (u *user) deleteNetwork(id int64) error {
887 network := u.getNetworkByID(id)
888 if network == nil {
889 panic("tried deleting a non-existing network")
890 }
891
892 if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
893 return err
894 }
895
896 u.removeNetwork(network)
897
898 idStr := fmt.Sprintf("%v", network.ID)
899 u.forEachDownstream(func(dc *downstreamConn) {
900 if dc.caps["soju.im/bouncer-networks-notify"] {
901 dc.SendMessage(&irc.Message{
902 Prefix: dc.srv.prefix(),
903 Command: "BOUNCER",
904 Params: []string{"NETWORK", idStr, "*"},
905 })
906 }
907 })
908
909 return nil
910}
911
912func (u *user) updateUser(record *User) error {
913 if u.ID != record.ID {
914 panic("ID mismatch when updating user")
915 }
916
917 realnameUpdated := u.Realname != record.Realname
918 if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
919 return fmt.Errorf("failed to update user %q: %v", u.Username, err)
920 }
921 u.User = *record
922
923 if realnameUpdated {
924 // Re-connect to networks which use the default realname
925 var needUpdate []Network
926 u.forEachNetwork(func(net *network) {
927 if net.Realname == "" {
928 needUpdate = append(needUpdate, net.Network)
929 }
930 })
931
932 var netErr error
933 for _, net := range needUpdate {
934 if _, err := u.updateNetwork(&net); err != nil {
935 netErr = err
936 }
937 }
938 if netErr != nil {
939 return netErr
940 }
941 }
942
943 return nil
944}
945
946func (u *user) stop() {
947 u.events <- eventStop{}
948 <-u.done
949}
950
951func (u *user) hasPersistentMsgStore() bool {
952 if u.msgStore == nil {
953 return false
954 }
955 _, isMem := u.msgStore.(*memoryMessageStore)
956 return !isMem
957}
Note: See TracBrowser for help on using the repository browser.