source: code/trunk/user.go@ 223

Last change on this file since 223 was 223, checked in by contact, 5 years ago

Use Network.GetName in network status NOTICE messages

File size: 8.2 KB
Line 
1package soju
2
3import (
4 "fmt"
5 "sync"
6 "time"
7
8 "gopkg.in/irc.v3"
9)
10
11type event interface{}
12
13type eventUpstreamMessage struct {
14 msg *irc.Message
15 uc *upstreamConn
16}
17
18type eventUpstreamConnectionError struct {
19 net *network
20 err error
21}
22
23type eventUpstreamConnected struct {
24 uc *upstreamConn
25}
26
27type eventUpstreamDisconnected struct {
28 uc *upstreamConn
29}
30
31type eventUpstreamError struct {
32 uc *upstreamConn
33 err error
34}
35
36type eventDownstreamMessage struct {
37 msg *irc.Message
38 dc *downstreamConn
39}
40
41type eventDownstreamConnected struct {
42 dc *downstreamConn
43}
44
45type eventDownstreamDisconnected struct {
46 dc *downstreamConn
47}
48
49type network struct {
50 Network
51 user *user
52 ring *Ring
53 stopped chan struct{}
54
55 history map[string]uint64
56 lastError error
57
58 lock sync.Mutex
59 conn *upstreamConn
60}
61
62func newNetwork(user *user, record *Network) *network {
63 return &network{
64 Network: *record,
65 user: user,
66 ring: NewRing(user.srv.RingCap),
67 stopped: make(chan struct{}),
68 history: make(map[string]uint64),
69 }
70}
71
72func (net *network) forEachDownstream(f func(*downstreamConn)) {
73 net.user.forEachDownstream(func(dc *downstreamConn) {
74 if dc.network != nil && dc.network != net {
75 return
76 }
77 f(dc)
78 })
79}
80
81func (net *network) run() {
82 var lastTry time.Time
83 for {
84 select {
85 case <-net.stopped:
86 return
87 default:
88 // This space is intentionally left blank
89 }
90
91 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
92 delay := retryConnectMinDelay - dur
93 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
94 time.Sleep(delay)
95 }
96 lastTry = time.Now()
97
98 uc, err := connectToUpstream(net)
99 if err != nil {
100 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
101 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
102 continue
103 }
104
105 uc.register()
106 if err := uc.runUntilRegistered(); err != nil {
107 uc.logger.Printf("failed to register: %v", err)
108 net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", err)}
109 uc.Close()
110 continue
111 }
112
113 net.user.events <- eventUpstreamConnected{uc}
114 if err := uc.readMessages(net.user.events); err != nil {
115 uc.logger.Printf("failed to handle messages: %v", err)
116 net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
117 }
118 uc.Close()
119 net.user.events <- eventUpstreamDisconnected{uc}
120 }
121}
122
123func (net *network) upstream() *upstreamConn {
124 net.lock.Lock()
125 defer net.lock.Unlock()
126 return net.conn
127}
128
129func (net *network) Stop() {
130 select {
131 case <-net.stopped:
132 return
133 default:
134 close(net.stopped)
135 }
136
137 if uc := net.upstream(); uc != nil {
138 uc.Close()
139 }
140}
141
142func (net *network) createUpdateChannel(ch *Channel) error {
143 if dbCh, err := net.user.srv.db.GetChannel(net.ID, ch.Name); err == nil {
144 ch.ID = dbCh.ID
145 } else if err != ErrNoSuchChannel {
146 return err
147 }
148 return net.user.srv.db.StoreChannel(net.ID, ch)
149}
150
151func (net *network) deleteChannel(name string) error {
152 return net.user.srv.db.DeleteChannel(net.ID, name)
153}
154
155type user struct {
156 User
157 srv *Server
158
159 events chan event
160
161 networks []*network
162 downstreamConns []*downstreamConn
163
164 // LIST commands in progress
165 pendingLISTs []pendingLIST
166}
167
168type pendingLIST struct {
169 downstreamID uint64
170 // list of per-upstream LIST commands not yet sent or completed
171 pendingCommands map[int64]*irc.Message
172}
173
174func newUser(srv *Server, record *User) *user {
175 return &user{
176 User: *record,
177 srv: srv,
178 events: make(chan event, 64),
179 }
180}
181
182func (u *user) forEachNetwork(f func(*network)) {
183 for _, network := range u.networks {
184 f(network)
185 }
186}
187
188func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
189 for _, network := range u.networks {
190 uc := network.upstream()
191 if uc == nil {
192 continue
193 }
194 f(uc)
195 }
196}
197
198func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
199 for _, dc := range u.downstreamConns {
200 f(dc)
201 }
202}
203
204func (u *user) getNetwork(name string) *network {
205 for _, network := range u.networks {
206 if network.Addr == name {
207 return network
208 }
209 if network.Name != "" && network.Name == name {
210 return network
211 }
212 }
213 return nil
214}
215
216func (u *user) run() {
217 networks, err := u.srv.db.ListNetworks(u.Username)
218 if err != nil {
219 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
220 return
221 }
222
223 for _, record := range networks {
224 network := newNetwork(u, &record)
225 u.networks = append(u.networks, network)
226
227 go network.run()
228 }
229
230 for e := range u.events {
231 switch e := e.(type) {
232 case eventUpstreamConnected:
233 uc := e.uc
234
235 uc.network.lock.Lock()
236 uc.network.conn = uc
237 uc.network.lock.Unlock()
238
239 uc.updateAway()
240
241 uc.forEachDownstream(func(dc *downstreamConn) {
242 sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
243 })
244 uc.network.lastError = nil
245 case eventUpstreamDisconnected:
246 uc := e.uc
247
248 uc.network.lock.Lock()
249 uc.network.conn = nil
250 uc.network.lock.Unlock()
251
252 for _, ml := range uc.messageLoggers {
253 if err := ml.Close(); err != nil {
254 uc.logger.Printf("failed to close message logger: %v", err)
255 }
256 }
257
258 uc.endPendingLISTs(true)
259
260 if uc.network.lastError == nil {
261 uc.forEachDownstream(func(dc *downstreamConn) {
262 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
263 })
264 }
265 case eventUpstreamConnectionError:
266 net := e.net
267
268 if net.lastError == nil || net.lastError.Error() != e.err.Error() {
269 net.forEachDownstream(func(dc *downstreamConn) {
270 sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
271 })
272 }
273 net.lastError = e.err
274 case eventUpstreamError:
275 uc := e.uc
276
277 uc.forEachDownstream(func(dc *downstreamConn) {
278 sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
279 })
280 uc.network.lastError = e.err
281 case eventUpstreamMessage:
282 msg, uc := e.msg, e.uc
283 if uc.isClosed() {
284 uc.logger.Printf("ignoring message on closed connection: %v", msg)
285 break
286 }
287 if err := uc.handleMessage(msg); err != nil {
288 uc.logger.Printf("failed to handle message %q: %v", msg, err)
289 }
290 case eventDownstreamConnected:
291 dc := e.dc
292
293 if err := dc.welcome(); err != nil {
294 dc.logger.Printf("failed to handle new registered connection: %v", err)
295 break
296 }
297
298 u.downstreamConns = append(u.downstreamConns, dc)
299
300 u.forEachUpstream(func(uc *upstreamConn) {
301 uc.updateAway()
302 })
303 case eventDownstreamDisconnected:
304 dc := e.dc
305
306 for net, rc := range dc.ringConsumers {
307 seq := rc.Close()
308 net.history[dc.clientName] = seq
309 }
310
311 for i := range u.downstreamConns {
312 if u.downstreamConns[i] == dc {
313 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
314 break
315 }
316 }
317
318 u.forEachUpstream(func(uc *upstreamConn) {
319 uc.updateAway()
320 })
321 case eventDownstreamMessage:
322 msg, dc := e.msg, e.dc
323 if dc.isClosed() {
324 dc.logger.Printf("ignoring message on closed connection: %v", msg)
325 break
326 }
327 err := dc.handleMessage(msg)
328 if ircErr, ok := err.(ircError); ok {
329 ircErr.Message.Prefix = dc.srv.prefix()
330 dc.SendMessage(ircErr.Message)
331 } else if err != nil {
332 dc.logger.Printf("failed to handle message %q: %v", msg, err)
333 dc.Close()
334 }
335 default:
336 u.srv.Logger.Printf("received unknown event type: %T", e)
337 }
338 }
339}
340
341func (u *user) createNetwork(net *Network) (*network, error) {
342 if net.ID != 0 {
343 panic("tried creating an already-existing network")
344 }
345
346 network := newNetwork(u, net)
347 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
348 if err != nil {
349 return nil, err
350 }
351
352 u.forEachDownstream(func(dc *downstreamConn) {
353 if dc.network == nil {
354 dc.runNetwork(network, false)
355 }
356 })
357
358 u.networks = append(u.networks, network)
359
360 go network.run()
361 return network, nil
362}
363
364func (u *user) deleteNetwork(id int64) error {
365 for i, net := range u.networks {
366 if net.ID != id {
367 continue
368 }
369
370 if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
371 return err
372 }
373
374 u.forEachDownstream(func(dc *downstreamConn) {
375 if dc.network != nil && dc.network == net {
376 dc.Close()
377 }
378 })
379
380 net.Stop()
381 net.ring.Close()
382 u.networks = append(u.networks[:i], u.networks[i+1:]...)
383 return nil
384 }
385
386 panic("tried deleting a non-existing network")
387}
Note: See TracBrowser for help on using the repository browser.