source: code/trunk/user.go@ 771

Last change on this file since 771 was 769, checked in by contact, 3 years ago

Ensure consistent network ordering

Right now there is no consistent ordering in the network list:
no ORDER BY in the DB, and network updates move entries to the end.

Let's always sort by network ID so that users don't see the entries
move around.

I've contemplated sorting by Network.GetName() instead, but:

  • Clients have now way to figure out dynamic order changes, e.g. when renaming a network.
  • Some clients might use ISUPPORT NETWORK when a user hasn't explicitly named a network, but soju won't use that for ordering, leading to non-alphabetic ordering in the client.

Let's leave it to clients to sort the networks by display name if
they want to.

File size: 25.5 KB
RevLine 
[101]1package soju
2
3import (
[652]4 "context"
[385]5 "crypto/sha256"
6 "encoding/binary"
[395]7 "encoding/hex"
[218]8 "fmt"
[705]9 "math/big"
10 "net"
[769]11 "sort"
[101]12 "time"
[103]13
14 "gopkg.in/irc.v3"
[101]15)
16
[165]17type event interface{}
18
19type eventUpstreamMessage struct {
[103]20 msg *irc.Message
21 uc *upstreamConn
22}
23
[218]24type eventUpstreamConnectionError struct {
25 net *network
26 err error
27}
28
[196]29type eventUpstreamConnected struct {
30 uc *upstreamConn
31}
32
[179]33type eventUpstreamDisconnected struct {
34 uc *upstreamConn
35}
36
[218]37type eventUpstreamError struct {
38 uc *upstreamConn
39 err error
40}
41
[165]42type eventDownstreamMessage struct {
[103]43 msg *irc.Message
44 dc *downstreamConn
45}
46
[166]47type eventDownstreamConnected struct {
48 dc *downstreamConn
49}
50
[167]51type eventDownstreamDisconnected struct {
52 dc *downstreamConn
53}
54
[435]55type eventChannelDetach struct {
56 uc *upstreamConn
57 name string
58}
59
[563]60type eventBroadcast struct {
61 msg *irc.Message
62}
63
[376]64type eventStop struct{}
65
[625]66type eventUserUpdate struct {
67 password *string
68 admin *bool
69 done chan error
70}
71
[480]72type deliveredClientMap map[string]string // client name -> msg ID
73
[485]74type deliveredStore struct {
75 m deliveredCasemapMap
76}
77
78func newDeliveredStore() deliveredStore {
79 return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
80}
81
82func (ds deliveredStore) HasTarget(target string) bool {
83 return ds.m.Value(target) != nil
84}
85
86func (ds deliveredStore) LoadID(target, clientName string) string {
87 clients := ds.m.Value(target)
88 if clients == nil {
89 return ""
90 }
91 return clients[clientName]
92}
93
94func (ds deliveredStore) StoreID(target, clientName, msgID string) {
95 clients := ds.m.Value(target)
96 if clients == nil {
97 clients = make(deliveredClientMap)
98 ds.m.SetValue(target, clients)
99 }
100 clients[clientName] = msgID
101}
102
103func (ds deliveredStore) ForEachTarget(f func(target string)) {
104 for _, entry := range ds.m.innerMap {
105 f(entry.originalKey)
106 }
107}
108
[489]109func (ds deliveredStore) ForEachClient(f func(clientName string)) {
110 clients := make(map[string]struct{})
111 for _, entry := range ds.m.innerMap {
112 delivered := entry.value.(deliveredClientMap)
113 for clientName := range delivered {
114 clients[clientName] = struct{}{}
115 }
116 }
117
118 for clientName := range clients {
119 f(clientName)
120 }
121}
122
[101]123type network struct {
124 Network
[202]125 user *user
[501]126 logger Logger
[202]127 stopped chan struct{}
[131]128
[482]129 conn *upstreamConn
130 channels channelCasemapMap
[485]131 delivered deliveredStore
[482]132 lastError error
133 casemap casemapping
[101]134}
135
[267]136func newNetwork(user *user, record *Network, channels []Channel) *network {
[501]137 logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
138
[478]139 m := channelCasemapMap{newCasemapMap(0)}
[267]140 for _, ch := range channels {
[283]141 ch := ch
[478]142 m.SetValue(ch.Name, &ch)
[267]143 }
144
[101]145 return &network{
[482]146 Network: *record,
147 user: user,
[501]148 logger: logger,
[482]149 stopped: make(chan struct{}),
150 channels: m,
[485]151 delivered: newDeliveredStore(),
[482]152 casemap: casemapRFC1459,
[101]153 }
154}
155
[218]156func (net *network) forEachDownstream(f func(*downstreamConn)) {
157 net.user.forEachDownstream(func(dc *downstreamConn) {
[693]158 if dc.network == nil && !dc.isMultiUpstream {
[532]159 return
160 }
[218]161 if dc.network != nil && dc.network != net {
162 return
163 }
164 f(dc)
165 })
166}
167
[311]168func (net *network) isStopped() bool {
169 select {
170 case <-net.stopped:
171 return true
172 default:
173 return false
174 }
175}
176
[385]177func userIdent(u *User) string {
178 // The ident is a string we will send to upstream servers in clear-text.
179 // For privacy reasons, make sure it doesn't expose any meaningful user
180 // metadata. We just use the base64-encoded hashed ID, so that people don't
181 // start relying on the string being an integer or following a pattern.
182 var b [64]byte
183 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
184 h := sha256.Sum256(b[:])
[395]185 return hex.EncodeToString(h[:16])
[385]186}
187
[101]188func (net *network) run() {
[542]189 if !net.Enabled {
190 return
191 }
192
[101]193 var lastTry time.Time
[735]194 backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter)
[101]195 for {
[311]196 if net.isStopped() {
[202]197 return
198 }
199
[735]200 delay := backoff.Next() - time.Now().Sub(lastTry)
201 if delay > 0 {
[501]202 net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
[101]203 time.Sleep(delay)
204 }
205 lastTry = time.Now()
206
[733]207 net.user.srv.metrics.upstreams.Add(1)
208
[732]209 uc, err := connectToUpstream(context.TODO(), net)
[101]210 if err != nil {
[501]211 net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
[218]212 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
[733]213 net.user.srv.metrics.upstreams.Add(-1)
[734]214 net.user.srv.metrics.upstreamConnectErrorsTotal.Inc()
[101]215 continue
216 }
217
[385]218 if net.user.srv.Identd != nil {
219 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
220 }
221
[101]222 uc.register()
[197]223 if err := uc.runUntilRegistered(); err != nil {
[399]224 text := err.Error()
[736]225 temp := true
[399]226 if regErr, ok := err.(registrationError); ok {
[736]227 text = regErr.Reason()
228 temp = regErr.Temporary()
[399]229 }
230 uc.logger.Printf("failed to register: %v", text)
231 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
[197]232 uc.Close()
[733]233 net.user.srv.metrics.upstreams.Add(-1)
[734]234 net.user.srv.metrics.upstreamConnectErrorsTotal.Inc()
[736]235 if !temp {
236 return
237 }
[197]238 continue
239 }
[101]240
[311]241 // TODO: this is racy with net.stopped. If the network is stopped
242 // before the user goroutine receives eventUpstreamConnected, the
243 // connection won't be closed.
[196]244 net.user.events <- eventUpstreamConnected{uc}
[165]245 if err := uc.readMessages(net.user.events); err != nil {
[101]246 uc.logger.Printf("failed to handle messages: %v", err)
[218]247 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
[101]248 }
249 uc.Close()
[179]250 net.user.events <- eventUpstreamDisconnected{uc}
[385]251
252 if net.user.srv.Identd != nil {
253 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
254 }
[710]255
256 net.user.srv.metrics.upstreams.Add(-1)
[735]257 backoff.Reset()
[101]258 }
259}
260
[309]261func (net *network) stop() {
[311]262 if !net.isStopped() {
[202]263 close(net.stopped)
264 }
265
[279]266 if net.conn != nil {
267 net.conn.Close()
[202]268 }
269}
270
[435]271func (net *network) detach(ch *Channel) {
272 if ch.Detached {
273 return
[267]274 }
[497]275
[501]276 net.logger.Printf("detaching channel %q", ch.Name)
[435]277
[497]278 ch.Detached = true
279
280 if net.user.msgStore != nil {
281 nameCM := net.casemap(ch.Name)
[666]282 lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
[497]283 if err != nil {
[501]284 net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
[497]285 }
286 ch.DetachedInternalMsgID = lastID
287 }
288
[435]289 if net.conn != nil {
[478]290 uch := net.conn.channels.Value(ch.Name)
291 if uch != nil {
[435]292 uch.updateAutoDetach(0)
293 }
[222]294 }
[284]295
[435]296 net.forEachDownstream(func(dc *downstreamConn) {
297 dc.SendMessage(&irc.Message{
298 Prefix: dc.prefix(),
299 Command: "PART",
300 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
301 })
302 })
303}
[284]304
[435]305func (net *network) attach(ch *Channel) {
306 if !ch.Detached {
307 return
308 }
[497]309
[501]310 net.logger.Printf("attaching channel %q", ch.Name)
[284]311
[497]312 detachedMsgID := ch.DetachedInternalMsgID
313 ch.Detached = false
314 ch.DetachedInternalMsgID = ""
315
[435]316 var uch *upstreamChannel
317 if net.conn != nil {
[478]318 uch = net.conn.channels.Value(ch.Name)
[284]319
[435]320 net.conn.updateChannelAutoDetach(ch.Name)
321 }
[284]322
[435]323 net.forEachDownstream(func(dc *downstreamConn) {
324 dc.SendMessage(&irc.Message{
325 Prefix: dc.prefix(),
326 Command: "JOIN",
327 Params: []string{dc.marshalEntity(net, ch.Name)},
328 })
329
330 if uch != nil {
331 forwardChannel(dc, uch)
[284]332 }
333
[497]334 if detachedMsgID != "" {
[701]335 dc.sendTargetBacklog(context.TODO(), net, ch.Name, detachedMsgID)
[495]336 }
[435]337 })
[222]338}
339
[676]340func (net *network) deleteChannel(ctx context.Context, name string) error {
[478]341 ch := net.channels.Value(name)
342 if ch == nil {
[416]343 return fmt.Errorf("unknown channel %q", name)
344 }
[435]345 if net.conn != nil {
[478]346 uch := net.conn.channels.Value(ch.Name)
347 if uch != nil {
[435]348 uch.updateAutoDetach(0)
349 }
350 }
351
[676]352 if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
[267]353 return err
354 }
[478]355 net.channels.Delete(name)
[267]356 return nil
[222]357}
358
[478]359func (net *network) updateCasemapping(newCasemap casemapping) {
360 net.casemap = newCasemap
361 net.channels.SetCasemapping(newCasemap)
[485]362 net.delivered.m.SetCasemapping(newCasemap)
[684]363 if uc := net.conn; uc != nil {
364 uc.channels.SetCasemapping(newCasemap)
365 for _, entry := range uc.channels.innerMap {
[478]366 uch := entry.value.(*upstreamChannel)
367 uch.Members.SetCasemapping(newCasemap)
368 }
[684]369 uc.monitored.SetCasemapping(newCasemap)
[478]370 }
[684]371 net.forEachDownstream(func(dc *downstreamConn) {
372 dc.monitored.SetCasemapping(newCasemap)
373 })
[478]374}
375
[740]376func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) {
[489]377 if !net.user.hasPersistentMsgStore() {
378 return
379 }
380
381 var receipts []DeliveryReceipt
382 net.delivered.ForEachTarget(func(target string) {
383 msgID := net.delivered.LoadID(target, clientName)
384 if msgID == "" {
385 return
386 }
387 receipts = append(receipts, DeliveryReceipt{
388 Target: target,
389 InternalMsgID: msgID,
390 })
391 })
392
[740]393 if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil {
[501]394 net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
[489]395 }
396}
397
[499]398func (net *network) isHighlight(msg *irc.Message) bool {
399 if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
400 return false
401 }
402
403 text := msg.Params[1]
404
405 nick := net.Nick
406 if net.conn != nil {
407 nick = net.conn.nick
408 }
409
410 // TODO: use case-mapping aware comparison here
411 return msg.Prefix.Name != nick && isHighlight(text, nick)
412}
413
414func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
415 highlight := net.isHighlight(msg)
416 return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
417}
418
[724]419func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
420 // User may have e.g. EXTERNAL mechanism configured. We do not want to
421 // automatically erase the key pair or any other credentials.
422 if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
423 return
424 }
425
426 net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
427 net.SASL.Mechanism = "PLAIN"
428 net.SASL.Plain.Username = username
429 net.SASL.Plain.Password = password
430 if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
431 net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
432 }
433}
434
[101]435type user struct {
436 User
[493]437 srv *Server
438 logger Logger
[101]439
[165]440 events chan event
[377]441 done chan struct{}
[103]442
[101]443 networks []*network
444 downstreamConns []*downstreamConn
[439]445 msgStore messageStore
[101]446}
447
448func newUser(srv *Server, record *User) *user {
[493]449 logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
450
[439]451 var msgStore messageStore
[691]452 if logPath := srv.Config().LogPath; logPath != "" {
453 msgStore = newFSMessageStore(logPath, record.Username)
[442]454 } else {
455 msgStore = newMemoryMessageStore()
[423]456 }
457
[101]458 return &user{
[489]459 User: *record,
460 srv: srv,
[493]461 logger: logger,
[489]462 events: make(chan event, 64),
463 done: make(chan struct{}),
464 msgStore: msgStore,
[101]465 }
466}
467
468func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
469 for _, network := range u.networks {
[279]470 if network.conn == nil {
[101]471 continue
472 }
[279]473 f(network.conn)
[101]474 }
475}
476
477func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
478 for _, dc := range u.downstreamConns {
479 f(dc)
480 }
481}
482
483func (u *user) getNetwork(name string) *network {
484 for _, network := range u.networks {
485 if network.Addr == name {
486 return network
487 }
[201]488 if network.Name != "" && network.Name == name {
489 return network
490 }
[101]491 }
492 return nil
493}
494
[313]495func (u *user) getNetworkByID(id int64) *network {
496 for _, net := range u.networks {
497 if net.ID == id {
498 return net
499 }
500 }
501 return nil
502}
503
[101]504func (u *user) run() {
[423]505 defer func() {
506 if u.msgStore != nil {
507 if err := u.msgStore.Close(); err != nil {
[493]508 u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
[423]509 }
510 }
511 close(u.done)
512 }()
[377]513
[652]514 networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
[101]515 if err != nil {
[493]516 u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
[101]517 return
518 }
519
[769]520 sort.Slice(networks, func(i, j int) bool {
521 return networks[i].ID < networks[j].ID
522 })
523
[101]524 for _, record := range networks {
[283]525 record := record
[652]526 channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
[267]527 if err != nil {
[493]528 u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
[359]529 continue
[267]530 }
531
532 network := newNetwork(u, &record, channels)
[101]533 u.networks = append(u.networks, network)
534
[489]535 if u.hasPersistentMsgStore() {
[652]536 receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
[489]537 if err != nil {
[493]538 u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
[489]539 return
540 }
541
542 for _, rcpt := range receipts {
543 network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
544 }
545 }
546
[101]547 go network.run()
548 }
[103]549
[165]550 for e := range u.events {
551 switch e := e.(type) {
[196]552 case eventUpstreamConnected:
[198]553 uc := e.uc
[199]554
555 uc.network.conn = uc
556
[198]557 uc.updateAway()
[684]558 uc.updateMonitor()
[218]559
[532]560 netIDStr := fmt.Sprintf("%v", uc.network.ID)
[218]561 uc.forEachDownstream(func(dc *downstreamConn) {
[276]562 dc.updateSupportedCaps()
[296]563
[543]564 if !dc.caps["soju.im/bouncer-networks"] {
565 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
566 }
567
568 dc.updateNick()
569 dc.updateRealname()
[722]570 dc.updateAccount()
[543]571 })
572 u.forEachDownstream(func(dc *downstreamConn) {
[535]573 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]574 dc.SendMessage(&irc.Message{
575 Prefix: dc.srv.prefix(),
576 Command: "BOUNCER",
[544]577 Params: []string{"NETWORK", netIDStr, "state=connected"},
[532]578 })
579 }
[218]580 })
581 uc.network.lastError = nil
[179]582 case eventUpstreamDisconnected:
[313]583 u.handleUpstreamDisconnected(e.uc)
584 case eventUpstreamConnectionError:
585 net := e.net
[199]586
[313]587 stopped := false
588 select {
589 case <-net.stopped:
590 stopped = true
591 default:
[179]592 }
[199]593
[313]594 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
[218]595 net.forEachDownstream(func(dc *downstreamConn) {
[223]596 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
[218]597 })
598 }
599 net.lastError = e.err
600 case eventUpstreamError:
601 uc := e.uc
602
603 uc.forEachDownstream(func(dc *downstreamConn) {
[223]604 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
[218]605 })
606 uc.network.lastError = e.err
[165]607 case eventUpstreamMessage:
608 msg, uc := e.msg, e.uc
[175]609 if uc.isClosed() {
[133]610 uc.logger.Printf("ignoring message on closed connection: %v", msg)
611 break
612 }
[739]613 if err := uc.handleMessage(context.TODO(), msg); err != nil {
[103]614 uc.logger.Printf("failed to handle message %q: %v", msg, err)
615 }
[435]616 case eventChannelDetach:
617 uc, name := e.uc, e.name
[478]618 c := uc.network.channels.Value(name)
619 if c == nil || c.Detached {
[435]620 continue
621 }
622 uc.network.detach(c)
[652]623 if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
[493]624 u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
[435]625 }
[166]626 case eventDownstreamConnected:
627 dc := e.dc
[168]628
[684]629 if dc.network != nil {
630 dc.monitored.SetCasemapping(dc.network.casemap)
631 }
632
[701]633 if err := dc.welcome(context.TODO()); err != nil {
[168]634 dc.logger.Printf("failed to handle new registered connection: %v", err)
635 break
636 }
637
[166]638 u.downstreamConns = append(u.downstreamConns, dc)
[198]639
[467]640 dc.forEachNetwork(func(network *network) {
641 if network.lastError != nil {
642 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
643 }
644 })
645
[198]646 u.forEachUpstream(func(uc *upstreamConn) {
647 uc.updateAway()
648 })
[167]649 case eventDownstreamDisconnected:
650 dc := e.dc
[204]651
[167]652 for i := range u.downstreamConns {
653 if u.downstreamConns[i] == dc {
654 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
655 break
656 }
657 }
[198]658
[489]659 dc.forEachNetwork(func(net *network) {
[740]660 net.storeClientDeliveryReceipts(context.TODO(), dc.clientName)
[489]661 })
662
[198]663 u.forEachUpstream(func(uc *upstreamConn) {
[738]664 uc.cancelPendingCommandsByDownstreamID(dc.id)
[198]665 uc.updateAway()
[684]666 uc.updateMonitor()
[198]667 })
[165]668 case eventDownstreamMessage:
669 msg, dc := e.msg, e.dc
[133]670 if dc.isClosed() {
671 dc.logger.Printf("ignoring message on closed connection: %v", msg)
672 break
673 }
[704]674 err := dc.handleMessage(context.TODO(), msg)
[103]675 if ircErr, ok := err.(ircError); ok {
676 ircErr.Message.Prefix = dc.srv.prefix()
677 dc.SendMessage(ircErr.Message)
678 } else if err != nil {
679 dc.logger.Printf("failed to handle message %q: %v", msg, err)
680 dc.Close()
681 }
[563]682 case eventBroadcast:
683 msg := e.msg
684 u.forEachDownstream(func(dc *downstreamConn) {
685 dc.SendMessage(msg)
686 })
[625]687 case eventUserUpdate:
688 // copy the user record because we'll mutate it
689 record := u.User
690
691 if e.password != nil {
692 record.Password = *e.password
693 }
694 if e.admin != nil {
695 record.Admin = *e.admin
696 }
697
[676]698 e.done <- u.updateUser(context.TODO(), &record)
[625]699
700 // If the password was updated, kill all downstream connections to
701 // force them to re-authenticate with the new credentials.
702 if e.password != nil {
703 u.forEachDownstream(func(dc *downstreamConn) {
704 dc.Close()
705 })
706 }
[376]707 case eventStop:
708 u.forEachDownstream(func(dc *downstreamConn) {
709 dc.Close()
710 })
711 for _, n := range u.networks {
712 n.stop()
[489]713
714 n.delivered.ForEachClient(func(clientName string) {
[740]715 n.storeClientDeliveryReceipts(context.TODO(), clientName)
[489]716 })
[376]717 }
718 return
[165]719 default:
[494]720 panic(fmt.Sprintf("received unknown event type: %T", e))
[103]721 }
722 }
[101]723}
724
[313]725func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
726 uc.network.conn = nil
727
[752]728 uc.abortPendingCommands()
[313]729
[478]730 for _, entry := range uc.channels.innerMap {
731 uch := entry.value.(*upstreamChannel)
[435]732 uch.updateAutoDetach(0)
733 }
734
[532]735 netIDStr := fmt.Sprintf("%v", uc.network.ID)
[313]736 uc.forEachDownstream(func(dc *downstreamConn) {
737 dc.updateSupportedCaps()
[543]738 })
[583]739
740 // If the network has been removed, don't send a state change notification
741 found := false
742 for _, net := range u.networks {
743 if net == uc.network {
744 found = true
745 break
746 }
747 }
748 if !found {
749 return
750 }
751
[543]752 u.forEachDownstream(func(dc *downstreamConn) {
[535]753 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]754 dc.SendMessage(&irc.Message{
755 Prefix: dc.srv.prefix(),
756 Command: "BOUNCER",
[544]757 Params: []string{"NETWORK", netIDStr, "state=disconnected"},
[532]758 })
759 }
[313]760 })
761
762 if uc.network.lastError == nil {
763 uc.forEachDownstream(func(dc *downstreamConn) {
[532]764 if !dc.caps["soju.im/bouncer-networks"] {
765 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
766 }
[313]767 })
768 }
769}
770
771func (u *user) addNetwork(network *network) {
772 u.networks = append(u.networks, network)
[769]773
774 sort.Slice(u.networks, func(i, j int) bool {
775 return u.networks[i].ID < u.networks[j].ID
776 })
777
[313]778 go network.run()
779}
780
781func (u *user) removeNetwork(network *network) {
782 network.stop()
783
784 u.forEachDownstream(func(dc *downstreamConn) {
785 if dc.network != nil && dc.network == network {
786 dc.Close()
787 }
788 })
789
790 for i, net := range u.networks {
791 if net == network {
792 u.networks = append(u.networks[:i], u.networks[i+1:]...)
793 return
794 }
795 }
796
797 panic("tried to remove a non-existing network")
798}
799
[500]800func (u *user) checkNetwork(record *Network) error {
[731]801 url, err := record.URL()
802 if err != nil {
803 return err
804 }
805 if url.User != nil {
806 return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme)
807 }
808 if url.RawQuery != "" {
809 return fmt.Errorf("%v:// URL must not have query values", url.Scheme)
810 }
811 if url.Fragment != "" {
812 return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme)
813 }
814 switch url.Scheme {
815 case "ircs", "irc+insecure":
816 if url.Host == "" {
817 return fmt.Errorf("%v:// URL must have a host", url.Scheme)
818 }
819 if url.Path != "" {
820 return fmt.Errorf("%v:// URL must not have a path", url.Scheme)
821 }
822 case "irc+unix", "unix":
823 if url.Host != "" {
824 return fmt.Errorf("%v:// URL must not have a host", url.Scheme)
825 }
826 if url.Path == "" {
827 return fmt.Errorf("%v:// URL must have a path", url.Scheme)
828 }
829 default:
830 return fmt.Errorf("unknown URL scheme %q", url.Scheme)
831 }
832
[500]833 for _, net := range u.networks {
834 if net.GetName() == record.GetName() && net.ID != record.ID {
835 return fmt.Errorf("a network with the name %q already exists", record.GetName())
836 }
837 }
[731]838
[500]839 return nil
840}
841
[676]842func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
[313]843 if record.ID != 0 {
[144]844 panic("tried creating an already-existing network")
845 }
846
[500]847 if err := u.checkNetwork(record); err != nil {
848 return nil, err
849 }
850
[691]851 if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
[612]852 return nil, fmt.Errorf("maximum number of networks reached")
853 }
854
[313]855 network := newNetwork(u, record, nil)
[676]856 err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
[101]857 if err != nil {
858 return nil, err
859 }
[144]860
[313]861 u.addNetwork(network)
[144]862
[532]863 idStr := fmt.Sprintf("%v", network.ID)
[535]864 attrs := getNetworkAttrs(network)
[532]865 u.forEachDownstream(func(dc *downstreamConn) {
[535]866 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]867 dc.SendMessage(&irc.Message{
868 Prefix: dc.srv.prefix(),
869 Command: "BOUNCER",
[535]870 Params: []string{"NETWORK", idStr, attrs.String()},
[532]871 })
872 }
873 })
874
[101]875 return network, nil
876}
[202]877
[676]878func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
[313]879 if record.ID == 0 {
880 panic("tried updating a new network")
881 }
[202]882
[568]883 // If the realname is reset to the default, just wipe the per-network
884 // setting
885 if record.Realname == u.Realname {
886 record.Realname = ""
887 }
888
[500]889 if err := u.checkNetwork(record); err != nil {
890 return nil, err
891 }
892
[313]893 network := u.getNetworkByID(record.ID)
894 if network == nil {
895 panic("tried updating a non-existing network")
896 }
897
[676]898 if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
[313]899 return nil, err
900 }
901
902 // Most network changes require us to re-connect to the upstream server
903
[478]904 channels := make([]Channel, 0, network.channels.Len())
905 for _, entry := range network.channels.innerMap {
906 ch := entry.value.(*Channel)
[313]907 channels = append(channels, *ch)
908 }
909
910 updatedNetwork := newNetwork(u, record, channels)
911
912 // If we're currently connected, disconnect and perform the necessary
913 // bookkeeping
914 if network.conn != nil {
915 network.stop()
916 // Note: this will set network.conn to nil
917 u.handleUpstreamDisconnected(network.conn)
918 }
919
920 // Patch downstream connections to use our fresh updated network
921 u.forEachDownstream(func(dc *downstreamConn) {
922 if dc.network != nil && dc.network == network {
923 dc.network = updatedNetwork
[202]924 }
[313]925 })
[202]926
[313]927 // We need to remove the network after patching downstream connections,
928 // otherwise they'll get closed
929 u.removeNetwork(network)
[202]930
[644]931 // The filesystem message store needs to be notified whenever the network
932 // is renamed
933 fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
934 if isFS && updatedNetwork.GetName() != network.GetName() {
[666]935 if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
[644]936 network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
937 }
938 }
939
[313]940 // This will re-connect to the upstream server
941 u.addNetwork(updatedNetwork)
942
[535]943 // TODO: only broadcast attributes that have changed
944 idStr := fmt.Sprintf("%v", updatedNetwork.ID)
945 attrs := getNetworkAttrs(updatedNetwork)
946 u.forEachDownstream(func(dc *downstreamConn) {
947 if dc.caps["soju.im/bouncer-networks-notify"] {
948 dc.SendMessage(&irc.Message{
949 Prefix: dc.srv.prefix(),
950 Command: "BOUNCER",
951 Params: []string{"NETWORK", idStr, attrs.String()},
952 })
953 }
954 })
[532]955
[313]956 return updatedNetwork, nil
957}
958
[676]959func (u *user) deleteNetwork(ctx context.Context, id int64) error {
[313]960 network := u.getNetworkByID(id)
961 if network == nil {
962 panic("tried deleting a non-existing network")
[202]963 }
964
[676]965 if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
[313]966 return err
967 }
968
969 u.removeNetwork(network)
[532]970
971 idStr := fmt.Sprintf("%v", network.ID)
972 u.forEachDownstream(func(dc *downstreamConn) {
[535]973 if dc.caps["soju.im/bouncer-networks-notify"] {
[532]974 dc.SendMessage(&irc.Message{
975 Prefix: dc.srv.prefix(),
976 Command: "BOUNCER",
977 Params: []string{"NETWORK", idStr, "*"},
978 })
979 }
980 })
981
[313]982 return nil
[202]983}
[252]984
[676]985func (u *user) updateUser(ctx context.Context, record *User) error {
[572]986 if u.ID != record.ID {
987 panic("ID mismatch when updating user")
988 }
[376]989
[572]990 realnameUpdated := u.Realname != record.Realname
[676]991 if err := u.srv.db.StoreUser(ctx, record); err != nil {
[568]992 return fmt.Errorf("failed to update user %q: %v", u.Username, err)
993 }
[572]994 u.User = *record
[568]995
[572]996 if realnameUpdated {
997 // Re-connect to networks which use the default realname
998 var needUpdate []Network
[768]999 for _, net := range u.networks {
[572]1000 if net.Realname == "" {
1001 needUpdate = append(needUpdate, net.Network)
1002 }
[768]1003 }
[568]1004
[572]1005 var netErr error
1006 for _, net := range needUpdate {
[676]1007 if _, err := u.updateNetwork(ctx, &net); err != nil {
[572]1008 netErr = err
1009 }
[568]1010 }
[572]1011 if netErr != nil {
1012 return netErr
1013 }
[568]1014 }
1015
[572]1016 return nil
[568]1017}
1018
[376]1019func (u *user) stop() {
1020 u.events <- eventStop{}
[377]1021 <-u.done
[376]1022}
[489]1023
1024func (u *user) hasPersistentMsgStore() bool {
1025 if u.msgStore == nil {
1026 return false
1027 }
1028 _, isMem := u.msgStore.(*memoryMessageStore)
1029 return !isMem
1030}
[705]1031
1032// localAddrForHost returns the local address to use when connecting to host.
1033// A nil address is returned when the OS should automatically pick one.
[732]1034func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
[705]1035 upstreamUserIPs := u.srv.Config().UpstreamUserIPs
1036 if len(upstreamUserIPs) == 0 {
1037 return nil, nil
1038 }
1039
[732]1040 ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
[705]1041 if err != nil {
1042 return nil, err
1043 }
1044
1045 wantIPv6 := false
1046 for _, ip := range ips {
1047 if ip.To4() == nil {
1048 wantIPv6 = true
1049 break
1050 }
1051 }
1052
1053 var ipNet *net.IPNet
1054 for _, in := range upstreamUserIPs {
1055 if wantIPv6 == (in.IP.To4() == nil) {
1056 ipNet = in
1057 break
1058 }
1059 }
1060 if ipNet == nil {
1061 return nil, nil
1062 }
1063
1064 var ipInt big.Int
1065 ipInt.SetBytes(ipNet.IP)
1066 ipInt.Add(&ipInt, big.NewInt(u.ID+1))
1067 ip := net.IP(ipInt.Bytes())
1068 if !ipNet.Contains(ip) {
1069 return nil, fmt.Errorf("IP network %v too small", ipNet)
1070 }
1071
1072 return &net.TCPAddr{IP: ip}, nil
1073}
Note: See TracBrowser for help on using the repository browser.