source: code/trunk/user.go@ 728

Last change on this file since 728 was 724, checked in by contact, 4 years ago

Add support for post-connection-registration upstream SASL auth

Once the downstream connection has logged in with their bouncer
credentials, allow them to issue more SASL auths which will be
redirected to the upstream network. This allows downstream clients
to provide UIs to login to transparently login to upstream networks.

File size: 24.0 KB
Line 
1package soju
2
3import (
4 "context"
5 "crypto/sha256"
6 "encoding/binary"
7 "encoding/hex"
8 "fmt"
9 "math/big"
10 "net"
11 "time"
12
13 "gopkg.in/irc.v3"
14)
15
16type event interface{}
17
18type eventUpstreamMessage struct {
19 msg *irc.Message
20 uc *upstreamConn
21}
22
23type eventUpstreamConnectionError struct {
24 net *network
25 err error
26}
27
28type eventUpstreamConnected struct {
29 uc *upstreamConn
30}
31
32type eventUpstreamDisconnected struct {
33 uc *upstreamConn
34}
35
36type eventUpstreamError struct {
37 uc *upstreamConn
38 err error
39}
40
41type eventDownstreamMessage struct {
42 msg *irc.Message
43 dc *downstreamConn
44}
45
46type eventDownstreamConnected struct {
47 dc *downstreamConn
48}
49
50type eventDownstreamDisconnected struct {
51 dc *downstreamConn
52}
53
54type eventChannelDetach struct {
55 uc *upstreamConn
56 name string
57}
58
59type eventBroadcast struct {
60 msg *irc.Message
61}
62
63type eventStop struct{}
64
65type eventUserUpdate struct {
66 password *string
67 admin *bool
68 done chan error
69}
70
71type deliveredClientMap map[string]string // client name -> msg ID
72
73type deliveredStore struct {
74 m deliveredCasemapMap
75}
76
77func newDeliveredStore() deliveredStore {
78 return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
79}
80
81func (ds deliveredStore) HasTarget(target string) bool {
82 return ds.m.Value(target) != nil
83}
84
85func (ds deliveredStore) LoadID(target, clientName string) string {
86 clients := ds.m.Value(target)
87 if clients == nil {
88 return ""
89 }
90 return clients[clientName]
91}
92
93func (ds deliveredStore) StoreID(target, clientName, msgID string) {
94 clients := ds.m.Value(target)
95 if clients == nil {
96 clients = make(deliveredClientMap)
97 ds.m.SetValue(target, clients)
98 }
99 clients[clientName] = msgID
100}
101
102func (ds deliveredStore) ForEachTarget(f func(target string)) {
103 for _, entry := range ds.m.innerMap {
104 f(entry.originalKey)
105 }
106}
107
108func (ds deliveredStore) ForEachClient(f func(clientName string)) {
109 clients := make(map[string]struct{})
110 for _, entry := range ds.m.innerMap {
111 delivered := entry.value.(deliveredClientMap)
112 for clientName := range delivered {
113 clients[clientName] = struct{}{}
114 }
115 }
116
117 for clientName := range clients {
118 f(clientName)
119 }
120}
121
122type network struct {
123 Network
124 user *user
125 logger Logger
126 stopped chan struct{}
127
128 conn *upstreamConn
129 channels channelCasemapMap
130 delivered deliveredStore
131 lastError error
132 casemap casemapping
133}
134
135func newNetwork(user *user, record *Network, channels []Channel) *network {
136 logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
137
138 m := channelCasemapMap{newCasemapMap(0)}
139 for _, ch := range channels {
140 ch := ch
141 m.SetValue(ch.Name, &ch)
142 }
143
144 return &network{
145 Network: *record,
146 user: user,
147 logger: logger,
148 stopped: make(chan struct{}),
149 channels: m,
150 delivered: newDeliveredStore(),
151 casemap: casemapRFC1459,
152 }
153}
154
155func (net *network) forEachDownstream(f func(*downstreamConn)) {
156 net.user.forEachDownstream(func(dc *downstreamConn) {
157 if dc.network == nil && !dc.isMultiUpstream {
158 return
159 }
160 if dc.network != nil && dc.network != net {
161 return
162 }
163 f(dc)
164 })
165}
166
167func (net *network) isStopped() bool {
168 select {
169 case <-net.stopped:
170 return true
171 default:
172 return false
173 }
174}
175
176func userIdent(u *User) string {
177 // The ident is a string we will send to upstream servers in clear-text.
178 // For privacy reasons, make sure it doesn't expose any meaningful user
179 // metadata. We just use the base64-encoded hashed ID, so that people don't
180 // start relying on the string being an integer or following a pattern.
181 var b [64]byte
182 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
183 h := sha256.Sum256(b[:])
184 return hex.EncodeToString(h[:16])
185}
186
187func (net *network) run() {
188 if !net.Enabled {
189 return
190 }
191
192 var lastTry time.Time
193 for {
194 if net.isStopped() {
195 return
196 }
197
198 if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
199 delay := retryConnectDelay - dur
200 net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
201 time.Sleep(delay)
202 }
203 lastTry = time.Now()
204
205 uc, err := connectToUpstream(net)
206 if err != nil {
207 net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
208 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
209 continue
210 }
211
212 if net.user.srv.Identd != nil {
213 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
214 }
215
216 net.user.srv.metrics.upstreams.Add(1)
217
218 uc.register()
219 if err := uc.runUntilRegistered(); err != nil {
220 text := err.Error()
221 if regErr, ok := err.(registrationError); ok {
222 text = string(regErr)
223 }
224 uc.logger.Printf("failed to register: %v", text)
225 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
226 uc.Close()
227 continue
228 }
229
230 // TODO: this is racy with net.stopped. If the network is stopped
231 // before the user goroutine receives eventUpstreamConnected, the
232 // connection won't be closed.
233 net.user.events <- eventUpstreamConnected{uc}
234 if err := uc.readMessages(net.user.events); err != nil {
235 uc.logger.Printf("failed to handle messages: %v", err)
236 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
237 }
238 uc.Close()
239 net.user.events <- eventUpstreamDisconnected{uc}
240
241 if net.user.srv.Identd != nil {
242 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
243 }
244
245 net.user.srv.metrics.upstreams.Add(-1)
246 }
247}
248
249func (net *network) stop() {
250 if !net.isStopped() {
251 close(net.stopped)
252 }
253
254 if net.conn != nil {
255 net.conn.Close()
256 }
257}
258
259func (net *network) detach(ch *Channel) {
260 if ch.Detached {
261 return
262 }
263
264 net.logger.Printf("detaching channel %q", ch.Name)
265
266 ch.Detached = true
267
268 if net.user.msgStore != nil {
269 nameCM := net.casemap(ch.Name)
270 lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
271 if err != nil {
272 net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
273 }
274 ch.DetachedInternalMsgID = lastID
275 }
276
277 if net.conn != nil {
278 uch := net.conn.channels.Value(ch.Name)
279 if uch != nil {
280 uch.updateAutoDetach(0)
281 }
282 }
283
284 net.forEachDownstream(func(dc *downstreamConn) {
285 dc.SendMessage(&irc.Message{
286 Prefix: dc.prefix(),
287 Command: "PART",
288 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
289 })
290 })
291}
292
293func (net *network) attach(ch *Channel) {
294 if !ch.Detached {
295 return
296 }
297
298 net.logger.Printf("attaching channel %q", ch.Name)
299
300 detachedMsgID := ch.DetachedInternalMsgID
301 ch.Detached = false
302 ch.DetachedInternalMsgID = ""
303
304 var uch *upstreamChannel
305 if net.conn != nil {
306 uch = net.conn.channels.Value(ch.Name)
307
308 net.conn.updateChannelAutoDetach(ch.Name)
309 }
310
311 net.forEachDownstream(func(dc *downstreamConn) {
312 dc.SendMessage(&irc.Message{
313 Prefix: dc.prefix(),
314 Command: "JOIN",
315 Params: []string{dc.marshalEntity(net, ch.Name)},
316 })
317
318 if uch != nil {
319 forwardChannel(dc, uch)
320 }
321
322 if detachedMsgID != "" {
323 dc.sendTargetBacklog(context.TODO(), net, ch.Name, detachedMsgID)
324 }
325 })
326}
327
328func (net *network) deleteChannel(ctx context.Context, name string) error {
329 ch := net.channels.Value(name)
330 if ch == nil {
331 return fmt.Errorf("unknown channel %q", name)
332 }
333 if net.conn != nil {
334 uch := net.conn.channels.Value(ch.Name)
335 if uch != nil {
336 uch.updateAutoDetach(0)
337 }
338 }
339
340 if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
341 return err
342 }
343 net.channels.Delete(name)
344 return nil
345}
346
347func (net *network) updateCasemapping(newCasemap casemapping) {
348 net.casemap = newCasemap
349 net.channels.SetCasemapping(newCasemap)
350 net.delivered.m.SetCasemapping(newCasemap)
351 if uc := net.conn; uc != nil {
352 uc.channels.SetCasemapping(newCasemap)
353 for _, entry := range uc.channels.innerMap {
354 uch := entry.value.(*upstreamChannel)
355 uch.Members.SetCasemapping(newCasemap)
356 }
357 uc.monitored.SetCasemapping(newCasemap)
358 }
359 net.forEachDownstream(func(dc *downstreamConn) {
360 dc.monitored.SetCasemapping(newCasemap)
361 })
362}
363
364func (net *network) storeClientDeliveryReceipts(clientName string) {
365 if !net.user.hasPersistentMsgStore() {
366 return
367 }
368
369 var receipts []DeliveryReceipt
370 net.delivered.ForEachTarget(func(target string) {
371 msgID := net.delivered.LoadID(target, clientName)
372 if msgID == "" {
373 return
374 }
375 receipts = append(receipts, DeliveryReceipt{
376 Target: target,
377 InternalMsgID: msgID,
378 })
379 })
380
381 if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
382 net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
383 }
384}
385
386func (net *network) isHighlight(msg *irc.Message) bool {
387 if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
388 return false
389 }
390
391 text := msg.Params[1]
392
393 nick := net.Nick
394 if net.conn != nil {
395 nick = net.conn.nick
396 }
397
398 // TODO: use case-mapping aware comparison here
399 return msg.Prefix.Name != nick && isHighlight(text, nick)
400}
401
402func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
403 highlight := net.isHighlight(msg)
404 return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
405}
406
407func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
408 // User may have e.g. EXTERNAL mechanism configured. We do not want to
409 // automatically erase the key pair or any other credentials.
410 if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
411 return
412 }
413
414 net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
415 net.SASL.Mechanism = "PLAIN"
416 net.SASL.Plain.Username = username
417 net.SASL.Plain.Password = password
418 if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
419 net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
420 }
421}
422
423type user struct {
424 User
425 srv *Server
426 logger Logger
427
428 events chan event
429 done chan struct{}
430
431 networks []*network
432 downstreamConns []*downstreamConn
433 msgStore messageStore
434}
435
436func newUser(srv *Server, record *User) *user {
437 logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
438
439 var msgStore messageStore
440 if logPath := srv.Config().LogPath; logPath != "" {
441 msgStore = newFSMessageStore(logPath, record.Username)
442 } else {
443 msgStore = newMemoryMessageStore()
444 }
445
446 return &user{
447 User: *record,
448 srv: srv,
449 logger: logger,
450 events: make(chan event, 64),
451 done: make(chan struct{}),
452 msgStore: msgStore,
453 }
454}
455
456func (u *user) forEachNetwork(f func(*network)) {
457 for _, network := range u.networks {
458 f(network)
459 }
460}
461
462func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
463 for _, network := range u.networks {
464 if network.conn == nil {
465 continue
466 }
467 f(network.conn)
468 }
469}
470
471func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
472 for _, dc := range u.downstreamConns {
473 f(dc)
474 }
475}
476
477func (u *user) getNetwork(name string) *network {
478 for _, network := range u.networks {
479 if network.Addr == name {
480 return network
481 }
482 if network.Name != "" && network.Name == name {
483 return network
484 }
485 }
486 return nil
487}
488
489func (u *user) getNetworkByID(id int64) *network {
490 for _, net := range u.networks {
491 if net.ID == id {
492 return net
493 }
494 }
495 return nil
496}
497
498func (u *user) run() {
499 defer func() {
500 if u.msgStore != nil {
501 if err := u.msgStore.Close(); err != nil {
502 u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
503 }
504 }
505 close(u.done)
506 }()
507
508 networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
509 if err != nil {
510 u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
511 return
512 }
513
514 for _, record := range networks {
515 record := record
516 channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
517 if err != nil {
518 u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
519 continue
520 }
521
522 network := newNetwork(u, &record, channels)
523 u.networks = append(u.networks, network)
524
525 if u.hasPersistentMsgStore() {
526 receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
527 if err != nil {
528 u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
529 return
530 }
531
532 for _, rcpt := range receipts {
533 network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
534 }
535 }
536
537 go network.run()
538 }
539
540 for e := range u.events {
541 switch e := e.(type) {
542 case eventUpstreamConnected:
543 uc := e.uc
544
545 uc.network.conn = uc
546
547 uc.updateAway()
548 uc.updateMonitor()
549
550 netIDStr := fmt.Sprintf("%v", uc.network.ID)
551 uc.forEachDownstream(func(dc *downstreamConn) {
552 dc.updateSupportedCaps()
553
554 if !dc.caps["soju.im/bouncer-networks"] {
555 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
556 }
557
558 dc.updateNick()
559 dc.updateRealname()
560 dc.updateAccount()
561 })
562 u.forEachDownstream(func(dc *downstreamConn) {
563 if dc.caps["soju.im/bouncer-networks-notify"] {
564 dc.SendMessage(&irc.Message{
565 Prefix: dc.srv.prefix(),
566 Command: "BOUNCER",
567 Params: []string{"NETWORK", netIDStr, "state=connected"},
568 })
569 }
570 })
571 uc.network.lastError = nil
572 case eventUpstreamDisconnected:
573 u.handleUpstreamDisconnected(e.uc)
574 case eventUpstreamConnectionError:
575 net := e.net
576
577 stopped := false
578 select {
579 case <-net.stopped:
580 stopped = true
581 default:
582 }
583
584 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
585 net.forEachDownstream(func(dc *downstreamConn) {
586 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
587 })
588 }
589 net.lastError = e.err
590 case eventUpstreamError:
591 uc := e.uc
592
593 uc.forEachDownstream(func(dc *downstreamConn) {
594 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
595 })
596 uc.network.lastError = e.err
597 case eventUpstreamMessage:
598 msg, uc := e.msg, e.uc
599 if uc.isClosed() {
600 uc.logger.Printf("ignoring message on closed connection: %v", msg)
601 break
602 }
603 if err := uc.handleMessage(msg); err != nil {
604 uc.logger.Printf("failed to handle message %q: %v", msg, err)
605 }
606 case eventChannelDetach:
607 uc, name := e.uc, e.name
608 c := uc.network.channels.Value(name)
609 if c == nil || c.Detached {
610 continue
611 }
612 uc.network.detach(c)
613 if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
614 u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
615 }
616 case eventDownstreamConnected:
617 dc := e.dc
618
619 if dc.network != nil {
620 dc.monitored.SetCasemapping(dc.network.casemap)
621 }
622
623 if err := dc.welcome(context.TODO()); err != nil {
624 dc.logger.Printf("failed to handle new registered connection: %v", err)
625 break
626 }
627
628 u.downstreamConns = append(u.downstreamConns, dc)
629
630 dc.forEachNetwork(func(network *network) {
631 if network.lastError != nil {
632 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
633 }
634 })
635
636 u.forEachUpstream(func(uc *upstreamConn) {
637 uc.updateAway()
638 })
639 case eventDownstreamDisconnected:
640 dc := e.dc
641
642 for i := range u.downstreamConns {
643 if u.downstreamConns[i] == dc {
644 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
645 break
646 }
647 }
648
649 dc.forEachNetwork(func(net *network) {
650 net.storeClientDeliveryReceipts(dc.clientName)
651 })
652
653 u.forEachUpstream(func(uc *upstreamConn) {
654 uc.updateAway()
655 uc.updateMonitor()
656 })
657 case eventDownstreamMessage:
658 msg, dc := e.msg, e.dc
659 if dc.isClosed() {
660 dc.logger.Printf("ignoring message on closed connection: %v", msg)
661 break
662 }
663 err := dc.handleMessage(context.TODO(), msg)
664 if ircErr, ok := err.(ircError); ok {
665 ircErr.Message.Prefix = dc.srv.prefix()
666 dc.SendMessage(ircErr.Message)
667 } else if err != nil {
668 dc.logger.Printf("failed to handle message %q: %v", msg, err)
669 dc.Close()
670 }
671 case eventBroadcast:
672 msg := e.msg
673 u.forEachDownstream(func(dc *downstreamConn) {
674 dc.SendMessage(msg)
675 })
676 case eventUserUpdate:
677 // copy the user record because we'll mutate it
678 record := u.User
679
680 if e.password != nil {
681 record.Password = *e.password
682 }
683 if e.admin != nil {
684 record.Admin = *e.admin
685 }
686
687 e.done <- u.updateUser(context.TODO(), &record)
688
689 // If the password was updated, kill all downstream connections to
690 // force them to re-authenticate with the new credentials.
691 if e.password != nil {
692 u.forEachDownstream(func(dc *downstreamConn) {
693 dc.Close()
694 })
695 }
696 case eventStop:
697 u.forEachDownstream(func(dc *downstreamConn) {
698 dc.Close()
699 })
700 for _, n := range u.networks {
701 n.stop()
702
703 n.delivered.ForEachClient(func(clientName string) {
704 n.storeClientDeliveryReceipts(clientName)
705 })
706 }
707 return
708 default:
709 panic(fmt.Sprintf("received unknown event type: %T", e))
710 }
711 }
712}
713
714func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
715 uc.network.conn = nil
716
717 uc.endPendingCommands()
718
719 for _, entry := range uc.channels.innerMap {
720 uch := entry.value.(*upstreamChannel)
721 uch.updateAutoDetach(0)
722 }
723
724 netIDStr := fmt.Sprintf("%v", uc.network.ID)
725 uc.forEachDownstream(func(dc *downstreamConn) {
726 dc.updateSupportedCaps()
727 })
728
729 // If the network has been removed, don't send a state change notification
730 found := false
731 for _, net := range u.networks {
732 if net == uc.network {
733 found = true
734 break
735 }
736 }
737 if !found {
738 return
739 }
740
741 u.forEachDownstream(func(dc *downstreamConn) {
742 if dc.caps["soju.im/bouncer-networks-notify"] {
743 dc.SendMessage(&irc.Message{
744 Prefix: dc.srv.prefix(),
745 Command: "BOUNCER",
746 Params: []string{"NETWORK", netIDStr, "state=disconnected"},
747 })
748 }
749 })
750
751 if uc.network.lastError == nil {
752 uc.forEachDownstream(func(dc *downstreamConn) {
753 if !dc.caps["soju.im/bouncer-networks"] {
754 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
755 }
756 })
757 }
758}
759
760func (u *user) addNetwork(network *network) {
761 u.networks = append(u.networks, network)
762 go network.run()
763}
764
765func (u *user) removeNetwork(network *network) {
766 network.stop()
767
768 u.forEachDownstream(func(dc *downstreamConn) {
769 if dc.network != nil && dc.network == network {
770 dc.Close()
771 }
772 })
773
774 for i, net := range u.networks {
775 if net == network {
776 u.networks = append(u.networks[:i], u.networks[i+1:]...)
777 return
778 }
779 }
780
781 panic("tried to remove a non-existing network")
782}
783
784func (u *user) checkNetwork(record *Network) error {
785 for _, net := range u.networks {
786 if net.GetName() == record.GetName() && net.ID != record.ID {
787 return fmt.Errorf("a network with the name %q already exists", record.GetName())
788 }
789 }
790 return nil
791}
792
793func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
794 if record.ID != 0 {
795 panic("tried creating an already-existing network")
796 }
797
798 if err := u.checkNetwork(record); err != nil {
799 return nil, err
800 }
801
802 if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
803 return nil, fmt.Errorf("maximum number of networks reached")
804 }
805
806 network := newNetwork(u, record, nil)
807 err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
808 if err != nil {
809 return nil, err
810 }
811
812 u.addNetwork(network)
813
814 idStr := fmt.Sprintf("%v", network.ID)
815 attrs := getNetworkAttrs(network)
816 u.forEachDownstream(func(dc *downstreamConn) {
817 if dc.caps["soju.im/bouncer-networks-notify"] {
818 dc.SendMessage(&irc.Message{
819 Prefix: dc.srv.prefix(),
820 Command: "BOUNCER",
821 Params: []string{"NETWORK", idStr, attrs.String()},
822 })
823 }
824 })
825
826 return network, nil
827}
828
829func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
830 if record.ID == 0 {
831 panic("tried updating a new network")
832 }
833
834 // If the realname is reset to the default, just wipe the per-network
835 // setting
836 if record.Realname == u.Realname {
837 record.Realname = ""
838 }
839
840 if err := u.checkNetwork(record); err != nil {
841 return nil, err
842 }
843
844 network := u.getNetworkByID(record.ID)
845 if network == nil {
846 panic("tried updating a non-existing network")
847 }
848
849 if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
850 return nil, err
851 }
852
853 // Most network changes require us to re-connect to the upstream server
854
855 channels := make([]Channel, 0, network.channels.Len())
856 for _, entry := range network.channels.innerMap {
857 ch := entry.value.(*Channel)
858 channels = append(channels, *ch)
859 }
860
861 updatedNetwork := newNetwork(u, record, channels)
862
863 // If we're currently connected, disconnect and perform the necessary
864 // bookkeeping
865 if network.conn != nil {
866 network.stop()
867 // Note: this will set network.conn to nil
868 u.handleUpstreamDisconnected(network.conn)
869 }
870
871 // Patch downstream connections to use our fresh updated network
872 u.forEachDownstream(func(dc *downstreamConn) {
873 if dc.network != nil && dc.network == network {
874 dc.network = updatedNetwork
875 }
876 })
877
878 // We need to remove the network after patching downstream connections,
879 // otherwise they'll get closed
880 u.removeNetwork(network)
881
882 // The filesystem message store needs to be notified whenever the network
883 // is renamed
884 fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
885 if isFS && updatedNetwork.GetName() != network.GetName() {
886 if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
887 network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
888 }
889 }
890
891 // This will re-connect to the upstream server
892 u.addNetwork(updatedNetwork)
893
894 // TODO: only broadcast attributes that have changed
895 idStr := fmt.Sprintf("%v", updatedNetwork.ID)
896 attrs := getNetworkAttrs(updatedNetwork)
897 u.forEachDownstream(func(dc *downstreamConn) {
898 if dc.caps["soju.im/bouncer-networks-notify"] {
899 dc.SendMessage(&irc.Message{
900 Prefix: dc.srv.prefix(),
901 Command: "BOUNCER",
902 Params: []string{"NETWORK", idStr, attrs.String()},
903 })
904 }
905 })
906
907 return updatedNetwork, nil
908}
909
910func (u *user) deleteNetwork(ctx context.Context, id int64) error {
911 network := u.getNetworkByID(id)
912 if network == nil {
913 panic("tried deleting a non-existing network")
914 }
915
916 if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
917 return err
918 }
919
920 u.removeNetwork(network)
921
922 idStr := fmt.Sprintf("%v", network.ID)
923 u.forEachDownstream(func(dc *downstreamConn) {
924 if dc.caps["soju.im/bouncer-networks-notify"] {
925 dc.SendMessage(&irc.Message{
926 Prefix: dc.srv.prefix(),
927 Command: "BOUNCER",
928 Params: []string{"NETWORK", idStr, "*"},
929 })
930 }
931 })
932
933 return nil
934}
935
936func (u *user) updateUser(ctx context.Context, record *User) error {
937 if u.ID != record.ID {
938 panic("ID mismatch when updating user")
939 }
940
941 realnameUpdated := u.Realname != record.Realname
942 if err := u.srv.db.StoreUser(ctx, record); err != nil {
943 return fmt.Errorf("failed to update user %q: %v", u.Username, err)
944 }
945 u.User = *record
946
947 if realnameUpdated {
948 // Re-connect to networks which use the default realname
949 var needUpdate []Network
950 u.forEachNetwork(func(net *network) {
951 if net.Realname == "" {
952 needUpdate = append(needUpdate, net.Network)
953 }
954 })
955
956 var netErr error
957 for _, net := range needUpdate {
958 if _, err := u.updateNetwork(ctx, &net); err != nil {
959 netErr = err
960 }
961 }
962 if netErr != nil {
963 return netErr
964 }
965 }
966
967 return nil
968}
969
970func (u *user) stop() {
971 u.events <- eventStop{}
972 <-u.done
973}
974
975func (u *user) hasPersistentMsgStore() bool {
976 if u.msgStore == nil {
977 return false
978 }
979 _, isMem := u.msgStore.(*memoryMessageStore)
980 return !isMem
981}
982
983// localAddrForHost returns the local address to use when connecting to host.
984// A nil address is returned when the OS should automatically pick one.
985func (u *user) localTCPAddrForHost(host string) (*net.TCPAddr, error) {
986 upstreamUserIPs := u.srv.Config().UpstreamUserIPs
987 if len(upstreamUserIPs) == 0 {
988 return nil, nil
989 }
990
991 ips, err := net.LookupIP(host)
992 if err != nil {
993 return nil, err
994 }
995
996 wantIPv6 := false
997 for _, ip := range ips {
998 if ip.To4() == nil {
999 wantIPv6 = true
1000 break
1001 }
1002 }
1003
1004 var ipNet *net.IPNet
1005 for _, in := range upstreamUserIPs {
1006 if wantIPv6 == (in.IP.To4() == nil) {
1007 ipNet = in
1008 break
1009 }
1010 }
1011 if ipNet == nil {
1012 return nil, nil
1013 }
1014
1015 var ipInt big.Int
1016 ipInt.SetBytes(ipNet.IP)
1017 ipInt.Add(&ipInt, big.NewInt(u.ID+1))
1018 ip := net.IP(ipInt.Bytes())
1019 if !ipNet.Contains(ip) {
1020 return nil, fmt.Errorf("IP network %v too small", ipNet)
1021 }
1022
1023 return &net.TCPAddr{IP: ip}, nil
1024}
Note: See TracBrowser for help on using the repository browser.