[145] | 1 | package fasthttp
|
---|
| 2 |
|
---|
| 3 | import (
|
---|
| 4 | "bufio"
|
---|
| 5 | "bytes"
|
---|
| 6 | "io"
|
---|
| 7 | "sync"
|
---|
| 8 |
|
---|
| 9 | "github.com/valyala/bytebufferpool"
|
---|
| 10 | )
|
---|
| 11 |
|
---|
| 12 | type requestStream struct {
|
---|
| 13 | header *RequestHeader
|
---|
| 14 | prefetchedBytes *bytes.Reader
|
---|
| 15 | reader *bufio.Reader
|
---|
| 16 | totalBytesRead int
|
---|
| 17 | chunkLeft int
|
---|
| 18 | }
|
---|
| 19 |
|
---|
| 20 | func (rs *requestStream) Read(p []byte) (int, error) {
|
---|
| 21 | var (
|
---|
| 22 | n int
|
---|
| 23 | err error
|
---|
| 24 | )
|
---|
| 25 | if rs.header.contentLength == -1 {
|
---|
| 26 | if rs.chunkLeft == 0 {
|
---|
| 27 | chunkSize, err := parseChunkSize(rs.reader)
|
---|
| 28 | if err != nil {
|
---|
| 29 | return 0, err
|
---|
| 30 | }
|
---|
| 31 | if chunkSize == 0 {
|
---|
| 32 | err = rs.header.ReadTrailer(rs.reader)
|
---|
| 33 | if err != nil && err != io.EOF {
|
---|
| 34 | return 0, err
|
---|
| 35 | }
|
---|
| 36 | return 0, io.EOF
|
---|
| 37 | }
|
---|
| 38 | rs.chunkLeft = chunkSize
|
---|
| 39 | }
|
---|
| 40 | bytesToRead := len(p)
|
---|
| 41 | if rs.chunkLeft < len(p) {
|
---|
| 42 | bytesToRead = rs.chunkLeft
|
---|
| 43 | }
|
---|
| 44 | n, err = rs.reader.Read(p[:bytesToRead])
|
---|
| 45 | rs.totalBytesRead += n
|
---|
| 46 | rs.chunkLeft -= n
|
---|
| 47 | if err == io.EOF {
|
---|
| 48 | err = io.ErrUnexpectedEOF
|
---|
| 49 | }
|
---|
| 50 | if err == nil && rs.chunkLeft == 0 {
|
---|
| 51 | err = readCrLf(rs.reader)
|
---|
| 52 | }
|
---|
| 53 | return n, err
|
---|
| 54 | }
|
---|
| 55 | if rs.totalBytesRead == rs.header.contentLength {
|
---|
| 56 | return 0, io.EOF
|
---|
| 57 | }
|
---|
| 58 | prefetchedSize := int(rs.prefetchedBytes.Size())
|
---|
| 59 | if prefetchedSize > rs.totalBytesRead {
|
---|
| 60 | left := prefetchedSize - rs.totalBytesRead
|
---|
| 61 | if len(p) > left {
|
---|
| 62 | p = p[:left]
|
---|
| 63 | }
|
---|
| 64 | n, err := rs.prefetchedBytes.Read(p)
|
---|
| 65 | rs.totalBytesRead += n
|
---|
| 66 | if n == rs.header.contentLength {
|
---|
| 67 | return n, io.EOF
|
---|
| 68 | }
|
---|
| 69 | return n, err
|
---|
| 70 | } else {
|
---|
| 71 | left := rs.header.contentLength - rs.totalBytesRead
|
---|
| 72 | if len(p) > left {
|
---|
| 73 | p = p[:left]
|
---|
| 74 | }
|
---|
| 75 | n, err = rs.reader.Read(p)
|
---|
| 76 | rs.totalBytesRead += n
|
---|
| 77 | if err != nil {
|
---|
| 78 | return n, err
|
---|
| 79 | }
|
---|
| 80 | }
|
---|
| 81 |
|
---|
| 82 | if rs.totalBytesRead == rs.header.contentLength {
|
---|
| 83 | err = io.EOF
|
---|
| 84 | }
|
---|
| 85 | return n, err
|
---|
| 86 | }
|
---|
| 87 |
|
---|
| 88 | func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h *RequestHeader) *requestStream {
|
---|
| 89 | rs := requestStreamPool.Get().(*requestStream)
|
---|
| 90 | rs.prefetchedBytes = bytes.NewReader(b.B)
|
---|
| 91 | rs.reader = r
|
---|
| 92 | rs.header = h
|
---|
| 93 | return rs
|
---|
| 94 | }
|
---|
| 95 |
|
---|
| 96 | func releaseRequestStream(rs *requestStream) {
|
---|
| 97 | rs.prefetchedBytes = nil
|
---|
| 98 | rs.totalBytesRead = 0
|
---|
| 99 | rs.chunkLeft = 0
|
---|
| 100 | rs.reader = nil
|
---|
| 101 | requestStreamPool.Put(rs)
|
---|
| 102 | }
|
---|
| 103 |
|
---|
| 104 | var requestStreamPool = sync.Pool{
|
---|
| 105 | New: func() interface{} {
|
---|
| 106 | return &requestStream{}
|
---|
| 107 | },
|
---|
| 108 | }
|
---|