source: code/trunk/user.go@ 307

Last change on this file since 307 was 296, checked in by delthas, 5 years ago

Update downstream nicks in single-server mode and after NICK

Previously, the downstream nick was never changed, even when the
downstream sent a NICK message or was in single-server mode with a
different nick.

This adds support for updating the downstream nick in the following
cases:

  • when a downstream sends NICK
  • additionally, in single-server mode:
    • when a downstream connects and its single network is connected
    • when an upstream connects
    • when an upstream sends NICK
File size: 10.5 KB
RevLine 
[101]1package soju
2
3import (
[218]4 "fmt"
[101]5 "time"
[103]6
7 "gopkg.in/irc.v3"
[101]8)
9
[165]10type event interface{}
11
12type eventUpstreamMessage struct {
[103]13 msg *irc.Message
14 uc *upstreamConn
15}
16
[218]17type eventUpstreamConnectionError struct {
18 net *network
19 err error
20}
21
[196]22type eventUpstreamConnected struct {
23 uc *upstreamConn
24}
25
[179]26type eventUpstreamDisconnected struct {
27 uc *upstreamConn
28}
29
[218]30type eventUpstreamError struct {
31 uc *upstreamConn
32 err error
33}
34
[165]35type eventDownstreamMessage struct {
[103]36 msg *irc.Message
37 dc *downstreamConn
38}
39
[166]40type eventDownstreamConnected struct {
41 dc *downstreamConn
42}
43
[167]44type eventDownstreamDisconnected struct {
45 dc *downstreamConn
46}
47
[253]48type networkHistory struct {
49 offlineClients map[string]uint64 // indexed by client name
50 ring *Ring // can be nil if there are no offline clients
51}
52
[101]53type network struct {
54 Network
[202]55 user *user
56 stopped chan struct{}
[131]57
[253]58 conn *upstreamConn
[267]59 channels map[string]*Channel
[253]60 history map[string]*networkHistory // indexed by entity
61 offlineClients map[string]struct{} // indexed by client name
62 lastError error
[101]63}
64
[267]65func newNetwork(user *user, record *Network, channels []Channel) *network {
66 m := make(map[string]*Channel, len(channels))
67 for _, ch := range channels {
[283]68 ch := ch
[267]69 m[ch.Name] = &ch
70 }
71
[101]72 return &network{
[253]73 Network: *record,
74 user: user,
75 stopped: make(chan struct{}),
[267]76 channels: m,
[253]77 history: make(map[string]*networkHistory),
78 offlineClients: make(map[string]struct{}),
[101]79 }
80}
81
[218]82func (net *network) forEachDownstream(f func(*downstreamConn)) {
83 net.user.forEachDownstream(func(dc *downstreamConn) {
84 if dc.network != nil && dc.network != net {
85 return
86 }
87 f(dc)
88 })
89}
90
[101]91func (net *network) run() {
92 var lastTry time.Time
93 for {
[202]94 select {
95 case <-net.stopped:
96 return
97 default:
98 // This space is intentionally left blank
99 }
100
[101]101 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
102 delay := retryConnectMinDelay - dur
103 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
104 time.Sleep(delay)
105 }
106 lastTry = time.Now()
107
108 uc, err := connectToUpstream(net)
109 if err != nil {
110 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
[218]111 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
[101]112 continue
113 }
114
115 uc.register()
[197]116 if err := uc.runUntilRegistered(); err != nil {
117 uc.logger.Printf("failed to register: %v", err)
[218]118 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", err)}
[197]119 uc.Close()
120 continue
121 }
[101]122
[196]123 net.user.events <- eventUpstreamConnected{uc}
[165]124 if err := uc.readMessages(net.user.events); err != nil {
[101]125 uc.logger.Printf("failed to handle messages: %v", err)
[218]126 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
[101]127 }
128 uc.Close()
[179]129 net.user.events <- eventUpstreamDisconnected{uc}
[101]130 }
131}
132
[202]133func (net *network) Stop() {
134 select {
135 case <-net.stopped:
136 return
137 default:
138 close(net.stopped)
139 }
140
[279]141 if net.conn != nil {
142 net.conn.Close()
[202]143 }
144}
145
[222]146func (net *network) createUpdateChannel(ch *Channel) error {
[267]147 if current, ok := net.channels[ch.Name]; ok {
148 ch.ID = current.ID // update channel if it already exists
149 }
150 if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
[222]151 return err
152 }
[284]153 prev := net.channels[ch.Name]
[267]154 net.channels[ch.Name] = ch
[284]155
156 if prev != nil && prev.Detached != ch.Detached {
157 history := net.history[ch.Name]
158 if ch.Detached {
159 net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name)
160 net.forEachDownstream(func(dc *downstreamConn) {
161 net.offlineClients[dc.clientName] = struct{}{}
162 if history != nil {
163 history.offlineClients[dc.clientName] = history.ring.Cur()
164 }
165
166 dc.SendMessage(&irc.Message{
167 Prefix: dc.prefix(),
168 Command: "PART",
169 Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
170 })
171 })
172 } else {
173 net.user.srv.Logger.Printf("network %q: attaching channel %q", net.GetName(), ch.Name)
174
175 var uch *upstreamChannel
176 if net.conn != nil {
177 uch = net.conn.channels[ch.Name]
178 }
179
180 net.forEachDownstream(func(dc *downstreamConn) {
181 dc.SendMessage(&irc.Message{
182 Prefix: dc.prefix(),
183 Command: "JOIN",
184 Params: []string{dc.marshalEntity(net, ch.Name)},
185 })
186
187 if uch != nil {
188 forwardChannel(dc, uch)
189 }
190
191 if history != nil {
192 dc.sendNetworkHistory(net)
193 }
194 })
195 }
196 }
197
[267]198 return nil
[222]199}
200
201func (net *network) deleteChannel(name string) error {
[267]202 if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
203 return err
204 }
205 delete(net.channels, name)
206 return nil
[222]207}
208
[101]209type user struct {
210 User
211 srv *Server
212
[165]213 events chan event
[103]214
[101]215 networks []*network
216 downstreamConns []*downstreamConn
[177]217
218 // LIST commands in progress
[179]219 pendingLISTs []pendingLIST
[101]220}
221
[177]222type pendingLIST struct {
223 downstreamID uint64
224 // list of per-upstream LIST commands not yet sent or completed
225 pendingCommands map[int64]*irc.Message
226}
227
[101]228func newUser(srv *Server, record *User) *user {
229 return &user{
[165]230 User: *record,
231 srv: srv,
232 events: make(chan event, 64),
[101]233 }
234}
235
236func (u *user) forEachNetwork(f func(*network)) {
237 for _, network := range u.networks {
238 f(network)
239 }
240}
241
242func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
243 for _, network := range u.networks {
[279]244 if network.conn == nil {
[101]245 continue
246 }
[279]247 f(network.conn)
[101]248 }
249}
250
251func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
252 for _, dc := range u.downstreamConns {
253 f(dc)
254 }
255}
256
257func (u *user) getNetwork(name string) *network {
258 for _, network := range u.networks {
259 if network.Addr == name {
260 return network
261 }
[201]262 if network.Name != "" && network.Name == name {
263 return network
264 }
[101]265 }
266 return nil
267}
268
269func (u *user) run() {
270 networks, err := u.srv.db.ListNetworks(u.Username)
271 if err != nil {
272 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
273 return
274 }
275
276 for _, record := range networks {
[283]277 record := record
[267]278 channels, err := u.srv.db.ListChannels(record.ID)
279 if err != nil {
280 u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
281 }
282
283 network := newNetwork(u, &record, channels)
[101]284 u.networks = append(u.networks, network)
285
286 go network.run()
287 }
[103]288
[165]289 for e := range u.events {
290 switch e := e.(type) {
[196]291 case eventUpstreamConnected:
[198]292 uc := e.uc
[199]293
294 uc.network.conn = uc
295
[198]296 uc.updateAway()
[218]297
298 uc.forEachDownstream(func(dc *downstreamConn) {
[276]299 dc.updateSupportedCaps()
[223]300 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
[296]301
302 dc.updateNick()
[218]303 })
304 uc.network.lastError = nil
[179]305 case eventUpstreamDisconnected:
306 uc := e.uc
[199]307
308 uc.network.conn = nil
309
[215]310 for _, ml := range uc.messageLoggers {
311 if err := ml.Close(); err != nil {
312 uc.logger.Printf("failed to close message logger: %v", err)
313 }
[179]314 }
[199]315
[181]316 uc.endPendingLISTs(true)
[218]317
[276]318 uc.forEachDownstream(func(dc *downstreamConn) {
319 dc.updateSupportedCaps()
320 })
321
[218]322 if uc.network.lastError == nil {
323 uc.forEachDownstream(func(dc *downstreamConn) {
[223]324 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
[218]325 })
326 }
327 case eventUpstreamConnectionError:
328 net := e.net
329
330 if net.lastError == nil || net.lastError.Error() != e.err.Error() {
331 net.forEachDownstream(func(dc *downstreamConn) {
[223]332 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
[218]333 })
334 }
335 net.lastError = e.err
336 case eventUpstreamError:
337 uc := e.uc
338
339 uc.forEachDownstream(func(dc *downstreamConn) {
[223]340 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
[218]341 })
342 uc.network.lastError = e.err
[165]343 case eventUpstreamMessage:
344 msg, uc := e.msg, e.uc
[175]345 if uc.isClosed() {
[133]346 uc.logger.Printf("ignoring message on closed connection: %v", msg)
347 break
348 }
[103]349 if err := uc.handleMessage(msg); err != nil {
350 uc.logger.Printf("failed to handle message %q: %v", msg, err)
351 }
[166]352 case eventDownstreamConnected:
353 dc := e.dc
[168]354
355 if err := dc.welcome(); err != nil {
356 dc.logger.Printf("failed to handle new registered connection: %v", err)
357 break
358 }
359
[166]360 u.downstreamConns = append(u.downstreamConns, dc)
[198]361
362 u.forEachUpstream(func(uc *upstreamConn) {
363 uc.updateAway()
364 })
[276]365
366 dc.updateSupportedCaps()
[167]367 case eventDownstreamDisconnected:
368 dc := e.dc
[204]369
[167]370 for i := range u.downstreamConns {
371 if u.downstreamConns[i] == dc {
372 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
373 break
374 }
375 }
[198]376
[253]377 // Save history if we're the last client with this name
378 skipHistory := make(map[*network]bool)
379 u.forEachDownstream(func(conn *downstreamConn) {
380 if dc.clientName == conn.clientName {
381 skipHistory[conn.network] = true
382 }
383 })
384
385 dc.forEachNetwork(func(net *network) {
386 if skipHistory[net] || skipHistory[nil] {
387 return
388 }
389
390 net.offlineClients[dc.clientName] = struct{}{}
[284]391 for target, history := range net.history {
392 if ch, ok := net.channels[target]; ok && ch.Detached {
393 continue
394 }
[253]395 history.offlineClients[dc.clientName] = history.ring.Cur()
396 }
397 })
398
[198]399 u.forEachUpstream(func(uc *upstreamConn) {
400 uc.updateAway()
401 })
[165]402 case eventDownstreamMessage:
403 msg, dc := e.msg, e.dc
[133]404 if dc.isClosed() {
405 dc.logger.Printf("ignoring message on closed connection: %v", msg)
406 break
407 }
[103]408 err := dc.handleMessage(msg)
409 if ircErr, ok := err.(ircError); ok {
410 ircErr.Message.Prefix = dc.srv.prefix()
411 dc.SendMessage(ircErr.Message)
412 } else if err != nil {
413 dc.logger.Printf("failed to handle message %q: %v", msg, err)
414 dc.Close()
415 }
[165]416 default:
417 u.srv.Logger.Printf("received unknown event type: %T", e)
[103]418 }
419 }
[101]420}
421
[120]422func (u *user) createNetwork(net *Network) (*network, error) {
[144]423 if net.ID != 0 {
424 panic("tried creating an already-existing network")
425 }
426
[267]427 network := newNetwork(u, net, nil)
[101]428 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
429 if err != nil {
430 return nil, err
431 }
[144]432
[101]433 u.networks = append(u.networks, network)
[144]434
[101]435 go network.run()
436 return network, nil
437}
[202]438
439func (u *user) deleteNetwork(id int64) error {
440 for i, net := range u.networks {
441 if net.ID != id {
442 continue
443 }
444
445 if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
446 return err
447 }
448
449 u.forEachDownstream(func(dc *downstreamConn) {
450 if dc.network != nil && dc.network == net {
451 dc.Close()
452 }
453 })
454
455 net.Stop()
456 u.networks = append(u.networks[:i], u.networks[i+1:]...)
457 return nil
458 }
459
460 panic("tried deleting a non-existing network")
461}
[252]462
463func (u *user) updatePassword(hashed string) error {
464 u.User.Password = hashed
465 return u.srv.db.UpdatePassword(&u.User)
466}
Note: See TracBrowser for help on using the repository browser.