source: code/trunk/vendor/nhooyr.io/websocket/accept.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 10.7 KB
Line 
1// +build !js
2
3package websocket
4
5import (
6 "bytes"
7 "crypto/sha1"
8 "encoding/base64"
9 "errors"
10 "fmt"
11 "io"
12 "log"
13 "net/http"
14 "net/textproto"
15 "net/url"
16 "path/filepath"
17 "strings"
18
19 "nhooyr.io/websocket/internal/errd"
20)
21
22// AcceptOptions represents Accept's options.
23type AcceptOptions struct {
24 // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
25 // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
26 // reject it, close the connection when c.Subprotocol() == "".
27 Subprotocols []string
28
29 // InsecureSkipVerify is used to disable Accept's origin verification behaviour.
30 //
31 // You probably want to use OriginPatterns instead.
32 InsecureSkipVerify bool
33
34 // OriginPatterns lists the host patterns for authorized origins.
35 // The request host is always authorized.
36 // Use this to enable cross origin WebSockets.
37 //
38 // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
39 // In such a case, example.com is the origin and chat.example.com is the request host.
40 // One would set this field to []string{"example.com"} to authorize example.com to connect.
41 //
42 // Each pattern is matched case insensitively against the request origin host
43 // with filepath.Match.
44 // See https://golang.org/pkg/path/filepath/#Match
45 //
46 // Please ensure you understand the ramifications of enabling this.
47 // If used incorrectly your WebSocket server will be open to CSRF attacks.
48 //
49 // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
50 // to bring attention to the danger of such a setting.
51 OriginPatterns []string
52
53 // CompressionMode controls the compression mode.
54 // Defaults to CompressionNoContextTakeover.
55 //
56 // See docs on CompressionMode for details.
57 CompressionMode CompressionMode
58
59 // CompressionThreshold controls the minimum size of a message before compression is applied.
60 //
61 // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
62 // for CompressionContextTakeover.
63 CompressionThreshold int
64}
65
66// Accept accepts a WebSocket handshake from a client and upgrades the
67// the connection to a WebSocket.
68//
69// Accept will not allow cross origin requests by default.
70// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
71//
72// Accept will write a response to w on all errors.
73func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
74 return accept(w, r, opts)
75}
76
77func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
78 defer errd.Wrap(&err, "failed to accept WebSocket connection")
79
80 if opts == nil {
81 opts = &AcceptOptions{}
82 }
83 opts = &*opts
84
85 errCode, err := verifyClientRequest(w, r)
86 if err != nil {
87 http.Error(w, err.Error(), errCode)
88 return nil, err
89 }
90
91 if !opts.InsecureSkipVerify {
92 err = authenticateOrigin(r, opts.OriginPatterns)
93 if err != nil {
94 if errors.Is(err, filepath.ErrBadPattern) {
95 log.Printf("websocket: %v", err)
96 err = errors.New(http.StatusText(http.StatusForbidden))
97 }
98 http.Error(w, err.Error(), http.StatusForbidden)
99 return nil, err
100 }
101 }
102
103 hj, ok := w.(http.Hijacker)
104 if !ok {
105 err = errors.New("http.ResponseWriter does not implement http.Hijacker")
106 http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
107 return nil, err
108 }
109
110 w.Header().Set("Upgrade", "websocket")
111 w.Header().Set("Connection", "Upgrade")
112
113 key := r.Header.Get("Sec-WebSocket-Key")
114 w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
115
116 subproto := selectSubprotocol(r, opts.Subprotocols)
117 if subproto != "" {
118 w.Header().Set("Sec-WebSocket-Protocol", subproto)
119 }
120
121 copts, err := acceptCompression(r, w, opts.CompressionMode)
122 if err != nil {
123 return nil, err
124 }
125
126 w.WriteHeader(http.StatusSwitchingProtocols)
127 // See https://github.com/nhooyr/websocket/issues/166
128 if ginWriter, ok := w.(interface {
129 WriteHeaderNow()
130 }); ok {
131 ginWriter.WriteHeaderNow()
132 }
133
134 netConn, brw, err := hj.Hijack()
135 if err != nil {
136 err = fmt.Errorf("failed to hijack connection: %w", err)
137 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
138 return nil, err
139 }
140
141 // https://github.com/golang/go/issues/32314
142 b, _ := brw.Reader.Peek(brw.Reader.Buffered())
143 brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
144
145 return newConn(connConfig{
146 subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
147 rwc: netConn,
148 client: false,
149 copts: copts,
150 flateThreshold: opts.CompressionThreshold,
151
152 br: brw.Reader,
153 bw: brw.Writer,
154 }), nil
155}
156
157func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
158 if !r.ProtoAtLeast(1, 1) {
159 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
160 }
161
162 if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
163 w.Header().Set("Connection", "Upgrade")
164 w.Header().Set("Upgrade", "websocket")
165 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
166 }
167
168 if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
169 w.Header().Set("Connection", "Upgrade")
170 w.Header().Set("Upgrade", "websocket")
171 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
172 }
173
174 if r.Method != "GET" {
175 return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
176 }
177
178 if r.Header.Get("Sec-WebSocket-Version") != "13" {
179 w.Header().Set("Sec-WebSocket-Version", "13")
180 return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
181 }
182
183 if r.Header.Get("Sec-WebSocket-Key") == "" {
184 return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
185 }
186
187 return 0, nil
188}
189
190func authenticateOrigin(r *http.Request, originHosts []string) error {
191 origin := r.Header.Get("Origin")
192 if origin == "" {
193 return nil
194 }
195
196 u, err := url.Parse(origin)
197 if err != nil {
198 return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
199 }
200
201 if strings.EqualFold(r.Host, u.Host) {
202 return nil
203 }
204
205 for _, hostPattern := range originHosts {
206 matched, err := match(hostPattern, u.Host)
207 if err != nil {
208 return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
209 }
210 if matched {
211 return nil
212 }
213 }
214 return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
215}
216
217func match(pattern, s string) (bool, error) {
218 return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
219}
220
221func selectSubprotocol(r *http.Request, subprotocols []string) string {
222 cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
223 for _, sp := range subprotocols {
224 for _, cp := range cps {
225 if strings.EqualFold(sp, cp) {
226 return cp
227 }
228 }
229 }
230 return ""
231}
232
233func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
234 if mode == CompressionDisabled {
235 return nil, nil
236 }
237
238 for _, ext := range websocketExtensions(r.Header) {
239 switch ext.name {
240 case "permessage-deflate":
241 return acceptDeflate(w, ext, mode)
242 // Disabled for now, see https://github.com/nhooyr/websocket/issues/218
243 // case "x-webkit-deflate-frame":
244 // return acceptWebkitDeflate(w, ext, mode)
245 }
246 }
247 return nil, nil
248}
249
250func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
251 copts := mode.opts()
252
253 for _, p := range ext.params {
254 switch p {
255 case "client_no_context_takeover":
256 copts.clientNoContextTakeover = true
257 continue
258 case "server_no_context_takeover":
259 copts.serverNoContextTakeover = true
260 continue
261 }
262
263 if strings.HasPrefix(p, "client_max_window_bits") {
264 // We cannot adjust the read sliding window so cannot make use of this.
265 continue
266 }
267
268 err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
269 http.Error(w, err.Error(), http.StatusBadRequest)
270 return nil, err
271 }
272
273 copts.setHeader(w.Header())
274
275 return copts, nil
276}
277
278func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
279 copts := mode.opts()
280 // The peer must explicitly request it.
281 copts.serverNoContextTakeover = false
282
283 for _, p := range ext.params {
284 if p == "no_context_takeover" {
285 copts.serverNoContextTakeover = true
286 continue
287 }
288
289 // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
290 // of ignoring it as the draft spec is unclear. It says the server can ignore it
291 // but the server has no way of signalling to the client it was ignored as the parameters
292 // are set one way.
293 // Thus us ignoring it would make the client think we understood it which would cause issues.
294 // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
295 //
296 // Either way, we're only implementing this for webkit which never sends the max_window_bits
297 // parameter so we don't need to worry about it.
298 err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
299 http.Error(w, err.Error(), http.StatusBadRequest)
300 return nil, err
301 }
302
303 s := "x-webkit-deflate-frame"
304 if copts.clientNoContextTakeover {
305 s += "; no_context_takeover"
306 }
307 w.Header().Set("Sec-WebSocket-Extensions", s)
308
309 return copts, nil
310}
311
312func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
313 for _, t := range headerTokens(h, key) {
314 if strings.EqualFold(t, token) {
315 return true
316 }
317 }
318 return false
319}
320
321type websocketExtension struct {
322 name string
323 params []string
324}
325
326func websocketExtensions(h http.Header) []websocketExtension {
327 var exts []websocketExtension
328 extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
329 for _, extStr := range extStrs {
330 if extStr == "" {
331 continue
332 }
333
334 vals := strings.Split(extStr, ";")
335 for i := range vals {
336 vals[i] = strings.TrimSpace(vals[i])
337 }
338
339 e := websocketExtension{
340 name: vals[0],
341 params: vals[1:],
342 }
343
344 exts = append(exts, e)
345 }
346 return exts
347}
348
349func headerTokens(h http.Header, key string) []string {
350 key = textproto.CanonicalMIMEHeaderKey(key)
351 var tokens []string
352 for _, v := range h[key] {
353 v = strings.TrimSpace(v)
354 for _, t := range strings.Split(v, ",") {
355 t = strings.TrimSpace(t)
356 tokens = append(tokens, t)
357 }
358 }
359 return tokens
360}
361
362var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
363
364func secWebSocketAccept(secWebSocketKey string) string {
365 h := sha1.New()
366 h.Write([]byte(secWebSocketKey))
367 h.Write(keyGUID)
368
369 return base64.StdEncoding.EncodeToString(h.Sum(nil))
370}
Note: See TracBrowser for help on using the repository browser.