source: code/trunk/conn.go@ 743

Last change on this file since 743 was 741, checked in by contact, 4 years ago

Use golang.org/x/time/rate

Instead of hand-rolling our own rate-limiter based on goroutines,
use golang.org/x/time/rate.

File size: 6.0 KB
RevLine 
[210]1package soju
2
3import (
[323]4 "context"
[210]5 "fmt"
[341]6 "io"
[210]7 "net"
[415]8 "strings"
[280]9 "sync"
[210]10 "time"
[415]11 "unicode"
[210]12
[741]13 "golang.org/x/time/rate"
[210]14 "gopkg.in/irc.v3"
[323]15 "nhooyr.io/websocket"
[210]16)
17
[315]18// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
19// reading and writing IRC messages.
20type ircConn interface {
21 ReadMessage() (*irc.Message, error)
22 WriteMessage(*irc.Message) error
23 Close() error
[323]24 SetReadDeadline(time.Time) error
[315]25 SetWriteDeadline(time.Time) error
[347]26 RemoteAddr() net.Addr
[383]27 LocalAddr() net.Addr
[315]28}
29
[323]30func newNetIRCConn(c net.Conn) ircConn {
[315]31 type netConn net.Conn
32 return struct {
33 *irc.Conn
34 netConn
35 }{irc.NewConn(c), c}
36}
37
[323]38type websocketIRCConn struct {
39 conn *websocket.Conn
40 readDeadline, writeDeadline time.Time
[347]41 remoteAddr string
[323]42}
43
[347]44func newWebsocketIRCConn(c *websocket.Conn, remoteAddr string) ircConn {
[465]45 return &websocketIRCConn{conn: c, remoteAddr: remoteAddr}
[323]46}
47
[465]48func (wic *websocketIRCConn) ReadMessage() (*irc.Message, error) {
[323]49 ctx := context.Background()
50 if !wic.readDeadline.IsZero() {
51 var cancel context.CancelFunc
52 ctx, cancel = context.WithDeadline(ctx, wic.readDeadline)
53 defer cancel()
54 }
55 _, b, err := wic.conn.Read(ctx)
56 if err != nil {
[341]57 switch websocket.CloseStatus(err) {
58 case websocket.StatusNormalClosure, websocket.StatusGoingAway:
59 return nil, io.EOF
60 default:
61 return nil, err
62 }
[323]63 }
64 return irc.ParseMessage(string(b))
65}
66
[465]67func (wic *websocketIRCConn) WriteMessage(msg *irc.Message) error {
[415]68 b := []byte(strings.ToValidUTF8(msg.String(), string(unicode.ReplacementChar)))
[323]69 ctx := context.Background()
70 if !wic.writeDeadline.IsZero() {
71 var cancel context.CancelFunc
72 ctx, cancel = context.WithDeadline(ctx, wic.writeDeadline)
73 defer cancel()
74 }
75 return wic.conn.Write(ctx, websocket.MessageText, b)
76}
77
[594]78func isErrWebSocketClosed(err error) bool {
79 return err != nil && strings.HasSuffix(err.Error(), "failed to close WebSocket: already wrote close")
80}
81
[465]82func (wic *websocketIRCConn) Close() error {
[594]83 err := wic.conn.Close(websocket.StatusNormalClosure, "")
84 // TODO: remove once this PR is merged:
85 // https://github.com/nhooyr/websocket/pull/303
86 if isErrWebSocketClosed(err) {
87 return nil
88 }
89 return err
[323]90}
91
[465]92func (wic *websocketIRCConn) SetReadDeadline(t time.Time) error {
[323]93 wic.readDeadline = t
94 return nil
95}
96
[465]97func (wic *websocketIRCConn) SetWriteDeadline(t time.Time) error {
[323]98 wic.writeDeadline = t
99 return nil
100}
101
[465]102func (wic *websocketIRCConn) RemoteAddr() net.Addr {
[347]103 return websocketAddr(wic.remoteAddr)
104}
105
[465]106func (wic *websocketIRCConn) LocalAddr() net.Addr {
[383]107 // Behind a reverse HTTP proxy, we don't have access to the real listening
108 // address
109 return websocketAddr("")
110}
111
[347]112type websocketAddr string
113
114func (websocketAddr) Network() string {
115 return "ws"
116}
117
118func (wa websocketAddr) String() string {
119 return string(wa)
120}
121
[398]122type connOptions struct {
[402]123 Logger Logger
[398]124 RateLimitDelay time.Duration
125 RateLimitBurst int
126}
127
[210]128type conn struct {
[315]129 conn ircConn
[280]130 srv *Server
131 logger Logger
132
133 lock sync.Mutex
[210]134 outgoing chan<- *irc.Message
[280]135 closed bool
[703]136 closedCh chan struct{}
[210]137}
138
[398]139func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
[210]140 outgoing := make(chan *irc.Message, 64)
141 c := &conn{
[315]142 conn: ic,
[210]143 srv: srv,
144 outgoing: outgoing,
[398]145 logger: options.Logger,
[703]146 closedCh: make(chan struct{}),
[210]147 }
148
149 go func() {
[741]150 ctx, cancel := c.NewContext(context.Background())
151 defer cancel()
[398]152
[741]153 rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
[210]154 for msg := range outgoing {
[741]155 if err := rl.Wait(ctx); err != nil {
156 break
[398]157 }
158
[691]159 if c.srv.Config().Debug {
[210]160 c.logger.Printf("sent: %v", msg)
161 }
[315]162 c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
163 if err := c.conn.WriteMessage(msg); err != nil {
[210]164 c.logger.Printf("failed to write message: %v", err)
165 break
166 }
167 }
[519]168 if err := c.conn.Close(); err != nil && !isErrClosed(err) {
[210]169 c.logger.Printf("failed to close connection: %v", err)
170 } else {
171 c.logger.Printf("connection closed")
172 }
173 // Drain the outgoing channel to prevent SendMessage from blocking
174 for range outgoing {
175 // This space is intentionally left blank
176 }
177 }()
178
179 c.logger.Printf("new connection")
180 return c
181}
182
183func (c *conn) isClosed() bool {
[280]184 c.lock.Lock()
185 defer c.lock.Unlock()
186 return c.closed
[210]187}
188
189// Close closes the connection. It is safe to call from any goroutine.
190func (c *conn) Close() error {
[280]191 c.lock.Lock()
192 defer c.lock.Unlock()
193
194 if c.closed {
[210]195 return fmt.Errorf("connection already closed")
196 }
[280]197
[315]198 err := c.conn.Close()
[280]199 c.closed = true
[210]200 close(c.outgoing)
[703]201 close(c.closedCh)
[312]202 return err
[210]203}
204
205func (c *conn) ReadMessage() (*irc.Message, error) {
[315]206 msg, err := c.conn.ReadMessage()
[519]207 if isErrClosed(err) {
208 return nil, io.EOF
209 } else if err != nil {
[210]210 return nil, err
211 }
212
[691]213 if c.srv.Config().Debug {
[210]214 c.logger.Printf("received: %v", msg)
215 }
216
217 return msg, nil
218}
219
220// SendMessage queues a new outgoing message. It is safe to call from any
221// goroutine.
[280]222//
223// If the connection is closed before the message is sent, SendMessage silently
224// drops the message.
[210]225func (c *conn) SendMessage(msg *irc.Message) {
[280]226 c.lock.Lock()
227 defer c.lock.Unlock()
228
229 if c.closed {
[210]230 return
231 }
232 c.outgoing <- msg
233}
[384]234
235func (c *conn) RemoteAddr() net.Addr {
236 return c.conn.RemoteAddr()
237}
238
239func (c *conn) LocalAddr() net.Addr {
240 return c.conn.LocalAddr()
241}
[703]242
243// NewContext returns a copy of the parent context with a new Done channel. The
244// returned context's Done channel is closed when the connection is closed,
245// when the returned cancel function is called, or when the parent context's
246// Done channel is closed, whichever happens first.
247//
248// Canceling this context releases resources associated with it, so code should
249// call cancel as soon as the operations running in this Context complete.
250func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) {
251 ctx, cancel := context.WithCancel(parent)
252
253 go func() {
254 defer cancel()
255
256 select {
257 case <-ctx.Done():
258 // The parent context has been cancelled, or the caller has called
259 // cancel()
260 case <-c.closedCh:
261 // The connection has been closed
262 }
263 }()
264
265 return ctx, cancel
266}
Note: See TracBrowser for help on using the repository browser.