source: code/trunk/user.go@ 414

Last change on this file since 414 was 409, checked in by contact, 5 years ago

Nuke in-memory ring buffer

Instead, always read chat history from logs. Unify the implicit chat
history (pushing history to clients) and explicit chat history
(via the CHATHISTORY command).

Instead of keeping track of ring buffer cursors for each client, use
message IDs.

If necessary, the ring buffer could be re-introduced behind a
common MessageStore interface (could be useful when on-disk logs are
disabled).

References: https://todo.sr.ht/~emersion/soju/80

File size: 13.3 KB
RevLine 
[101]1package soju
2
3import (
[385]4 "crypto/sha256"
5 "encoding/binary"
[395]6 "encoding/hex"
[218]7 "fmt"
[101]8 "time"
[103]9
10 "gopkg.in/irc.v3"
[101]11)
12
[165]13type event interface{}
14
15type eventUpstreamMessage struct {
[103]16 msg *irc.Message
17 uc *upstreamConn
18}
19
[218]20type eventUpstreamConnectionError struct {
21 net *network
22 err error
23}
24
[196]25type eventUpstreamConnected struct {
26 uc *upstreamConn
27}
28
[179]29type eventUpstreamDisconnected struct {
30 uc *upstreamConn
31}
32
[218]33type eventUpstreamError struct {
34 uc *upstreamConn
35 err error
36}
37
[165]38type eventDownstreamMessage struct {
[103]39 msg *irc.Message
40 dc *downstreamConn
41}
42
[166]43type eventDownstreamConnected struct {
44 dc *downstreamConn
45}
46
[167]47type eventDownstreamDisconnected struct {
48 dc *downstreamConn
49}
50
[376]51type eventStop struct{}
52
[253]53type networkHistory struct {
[409]54 clients map[string]string // indexed by client name
[253]55}
56
[101]57type network struct {
58 Network
[202]59 user *user
60 stopped chan struct{}
[131]61
[253]62 conn *upstreamConn
[267]63 channels map[string]*Channel
[253]64 history map[string]*networkHistory // indexed by entity
65 offlineClients map[string]struct{} // indexed by client name
66 lastError error
[101]67}
68
[267]69func newNetwork(user *user, record *Network, channels []Channel) *network {
70 m := make(map[string]*Channel, len(channels))
71 for _, ch := range channels {
[283]72 ch := ch
[267]73 m[ch.Name] = &ch
74 }
75
[101]76 return &network{
[253]77 Network: *record,
78 user: user,
79 stopped: make(chan struct{}),
[267]80 channels: m,
[253]81 history: make(map[string]*networkHistory),
82 offlineClients: make(map[string]struct{}),
[101]83 }
84}
85
[218]86func (net *network) forEachDownstream(f func(*downstreamConn)) {
87 net.user.forEachDownstream(func(dc *downstreamConn) {
88 if dc.network != nil && dc.network != net {
89 return
90 }
91 f(dc)
92 })
93}
94
[311]95func (net *network) isStopped() bool {
96 select {
97 case <-net.stopped:
98 return true
99 default:
100 return false
101 }
102}
103
[385]104func userIdent(u *User) string {
105 // The ident is a string we will send to upstream servers in clear-text.
106 // For privacy reasons, make sure it doesn't expose any meaningful user
107 // metadata. We just use the base64-encoded hashed ID, so that people don't
108 // start relying on the string being an integer or following a pattern.
109 var b [64]byte
110 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
111 h := sha256.Sum256(b[:])
[395]112 return hex.EncodeToString(h[:16])
[385]113}
114
[101]115func (net *network) run() {
116 var lastTry time.Time
117 for {
[311]118 if net.isStopped() {
[202]119 return
120 }
121
[398]122 if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
123 delay := retryConnectDelay - dur
[101]124 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
125 time.Sleep(delay)
126 }
127 lastTry = time.Now()
128
129 uc, err := connectToUpstream(net)
130 if err != nil {
131 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
[218]132 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
[101]133 continue
134 }
135
[385]136 if net.user.srv.Identd != nil {
137 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
138 }
139
[101]140 uc.register()
[197]141 if err := uc.runUntilRegistered(); err != nil {
[399]142 text := err.Error()
143 if regErr, ok := err.(registrationError); ok {
144 text = string(regErr)
145 }
146 uc.logger.Printf("failed to register: %v", text)
147 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
[197]148 uc.Close()
149 continue
150 }
[101]151
[311]152 // TODO: this is racy with net.stopped. If the network is stopped
153 // before the user goroutine receives eventUpstreamConnected, the
154 // connection won't be closed.
[196]155 net.user.events <- eventUpstreamConnected{uc}
[165]156 if err := uc.readMessages(net.user.events); err != nil {
[101]157 uc.logger.Printf("failed to handle messages: %v", err)
[218]158 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
[101]159 }
160 uc.Close()
[179]161 net.user.events <- eventUpstreamDisconnected{uc}
[385]162
163 if net.user.srv.Identd != nil {
164 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
165 }
[101]166 }
167}
168
[309]169func (net *network) stop() {
[311]170 if !net.isStopped() {
[202]171 close(net.stopped)
172 }
173
[279]174 if net.conn != nil {
175 net.conn.Close()
[202]176 }
177}
178
[222]179func (net *network) createUpdateChannel(ch *Channel) error {
[267]180 if current, ok := net.channels[ch.Name]; ok {
181 ch.ID = current.ID // update channel if it already exists
182 }
183 if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
[222]184 return err
185 }
[284]186 prev := net.channels[ch.Name]
[267]187 net.channels[ch.Name] = ch
[284]188
189 if prev != nil && prev.Detached != ch.Detached {
190 history := net.history[ch.Name]
191 if ch.Detached {
192 net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name)
193 net.forEachDownstream(func(dc *downstreamConn) {
194 net.offlineClients[dc.clientName] = struct{}{}
195
196 dc.SendMessage(&irc.Message{
197 Prefix: dc.prefix(),
198 Command: "PART",
199 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
200 })
201 })
202 } else {
203 net.user.srv.Logger.Printf("network %q: attaching channel %q", net.GetName(), ch.Name)
204
205 var uch *upstreamChannel
206 if net.conn != nil {
207 uch = net.conn.channels[ch.Name]
208 }
209
210 net.forEachDownstream(func(dc *downstreamConn) {
211 dc.SendMessage(&irc.Message{
212 Prefix: dc.prefix(),
213 Command: "JOIN",
214 Params: []string{dc.marshalEntity(net, ch.Name)},
215 })
216
217 if uch != nil {
218 forwardChannel(dc, uch)
219 }
220
221 if history != nil {
222 dc.sendNetworkHistory(net)
223 }
224 })
225 }
226 }
227
[267]228 return nil
[222]229}
230
231func (net *network) deleteChannel(name string) error {
[267]232 if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
233 return err
234 }
235 delete(net.channels, name)
236 return nil
[222]237}
238
[101]239type user struct {
240 User
241 srv *Server
242
[165]243 events chan event
[377]244 done chan struct{}
[103]245
[101]246 networks []*network
247 downstreamConns []*downstreamConn
[177]248
249 // LIST commands in progress
[179]250 pendingLISTs []pendingLIST
[101]251}
252
[177]253type pendingLIST struct {
254 downstreamID uint64
255 // list of per-upstream LIST commands not yet sent or completed
256 pendingCommands map[int64]*irc.Message
257}
258
[101]259func newUser(srv *Server, record *User) *user {
260 return &user{
[165]261 User: *record,
262 srv: srv,
263 events: make(chan event, 64),
[377]264 done: make(chan struct{}),
[101]265 }
266}
267
268func (u *user) forEachNetwork(f func(*network)) {
269 for _, network := range u.networks {
270 f(network)
271 }
272}
273
274func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
275 for _, network := range u.networks {
[279]276 if network.conn == nil {
[101]277 continue
278 }
[279]279 f(network.conn)
[101]280 }
281}
282
283func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
284 for _, dc := range u.downstreamConns {
285 f(dc)
286 }
287}
288
289func (u *user) getNetwork(name string) *network {
290 for _, network := range u.networks {
291 if network.Addr == name {
292 return network
293 }
[201]294 if network.Name != "" && network.Name == name {
295 return network
296 }
[101]297 }
298 return nil
299}
300
[313]301func (u *user) getNetworkByID(id int64) *network {
302 for _, net := range u.networks {
303 if net.ID == id {
304 return net
305 }
306 }
307 return nil
308}
309
[101]310func (u *user) run() {
[377]311 defer close(u.done)
312
[101]313 networks, err := u.srv.db.ListNetworks(u.Username)
314 if err != nil {
315 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
316 return
317 }
318
319 for _, record := range networks {
[283]320 record := record
[267]321 channels, err := u.srv.db.ListChannels(record.ID)
322 if err != nil {
323 u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
[359]324 continue
[267]325 }
326
327 network := newNetwork(u, &record, channels)
[101]328 u.networks = append(u.networks, network)
329
330 go network.run()
331 }
[103]332
[165]333 for e := range u.events {
334 switch e := e.(type) {
[196]335 case eventUpstreamConnected:
[198]336 uc := e.uc
[199]337
338 uc.network.conn = uc
339
[198]340 uc.updateAway()
[218]341
342 uc.forEachDownstream(func(dc *downstreamConn) {
[276]343 dc.updateSupportedCaps()
[223]344 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
[296]345
346 dc.updateNick()
[218]347 })
348 uc.network.lastError = nil
[179]349 case eventUpstreamDisconnected:
[313]350 u.handleUpstreamDisconnected(e.uc)
351 case eventUpstreamConnectionError:
352 net := e.net
[199]353
[313]354 stopped := false
355 select {
356 case <-net.stopped:
357 stopped = true
358 default:
[179]359 }
[199]360
[313]361 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
[218]362 net.forEachDownstream(func(dc *downstreamConn) {
[223]363 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
[218]364 })
365 }
366 net.lastError = e.err
367 case eventUpstreamError:
368 uc := e.uc
369
370 uc.forEachDownstream(func(dc *downstreamConn) {
[223]371 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
[218]372 })
373 uc.network.lastError = e.err
[165]374 case eventUpstreamMessage:
375 msg, uc := e.msg, e.uc
[175]376 if uc.isClosed() {
[133]377 uc.logger.Printf("ignoring message on closed connection: %v", msg)
378 break
379 }
[103]380 if err := uc.handleMessage(msg); err != nil {
381 uc.logger.Printf("failed to handle message %q: %v", msg, err)
382 }
[166]383 case eventDownstreamConnected:
384 dc := e.dc
[168]385
386 if err := dc.welcome(); err != nil {
387 dc.logger.Printf("failed to handle new registered connection: %v", err)
388 break
389 }
390
[166]391 u.downstreamConns = append(u.downstreamConns, dc)
[198]392
393 u.forEachUpstream(func(uc *upstreamConn) {
394 uc.updateAway()
395 })
[276]396
397 dc.updateSupportedCaps()
[167]398 case eventDownstreamDisconnected:
399 dc := e.dc
[204]400
[167]401 for i := range u.downstreamConns {
402 if u.downstreamConns[i] == dc {
403 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
404 break
405 }
406 }
[198]407
[253]408 // Save history if we're the last client with this name
409 skipHistory := make(map[*network]bool)
410 u.forEachDownstream(func(conn *downstreamConn) {
411 if dc.clientName == conn.clientName {
412 skipHistory[conn.network] = true
413 }
414 })
415
416 dc.forEachNetwork(func(net *network) {
417 if skipHistory[net] || skipHistory[nil] {
418 return
419 }
420
421 net.offlineClients[dc.clientName] = struct{}{}
422 })
423
[198]424 u.forEachUpstream(func(uc *upstreamConn) {
425 uc.updateAway()
426 })
[165]427 case eventDownstreamMessage:
428 msg, dc := e.msg, e.dc
[133]429 if dc.isClosed() {
430 dc.logger.Printf("ignoring message on closed connection: %v", msg)
431 break
432 }
[103]433 err := dc.handleMessage(msg)
434 if ircErr, ok := err.(ircError); ok {
435 ircErr.Message.Prefix = dc.srv.prefix()
436 dc.SendMessage(ircErr.Message)
437 } else if err != nil {
438 dc.logger.Printf("failed to handle message %q: %v", msg, err)
439 dc.Close()
440 }
[376]441 case eventStop:
442 u.forEachDownstream(func(dc *downstreamConn) {
443 dc.Close()
444 })
445 for _, n := range u.networks {
446 n.stop()
447 }
448 return
[165]449 default:
450 u.srv.Logger.Printf("received unknown event type: %T", e)
[103]451 }
452 }
[101]453}
454
[313]455func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
456 uc.network.conn = nil
457
458 for _, ml := range uc.messageLoggers {
459 if err := ml.Close(); err != nil {
460 uc.logger.Printf("failed to close message logger: %v", err)
461 }
462 }
463
464 uc.endPendingLISTs(true)
465
466 uc.forEachDownstream(func(dc *downstreamConn) {
467 dc.updateSupportedCaps()
468 })
469
470 if uc.network.lastError == nil {
471 uc.forEachDownstream(func(dc *downstreamConn) {
472 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
473 })
474 }
475}
476
477func (u *user) addNetwork(network *network) {
478 u.networks = append(u.networks, network)
479 go network.run()
480}
481
482func (u *user) removeNetwork(network *network) {
483 network.stop()
484
485 u.forEachDownstream(func(dc *downstreamConn) {
486 if dc.network != nil && dc.network == network {
487 dc.Close()
488 }
489 })
490
491 for i, net := range u.networks {
492 if net == network {
493 u.networks = append(u.networks[:i], u.networks[i+1:]...)
494 return
495 }
496 }
497
498 panic("tried to remove a non-existing network")
499}
500
501func (u *user) createNetwork(record *Network) (*network, error) {
502 if record.ID != 0 {
[144]503 panic("tried creating an already-existing network")
504 }
505
[313]506 network := newNetwork(u, record, nil)
[101]507 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
508 if err != nil {
509 return nil, err
510 }
[144]511
[313]512 u.addNetwork(network)
[144]513
[101]514 return network, nil
515}
[202]516
[313]517func (u *user) updateNetwork(record *Network) (*network, error) {
518 if record.ID == 0 {
519 panic("tried updating a new network")
520 }
[202]521
[313]522 network := u.getNetworkByID(record.ID)
523 if network == nil {
524 panic("tried updating a non-existing network")
525 }
526
527 if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
528 return nil, err
529 }
530
531 // Most network changes require us to re-connect to the upstream server
532
533 channels := make([]Channel, 0, len(network.channels))
534 for _, ch := range network.channels {
535 channels = append(channels, *ch)
536 }
537
538 updatedNetwork := newNetwork(u, record, channels)
539
540 // If we're currently connected, disconnect and perform the necessary
541 // bookkeeping
542 if network.conn != nil {
543 network.stop()
544 // Note: this will set network.conn to nil
545 u.handleUpstreamDisconnected(network.conn)
546 }
547
548 // Patch downstream connections to use our fresh updated network
549 u.forEachDownstream(func(dc *downstreamConn) {
550 if dc.network != nil && dc.network == network {
551 dc.network = updatedNetwork
[202]552 }
[313]553 })
[202]554
[313]555 // We need to remove the network after patching downstream connections,
556 // otherwise they'll get closed
557 u.removeNetwork(network)
[202]558
[313]559 // This will re-connect to the upstream server
560 u.addNetwork(updatedNetwork)
561
562 return updatedNetwork, nil
563}
564
565func (u *user) deleteNetwork(id int64) error {
566 network := u.getNetworkByID(id)
567 if network == nil {
568 panic("tried deleting a non-existing network")
[202]569 }
570
[313]571 if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
572 return err
573 }
574
575 u.removeNetwork(network)
576 return nil
[202]577}
[252]578
579func (u *user) updatePassword(hashed string) error {
580 u.User.Password = hashed
[324]581 return u.srv.db.StoreUser(&u.User)
[252]582}
[376]583
584func (u *user) stop() {
585 u.events <- eventStop{}
[377]586 <-u.done
[376]587}
Note: See TracBrowser for help on using the repository browser.