[822] | 1 | package websocket // import "nhooyr.io/websocket"
|
---|
| 2 |
|
---|
| 3 | import (
|
---|
| 4 | "bytes"
|
---|
| 5 | "context"
|
---|
| 6 | "errors"
|
---|
| 7 | "fmt"
|
---|
| 8 | "io"
|
---|
| 9 | "net/http"
|
---|
| 10 | "reflect"
|
---|
| 11 | "runtime"
|
---|
| 12 | "strings"
|
---|
| 13 | "sync"
|
---|
| 14 | "syscall/js"
|
---|
| 15 |
|
---|
| 16 | "nhooyr.io/websocket/internal/bpool"
|
---|
| 17 | "nhooyr.io/websocket/internal/wsjs"
|
---|
| 18 | "nhooyr.io/websocket/internal/xsync"
|
---|
| 19 | )
|
---|
| 20 |
|
---|
| 21 | // Conn provides a wrapper around the browser WebSocket API.
|
---|
| 22 | type Conn struct {
|
---|
| 23 | ws wsjs.WebSocket
|
---|
| 24 |
|
---|
| 25 | // read limit for a message in bytes.
|
---|
| 26 | msgReadLimit xsync.Int64
|
---|
| 27 |
|
---|
| 28 | closingMu sync.Mutex
|
---|
| 29 | isReadClosed xsync.Int64
|
---|
| 30 | closeOnce sync.Once
|
---|
| 31 | closed chan struct{}
|
---|
| 32 | closeErrOnce sync.Once
|
---|
| 33 | closeErr error
|
---|
| 34 | closeWasClean bool
|
---|
| 35 |
|
---|
| 36 | releaseOnClose func()
|
---|
| 37 | releaseOnMessage func()
|
---|
| 38 |
|
---|
| 39 | readSignal chan struct{}
|
---|
| 40 | readBufMu sync.Mutex
|
---|
| 41 | readBuf []wsjs.MessageEvent
|
---|
| 42 | }
|
---|
| 43 |
|
---|
| 44 | func (c *Conn) close(err error, wasClean bool) {
|
---|
| 45 | c.closeOnce.Do(func() {
|
---|
| 46 | runtime.SetFinalizer(c, nil)
|
---|
| 47 |
|
---|
| 48 | if !wasClean {
|
---|
| 49 | err = fmt.Errorf("unclean connection close: %w", err)
|
---|
| 50 | }
|
---|
| 51 | c.setCloseErr(err)
|
---|
| 52 | c.closeWasClean = wasClean
|
---|
| 53 | close(c.closed)
|
---|
| 54 | })
|
---|
| 55 | }
|
---|
| 56 |
|
---|
| 57 | func (c *Conn) init() {
|
---|
| 58 | c.closed = make(chan struct{})
|
---|
| 59 | c.readSignal = make(chan struct{}, 1)
|
---|
| 60 |
|
---|
| 61 | c.msgReadLimit.Store(32768)
|
---|
| 62 |
|
---|
| 63 | c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
|
---|
| 64 | err := CloseError{
|
---|
| 65 | Code: StatusCode(e.Code),
|
---|
| 66 | Reason: e.Reason,
|
---|
| 67 | }
|
---|
| 68 | // We do not know if we sent or received this close as
|
---|
| 69 | // its possible the browser triggered it without us
|
---|
| 70 | // explicitly sending it.
|
---|
| 71 | c.close(err, e.WasClean)
|
---|
| 72 |
|
---|
| 73 | c.releaseOnClose()
|
---|
| 74 | c.releaseOnMessage()
|
---|
| 75 | })
|
---|
| 76 |
|
---|
| 77 | c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
|
---|
| 78 | c.readBufMu.Lock()
|
---|
| 79 | defer c.readBufMu.Unlock()
|
---|
| 80 |
|
---|
| 81 | c.readBuf = append(c.readBuf, e)
|
---|
| 82 |
|
---|
| 83 | // Lets the read goroutine know there is definitely something in readBuf.
|
---|
| 84 | select {
|
---|
| 85 | case c.readSignal <- struct{}{}:
|
---|
| 86 | default:
|
---|
| 87 | }
|
---|
| 88 | })
|
---|
| 89 |
|
---|
| 90 | runtime.SetFinalizer(c, func(c *Conn) {
|
---|
| 91 | c.setCloseErr(errors.New("connection garbage collected"))
|
---|
| 92 | c.closeWithInternal()
|
---|
| 93 | })
|
---|
| 94 | }
|
---|
| 95 |
|
---|
| 96 | func (c *Conn) closeWithInternal() {
|
---|
| 97 | c.Close(StatusInternalError, "something went wrong")
|
---|
| 98 | }
|
---|
| 99 |
|
---|
| 100 | // Read attempts to read a message from the connection.
|
---|
| 101 | // The maximum time spent waiting is bounded by the context.
|
---|
| 102 | func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
|
---|
| 103 | if c.isReadClosed.Load() == 1 {
|
---|
| 104 | return 0, nil, errors.New("WebSocket connection read closed")
|
---|
| 105 | }
|
---|
| 106 |
|
---|
| 107 | typ, p, err := c.read(ctx)
|
---|
| 108 | if err != nil {
|
---|
| 109 | return 0, nil, fmt.Errorf("failed to read: %w", err)
|
---|
| 110 | }
|
---|
| 111 | if int64(len(p)) > c.msgReadLimit.Load() {
|
---|
| 112 | err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
|
---|
| 113 | c.Close(StatusMessageTooBig, err.Error())
|
---|
| 114 | return 0, nil, err
|
---|
| 115 | }
|
---|
| 116 | return typ, p, nil
|
---|
| 117 | }
|
---|
| 118 |
|
---|
| 119 | func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
|
---|
| 120 | select {
|
---|
| 121 | case <-ctx.Done():
|
---|
| 122 | c.Close(StatusPolicyViolation, "read timed out")
|
---|
| 123 | return 0, nil, ctx.Err()
|
---|
| 124 | case <-c.readSignal:
|
---|
| 125 | case <-c.closed:
|
---|
| 126 | return 0, nil, c.closeErr
|
---|
| 127 | }
|
---|
| 128 |
|
---|
| 129 | c.readBufMu.Lock()
|
---|
| 130 | defer c.readBufMu.Unlock()
|
---|
| 131 |
|
---|
| 132 | me := c.readBuf[0]
|
---|
| 133 | // We copy the messages forward and decrease the size
|
---|
| 134 | // of the slice to avoid reallocating.
|
---|
| 135 | copy(c.readBuf, c.readBuf[1:])
|
---|
| 136 | c.readBuf = c.readBuf[:len(c.readBuf)-1]
|
---|
| 137 |
|
---|
| 138 | if len(c.readBuf) > 0 {
|
---|
| 139 | // Next time we read, we'll grab the message.
|
---|
| 140 | select {
|
---|
| 141 | case c.readSignal <- struct{}{}:
|
---|
| 142 | default:
|
---|
| 143 | }
|
---|
| 144 | }
|
---|
| 145 |
|
---|
| 146 | switch p := me.Data.(type) {
|
---|
| 147 | case string:
|
---|
| 148 | return MessageText, []byte(p), nil
|
---|
| 149 | case []byte:
|
---|
| 150 | return MessageBinary, p, nil
|
---|
| 151 | default:
|
---|
| 152 | panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
|
---|
| 153 | }
|
---|
| 154 | }
|
---|
| 155 |
|
---|
| 156 | // Ping is mocked out for Wasm.
|
---|
| 157 | func (c *Conn) Ping(ctx context.Context) error {
|
---|
| 158 | return nil
|
---|
| 159 | }
|
---|
| 160 |
|
---|
| 161 | // Write writes a message of the given type to the connection.
|
---|
| 162 | // Always non blocking.
|
---|
| 163 | func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
|
---|
| 164 | err := c.write(ctx, typ, p)
|
---|
| 165 | if err != nil {
|
---|
| 166 | // Have to ensure the WebSocket is closed after a write error
|
---|
| 167 | // to match the Go API. It can only error if the message type
|
---|
| 168 | // is unexpected or the passed bytes contain invalid UTF-8 for
|
---|
| 169 | // MessageText.
|
---|
| 170 | err := fmt.Errorf("failed to write: %w", err)
|
---|
| 171 | c.setCloseErr(err)
|
---|
| 172 | c.closeWithInternal()
|
---|
| 173 | return err
|
---|
| 174 | }
|
---|
| 175 | return nil
|
---|
| 176 | }
|
---|
| 177 |
|
---|
| 178 | func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
|
---|
| 179 | if c.isClosed() {
|
---|
| 180 | return c.closeErr
|
---|
| 181 | }
|
---|
| 182 | switch typ {
|
---|
| 183 | case MessageBinary:
|
---|
| 184 | return c.ws.SendBytes(p)
|
---|
| 185 | case MessageText:
|
---|
| 186 | return c.ws.SendText(string(p))
|
---|
| 187 | default:
|
---|
| 188 | return fmt.Errorf("unexpected message type: %v", typ)
|
---|
| 189 | }
|
---|
| 190 | }
|
---|
| 191 |
|
---|
| 192 | // Close closes the WebSocket with the given code and reason.
|
---|
| 193 | // It will wait until the peer responds with a close frame
|
---|
| 194 | // or the connection is closed.
|
---|
| 195 | // It thus performs the full WebSocket close handshake.
|
---|
| 196 | func (c *Conn) Close(code StatusCode, reason string) error {
|
---|
| 197 | err := c.exportedClose(code, reason)
|
---|
| 198 | if err != nil {
|
---|
| 199 | return fmt.Errorf("failed to close WebSocket: %w", err)
|
---|
| 200 | }
|
---|
| 201 | return nil
|
---|
| 202 | }
|
---|
| 203 |
|
---|
| 204 | func (c *Conn) exportedClose(code StatusCode, reason string) error {
|
---|
| 205 | c.closingMu.Lock()
|
---|
| 206 | defer c.closingMu.Unlock()
|
---|
| 207 |
|
---|
| 208 | ce := fmt.Errorf("sent close: %w", CloseError{
|
---|
| 209 | Code: code,
|
---|
| 210 | Reason: reason,
|
---|
| 211 | })
|
---|
| 212 |
|
---|
| 213 | if c.isClosed() {
|
---|
| 214 | return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr)
|
---|
| 215 | }
|
---|
| 216 |
|
---|
| 217 | c.setCloseErr(ce)
|
---|
| 218 | err := c.ws.Close(int(code), reason)
|
---|
| 219 | if err != nil {
|
---|
| 220 | return err
|
---|
| 221 | }
|
---|
| 222 |
|
---|
| 223 | <-c.closed
|
---|
| 224 | if !c.closeWasClean {
|
---|
| 225 | return c.closeErr
|
---|
| 226 | }
|
---|
| 227 | return nil
|
---|
| 228 | }
|
---|
| 229 |
|
---|
| 230 | // Subprotocol returns the negotiated subprotocol.
|
---|
| 231 | // An empty string means the default protocol.
|
---|
| 232 | func (c *Conn) Subprotocol() string {
|
---|
| 233 | return c.ws.Subprotocol()
|
---|
| 234 | }
|
---|
| 235 |
|
---|
| 236 | // DialOptions represents the options available to pass to Dial.
|
---|
| 237 | type DialOptions struct {
|
---|
| 238 | // Subprotocols lists the subprotocols to negotiate with the server.
|
---|
| 239 | Subprotocols []string
|
---|
| 240 | }
|
---|
| 241 |
|
---|
| 242 | // Dial creates a new WebSocket connection to the given url with the given options.
|
---|
| 243 | // The passed context bounds the maximum time spent waiting for the connection to open.
|
---|
| 244 | // The returned *http.Response is always nil or a mock. It's only in the signature
|
---|
| 245 | // to match the core API.
|
---|
| 246 | func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
|
---|
| 247 | c, resp, err := dial(ctx, url, opts)
|
---|
| 248 | if err != nil {
|
---|
| 249 | return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
|
---|
| 250 | }
|
---|
| 251 | return c, resp, nil
|
---|
| 252 | }
|
---|
| 253 |
|
---|
| 254 | func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
|
---|
| 255 | if opts == nil {
|
---|
| 256 | opts = &DialOptions{}
|
---|
| 257 | }
|
---|
| 258 |
|
---|
| 259 | url = strings.Replace(url, "http://", "ws://", 1)
|
---|
| 260 | url = strings.Replace(url, "https://", "wss://", 1)
|
---|
| 261 |
|
---|
| 262 | ws, err := wsjs.New(url, opts.Subprotocols)
|
---|
| 263 | if err != nil {
|
---|
| 264 | return nil, nil, err
|
---|
| 265 | }
|
---|
| 266 |
|
---|
| 267 | c := &Conn{
|
---|
| 268 | ws: ws,
|
---|
| 269 | }
|
---|
| 270 | c.init()
|
---|
| 271 |
|
---|
| 272 | opench := make(chan struct{})
|
---|
| 273 | releaseOpen := ws.OnOpen(func(e js.Value) {
|
---|
| 274 | close(opench)
|
---|
| 275 | })
|
---|
| 276 | defer releaseOpen()
|
---|
| 277 |
|
---|
| 278 | select {
|
---|
| 279 | case <-ctx.Done():
|
---|
| 280 | c.Close(StatusPolicyViolation, "dial timed out")
|
---|
| 281 | return nil, nil, ctx.Err()
|
---|
| 282 | case <-opench:
|
---|
| 283 | return c, &http.Response{
|
---|
| 284 | StatusCode: http.StatusSwitchingProtocols,
|
---|
| 285 | }, nil
|
---|
| 286 | case <-c.closed:
|
---|
| 287 | return nil, nil, c.closeErr
|
---|
| 288 | }
|
---|
| 289 | }
|
---|
| 290 |
|
---|
| 291 | // Reader attempts to read a message from the connection.
|
---|
| 292 | // The maximum time spent waiting is bounded by the context.
|
---|
| 293 | func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
|
---|
| 294 | typ, p, err := c.Read(ctx)
|
---|
| 295 | if err != nil {
|
---|
| 296 | return 0, nil, err
|
---|
| 297 | }
|
---|
| 298 | return typ, bytes.NewReader(p), nil
|
---|
| 299 | }
|
---|
| 300 |
|
---|
| 301 | // Writer returns a writer to write a WebSocket data message to the connection.
|
---|
| 302 | // It buffers the entire message in memory and then sends it when the writer
|
---|
| 303 | // is closed.
|
---|
| 304 | func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
|
---|
| 305 | return writer{
|
---|
| 306 | c: c,
|
---|
| 307 | ctx: ctx,
|
---|
| 308 | typ: typ,
|
---|
| 309 | b: bpool.Get(),
|
---|
| 310 | }, nil
|
---|
| 311 | }
|
---|
| 312 |
|
---|
| 313 | type writer struct {
|
---|
| 314 | closed bool
|
---|
| 315 |
|
---|
| 316 | c *Conn
|
---|
| 317 | ctx context.Context
|
---|
| 318 | typ MessageType
|
---|
| 319 |
|
---|
| 320 | b *bytes.Buffer
|
---|
| 321 | }
|
---|
| 322 |
|
---|
| 323 | func (w writer) Write(p []byte) (int, error) {
|
---|
| 324 | if w.closed {
|
---|
| 325 | return 0, errors.New("cannot write to closed writer")
|
---|
| 326 | }
|
---|
| 327 | n, err := w.b.Write(p)
|
---|
| 328 | if err != nil {
|
---|
| 329 | return n, fmt.Errorf("failed to write message: %w", err)
|
---|
| 330 | }
|
---|
| 331 | return n, nil
|
---|
| 332 | }
|
---|
| 333 |
|
---|
| 334 | func (w writer) Close() error {
|
---|
| 335 | if w.closed {
|
---|
| 336 | return errors.New("cannot close closed writer")
|
---|
| 337 | }
|
---|
| 338 | w.closed = true
|
---|
| 339 | defer bpool.Put(w.b)
|
---|
| 340 |
|
---|
| 341 | err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
|
---|
| 342 | if err != nil {
|
---|
| 343 | return fmt.Errorf("failed to close writer: %w", err)
|
---|
| 344 | }
|
---|
| 345 | return nil
|
---|
| 346 | }
|
---|
| 347 |
|
---|
| 348 | // CloseRead implements *Conn.CloseRead for wasm.
|
---|
| 349 | func (c *Conn) CloseRead(ctx context.Context) context.Context {
|
---|
| 350 | c.isReadClosed.Store(1)
|
---|
| 351 |
|
---|
| 352 | ctx, cancel := context.WithCancel(ctx)
|
---|
| 353 | go func() {
|
---|
| 354 | defer cancel()
|
---|
| 355 | c.read(ctx)
|
---|
| 356 | c.Close(StatusPolicyViolation, "unexpected data message")
|
---|
| 357 | }()
|
---|
| 358 | return ctx
|
---|
| 359 | }
|
---|
| 360 |
|
---|
| 361 | // SetReadLimit implements *Conn.SetReadLimit for wasm.
|
---|
| 362 | func (c *Conn) SetReadLimit(n int64) {
|
---|
| 363 | c.msgReadLimit.Store(n)
|
---|
| 364 | }
|
---|
| 365 |
|
---|
| 366 | func (c *Conn) setCloseErr(err error) {
|
---|
| 367 | c.closeErrOnce.Do(func() {
|
---|
| 368 | c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
|
---|
| 369 | })
|
---|
| 370 | }
|
---|
| 371 |
|
---|
| 372 | func (c *Conn) isClosed() bool {
|
---|
| 373 | select {
|
---|
| 374 | case <-c.closed:
|
---|
| 375 | return true
|
---|
| 376 | default:
|
---|
| 377 | return false
|
---|
| 378 | }
|
---|
| 379 | }
|
---|