source: code/trunk/user.go@ 168

Last change on this file since 168 was 168, checked in by contact, 5 years ago

Nuke user.lock

Split user.register into two functions, one to make sure the user is
authenticated, the other to send our current state. This allows to get
rid of data races by doing the second part in the user goroutine.

Closes: https://todo.sr.ht/~emersion/soju/22

File size: 4.4 KB
Line 
1package soju
2
3import (
4 "sync"
5 "time"
6
7 "gopkg.in/irc.v3"
8)
9
10type event interface{}
11
12type eventUpstreamMessage struct {
13 msg *irc.Message
14 uc *upstreamConn
15}
16
17type eventDownstreamMessage struct {
18 msg *irc.Message
19 dc *downstreamConn
20}
21
22type eventDownstreamConnected struct {
23 dc *downstreamConn
24}
25
26type eventDownstreamDisconnected struct {
27 dc *downstreamConn
28}
29
30type network struct {
31 Network
32 user *user
33 ring *Ring
34
35 lock sync.Mutex
36 conn *upstreamConn
37 history map[string]uint64
38}
39
40func newNetwork(user *user, record *Network) *network {
41 return &network{
42 Network: *record,
43 user: user,
44 ring: NewRing(user.srv.RingCap),
45 history: make(map[string]uint64),
46 }
47}
48
49func (net *network) run() {
50 var lastTry time.Time
51 for {
52 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
53 delay := retryConnectMinDelay - dur
54 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
55 time.Sleep(delay)
56 }
57 lastTry = time.Now()
58
59 uc, err := connectToUpstream(net)
60 if err != nil {
61 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
62 continue
63 }
64
65 uc.register()
66
67 net.lock.Lock()
68 net.conn = uc
69 net.lock.Unlock()
70
71 if err := uc.readMessages(net.user.events); err != nil {
72 uc.logger.Printf("failed to handle messages: %v", err)
73 }
74 uc.Close()
75
76 net.lock.Lock()
77 net.conn = nil
78 net.lock.Unlock()
79 }
80}
81
82func (net *network) upstream() *upstreamConn {
83 net.lock.Lock()
84 defer net.lock.Unlock()
85 return net.conn
86}
87
88type user struct {
89 User
90 srv *Server
91
92 events chan event
93
94 networks []*network
95 downstreamConns []*downstreamConn
96}
97
98func newUser(srv *Server, record *User) *user {
99 return &user{
100 User: *record,
101 srv: srv,
102 events: make(chan event, 64),
103 }
104}
105
106func (u *user) forEachNetwork(f func(*network)) {
107 for _, network := range u.networks {
108 f(network)
109 }
110}
111
112func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
113 for _, network := range u.networks {
114 uc := network.upstream()
115 if uc == nil || !uc.registered || uc.closed {
116 continue
117 }
118 f(uc)
119 }
120}
121
122func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
123 for _, dc := range u.downstreamConns {
124 f(dc)
125 }
126}
127
128func (u *user) getNetwork(name string) *network {
129 for _, network := range u.networks {
130 if network.Addr == name {
131 return network
132 }
133 }
134 return nil
135}
136
137func (u *user) run() {
138 networks, err := u.srv.db.ListNetworks(u.Username)
139 if err != nil {
140 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
141 return
142 }
143
144 for _, record := range networks {
145 network := newNetwork(u, &record)
146 u.networks = append(u.networks, network)
147
148 go network.run()
149 }
150
151 for e := range u.events {
152 switch e := e.(type) {
153 case eventUpstreamMessage:
154 msg, uc := e.msg, e.uc
155 if uc.closed {
156 uc.logger.Printf("ignoring message on closed connection: %v", msg)
157 break
158 }
159 if err := uc.handleMessage(msg); err != nil {
160 uc.logger.Printf("failed to handle message %q: %v", msg, err)
161 }
162 case eventDownstreamConnected:
163 dc := e.dc
164
165 if err := dc.welcome(); err != nil {
166 dc.logger.Printf("failed to handle new registered connection: %v", err)
167 break
168 }
169
170 u.downstreamConns = append(u.downstreamConns, dc)
171 case eventDownstreamDisconnected:
172 dc := e.dc
173 for i := range u.downstreamConns {
174 if u.downstreamConns[i] == dc {
175 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
176 break
177 }
178 }
179 case eventDownstreamMessage:
180 msg, dc := e.msg, e.dc
181 if dc.isClosed() {
182 dc.logger.Printf("ignoring message on closed connection: %v", msg)
183 break
184 }
185 err := dc.handleMessage(msg)
186 if ircErr, ok := err.(ircError); ok {
187 ircErr.Message.Prefix = dc.srv.prefix()
188 dc.SendMessage(ircErr.Message)
189 } else if err != nil {
190 dc.logger.Printf("failed to handle message %q: %v", msg, err)
191 dc.Close()
192 }
193 default:
194 u.srv.Logger.Printf("received unknown event type: %T", e)
195 }
196 }
197}
198
199func (u *user) createNetwork(net *Network) (*network, error) {
200 if net.ID != 0 {
201 panic("tried creating an already-existing network")
202 }
203
204 network := newNetwork(u, net)
205 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
206 if err != nil {
207 return nil, err
208 }
209
210 u.forEachDownstream(func(dc *downstreamConn) {
211 if dc.network == nil {
212 dc.runNetwork(network, false)
213 }
214 })
215
216 u.networks = append(u.networks, network)
217
218 go network.run()
219 return network, nil
220}
Note: See TracBrowser for help on using the repository browser.