source: code/trunk/conn.go@ 777

Last change on this file since 777 was 757, checked in by contact, 4 years ago

Add context to {conn,upstreamConn}.SendMessage

This avoids blocking on upstream message rate limiting for too
long.

File size: 6.1 KB
RevLine 
[210]1package soju
2
3import (
[323]4 "context"
[210]5 "fmt"
[341]6 "io"
[210]7 "net"
[415]8 "strings"
[280]9 "sync"
[210]10 "time"
[415]11 "unicode"
[210]12
[741]13 "golang.org/x/time/rate"
[210]14 "gopkg.in/irc.v3"
[323]15 "nhooyr.io/websocket"
[210]16)
17
[315]18// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
19// reading and writing IRC messages.
20type ircConn interface {
21 ReadMessage() (*irc.Message, error)
22 WriteMessage(*irc.Message) error
23 Close() error
[323]24 SetReadDeadline(time.Time) error
[315]25 SetWriteDeadline(time.Time) error
[347]26 RemoteAddr() net.Addr
[383]27 LocalAddr() net.Addr
[315]28}
29
[323]30func newNetIRCConn(c net.Conn) ircConn {
[315]31 type netConn net.Conn
32 return struct {
33 *irc.Conn
34 netConn
35 }{irc.NewConn(c), c}
36}
37
[323]38type websocketIRCConn struct {
39 conn *websocket.Conn
40 readDeadline, writeDeadline time.Time
[347]41 remoteAddr string
[323]42}
43
[347]44func newWebsocketIRCConn(c *websocket.Conn, remoteAddr string) ircConn {
[465]45 return &websocketIRCConn{conn: c, remoteAddr: remoteAddr}
[323]46}
47
[465]48func (wic *websocketIRCConn) ReadMessage() (*irc.Message, error) {
[323]49 ctx := context.Background()
50 if !wic.readDeadline.IsZero() {
51 var cancel context.CancelFunc
52 ctx, cancel = context.WithDeadline(ctx, wic.readDeadline)
53 defer cancel()
54 }
55 _, b, err := wic.conn.Read(ctx)
56 if err != nil {
[341]57 switch websocket.CloseStatus(err) {
58 case websocket.StatusNormalClosure, websocket.StatusGoingAway:
59 return nil, io.EOF
60 default:
61 return nil, err
62 }
[323]63 }
64 return irc.ParseMessage(string(b))
65}
66
[465]67func (wic *websocketIRCConn) WriteMessage(msg *irc.Message) error {
[415]68 b := []byte(strings.ToValidUTF8(msg.String(), string(unicode.ReplacementChar)))
[323]69 ctx := context.Background()
70 if !wic.writeDeadline.IsZero() {
71 var cancel context.CancelFunc
72 ctx, cancel = context.WithDeadline(ctx, wic.writeDeadline)
73 defer cancel()
74 }
75 return wic.conn.Write(ctx, websocket.MessageText, b)
76}
77
[594]78func isErrWebSocketClosed(err error) bool {
79 return err != nil && strings.HasSuffix(err.Error(), "failed to close WebSocket: already wrote close")
80}
81
[465]82func (wic *websocketIRCConn) Close() error {
[594]83 err := wic.conn.Close(websocket.StatusNormalClosure, "")
84 // TODO: remove once this PR is merged:
85 // https://github.com/nhooyr/websocket/pull/303
86 if isErrWebSocketClosed(err) {
87 return nil
88 }
89 return err
[323]90}
91
[465]92func (wic *websocketIRCConn) SetReadDeadline(t time.Time) error {
[323]93 wic.readDeadline = t
94 return nil
95}
96
[465]97func (wic *websocketIRCConn) SetWriteDeadline(t time.Time) error {
[323]98 wic.writeDeadline = t
99 return nil
100}
101
[465]102func (wic *websocketIRCConn) RemoteAddr() net.Addr {
[347]103 return websocketAddr(wic.remoteAddr)
104}
105
[465]106func (wic *websocketIRCConn) LocalAddr() net.Addr {
[383]107 // Behind a reverse HTTP proxy, we don't have access to the real listening
108 // address
109 return websocketAddr("")
110}
111
[347]112type websocketAddr string
113
114func (websocketAddr) Network() string {
115 return "ws"
116}
117
118func (wa websocketAddr) String() string {
119 return string(wa)
120}
121
[398]122type connOptions struct {
[402]123 Logger Logger
[398]124 RateLimitDelay time.Duration
125 RateLimitBurst int
126}
127
[210]128type conn struct {
[315]129 conn ircConn
[280]130 srv *Server
131 logger Logger
132
133 lock sync.Mutex
[210]134 outgoing chan<- *irc.Message
[280]135 closed bool
[703]136 closedCh chan struct{}
[210]137}
138
[398]139func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
[210]140 outgoing := make(chan *irc.Message, 64)
141 c := &conn{
[315]142 conn: ic,
[210]143 srv: srv,
144 outgoing: outgoing,
[398]145 logger: options.Logger,
[703]146 closedCh: make(chan struct{}),
[210]147 }
148
149 go func() {
[741]150 ctx, cancel := c.NewContext(context.Background())
151 defer cancel()
[398]152
[741]153 rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
[210]154 for msg := range outgoing {
[741]155 if err := rl.Wait(ctx); err != nil {
156 break
[398]157 }
158
[747]159 c.logger.Debugf("sent: %v", msg)
[315]160 c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
161 if err := c.conn.WriteMessage(msg); err != nil {
[210]162 c.logger.Printf("failed to write message: %v", err)
163 break
164 }
165 }
[519]166 if err := c.conn.Close(); err != nil && !isErrClosed(err) {
[210]167 c.logger.Printf("failed to close connection: %v", err)
168 } else {
[748]169 c.logger.Debugf("connection closed")
[210]170 }
171 // Drain the outgoing channel to prevent SendMessage from blocking
172 for range outgoing {
173 // This space is intentionally left blank
174 }
175 }()
176
[748]177 c.logger.Debugf("new connection")
[210]178 return c
179}
180
181func (c *conn) isClosed() bool {
[280]182 c.lock.Lock()
183 defer c.lock.Unlock()
184 return c.closed
[210]185}
186
187// Close closes the connection. It is safe to call from any goroutine.
188func (c *conn) Close() error {
[280]189 c.lock.Lock()
190 defer c.lock.Unlock()
191
192 if c.closed {
[210]193 return fmt.Errorf("connection already closed")
194 }
[280]195
[315]196 err := c.conn.Close()
[280]197 c.closed = true
[210]198 close(c.outgoing)
[703]199 close(c.closedCh)
[312]200 return err
[210]201}
202
203func (c *conn) ReadMessage() (*irc.Message, error) {
[315]204 msg, err := c.conn.ReadMessage()
[519]205 if isErrClosed(err) {
206 return nil, io.EOF
207 } else if err != nil {
[210]208 return nil, err
209 }
210
[747]211 c.logger.Debugf("received: %v", msg)
[210]212 return msg, nil
213}
214
215// SendMessage queues a new outgoing message. It is safe to call from any
216// goroutine.
[280]217//
218// If the connection is closed before the message is sent, SendMessage silently
219// drops the message.
[757]220func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
[280]221 c.lock.Lock()
222 defer c.lock.Unlock()
223
224 if c.closed {
[210]225 return
226 }
[757]227
228 select {
229 case c.outgoing <- msg:
230 // Success
231 case <-ctx.Done():
232 c.logger.Printf("failed to send message: %v", ctx.Err())
233 }
[210]234}
[384]235
236func (c *conn) RemoteAddr() net.Addr {
237 return c.conn.RemoteAddr()
238}
239
240func (c *conn) LocalAddr() net.Addr {
241 return c.conn.LocalAddr()
242}
[703]243
244// NewContext returns a copy of the parent context with a new Done channel. The
245// returned context's Done channel is closed when the connection is closed,
246// when the returned cancel function is called, or when the parent context's
247// Done channel is closed, whichever happens first.
248//
249// Canceling this context releases resources associated with it, so code should
250// call cancel as soon as the operations running in this Context complete.
251func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) {
252 ctx, cancel := context.WithCancel(parent)
253
254 go func() {
255 defer cancel()
256
257 select {
258 case <-ctx.Done():
259 // The parent context has been cancelled, or the caller has called
260 // cancel()
261 case <-c.closedCh:
262 // The connection has been closed
263 }
264 }()
265
266 return ctx, cancel
267}
Note: See TracBrowser for help on using the repository browser.