source: code/trunk/user.go@ 415

Last change on this file since 415 was 409, checked in by contact, 5 years ago

Nuke in-memory ring buffer

Instead, always read chat history from logs. Unify the implicit chat
history (pushing history to clients) and explicit chat history
(via the CHATHISTORY command).

Instead of keeping track of ring buffer cursors for each client, use
message IDs.

If necessary, the ring buffer could be re-introduced behind a
common MessageStore interface (could be useful when on-disk logs are
disabled).

References: https://todo.sr.ht/~emersion/soju/80

File size: 13.3 KB
Line 
1package soju
2
3import (
4 "crypto/sha256"
5 "encoding/binary"
6 "encoding/hex"
7 "fmt"
8 "time"
9
10 "gopkg.in/irc.v3"
11)
12
13type event interface{}
14
15type eventUpstreamMessage struct {
16 msg *irc.Message
17 uc *upstreamConn
18}
19
20type eventUpstreamConnectionError struct {
21 net *network
22 err error
23}
24
25type eventUpstreamConnected struct {
26 uc *upstreamConn
27}
28
29type eventUpstreamDisconnected struct {
30 uc *upstreamConn
31}
32
33type eventUpstreamError struct {
34 uc *upstreamConn
35 err error
36}
37
38type eventDownstreamMessage struct {
39 msg *irc.Message
40 dc *downstreamConn
41}
42
43type eventDownstreamConnected struct {
44 dc *downstreamConn
45}
46
47type eventDownstreamDisconnected struct {
48 dc *downstreamConn
49}
50
51type eventStop struct{}
52
53type networkHistory struct {
54 clients map[string]string // indexed by client name
55}
56
57type network struct {
58 Network
59 user *user
60 stopped chan struct{}
61
62 conn *upstreamConn
63 channels map[string]*Channel
64 history map[string]*networkHistory // indexed by entity
65 offlineClients map[string]struct{} // indexed by client name
66 lastError error
67}
68
69func newNetwork(user *user, record *Network, channels []Channel) *network {
70 m := make(map[string]*Channel, len(channels))
71 for _, ch := range channels {
72 ch := ch
73 m[ch.Name] = &ch
74 }
75
76 return &network{
77 Network: *record,
78 user: user,
79 stopped: make(chan struct{}),
80 channels: m,
81 history: make(map[string]*networkHistory),
82 offlineClients: make(map[string]struct{}),
83 }
84}
85
86func (net *network) forEachDownstream(f func(*downstreamConn)) {
87 net.user.forEachDownstream(func(dc *downstreamConn) {
88 if dc.network != nil && dc.network != net {
89 return
90 }
91 f(dc)
92 })
93}
94
95func (net *network) isStopped() bool {
96 select {
97 case <-net.stopped:
98 return true
99 default:
100 return false
101 }
102}
103
104func userIdent(u *User) string {
105 // The ident is a string we will send to upstream servers in clear-text.
106 // For privacy reasons, make sure it doesn't expose any meaningful user
107 // metadata. We just use the base64-encoded hashed ID, so that people don't
108 // start relying on the string being an integer or following a pattern.
109 var b [64]byte
110 binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
111 h := sha256.Sum256(b[:])
112 return hex.EncodeToString(h[:16])
113}
114
115func (net *network) run() {
116 var lastTry time.Time
117 for {
118 if net.isStopped() {
119 return
120 }
121
122 if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
123 delay := retryConnectDelay - dur
124 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
125 time.Sleep(delay)
126 }
127 lastTry = time.Now()
128
129 uc, err := connectToUpstream(net)
130 if err != nil {
131 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
132 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
133 continue
134 }
135
136 if net.user.srv.Identd != nil {
137 net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
138 }
139
140 uc.register()
141 if err := uc.runUntilRegistered(); err != nil {
142 text := err.Error()
143 if regErr, ok := err.(registrationError); ok {
144 text = string(regErr)
145 }
146 uc.logger.Printf("failed to register: %v", text)
147 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
148 uc.Close()
149 continue
150 }
151
152 // TODO: this is racy with net.stopped. If the network is stopped
153 // before the user goroutine receives eventUpstreamConnected, the
154 // connection won't be closed.
155 net.user.events <- eventUpstreamConnected{uc}
156 if err := uc.readMessages(net.user.events); err != nil {
157 uc.logger.Printf("failed to handle messages: %v", err)
158 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
159 }
160 uc.Close()
161 net.user.events <- eventUpstreamDisconnected{uc}
162
163 if net.user.srv.Identd != nil {
164 net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
165 }
166 }
167}
168
169func (net *network) stop() {
170 if !net.isStopped() {
171 close(net.stopped)
172 }
173
174 if net.conn != nil {
175 net.conn.Close()
176 }
177}
178
179func (net *network) createUpdateChannel(ch *Channel) error {
180 if current, ok := net.channels[ch.Name]; ok {
181 ch.ID = current.ID // update channel if it already exists
182 }
183 if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
184 return err
185 }
186 prev := net.channels[ch.Name]
187 net.channels[ch.Name] = ch
188
189 if prev != nil && prev.Detached != ch.Detached {
190 history := net.history[ch.Name]
191 if ch.Detached {
192 net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name)
193 net.forEachDownstream(func(dc *downstreamConn) {
194 net.offlineClients[dc.clientName] = struct{}{}
195
196 dc.SendMessage(&irc.Message{
197 Prefix: dc.prefix(),
198 Command: "PART",
199 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
200 })
201 })
202 } else {
203 net.user.srv.Logger.Printf("network %q: attaching channel %q", net.GetName(), ch.Name)
204
205 var uch *upstreamChannel
206 if net.conn != nil {
207 uch = net.conn.channels[ch.Name]
208 }
209
210 net.forEachDownstream(func(dc *downstreamConn) {
211 dc.SendMessage(&irc.Message{
212 Prefix: dc.prefix(),
213 Command: "JOIN",
214 Params: []string{dc.marshalEntity(net, ch.Name)},
215 })
216
217 if uch != nil {
218 forwardChannel(dc, uch)
219 }
220
221 if history != nil {
222 dc.sendNetworkHistory(net)
223 }
224 })
225 }
226 }
227
228 return nil
229}
230
231func (net *network) deleteChannel(name string) error {
232 if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
233 return err
234 }
235 delete(net.channels, name)
236 return nil
237}
238
239type user struct {
240 User
241 srv *Server
242
243 events chan event
244 done chan struct{}
245
246 networks []*network
247 downstreamConns []*downstreamConn
248
249 // LIST commands in progress
250 pendingLISTs []pendingLIST
251}
252
253type pendingLIST struct {
254 downstreamID uint64
255 // list of per-upstream LIST commands not yet sent or completed
256 pendingCommands map[int64]*irc.Message
257}
258
259func newUser(srv *Server, record *User) *user {
260 return &user{
261 User: *record,
262 srv: srv,
263 events: make(chan event, 64),
264 done: make(chan struct{}),
265 }
266}
267
268func (u *user) forEachNetwork(f func(*network)) {
269 for _, network := range u.networks {
270 f(network)
271 }
272}
273
274func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
275 for _, network := range u.networks {
276 if network.conn == nil {
277 continue
278 }
279 f(network.conn)
280 }
281}
282
283func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
284 for _, dc := range u.downstreamConns {
285 f(dc)
286 }
287}
288
289func (u *user) getNetwork(name string) *network {
290 for _, network := range u.networks {
291 if network.Addr == name {
292 return network
293 }
294 if network.Name != "" && network.Name == name {
295 return network
296 }
297 }
298 return nil
299}
300
301func (u *user) getNetworkByID(id int64) *network {
302 for _, net := range u.networks {
303 if net.ID == id {
304 return net
305 }
306 }
307 return nil
308}
309
310func (u *user) run() {
311 defer close(u.done)
312
313 networks, err := u.srv.db.ListNetworks(u.Username)
314 if err != nil {
315 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
316 return
317 }
318
319 for _, record := range networks {
320 record := record
321 channels, err := u.srv.db.ListChannels(record.ID)
322 if err != nil {
323 u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
324 continue
325 }
326
327 network := newNetwork(u, &record, channels)
328 u.networks = append(u.networks, network)
329
330 go network.run()
331 }
332
333 for e := range u.events {
334 switch e := e.(type) {
335 case eventUpstreamConnected:
336 uc := e.uc
337
338 uc.network.conn = uc
339
340 uc.updateAway()
341
342 uc.forEachDownstream(func(dc *downstreamConn) {
343 dc.updateSupportedCaps()
344 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
345
346 dc.updateNick()
347 })
348 uc.network.lastError = nil
349 case eventUpstreamDisconnected:
350 u.handleUpstreamDisconnected(e.uc)
351 case eventUpstreamConnectionError:
352 net := e.net
353
354 stopped := false
355 select {
356 case <-net.stopped:
357 stopped = true
358 default:
359 }
360
361 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
362 net.forEachDownstream(func(dc *downstreamConn) {
363 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
364 })
365 }
366 net.lastError = e.err
367 case eventUpstreamError:
368 uc := e.uc
369
370 uc.forEachDownstream(func(dc *downstreamConn) {
371 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
372 })
373 uc.network.lastError = e.err
374 case eventUpstreamMessage:
375 msg, uc := e.msg, e.uc
376 if uc.isClosed() {
377 uc.logger.Printf("ignoring message on closed connection: %v", msg)
378 break
379 }
380 if err := uc.handleMessage(msg); err != nil {
381 uc.logger.Printf("failed to handle message %q: %v", msg, err)
382 }
383 case eventDownstreamConnected:
384 dc := e.dc
385
386 if err := dc.welcome(); err != nil {
387 dc.logger.Printf("failed to handle new registered connection: %v", err)
388 break
389 }
390
391 u.downstreamConns = append(u.downstreamConns, dc)
392
393 u.forEachUpstream(func(uc *upstreamConn) {
394 uc.updateAway()
395 })
396
397 dc.updateSupportedCaps()
398 case eventDownstreamDisconnected:
399 dc := e.dc
400
401 for i := range u.downstreamConns {
402 if u.downstreamConns[i] == dc {
403 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
404 break
405 }
406 }
407
408 // Save history if we're the last client with this name
409 skipHistory := make(map[*network]bool)
410 u.forEachDownstream(func(conn *downstreamConn) {
411 if dc.clientName == conn.clientName {
412 skipHistory[conn.network] = true
413 }
414 })
415
416 dc.forEachNetwork(func(net *network) {
417 if skipHistory[net] || skipHistory[nil] {
418 return
419 }
420
421 net.offlineClients[dc.clientName] = struct{}{}
422 })
423
424 u.forEachUpstream(func(uc *upstreamConn) {
425 uc.updateAway()
426 })
427 case eventDownstreamMessage:
428 msg, dc := e.msg, e.dc
429 if dc.isClosed() {
430 dc.logger.Printf("ignoring message on closed connection: %v", msg)
431 break
432 }
433 err := dc.handleMessage(msg)
434 if ircErr, ok := err.(ircError); ok {
435 ircErr.Message.Prefix = dc.srv.prefix()
436 dc.SendMessage(ircErr.Message)
437 } else if err != nil {
438 dc.logger.Printf("failed to handle message %q: %v", msg, err)
439 dc.Close()
440 }
441 case eventStop:
442 u.forEachDownstream(func(dc *downstreamConn) {
443 dc.Close()
444 })
445 for _, n := range u.networks {
446 n.stop()
447 }
448 return
449 default:
450 u.srv.Logger.Printf("received unknown event type: %T", e)
451 }
452 }
453}
454
455func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
456 uc.network.conn = nil
457
458 for _, ml := range uc.messageLoggers {
459 if err := ml.Close(); err != nil {
460 uc.logger.Printf("failed to close message logger: %v", err)
461 }
462 }
463
464 uc.endPendingLISTs(true)
465
466 uc.forEachDownstream(func(dc *downstreamConn) {
467 dc.updateSupportedCaps()
468 })
469
470 if uc.network.lastError == nil {
471 uc.forEachDownstream(func(dc *downstreamConn) {
472 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
473 })
474 }
475}
476
477func (u *user) addNetwork(network *network) {
478 u.networks = append(u.networks, network)
479 go network.run()
480}
481
482func (u *user) removeNetwork(network *network) {
483 network.stop()
484
485 u.forEachDownstream(func(dc *downstreamConn) {
486 if dc.network != nil && dc.network == network {
487 dc.Close()
488 }
489 })
490
491 for i, net := range u.networks {
492 if net == network {
493 u.networks = append(u.networks[:i], u.networks[i+1:]...)
494 return
495 }
496 }
497
498 panic("tried to remove a non-existing network")
499}
500
501func (u *user) createNetwork(record *Network) (*network, error) {
502 if record.ID != 0 {
503 panic("tried creating an already-existing network")
504 }
505
506 network := newNetwork(u, record, nil)
507 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
508 if err != nil {
509 return nil, err
510 }
511
512 u.addNetwork(network)
513
514 return network, nil
515}
516
517func (u *user) updateNetwork(record *Network) (*network, error) {
518 if record.ID == 0 {
519 panic("tried updating a new network")
520 }
521
522 network := u.getNetworkByID(record.ID)
523 if network == nil {
524 panic("tried updating a non-existing network")
525 }
526
527 if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
528 return nil, err
529 }
530
531 // Most network changes require us to re-connect to the upstream server
532
533 channels := make([]Channel, 0, len(network.channels))
534 for _, ch := range network.channels {
535 channels = append(channels, *ch)
536 }
537
538 updatedNetwork := newNetwork(u, record, channels)
539
540 // If we're currently connected, disconnect and perform the necessary
541 // bookkeeping
542 if network.conn != nil {
543 network.stop()
544 // Note: this will set network.conn to nil
545 u.handleUpstreamDisconnected(network.conn)
546 }
547
548 // Patch downstream connections to use our fresh updated network
549 u.forEachDownstream(func(dc *downstreamConn) {
550 if dc.network != nil && dc.network == network {
551 dc.network = updatedNetwork
552 }
553 })
554
555 // We need to remove the network after patching downstream connections,
556 // otherwise they'll get closed
557 u.removeNetwork(network)
558
559 // This will re-connect to the upstream server
560 u.addNetwork(updatedNetwork)
561
562 return updatedNetwork, nil
563}
564
565func (u *user) deleteNetwork(id int64) error {
566 network := u.getNetworkByID(id)
567 if network == nil {
568 panic("tried deleting a non-existing network")
569 }
570
571 if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
572 return err
573 }
574
575 u.removeNetwork(network)
576 return nil
577}
578
579func (u *user) updatePassword(hashed string) error {
580 u.User.Password = hashed
581 return u.srv.db.StoreUser(&u.User)
582}
583
584func (u *user) stop() {
585 u.events <- eventStop{}
586 <-u.done
587}
Note: See TracBrowser for help on using the repository browser.