source: code/trunk/user.go@ 352

Last change on this file since 352 was 324, checked in by contact, 5 years ago

Introduce User.Created

For Network and Channel, the database only needed to define one Store
operation to create/update a record. However since User is missing an ID
we couldn't have a single StoreUser function like other types. We had
CreateUser and UpdatePassword. As new User fields get added (e.g. the
upcoming Admin flag) this isn't sustainable.

We could have CreateUser and UpdateUser, but this wouldn't be consistent
with other types. Instead, introduce User.Created which indicates
whether the record is already stored in the DB. This can be used in a
new StoreUser function to decide whether we need to UPDATE or INSERT
without relying on SQL constraints and INSERT OR UPDATE.

The ListUsers and GetUser functions set User.Created to true.

File size: 12.5 KB
RevLine 
[101]1package soju
2
3import (
[218]4 "fmt"
[101]5 "time"
[103]6
7 "gopkg.in/irc.v3"
[101]8)
9
[165]10type event interface{}
11
12type eventUpstreamMessage struct {
[103]13 msg *irc.Message
14 uc *upstreamConn
15}
16
[218]17type eventUpstreamConnectionError struct {
18 net *network
19 err error
20}
21
[196]22type eventUpstreamConnected struct {
23 uc *upstreamConn
24}
25
[179]26type eventUpstreamDisconnected struct {
27 uc *upstreamConn
28}
29
[218]30type eventUpstreamError struct {
31 uc *upstreamConn
32 err error
33}
34
[165]35type eventDownstreamMessage struct {
[103]36 msg *irc.Message
37 dc *downstreamConn
38}
39
[166]40type eventDownstreamConnected struct {
41 dc *downstreamConn
42}
43
[167]44type eventDownstreamDisconnected struct {
45 dc *downstreamConn
46}
47
[253]48type networkHistory struct {
49 offlineClients map[string]uint64 // indexed by client name
50 ring *Ring // can be nil if there are no offline clients
51}
52
[101]53type network struct {
54 Network
[202]55 user *user
56 stopped chan struct{}
[131]57
[253]58 conn *upstreamConn
[267]59 channels map[string]*Channel
[253]60 history map[string]*networkHistory // indexed by entity
61 offlineClients map[string]struct{} // indexed by client name
62 lastError error
[101]63}
64
[267]65func newNetwork(user *user, record *Network, channels []Channel) *network {
66 m := make(map[string]*Channel, len(channels))
67 for _, ch := range channels {
[283]68 ch := ch
[267]69 m[ch.Name] = &ch
70 }
71
[101]72 return &network{
[253]73 Network: *record,
74 user: user,
75 stopped: make(chan struct{}),
[267]76 channels: m,
[253]77 history: make(map[string]*networkHistory),
78 offlineClients: make(map[string]struct{}),
[101]79 }
80}
81
[218]82func (net *network) forEachDownstream(f func(*downstreamConn)) {
83 net.user.forEachDownstream(func(dc *downstreamConn) {
84 if dc.network != nil && dc.network != net {
85 return
86 }
87 f(dc)
88 })
89}
90
[311]91func (net *network) isStopped() bool {
92 select {
93 case <-net.stopped:
94 return true
95 default:
96 return false
97 }
98}
99
[101]100func (net *network) run() {
101 var lastTry time.Time
102 for {
[311]103 if net.isStopped() {
[202]104 return
105 }
106
[101]107 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
108 delay := retryConnectMinDelay - dur
109 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
110 time.Sleep(delay)
111 }
112 lastTry = time.Now()
113
114 uc, err := connectToUpstream(net)
115 if err != nil {
116 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
[218]117 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
[101]118 continue
119 }
120
121 uc.register()
[197]122 if err := uc.runUntilRegistered(); err != nil {
123 uc.logger.Printf("failed to register: %v", err)
[218]124 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", err)}
[197]125 uc.Close()
126 continue
127 }
[101]128
[311]129 // TODO: this is racy with net.stopped. If the network is stopped
130 // before the user goroutine receives eventUpstreamConnected, the
131 // connection won't be closed.
[196]132 net.user.events <- eventUpstreamConnected{uc}
[165]133 if err := uc.readMessages(net.user.events); err != nil {
[101]134 uc.logger.Printf("failed to handle messages: %v", err)
[218]135 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
[101]136 }
137 uc.Close()
[179]138 net.user.events <- eventUpstreamDisconnected{uc}
[101]139 }
140}
141
[309]142func (net *network) stop() {
[311]143 if !net.isStopped() {
[202]144 close(net.stopped)
145 }
146
[279]147 if net.conn != nil {
148 net.conn.Close()
[202]149 }
150}
151
[222]152func (net *network) createUpdateChannel(ch *Channel) error {
[267]153 if current, ok := net.channels[ch.Name]; ok {
154 ch.ID = current.ID // update channel if it already exists
155 }
156 if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
[222]157 return err
158 }
[284]159 prev := net.channels[ch.Name]
[267]160 net.channels[ch.Name] = ch
[284]161
162 if prev != nil && prev.Detached != ch.Detached {
163 history := net.history[ch.Name]
164 if ch.Detached {
165 net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name)
166 net.forEachDownstream(func(dc *downstreamConn) {
167 net.offlineClients[dc.clientName] = struct{}{}
168 if history != nil {
169 history.offlineClients[dc.clientName] = history.ring.Cur()
170 }
171
172 dc.SendMessage(&irc.Message{
173 Prefix: dc.prefix(),
174 Command: "PART",
175 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
176 })
177 })
178 } else {
179 net.user.srv.Logger.Printf("network %q: attaching channel %q", net.GetName(), ch.Name)
180
181 var uch *upstreamChannel
182 if net.conn != nil {
183 uch = net.conn.channels[ch.Name]
184 }
185
186 net.forEachDownstream(func(dc *downstreamConn) {
187 dc.SendMessage(&irc.Message{
188 Prefix: dc.prefix(),
189 Command: "JOIN",
190 Params: []string{dc.marshalEntity(net, ch.Name)},
191 })
192
193 if uch != nil {
194 forwardChannel(dc, uch)
195 }
196
197 if history != nil {
198 dc.sendNetworkHistory(net)
199 }
200 })
201 }
202 }
203
[267]204 return nil
[222]205}
206
207func (net *network) deleteChannel(name string) error {
[267]208 if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
209 return err
210 }
211 delete(net.channels, name)
212 return nil
[222]213}
214
[101]215type user struct {
216 User
217 srv *Server
218
[165]219 events chan event
[103]220
[101]221 networks []*network
222 downstreamConns []*downstreamConn
[177]223
224 // LIST commands in progress
[179]225 pendingLISTs []pendingLIST
[101]226}
227
[177]228type pendingLIST struct {
229 downstreamID uint64
230 // list of per-upstream LIST commands not yet sent or completed
231 pendingCommands map[int64]*irc.Message
232}
233
[101]234func newUser(srv *Server, record *User) *user {
235 return &user{
[165]236 User: *record,
237 srv: srv,
238 events: make(chan event, 64),
[101]239 }
240}
241
242func (u *user) forEachNetwork(f func(*network)) {
243 for _, network := range u.networks {
244 f(network)
245 }
246}
247
248func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
249 for _, network := range u.networks {
[279]250 if network.conn == nil {
[101]251 continue
252 }
[279]253 f(network.conn)
[101]254 }
255}
256
257func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
258 for _, dc := range u.downstreamConns {
259 f(dc)
260 }
261}
262
263func (u *user) getNetwork(name string) *network {
264 for _, network := range u.networks {
265 if network.Addr == name {
266 return network
267 }
[201]268 if network.Name != "" && network.Name == name {
269 return network
270 }
[101]271 }
272 return nil
273}
274
[313]275func (u *user) getNetworkByID(id int64) *network {
276 for _, net := range u.networks {
277 if net.ID == id {
278 return net
279 }
280 }
281 return nil
282}
283
[101]284func (u *user) run() {
285 networks, err := u.srv.db.ListNetworks(u.Username)
286 if err != nil {
287 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
288 return
289 }
290
291 for _, record := range networks {
[283]292 record := record
[267]293 channels, err := u.srv.db.ListChannels(record.ID)
294 if err != nil {
295 u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
296 }
297
298 network := newNetwork(u, &record, channels)
[101]299 u.networks = append(u.networks, network)
300
301 go network.run()
302 }
[103]303
[165]304 for e := range u.events {
305 switch e := e.(type) {
[196]306 case eventUpstreamConnected:
[198]307 uc := e.uc
[199]308
309 uc.network.conn = uc
310
[198]311 uc.updateAway()
[218]312
313 uc.forEachDownstream(func(dc *downstreamConn) {
[276]314 dc.updateSupportedCaps()
[223]315 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
[296]316
317 dc.updateNick()
[218]318 })
319 uc.network.lastError = nil
[179]320 case eventUpstreamDisconnected:
[313]321 u.handleUpstreamDisconnected(e.uc)
322 case eventUpstreamConnectionError:
323 net := e.net
[199]324
[313]325 stopped := false
326 select {
327 case <-net.stopped:
328 stopped = true
329 default:
[179]330 }
[199]331
[313]332 if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
[218]333 net.forEachDownstream(func(dc *downstreamConn) {
[223]334 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
[218]335 })
336 }
337 net.lastError = e.err
338 case eventUpstreamError:
339 uc := e.uc
340
341 uc.forEachDownstream(func(dc *downstreamConn) {
[223]342 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
[218]343 })
344 uc.network.lastError = e.err
[165]345 case eventUpstreamMessage:
346 msg, uc := e.msg, e.uc
[175]347 if uc.isClosed() {
[133]348 uc.logger.Printf("ignoring message on closed connection: %v", msg)
349 break
350 }
[103]351 if err := uc.handleMessage(msg); err != nil {
352 uc.logger.Printf("failed to handle message %q: %v", msg, err)
353 }
[166]354 case eventDownstreamConnected:
355 dc := e.dc
[168]356
357 if err := dc.welcome(); err != nil {
358 dc.logger.Printf("failed to handle new registered connection: %v", err)
359 break
360 }
361
[166]362 u.downstreamConns = append(u.downstreamConns, dc)
[198]363
364 u.forEachUpstream(func(uc *upstreamConn) {
365 uc.updateAway()
366 })
[276]367
368 dc.updateSupportedCaps()
[167]369 case eventDownstreamDisconnected:
370 dc := e.dc
[204]371
[167]372 for i := range u.downstreamConns {
373 if u.downstreamConns[i] == dc {
374 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
375 break
376 }
377 }
[198]378
[253]379 // Save history if we're the last client with this name
380 skipHistory := make(map[*network]bool)
381 u.forEachDownstream(func(conn *downstreamConn) {
382 if dc.clientName == conn.clientName {
383 skipHistory[conn.network] = true
384 }
385 })
386
387 dc.forEachNetwork(func(net *network) {
388 if skipHistory[net] || skipHistory[nil] {
389 return
390 }
391
392 net.offlineClients[dc.clientName] = struct{}{}
[284]393 for target, history := range net.history {
394 if ch, ok := net.channels[target]; ok && ch.Detached {
395 continue
396 }
[253]397 history.offlineClients[dc.clientName] = history.ring.Cur()
398 }
399 })
400
[198]401 u.forEachUpstream(func(uc *upstreamConn) {
402 uc.updateAway()
403 })
[165]404 case eventDownstreamMessage:
405 msg, dc := e.msg, e.dc
[133]406 if dc.isClosed() {
407 dc.logger.Printf("ignoring message on closed connection: %v", msg)
408 break
409 }
[103]410 err := dc.handleMessage(msg)
411 if ircErr, ok := err.(ircError); ok {
412 ircErr.Message.Prefix = dc.srv.prefix()
413 dc.SendMessage(ircErr.Message)
414 } else if err != nil {
415 dc.logger.Printf("failed to handle message %q: %v", msg, err)
416 dc.Close()
417 }
[165]418 default:
419 u.srv.Logger.Printf("received unknown event type: %T", e)
[103]420 }
421 }
[101]422}
423
[313]424func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
425 uc.network.conn = nil
426
427 for _, ml := range uc.messageLoggers {
428 if err := ml.Close(); err != nil {
429 uc.logger.Printf("failed to close message logger: %v", err)
430 }
431 }
432
433 uc.endPendingLISTs(true)
434
435 uc.forEachDownstream(func(dc *downstreamConn) {
436 dc.updateSupportedCaps()
437 })
438
439 if uc.network.lastError == nil {
440 uc.forEachDownstream(func(dc *downstreamConn) {
441 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
442 })
443 }
444}
445
446func (u *user) addNetwork(network *network) {
447 u.networks = append(u.networks, network)
448 go network.run()
449}
450
451func (u *user) removeNetwork(network *network) {
452 network.stop()
453
454 u.forEachDownstream(func(dc *downstreamConn) {
455 if dc.network != nil && dc.network == network {
456 dc.Close()
457 }
458 })
459
460 for i, net := range u.networks {
461 if net == network {
462 u.networks = append(u.networks[:i], u.networks[i+1:]...)
463 return
464 }
465 }
466
467 panic("tried to remove a non-existing network")
468}
469
470func (u *user) createNetwork(record *Network) (*network, error) {
471 if record.ID != 0 {
[144]472 panic("tried creating an already-existing network")
473 }
474
[313]475 network := newNetwork(u, record, nil)
[101]476 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
477 if err != nil {
478 return nil, err
479 }
[144]480
[313]481 u.addNetwork(network)
[144]482
[101]483 return network, nil
484}
[202]485
[313]486func (u *user) updateNetwork(record *Network) (*network, error) {
487 if record.ID == 0 {
488 panic("tried updating a new network")
489 }
[202]490
[313]491 network := u.getNetworkByID(record.ID)
492 if network == nil {
493 panic("tried updating a non-existing network")
494 }
495
496 if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
497 return nil, err
498 }
499
500 // Most network changes require us to re-connect to the upstream server
501
502 channels := make([]Channel, 0, len(network.channels))
503 for _, ch := range network.channels {
504 channels = append(channels, *ch)
505 }
506
507 updatedNetwork := newNetwork(u, record, channels)
508
509 // If we're currently connected, disconnect and perform the necessary
510 // bookkeeping
511 if network.conn != nil {
512 network.stop()
513 // Note: this will set network.conn to nil
514 u.handleUpstreamDisconnected(network.conn)
515 }
516
517 // Patch downstream connections to use our fresh updated network
518 u.forEachDownstream(func(dc *downstreamConn) {
519 if dc.network != nil && dc.network == network {
520 dc.network = updatedNetwork
[202]521 }
[313]522 })
[202]523
[313]524 // We need to remove the network after patching downstream connections,
525 // otherwise they'll get closed
526 u.removeNetwork(network)
[202]527
[313]528 // This will re-connect to the upstream server
529 u.addNetwork(updatedNetwork)
530
531 return updatedNetwork, nil
532}
533
534func (u *user) deleteNetwork(id int64) error {
535 network := u.getNetworkByID(id)
536 if network == nil {
537 panic("tried deleting a non-existing network")
[202]538 }
539
[313]540 if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
541 return err
542 }
543
544 u.removeNetwork(network)
545 return nil
[202]546}
[252]547
548func (u *user) updatePassword(hashed string) error {
549 u.User.Password = hashed
[324]550 return u.srv.db.StoreUser(&u.User)
[252]551}
Note: See TracBrowser for help on using the repository browser.