source: code/trunk/vendor/nhooyr.io/websocket/write.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 8.0 KB
RevLine 
[822]1// +build !js
2
3package websocket
4
5import (
6 "bufio"
7 "context"
8 "crypto/rand"
9 "encoding/binary"
10 "errors"
11 "fmt"
12 "io"
13 "time"
14
15 "github.com/klauspost/compress/flate"
16
17 "nhooyr.io/websocket/internal/errd"
18)
19
20// Writer returns a writer bounded by the context that will write
21// a WebSocket message of type dataType to the connection.
22//
23// You must close the writer once you have written the entire message.
24//
25// Only one writer can be open at a time, multiple calls will block until the previous writer
26// is closed.
27func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
28 w, err := c.writer(ctx, typ)
29 if err != nil {
30 return nil, fmt.Errorf("failed to get writer: %w", err)
31 }
32 return w, nil
33}
34
35// Write writes a message to the connection.
36//
37// See the Writer method if you want to stream a message.
38//
39// If compression is disabled or the threshold is not met, then it
40// will write the message in a single frame.
41func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
42 _, err := c.write(ctx, typ, p)
43 if err != nil {
44 return fmt.Errorf("failed to write msg: %w", err)
45 }
46 return nil
47}
48
49type msgWriter struct {
50 mw *msgWriterState
51 closed bool
52}
53
54func (mw *msgWriter) Write(p []byte) (int, error) {
55 if mw.closed {
56 return 0, errors.New("cannot use closed writer")
57 }
58 return mw.mw.Write(p)
59}
60
61func (mw *msgWriter) Close() error {
62 if mw.closed {
63 return errors.New("cannot use closed writer")
64 }
65 mw.closed = true
66 return mw.mw.Close()
67}
68
69type msgWriterState struct {
70 c *Conn
71
72 mu *mu
73 writeMu *mu
74
75 ctx context.Context
76 opcode opcode
77 flate bool
78
79 trimWriter *trimLastFourBytesWriter
80 dict slidingWindow
81}
82
83func newMsgWriterState(c *Conn) *msgWriterState {
84 mw := &msgWriterState{
85 c: c,
86 mu: newMu(c),
87 writeMu: newMu(c),
88 }
89 return mw
90}
91
92func (mw *msgWriterState) ensureFlate() {
93 if mw.trimWriter == nil {
94 mw.trimWriter = &trimLastFourBytesWriter{
95 w: writerFunc(mw.write),
96 }
97 }
98
99 mw.dict.init(8192)
100 mw.flate = true
101}
102
103func (mw *msgWriterState) flateContextTakeover() bool {
104 if mw.c.client {
105 return !mw.c.copts.clientNoContextTakeover
106 }
107 return !mw.c.copts.serverNoContextTakeover
108}
109
110func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
111 err := c.msgWriterState.reset(ctx, typ)
112 if err != nil {
113 return nil, err
114 }
115 return &msgWriter{
116 mw: c.msgWriterState,
117 closed: false,
118 }, nil
119}
120
121func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
122 mw, err := c.writer(ctx, typ)
123 if err != nil {
124 return 0, err
125 }
126
127 if !c.flate() {
128 defer c.msgWriterState.mu.unlock()
129 return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p)
130 }
131
132 n, err := mw.Write(p)
133 if err != nil {
134 return n, err
135 }
136
137 err = mw.Close()
138 return n, err
139}
140
141func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
142 err := mw.mu.lock(ctx)
143 if err != nil {
144 return err
145 }
146
147 mw.ctx = ctx
148 mw.opcode = opcode(typ)
149 mw.flate = false
150
151 mw.trimWriter.reset()
152
153 return nil
154}
155
156// Write writes the given bytes to the WebSocket connection.
157func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
158 err = mw.writeMu.lock(mw.ctx)
159 if err != nil {
160 return 0, fmt.Errorf("failed to write: %w", err)
161 }
162 defer mw.writeMu.unlock()
163
164 defer func() {
165 if err != nil {
166 err = fmt.Errorf("failed to write: %w", err)
167 mw.c.close(err)
168 }
169 }()
170
171 if mw.c.flate() {
172 // Only enables flate if the length crosses the
173 // threshold on the first frame
174 if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
175 mw.ensureFlate()
176 }
177 }
178
179 if mw.flate {
180 err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf)
181 if err != nil {
182 return 0, err
183 }
184 mw.dict.write(p)
185 return len(p), nil
186 }
187
188 return mw.write(p)
189}
190
191func (mw *msgWriterState) write(p []byte) (int, error) {
192 n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
193 if err != nil {
194 return n, fmt.Errorf("failed to write data frame: %w", err)
195 }
196 mw.opcode = opContinuation
197 return n, nil
198}
199
200// Close flushes the frame to the connection.
201func (mw *msgWriterState) Close() (err error) {
202 defer errd.Wrap(&err, "failed to close writer")
203
204 err = mw.writeMu.lock(mw.ctx)
205 if err != nil {
206 return err
207 }
208 defer mw.writeMu.unlock()
209
210 _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
211 if err != nil {
212 return fmt.Errorf("failed to write fin frame: %w", err)
213 }
214
215 if mw.flate && !mw.flateContextTakeover() {
216 mw.dict.close()
217 }
218 mw.mu.unlock()
219 return nil
220}
221
222func (mw *msgWriterState) close() {
223 if mw.c.client {
224 mw.c.writeFrameMu.forceLock()
225 putBufioWriter(mw.c.bw)
226 }
227
228 mw.writeMu.forceLock()
229 mw.dict.close()
230}
231
232func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
233 ctx, cancel := context.WithTimeout(ctx, time.Second*5)
234 defer cancel()
235
236 _, err := c.writeFrame(ctx, true, false, opcode, p)
237 if err != nil {
238 return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
239 }
240 return nil
241}
242
243// frame handles all writes to the connection.
244func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
245 err = c.writeFrameMu.lock(ctx)
246 if err != nil {
247 return 0, err
248 }
249 defer c.writeFrameMu.unlock()
250
251 // If the state says a close has already been written, we wait until
252 // the connection is closed and return that error.
253 //
254 // However, if the frame being written is a close, that means its the close from
255 // the state being set so we let it go through.
256 c.closeMu.Lock()
257 wroteClose := c.wroteClose
258 c.closeMu.Unlock()
259 if wroteClose && opcode != opClose {
260 select {
261 case <-ctx.Done():
262 return 0, ctx.Err()
263 case <-c.closed:
264 return 0, c.closeErr
265 }
266 }
267
268 select {
269 case <-c.closed:
270 return 0, c.closeErr
271 case c.writeTimeout <- ctx:
272 }
273
274 defer func() {
275 if err != nil {
276 select {
277 case <-c.closed:
278 err = c.closeErr
279 case <-ctx.Done():
280 err = ctx.Err()
281 }
282 c.close(err)
283 err = fmt.Errorf("failed to write frame: %w", err)
284 }
285 }()
286
287 c.writeHeader.fin = fin
288 c.writeHeader.opcode = opcode
289 c.writeHeader.payloadLength = int64(len(p))
290
291 if c.client {
292 c.writeHeader.masked = true
293 _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
294 if err != nil {
295 return 0, fmt.Errorf("failed to generate masking key: %w", err)
296 }
297 c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
298 }
299
300 c.writeHeader.rsv1 = false
301 if flate && (opcode == opText || opcode == opBinary) {
302 c.writeHeader.rsv1 = true
303 }
304
305 err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
306 if err != nil {
307 return 0, err
308 }
309
310 n, err := c.writeFramePayload(p)
311 if err != nil {
312 return n, err
313 }
314
315 if c.writeHeader.fin {
316 err = c.bw.Flush()
317 if err != nil {
318 return n, fmt.Errorf("failed to flush: %w", err)
319 }
320 }
321
322 select {
323 case <-c.closed:
324 return n, c.closeErr
325 case c.writeTimeout <- context.Background():
326 }
327
328 return n, nil
329}
330
331func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
332 defer errd.Wrap(&err, "failed to write frame payload")
333
334 if !c.writeHeader.masked {
335 return c.bw.Write(p)
336 }
337
338 maskKey := c.writeHeader.maskKey
339 for len(p) > 0 {
340 // If the buffer is full, we need to flush.
341 if c.bw.Available() == 0 {
342 err = c.bw.Flush()
343 if err != nil {
344 return n, err
345 }
346 }
347
348 // Start of next write in the buffer.
349 i := c.bw.Buffered()
350
351 j := len(p)
352 if j > c.bw.Available() {
353 j = c.bw.Available()
354 }
355
356 _, err := c.bw.Write(p[:j])
357 if err != nil {
358 return n, err
359 }
360
361 maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
362
363 p = p[j:]
364 n += j
365 }
366
367 return n, nil
368}
369
370type writerFunc func(p []byte) (int, error)
371
372func (f writerFunc) Write(p []byte) (int, error) {
373 return f(p)
374}
375
376// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
377// and returns it.
378func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
379 var writeBuf []byte
380 bw.Reset(writerFunc(func(p2 []byte) (int, error) {
381 writeBuf = p2[:cap(p2)]
382 return len(p2), nil
383 }))
384
385 bw.WriteByte(0)
386 bw.Flush()
387
388 bw.Reset(w)
389
390 return writeBuf
391}
392
393func (c *Conn) writeError(code StatusCode, err error) {
394 c.setCloseErr(err)
395 c.writeClose(code, err.Error())
396 c.close(nil)
397}
Note: See TracBrowser for help on using the repository browser.