[822] | 1 | // +build !js
|
---|
| 2 |
|
---|
| 3 | package websocket
|
---|
| 4 |
|
---|
| 5 | import (
|
---|
| 6 | "bytes"
|
---|
| 7 | "crypto/sha1"
|
---|
| 8 | "encoding/base64"
|
---|
| 9 | "errors"
|
---|
| 10 | "fmt"
|
---|
| 11 | "io"
|
---|
| 12 | "log"
|
---|
| 13 | "net/http"
|
---|
| 14 | "net/textproto"
|
---|
| 15 | "net/url"
|
---|
| 16 | "path/filepath"
|
---|
| 17 | "strings"
|
---|
| 18 |
|
---|
| 19 | "nhooyr.io/websocket/internal/errd"
|
---|
| 20 | )
|
---|
| 21 |
|
---|
| 22 | // AcceptOptions represents Accept's options.
|
---|
| 23 | type AcceptOptions struct {
|
---|
| 24 | // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
|
---|
| 25 | // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
|
---|
| 26 | // reject it, close the connection when c.Subprotocol() == "".
|
---|
| 27 | Subprotocols []string
|
---|
| 28 |
|
---|
| 29 | // InsecureSkipVerify is used to disable Accept's origin verification behaviour.
|
---|
| 30 | //
|
---|
| 31 | // You probably want to use OriginPatterns instead.
|
---|
| 32 | InsecureSkipVerify bool
|
---|
| 33 |
|
---|
| 34 | // OriginPatterns lists the host patterns for authorized origins.
|
---|
| 35 | // The request host is always authorized.
|
---|
| 36 | // Use this to enable cross origin WebSockets.
|
---|
| 37 | //
|
---|
| 38 | // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
|
---|
| 39 | // In such a case, example.com is the origin and chat.example.com is the request host.
|
---|
| 40 | // One would set this field to []string{"example.com"} to authorize example.com to connect.
|
---|
| 41 | //
|
---|
| 42 | // Each pattern is matched case insensitively against the request origin host
|
---|
| 43 | // with filepath.Match.
|
---|
| 44 | // See https://golang.org/pkg/path/filepath/#Match
|
---|
| 45 | //
|
---|
| 46 | // Please ensure you understand the ramifications of enabling this.
|
---|
| 47 | // If used incorrectly your WebSocket server will be open to CSRF attacks.
|
---|
| 48 | //
|
---|
| 49 | // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
|
---|
| 50 | // to bring attention to the danger of such a setting.
|
---|
| 51 | OriginPatterns []string
|
---|
| 52 |
|
---|
| 53 | // CompressionMode controls the compression mode.
|
---|
| 54 | // Defaults to CompressionNoContextTakeover.
|
---|
| 55 | //
|
---|
| 56 | // See docs on CompressionMode for details.
|
---|
| 57 | CompressionMode CompressionMode
|
---|
| 58 |
|
---|
| 59 | // CompressionThreshold controls the minimum size of a message before compression is applied.
|
---|
| 60 | //
|
---|
| 61 | // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
|
---|
| 62 | // for CompressionContextTakeover.
|
---|
| 63 | CompressionThreshold int
|
---|
| 64 | }
|
---|
| 65 |
|
---|
| 66 | // Accept accepts a WebSocket handshake from a client and upgrades the
|
---|
| 67 | // the connection to a WebSocket.
|
---|
| 68 | //
|
---|
| 69 | // Accept will not allow cross origin requests by default.
|
---|
| 70 | // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
|
---|
| 71 | //
|
---|
| 72 | // Accept will write a response to w on all errors.
|
---|
| 73 | func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
|
---|
| 74 | return accept(w, r, opts)
|
---|
| 75 | }
|
---|
| 76 |
|
---|
| 77 | func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
|
---|
| 78 | defer errd.Wrap(&err, "failed to accept WebSocket connection")
|
---|
| 79 |
|
---|
| 80 | if opts == nil {
|
---|
| 81 | opts = &AcceptOptions{}
|
---|
| 82 | }
|
---|
| 83 | opts = &*opts
|
---|
| 84 |
|
---|
| 85 | errCode, err := verifyClientRequest(w, r)
|
---|
| 86 | if err != nil {
|
---|
| 87 | http.Error(w, err.Error(), errCode)
|
---|
| 88 | return nil, err
|
---|
| 89 | }
|
---|
| 90 |
|
---|
| 91 | if !opts.InsecureSkipVerify {
|
---|
| 92 | err = authenticateOrigin(r, opts.OriginPatterns)
|
---|
| 93 | if err != nil {
|
---|
| 94 | if errors.Is(err, filepath.ErrBadPattern) {
|
---|
| 95 | log.Printf("websocket: %v", err)
|
---|
| 96 | err = errors.New(http.StatusText(http.StatusForbidden))
|
---|
| 97 | }
|
---|
| 98 | http.Error(w, err.Error(), http.StatusForbidden)
|
---|
| 99 | return nil, err
|
---|
| 100 | }
|
---|
| 101 | }
|
---|
| 102 |
|
---|
| 103 | hj, ok := w.(http.Hijacker)
|
---|
| 104 | if !ok {
|
---|
| 105 | err = errors.New("http.ResponseWriter does not implement http.Hijacker")
|
---|
| 106 | http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
|
---|
| 107 | return nil, err
|
---|
| 108 | }
|
---|
| 109 |
|
---|
| 110 | w.Header().Set("Upgrade", "websocket")
|
---|
| 111 | w.Header().Set("Connection", "Upgrade")
|
---|
| 112 |
|
---|
| 113 | key := r.Header.Get("Sec-WebSocket-Key")
|
---|
| 114 | w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
|
---|
| 115 |
|
---|
| 116 | subproto := selectSubprotocol(r, opts.Subprotocols)
|
---|
| 117 | if subproto != "" {
|
---|
| 118 | w.Header().Set("Sec-WebSocket-Protocol", subproto)
|
---|
| 119 | }
|
---|
| 120 |
|
---|
| 121 | copts, err := acceptCompression(r, w, opts.CompressionMode)
|
---|
| 122 | if err != nil {
|
---|
| 123 | return nil, err
|
---|
| 124 | }
|
---|
| 125 |
|
---|
| 126 | w.WriteHeader(http.StatusSwitchingProtocols)
|
---|
| 127 | // See https://github.com/nhooyr/websocket/issues/166
|
---|
| 128 | if ginWriter, ok := w.(interface {
|
---|
| 129 | WriteHeaderNow()
|
---|
| 130 | }); ok {
|
---|
| 131 | ginWriter.WriteHeaderNow()
|
---|
| 132 | }
|
---|
| 133 |
|
---|
| 134 | netConn, brw, err := hj.Hijack()
|
---|
| 135 | if err != nil {
|
---|
| 136 | err = fmt.Errorf("failed to hijack connection: %w", err)
|
---|
| 137 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
---|
| 138 | return nil, err
|
---|
| 139 | }
|
---|
| 140 |
|
---|
| 141 | // https://github.com/golang/go/issues/32314
|
---|
| 142 | b, _ := brw.Reader.Peek(brw.Reader.Buffered())
|
---|
| 143 | brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
|
---|
| 144 |
|
---|
| 145 | return newConn(connConfig{
|
---|
| 146 | subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
|
---|
| 147 | rwc: netConn,
|
---|
| 148 | client: false,
|
---|
| 149 | copts: copts,
|
---|
| 150 | flateThreshold: opts.CompressionThreshold,
|
---|
| 151 |
|
---|
| 152 | br: brw.Reader,
|
---|
| 153 | bw: brw.Writer,
|
---|
| 154 | }), nil
|
---|
| 155 | }
|
---|
| 156 |
|
---|
| 157 | func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
|
---|
| 158 | if !r.ProtoAtLeast(1, 1) {
|
---|
| 159 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
|
---|
| 160 | }
|
---|
| 161 |
|
---|
| 162 | if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
|
---|
| 163 | w.Header().Set("Connection", "Upgrade")
|
---|
| 164 | w.Header().Set("Upgrade", "websocket")
|
---|
| 165 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
|
---|
| 166 | }
|
---|
| 167 |
|
---|
| 168 | if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
|
---|
| 169 | w.Header().Set("Connection", "Upgrade")
|
---|
| 170 | w.Header().Set("Upgrade", "websocket")
|
---|
| 171 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
|
---|
| 172 | }
|
---|
| 173 |
|
---|
| 174 | if r.Method != "GET" {
|
---|
| 175 | return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
|
---|
| 176 | }
|
---|
| 177 |
|
---|
| 178 | if r.Header.Get("Sec-WebSocket-Version") != "13" {
|
---|
| 179 | w.Header().Set("Sec-WebSocket-Version", "13")
|
---|
| 180 | return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
|
---|
| 181 | }
|
---|
| 182 |
|
---|
| 183 | if r.Header.Get("Sec-WebSocket-Key") == "" {
|
---|
| 184 | return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
|
---|
| 185 | }
|
---|
| 186 |
|
---|
| 187 | return 0, nil
|
---|
| 188 | }
|
---|
| 189 |
|
---|
| 190 | func authenticateOrigin(r *http.Request, originHosts []string) error {
|
---|
| 191 | origin := r.Header.Get("Origin")
|
---|
| 192 | if origin == "" {
|
---|
| 193 | return nil
|
---|
| 194 | }
|
---|
| 195 |
|
---|
| 196 | u, err := url.Parse(origin)
|
---|
| 197 | if err != nil {
|
---|
| 198 | return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
|
---|
| 199 | }
|
---|
| 200 |
|
---|
| 201 | if strings.EqualFold(r.Host, u.Host) {
|
---|
| 202 | return nil
|
---|
| 203 | }
|
---|
| 204 |
|
---|
| 205 | for _, hostPattern := range originHosts {
|
---|
| 206 | matched, err := match(hostPattern, u.Host)
|
---|
| 207 | if err != nil {
|
---|
| 208 | return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
|
---|
| 209 | }
|
---|
| 210 | if matched {
|
---|
| 211 | return nil
|
---|
| 212 | }
|
---|
| 213 | }
|
---|
| 214 | return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
|
---|
| 215 | }
|
---|
| 216 |
|
---|
| 217 | func match(pattern, s string) (bool, error) {
|
---|
| 218 | return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
|
---|
| 219 | }
|
---|
| 220 |
|
---|
| 221 | func selectSubprotocol(r *http.Request, subprotocols []string) string {
|
---|
| 222 | cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
|
---|
| 223 | for _, sp := range subprotocols {
|
---|
| 224 | for _, cp := range cps {
|
---|
| 225 | if strings.EqualFold(sp, cp) {
|
---|
| 226 | return cp
|
---|
| 227 | }
|
---|
| 228 | }
|
---|
| 229 | }
|
---|
| 230 | return ""
|
---|
| 231 | }
|
---|
| 232 |
|
---|
| 233 | func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
|
---|
| 234 | if mode == CompressionDisabled {
|
---|
| 235 | return nil, nil
|
---|
| 236 | }
|
---|
| 237 |
|
---|
| 238 | for _, ext := range websocketExtensions(r.Header) {
|
---|
| 239 | switch ext.name {
|
---|
| 240 | case "permessage-deflate":
|
---|
| 241 | return acceptDeflate(w, ext, mode)
|
---|
| 242 | // Disabled for now, see https://github.com/nhooyr/websocket/issues/218
|
---|
| 243 | // case "x-webkit-deflate-frame":
|
---|
| 244 | // return acceptWebkitDeflate(w, ext, mode)
|
---|
| 245 | }
|
---|
| 246 | }
|
---|
| 247 | return nil, nil
|
---|
| 248 | }
|
---|
| 249 |
|
---|
| 250 | func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
|
---|
| 251 | copts := mode.opts()
|
---|
| 252 |
|
---|
| 253 | for _, p := range ext.params {
|
---|
| 254 | switch p {
|
---|
| 255 | case "client_no_context_takeover":
|
---|
| 256 | copts.clientNoContextTakeover = true
|
---|
| 257 | continue
|
---|
| 258 | case "server_no_context_takeover":
|
---|
| 259 | copts.serverNoContextTakeover = true
|
---|
| 260 | continue
|
---|
| 261 | }
|
---|
| 262 |
|
---|
| 263 | if strings.HasPrefix(p, "client_max_window_bits") {
|
---|
| 264 | // We cannot adjust the read sliding window so cannot make use of this.
|
---|
| 265 | continue
|
---|
| 266 | }
|
---|
| 267 |
|
---|
| 268 | err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
|
---|
| 269 | http.Error(w, err.Error(), http.StatusBadRequest)
|
---|
| 270 | return nil, err
|
---|
| 271 | }
|
---|
| 272 |
|
---|
| 273 | copts.setHeader(w.Header())
|
---|
| 274 |
|
---|
| 275 | return copts, nil
|
---|
| 276 | }
|
---|
| 277 |
|
---|
| 278 | func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
|
---|
| 279 | copts := mode.opts()
|
---|
| 280 | // The peer must explicitly request it.
|
---|
| 281 | copts.serverNoContextTakeover = false
|
---|
| 282 |
|
---|
| 283 | for _, p := range ext.params {
|
---|
| 284 | if p == "no_context_takeover" {
|
---|
| 285 | copts.serverNoContextTakeover = true
|
---|
| 286 | continue
|
---|
| 287 | }
|
---|
| 288 |
|
---|
| 289 | // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
|
---|
| 290 | // of ignoring it as the draft spec is unclear. It says the server can ignore it
|
---|
| 291 | // but the server has no way of signalling to the client it was ignored as the parameters
|
---|
| 292 | // are set one way.
|
---|
| 293 | // Thus us ignoring it would make the client think we understood it which would cause issues.
|
---|
| 294 | // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
|
---|
| 295 | //
|
---|
| 296 | // Either way, we're only implementing this for webkit which never sends the max_window_bits
|
---|
| 297 | // parameter so we don't need to worry about it.
|
---|
| 298 | err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
|
---|
| 299 | http.Error(w, err.Error(), http.StatusBadRequest)
|
---|
| 300 | return nil, err
|
---|
| 301 | }
|
---|
| 302 |
|
---|
| 303 | s := "x-webkit-deflate-frame"
|
---|
| 304 | if copts.clientNoContextTakeover {
|
---|
| 305 | s += "; no_context_takeover"
|
---|
| 306 | }
|
---|
| 307 | w.Header().Set("Sec-WebSocket-Extensions", s)
|
---|
| 308 |
|
---|
| 309 | return copts, nil
|
---|
| 310 | }
|
---|
| 311 |
|
---|
| 312 | func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
|
---|
| 313 | for _, t := range headerTokens(h, key) {
|
---|
| 314 | if strings.EqualFold(t, token) {
|
---|
| 315 | return true
|
---|
| 316 | }
|
---|
| 317 | }
|
---|
| 318 | return false
|
---|
| 319 | }
|
---|
| 320 |
|
---|
| 321 | type websocketExtension struct {
|
---|
| 322 | name string
|
---|
| 323 | params []string
|
---|
| 324 | }
|
---|
| 325 |
|
---|
| 326 | func websocketExtensions(h http.Header) []websocketExtension {
|
---|
| 327 | var exts []websocketExtension
|
---|
| 328 | extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
|
---|
| 329 | for _, extStr := range extStrs {
|
---|
| 330 | if extStr == "" {
|
---|
| 331 | continue
|
---|
| 332 | }
|
---|
| 333 |
|
---|
| 334 | vals := strings.Split(extStr, ";")
|
---|
| 335 | for i := range vals {
|
---|
| 336 | vals[i] = strings.TrimSpace(vals[i])
|
---|
| 337 | }
|
---|
| 338 |
|
---|
| 339 | e := websocketExtension{
|
---|
| 340 | name: vals[0],
|
---|
| 341 | params: vals[1:],
|
---|
| 342 | }
|
---|
| 343 |
|
---|
| 344 | exts = append(exts, e)
|
---|
| 345 | }
|
---|
| 346 | return exts
|
---|
| 347 | }
|
---|
| 348 |
|
---|
| 349 | func headerTokens(h http.Header, key string) []string {
|
---|
| 350 | key = textproto.CanonicalMIMEHeaderKey(key)
|
---|
| 351 | var tokens []string
|
---|
| 352 | for _, v := range h[key] {
|
---|
| 353 | v = strings.TrimSpace(v)
|
---|
| 354 | for _, t := range strings.Split(v, ",") {
|
---|
| 355 | t = strings.TrimSpace(t)
|
---|
| 356 | tokens = append(tokens, t)
|
---|
| 357 | }
|
---|
| 358 | }
|
---|
| 359 | return tokens
|
---|
| 360 | }
|
---|
| 361 |
|
---|
| 362 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
---|
| 363 |
|
---|
| 364 | func secWebSocketAccept(secWebSocketKey string) string {
|
---|
| 365 | h := sha1.New()
|
---|
| 366 | h.Write([]byte(secWebSocketKey))
|
---|
| 367 | h.Write(keyGUID)
|
---|
| 368 |
|
---|
| 369 | return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
---|
| 370 | }
|
---|