source: code/trunk/user.go@ 729

Last change on this file since 729 was 724, checked in by contact, 4 years ago

Add support for post-connection-registration upstream SASL auth

Once the downstream connection has logged in with their bouncer
credentials, allow them to issue more SASL auths which will be
redirected to the upstream network. This allows downstream clients
to provide UIs to login to transparently login to upstream networks.

File size: 24.0 KB
RevLine 
[101]1package soju
2
3import (
[652]4 "context"
[385]5 "crypto/sha256"
6 "encoding/binary"
[395]7 "encoding/hex"
[218]8 "fmt"
[705]9 "math/big"
10 "net"
[101]11 "time"
[103]12
13 "gopkg.in/irc.v3"
[101]14)
15
[165]16type event interface{}
17
18type eventUpstreamMessage struct {
[103]19 msg *irc.Message
20 uc *upstreamConn
21}
22
[218]23type eventUpstreamConnectionError struct {
24 net *network
25 err error
26}
27
[196]28type eventUpstreamConnected struct {
29 uc *upstreamConn
30}
31
[179]32type eventUpstreamDisconnected struct {
33 uc *upstreamConn
34}
35
[218]36type eventUpstreamError struct {
37 uc *upstreamConn
38 err error
39}
40
[165]41type eventDownstreamMessage struct {
[103]42 msg *irc.Message
43 dc *downstreamConn
44}
45
[166]46type eventDownstreamConnected struct {
47 dc *downstreamConn
48}
49
[167]50type eventDownstreamDisconnected struct {
51 dc *downstreamConn
52}
53
[435]54type eventChannelDetach struct {
55 uc *upstreamConn
56 name string
57}
58
[563]59type eventBroadcast struct {
60 msg *irc.Message
61}
62
[376]63type eventStop struct{}
64
[625]65type eventUserUpdate struct {
66 password *string
67 admin *bool
68 done chan error
69}
70
[480]71type deliveredClientMap map[string]string // client name -> msg ID
72
[485]73type deliveredStore struct {
74 m deliveredCasemapMap
75}
76
77func newDeliveredStore() deliveredStore {
78 return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
79}
80
81func (ds deliveredStore) HasTarget(target string) bool {
82 return ds.m.Value(target) != nil
83}
84
85func (ds deliveredStore) LoadID(target, clientName string) string {
86 clients := ds.m.Value(target)
87 if clients == nil {
88 return ""
89 }
90 return clients[clientName]
91}
92
93func (ds deliveredStore) StoreID(target, clientName, msgID string) {
94 clients := ds.m.Value(target)
95 if clients == nil {
96 clients = make(deliveredClientMap)
97 ds.m.SetValue(target, clients)
98 }
99 clients[clientName] = msgID
100}
101
102func (ds deliveredStore) ForEachTarget(f func(target string)) {
103 for _, entry := range ds.m.innerMap {
104 f(entry.originalKey)
105 }
106}
107
[489]108func (ds deliveredStore) ForEachClient(f func(clientName string)) {
109 clients := make(map[string]struct{})
110 for _, entry := range ds.m.innerMap {
111 delivered := entry.value.(deliveredClientMap)
112 for clientName := range delivered {
113 clients[clientName] = struct{}{}
114 }
115 }
116
117 for clientName := range clients {
118 f(clientName)
119 }
120}
121
[101]122type network struct {
123 Network
[202]124 user *user
[501]125 logger Logger
[202]126 stopped chan struct{}
[131]127
[482]128 conn *upstreamConn
129 channels channelCasemapMap
[485]130 delivered deliveredStore
[482]131 lastError error
132 casemap casemapping
[101]133}
134
[267]135func newNetwork(user *user, record *Network, channels []Channel) *network {
[501]136 logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
137
[478]138 m := channelCasemapMap{newCasemapMap(0)}
[267]139 for _, ch := range channels {
[283]140 ch := ch
[478]141 m.SetValue(ch.Name, &ch)
[267]142 }
143
[101]144 return &network{
[482]145 Network: *record,
146 user: user,
[501]147 logger: logger,
[482]148 stopped: make(chan struct{}),
149 channels: m,
[485]150 delivered: newDeliveredStore(),
[482]151 casemap: casemapRFC1459,
[101]152 }
153}
154
[218]155func (net *network) forEachDownstream(f func(*downstreamConn)) {
156 net.user.forEachDownstream(func(dc *downstreamConn) {
[693]157 if dc.network == nil && !dc.isMultiUpstream {
[532]158 return
159 }
[218]160 if dc.network != nil && dc.network != net {
161 return
162 }
163 f(dc)
164 })
165}
166
[311]167func (net *network) isStopped() bool {
168 select {
169 case <-net.stopped:
170 return true
171 default:
172 return false
173 }
174}
175
[385]176func userIdent(u *User) string {
177 // The ident is a string we will send to upstream servers in clear-text.
178 // For privacy reasons, make sure it doesn't expose any meaningful user
179 // metadata. We just use the base64-encoded hashed ID, so that people don't
180 // start relying on the string being an integer or following a pattern.
181 var b [64]byte
182 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
183 h := sha256.Sum256(b[:])
[395]184 return hex.EncodeToString(h[:16])
[385]185}
186
[101]187func (net *network) run() {
[542]188 if !net.Enabled {
189 return
190 }
191
[101]192 var lastTry time.Time
193 for {
[311]194 if net.isStopped() {
[202]195 return
196 }
197
[398]198 if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
199 delay := retryConnectDelay - dur
[501]200 net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
[101]201 time.Sleep(delay)
202 }
203 lastTry = time.Now()
204
205 uc, err := connectToUpstream(net)
206 if err != nil {
[501]207 net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
[218]208 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
[101]209 continue
210 }
211
[385]212 if net.user.srv.Identd != nil {
213 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
214 }
215
[710]216 net.user.srv.metrics.upstreams.Add(1)
217
[101]218 uc.register()
[197]219 if err := uc.runUntilRegistered(); err != nil {
[399]220 text := err.Error()
221 if regErr, ok := err.(registrationError); ok {
222 text = string(regErr)
223 }
224 uc.logger.Printf("failed to register: %v", text)
225 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
[197]226 uc.Close()
227 continue
228 }
[101]229
[311]230 // TODO: this is racy with net.stopped. If the network is stopped
231 // before the user goroutine receives eventUpstreamConnected, the
232 // connection won't be closed.
[196]233 net.user.events <- eventUpstreamConnected{uc}
[165]234 if err := uc.readMessages(net.user.events); err != nil {
[101]235 uc.logger.Printf("failed to handle messages: %v", err)
[218]236 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
[101]237 }
238 uc.Close()
[179]239 net.user.events <- eventUpstreamDisconnected{uc}
[385]240
241 if net.user.srv.Identd != nil {
242 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
243 }
[710]244
245 net.user.srv.metrics.upstreams.Add(-1)
[101]246 }
247}
248
[309]249func (net *network) stop() {
[311]250 if !net.isStopped() {
[202]251 close(net.stopped)
252 }
253
[279]254 if net.conn != nil {
255 net.conn.Close()
[202]256 }
257}
258
[435]259func (net *network) detach(ch *Channel) {
260 if ch.Detached {
261 return
[267]262 }
[497]263
[501]264 net.logger.Printf("detaching channel %q", ch.Name)
[435]265
[497]266 ch.Detached = true
267
268 if net.user.msgStore != nil {
269 nameCM := net.casemap(ch.Name)
[666]270 lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
[497]271 if err != nil {
[501]272 net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
[497]273 }
274 ch.DetachedInternalMsgID = lastID
275 }
276
[435]277 if net.conn != nil {
[478]278 uch := net.conn.channels.Value(ch.Name)
279 if uch != nil {
[435]280 uch.updateAutoDetach(0)
281 }
[222]282 }
[284]283
[435]284 net.forEachDownstream(func(dc *downstreamConn) {
285 dc.SendMessage(&irc.Message{
286 Prefix: dc.prefix(),
287 Command: "PART",
288 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
289 })
290 })
291}
[284]292
[435]293func (net *network) attach(ch *Channel) {
294 if !ch.Detached {
295 return
296 }
[497]297
[501]298 net.logger.Printf("attaching channel %q", ch.Name)
[284]299
[497]300 detachedMsgID := ch.DetachedInternalMsgID
301 ch.Detached = false
302 ch.DetachedInternalMsgID = ""
303
[435]304 var uch *upstreamChannel
305 if net.conn != nil {
[478]306 uch = net.conn.channels.Value(ch.Name)
[284]307
[435]308 net.conn.updateChannelAutoDetach(ch.Name)
309 }
[284]310
[435]311 net.forEachDownstream(func(dc *downstreamConn) {
312 dc.SendMessage(&irc.Message{
313 Prefix: dc.prefix(),
314 Command: "JOIN",
315 Params: []string{dc.marshalEntity(net, ch.Name)},
316 })
317
318 if uch != nil {
319 forwardChannel(dc, uch)
[284]320 }
321
[497]322 if detachedMsgID != "" {
[701]323 dc.sendTargetBacklog(context.TODO(), net, ch.Name, detachedMsgID)
[495]324 }
[435]325 })
[222]326}
327
[676]328func (net *network) deleteChannel(ctx context.Context, name string) error {
[478]329 ch := net.channels.Value(name)
330 if ch == nil {
[416]331 return fmt.Errorf("unknown channel %q", name)
332 }
[435]333 if net.conn != nil {
[478]334 uch := net.conn.channels.Value(ch.Name)
335 if uch != nil {
[435]336 uch.updateAutoDetach(0)
337 }
338 }
339
[676]340 if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
[267]341 return err
342 }
[478]343 net.channels.Delete(name)
[267]344 return nil
[222]345}
346
[478]347func (net *network) updateCasemapping(newCasemap casemapping) {
348 net.casemap = newCasemap
349 net.channels.SetCasemapping(newCasemap)
[485]350 net.delivered.m.SetCasemapping(newCasemap)
[684]351 if uc := net.conn; uc != nil {
352 uc.channels.SetCasemapping(newCasemap)
353 for _, entry := range uc.channels.innerMap {
[478]354 uch := entry.value.(*upstreamChannel)
355 uch.Members.SetCasemapping(newCasemap)
356 }
[684]357 uc.monitored.SetCasemapping(newCasemap)
[478]358 }
[684]359 net.forEachDownstream(func(dc *downstreamConn) {
360 dc.monitored.SetCasemapping(newCasemap)
361 })
[478]362}
363
[489]364func (net *network) storeClientDeliveryReceipts(clientName string) {
365 if !net.user.hasPersistentMsgStore() {
366 return
367 }
368
369 var receipts []DeliveryReceipt
370 net.delivered.ForEachTarget(func(target string) {
371 msgID := net.delivered.LoadID(target, clientName)
372 if msgID == "" {
373 return
374 }
375 receipts = append(receipts, DeliveryReceipt{
376 Target: target,
377 InternalMsgID: msgID,
378 })
379 })
380
[652]381 if err := net.user.srv.db.StoreClientDeliveryReceipts(context.TODO(), net.ID, clientName, receipts); err != nil {
[501]382 net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
[489]383 }
384}
385
[499]386func (net *network) isHighlight(msg *irc.Message) bool {
387 if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
388 return false
389 }
390
391 text := msg.Params[1]
392
393 nick := net.Nick
394 if net.conn != nil {
395 nick = net.conn.nick
396 }
397
398 // TODO: use case-mapping aware comparison here
399 return msg.Prefix.Name != nick && isHighlight(text, nick)
400}
401
402func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
403 highlight := net.isHighlight(msg)
404 return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
405}
406
[724]407func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
408 // User may have e.g. EXTERNAL mechanism configured. We do not want to
409 // automatically erase the key pair or any other credentials.
410 if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
411 return
412 }
413
414 net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
415 net.SASL.Mechanism = "PLAIN"
416 net.SASL.Plain.Username = username
417 net.SASL.Plain.Password = password
418 if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
419 net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
420 }
421}
422
[101]423type user struct {
424 User
[493]425 srv *Server
426 logger Logger
[101]427
[165]428 events chan event
[377]429 done chan struct{}
[103]430
[101]431 networks []*network
432 downstreamConns []*downstreamConn
[439]433 msgStore messageStore
[101]434}
435
436func newUser(srv *Server, record *User) *user {
[493]437 logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
438
[439]439 var msgStore messageStore
[691]440 if logPath := srv.Config().LogPath; logPath != "" {
441 msgStore = newFSMessageStore(logPath, record.Username)
[442]442 } else {
443 msgStore = newMemoryMessageStore()
[423]444 }
445
[101]446 return &user{
[489]447 User: *record,
448 srv: srv,
[493]449 logger: logger,
[489]450 events: make(chan event, 64),
451 done: make(chan struct{}),
452 msgStore: msgStore,
[101]453 }
454}
455
456func (u *user) forEachNetwork(f func(*network)) {
457 for _, network := range u.networks {
458 f(network)
459 }
460}
461
462func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
463 for _, network := range u.networks {
[279]464 if network.conn == nil {
[101]465 continue
466 }
[279]467 f(network.conn)
[101]468 }
469}
470
471func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
472 for _, dc := range u.downstreamConns {
473 f(dc)
474 }
475}
476
477func (u *user) getNetwork(name string) *network {
478 for _, network := range u.networks {
479 if network.Addr == name {
480 return network
481 }
[201]482 if network.Name != "" && network.Name == name {
483 return network
484 }
[101]485 }
486 return nil
487}
488
[313]489func (u *user) getNetworkByID(id int64) *network {
490 for _, net := range u.networks {
491 if net.ID == id {
492 return net
493 }
494 }
495 return nil
496}
497
[101]498func (u *user) run() {
[423]499 defer func() {
500 if u.msgStore != nil {
501 if err := u.msgStore.Close(); err != nil {
[493]502 u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
[423]503 }
504 }
505 close(u.done)
506 }()
[377]507
[652]508 networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
[101]509 if err != nil {
[493]510 u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
[101]511 return
512 }
513
514 for _, record := range networks {
[283]515 record := record
[652]516 channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
[267]517 if err != nil {
[493]518 u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
[359]519 continue
[267]520 }
521
522 network := newNetwork(u, &record, channels)
[101]523 u.networks = append(u.networks, network)
524
[489]525 if u.hasPersistentMsgStore() {
[652]526 receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
[489]527 if err != nil {
[493]528 u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
[489]529 return
530 }
531
532 for _, rcpt := range receipts {
533 network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
534 }
535 }
536
[101]537 go network.run()
538 }
[103]539
[165]540 for e := range u.events {
541 switch e := e.(type) {
[196]542 case eventUpstreamConnected:
[198]543 uc := e.uc
[199]544
545 uc.network.conn = uc
546
[198]547 uc.updateAway()
[684]548 uc.updateMonitor()
[218]549
[532]550 netIDStr := fmt.Sprintf("%v", uc.network.ID)
[218]551 uc.forEachDownstream(func(dc *downstreamConn) {
[276]552 dc.updateSupportedCaps()
[296]553
[543]554 if !dc.caps["soju.im/bouncer-networks"] {
555 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
556 }
557
558 dc.updateNick()
559 dc.updateRealname()
[722]560 dc.updateAccount()
[543]561 })
562 u.forEachDownstream(func(dc *downstreamConn) {
[535]563 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]564 dc.SendMessage(&irc.Message{
565 Prefix: dc.srv.prefix(),
566 Command: "BOUNCER",
[544]567 Params: []string{"NETWORK", netIDStr, "state=connected"},
[532]568 })
569 }
[218]570 })
571 uc.network.lastError = nil
[179]572 case eventUpstreamDisconnected:
[313]573 u.handleUpstreamDisconnected(e.uc)
574 case eventUpstreamConnectionError:
575 net := e.net
[199]576
[313]577 stopped := false
578 select {
579 case <-net.stopped:
580 stopped = true
581 default:
[179]582 }
[199]583
[313]584 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
[218]585 net.forEachDownstream(func(dc *downstreamConn) {
[223]586 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
[218]587 })
588 }
589 net.lastError = e.err
590 case eventUpstreamError:
591 uc := e.uc
592
593 uc.forEachDownstream(func(dc *downstreamConn) {
[223]594 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
[218]595 })
596 uc.network.lastError = e.err
[165]597 case eventUpstreamMessage:
598 msg, uc := e.msg, e.uc
[175]599 if uc.isClosed() {
[133]600 uc.logger.Printf("ignoring message on closed connection: %v", msg)
601 break
602 }
[103]603 if err := uc.handleMessage(msg); err != nil {
604 uc.logger.Printf("failed to handle message %q: %v", msg, err)
605 }
[435]606 case eventChannelDetach:
607 uc, name := e.uc, e.name
[478]608 c := uc.network.channels.Value(name)
609 if c == nil || c.Detached {
[435]610 continue
611 }
612 uc.network.detach(c)
[652]613 if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
[493]614 u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
[435]615 }
[166]616 case eventDownstreamConnected:
617 dc := e.dc
[168]618
[684]619 if dc.network != nil {
620 dc.monitored.SetCasemapping(dc.network.casemap)
621 }
622
[701]623 if err := dc.welcome(context.TODO()); err != nil {
[168]624 dc.logger.Printf("failed to handle new registered connection: %v", err)
625 break
626 }
627
[166]628 u.downstreamConns = append(u.downstreamConns, dc)
[198]629
[467]630 dc.forEachNetwork(func(network *network) {
631 if network.lastError != nil {
632 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
633 }
634 })
635
[198]636 u.forEachUpstream(func(uc *upstreamConn) {
637 uc.updateAway()
638 })
[167]639 case eventDownstreamDisconnected:
640 dc := e.dc
[204]641
[167]642 for i := range u.downstreamConns {
643 if u.downstreamConns[i] == dc {
644 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
645 break
646 }
647 }
[198]648
[489]649 dc.forEachNetwork(func(net *network) {
650 net.storeClientDeliveryReceipts(dc.clientName)
651 })
652
[198]653 u.forEachUpstream(func(uc *upstreamConn) {
654 uc.updateAway()
[684]655 uc.updateMonitor()
[198]656 })
[165]657 case eventDownstreamMessage:
658 msg, dc := e.msg, e.dc
[133]659 if dc.isClosed() {
660 dc.logger.Printf("ignoring message on closed connection: %v", msg)
661 break
662 }
[704]663 err := dc.handleMessage(context.TODO(), msg)
[103]664 if ircErr, ok := err.(ircError); ok {
665 ircErr.Message.Prefix = dc.srv.prefix()
666 dc.SendMessage(ircErr.Message)
667 } else if err != nil {
668 dc.logger.Printf("failed to handle message %q: %v", msg, err)
669 dc.Close()
670 }
[563]671 case eventBroadcast:
672 msg := e.msg
673 u.forEachDownstream(func(dc *downstreamConn) {
674 dc.SendMessage(msg)
675 })
[625]676 case eventUserUpdate:
677 // copy the user record because we'll mutate it
678 record := u.User
679
680 if e.password != nil {
681 record.Password = *e.password
682 }
683 if e.admin != nil {
684 record.Admin = *e.admin
685 }
686
[676]687 e.done <- u.updateUser(context.TODO(), &record)
[625]688
689 // If the password was updated, kill all downstream connections to
690 // force them to re-authenticate with the new credentials.
691 if e.password != nil {
692 u.forEachDownstream(func(dc *downstreamConn) {
693 dc.Close()
694 })
695 }
[376]696 case eventStop:
697 u.forEachDownstream(func(dc *downstreamConn) {
698 dc.Close()
699 })
700 for _, n := range u.networks {
701 n.stop()
[489]702
703 n.delivered.ForEachClient(func(clientName string) {
704 n.storeClientDeliveryReceipts(clientName)
705 })
[376]706 }
707 return
[165]708 default:
[494]709 panic(fmt.Sprintf("received unknown event type: %T", e))
[103]710 }
711 }
[101]712}
713
[313]714func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
715 uc.network.conn = nil
716
[682]717 uc.endPendingCommands()
[313]718
[478]719 for _, entry := range uc.channels.innerMap {
720 uch := entry.value.(*upstreamChannel)
[435]721 uch.updateAutoDetach(0)
722 }
723
[532]724 netIDStr := fmt.Sprintf("%v", uc.network.ID)
[313]725 uc.forEachDownstream(func(dc *downstreamConn) {
726 dc.updateSupportedCaps()
[543]727 })
[583]728
729 // If the network has been removed, don't send a state change notification
730 found := false
731 for _, net := range u.networks {
732 if net == uc.network {
733 found = true
734 break
735 }
736 }
737 if !found {
738 return
739 }
740
[543]741 u.forEachDownstream(func(dc *downstreamConn) {
[535]742 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]743 dc.SendMessage(&irc.Message{
744 Prefix: dc.srv.prefix(),
745 Command: "BOUNCER",
[544]746 Params: []string{"NETWORK", netIDStr, "state=disconnected"},
[532]747 })
748 }
[313]749 })
750
751 if uc.network.lastError == nil {
752 uc.forEachDownstream(func(dc *downstreamConn) {
[532]753 if !dc.caps["soju.im/bouncer-networks"] {
754 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
755 }
[313]756 })
757 }
758}
759
760func (u *user) addNetwork(network *network) {
761 u.networks = append(u.networks, network)
762 go network.run()
763}
764
765func (u *user) removeNetwork(network *network) {
766 network.stop()
767
768 u.forEachDownstream(func(dc *downstreamConn) {
769 if dc.network != nil && dc.network == network {
770 dc.Close()
771 }
772 })
773
774 for i, net := range u.networks {
775 if net == network {
776 u.networks = append(u.networks[:i], u.networks[i+1:]...)
777 return
778 }
779 }
780
781 panic("tried to remove a non-existing network")
782}
783
[500]784func (u *user) checkNetwork(record *Network) error {
785 for _, net := range u.networks {
786 if net.GetName() == record.GetName() && net.ID != record.ID {
787 return fmt.Errorf("a network with the name %q already exists", record.GetName())
788 }
789 }
790 return nil
791}
792
[676]793func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
[313]794 if record.ID != 0 {
[144]795 panic("tried creating an already-existing network")
796 }
797
[500]798 if err := u.checkNetwork(record); err != nil {
799 return nil, err
800 }
801
[691]802 if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
[612]803 return nil, fmt.Errorf("maximum number of networks reached")
804 }
805
[313]806 network := newNetwork(u, record, nil)
[676]807 err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
[101]808 if err != nil {
809 return nil, err
810 }
[144]811
[313]812 u.addNetwork(network)
[144]813
[532]814 idStr := fmt.Sprintf("%v", network.ID)
[535]815 attrs := getNetworkAttrs(network)
[532]816 u.forEachDownstream(func(dc *downstreamConn) {
[535]817 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]818 dc.SendMessage(&irc.Message{
819 Prefix: dc.srv.prefix(),
820 Command: "BOUNCER",
[535]821 Params: []string{"NETWORK", idStr, attrs.String()},
[532]822 })
823 }
824 })
825
[101]826 return network, nil
827}
[202]828
[676]829func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
[313]830 if record.ID == 0 {
831 panic("tried updating a new network")
832 }
[202]833
[568]834 // If the realname is reset to the default, just wipe the per-network
835 // setting
836 if record.Realname == u.Realname {
837 record.Realname = ""
838 }
839
[500]840 if err := u.checkNetwork(record); err != nil {
841 return nil, err
842 }
843
[313]844 network := u.getNetworkByID(record.ID)
845 if network == nil {
846 panic("tried updating a non-existing network")
847 }
848
[676]849 if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
[313]850 return nil, err
851 }
852
853 // Most network changes require us to re-connect to the upstream server
854
[478]855 channels := make([]Channel, 0, network.channels.Len())
856 for _, entry := range network.channels.innerMap {
857 ch := entry.value.(*Channel)
[313]858 channels = append(channels, *ch)
859 }
860
861 updatedNetwork := newNetwork(u, record, channels)
862
863 // If we're currently connected, disconnect and perform the necessary
864 // bookkeeping
865 if network.conn != nil {
866 network.stop()
867 // Note: this will set network.conn to nil
868 u.handleUpstreamDisconnected(network.conn)
869 }
870
871 // Patch downstream connections to use our fresh updated network
872 u.forEachDownstream(func(dc *downstreamConn) {
873 if dc.network != nil && dc.network == network {
874 dc.network = updatedNetwork
[202]875 }
[313]876 })
[202]877
[313]878 // We need to remove the network after patching downstream connections,
879 // otherwise they'll get closed
880 u.removeNetwork(network)
[202]881
[644]882 // The filesystem message store needs to be notified whenever the network
883 // is renamed
884 fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
885 if isFS && updatedNetwork.GetName() != network.GetName() {
[666]886 if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
[644]887 network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
888 }
889 }
890
[313]891 // This will re-connect to the upstream server
892 u.addNetwork(updatedNetwork)
893
[535]894 // TODO: only broadcast attributes that have changed
895 idStr := fmt.Sprintf("%v", updatedNetwork.ID)
896 attrs := getNetworkAttrs(updatedNetwork)
897 u.forEachDownstream(func(dc *downstreamConn) {
898 if dc.caps["soju.im/bouncer-networks-notify"] {
899 dc.SendMessage(&irc.Message{
900 Prefix: dc.srv.prefix(),
901 Command: "BOUNCER",
902 Params: []string{"NETWORK", idStr, attrs.String()},
903 })
904 }
905 })
[532]906
[313]907 return updatedNetwork, nil
908}
909
[676]910func (u *user) deleteNetwork(ctx context.Context, id int64) error {
[313]911 network := u.getNetworkByID(id)
912 if network == nil {
913 panic("tried deleting a non-existing network")
[202]914 }
915
[676]916 if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
[313]917 return err
918 }
919
920 u.removeNetwork(network)
[532]921
922 idStr := fmt.Sprintf("%v", network.ID)
923 u.forEachDownstream(func(dc *downstreamConn) {
[535]924 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]925 dc.SendMessage(&irc.Message{
926 Prefix: dc.srv.prefix(),
927 Command: "BOUNCER",
928 Params: []string{"NETWORK", idStr, "*"},
929 })
930 }
931 })
932
[313]933 return nil
[202]934}
[252]935
[676]936func (u *user) updateUser(ctx context.Context, record *User) error {
[572]937 if u.ID != record.ID {
938 panic("ID mismatch when updating user")
939 }
[376]940
[572]941 realnameUpdated := u.Realname != record.Realname
[676]942 if err := u.srv.db.StoreUser(ctx, record); err != nil {
[568]943 return fmt.Errorf("failed to update user %q: %v", u.Username, err)
944 }
[572]945 u.User = *record
[568]946
[572]947 if realnameUpdated {
948 // Re-connect to networks which use the default realname
949 var needUpdate []Network
950 u.forEachNetwork(func(net *network) {
951 if net.Realname == "" {
952 needUpdate = append(needUpdate, net.Network)
953 }
954 })
[568]955
[572]956 var netErr error
957 for _, net := range needUpdate {
[676]958 if _, err := u.updateNetwork(ctx, &net); err != nil {
[572]959 netErr = err
960 }
[568]961 }
[572]962 if netErr != nil {
963 return netErr
964 }
[568]965 }
966
[572]967 return nil
[568]968}
969
[376]970func (u *user) stop() {
971 u.events <- eventStop{}
[377]972 <-u.done
[376]973}
[489]974
975func (u *user) hasPersistentMsgStore() bool {
976 if u.msgStore == nil {
977 return false
978 }
979 _, isMem := u.msgStore.(*memoryMessageStore)
980 return !isMem
981}
[705]982
983// localAddrForHost returns the local address to use when connecting to host.
984// A nil address is returned when the OS should automatically pick one.
985func (u *user) localTCPAddrForHost(host string) (*net.TCPAddr, error) {
986 upstreamUserIPs := u.srv.Config().UpstreamUserIPs
987 if len(upstreamUserIPs) == 0 {
988 return nil, nil
989 }
990
991 ips, err := net.LookupIP(host)
992 if err != nil {
993 return nil, err
994 }
995
996 wantIPv6 := false
997 for _, ip := range ips {
998 if ip.To4() == nil {
999 wantIPv6 = true
1000 break
1001 }
1002 }
1003
1004 var ipNet *net.IPNet
1005 for _, in := range upstreamUserIPs {
1006 if wantIPv6 == (in.IP.To4() == nil) {
1007 ipNet = in
1008 break
1009 }
1010 }
1011 if ipNet == nil {
1012 return nil, nil
1013 }
1014
1015 var ipInt big.Int
1016 ipInt.SetBytes(ipNet.IP)
1017 ipInt.Add(&ipInt, big.NewInt(u.ID+1))
1018 ip := net.IP(ipInt.Bytes())
1019 if !ipNet.Contains(ip) {
1020 return nil, fmt.Errorf("IP network %v too small", ipNet)
1021 }
1022
1023 return &net.TCPAddr{IP: ip}, nil
1024}
Note: See TracBrowser for help on using the repository browser.