source: code/trunk/conn.go@ 597

Last change on this file since 597 was 594, checked in by contact, 4 years ago

Workaround lack of net.ErrClosed in WebSocket library

File size: 5.8 KB
Line 
1package soju
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "net"
8 "strings"
9 "sync"
10 "time"
11 "unicode"
12
13 "gopkg.in/irc.v3"
14 "nhooyr.io/websocket"
15)
16
17// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
18// reading and writing IRC messages.
19type ircConn interface {
20 ReadMessage() (*irc.Message, error)
21 WriteMessage(*irc.Message) error
22 Close() error
23 SetReadDeadline(time.Time) error
24 SetWriteDeadline(time.Time) error
25 RemoteAddr() net.Addr
26 LocalAddr() net.Addr
27}
28
29func newNetIRCConn(c net.Conn) ircConn {
30 type netConn net.Conn
31 return struct {
32 *irc.Conn
33 netConn
34 }{irc.NewConn(c), c}
35}
36
37type websocketIRCConn struct {
38 conn *websocket.Conn
39 readDeadline, writeDeadline time.Time
40 remoteAddr string
41}
42
43func newWebsocketIRCConn(c *websocket.Conn, remoteAddr string) ircConn {
44 return &websocketIRCConn{conn: c, remoteAddr: remoteAddr}
45}
46
47func (wic *websocketIRCConn) ReadMessage() (*irc.Message, error) {
48 ctx := context.Background()
49 if !wic.readDeadline.IsZero() {
50 var cancel context.CancelFunc
51 ctx, cancel = context.WithDeadline(ctx, wic.readDeadline)
52 defer cancel()
53 }
54 _, b, err := wic.conn.Read(ctx)
55 if err != nil {
56 switch websocket.CloseStatus(err) {
57 case websocket.StatusNormalClosure, websocket.StatusGoingAway:
58 return nil, io.EOF
59 default:
60 return nil, err
61 }
62 }
63 return irc.ParseMessage(string(b))
64}
65
66func (wic *websocketIRCConn) WriteMessage(msg *irc.Message) error {
67 b := []byte(strings.ToValidUTF8(msg.String(), string(unicode.ReplacementChar)))
68 ctx := context.Background()
69 if !wic.writeDeadline.IsZero() {
70 var cancel context.CancelFunc
71 ctx, cancel = context.WithDeadline(ctx, wic.writeDeadline)
72 defer cancel()
73 }
74 return wic.conn.Write(ctx, websocket.MessageText, b)
75}
76
77func isErrWebSocketClosed(err error) bool {
78 return err != nil && strings.HasSuffix(err.Error(), "failed to close WebSocket: already wrote close")
79}
80
81func (wic *websocketIRCConn) Close() error {
82 err := wic.conn.Close(websocket.StatusNormalClosure, "")
83 // TODO: remove once this PR is merged:
84 // https://github.com/nhooyr/websocket/pull/303
85 if isErrWebSocketClosed(err) {
86 return nil
87 }
88 return err
89}
90
91func (wic *websocketIRCConn) SetReadDeadline(t time.Time) error {
92 wic.readDeadline = t
93 return nil
94}
95
96func (wic *websocketIRCConn) SetWriteDeadline(t time.Time) error {
97 wic.writeDeadline = t
98 return nil
99}
100
101func (wic *websocketIRCConn) RemoteAddr() net.Addr {
102 return websocketAddr(wic.remoteAddr)
103}
104
105func (wic *websocketIRCConn) LocalAddr() net.Addr {
106 // Behind a reverse HTTP proxy, we don't have access to the real listening
107 // address
108 return websocketAddr("")
109}
110
111type websocketAddr string
112
113func (websocketAddr) Network() string {
114 return "ws"
115}
116
117func (wa websocketAddr) String() string {
118 return string(wa)
119}
120
121type rateLimiter struct {
122 C <-chan struct{}
123 ticker *time.Ticker
124 stopped chan struct{}
125}
126
127func newRateLimiter(delay time.Duration, burst int) *rateLimiter {
128 ch := make(chan struct{}, burst)
129 for i := 0; i < burst; i++ {
130 ch <- struct{}{}
131 }
132 ticker := time.NewTicker(delay)
133 stopped := make(chan struct{})
134 go func() {
135 for {
136 select {
137 case <-ticker.C:
138 select {
139 case ch <- struct{}{}:
140 // This space is intentionally left blank
141 case <-stopped:
142 return
143 }
144 case <-stopped:
145 return
146 }
147 }
148 }()
149 return &rateLimiter{
150 C: ch,
151 ticker: ticker,
152 stopped: stopped,
153 }
154}
155
156func (rl *rateLimiter) Stop() {
157 rl.ticker.Stop()
158 close(rl.stopped)
159}
160
161type connOptions struct {
162 Logger Logger
163 RateLimitDelay time.Duration
164 RateLimitBurst int
165}
166
167type conn struct {
168 conn ircConn
169 srv *Server
170 logger Logger
171
172 lock sync.Mutex
173 outgoing chan<- *irc.Message
174 closed bool
175}
176
177func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
178 outgoing := make(chan *irc.Message, 64)
179 c := &conn{
180 conn: ic,
181 srv: srv,
182 outgoing: outgoing,
183 logger: options.Logger,
184 }
185
186 go func() {
187 var rl *rateLimiter
188 if options.RateLimitDelay > 0 && options.RateLimitBurst > 0 {
189 rl = newRateLimiter(options.RateLimitDelay, options.RateLimitBurst)
190 defer rl.Stop()
191 }
192
193 for msg := range outgoing {
194 if rl != nil {
195 <-rl.C
196 }
197
198 if c.srv.Debug {
199 c.logger.Printf("sent: %v", msg)
200 }
201 c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
202 if err := c.conn.WriteMessage(msg); err != nil {
203 c.logger.Printf("failed to write message: %v", err)
204 break
205 }
206 }
207 if err := c.conn.Close(); err != nil && !isErrClosed(err) {
208 c.logger.Printf("failed to close connection: %v", err)
209 } else {
210 c.logger.Printf("connection closed")
211 }
212 // Drain the outgoing channel to prevent SendMessage from blocking
213 for range outgoing {
214 // This space is intentionally left blank
215 }
216 }()
217
218 c.logger.Printf("new connection")
219 return c
220}
221
222func (c *conn) isClosed() bool {
223 c.lock.Lock()
224 defer c.lock.Unlock()
225 return c.closed
226}
227
228// Close closes the connection. It is safe to call from any goroutine.
229func (c *conn) Close() error {
230 c.lock.Lock()
231 defer c.lock.Unlock()
232
233 if c.closed {
234 return fmt.Errorf("connection already closed")
235 }
236
237 err := c.conn.Close()
238 c.closed = true
239 close(c.outgoing)
240 return err
241}
242
243func (c *conn) ReadMessage() (*irc.Message, error) {
244 msg, err := c.conn.ReadMessage()
245 if isErrClosed(err) {
246 return nil, io.EOF
247 } else if err != nil {
248 return nil, err
249 }
250
251 if c.srv.Debug {
252 c.logger.Printf("received: %v", msg)
253 }
254
255 return msg, nil
256}
257
258// SendMessage queues a new outgoing message. It is safe to call from any
259// goroutine.
260//
261// If the connection is closed before the message is sent, SendMessage silently
262// drops the message.
263func (c *conn) SendMessage(msg *irc.Message) {
264 c.lock.Lock()
265 defer c.lock.Unlock()
266
267 if c.closed {
268 return
269 }
270 c.outgoing <- msg
271}
272
273func (c *conn) RemoteAddr() net.Addr {
274 return c.conn.RemoteAddr()
275}
276
277func (c *conn) LocalAddr() net.Addr {
278 return c.conn.LocalAddr()
279}
Note: See TracBrowser for help on using the repository browser.