1 | // +build !js
|
---|
2 |
|
---|
3 | package websocket
|
---|
4 |
|
---|
5 | import (
|
---|
6 | "bytes"
|
---|
7 | "crypto/sha1"
|
---|
8 | "encoding/base64"
|
---|
9 | "errors"
|
---|
10 | "fmt"
|
---|
11 | "io"
|
---|
12 | "log"
|
---|
13 | "net/http"
|
---|
14 | "net/textproto"
|
---|
15 | "net/url"
|
---|
16 | "path/filepath"
|
---|
17 | "strings"
|
---|
18 |
|
---|
19 | "nhooyr.io/websocket/internal/errd"
|
---|
20 | )
|
---|
21 |
|
---|
22 | // AcceptOptions represents Accept's options.
|
---|
23 | type AcceptOptions struct {
|
---|
24 | // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
|
---|
25 | // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
|
---|
26 | // reject it, close the connection when c.Subprotocol() == "".
|
---|
27 | Subprotocols []string
|
---|
28 |
|
---|
29 | // InsecureSkipVerify is used to disable Accept's origin verification behaviour.
|
---|
30 | //
|
---|
31 | // You probably want to use OriginPatterns instead.
|
---|
32 | InsecureSkipVerify bool
|
---|
33 |
|
---|
34 | // OriginPatterns lists the host patterns for authorized origins.
|
---|
35 | // The request host is always authorized.
|
---|
36 | // Use this to enable cross origin WebSockets.
|
---|
37 | //
|
---|
38 | // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
|
---|
39 | // In such a case, example.com is the origin and chat.example.com is the request host.
|
---|
40 | // One would set this field to []string{"example.com"} to authorize example.com to connect.
|
---|
41 | //
|
---|
42 | // Each pattern is matched case insensitively against the request origin host
|
---|
43 | // with filepath.Match.
|
---|
44 | // See https://golang.org/pkg/path/filepath/#Match
|
---|
45 | //
|
---|
46 | // Please ensure you understand the ramifications of enabling this.
|
---|
47 | // If used incorrectly your WebSocket server will be open to CSRF attacks.
|
---|
48 | //
|
---|
49 | // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
|
---|
50 | // to bring attention to the danger of such a setting.
|
---|
51 | OriginPatterns []string
|
---|
52 |
|
---|
53 | // CompressionMode controls the compression mode.
|
---|
54 | // Defaults to CompressionNoContextTakeover.
|
---|
55 | //
|
---|
56 | // See docs on CompressionMode for details.
|
---|
57 | CompressionMode CompressionMode
|
---|
58 |
|
---|
59 | // CompressionThreshold controls the minimum size of a message before compression is applied.
|
---|
60 | //
|
---|
61 | // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
|
---|
62 | // for CompressionContextTakeover.
|
---|
63 | CompressionThreshold int
|
---|
64 | }
|
---|
65 |
|
---|
66 | // Accept accepts a WebSocket handshake from a client and upgrades the
|
---|
67 | // the connection to a WebSocket.
|
---|
68 | //
|
---|
69 | // Accept will not allow cross origin requests by default.
|
---|
70 | // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
|
---|
71 | //
|
---|
72 | // Accept will write a response to w on all errors.
|
---|
73 | func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
|
---|
74 | return accept(w, r, opts)
|
---|
75 | }
|
---|
76 |
|
---|
77 | func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
|
---|
78 | defer errd.Wrap(&err, "failed to accept WebSocket connection")
|
---|
79 |
|
---|
80 | if opts == nil {
|
---|
81 | opts = &AcceptOptions{}
|
---|
82 | }
|
---|
83 | opts = &*opts
|
---|
84 |
|
---|
85 | errCode, err := verifyClientRequest(w, r)
|
---|
86 | if err != nil {
|
---|
87 | http.Error(w, err.Error(), errCode)
|
---|
88 | return nil, err
|
---|
89 | }
|
---|
90 |
|
---|
91 | if !opts.InsecureSkipVerify {
|
---|
92 | err = authenticateOrigin(r, opts.OriginPatterns)
|
---|
93 | if err != nil {
|
---|
94 | if errors.Is(err, filepath.ErrBadPattern) {
|
---|
95 | log.Printf("websocket: %v", err)
|
---|
96 | err = errors.New(http.StatusText(http.StatusForbidden))
|
---|
97 | }
|
---|
98 | http.Error(w, err.Error(), http.StatusForbidden)
|
---|
99 | return nil, err
|
---|
100 | }
|
---|
101 | }
|
---|
102 |
|
---|
103 | hj, ok := w.(http.Hijacker)
|
---|
104 | if !ok {
|
---|
105 | err = errors.New("http.ResponseWriter does not implement http.Hijacker")
|
---|
106 | http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
|
---|
107 | return nil, err
|
---|
108 | }
|
---|
109 |
|
---|
110 | w.Header().Set("Upgrade", "websocket")
|
---|
111 | w.Header().Set("Connection", "Upgrade")
|
---|
112 |
|
---|
113 | key := r.Header.Get("Sec-WebSocket-Key")
|
---|
114 | w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
|
---|
115 |
|
---|
116 | subproto := selectSubprotocol(r, opts.Subprotocols)
|
---|
117 | if subproto != "" {
|
---|
118 | w.Header().Set("Sec-WebSocket-Protocol", subproto)
|
---|
119 | }
|
---|
120 |
|
---|
121 | copts, err := acceptCompression(r, w, opts.CompressionMode)
|
---|
122 | if err != nil {
|
---|
123 | return nil, err
|
---|
124 | }
|
---|
125 |
|
---|
126 | w.WriteHeader(http.StatusSwitchingProtocols)
|
---|
127 | // See https://github.com/nhooyr/websocket/issues/166
|
---|
128 | if ginWriter, ok := w.(interface {
|
---|
129 | WriteHeaderNow()
|
---|
130 | }); ok {
|
---|
131 | ginWriter.WriteHeaderNow()
|
---|
132 | }
|
---|
133 |
|
---|
134 | netConn, brw, err := hj.Hijack()
|
---|
135 | if err != nil {
|
---|
136 | err = fmt.Errorf("failed to hijack connection: %w", err)
|
---|
137 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
---|
138 | return nil, err
|
---|
139 | }
|
---|
140 |
|
---|
141 | // https://github.com/golang/go/issues/32314
|
---|
142 | b, _ := brw.Reader.Peek(brw.Reader.Buffered())
|
---|
143 | brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
|
---|
144 |
|
---|
145 | return newConn(connConfig{
|
---|
146 | subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
|
---|
147 | rwc: netConn,
|
---|
148 | client: false,
|
---|
149 | copts: copts,
|
---|
150 | flateThreshold: opts.CompressionThreshold,
|
---|
151 |
|
---|
152 | br: brw.Reader,
|
---|
153 | bw: brw.Writer,
|
---|
154 | }), nil
|
---|
155 | }
|
---|
156 |
|
---|
157 | func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
|
---|
158 | if !r.ProtoAtLeast(1, 1) {
|
---|
159 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
|
---|
160 | }
|
---|
161 |
|
---|
162 | if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
|
---|
163 | w.Header().Set("Connection", "Upgrade")
|
---|
164 | w.Header().Set("Upgrade", "websocket")
|
---|
165 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
|
---|
166 | }
|
---|
167 |
|
---|
168 | if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
|
---|
169 | w.Header().Set("Connection", "Upgrade")
|
---|
170 | w.Header().Set("Upgrade", "websocket")
|
---|
171 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
|
---|
172 | }
|
---|
173 |
|
---|
174 | if r.Method != "GET" {
|
---|
175 | return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
|
---|
176 | }
|
---|
177 |
|
---|
178 | if r.Header.Get("Sec-WebSocket-Version") != "13" {
|
---|
179 | w.Header().Set("Sec-WebSocket-Version", "13")
|
---|
180 | return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
|
---|
181 | }
|
---|
182 |
|
---|
183 | if r.Header.Get("Sec-WebSocket-Key") == "" {
|
---|
184 | return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
|
---|
185 | }
|
---|
186 |
|
---|
187 | return 0, nil
|
---|
188 | }
|
---|
189 |
|
---|
190 | func authenticateOrigin(r *http.Request, originHosts []string) error {
|
---|
191 | origin := r.Header.Get("Origin")
|
---|
192 | if origin == "" {
|
---|
193 | return nil
|
---|
194 | }
|
---|
195 |
|
---|
196 | u, err := url.Parse(origin)
|
---|
197 | if err != nil {
|
---|
198 | return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
|
---|
199 | }
|
---|
200 |
|
---|
201 | if strings.EqualFold(r.Host, u.Host) {
|
---|
202 | return nil
|
---|
203 | }
|
---|
204 |
|
---|
205 | for _, hostPattern := range originHosts {
|
---|
206 | matched, err := match(hostPattern, u.Host)
|
---|
207 | if err != nil {
|
---|
208 | return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
|
---|
209 | }
|
---|
210 | if matched {
|
---|
211 | return nil
|
---|
212 | }
|
---|
213 | }
|
---|
214 | return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
|
---|
215 | }
|
---|
216 |
|
---|
217 | func match(pattern, s string) (bool, error) {
|
---|
218 | return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
|
---|
219 | }
|
---|
220 |
|
---|
221 | func selectSubprotocol(r *http.Request, subprotocols []string) string {
|
---|
222 | cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
|
---|
223 | for _, sp := range subprotocols {
|
---|
224 | for _, cp := range cps {
|
---|
225 | if strings.EqualFold(sp, cp) {
|
---|
226 | return cp
|
---|
227 | }
|
---|
228 | }
|
---|
229 | }
|
---|
230 | return ""
|
---|
231 | }
|
---|
232 |
|
---|
233 | func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
|
---|
234 | if mode == CompressionDisabled {
|
---|
235 | return nil, nil
|
---|
236 | }
|
---|
237 |
|
---|
238 | for _, ext := range websocketExtensions(r.Header) {
|
---|
239 | switch ext.name {
|
---|
240 | case "permessage-deflate":
|
---|
241 | return acceptDeflate(w, ext, mode)
|
---|
242 | // Disabled for now, see https://github.com/nhooyr/websocket/issues/218
|
---|
243 | // case "x-webkit-deflate-frame":
|
---|
244 | // return acceptWebkitDeflate(w, ext, mode)
|
---|
245 | }
|
---|
246 | }
|
---|
247 | return nil, nil
|
---|
248 | }
|
---|
249 |
|
---|
250 | func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
|
---|
251 | copts := mode.opts()
|
---|
252 |
|
---|
253 | for _, p := range ext.params {
|
---|
254 | switch p {
|
---|
255 | case "client_no_context_takeover":
|
---|
256 | copts.clientNoContextTakeover = true
|
---|
257 | continue
|
---|
258 | case "server_no_context_takeover":
|
---|
259 | copts.serverNoContextTakeover = true
|
---|
260 | continue
|
---|
261 | }
|
---|
262 |
|
---|
263 | if strings.HasPrefix(p, "client_max_window_bits") {
|
---|
264 | // We cannot adjust the read sliding window so cannot make use of this.
|
---|
265 | continue
|
---|
266 | }
|
---|
267 |
|
---|
268 | err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
|
---|
269 | http.Error(w, err.Error(), http.StatusBadRequest)
|
---|
270 | return nil, err
|
---|
271 | }
|
---|
272 |
|
---|
273 | copts.setHeader(w.Header())
|
---|
274 |
|
---|
275 | return copts, nil
|
---|
276 | }
|
---|
277 |
|
---|
278 | func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
|
---|
279 | copts := mode.opts()
|
---|
280 | // The peer must explicitly request it.
|
---|
281 | copts.serverNoContextTakeover = false
|
---|
282 |
|
---|
283 | for _, p := range ext.params {
|
---|
284 | if p == "no_context_takeover" {
|
---|
285 | copts.serverNoContextTakeover = true
|
---|
286 | continue
|
---|
287 | }
|
---|
288 |
|
---|
289 | // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
|
---|
290 | // of ignoring it as the draft spec is unclear. It says the server can ignore it
|
---|
291 | // but the server has no way of signalling to the client it was ignored as the parameters
|
---|
292 | // are set one way.
|
---|
293 | // Thus us ignoring it would make the client think we understood it which would cause issues.
|
---|
294 | // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
|
---|
295 | //
|
---|
296 | // Either way, we're only implementing this for webkit which never sends the max_window_bits
|
---|
297 | // parameter so we don't need to worry about it.
|
---|
298 | err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
|
---|
299 | http.Error(w, err.Error(), http.StatusBadRequest)
|
---|
300 | return nil, err
|
---|
301 | }
|
---|
302 |
|
---|
303 | s := "x-webkit-deflate-frame"
|
---|
304 | if copts.clientNoContextTakeover {
|
---|
305 | s += "; no_context_takeover"
|
---|
306 | }
|
---|
307 | w.Header().Set("Sec-WebSocket-Extensions", s)
|
---|
308 |
|
---|
309 | return copts, nil
|
---|
310 | }
|
---|
311 |
|
---|
312 | func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
|
---|
313 | for _, t := range headerTokens(h, key) {
|
---|
314 | if strings.EqualFold(t, token) {
|
---|
315 | return true
|
---|
316 | }
|
---|
317 | }
|
---|
318 | return false
|
---|
319 | }
|
---|
320 |
|
---|
321 | type websocketExtension struct {
|
---|
322 | name string
|
---|
323 | params []string
|
---|
324 | }
|
---|
325 |
|
---|
326 | func websocketExtensions(h http.Header) []websocketExtension {
|
---|
327 | var exts []websocketExtension
|
---|
328 | extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
|
---|
329 | for _, extStr := range extStrs {
|
---|
330 | if extStr == "" {
|
---|
331 | continue
|
---|
332 | }
|
---|
333 |
|
---|
334 | vals := strings.Split(extStr, ";")
|
---|
335 | for i := range vals {
|
---|
336 | vals[i] = strings.TrimSpace(vals[i])
|
---|
337 | }
|
---|
338 |
|
---|
339 | e := websocketExtension{
|
---|
340 | name: vals[0],
|
---|
341 | params: vals[1:],
|
---|
342 | }
|
---|
343 |
|
---|
344 | exts = append(exts, e)
|
---|
345 | }
|
---|
346 | return exts
|
---|
347 | }
|
---|
348 |
|
---|
349 | func headerTokens(h http.Header, key string) []string {
|
---|
350 | key = textproto.CanonicalMIMEHeaderKey(key)
|
---|
351 | var tokens []string
|
---|
352 | for _, v := range h[key] {
|
---|
353 | v = strings.TrimSpace(v)
|
---|
354 | for _, t := range strings.Split(v, ",") {
|
---|
355 | t = strings.TrimSpace(t)
|
---|
356 | tokens = append(tokens, t)
|
---|
357 | }
|
---|
358 | }
|
---|
359 | return tokens
|
---|
360 | }
|
---|
361 |
|
---|
362 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
---|
363 |
|
---|
364 | func secWebSocketAccept(secWebSocketKey string) string {
|
---|
365 | h := sha1.New()
|
---|
366 | h.Write([]byte(secWebSocketKey))
|
---|
367 | h.Write(keyGUID)
|
---|
368 |
|
---|
369 | return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
---|
370 | }
|
---|