source: code/trunk/conn.go@ 441

Last change on this file since 441 was 415, checked in by hubert, 5 years ago

Make sure that WebSocket messages are valid UTF-8

... by replacing invalid bytes with the REPLACEMENT CHARACTER U+FFFD

This is better than:

  • discarding the whole message, since the user would not see it...
  • removing invalid bytes, since the user would not see their presence,
  • converting the encoding (this is actually not possible).

Contrary to its documentation, strings.ToValidUTF8 doesn't copy the
string if it's valid UTF-8:
<https://golang.org/src/strings/strings.go?s=15815:15861#L623>

File size: 5.4 KB
RevLine 
[210]1package soju
2
3import (
[323]4 "context"
[210]5 "fmt"
[341]6 "io"
[210]7 "net"
[415]8 "strings"
[280]9 "sync"
[210]10 "time"
[415]11 "unicode"
[210]12
13 "gopkg.in/irc.v3"
[323]14 "nhooyr.io/websocket"
[210]15)
16
[315]17// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
18// reading and writing IRC messages.
19type ircConn interface {
20 ReadMessage() (*irc.Message, error)
21 WriteMessage(*irc.Message) error
22 Close() error
[323]23 SetReadDeadline(time.Time) error
[315]24 SetWriteDeadline(time.Time) error
[347]25 RemoteAddr() net.Addr
[383]26 LocalAddr() net.Addr
[315]27}
28
[323]29func newNetIRCConn(c net.Conn) ircConn {
[315]30 type netConn net.Conn
31 return struct {
32 *irc.Conn
33 netConn
34 }{irc.NewConn(c), c}
35}
36
[323]37type websocketIRCConn struct {
38 conn *websocket.Conn
39 readDeadline, writeDeadline time.Time
[347]40 remoteAddr string
[323]41}
42
[347]43func newWebsocketIRCConn(c *websocket.Conn, remoteAddr string) ircConn {
44 return websocketIRCConn{conn: c, remoteAddr: remoteAddr}
[323]45}
46
47func (wic websocketIRCConn) ReadMessage() (*irc.Message, error) {
48 ctx := context.Background()
49 if !wic.readDeadline.IsZero() {
50 var cancel context.CancelFunc
51 ctx, cancel = context.WithDeadline(ctx, wic.readDeadline)
52 defer cancel()
53 }
54 _, b, err := wic.conn.Read(ctx)
55 if err != nil {
[341]56 switch websocket.CloseStatus(err) {
57 case websocket.StatusNormalClosure, websocket.StatusGoingAway:
58 return nil, io.EOF
59 default:
60 return nil, err
61 }
[323]62 }
63 return irc.ParseMessage(string(b))
64}
65
66func (wic websocketIRCConn) WriteMessage(msg *irc.Message) error {
[415]67 b := []byte(strings.ToValidUTF8(msg.String(), string(unicode.ReplacementChar)))
[323]68 ctx := context.Background()
69 if !wic.writeDeadline.IsZero() {
70 var cancel context.CancelFunc
71 ctx, cancel = context.WithDeadline(ctx, wic.writeDeadline)
72 defer cancel()
73 }
74 return wic.conn.Write(ctx, websocket.MessageText, b)
75}
76
77func (wic websocketIRCConn) Close() error {
78 return wic.conn.Close(websocket.StatusNormalClosure, "")
79}
80
81func (wic websocketIRCConn) SetReadDeadline(t time.Time) error {
82 wic.readDeadline = t
83 return nil
84}
85
86func (wic websocketIRCConn) SetWriteDeadline(t time.Time) error {
87 wic.writeDeadline = t
88 return nil
89}
90
[347]91func (wic websocketIRCConn) RemoteAddr() net.Addr {
92 return websocketAddr(wic.remoteAddr)
93}
94
[383]95func (wic websocketIRCConn) LocalAddr() net.Addr {
96 // Behind a reverse HTTP proxy, we don't have access to the real listening
97 // address
98 return websocketAddr("")
99}
100
[347]101type websocketAddr string
102
103func (websocketAddr) Network() string {
104 return "ws"
105}
106
107func (wa websocketAddr) String() string {
108 return string(wa)
109}
110
[398]111type rateLimiter struct {
[402]112 C <-chan struct{}
113 ticker *time.Ticker
[398]114 stopped chan struct{}
115}
116
117func newRateLimiter(delay time.Duration, burst int) *rateLimiter {
118 ch := make(chan struct{}, burst)
119 for i := 0; i < burst; i++ {
120 ch <- struct{}{}
121 }
122 ticker := time.NewTicker(delay)
123 stopped := make(chan struct{})
124 go func() {
125 for {
126 select {
127 case <-ticker.C:
128 select {
129 case ch <- struct{}{}:
130 // This space is intentionally left blank
131 case <-stopped:
132 return
133 }
134 case <-stopped:
135 return
136 }
137 }
138 }()
139 return &rateLimiter{
[402]140 C: ch,
141 ticker: ticker,
[398]142 stopped: stopped,
143 }
144}
145
146func (rl *rateLimiter) Stop() {
147 rl.ticker.Stop()
148 close(rl.stopped)
149}
150
151type connOptions struct {
[402]152 Logger Logger
[398]153 RateLimitDelay time.Duration
154 RateLimitBurst int
155}
156
[210]157type conn struct {
[315]158 conn ircConn
[280]159 srv *Server
160 logger Logger
161
162 lock sync.Mutex
[210]163 outgoing chan<- *irc.Message
[280]164 closed bool
[210]165}
166
[398]167func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
[210]168 outgoing := make(chan *irc.Message, 64)
169 c := &conn{
[315]170 conn: ic,
[210]171 srv: srv,
172 outgoing: outgoing,
[398]173 logger: options.Logger,
[210]174 }
175
176 go func() {
[398]177 var rl *rateLimiter
178 if options.RateLimitDelay > 0 && options.RateLimitBurst > 0 {
179 rl = newRateLimiter(options.RateLimitDelay, options.RateLimitBurst)
180 defer rl.Stop()
181 }
182
[210]183 for msg := range outgoing {
[398]184 if rl != nil {
185 <-rl.C
186 }
187
[210]188 if c.srv.Debug {
189 c.logger.Printf("sent: %v", msg)
190 }
[315]191 c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
192 if err := c.conn.WriteMessage(msg); err != nil {
[210]193 c.logger.Printf("failed to write message: %v", err)
194 break
195 }
196 }
[315]197 if err := c.conn.Close(); err != nil {
[210]198 c.logger.Printf("failed to close connection: %v", err)
199 } else {
200 c.logger.Printf("connection closed")
201 }
202 // Drain the outgoing channel to prevent SendMessage from blocking
203 for range outgoing {
204 // This space is intentionally left blank
205 }
206 }()
207
208 c.logger.Printf("new connection")
209 return c
210}
211
212func (c *conn) isClosed() bool {
[280]213 c.lock.Lock()
214 defer c.lock.Unlock()
215 return c.closed
[210]216}
217
218// Close closes the connection. It is safe to call from any goroutine.
219func (c *conn) Close() error {
[280]220 c.lock.Lock()
221 defer c.lock.Unlock()
222
223 if c.closed {
[210]224 return fmt.Errorf("connection already closed")
225 }
[280]226
[315]227 err := c.conn.Close()
[280]228 c.closed = true
[210]229 close(c.outgoing)
[312]230 return err
[210]231}
232
233func (c *conn) ReadMessage() (*irc.Message, error) {
[315]234 msg, err := c.conn.ReadMessage()
[210]235 if err != nil {
236 return nil, err
237 }
238
239 if c.srv.Debug {
240 c.logger.Printf("received: %v", msg)
241 }
242
243 return msg, nil
244}
245
246// SendMessage queues a new outgoing message. It is safe to call from any
247// goroutine.
[280]248//
249// If the connection is closed before the message is sent, SendMessage silently
250// drops the message.
[210]251func (c *conn) SendMessage(msg *irc.Message) {
[280]252 c.lock.Lock()
253 defer c.lock.Unlock()
254
255 if c.closed {
[210]256 return
257 }
258 c.outgoing <- msg
259}
[384]260
261func (c *conn) RemoteAddr() net.Addr {
262 return c.conn.RemoteAddr()
263}
264
265func (c *conn) LocalAddr() net.Addr {
266 return c.conn.LocalAddr()
267}
Note: See TracBrowser for help on using the repository browser.