[822] | 1 | // Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification:
|
---|
| 2 | // https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt
|
---|
| 3 | package proxyproto
|
---|
| 4 |
|
---|
| 5 | import (
|
---|
| 6 | "bufio"
|
---|
| 7 | "bytes"
|
---|
| 8 | "errors"
|
---|
| 9 | "io"
|
---|
| 10 | "net"
|
---|
| 11 | "time"
|
---|
| 12 | )
|
---|
| 13 |
|
---|
| 14 | var (
|
---|
| 15 | // Protocol
|
---|
| 16 | SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'}
|
---|
| 17 | SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'}
|
---|
| 18 |
|
---|
| 19 | ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header")
|
---|
| 20 | ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less")
|
---|
| 21 | ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n")
|
---|
| 22 | ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command")
|
---|
| 23 | ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol")
|
---|
| 24 | ErrCantReadLength = errors.New("proxyproto: can't read length")
|
---|
| 25 | ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address")
|
---|
| 26 | ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address")
|
---|
| 27 | ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present")
|
---|
| 28 | ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version")
|
---|
| 29 | ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command")
|
---|
| 30 | ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol")
|
---|
| 31 | ErrInvalidLength = errors.New("proxyproto: invalid length")
|
---|
| 32 | ErrInvalidAddress = errors.New("proxyproto: invalid address")
|
---|
| 33 | ErrInvalidPortNumber = errors.New("proxyproto: invalid port number")
|
---|
| 34 | ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one")
|
---|
| 35 | )
|
---|
| 36 |
|
---|
| 37 | // Header is the placeholder for proxy protocol header.
|
---|
| 38 | type Header struct {
|
---|
| 39 | Version byte
|
---|
| 40 | Command ProtocolVersionAndCommand
|
---|
| 41 | TransportProtocol AddressFamilyAndProtocol
|
---|
| 42 | SourceAddr net.Addr
|
---|
| 43 | DestinationAddr net.Addr
|
---|
| 44 | rawTLVs []byte
|
---|
| 45 | }
|
---|
| 46 |
|
---|
| 47 | // HeaderProxyFromAddrs creates a new PROXY header from a source and a
|
---|
| 48 | // destination address. If version is zero, the latest protocol version is
|
---|
| 49 | // used.
|
---|
| 50 | //
|
---|
| 51 | // The header is filled on a best-effort basis: if hints cannot be inferred
|
---|
| 52 | // from the provided addresses, the header will be left unspecified.
|
---|
| 53 | func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header {
|
---|
| 54 | if version < 1 || version > 2 {
|
---|
| 55 | version = 2
|
---|
| 56 | }
|
---|
| 57 | h := &Header{
|
---|
| 58 | Version: version,
|
---|
| 59 | Command: LOCAL,
|
---|
| 60 | TransportProtocol: UNSPEC,
|
---|
| 61 | }
|
---|
| 62 | switch sourceAddr := sourceAddr.(type) {
|
---|
| 63 | case *net.TCPAddr:
|
---|
| 64 | if _, ok := destAddr.(*net.TCPAddr); !ok {
|
---|
| 65 | break
|
---|
| 66 | }
|
---|
| 67 | if len(sourceAddr.IP.To4()) == net.IPv4len {
|
---|
| 68 | h.TransportProtocol = TCPv4
|
---|
| 69 | } else if len(sourceAddr.IP) == net.IPv6len {
|
---|
| 70 | h.TransportProtocol = TCPv6
|
---|
| 71 | }
|
---|
| 72 | case *net.UDPAddr:
|
---|
| 73 | if _, ok := destAddr.(*net.UDPAddr); !ok {
|
---|
| 74 | break
|
---|
| 75 | }
|
---|
| 76 | if len(sourceAddr.IP.To4()) == net.IPv4len {
|
---|
| 77 | h.TransportProtocol = UDPv4
|
---|
| 78 | } else if len(sourceAddr.IP) == net.IPv6len {
|
---|
| 79 | h.TransportProtocol = UDPv6
|
---|
| 80 | }
|
---|
| 81 | case *net.UnixAddr:
|
---|
| 82 | if _, ok := destAddr.(*net.UnixAddr); !ok {
|
---|
| 83 | break
|
---|
| 84 | }
|
---|
| 85 | switch sourceAddr.Net {
|
---|
| 86 | case "unix":
|
---|
| 87 | h.TransportProtocol = UnixStream
|
---|
| 88 | case "unixgram":
|
---|
| 89 | h.TransportProtocol = UnixDatagram
|
---|
| 90 | }
|
---|
| 91 | }
|
---|
| 92 | if h.TransportProtocol != UNSPEC {
|
---|
| 93 | h.Command = PROXY
|
---|
| 94 | h.SourceAddr = sourceAddr
|
---|
| 95 | h.DestinationAddr = destAddr
|
---|
| 96 | }
|
---|
| 97 | return h
|
---|
| 98 | }
|
---|
| 99 |
|
---|
| 100 | func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) {
|
---|
| 101 | if !header.TransportProtocol.IsStream() {
|
---|
| 102 | return nil, nil, false
|
---|
| 103 | }
|
---|
| 104 | sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr)
|
---|
| 105 | destAddr, destOK := header.DestinationAddr.(*net.TCPAddr)
|
---|
| 106 | return sourceAddr, destAddr, sourceOK && destOK
|
---|
| 107 | }
|
---|
| 108 |
|
---|
| 109 | func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) {
|
---|
| 110 | if !header.TransportProtocol.IsDatagram() {
|
---|
| 111 | return nil, nil, false
|
---|
| 112 | }
|
---|
| 113 | sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr)
|
---|
| 114 | destAddr, destOK := header.DestinationAddr.(*net.UDPAddr)
|
---|
| 115 | return sourceAddr, destAddr, sourceOK && destOK
|
---|
| 116 | }
|
---|
| 117 |
|
---|
| 118 | func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) {
|
---|
| 119 | if !header.TransportProtocol.IsUnix() {
|
---|
| 120 | return nil, nil, false
|
---|
| 121 | }
|
---|
| 122 | sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr)
|
---|
| 123 | destAddr, destOK := header.DestinationAddr.(*net.UnixAddr)
|
---|
| 124 | return sourceAddr, destAddr, sourceOK && destOK
|
---|
| 125 | }
|
---|
| 126 |
|
---|
| 127 | func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) {
|
---|
| 128 | if sourceAddr, destAddr, ok := header.TCPAddrs(); ok {
|
---|
| 129 | return sourceAddr.IP, destAddr.IP, true
|
---|
| 130 | } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok {
|
---|
| 131 | return sourceAddr.IP, destAddr.IP, true
|
---|
| 132 | } else {
|
---|
| 133 | return nil, nil, false
|
---|
| 134 | }
|
---|
| 135 | }
|
---|
| 136 |
|
---|
| 137 | func (header *Header) Ports() (sourcePort, destPort int, ok bool) {
|
---|
| 138 | if sourceAddr, destAddr, ok := header.TCPAddrs(); ok {
|
---|
| 139 | return sourceAddr.Port, destAddr.Port, true
|
---|
| 140 | } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok {
|
---|
| 141 | return sourceAddr.Port, destAddr.Port, true
|
---|
| 142 | } else {
|
---|
| 143 | return 0, 0, false
|
---|
| 144 | }
|
---|
| 145 | }
|
---|
| 146 |
|
---|
| 147 | // EqualTo returns true if headers are equivalent, false otherwise.
|
---|
| 148 | // Deprecated: use EqualsTo instead. This method will eventually be removed.
|
---|
| 149 | func (header *Header) EqualTo(otherHeader *Header) bool {
|
---|
| 150 | return header.EqualsTo(otherHeader)
|
---|
| 151 | }
|
---|
| 152 |
|
---|
| 153 | // EqualsTo returns true if headers are equivalent, false otherwise.
|
---|
| 154 | func (header *Header) EqualsTo(otherHeader *Header) bool {
|
---|
| 155 | if otherHeader == nil {
|
---|
| 156 | return false
|
---|
| 157 | }
|
---|
| 158 | // TLVs only exist for version 2
|
---|
| 159 | if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) {
|
---|
| 160 | return false
|
---|
| 161 | }
|
---|
| 162 | if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol {
|
---|
| 163 | return false
|
---|
| 164 | }
|
---|
| 165 | // Return early for header with LOCAL command, which contains no address information
|
---|
| 166 | if header.Command == LOCAL {
|
---|
| 167 | return true
|
---|
| 168 | }
|
---|
| 169 | return header.SourceAddr.String() == otherHeader.SourceAddr.String() &&
|
---|
| 170 | header.DestinationAddr.String() == otherHeader.DestinationAddr.String()
|
---|
| 171 | }
|
---|
| 172 |
|
---|
| 173 | // WriteTo renders a proxy protocol header in a format and writes it to an io.Writer.
|
---|
| 174 | func (header *Header) WriteTo(w io.Writer) (int64, error) {
|
---|
| 175 | buf, err := header.Format()
|
---|
| 176 | if err != nil {
|
---|
| 177 | return 0, err
|
---|
| 178 | }
|
---|
| 179 |
|
---|
| 180 | return bytes.NewBuffer(buf).WriteTo(w)
|
---|
| 181 | }
|
---|
| 182 |
|
---|
| 183 | // Format renders a proxy protocol header in a format to write over the wire.
|
---|
| 184 | func (header *Header) Format() ([]byte, error) {
|
---|
| 185 | switch header.Version {
|
---|
| 186 | case 1:
|
---|
| 187 | return header.formatVersion1()
|
---|
| 188 | case 2:
|
---|
| 189 | return header.formatVersion2()
|
---|
| 190 | default:
|
---|
| 191 | return nil, ErrUnknownProxyProtocolVersion
|
---|
| 192 | }
|
---|
| 193 | }
|
---|
| 194 |
|
---|
| 195 | // TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol.
|
---|
| 196 | func (header *Header) TLVs() ([]TLV, error) {
|
---|
| 197 | return SplitTLVs(header.rawTLVs)
|
---|
| 198 | }
|
---|
| 199 |
|
---|
| 200 | // SetTLVs sets the TLVs stored in this header. This method replaces any
|
---|
| 201 | // previous TLV.
|
---|
| 202 | func (header *Header) SetTLVs(tlvs []TLV) error {
|
---|
| 203 | raw, err := JoinTLVs(tlvs)
|
---|
| 204 | if err != nil {
|
---|
| 205 | return err
|
---|
| 206 | }
|
---|
| 207 | header.rawTLVs = raw
|
---|
| 208 | return nil
|
---|
| 209 | }
|
---|
| 210 |
|
---|
| 211 | // Read identifies the proxy protocol version and reads the remaining of
|
---|
| 212 | // the header, accordingly.
|
---|
| 213 | //
|
---|
| 214 | // If proxy protocol header signature is not present, the reader buffer remains untouched
|
---|
| 215 | // and is safe for reading outside of this code.
|
---|
| 216 | //
|
---|
| 217 | // If proxy protocol header signature is present but an error is raised while processing
|
---|
| 218 | // the remaining header, assume the reader buffer to be in a corrupt state.
|
---|
| 219 | // Also, this operation will block until enough bytes are available for peeking.
|
---|
| 220 | func Read(reader *bufio.Reader) (*Header, error) {
|
---|
| 221 | // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone.
|
---|
| 222 | b1, err := reader.Peek(1)
|
---|
| 223 | if err != nil {
|
---|
| 224 | if err == io.EOF {
|
---|
| 225 | return nil, ErrNoProxyProtocol
|
---|
| 226 | }
|
---|
| 227 | return nil, err
|
---|
| 228 | }
|
---|
| 229 |
|
---|
| 230 | if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) {
|
---|
| 231 | signature, err := reader.Peek(5)
|
---|
| 232 | if err != nil {
|
---|
| 233 | if err == io.EOF {
|
---|
| 234 | return nil, ErrNoProxyProtocol
|
---|
| 235 | }
|
---|
| 236 | return nil, err
|
---|
| 237 | }
|
---|
| 238 | if bytes.Equal(signature[:5], SIGV1) {
|
---|
| 239 | return parseVersion1(reader)
|
---|
| 240 | }
|
---|
| 241 |
|
---|
| 242 | signature, err = reader.Peek(12)
|
---|
| 243 | if err != nil {
|
---|
| 244 | if err == io.EOF {
|
---|
| 245 | return nil, ErrNoProxyProtocol
|
---|
| 246 | }
|
---|
| 247 | return nil, err
|
---|
| 248 | }
|
---|
| 249 | if bytes.Equal(signature[:12], SIGV2) {
|
---|
| 250 | return parseVersion2(reader)
|
---|
| 251 | }
|
---|
| 252 | }
|
---|
| 253 |
|
---|
| 254 | return nil, ErrNoProxyProtocol
|
---|
| 255 | }
|
---|
| 256 |
|
---|
| 257 | // ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed
|
---|
| 258 | // there's no proxy protocol header.
|
---|
| 259 | func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) {
|
---|
| 260 | type header struct {
|
---|
| 261 | h *Header
|
---|
| 262 | e error
|
---|
| 263 | }
|
---|
| 264 | read := make(chan *header, 1)
|
---|
| 265 |
|
---|
| 266 | go func() {
|
---|
| 267 | h := &header{}
|
---|
| 268 | h.h, h.e = Read(reader)
|
---|
| 269 | read <- h
|
---|
| 270 | }()
|
---|
| 271 |
|
---|
| 272 | timer := time.NewTimer(timeout)
|
---|
| 273 | select {
|
---|
| 274 | case result := <-read:
|
---|
| 275 | timer.Stop()
|
---|
| 276 | return result.h, result.e
|
---|
| 277 | case <-timer.C:
|
---|
| 278 | return nil, ErrNoProxyProtocol
|
---|
| 279 | }
|
---|
| 280 | }
|
---|