[822] | 1 | // +build !js
|
---|
| 2 |
|
---|
| 3 | package websocket
|
---|
| 4 |
|
---|
| 5 | import (
|
---|
| 6 | "context"
|
---|
| 7 | "encoding/binary"
|
---|
| 8 | "errors"
|
---|
| 9 | "fmt"
|
---|
| 10 | "log"
|
---|
| 11 | "time"
|
---|
| 12 |
|
---|
| 13 | "nhooyr.io/websocket/internal/errd"
|
---|
| 14 | )
|
---|
| 15 |
|
---|
| 16 | // Close performs the WebSocket close handshake with the given status code and reason.
|
---|
| 17 | //
|
---|
| 18 | // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
|
---|
| 19 | // the peer to send a close frame.
|
---|
| 20 | // All data messages received from the peer during the close handshake will be discarded.
|
---|
| 21 | //
|
---|
| 22 | // The connection can only be closed once. Additional calls to Close
|
---|
| 23 | // are no-ops.
|
---|
| 24 | //
|
---|
| 25 | // The maximum length of reason must be 125 bytes. Avoid
|
---|
| 26 | // sending a dynamic reason.
|
---|
| 27 | //
|
---|
| 28 | // Close will unblock all goroutines interacting with the connection once
|
---|
| 29 | // complete.
|
---|
| 30 | func (c *Conn) Close(code StatusCode, reason string) error {
|
---|
| 31 | return c.closeHandshake(code, reason)
|
---|
| 32 | }
|
---|
| 33 |
|
---|
| 34 | func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
|
---|
| 35 | defer errd.Wrap(&err, "failed to close WebSocket")
|
---|
| 36 |
|
---|
| 37 | writeErr := c.writeClose(code, reason)
|
---|
| 38 | closeHandshakeErr := c.waitCloseHandshake()
|
---|
| 39 |
|
---|
| 40 | if writeErr != nil {
|
---|
| 41 | return writeErr
|
---|
| 42 | }
|
---|
| 43 |
|
---|
| 44 | if CloseStatus(closeHandshakeErr) == -1 {
|
---|
| 45 | return closeHandshakeErr
|
---|
| 46 | }
|
---|
| 47 |
|
---|
| 48 | return nil
|
---|
| 49 | }
|
---|
| 50 |
|
---|
| 51 | var errAlreadyWroteClose = errors.New("already wrote close")
|
---|
| 52 |
|
---|
| 53 | func (c *Conn) writeClose(code StatusCode, reason string) error {
|
---|
| 54 | c.closeMu.Lock()
|
---|
| 55 | wroteClose := c.wroteClose
|
---|
| 56 | c.wroteClose = true
|
---|
| 57 | c.closeMu.Unlock()
|
---|
| 58 | if wroteClose {
|
---|
| 59 | return errAlreadyWroteClose
|
---|
| 60 | }
|
---|
| 61 |
|
---|
| 62 | ce := CloseError{
|
---|
| 63 | Code: code,
|
---|
| 64 | Reason: reason,
|
---|
| 65 | }
|
---|
| 66 |
|
---|
| 67 | var p []byte
|
---|
| 68 | var marshalErr error
|
---|
| 69 | if ce.Code != StatusNoStatusRcvd {
|
---|
| 70 | p, marshalErr = ce.bytes()
|
---|
| 71 | if marshalErr != nil {
|
---|
| 72 | log.Printf("websocket: %v", marshalErr)
|
---|
| 73 | }
|
---|
| 74 | }
|
---|
| 75 |
|
---|
| 76 | writeErr := c.writeControl(context.Background(), opClose, p)
|
---|
| 77 | if CloseStatus(writeErr) != -1 {
|
---|
| 78 | // Not a real error if it's due to a close frame being received.
|
---|
| 79 | writeErr = nil
|
---|
| 80 | }
|
---|
| 81 |
|
---|
| 82 | // We do this after in case there was an error writing the close frame.
|
---|
| 83 | c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
|
---|
| 84 |
|
---|
| 85 | if marshalErr != nil {
|
---|
| 86 | return marshalErr
|
---|
| 87 | }
|
---|
| 88 | return writeErr
|
---|
| 89 | }
|
---|
| 90 |
|
---|
| 91 | func (c *Conn) waitCloseHandshake() error {
|
---|
| 92 | defer c.close(nil)
|
---|
| 93 |
|
---|
| 94 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
---|
| 95 | defer cancel()
|
---|
| 96 |
|
---|
| 97 | err := c.readMu.lock(ctx)
|
---|
| 98 | if err != nil {
|
---|
| 99 | return err
|
---|
| 100 | }
|
---|
| 101 | defer c.readMu.unlock()
|
---|
| 102 |
|
---|
| 103 | if c.readCloseFrameErr != nil {
|
---|
| 104 | return c.readCloseFrameErr
|
---|
| 105 | }
|
---|
| 106 |
|
---|
| 107 | for {
|
---|
| 108 | h, err := c.readLoop(ctx)
|
---|
| 109 | if err != nil {
|
---|
| 110 | return err
|
---|
| 111 | }
|
---|
| 112 |
|
---|
| 113 | for i := int64(0); i < h.payloadLength; i++ {
|
---|
| 114 | _, err := c.br.ReadByte()
|
---|
| 115 | if err != nil {
|
---|
| 116 | return err
|
---|
| 117 | }
|
---|
| 118 | }
|
---|
| 119 | }
|
---|
| 120 | }
|
---|
| 121 |
|
---|
| 122 | func parseClosePayload(p []byte) (CloseError, error) {
|
---|
| 123 | if len(p) == 0 {
|
---|
| 124 | return CloseError{
|
---|
| 125 | Code: StatusNoStatusRcvd,
|
---|
| 126 | }, nil
|
---|
| 127 | }
|
---|
| 128 |
|
---|
| 129 | if len(p) < 2 {
|
---|
| 130 | return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
|
---|
| 131 | }
|
---|
| 132 |
|
---|
| 133 | ce := CloseError{
|
---|
| 134 | Code: StatusCode(binary.BigEndian.Uint16(p)),
|
---|
| 135 | Reason: string(p[2:]),
|
---|
| 136 | }
|
---|
| 137 |
|
---|
| 138 | if !validWireCloseCode(ce.Code) {
|
---|
| 139 | return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
|
---|
| 140 | }
|
---|
| 141 |
|
---|
| 142 | return ce, nil
|
---|
| 143 | }
|
---|
| 144 |
|
---|
| 145 | // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
|
---|
| 146 | // and https://tools.ietf.org/html/rfc6455#section-7.4.1
|
---|
| 147 | func validWireCloseCode(code StatusCode) bool {
|
---|
| 148 | switch code {
|
---|
| 149 | case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
|
---|
| 150 | return false
|
---|
| 151 | }
|
---|
| 152 |
|
---|
| 153 | if code >= StatusNormalClosure && code <= StatusBadGateway {
|
---|
| 154 | return true
|
---|
| 155 | }
|
---|
| 156 | if code >= 3000 && code <= 4999 {
|
---|
| 157 | return true
|
---|
| 158 | }
|
---|
| 159 |
|
---|
| 160 | return false
|
---|
| 161 | }
|
---|
| 162 |
|
---|
| 163 | func (ce CloseError) bytes() ([]byte, error) {
|
---|
| 164 | p, err := ce.bytesErr()
|
---|
| 165 | if err != nil {
|
---|
| 166 | err = fmt.Errorf("failed to marshal close frame: %w", err)
|
---|
| 167 | ce = CloseError{
|
---|
| 168 | Code: StatusInternalError,
|
---|
| 169 | }
|
---|
| 170 | p, _ = ce.bytesErr()
|
---|
| 171 | }
|
---|
| 172 | return p, err
|
---|
| 173 | }
|
---|
| 174 |
|
---|
| 175 | const maxCloseReason = maxControlPayload - 2
|
---|
| 176 |
|
---|
| 177 | func (ce CloseError) bytesErr() ([]byte, error) {
|
---|
| 178 | if len(ce.Reason) > maxCloseReason {
|
---|
| 179 | return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
|
---|
| 180 | }
|
---|
| 181 |
|
---|
| 182 | if !validWireCloseCode(ce.Code) {
|
---|
| 183 | return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
|
---|
| 184 | }
|
---|
| 185 |
|
---|
| 186 | buf := make([]byte, 2+len(ce.Reason))
|
---|
| 187 | binary.BigEndian.PutUint16(buf, uint16(ce.Code))
|
---|
| 188 | copy(buf[2:], ce.Reason)
|
---|
| 189 | return buf, nil
|
---|
| 190 | }
|
---|
| 191 |
|
---|
| 192 | func (c *Conn) setCloseErr(err error) {
|
---|
| 193 | c.closeMu.Lock()
|
---|
| 194 | c.setCloseErrLocked(err)
|
---|
| 195 | c.closeMu.Unlock()
|
---|
| 196 | }
|
---|
| 197 |
|
---|
| 198 | func (c *Conn) setCloseErrLocked(err error) {
|
---|
| 199 | if c.closeErr == nil {
|
---|
| 200 | c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
|
---|
| 201 | }
|
---|
| 202 | }
|
---|
| 203 |
|
---|
| 204 | func (c *Conn) isClosed() bool {
|
---|
| 205 | select {
|
---|
| 206 | case <-c.closed:
|
---|
| 207 | return true
|
---|
| 208 | default:
|
---|
| 209 | return false
|
---|
| 210 | }
|
---|
| 211 | }
|
---|