source: code/trunk/user.go@ 283

Last change on this file since 283 was 283, checked in by delthas, 5 years ago

Fix joining only one saved channel per network

This fixes a serious bug added in 276ce12e, where in newNetwork all
channels point to the same channel, which causes soju to only join a
single channel when connecting to an upstream network.

This also adds the same kind of reassignment of a for loop variable in
user.run(), even though that function currently works correctly, as a
sanity improvement in case this function is changed in the future.

File size: 9.3 KB
RevLine 
[101]1package soju
2
3import (
[218]4 "fmt"
[101]5 "time"
[103]6
7 "gopkg.in/irc.v3"
[101]8)
9
[165]10type event interface{}
11
12type eventUpstreamMessage struct {
[103]13 msg *irc.Message
14 uc *upstreamConn
15}
16
[218]17type eventUpstreamConnectionError struct {
18 net *network
19 err error
20}
21
[196]22type eventUpstreamConnected struct {
23 uc *upstreamConn
24}
25
[179]26type eventUpstreamDisconnected struct {
27 uc *upstreamConn
28}
29
[218]30type eventUpstreamError struct {
31 uc *upstreamConn
32 err error
33}
34
[165]35type eventDownstreamMessage struct {
[103]36 msg *irc.Message
37 dc *downstreamConn
38}
39
[166]40type eventDownstreamConnected struct {
41 dc *downstreamConn
42}
43
[167]44type eventDownstreamDisconnected struct {
45 dc *downstreamConn
46}
47
[253]48type networkHistory struct {
49 offlineClients map[string]uint64 // indexed by client name
50 ring *Ring // can be nil if there are no offline clients
51}
52
[101]53type network struct {
54 Network
[202]55 user *user
56 stopped chan struct{}
[131]57
[253]58 conn *upstreamConn
[267]59 channels map[string]*Channel
[253]60 history map[string]*networkHistory // indexed by entity
61 offlineClients map[string]struct{} // indexed by client name
62 lastError error
[101]63}
64
[267]65func newNetwork(user *user, record *Network, channels []Channel) *network {
66 m := make(map[string]*Channel, len(channels))
67 for _, ch := range channels {
[283]68 ch := ch
[267]69 m[ch.Name] = &ch
70 }
71
[101]72 return &network{
[253]73 Network: *record,
74 user: user,
75 stopped: make(chan struct{}),
[267]76 channels: m,
[253]77 history: make(map[string]*networkHistory),
78 offlineClients: make(map[string]struct{}),
[101]79 }
80}
81
[218]82func (net *network) forEachDownstream(f func(*downstreamConn)) {
83 net.user.forEachDownstream(func(dc *downstreamConn) {
84 if dc.network != nil && dc.network != net {
85 return
86 }
87 f(dc)
88 })
89}
90
[101]91func (net *network) run() {
92 var lastTry time.Time
93 for {
[202]94 select {
95 case <-net.stopped:
96 return
97 default:
98 // This space is intentionally left blank
99 }
100
[101]101 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
102 delay := retryConnectMinDelay - dur
103 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
104 time.Sleep(delay)
105 }
106 lastTry = time.Now()
107
108 uc, err := connectToUpstream(net)
109 if err != nil {
110 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
[218]111 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
[101]112 continue
113 }
114
115 uc.register()
[197]116 if err := uc.runUntilRegistered(); err != nil {
117 uc.logger.Printf("failed to register: %v", err)
[218]118 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", err)}
[197]119 uc.Close()
120 continue
121 }
[101]122
[196]123 net.user.events <- eventUpstreamConnected{uc}
[165]124 if err := uc.readMessages(net.user.events); err != nil {
[101]125 uc.logger.Printf("failed to handle messages: %v", err)
[218]126 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
[101]127 }
128 uc.Close()
[179]129 net.user.events <- eventUpstreamDisconnected{uc}
[101]130 }
131}
132
[202]133func (net *network) Stop() {
134 select {
135 case <-net.stopped:
136 return
137 default:
138 close(net.stopped)
139 }
140
[279]141 if net.conn != nil {
142 net.conn.Close()
[202]143 }
144}
145
[222]146func (net *network) createUpdateChannel(ch *Channel) error {
[267]147 if current, ok := net.channels[ch.Name]; ok {
148 ch.ID = current.ID // update channel if it already exists
149 }
150 if err := net.user.srv.db.StoreChannel(net.ID, ch); err != nil {
[222]151 return err
152 }
[267]153 net.channels[ch.Name] = ch
154 return nil
[222]155}
156
157func (net *network) deleteChannel(name string) error {
[267]158 if err := net.user.srv.db.DeleteChannel(net.ID, name); err != nil {
159 return err
160 }
161 delete(net.channels, name)
162 return nil
[222]163}
164
[101]165type user struct {
166 User
167 srv *Server
168
[165]169 events chan event
[103]170
[101]171 networks []*network
172 downstreamConns []*downstreamConn
[177]173
174 // LIST commands in progress
[179]175 pendingLISTs []pendingLIST
[101]176}
177
[177]178type pendingLIST struct {
179 downstreamID uint64
180 // list of per-upstream LIST commands not yet sent or completed
181 pendingCommands map[int64]*irc.Message
182}
183
[101]184func newUser(srv *Server, record *User) *user {
185 return &user{
[165]186 User: *record,
187 srv: srv,
188 events: make(chan event, 64),
[101]189 }
190}
191
192func (u *user) forEachNetwork(f func(*network)) {
193 for _, network := range u.networks {
194 f(network)
195 }
196}
197
198func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
199 for _, network := range u.networks {
[279]200 if network.conn == nil {
[101]201 continue
202 }
[279]203 f(network.conn)
[101]204 }
205}
206
207func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
208 for _, dc := range u.downstreamConns {
209 f(dc)
210 }
211}
212
213func (u *user) getNetwork(name string) *network {
214 for _, network := range u.networks {
215 if network.Addr == name {
216 return network
217 }
[201]218 if network.Name != "" && network.Name == name {
219 return network
220 }
[101]221 }
222 return nil
223}
224
225func (u *user) run() {
226 networks, err := u.srv.db.ListNetworks(u.Username)
227 if err != nil {
228 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
229 return
230 }
231
232 for _, record := range networks {
[283]233 record := record
[267]234 channels, err := u.srv.db.ListChannels(record.ID)
235 if err != nil {
236 u.srv.Logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
237 }
238
239 network := newNetwork(u, &record, channels)
[101]240 u.networks = append(u.networks, network)
241
242 go network.run()
243 }
[103]244
[165]245 for e := range u.events {
246 switch e := e.(type) {
[196]247 case eventUpstreamConnected:
[198]248 uc := e.uc
[199]249
250 uc.network.conn = uc
251
[198]252 uc.updateAway()
[218]253
254 uc.forEachDownstream(func(dc *downstreamConn) {
[276]255 dc.updateSupportedCaps()
[223]256 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
[218]257 })
258 uc.network.lastError = nil
[179]259 case eventUpstreamDisconnected:
260 uc := e.uc
[199]261
262 uc.network.conn = nil
263
[215]264 for _, ml := range uc.messageLoggers {
265 if err := ml.Close(); err != nil {
266 uc.logger.Printf("failed to close message logger: %v", err)
267 }
[179]268 }
[199]269
[181]270 uc.endPendingLISTs(true)
[218]271
[276]272 uc.forEachDownstream(func(dc *downstreamConn) {
273 dc.updateSupportedCaps()
274 })
275
[218]276 if uc.network.lastError == nil {
277 uc.forEachDownstream(func(dc *downstreamConn) {
[223]278 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
[218]279 })
280 }
281 case eventUpstreamConnectionError:
282 net := e.net
283
284 if net.lastError == nil || net.lastError.Error() != e.err.Error() {
285 net.forEachDownstream(func(dc *downstreamConn) {
[223]286 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
[218]287 })
288 }
289 net.lastError = e.err
290 case eventUpstreamError:
291 uc := e.uc
292
293 uc.forEachDownstream(func(dc *downstreamConn) {
[223]294 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
[218]295 })
296 uc.network.lastError = e.err
[165]297 case eventUpstreamMessage:
298 msg, uc := e.msg, e.uc
[175]299 if uc.isClosed() {
[133]300 uc.logger.Printf("ignoring message on closed connection: %v", msg)
301 break
302 }
[103]303 if err := uc.handleMessage(msg); err != nil {
304 uc.logger.Printf("failed to handle message %q: %v", msg, err)
305 }
[166]306 case eventDownstreamConnected:
307 dc := e.dc
[168]308
309 if err := dc.welcome(); err != nil {
310 dc.logger.Printf("failed to handle new registered connection: %v", err)
311 break
312 }
313
[166]314 u.downstreamConns = append(u.downstreamConns, dc)
[198]315
316 u.forEachUpstream(func(uc *upstreamConn) {
317 uc.updateAway()
318 })
[276]319
320 dc.updateSupportedCaps()
[167]321 case eventDownstreamDisconnected:
322 dc := e.dc
[204]323
[167]324 for i := range u.downstreamConns {
325 if u.downstreamConns[i] == dc {
326 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
327 break
328 }
329 }
[198]330
[253]331 // Save history if we're the last client with this name
332 skipHistory := make(map[*network]bool)
333 u.forEachDownstream(func(conn *downstreamConn) {
334 if dc.clientName == conn.clientName {
335 skipHistory[conn.network] = true
336 }
337 })
338
339 dc.forEachNetwork(func(net *network) {
340 if skipHistory[net] || skipHistory[nil] {
341 return
342 }
343
344 net.offlineClients[dc.clientName] = struct{}{}
345 for _, history := range net.history {
346 history.offlineClients[dc.clientName] = history.ring.Cur()
347 }
348 })
349
[198]350 u.forEachUpstream(func(uc *upstreamConn) {
351 uc.updateAway()
352 })
[165]353 case eventDownstreamMessage:
354 msg, dc := e.msg, e.dc
[133]355 if dc.isClosed() {
356 dc.logger.Printf("ignoring message on closed connection: %v", msg)
357 break
358 }
[103]359 err := dc.handleMessage(msg)
360 if ircErr, ok := err.(ircError); ok {
361 ircErr.Message.Prefix = dc.srv.prefix()
362 dc.SendMessage(ircErr.Message)
363 } else if err != nil {
364 dc.logger.Printf("failed to handle message %q: %v", msg, err)
365 dc.Close()
366 }
[165]367 default:
368 u.srv.Logger.Printf("received unknown event type: %T", e)
[103]369 }
370 }
[101]371}
372
[120]373func (u *user) createNetwork(net *Network) (*network, error) {
[144]374 if net.ID != 0 {
375 panic("tried creating an already-existing network")
376 }
377
[267]378 network := newNetwork(u, net, nil)
[101]379 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
380 if err != nil {
381 return nil, err
382 }
[144]383
[101]384 u.networks = append(u.networks, network)
[144]385
[101]386 go network.run()
387 return network, nil
388}
[202]389
390func (u *user) deleteNetwork(id int64) error {
391 for i, net := range u.networks {
392 if net.ID != id {
393 continue
394 }
395
396 if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
397 return err
398 }
399
400 u.forEachDownstream(func(dc *downstreamConn) {
401 if dc.network != nil && dc.network == net {
402 dc.Close()
403 }
404 })
405
406 net.Stop()
407 u.networks = append(u.networks[:i], u.networks[i+1:]...)
408 return nil
409 }
410
411 panic("tried deleting a non-existing network")
412}
[252]413
414func (u *user) updatePassword(hashed string) error {
415 u.User.Password = hashed
416 return u.srv.db.UpdatePassword(&u.User)
417}
Note: See TracBrowser for help on using the repository browser.