source: code/trunk/vendor/nhooyr.io/websocket/read.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 10.2 KB
RevLine 
[822]1// +build !js
2
3package websocket
4
5import (
6 "bufio"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "io/ioutil"
12 "strings"
13 "time"
14
15 "nhooyr.io/websocket/internal/errd"
16 "nhooyr.io/websocket/internal/xsync"
17)
18
19// Reader reads from the connection until until there is a WebSocket
20// data message to be read. It will handle ping, pong and close frames as appropriate.
21//
22// It returns the type of the message and an io.Reader to read it.
23// The passed context will also bound the reader.
24// Ensure you read to EOF otherwise the connection will hang.
25//
26// Call CloseRead if you do not expect any data messages from the peer.
27//
28// Only one Reader may be open at a time.
29func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
30 return c.reader(ctx)
31}
32
33// Read is a convenience method around Reader to read a single message
34// from the connection.
35func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
36 typ, r, err := c.Reader(ctx)
37 if err != nil {
38 return 0, nil, err
39 }
40
41 b, err := ioutil.ReadAll(r)
42 return typ, b, err
43}
44
45// CloseRead starts a goroutine to read from the connection until it is closed
46// or a data message is received.
47//
48// Once CloseRead is called you cannot read any messages from the connection.
49// The returned context will be cancelled when the connection is closed.
50//
51// If a data message is received, the connection will be closed with StatusPolicyViolation.
52//
53// Call CloseRead when you do not expect to read any more messages.
54// Since it actively reads from the connection, it will ensure that ping, pong and close
55// frames are responded to. This means c.Ping and c.Close will still work as expected.
56func (c *Conn) CloseRead(ctx context.Context) context.Context {
57 ctx, cancel := context.WithCancel(ctx)
58 go func() {
59 defer cancel()
60 c.Reader(ctx)
61 c.Close(StatusPolicyViolation, "unexpected data message")
62 }()
63 return ctx
64}
65
66// SetReadLimit sets the max number of bytes to read for a single message.
67// It applies to the Reader and Read methods.
68//
69// By default, the connection has a message read limit of 32768 bytes.
70//
71// When the limit is hit, the connection will be closed with StatusMessageTooBig.
72func (c *Conn) SetReadLimit(n int64) {
73 // We add read one more byte than the limit in case
74 // there is a fin frame that needs to be read.
75 c.msgReader.limitReader.limit.Store(n + 1)
76}
77
78const defaultReadLimit = 32768
79
80func newMsgReader(c *Conn) *msgReader {
81 mr := &msgReader{
82 c: c,
83 fin: true,
84 }
85 mr.readFunc = mr.read
86
87 mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
88 return mr
89}
90
91func (mr *msgReader) resetFlate() {
92 if mr.flateContextTakeover() {
93 mr.dict.init(32768)
94 }
95 if mr.flateBufio == nil {
96 mr.flateBufio = getBufioReader(mr.readFunc)
97 }
98
99 mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
100 mr.limitReader.r = mr.flateReader
101 mr.flateTail.Reset(deflateMessageTail)
102}
103
104func (mr *msgReader) putFlateReader() {
105 if mr.flateReader != nil {
106 putFlateReader(mr.flateReader)
107 mr.flateReader = nil
108 }
109}
110
111func (mr *msgReader) close() {
112 mr.c.readMu.forceLock()
113 mr.putFlateReader()
114 mr.dict.close()
115 if mr.flateBufio != nil {
116 putBufioReader(mr.flateBufio)
117 }
118
119 if mr.c.client {
120 putBufioReader(mr.c.br)
121 mr.c.br = nil
122 }
123}
124
125func (mr *msgReader) flateContextTakeover() bool {
126 if mr.c.client {
127 return !mr.c.copts.serverNoContextTakeover
128 }
129 return !mr.c.copts.clientNoContextTakeover
130}
131
132func (c *Conn) readRSV1Illegal(h header) bool {
133 // If compression is disabled, rsv1 is illegal.
134 if !c.flate() {
135 return true
136 }
137 // rsv1 is only allowed on data frames beginning messages.
138 if h.opcode != opText && h.opcode != opBinary {
139 return true
140 }
141 return false
142}
143
144func (c *Conn) readLoop(ctx context.Context) (header, error) {
145 for {
146 h, err := c.readFrameHeader(ctx)
147 if err != nil {
148 return header{}, err
149 }
150
151 if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
152 err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
153 c.writeError(StatusProtocolError, err)
154 return header{}, err
155 }
156
157 if !c.client && !h.masked {
158 return header{}, errors.New("received unmasked frame from client")
159 }
160
161 switch h.opcode {
162 case opClose, opPing, opPong:
163 err = c.handleControl(ctx, h)
164 if err != nil {
165 // Pass through CloseErrors when receiving a close frame.
166 if h.opcode == opClose && CloseStatus(err) != -1 {
167 return header{}, err
168 }
169 return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
170 }
171 case opContinuation, opText, opBinary:
172 return h, nil
173 default:
174 err := fmt.Errorf("received unknown opcode %v", h.opcode)
175 c.writeError(StatusProtocolError, err)
176 return header{}, err
177 }
178 }
179}
180
181func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
182 select {
183 case <-c.closed:
184 return header{}, c.closeErr
185 case c.readTimeout <- ctx:
186 }
187
188 h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
189 if err != nil {
190 select {
191 case <-c.closed:
192 return header{}, c.closeErr
193 case <-ctx.Done():
194 return header{}, ctx.Err()
195 default:
196 c.close(err)
197 return header{}, err
198 }
199 }
200
201 select {
202 case <-c.closed:
203 return header{}, c.closeErr
204 case c.readTimeout <- context.Background():
205 }
206
207 return h, nil
208}
209
210func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
211 select {
212 case <-c.closed:
213 return 0, c.closeErr
214 case c.readTimeout <- ctx:
215 }
216
217 n, err := io.ReadFull(c.br, p)
218 if err != nil {
219 select {
220 case <-c.closed:
221 return n, c.closeErr
222 case <-ctx.Done():
223 return n, ctx.Err()
224 default:
225 err = fmt.Errorf("failed to read frame payload: %w", err)
226 c.close(err)
227 return n, err
228 }
229 }
230
231 select {
232 case <-c.closed:
233 return n, c.closeErr
234 case c.readTimeout <- context.Background():
235 }
236
237 return n, err
238}
239
240func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
241 if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
242 err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
243 c.writeError(StatusProtocolError, err)
244 return err
245 }
246
247 if !h.fin {
248 err := errors.New("received fragmented control frame")
249 c.writeError(StatusProtocolError, err)
250 return err
251 }
252
253 ctx, cancel := context.WithTimeout(ctx, time.Second*5)
254 defer cancel()
255
256 b := c.readControlBuf[:h.payloadLength]
257 _, err = c.readFramePayload(ctx, b)
258 if err != nil {
259 return err
260 }
261
262 if h.masked {
263 mask(h.maskKey, b)
264 }
265
266 switch h.opcode {
267 case opPing:
268 return c.writeControl(ctx, opPong, b)
269 case opPong:
270 c.activePingsMu.Lock()
271 pong, ok := c.activePings[string(b)]
272 c.activePingsMu.Unlock()
273 if ok {
274 select {
275 case pong <- struct{}{}:
276 default:
277 }
278 }
279 return nil
280 }
281
282 defer func() {
283 c.readCloseFrameErr = err
284 }()
285
286 ce, err := parseClosePayload(b)
287 if err != nil {
288 err = fmt.Errorf("received invalid close payload: %w", err)
289 c.writeError(StatusProtocolError, err)
290 return err
291 }
292
293 err = fmt.Errorf("received close frame: %w", ce)
294 c.setCloseErr(err)
295 c.writeClose(ce.Code, ce.Reason)
296 c.close(err)
297 return err
298}
299
300func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
301 defer errd.Wrap(&err, "failed to get reader")
302
303 err = c.readMu.lock(ctx)
304 if err != nil {
305 return 0, nil, err
306 }
307 defer c.readMu.unlock()
308
309 if !c.msgReader.fin {
310 err = errors.New("previous message not read to completion")
311 c.close(fmt.Errorf("failed to get reader: %w", err))
312 return 0, nil, err
313 }
314
315 h, err := c.readLoop(ctx)
316 if err != nil {
317 return 0, nil, err
318 }
319
320 if h.opcode == opContinuation {
321 err := errors.New("received continuation frame without text or binary frame")
322 c.writeError(StatusProtocolError, err)
323 return 0, nil, err
324 }
325
326 c.msgReader.reset(ctx, h)
327
328 return MessageType(h.opcode), c.msgReader, nil
329}
330
331type msgReader struct {
332 c *Conn
333
334 ctx context.Context
335 flate bool
336 flateReader io.Reader
337 flateBufio *bufio.Reader
338 flateTail strings.Reader
339 limitReader *limitReader
340 dict slidingWindow
341
342 fin bool
343 payloadLength int64
344 maskKey uint32
345
346 // readerFunc(mr.Read) to avoid continuous allocations.
347 readFunc readerFunc
348}
349
350func (mr *msgReader) reset(ctx context.Context, h header) {
351 mr.ctx = ctx
352 mr.flate = h.rsv1
353 mr.limitReader.reset(mr.readFunc)
354
355 if mr.flate {
356 mr.resetFlate()
357 }
358
359 mr.setFrame(h)
360}
361
362func (mr *msgReader) setFrame(h header) {
363 mr.fin = h.fin
364 mr.payloadLength = h.payloadLength
365 mr.maskKey = h.maskKey
366}
367
368func (mr *msgReader) Read(p []byte) (n int, err error) {
369 err = mr.c.readMu.lock(mr.ctx)
370 if err != nil {
371 return 0, fmt.Errorf("failed to read: %w", err)
372 }
373 defer mr.c.readMu.unlock()
374
375 n, err = mr.limitReader.Read(p)
376 if mr.flate && mr.flateContextTakeover() {
377 p = p[:n]
378 mr.dict.write(p)
379 }
380 if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
381 mr.putFlateReader()
382 return n, io.EOF
383 }
384 if err != nil {
385 err = fmt.Errorf("failed to read: %w", err)
386 mr.c.close(err)
387 }
388 return n, err
389}
390
391func (mr *msgReader) read(p []byte) (int, error) {
392 for {
393 if mr.payloadLength == 0 {
394 if mr.fin {
395 if mr.flate {
396 return mr.flateTail.Read(p)
397 }
398 return 0, io.EOF
399 }
400
401 h, err := mr.c.readLoop(mr.ctx)
402 if err != nil {
403 return 0, err
404 }
405 if h.opcode != opContinuation {
406 err := errors.New("received new data message without finishing the previous message")
407 mr.c.writeError(StatusProtocolError, err)
408 return 0, err
409 }
410 mr.setFrame(h)
411
412 continue
413 }
414
415 if int64(len(p)) > mr.payloadLength {
416 p = p[:mr.payloadLength]
417 }
418
419 n, err := mr.c.readFramePayload(mr.ctx, p)
420 if err != nil {
421 return n, err
422 }
423
424 mr.payloadLength -= int64(n)
425
426 if !mr.c.client {
427 mr.maskKey = mask(mr.maskKey, p)
428 }
429
430 return n, nil
431 }
432}
433
434type limitReader struct {
435 c *Conn
436 r io.Reader
437 limit xsync.Int64
438 n int64
439}
440
441func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
442 lr := &limitReader{
443 c: c,
444 }
445 lr.limit.Store(limit)
446 lr.reset(r)
447 return lr
448}
449
450func (lr *limitReader) reset(r io.Reader) {
451 lr.n = lr.limit.Load()
452 lr.r = r
453}
454
455func (lr *limitReader) Read(p []byte) (int, error) {
456 if lr.n <= 0 {
457 err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
458 lr.c.writeError(StatusMessageTooBig, err)
459 return 0, err
460 }
461
462 if int64(len(p)) > lr.n {
463 p = p[:lr.n]
464 }
465 n, err := lr.r.Read(p)
466 lr.n -= int64(n)
467 return n, err
468}
469
470type readerFunc func(p []byte) (int, error)
471
472func (f readerFunc) Read(p []byte) (int, error) {
473 return f(p)
474}
Note: See TracBrowser for help on using the repository browser.