source: code/trunk/user.go@ 175

Last change on this file since 175 was 175, checked in by contact, 5 years ago

Fix race condition in upstreamConn.Close

upstreamConn.closed was a bool accessed from different goroutines. Use
the same pattern as downstreamConn instead.

File size: 4.5 KB
Line 
1package soju
2
3import (
4 "sync"
5 "time"
6
7 "gopkg.in/irc.v3"
8)
9
10type event interface{}
11
12type eventUpstreamMessage struct {
13 msg *irc.Message
14 uc *upstreamConn
15}
16
17type eventDownstreamMessage struct {
18 msg *irc.Message
19 dc *downstreamConn
20}
21
22type eventDownstreamConnected struct {
23 dc *downstreamConn
24}
25
26type eventDownstreamDisconnected struct {
27 dc *downstreamConn
28}
29
30type network struct {
31 Network
32 user *user
33 ring *Ring
34
35 lock sync.Mutex
36 conn *upstreamConn
37 history map[string]uint64
38}
39
40func newNetwork(user *user, record *Network) *network {
41 return &network{
42 Network: *record,
43 user: user,
44 ring: NewRing(user.srv.RingCap),
45 history: make(map[string]uint64),
46 }
47}
48
49func (net *network) run() {
50 var lastTry time.Time
51 for {
52 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
53 delay := retryConnectMinDelay - dur
54 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
55 time.Sleep(delay)
56 }
57 lastTry = time.Now()
58
59 uc, err := connectToUpstream(net)
60 if err != nil {
61 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
62 continue
63 }
64
65 uc.register()
66
67 // TODO: wait for the connection to be registered before adding it to
68 // net, otherwise messages might be sent to it while still being
69 // unauthenticated
70 net.lock.Lock()
71 net.conn = uc
72 net.lock.Unlock()
73
74 if err := uc.readMessages(net.user.events); err != nil {
75 uc.logger.Printf("failed to handle messages: %v", err)
76 }
77 uc.Close()
78
79 net.lock.Lock()
80 net.conn = nil
81 net.lock.Unlock()
82 }
83}
84
85func (net *network) upstream() *upstreamConn {
86 net.lock.Lock()
87 defer net.lock.Unlock()
88 return net.conn
89}
90
91type user struct {
92 User
93 srv *Server
94
95 events chan event
96
97 networks []*network
98 downstreamConns []*downstreamConn
99}
100
101func newUser(srv *Server, record *User) *user {
102 return &user{
103 User: *record,
104 srv: srv,
105 events: make(chan event, 64),
106 }
107}
108
109func (u *user) forEachNetwork(f func(*network)) {
110 for _, network := range u.networks {
111 f(network)
112 }
113}
114
115func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
116 for _, network := range u.networks {
117 uc := network.upstream()
118 if uc == nil || !uc.registered {
119 continue
120 }
121 f(uc)
122 }
123}
124
125func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
126 for _, dc := range u.downstreamConns {
127 f(dc)
128 }
129}
130
131func (u *user) getNetwork(name string) *network {
132 for _, network := range u.networks {
133 if network.Addr == name {
134 return network
135 }
136 }
137 return nil
138}
139
140func (u *user) run() {
141 networks, err := u.srv.db.ListNetworks(u.Username)
142 if err != nil {
143 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
144 return
145 }
146
147 for _, record := range networks {
148 network := newNetwork(u, &record)
149 u.networks = append(u.networks, network)
150
151 go network.run()
152 }
153
154 for e := range u.events {
155 switch e := e.(type) {
156 case eventUpstreamMessage:
157 msg, uc := e.msg, e.uc
158 if uc.isClosed() {
159 uc.logger.Printf("ignoring message on closed connection: %v", msg)
160 break
161 }
162 if err := uc.handleMessage(msg); err != nil {
163 uc.logger.Printf("failed to handle message %q: %v", msg, err)
164 }
165 case eventDownstreamConnected:
166 dc := e.dc
167
168 if err := dc.welcome(); err != nil {
169 dc.logger.Printf("failed to handle new registered connection: %v", err)
170 break
171 }
172
173 u.downstreamConns = append(u.downstreamConns, dc)
174 case eventDownstreamDisconnected:
175 dc := e.dc
176 for i := range u.downstreamConns {
177 if u.downstreamConns[i] == dc {
178 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
179 break
180 }
181 }
182 case eventDownstreamMessage:
183 msg, dc := e.msg, e.dc
184 if dc.isClosed() {
185 dc.logger.Printf("ignoring message on closed connection: %v", msg)
186 break
187 }
188 err := dc.handleMessage(msg)
189 if ircErr, ok := err.(ircError); ok {
190 ircErr.Message.Prefix = dc.srv.prefix()
191 dc.SendMessage(ircErr.Message)
192 } else if err != nil {
193 dc.logger.Printf("failed to handle message %q: %v", msg, err)
194 dc.Close()
195 }
196 default:
197 u.srv.Logger.Printf("received unknown event type: %T", e)
198 }
199 }
200}
201
202func (u *user) createNetwork(net *Network) (*network, error) {
203 if net.ID != 0 {
204 panic("tried creating an already-existing network")
205 }
206
207 network := newNetwork(u, net)
208 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
209 if err != nil {
210 return nil, err
211 }
212
213 u.forEachDownstream(func(dc *downstreamConn) {
214 if dc.network == nil {
215 dc.runNetwork(network, false)
216 }
217 })
218
219 u.networks = append(u.networks, network)
220
221 go network.run()
222 return network, nil
223}
Note: See TracBrowser for help on using the repository browser.