source: code/trunk/vendor/nhooyr.io/websocket/dial.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 7.7 KB
RevLine 
[822]1// +build !js
2
3package websocket
4
5import (
6 "bufio"
7 "bytes"
8 "context"
9 "crypto/rand"
10 "encoding/base64"
11 "fmt"
12 "io"
13 "io/ioutil"
14 "net/http"
15 "net/url"
16 "strings"
17 "sync"
18 "time"
19
20 "nhooyr.io/websocket/internal/errd"
21)
22
23// DialOptions represents Dial's options.
24type DialOptions struct {
25 // HTTPClient is used for the connection.
26 // Its Transport must return writable bodies for WebSocket handshakes.
27 // http.Transport does beginning with Go 1.12.
28 HTTPClient *http.Client
29
30 // HTTPHeader specifies the HTTP headers included in the handshake request.
31 HTTPHeader http.Header
32
33 // Subprotocols lists the WebSocket subprotocols to negotiate with the server.
34 Subprotocols []string
35
36 // CompressionMode controls the compression mode.
37 // Defaults to CompressionNoContextTakeover.
38 //
39 // See docs on CompressionMode for details.
40 CompressionMode CompressionMode
41
42 // CompressionThreshold controls the minimum size of a message before compression is applied.
43 //
44 // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
45 // for CompressionContextTakeover.
46 CompressionThreshold int
47}
48
49// Dial performs a WebSocket handshake on url.
50//
51// The response is the WebSocket handshake response from the server.
52// You never need to close resp.Body yourself.
53//
54// If an error occurs, the returned response may be non nil.
55// However, you can only read the first 1024 bytes of the body.
56//
57// This function requires at least Go 1.12 as it uses a new feature
58// in net/http to perform WebSocket handshakes.
59// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
60//
61// URLs with http/https schemes will work and are interpreted as ws/wss.
62func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
63 return dial(ctx, u, opts, nil)
64}
65
66func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
67 defer errd.Wrap(&err, "failed to WebSocket dial")
68
69 if opts == nil {
70 opts = &DialOptions{}
71 }
72
73 opts = &*opts
74 if opts.HTTPClient == nil {
75 opts.HTTPClient = http.DefaultClient
76 } else if opts.HTTPClient.Timeout > 0 {
77 var cancel context.CancelFunc
78
79 ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
80 defer cancel()
81
82 newClient := *opts.HTTPClient
83 newClient.Timeout = 0
84 opts.HTTPClient = &newClient
85 }
86
87 if opts.HTTPHeader == nil {
88 opts.HTTPHeader = http.Header{}
89 }
90
91 secWebSocketKey, err := secWebSocketKey(rand)
92 if err != nil {
93 return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
94 }
95
96 var copts *compressionOptions
97 if opts.CompressionMode != CompressionDisabled {
98 copts = opts.CompressionMode.opts()
99 }
100
101 resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
102 if err != nil {
103 return nil, resp, err
104 }
105 respBody := resp.Body
106 resp.Body = nil
107 defer func() {
108 if err != nil {
109 // We read a bit of the body for easier debugging.
110 r := io.LimitReader(respBody, 1024)
111
112 timer := time.AfterFunc(time.Second*3, func() {
113 respBody.Close()
114 })
115 defer timer.Stop()
116
117 b, _ := ioutil.ReadAll(r)
118 respBody.Close()
119 resp.Body = ioutil.NopCloser(bytes.NewReader(b))
120 }
121 }()
122
123 copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
124 if err != nil {
125 return nil, resp, err
126 }
127
128 rwc, ok := respBody.(io.ReadWriteCloser)
129 if !ok {
130 return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
131 }
132
133 return newConn(connConfig{
134 subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
135 rwc: rwc,
136 client: true,
137 copts: copts,
138 flateThreshold: opts.CompressionThreshold,
139 br: getBufioReader(rwc),
140 bw: getBufioWriter(rwc),
141 }), resp, nil
142}
143
144func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
145 u, err := url.Parse(urls)
146 if err != nil {
147 return nil, fmt.Errorf("failed to parse url: %w", err)
148 }
149
150 switch u.Scheme {
151 case "ws":
152 u.Scheme = "http"
153 case "wss":
154 u.Scheme = "https"
155 case "http", "https":
156 default:
157 return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
158 }
159
160 req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
161 req.Header = opts.HTTPHeader.Clone()
162 req.Header.Set("Connection", "Upgrade")
163 req.Header.Set("Upgrade", "websocket")
164 req.Header.Set("Sec-WebSocket-Version", "13")
165 req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
166 if len(opts.Subprotocols) > 0 {
167 req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
168 }
169 if copts != nil {
170 copts.setHeader(req.Header)
171 }
172
173 resp, err := opts.HTTPClient.Do(req)
174 if err != nil {
175 return nil, fmt.Errorf("failed to send handshake request: %w", err)
176 }
177 return resp, nil
178}
179
180func secWebSocketKey(rr io.Reader) (string, error) {
181 if rr == nil {
182 rr = rand.Reader
183 }
184 b := make([]byte, 16)
185 _, err := io.ReadFull(rr, b)
186 if err != nil {
187 return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
188 }
189 return base64.StdEncoding.EncodeToString(b), nil
190}
191
192func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
193 if resp.StatusCode != http.StatusSwitchingProtocols {
194 return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
195 }
196
197 if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
198 return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
199 }
200
201 if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
202 return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
203 }
204
205 if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
206 return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
207 resp.Header.Get("Sec-WebSocket-Accept"),
208 secWebSocketKey,
209 )
210 }
211
212 err := verifySubprotocol(opts.Subprotocols, resp)
213 if err != nil {
214 return nil, err
215 }
216
217 return verifyServerExtensions(copts, resp.Header)
218}
219
220func verifySubprotocol(subprotos []string, resp *http.Response) error {
221 proto := resp.Header.Get("Sec-WebSocket-Protocol")
222 if proto == "" {
223 return nil
224 }
225
226 for _, sp2 := range subprotos {
227 if strings.EqualFold(sp2, proto) {
228 return nil
229 }
230 }
231
232 return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
233}
234
235func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
236 exts := websocketExtensions(h)
237 if len(exts) == 0 {
238 return nil, nil
239 }
240
241 ext := exts[0]
242 if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
243 return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
244 }
245
246 copts = &*copts
247
248 for _, p := range ext.params {
249 switch p {
250 case "client_no_context_takeover":
251 copts.clientNoContextTakeover = true
252 continue
253 case "server_no_context_takeover":
254 copts.serverNoContextTakeover = true
255 continue
256 }
257
258 return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
259 }
260
261 return copts, nil
262}
263
264var bufioReaderPool sync.Pool
265
266func getBufioReader(r io.Reader) *bufio.Reader {
267 br, ok := bufioReaderPool.Get().(*bufio.Reader)
268 if !ok {
269 return bufio.NewReader(r)
270 }
271 br.Reset(r)
272 return br
273}
274
275func putBufioReader(br *bufio.Reader) {
276 bufioReaderPool.Put(br)
277}
278
279var bufioWriterPool sync.Pool
280
281func getBufioWriter(w io.Writer) *bufio.Writer {
282 bw, ok := bufioWriterPool.Get().(*bufio.Writer)
283 if !ok {
284 return bufio.NewWriter(w)
285 }
286 bw.Reset(w)
287 return bw
288}
289
290func putBufioWriter(bw *bufio.Writer) {
291 bufioWriterPool.Put(bw)
292}
Note: See TracBrowser for help on using the repository browser.