source: code/trunk/conn.go@ 641

Last change on this file since 641 was 594, checked in by contact, 4 years ago

Workaround lack of net.ErrClosed in WebSocket library

File size: 5.8 KB
RevLine 
[210]1package soju
2
3import (
[323]4 "context"
[210]5 "fmt"
[341]6 "io"
[210]7 "net"
[415]8 "strings"
[280]9 "sync"
[210]10 "time"
[415]11 "unicode"
[210]12
13 "gopkg.in/irc.v3"
[323]14 "nhooyr.io/websocket"
[210]15)
16
[315]17// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
18// reading and writing IRC messages.
19type ircConn interface {
20 ReadMessage() (*irc.Message, error)
21 WriteMessage(*irc.Message) error
22 Close() error
[323]23 SetReadDeadline(time.Time) error
[315]24 SetWriteDeadline(time.Time) error
[347]25 RemoteAddr() net.Addr
[383]26 LocalAddr() net.Addr
[315]27}
28
[323]29func newNetIRCConn(c net.Conn) ircConn {
[315]30 type netConn net.Conn
31 return struct {
32 *irc.Conn
33 netConn
34 }{irc.NewConn(c), c}
35}
36
[323]37type websocketIRCConn struct {
38 conn *websocket.Conn
39 readDeadline, writeDeadline time.Time
[347]40 remoteAddr string
[323]41}
42
[347]43func newWebsocketIRCConn(c *websocket.Conn, remoteAddr string) ircConn {
[465]44 return &websocketIRCConn{conn: c, remoteAddr: remoteAddr}
[323]45}
46
[465]47func (wic *websocketIRCConn) ReadMessage() (*irc.Message, error) {
[323]48 ctx := context.Background()
49 if !wic.readDeadline.IsZero() {
50 var cancel context.CancelFunc
51 ctx, cancel = context.WithDeadline(ctx, wic.readDeadline)
52 defer cancel()
53 }
54 _, b, err := wic.conn.Read(ctx)
55 if err != nil {
[341]56 switch websocket.CloseStatus(err) {
57 case websocket.StatusNormalClosure, websocket.StatusGoingAway:
58 return nil, io.EOF
59 default:
60 return nil, err
61 }
[323]62 }
63 return irc.ParseMessage(string(b))
64}
65
[465]66func (wic *websocketIRCConn) WriteMessage(msg *irc.Message) error {
[415]67 b := []byte(strings.ToValidUTF8(msg.String(), string(unicode.ReplacementChar)))
[323]68 ctx := context.Background()
69 if !wic.writeDeadline.IsZero() {
70 var cancel context.CancelFunc
71 ctx, cancel = context.WithDeadline(ctx, wic.writeDeadline)
72 defer cancel()
73 }
74 return wic.conn.Write(ctx, websocket.MessageText, b)
75}
76
[594]77func isErrWebSocketClosed(err error) bool {
78 return err != nil && strings.HasSuffix(err.Error(), "failed to close WebSocket: already wrote close")
79}
80
[465]81func (wic *websocketIRCConn) Close() error {
[594]82 err := wic.conn.Close(websocket.StatusNormalClosure, "")
83 // TODO: remove once this PR is merged:
84 // https://github.com/nhooyr/websocket/pull/303
85 if isErrWebSocketClosed(err) {
86 return nil
87 }
88 return err
[323]89}
90
[465]91func (wic *websocketIRCConn) SetReadDeadline(t time.Time) error {
[323]92 wic.readDeadline = t
93 return nil
94}
95
[465]96func (wic *websocketIRCConn) SetWriteDeadline(t time.Time) error {
[323]97 wic.writeDeadline = t
98 return nil
99}
100
[465]101func (wic *websocketIRCConn) RemoteAddr() net.Addr {
[347]102 return websocketAddr(wic.remoteAddr)
103}
104
[465]105func (wic *websocketIRCConn) LocalAddr() net.Addr {
[383]106 // Behind a reverse HTTP proxy, we don't have access to the real listening
107 // address
108 return websocketAddr("")
109}
110
[347]111type websocketAddr string
112
113func (websocketAddr) Network() string {
114 return "ws"
115}
116
117func (wa websocketAddr) String() string {
118 return string(wa)
119}
120
[398]121type rateLimiter struct {
[402]122 C <-chan struct{}
123 ticker *time.Ticker
[398]124 stopped chan struct{}
125}
126
127func newRateLimiter(delay time.Duration, burst int) *rateLimiter {
128 ch := make(chan struct{}, burst)
129 for i := 0; i < burst; i++ {
130 ch <- struct{}{}
131 }
132 ticker := time.NewTicker(delay)
133 stopped := make(chan struct{})
134 go func() {
135 for {
136 select {
137 case <-ticker.C:
138 select {
139 case ch <- struct{}{}:
140 // This space is intentionally left blank
141 case <-stopped:
142 return
143 }
144 case <-stopped:
145 return
146 }
147 }
148 }()
149 return &rateLimiter{
[402]150 C: ch,
151 ticker: ticker,
[398]152 stopped: stopped,
153 }
154}
155
156func (rl *rateLimiter) Stop() {
157 rl.ticker.Stop()
158 close(rl.stopped)
159}
160
161type connOptions struct {
[402]162 Logger Logger
[398]163 RateLimitDelay time.Duration
164 RateLimitBurst int
165}
166
[210]167type conn struct {
[315]168 conn ircConn
[280]169 srv *Server
170 logger Logger
171
172 lock sync.Mutex
[210]173 outgoing chan<- *irc.Message
[280]174 closed bool
[210]175}
176
[398]177func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
[210]178 outgoing := make(chan *irc.Message, 64)
179 c := &conn{
[315]180 conn: ic,
[210]181 srv: srv,
182 outgoing: outgoing,
[398]183 logger: options.Logger,
[210]184 }
185
186 go func() {
[398]187 var rl *rateLimiter
188 if options.RateLimitDelay > 0 && options.RateLimitBurst > 0 {
189 rl = newRateLimiter(options.RateLimitDelay, options.RateLimitBurst)
190 defer rl.Stop()
191 }
192
[210]193 for msg := range outgoing {
[398]194 if rl != nil {
195 <-rl.C
196 }
197
[210]198 if c.srv.Debug {
199 c.logger.Printf("sent: %v", msg)
200 }
[315]201 c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
202 if err := c.conn.WriteMessage(msg); err != nil {
[210]203 c.logger.Printf("failed to write message: %v", err)
204 break
205 }
206 }
[519]207 if err := c.conn.Close(); err != nil && !isErrClosed(err) {
[210]208 c.logger.Printf("failed to close connection: %v", err)
209 } else {
210 c.logger.Printf("connection closed")
211 }
212 // Drain the outgoing channel to prevent SendMessage from blocking
213 for range outgoing {
214 // This space is intentionally left blank
215 }
216 }()
217
218 c.logger.Printf("new connection")
219 return c
220}
221
222func (c *conn) isClosed() bool {
[280]223 c.lock.Lock()
224 defer c.lock.Unlock()
225 return c.closed
[210]226}
227
228// Close closes the connection. It is safe to call from any goroutine.
229func (c *conn) Close() error {
[280]230 c.lock.Lock()
231 defer c.lock.Unlock()
232
233 if c.closed {
[210]234 return fmt.Errorf("connection already closed")
235 }
[280]236
[315]237 err := c.conn.Close()
[280]238 c.closed = true
[210]239 close(c.outgoing)
[312]240 return err
[210]241}
242
243func (c *conn) ReadMessage() (*irc.Message, error) {
[315]244 msg, err := c.conn.ReadMessage()
[519]245 if isErrClosed(err) {
246 return nil, io.EOF
247 } else if err != nil {
[210]248 return nil, err
249 }
250
251 if c.srv.Debug {
252 c.logger.Printf("received: %v", msg)
253 }
254
255 return msg, nil
256}
257
258// SendMessage queues a new outgoing message. It is safe to call from any
259// goroutine.
[280]260//
261// If the connection is closed before the message is sent, SendMessage silently
262// drops the message.
[210]263func (c *conn) SendMessage(msg *irc.Message) {
[280]264 c.lock.Lock()
265 defer c.lock.Unlock()
266
267 if c.closed {
[210]268 return
269 }
270 c.outgoing <- msg
271}
[384]272
273func (c *conn) RemoteAddr() net.Addr {
274 return c.conn.RemoteAddr()
275}
276
277func (c *conn) LocalAddr() net.Addr {
278 return c.conn.LocalAddr()
279}
Note: See TracBrowser for help on using the repository browser.