source: code/trunk/user.go@ 283

Last change on this file since 283 was 283, checked in by delthas, 5 years ago

Fix joining only one saved channel per network

This fixes a serious bug added in 276ce12e, where in newNetwork all
channels point to the same channel, which causes soju to only join a
single channel when connecting to an upstream network.

This also adds the same kind of reassignment of a for loop variable in
user.run(), even though that function currently works correctly, as a
sanity improvement in case this function is changed in the future.

File size: 9.3 KB
Line 
1package soju
2
3import (
4 "fmt"
5 "time"
6
7 "gopkg.in/irc.v3"
8)
9
10type event interface{}
11
12type eventUpstreamMessage struct {
13 msg *irc.Message
14 uc *upstreamConn
15}
16
17type eventUpstreamConnectionError struct {
18 net *network
19 err error
20}
21
22type eventUpstreamConnected struct {
23 uc *upstreamConn
24}
25
26type eventUpstreamDisconnected struct {
27 uc *upstreamConn
28}
29
30type eventUpstreamError struct {
31 uc *upstreamConn
32 err error
33}
34
35type eventDownstreamMessage struct {
36 msg *irc.Message
37 dc *downstreamConn
38}
39
40type eventDownstreamConnected struct {
41 dc *downstreamConn
42}
43
44type eventDownstreamDisconnected struct {
45 dc *downstreamConn
46}
47
48type networkHistory struct {
49 offlineClients map[string]uint64 // indexed by client name
50 ring *Ring // can be nil if there are no offline clients
51}
52
53type network struct {
54 Network
55 user *user
56 stopped chan struct{}
57
58 conn *upstreamConn
59 channels map[string]*Channel
60 history map[string]*networkHistory // indexed by entity
61 offlineClients map[string]struct{} // indexed by client name
62 lastError error
63}
64
65func newNetwork(user *user, record *Network, channels []Channel) *network {
66 m := make(map[string]*Channel, len(channels))
67 for _, ch := range channels {
68 ch := ch
69 m[ch.Name] = &ch
70 }
71
72 return &network{
73 Network: *record,
74 user: user,
75 stopped: make(chan struct{}),
76 channels: m,
77 history: make(map[string]*networkHistory),
78 offlineClients: make(map[string]struct{}),
79 }
80}
81
82func (net *network) forEachDownstream(f func(*downstreamConn)) {
83 net.user.forEachDownstream(func(dc *downstreamConn) {
84 if dc.network != nil && dc.network != net {
85 return
86 }
87 f(dc)
88 })
89}
90
91func (net *network) run() {
92 var lastTry time.Time
93 for {
94 select {
95 case <-net.stopped:
96 return
97 default:
98 // This space is intentionally left blank
99 }
100
101 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
102 delay := retryConnectMinDelay - dur
103 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
104 time.Sleep(delay)
105 }
106 lastTry = time.Now()
107
108 uc, err := connectToUpstream(net)
109 if err != nil {
110 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
111 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
112 continue
113 }
114
115 uc.register()
116 if err := uc.runUntilRegistered(); err != nil {
117 uc.logger.Printf("failed to register: %v", err)
118 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", err)}
119 uc.Close()
120 continue
121 }
122
123 net.user.events <- eventUpstreamConnected{uc}
124 if err := uc.readMessages(net.user.events); err != nil {
125 uc.logger.Printf("failed to handle messages: %v", err)
126 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
127 }
128 uc.Close()
129 net.user.events <- eventUpstreamDisconnected{uc}
130 }
131}
132
133func (net *network) Stop() {
134 select {
135 case <-net.stopped:
136 return
137 default:
138 close(net.stopped)
139 }
140
141 if net.conn != nil {
142 net.conn.Close()
143 }
144}
145
146func (net *network) createUpdateChannel(ch *Channel) error {
147 if current, ok := net.channels[ch.Name]; ok {
148 ch.ID = current.ID // update channel if it already exists
149 }
150 if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
151 return err
152 }
153 net.channels[ch.Name] = ch
154 return nil
155}
156
157func (net *network) deleteChannel(name string) error {
158 if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
159 return err
160 }
161 delete(net.channels, name)
162 return nil
163}
164
165type user struct {
166 User
167 srv *Server
168
169 events chan event
170
171 networks []*network
172 downstreamConns []*downstreamConn
173
174 // LIST commands in progress
175 pendingLISTs []pendingLIST
176}
177
178type pendingLIST struct {
179 downstreamID uint64
180 // list of per-upstream LIST commands not yet sent or completed
181 pendingCommands map[int64]*irc.Message
182}
183
184func newUser(srv *Server, record *User) *user {
185 return &user{
186 User: *record,
187 srv: srv,
188 events: make(chan event, 64),
189 }
190}
191
192func (u *user) forEachNetwork(f func(*network)) {
193 for _, network := range u.networks {
194 f(network)
195 }
196}
197
198func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
199 for _, network := range u.networks {
200 if network.conn == nil {
201 continue
202 }
203 f(network.conn)
204 }
205}
206
207func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
208 for _, dc := range u.downstreamConns {
209 f(dc)
210 }
211}
212
213func (u *user) getNetwork(name string) *network {
214 for _, network := range u.networks {
215 if network.Addr == name {
216 return network
217 }
218 if network.Name != "" && network.Name == name {
219 return network
220 }
221 }
222 return nil
223}
224
225func (u *user) run() {
226 networks, err := u.srv.db.ListNetworks(u.Username)
227 if err != nil {
228 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
229 return
230 }
231
232 for _, record := range networks {
233 record := record
234 channels, err := u.srv.db.ListChannels(record.ID)
235 if err != nil {
236 u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
237 }
238
239 network := newNetwork(u, &record, channels)
240 u.networks = append(u.networks, network)
241
242 go network.run()
243 }
244
245 for e := range u.events {
246 switch e := e.(type) {
247 case eventUpstreamConnected:
248 uc := e.uc
249
250 uc.network.conn = uc
251
252 uc.updateAway()
253
254 uc.forEachDownstream(func(dc *downstreamConn) {
255 dc.updateSupportedCaps()
256 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
257 })
258 uc.network.lastError = nil
259 case eventUpstreamDisconnected:
260 uc := e.uc
261
262 uc.network.conn = nil
263
264 for _, ml := range uc.messageLoggers {
265 if err := ml.Close(); err != nil {
266 uc.logger.Printf("failed to close message logger: %v", err)
267 }
268 }
269
270 uc.endPendingLISTs(true)
271
272 uc.forEachDownstream(func(dc *downstreamConn) {
273 dc.updateSupportedCaps()
274 })
275
276 if uc.network.lastError == nil {
277 uc.forEachDownstream(func(dc *downstreamConn) {
278 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
279 })
280 }
281 case eventUpstreamConnectionError:
282 net := e.net
283
284 if net.lastError == nil || net.lastError.Error() != e.err.Error() {
285 net.forEachDownstream(func(dc *downstreamConn) {
286 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
287 })
288 }
289 net.lastError = e.err
290 case eventUpstreamError:
291 uc := e.uc
292
293 uc.forEachDownstream(func(dc *downstreamConn) {
294 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
295 })
296 uc.network.lastError = e.err
297 case eventUpstreamMessage:
298 msg, uc := e.msg, e.uc
299 if uc.isClosed() {
300 uc.logger.Printf("ignoring message on closed connection: %v", msg)
301 break
302 }
303 if err := uc.handleMessage(msg); err != nil {
304 uc.logger.Printf("failed to handle message %q: %v", msg, err)
305 }
306 case eventDownstreamConnected:
307 dc := e.dc
308
309 if err := dc.welcome(); err != nil {
310 dc.logger.Printf("failed to handle new registered connection: %v", err)
311 break
312 }
313
314 u.downstreamConns = append(u.downstreamConns, dc)
315
316 u.forEachUpstream(func(uc *upstreamConn) {
317 uc.updateAway()
318 })
319
320 dc.updateSupportedCaps()
321 case eventDownstreamDisconnected:
322 dc := e.dc
323
324 for i := range u.downstreamConns {
325 if u.downstreamConns[i] == dc {
326 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
327 break
328 }
329 }
330
331 // Save history if we're the last client with this name
332 skipHistory := make(map[*network]bool)
333 u.forEachDownstream(func(conn *downstreamConn) {
334 if dc.clientName == conn.clientName {
335 skipHistory[conn.network] = true
336 }
337 })
338
339 dc.forEachNetwork(func(net *network) {
340 if skipHistory[net] || skipHistory[nil] {
341 return
342 }
343
344 net.offlineClients[dc.clientName] = struct{}{}
345 for _, history := range net.history {
346 history.offlineClients[dc.clientName] = history.ring.Cur()
347 }
348 })
349
350 u.forEachUpstream(func(uc *upstreamConn) {
351 uc.updateAway()
352 })
353 case eventDownstreamMessage:
354 msg, dc := e.msg, e.dc
355 if dc.isClosed() {
356 dc.logger.Printf("ignoring message on closed connection: %v", msg)
357 break
358 }
359 err := dc.handleMessage(msg)
360 if ircErr, ok := err.(ircError); ok {
361 ircErr.Message.Prefix = dc.srv.prefix()
362 dc.SendMessage(ircErr.Message)
363 } else if err != nil {
364 dc.logger.Printf("failed to handle message %q: %v", msg, err)
365 dc.Close()
366 }
367 default:
368 u.srv.Logger.Printf("received unknown event type: %T", e)
369 }
370 }
371}
372
373func (u *user) createNetwork(net *Network) (*network, error) {
374 if net.ID != 0 {
375 panic("tried creating an already-existing network")
376 }
377
378 network := newNetwork(u, net, nil)
379 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
380 if err != nil {
381 return nil, err
382 }
383
384 u.networks = append(u.networks, network)
385
386 go network.run()
387 return network, nil
388}
389
390func (u *user) deleteNetwork(id int64) error {
391 for i, net := range u.networks {
392 if net.ID != id {
393 continue
394 }
395
396 if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
397 return err
398 }
399
400 u.forEachDownstream(func(dc *downstreamConn) {
401 if dc.network != nil && dc.network == net {
402 dc.Close()
403 }
404 })
405
406 net.Stop()
407 u.networks = append(u.networks[:i], u.networks[i+1:]...)
408 return nil
409 }
410
411 panic("tried deleting a non-existing network")
412}
413
414func (u *user) updatePassword(hashed string) error {
415 u.User.Password = hashed
416 return u.srv.db.UpdatePassword(&u.User)
417}
Note: See TracBrowser for help on using the repository browser.