source: code/trunk/user.go@ 202

Last change on this file since 202 was 202, checked in by contact, 5 years ago

Add "network delete" service command

And add all the infrastructure required to stop and delete networks.

References: https://todo.sr.ht/~emersion/soju/17

File size: 6.1 KB
Line 
1package soju
2
3import (
4 "sync"
5 "time"
6
7 "gopkg.in/irc.v3"
8)
9
10type event interface{}
11
12type eventUpstreamMessage struct {
13 msg *irc.Message
14 uc *upstreamConn
15}
16
17type eventUpstreamConnected struct {
18 uc *upstreamConn
19}
20
21type eventUpstreamDisconnected struct {
22 uc *upstreamConn
23}
24
25type eventDownstreamMessage struct {
26 msg *irc.Message
27 dc *downstreamConn
28}
29
30type eventDownstreamConnected struct {
31 dc *downstreamConn
32}
33
34type eventDownstreamDisconnected struct {
35 dc *downstreamConn
36}
37
38type network struct {
39 Network
40 user *user
41 ring *Ring
42 stopped chan struct{}
43
44 lock sync.Mutex
45 conn *upstreamConn
46 history map[string]uint64
47}
48
49func newNetwork(user *user, record *Network) *network {
50 return &network{
51 Network: *record,
52 user: user,
53 ring: NewRing(user.srv.RingCap),
54 stopped: make(chan struct{}),
55 history: make(map[string]uint64),
56 }
57}
58
59func (net *network) run() {
60 var lastTry time.Time
61 for {
62 select {
63 case <-net.stopped:
64 return
65 default:
66 // This space is intentionally left blank
67 }
68
69 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
70 delay := retryConnectMinDelay - dur
71 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
72 time.Sleep(delay)
73 }
74 lastTry = time.Now()
75
76 uc, err := connectToUpstream(net)
77 if err != nil {
78 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
79 continue
80 }
81
82 uc.register()
83 if err := uc.runUntilRegistered(); err != nil {
84 uc.logger.Printf("failed to register: %v", err)
85 uc.Close()
86 continue
87 }
88
89 net.user.events <- eventUpstreamConnected{uc}
90 if err := uc.readMessages(net.user.events); err != nil {
91 uc.logger.Printf("failed to handle messages: %v", err)
92 }
93 uc.Close()
94 net.user.events <- eventUpstreamDisconnected{uc}
95 }
96}
97
98func (net *network) upstream() *upstreamConn {
99 net.lock.Lock()
100 defer net.lock.Unlock()
101 return net.conn
102}
103
104func (net *network) Stop() {
105 select {
106 case <-net.stopped:
107 return
108 default:
109 close(net.stopped)
110 }
111
112 if uc := net.upstream(); uc != nil {
113 uc.Close()
114 }
115}
116
117type user struct {
118 User
119 srv *Server
120
121 events chan event
122
123 networks []*network
124 downstreamConns []*downstreamConn
125
126 // LIST commands in progress
127 pendingLISTs []pendingLIST
128}
129
130type pendingLIST struct {
131 downstreamID uint64
132 // list of per-upstream LIST commands not yet sent or completed
133 pendingCommands map[int64]*irc.Message
134}
135
136func newUser(srv *Server, record *User) *user {
137 return &user{
138 User: *record,
139 srv: srv,
140 events: make(chan event, 64),
141 }
142}
143
144func (u *user) forEachNetwork(f func(*network)) {
145 for _, network := range u.networks {
146 f(network)
147 }
148}
149
150func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
151 for _, network := range u.networks {
152 uc := network.upstream()
153 if uc == nil {
154 continue
155 }
156 f(uc)
157 }
158}
159
160func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
161 for _, dc := range u.downstreamConns {
162 f(dc)
163 }
164}
165
166func (u *user) getNetwork(name string) *network {
167 for _, network := range u.networks {
168 if network.Addr == name {
169 return network
170 }
171 if network.Name != "" && network.Name == name {
172 return network
173 }
174 }
175 return nil
176}
177
178func (u *user) run() {
179 networks, err := u.srv.db.ListNetworks(u.Username)
180 if err != nil {
181 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
182 return
183 }
184
185 for _, record := range networks {
186 network := newNetwork(u, &record)
187 u.networks = append(u.networks, network)
188
189 go network.run()
190 }
191
192 for e := range u.events {
193 switch e := e.(type) {
194 case eventUpstreamConnected:
195 uc := e.uc
196
197 uc.network.lock.Lock()
198 uc.network.conn = uc
199 uc.network.lock.Unlock()
200
201 uc.updateAway()
202 case eventUpstreamDisconnected:
203 uc := e.uc
204
205 uc.network.lock.Lock()
206 uc.network.conn = nil
207 uc.network.lock.Unlock()
208
209 for _, log := range uc.logs {
210 log.file.Close()
211 }
212
213 uc.endPendingLISTs(true)
214 case eventUpstreamMessage:
215 msg, uc := e.msg, e.uc
216 if uc.isClosed() {
217 uc.logger.Printf("ignoring message on closed connection: %v", msg)
218 break
219 }
220 if err := uc.handleMessage(msg); err != nil {
221 uc.logger.Printf("failed to handle message %q: %v", msg, err)
222 }
223 case eventDownstreamConnected:
224 dc := e.dc
225
226 if err := dc.welcome(); err != nil {
227 dc.logger.Printf("failed to handle new registered connection: %v", err)
228 break
229 }
230
231 u.downstreamConns = append(u.downstreamConns, dc)
232
233 u.forEachUpstream(func(uc *upstreamConn) {
234 uc.updateAway()
235 })
236 case eventDownstreamDisconnected:
237 dc := e.dc
238 for i := range u.downstreamConns {
239 if u.downstreamConns[i] == dc {
240 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
241 break
242 }
243 }
244
245 u.forEachUpstream(func(uc *upstreamConn) {
246 uc.updateAway()
247 })
248 case eventDownstreamMessage:
249 msg, dc := e.msg, e.dc
250 if dc.isClosed() {
251 dc.logger.Printf("ignoring message on closed connection: %v", msg)
252 break
253 }
254 err := dc.handleMessage(msg)
255 if ircErr, ok := err.(ircError); ok {
256 ircErr.Message.Prefix = dc.srv.prefix()
257 dc.SendMessage(ircErr.Message)
258 } else if err != nil {
259 dc.logger.Printf("failed to handle message %q: %v", msg, err)
260 dc.Close()
261 }
262 default:
263 u.srv.Logger.Printf("received unknown event type: %T", e)
264 }
265 }
266}
267
268func (u *user) createNetwork(net *Network) (*network, error) {
269 if net.ID != 0 {
270 panic("tried creating an already-existing network")
271 }
272
273 network := newNetwork(u, net)
274 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
275 if err != nil {
276 return nil, err
277 }
278
279 u.forEachDownstream(func(dc *downstreamConn) {
280 if dc.network == nil {
281 dc.runNetwork(network, false)
282 }
283 })
284
285 u.networks = append(u.networks, network)
286
287 go network.run()
288 return network, nil
289}
290
291func (u *user) deleteNetwork(id int64) error {
292 for i, net := range u.networks {
293 if net.ID != id {
294 continue
295 }
296
297 if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
298 return err
299 }
300
301 u.forEachDownstream(func(dc *downstreamConn) {
302 if dc.network != nil && dc.network == net {
303 dc.Close()
304 }
305 })
306
307 net.Stop()
308 u.networks = append(u.networks[:i], u.networks[i+1:]...)
309 return nil
310 }
311
312 panic("tried deleting a non-existing network")
313}
Note: See TracBrowser for help on using the repository browser.