source: code/trunk/user.go@ 669

Last change on this file since 669 was 666, checked in by contact, 4 years ago

msgstore: take Network as arg instead of network

The message stores don't need to access the internal network
struct, they just need network metadata such as ID and name.

This can ease moving message stores into a separate package in the
future.

File size: 22.2 KB
RevLine 
[101]1package soju
2
3import (
[652]4 "context"
[385]5 "crypto/sha256"
6 "encoding/binary"
[395]7 "encoding/hex"
[218]8 "fmt"
[101]9 "time"
[103]10
11 "gopkg.in/irc.v3"
[101]12)
13
[165]14type event interface{}
15
16type eventUpstreamMessage struct {
[103]17 msg *irc.Message
18 uc *upstreamConn
19}
20
[218]21type eventUpstreamConnectionError struct {
22 net *network
23 err error
24}
25
[196]26type eventUpstreamConnected struct {
27 uc *upstreamConn
28}
29
[179]30type eventUpstreamDisconnected struct {
31 uc *upstreamConn
32}
33
[218]34type eventUpstreamError struct {
35 uc *upstreamConn
36 err error
37}
38
[165]39type eventDownstreamMessage struct {
[103]40 msg *irc.Message
41 dc *downstreamConn
42}
43
[166]44type eventDownstreamConnected struct {
45 dc *downstreamConn
46}
47
[167]48type eventDownstreamDisconnected struct {
49 dc *downstreamConn
50}
51
[435]52type eventChannelDetach struct {
53 uc *upstreamConn
54 name string
55}
56
[563]57type eventBroadcast struct {
58 msg *irc.Message
59}
60
[376]61type eventStop struct{}
62
[625]63type eventUserUpdate struct {
64 password *string
65 admin *bool
66 done chan error
67}
68
[480]69type deliveredClientMap map[string]string // client name -> msg ID
70
[485]71type deliveredStore struct {
72 m deliveredCasemapMap
73}
74
75func newDeliveredStore() deliveredStore {
76 return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
77}
78
79func (ds deliveredStore) HasTarget(target string) bool {
80 return ds.m.Value(target) != nil
81}
82
83func (ds deliveredStore) LoadID(target, clientName string) string {
84 clients := ds.m.Value(target)
85 if clients == nil {
86 return ""
87 }
88 return clients[clientName]
89}
90
91func (ds deliveredStore) StoreID(target, clientName, msgID string) {
92 clients := ds.m.Value(target)
93 if clients == nil {
94 clients = make(deliveredClientMap)
95 ds.m.SetValue(target, clients)
96 }
97 clients[clientName] = msgID
98}
99
100func (ds deliveredStore) ForEachTarget(f func(target string)) {
101 for _, entry := range ds.m.innerMap {
102 f(entry.originalKey)
103 }
104}
105
[489]106func (ds deliveredStore) ForEachClient(f func(clientName string)) {
107 clients := make(map[string]struct{})
108 for _, entry := range ds.m.innerMap {
109 delivered := entry.value.(deliveredClientMap)
110 for clientName := range delivered {
111 clients[clientName] = struct{}{}
112 }
113 }
114
115 for clientName := range clients {
116 f(clientName)
117 }
118}
119
[101]120type network struct {
121 Network
[202]122 user *user
[501]123 logger Logger
[202]124 stopped chan struct{}
[131]125
[482]126 conn *upstreamConn
127 channels channelCasemapMap
[485]128 delivered deliveredStore
[482]129 lastError error
130 casemap casemapping
[101]131}
132
[267]133func newNetwork(user *user, record *Network, channels []Channel) *network {
[501]134 logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
135
[478]136 m := channelCasemapMap{newCasemapMap(0)}
[267]137 for _, ch := range channels {
[283]138 ch := ch
[478]139 m.SetValue(ch.Name, &ch)
[267]140 }
141
[101]142 return &network{
[482]143 Network: *record,
144 user: user,
[501]145 logger: logger,
[482]146 stopped: make(chan struct{}),
147 channels: m,
[485]148 delivered: newDeliveredStore(),
[482]149 casemap: casemapRFC1459,
[101]150 }
151}
152
[218]153func (net *network) forEachDownstream(f func(*downstreamConn)) {
154 net.user.forEachDownstream(func(dc *downstreamConn) {
[532]155 if dc.network == nil && dc.caps["soju.im/bouncer-networks"] {
156 return
157 }
[218]158 if dc.network != nil && dc.network != net {
159 return
160 }
161 f(dc)
162 })
163}
164
[311]165func (net *network) isStopped() bool {
166 select {
167 case <-net.stopped:
168 return true
169 default:
170 return false
171 }
172}
173
[385]174func userIdent(u *User) string {
175 // The ident is a string we will send to upstream servers in clear-text.
176 // For privacy reasons, make sure it doesn't expose any meaningful user
177 // metadata. We just use the base64-encoded hashed ID, so that people don't
178 // start relying on the string being an integer or following a pattern.
179 var b [64]byte
180 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
181 h := sha256.Sum256(b[:])
[395]182 return hex.EncodeToString(h[:16])
[385]183}
184
[101]185func (net *network) run() {
[542]186 if !net.Enabled {
187 return
188 }
189
[101]190 var lastTry time.Time
191 for {
[311]192 if net.isStopped() {
[202]193 return
194 }
195
[398]196 if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
197 delay := retryConnectDelay - dur
[501]198 net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
[101]199 time.Sleep(delay)
200 }
201 lastTry = time.Now()
202
203 uc, err := connectToUpstream(net)
204 if err != nil {
[501]205 net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
[218]206 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
[101]207 continue
208 }
209
[385]210 if net.user.srv.Identd != nil {
211 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
212 }
213
[101]214 uc.register()
[197]215 if err := uc.runUntilRegistered(); err != nil {
[399]216 text := err.Error()
217 if regErr, ok := err.(registrationError); ok {
218 text = string(regErr)
219 }
220 uc.logger.Printf("failed to register: %v", text)
221 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
[197]222 uc.Close()
223 continue
224 }
[101]225
[311]226 // TODO: this is racy with net.stopped. If the network is stopped
227 // before the user goroutine receives eventUpstreamConnected, the
228 // connection won't be closed.
[196]229 net.user.events <- eventUpstreamConnected{uc}
[165]230 if err := uc.readMessages(net.user.events); err != nil {
[101]231 uc.logger.Printf("failed to handle messages: %v", err)
[218]232 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
[101]233 }
234 uc.Close()
[179]235 net.user.events <- eventUpstreamDisconnected{uc}
[385]236
237 if net.user.srv.Identd != nil {
238 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
239 }
[101]240 }
241}
242
[309]243func (net *network) stop() {
[311]244 if !net.isStopped() {
[202]245 close(net.stopped)
246 }
247
[279]248 if net.conn != nil {
249 net.conn.Close()
[202]250 }
251}
252
[435]253func (net *network) detach(ch *Channel) {
254 if ch.Detached {
255 return
[267]256 }
[497]257
[501]258 net.logger.Printf("detaching channel %q", ch.Name)
[435]259
[497]260 ch.Detached = true
261
262 if net.user.msgStore != nil {
263 nameCM := net.casemap(ch.Name)
[666]264 lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
[497]265 if err != nil {
[501]266 net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
[497]267 }
268 ch.DetachedInternalMsgID = lastID
269 }
270
[435]271 if net.conn != nil {
[478]272 uch := net.conn.channels.Value(ch.Name)
273 if uch != nil {
[435]274 uch.updateAutoDetach(0)
275 }
[222]276 }
[284]277
[435]278 net.forEachDownstream(func(dc *downstreamConn) {
279 dc.SendMessage(&irc.Message{
280 Prefix: dc.prefix(),
281 Command: "PART",
282 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
283 })
284 })
285}
[284]286
[435]287func (net *network) attach(ch *Channel) {
288 if !ch.Detached {
289 return
290 }
[497]291
[501]292 net.logger.Printf("attaching channel %q", ch.Name)
[284]293
[497]294 detachedMsgID := ch.DetachedInternalMsgID
295 ch.Detached = false
296 ch.DetachedInternalMsgID = ""
297
[435]298 var uch *upstreamChannel
299 if net.conn != nil {
[478]300 uch = net.conn.channels.Value(ch.Name)
[284]301
[435]302 net.conn.updateChannelAutoDetach(ch.Name)
303 }
[284]304
[435]305 net.forEachDownstream(func(dc *downstreamConn) {
306 dc.SendMessage(&irc.Message{
307 Prefix: dc.prefix(),
308 Command: "JOIN",
309 Params: []string{dc.marshalEntity(net, ch.Name)},
310 })
311
312 if uch != nil {
313 forwardChannel(dc, uch)
[284]314 }
315
[497]316 if detachedMsgID != "" {
317 dc.sendTargetBacklog(net, ch.Name, detachedMsgID)
[495]318 }
[435]319 })
[222]320}
321
322func (net *network) deleteChannel(name string) error {
[478]323 ch := net.channels.Value(name)
324 if ch == nil {
[416]325 return fmt.Errorf("unknown channel %q", name)
326 }
[435]327 if net.conn != nil {
[478]328 uch := net.conn.channels.Value(ch.Name)
329 if uch != nil {
[435]330 uch.updateAutoDetach(0)
331 }
332 }
333
[652]334 if err := net.user.srv.db.DeleteChannel(context.TODO(), ch.ID); err != nil {
[267]335 return err
336 }
[478]337 net.channels.Delete(name)
[267]338 return nil
[222]339}
340
[478]341func (net *network) updateCasemapping(newCasemap casemapping) {
342 net.casemap = newCasemap
343 net.channels.SetCasemapping(newCasemap)
[485]344 net.delivered.m.SetCasemapping(newCasemap)
[478]345 if net.conn != nil {
346 net.conn.channels.SetCasemapping(newCasemap)
347 for _, entry := range net.conn.channels.innerMap {
348 uch := entry.value.(*upstreamChannel)
349 uch.Members.SetCasemapping(newCasemap)
350 }
351 }
352}
353
[489]354func (net *network) storeClientDeliveryReceipts(clientName string) {
355 if !net.user.hasPersistentMsgStore() {
356 return
357 }
358
359 var receipts []DeliveryReceipt
360 net.delivered.ForEachTarget(func(target string) {
361 msgID := net.delivered.LoadID(target, clientName)
362 if msgID == "" {
363 return
364 }
365 receipts = append(receipts, DeliveryReceipt{
366 Target: target,
367 InternalMsgID: msgID,
368 })
369 })
370
[652]371 if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
[501]372 net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
[489]373 }
374}
375
[499]376func (net *network) isHighlight(msg *irc.Message) bool {
377 if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
378 return false
379 }
380
381 text := msg.Params[1]
382
383 nick := net.Nick
384 if net.conn != nil {
385 nick = net.conn.nick
386 }
387
388 // TODO: use case-mapping aware comparison here
389 return msg.Prefix.Name != nick && isHighlight(text, nick)
390}
391
392func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
393 highlight := net.isHighlight(msg)
394 return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
395}
396
[101]397type user struct {
398 User
[493]399 srv *Server
400 logger Logger
[101]401
[165]402 events chan event
[377]403 done chan struct{}
[103]404
[101]405 networks []*network
406 downstreamConns []*downstreamConn
[439]407 msgStore messageStore
[177]408
409 // LIST commands in progress
[179]410 pendingLISTs []pendingLIST
[101]411}
412
[177]413type pendingLIST struct {
414 downstreamID uint64
415 // list of per-upstream LIST commands not yet sent or completed
416 pendingCommands map[int64]*irc.Message
417}
418
[101]419func newUser(srv *Server, record *User) *user {
[493]420 logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
421
[439]422 var msgStore messageStore
[423]423 if srv.LogPath != "" {
[439]424 msgStore = newFSMessageStore(srv.LogPath, record.Username)
[442]425 } else {
426 msgStore = newMemoryMessageStore()
[423]427 }
428
[101]429 return &user{
[489]430 User: *record,
431 srv: srv,
[493]432 logger: logger,
[489]433 events: make(chan event, 64),
434 done: make(chan struct{}),
435 msgStore: msgStore,
[101]436 }
437}
438
439func (u *user) forEachNetwork(f func(*network)) {
440 for _, network := range u.networks {
441 f(network)
442 }
443}
444
445func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
446 for _, network := range u.networks {
[279]447 if network.conn == nil {
[101]448 continue
449 }
[279]450 f(network.conn)
[101]451 }
452}
453
454func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
455 for _, dc := range u.downstreamConns {
456 f(dc)
457 }
458}
459
460func (u *user) getNetwork(name string) *network {
461 for _, network := range u.networks {
462 if network.Addr == name {
463 return network
464 }
[201]465 if network.Name != "" && network.Name == name {
466 return network
467 }
[101]468 }
469 return nil
470}
471
[313]472func (u *user) getNetworkByID(id int64) *network {
473 for _, net := range u.networks {
474 if net.ID == id {
475 return net
476 }
477 }
478 return nil
479}
480
[101]481func (u *user) run() {
[423]482 defer func() {
483 if u.msgStore != nil {
484 if err := u.msgStore.Close(); err != nil {
[493]485 u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
[423]486 }
487 }
488 close(u.done)
489 }()
[377]490
[652]491 networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
[101]492 if err != nil {
[493]493 u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
[101]494 return
495 }
496
497 for _, record := range networks {
[283]498 record := record
[652]499 channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
[267]500 if err != nil {
[493]501 u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
[359]502 continue
[267]503 }
504
505 network := newNetwork(u, &record, channels)
[101]506 u.networks = append(u.networks, network)
507
[489]508 if u.hasPersistentMsgStore() {
[652]509 receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
[489]510 if err != nil {
[493]511 u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
[489]512 return
513 }
514
515 for _, rcpt := range receipts {
516 network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
517 }
518 }
519
[101]520 go network.run()
521 }
[103]522
[165]523 for e := range u.events {
524 switch e := e.(type) {
[196]525 case eventUpstreamConnected:
[198]526 uc := e.uc
[199]527
528 uc.network.conn = uc
529
[198]530 uc.updateAway()
[218]531
[532]532 netIDStr := fmt.Sprintf("%v", uc.network.ID)
[218]533 uc.forEachDownstream(func(dc *downstreamConn) {
[276]534 dc.updateSupportedCaps()
[296]535
[543]536 if !dc.caps["soju.im/bouncer-networks"] {
537 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
538 }
539
540 dc.updateNick()
541 dc.updateRealname()
542 })
543 u.forEachDownstream(func(dc *downstreamConn) {
[535]544 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]545 dc.SendMessage(&irc.Message{
546 Prefix: dc.srv.prefix(),
547 Command: "BOUNCER",
[544]548 Params: []string{"NETWORK", netIDStr, "state=connected"},
[532]549 })
550 }
[218]551 })
552 uc.network.lastError = nil
[179]553 case eventUpstreamDisconnected:
[313]554 u.handleUpstreamDisconnected(e.uc)
555 case eventUpstreamConnectionError:
556 net := e.net
[199]557
[313]558 stopped := false
559 select {
560 case <-net.stopped:
561 stopped = true
562 default:
[179]563 }
[199]564
[313]565 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
[218]566 net.forEachDownstream(func(dc *downstreamConn) {
[223]567 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
[218]568 })
569 }
570 net.lastError = e.err
571 case eventUpstreamError:
572 uc := e.uc
573
574 uc.forEachDownstream(func(dc *downstreamConn) {
[223]575 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
[218]576 })
577 uc.network.lastError = e.err
[165]578 case eventUpstreamMessage:
579 msg, uc := e.msg, e.uc
[175]580 if uc.isClosed() {
[133]581 uc.logger.Printf("ignoring message on closed connection: %v", msg)
582 break
583 }
[103]584 if err := uc.handleMessage(msg); err != nil {
585 uc.logger.Printf("failed to handle message %q: %v", msg, err)
586 }
[435]587 case eventChannelDetach:
588 uc, name := e.uc, e.name
[478]589 c := uc.network.channels.Value(name)
590 if c == nil || c.Detached {
[435]591 continue
592 }
593 uc.network.detach(c)
[652]594 if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
[493]595 u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
[435]596 }
[166]597 case eventDownstreamConnected:
598 dc := e.dc
[168]599
600 if err := dc.welcome(); err != nil {
601 dc.logger.Printf("failed to handle new registered connection: %v", err)
602 break
603 }
604
[166]605 u.downstreamConns = append(u.downstreamConns, dc)
[198]606
[467]607 dc.forEachNetwork(func(network *network) {
608 if network.lastError != nil {
609 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
610 }
611 })
612
[198]613 u.forEachUpstream(func(uc *upstreamConn) {
614 uc.updateAway()
615 })
[167]616 case eventDownstreamDisconnected:
617 dc := e.dc
[204]618
[167]619 for i := range u.downstreamConns {
620 if u.downstreamConns[i] == dc {
621 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
622 break
623 }
624 }
[198]625
[489]626 dc.forEachNetwork(func(net *network) {
627 net.storeClientDeliveryReceipts(dc.clientName)
628 })
629
[198]630 u.forEachUpstream(func(uc *upstreamConn) {
631 uc.updateAway()
632 })
[165]633 case eventDownstreamMessage:
634 msg, dc := e.msg, e.dc
[133]635 if dc.isClosed() {
636 dc.logger.Printf("ignoring message on closed connection: %v", msg)
637 break
638 }
[103]639 err := dc.handleMessage(msg)
640 if ircErr, ok := err.(ircError); ok {
641 ircErr.Message.Prefix = dc.srv.prefix()
642 dc.SendMessage(ircErr.Message)
643 } else if err != nil {
644 dc.logger.Printf("failed to handle message %q: %v", msg, err)
645 dc.Close()
646 }
[563]647 case eventBroadcast:
648 msg := e.msg
649 u.forEachDownstream(func(dc *downstreamConn) {
650 dc.SendMessage(msg)
651 })
[625]652 case eventUserUpdate:
653 // copy the user record because we'll mutate it
654 record := u.User
655
656 if e.password != nil {
657 record.Password = *e.password
658 }
659 if e.admin != nil {
660 record.Admin = *e.admin
661 }
662
663 e.done <- u.updateUser(&record)
664
665 // If the password was updated, kill all downstream connections to
666 // force them to re-authenticate with the new credentials.
667 if e.password != nil {
668 u.forEachDownstream(func(dc *downstreamConn) {
669 dc.Close()
670 })
671 }
[376]672 case eventStop:
673 u.forEachDownstream(func(dc *downstreamConn) {
674 dc.Close()
675 })
676 for _, n := range u.networks {
677 n.stop()
[489]678
679 n.delivered.ForEachClient(func(clientName string) {
680 n.storeClientDeliveryReceipts(clientName)
681 })
[376]682 }
683 return
[165]684 default:
[494]685 panic(fmt.Sprintf("received unknown event type: %T", e))
[103]686 }
687 }
[101]688}
689
[313]690func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
691 uc.network.conn = nil
692
693 uc.endPendingLISTs(true)
694
[478]695 for _, entry := range uc.channels.innerMap {
696 uch := entry.value.(*upstreamChannel)
[435]697 uch.updateAutoDetach(0)
698 }
699
[532]700 netIDStr := fmt.Sprintf("%v", uc.network.ID)
[313]701 uc.forEachDownstream(func(dc *downstreamConn) {
702 dc.updateSupportedCaps()
[543]703 })
[583]704
705 // If the network has been removed, don't send a state change notification
706 found := false
707 for _, net := range u.networks {
708 if net == uc.network {
709 found = true
710 break
711 }
712 }
713 if !found {
714 return
715 }
716
[543]717 u.forEachDownstream(func(dc *downstreamConn) {
[535]718 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]719 dc.SendMessage(&irc.Message{
720 Prefix: dc.srv.prefix(),
721 Command: "BOUNCER",
[544]722 Params: []string{"NETWORK", netIDStr, "state=disconnected"},
[532]723 })
724 }
[313]725 })
726
727 if uc.network.lastError == nil {
728 uc.forEachDownstream(func(dc *downstreamConn) {
[532]729 if !dc.caps["soju.im/bouncer-networks"] {
730 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
731 }
[313]732 })
733 }
734}
735
736func (u *user) addNetwork(network *network) {
737 u.networks = append(u.networks, network)
738 go network.run()
739}
740
741func (u *user) removeNetwork(network *network) {
742 network.stop()
743
744 u.forEachDownstream(func(dc *downstreamConn) {
745 if dc.network != nil && dc.network == network {
746 dc.Close()
747 }
748 })
749
750 for i, net := range u.networks {
751 if net == network {
752 u.networks = append(u.networks[:i], u.networks[i+1:]...)
753 return
754 }
755 }
756
757 panic("tried to remove a non-existing network")
758}
759
[500]760func (u *user) checkNetwork(record *Network) error {
761 for _, net := range u.networks {
762 if net.GetName() == record.GetName() && net.ID != record.ID {
763 return fmt.Errorf("a network with the name %q already exists", record.GetName())
764 }
765 }
766 return nil
767}
768
[313]769func (u *user) createNetwork(record *Network) (*network, error) {
770 if record.ID != 0 {
[144]771 panic("tried creating an already-existing network")
772 }
773
[500]774 if err := u.checkNetwork(record); err != nil {
775 return nil, err
776 }
777
[612]778 if u.srv.MaxUserNetworks >= 0 && len(u.networks) >= u.srv.MaxUserNetworks {
779 return nil, fmt.Errorf("maximum number of networks reached")
780 }
781
[313]782 network := newNetwork(u, record, nil)
[652]783 err := u.srv.db.StoreNetwork(context.TODO(), u.ID, &network.Network)
[101]784 if err != nil {
785 return nil, err
786 }
[144]787
[313]788 u.addNetwork(network)
[144]789
[532]790 idStr := fmt.Sprintf("%v", network.ID)
[535]791 attrs := getNetworkAttrs(network)
[532]792 u.forEachDownstream(func(dc *downstreamConn) {
[535]793 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]794 dc.SendMessage(&irc.Message{
795 Prefix: dc.srv.prefix(),
796 Command: "BOUNCER",
[535]797 Params: []string{"NETWORK", idStr, attrs.String()},
[532]798 })
799 }
800 })
801
[101]802 return network, nil
803}
[202]804
[313]805func (u *user) updateNetwork(record *Network) (*network, error) {
806 if record.ID == 0 {
807 panic("tried updating a new network")
808 }
[202]809
[568]810 // If the realname is reset to the default, just wipe the per-network
811 // setting
812 if record.Realname == u.Realname {
813 record.Realname = ""
814 }
815
[500]816 if err := u.checkNetwork(record); err != nil {
817 return nil, err
818 }
819
[313]820 network := u.getNetworkByID(record.ID)
821 if network == nil {
822 panic("tried updating a non-existing network")
823 }
824
[652]825 if err := u.srv.db.StoreNetwork(context.TODO(), u.ID, record); err != nil {
[313]826 return nil, err
827 }
828
829 // Most network changes require us to re-connect to the upstream server
830
[478]831 channels := make([]Channel, 0, network.channels.Len())
832 for _, entry := range network.channels.innerMap {
833 ch := entry.value.(*Channel)
[313]834 channels = append(channels, *ch)
835 }
836
837 updatedNetwork := newNetwork(u, record, channels)
838
839 // If we're currently connected, disconnect and perform the necessary
840 // bookkeeping
841 if network.conn != nil {
842 network.stop()
843 // Note: this will set network.conn to nil
844 u.handleUpstreamDisconnected(network.conn)
845 }
846
847 // Patch downstream connections to use our fresh updated network
848 u.forEachDownstream(func(dc *downstreamConn) {
849 if dc.network != nil && dc.network == network {
850 dc.network = updatedNetwork
[202]851 }
[313]852 })
[202]853
[313]854 // We need to remove the network after patching downstream connections,
855 // otherwise they'll get closed
856 u.removeNetwork(network)
[202]857
[644]858 // The filesystem message store needs to be notified whenever the network
859 // is renamed
860 fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
861 if isFS && updatedNetwork.GetName() != network.GetName() {
[666]862 if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
[644]863 network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
864 }
865 }
866
[313]867 // This will re-connect to the upstream server
868 u.addNetwork(updatedNetwork)
869
[535]870 // TODO: only broadcast attributes that have changed
871 idStr := fmt.Sprintf("%v", updatedNetwork.ID)
872 attrs := getNetworkAttrs(updatedNetwork)
873 u.forEachDownstream(func(dc *downstreamConn) {
874 if dc.caps["soju.im/bouncer-networks-notify"] {
875 dc.SendMessage(&irc.Message{
876 Prefix: dc.srv.prefix(),
877 Command: "BOUNCER",
878 Params: []string{"NETWORK", idStr, attrs.String()},
879 })
880 }
881 })
[532]882
[313]883 return updatedNetwork, nil
884}
885
886func (u *user) deleteNetwork(id int64) error {
887 network := u.getNetworkByID(id)
888 if network == nil {
889 panic("tried deleting a non-existing network")
[202]890 }
891
[652]892 if err := u.srv.db.DeleteNetwork(context.TODO(), network.ID); err != nil {
[313]893 return err
894 }
895
896 u.removeNetwork(network)
[532]897
898 idStr := fmt.Sprintf("%v", network.ID)
899 u.forEachDownstream(func(dc *downstreamConn) {
[535]900 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]901 dc.SendMessage(&irc.Message{
902 Prefix: dc.srv.prefix(),
903 Command: "BOUNCER",
904 Params: []string{"NETWORK", idStr, "*"},
905 })
906 }
907 })
908
[313]909 return nil
[202]910}
[252]911
[572]912func (u *user) updateUser(record *User) error {
913 if u.ID != record.ID {
914 panic("ID mismatch when updating user")
915 }
[376]916
[572]917 realnameUpdated := u.Realname != record.Realname
[652]918 if err := u.srv.db.StoreUser(context.TODO(), record); err != nil {
[568]919 return fmt.Errorf("failed to update user %q: %v", u.Username, err)
920 }
[572]921 u.User = *record
[568]922
[572]923 if realnameUpdated {
924 // Re-connect to networks which use the default realname
925 var needUpdate []Network
926 u.forEachNetwork(func(net *network) {
927 if net.Realname == "" {
928 needUpdate = append(needUpdate, net.Network)
929 }
930 })
[568]931
[572]932 var netErr error
933 for _, net := range needUpdate {
934 if _, err := u.updateNetwork(&net); err != nil {
935 netErr = err
936 }
[568]937 }
[572]938 if netErr != nil {
939 return netErr
940 }
[568]941 }
942
[572]943 return nil
[568]944}
945
[376]946func (u *user) stop() {
947 u.events <- eventStop{}
[377]948 <-u.done
[376]949}
[489]950
951func (u *user) hasPersistentMsgStore() bool {
952 if u.msgStore == nil {
953 return false
954 }
955 _, isMem := u.msgStore.(*memoryMessageStore)
956 return !isMem
957}
Note: See TracBrowser for help on using the repository browser.