[822] | 1 | // +build !js
|
---|
| 2 |
|
---|
| 3 | package websocket
|
---|
| 4 |
|
---|
| 5 | import (
|
---|
| 6 | "bufio"
|
---|
| 7 | "bytes"
|
---|
| 8 | "context"
|
---|
| 9 | "crypto/rand"
|
---|
| 10 | "encoding/base64"
|
---|
| 11 | "fmt"
|
---|
| 12 | "io"
|
---|
| 13 | "io/ioutil"
|
---|
| 14 | "net/http"
|
---|
| 15 | "net/url"
|
---|
| 16 | "strings"
|
---|
| 17 | "sync"
|
---|
| 18 | "time"
|
---|
| 19 |
|
---|
| 20 | "nhooyr.io/websocket/internal/errd"
|
---|
| 21 | )
|
---|
| 22 |
|
---|
| 23 | // DialOptions represents Dial's options.
|
---|
| 24 | type DialOptions struct {
|
---|
| 25 | // HTTPClient is used for the connection.
|
---|
| 26 | // Its Transport must return writable bodies for WebSocket handshakes.
|
---|
| 27 | // http.Transport does beginning with Go 1.12.
|
---|
| 28 | HTTPClient *http.Client
|
---|
| 29 |
|
---|
| 30 | // HTTPHeader specifies the HTTP headers included in the handshake request.
|
---|
| 31 | HTTPHeader http.Header
|
---|
| 32 |
|
---|
| 33 | // Subprotocols lists the WebSocket subprotocols to negotiate with the server.
|
---|
| 34 | Subprotocols []string
|
---|
| 35 |
|
---|
| 36 | // CompressionMode controls the compression mode.
|
---|
| 37 | // Defaults to CompressionNoContextTakeover.
|
---|
| 38 | //
|
---|
| 39 | // See docs on CompressionMode for details.
|
---|
| 40 | CompressionMode CompressionMode
|
---|
| 41 |
|
---|
| 42 | // CompressionThreshold controls the minimum size of a message before compression is applied.
|
---|
| 43 | //
|
---|
| 44 | // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
|
---|
| 45 | // for CompressionContextTakeover.
|
---|
| 46 | CompressionThreshold int
|
---|
| 47 | }
|
---|
| 48 |
|
---|
| 49 | // Dial performs a WebSocket handshake on url.
|
---|
| 50 | //
|
---|
| 51 | // The response is the WebSocket handshake response from the server.
|
---|
| 52 | // You never need to close resp.Body yourself.
|
---|
| 53 | //
|
---|
| 54 | // If an error occurs, the returned response may be non nil.
|
---|
| 55 | // However, you can only read the first 1024 bytes of the body.
|
---|
| 56 | //
|
---|
| 57 | // This function requires at least Go 1.12 as it uses a new feature
|
---|
| 58 | // in net/http to perform WebSocket handshakes.
|
---|
| 59 | // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
|
---|
| 60 | //
|
---|
| 61 | // URLs with http/https schemes will work and are interpreted as ws/wss.
|
---|
| 62 | func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
|
---|
| 63 | return dial(ctx, u, opts, nil)
|
---|
| 64 | }
|
---|
| 65 |
|
---|
| 66 | func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
|
---|
| 67 | defer errd.Wrap(&err, "failed to WebSocket dial")
|
---|
| 68 |
|
---|
| 69 | if opts == nil {
|
---|
| 70 | opts = &DialOptions{}
|
---|
| 71 | }
|
---|
| 72 |
|
---|
| 73 | opts = &*opts
|
---|
| 74 | if opts.HTTPClient == nil {
|
---|
| 75 | opts.HTTPClient = http.DefaultClient
|
---|
| 76 | } else if opts.HTTPClient.Timeout > 0 {
|
---|
| 77 | var cancel context.CancelFunc
|
---|
| 78 |
|
---|
| 79 | ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
|
---|
| 80 | defer cancel()
|
---|
| 81 |
|
---|
| 82 | newClient := *opts.HTTPClient
|
---|
| 83 | newClient.Timeout = 0
|
---|
| 84 | opts.HTTPClient = &newClient
|
---|
| 85 | }
|
---|
| 86 |
|
---|
| 87 | if opts.HTTPHeader == nil {
|
---|
| 88 | opts.HTTPHeader = http.Header{}
|
---|
| 89 | }
|
---|
| 90 |
|
---|
| 91 | secWebSocketKey, err := secWebSocketKey(rand)
|
---|
| 92 | if err != nil {
|
---|
| 93 | return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
|
---|
| 94 | }
|
---|
| 95 |
|
---|
| 96 | var copts *compressionOptions
|
---|
| 97 | if opts.CompressionMode != CompressionDisabled {
|
---|
| 98 | copts = opts.CompressionMode.opts()
|
---|
| 99 | }
|
---|
| 100 |
|
---|
| 101 | resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
|
---|
| 102 | if err != nil {
|
---|
| 103 | return nil, resp, err
|
---|
| 104 | }
|
---|
| 105 | respBody := resp.Body
|
---|
| 106 | resp.Body = nil
|
---|
| 107 | defer func() {
|
---|
| 108 | if err != nil {
|
---|
| 109 | // We read a bit of the body for easier debugging.
|
---|
| 110 | r := io.LimitReader(respBody, 1024)
|
---|
| 111 |
|
---|
| 112 | timer := time.AfterFunc(time.Second*3, func() {
|
---|
| 113 | respBody.Close()
|
---|
| 114 | })
|
---|
| 115 | defer timer.Stop()
|
---|
| 116 |
|
---|
| 117 | b, _ := ioutil.ReadAll(r)
|
---|
| 118 | respBody.Close()
|
---|
| 119 | resp.Body = ioutil.NopCloser(bytes.NewReader(b))
|
---|
| 120 | }
|
---|
| 121 | }()
|
---|
| 122 |
|
---|
| 123 | copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
|
---|
| 124 | if err != nil {
|
---|
| 125 | return nil, resp, err
|
---|
| 126 | }
|
---|
| 127 |
|
---|
| 128 | rwc, ok := respBody.(io.ReadWriteCloser)
|
---|
| 129 | if !ok {
|
---|
| 130 | return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
|
---|
| 131 | }
|
---|
| 132 |
|
---|
| 133 | return newConn(connConfig{
|
---|
| 134 | subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
|
---|
| 135 | rwc: rwc,
|
---|
| 136 | client: true,
|
---|
| 137 | copts: copts,
|
---|
| 138 | flateThreshold: opts.CompressionThreshold,
|
---|
| 139 | br: getBufioReader(rwc),
|
---|
| 140 | bw: getBufioWriter(rwc),
|
---|
| 141 | }), resp, nil
|
---|
| 142 | }
|
---|
| 143 |
|
---|
| 144 | func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
|
---|
| 145 | u, err := url.Parse(urls)
|
---|
| 146 | if err != nil {
|
---|
| 147 | return nil, fmt.Errorf("failed to parse url: %w", err)
|
---|
| 148 | }
|
---|
| 149 |
|
---|
| 150 | switch u.Scheme {
|
---|
| 151 | case "ws":
|
---|
| 152 | u.Scheme = "http"
|
---|
| 153 | case "wss":
|
---|
| 154 | u.Scheme = "https"
|
---|
| 155 | case "http", "https":
|
---|
| 156 | default:
|
---|
| 157 | return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
|
---|
| 158 | }
|
---|
| 159 |
|
---|
| 160 | req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
---|
| 161 | req.Header = opts.HTTPHeader.Clone()
|
---|
| 162 | req.Header.Set("Connection", "Upgrade")
|
---|
| 163 | req.Header.Set("Upgrade", "websocket")
|
---|
| 164 | req.Header.Set("Sec-WebSocket-Version", "13")
|
---|
| 165 | req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
|
---|
| 166 | if len(opts.Subprotocols) > 0 {
|
---|
| 167 | req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
|
---|
| 168 | }
|
---|
| 169 | if copts != nil {
|
---|
| 170 | copts.setHeader(req.Header)
|
---|
| 171 | }
|
---|
| 172 |
|
---|
| 173 | resp, err := opts.HTTPClient.Do(req)
|
---|
| 174 | if err != nil {
|
---|
| 175 | return nil, fmt.Errorf("failed to send handshake request: %w", err)
|
---|
| 176 | }
|
---|
| 177 | return resp, nil
|
---|
| 178 | }
|
---|
| 179 |
|
---|
| 180 | func secWebSocketKey(rr io.Reader) (string, error) {
|
---|
| 181 | if rr == nil {
|
---|
| 182 | rr = rand.Reader
|
---|
| 183 | }
|
---|
| 184 | b := make([]byte, 16)
|
---|
| 185 | _, err := io.ReadFull(rr, b)
|
---|
| 186 | if err != nil {
|
---|
| 187 | return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
|
---|
| 188 | }
|
---|
| 189 | return base64.StdEncoding.EncodeToString(b), nil
|
---|
| 190 | }
|
---|
| 191 |
|
---|
| 192 | func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
|
---|
| 193 | if resp.StatusCode != http.StatusSwitchingProtocols {
|
---|
| 194 | return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
---|
| 195 | }
|
---|
| 196 |
|
---|
| 197 | if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
|
---|
| 198 | return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
|
---|
| 199 | }
|
---|
| 200 |
|
---|
| 201 | if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
|
---|
| 202 | return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
|
---|
| 203 | }
|
---|
| 204 |
|
---|
| 205 | if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
|
---|
| 206 | return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
|
---|
| 207 | resp.Header.Get("Sec-WebSocket-Accept"),
|
---|
| 208 | secWebSocketKey,
|
---|
| 209 | )
|
---|
| 210 | }
|
---|
| 211 |
|
---|
| 212 | err := verifySubprotocol(opts.Subprotocols, resp)
|
---|
| 213 | if err != nil {
|
---|
| 214 | return nil, err
|
---|
| 215 | }
|
---|
| 216 |
|
---|
| 217 | return verifyServerExtensions(copts, resp.Header)
|
---|
| 218 | }
|
---|
| 219 |
|
---|
| 220 | func verifySubprotocol(subprotos []string, resp *http.Response) error {
|
---|
| 221 | proto := resp.Header.Get("Sec-WebSocket-Protocol")
|
---|
| 222 | if proto == "" {
|
---|
| 223 | return nil
|
---|
| 224 | }
|
---|
| 225 |
|
---|
| 226 | for _, sp2 := range subprotos {
|
---|
| 227 | if strings.EqualFold(sp2, proto) {
|
---|
| 228 | return nil
|
---|
| 229 | }
|
---|
| 230 | }
|
---|
| 231 |
|
---|
| 232 | return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
|
---|
| 233 | }
|
---|
| 234 |
|
---|
| 235 | func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
|
---|
| 236 | exts := websocketExtensions(h)
|
---|
| 237 | if len(exts) == 0 {
|
---|
| 238 | return nil, nil
|
---|
| 239 | }
|
---|
| 240 |
|
---|
| 241 | ext := exts[0]
|
---|
| 242 | if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
|
---|
| 243 | return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
|
---|
| 244 | }
|
---|
| 245 |
|
---|
| 246 | copts = &*copts
|
---|
| 247 |
|
---|
| 248 | for _, p := range ext.params {
|
---|
| 249 | switch p {
|
---|
| 250 | case "client_no_context_takeover":
|
---|
| 251 | copts.clientNoContextTakeover = true
|
---|
| 252 | continue
|
---|
| 253 | case "server_no_context_takeover":
|
---|
| 254 | copts.serverNoContextTakeover = true
|
---|
| 255 | continue
|
---|
| 256 | }
|
---|
| 257 |
|
---|
| 258 | return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
|
---|
| 259 | }
|
---|
| 260 |
|
---|
| 261 | return copts, nil
|
---|
| 262 | }
|
---|
| 263 |
|
---|
| 264 | var bufioReaderPool sync.Pool
|
---|
| 265 |
|
---|
| 266 | func getBufioReader(r io.Reader) *bufio.Reader {
|
---|
| 267 | br, ok := bufioReaderPool.Get().(*bufio.Reader)
|
---|
| 268 | if !ok {
|
---|
| 269 | return bufio.NewReader(r)
|
---|
| 270 | }
|
---|
| 271 | br.Reset(r)
|
---|
| 272 | return br
|
---|
| 273 | }
|
---|
| 274 |
|
---|
| 275 | func putBufioReader(br *bufio.Reader) {
|
---|
| 276 | bufioReaderPool.Put(br)
|
---|
| 277 | }
|
---|
| 278 |
|
---|
| 279 | var bufioWriterPool sync.Pool
|
---|
| 280 |
|
---|
| 281 | func getBufioWriter(w io.Writer) *bufio.Writer {
|
---|
| 282 | bw, ok := bufioWriterPool.Get().(*bufio.Writer)
|
---|
| 283 | if !ok {
|
---|
| 284 | return bufio.NewWriter(w)
|
---|
| 285 | }
|
---|
| 286 | bw.Reset(w)
|
---|
| 287 | return bw
|
---|
| 288 | }
|
---|
| 289 |
|
---|
| 290 | func putBufioWriter(bw *bufio.Writer) {
|
---|
| 291 | bufioWriterPool.Put(bw)
|
---|
| 292 | }
|
---|