1 | // Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification:
|
---|
2 | // https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt
|
---|
3 | package proxyproto
|
---|
4 |
|
---|
5 | import (
|
---|
6 | "bufio"
|
---|
7 | "bytes"
|
---|
8 | "errors"
|
---|
9 | "io"
|
---|
10 | "net"
|
---|
11 | "time"
|
---|
12 | )
|
---|
13 |
|
---|
14 | var (
|
---|
15 | // Protocol
|
---|
16 | SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'}
|
---|
17 | SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'}
|
---|
18 |
|
---|
19 | ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header")
|
---|
20 | ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less")
|
---|
21 | ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n")
|
---|
22 | ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command")
|
---|
23 | ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol")
|
---|
24 | ErrCantReadLength = errors.New("proxyproto: can't read length")
|
---|
25 | ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address")
|
---|
26 | ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address")
|
---|
27 | ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present")
|
---|
28 | ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version")
|
---|
29 | ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command")
|
---|
30 | ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol")
|
---|
31 | ErrInvalidLength = errors.New("proxyproto: invalid length")
|
---|
32 | ErrInvalidAddress = errors.New("proxyproto: invalid address")
|
---|
33 | ErrInvalidPortNumber = errors.New("proxyproto: invalid port number")
|
---|
34 | ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one")
|
---|
35 | )
|
---|
36 |
|
---|
37 | // Header is the placeholder for proxy protocol header.
|
---|
38 | type Header struct {
|
---|
39 | Version byte
|
---|
40 | Command ProtocolVersionAndCommand
|
---|
41 | TransportProtocol AddressFamilyAndProtocol
|
---|
42 | SourceAddr net.Addr
|
---|
43 | DestinationAddr net.Addr
|
---|
44 | rawTLVs []byte
|
---|
45 | }
|
---|
46 |
|
---|
47 | // HeaderProxyFromAddrs creates a new PROXY header from a source and a
|
---|
48 | // destination address. If version is zero, the latest protocol version is
|
---|
49 | // used.
|
---|
50 | //
|
---|
51 | // The header is filled on a best-effort basis: if hints cannot be inferred
|
---|
52 | // from the provided addresses, the header will be left unspecified.
|
---|
53 | func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header {
|
---|
54 | if version < 1 || version > 2 {
|
---|
55 | version = 2
|
---|
56 | }
|
---|
57 | h := &Header{
|
---|
58 | Version: version,
|
---|
59 | Command: LOCAL,
|
---|
60 | TransportProtocol: UNSPEC,
|
---|
61 | }
|
---|
62 | switch sourceAddr := sourceAddr.(type) {
|
---|
63 | case *net.TCPAddr:
|
---|
64 | if _, ok := destAddr.(*net.TCPAddr); !ok {
|
---|
65 | break
|
---|
66 | }
|
---|
67 | if len(sourceAddr.IP.To4()) == net.IPv4len {
|
---|
68 | h.TransportProtocol = TCPv4
|
---|
69 | } else if len(sourceAddr.IP) == net.IPv6len {
|
---|
70 | h.TransportProtocol = TCPv6
|
---|
71 | }
|
---|
72 | case *net.UDPAddr:
|
---|
73 | if _, ok := destAddr.(*net.UDPAddr); !ok {
|
---|
74 | break
|
---|
75 | }
|
---|
76 | if len(sourceAddr.IP.To4()) == net.IPv4len {
|
---|
77 | h.TransportProtocol = UDPv4
|
---|
78 | } else if len(sourceAddr.IP) == net.IPv6len {
|
---|
79 | h.TransportProtocol = UDPv6
|
---|
80 | }
|
---|
81 | case *net.UnixAddr:
|
---|
82 | if _, ok := destAddr.(*net.UnixAddr); !ok {
|
---|
83 | break
|
---|
84 | }
|
---|
85 | switch sourceAddr.Net {
|
---|
86 | case "unix":
|
---|
87 | h.TransportProtocol = UnixStream
|
---|
88 | case "unixgram":
|
---|
89 | h.TransportProtocol = UnixDatagram
|
---|
90 | }
|
---|
91 | }
|
---|
92 | if h.TransportProtocol != UNSPEC {
|
---|
93 | h.Command = PROXY
|
---|
94 | h.SourceAddr = sourceAddr
|
---|
95 | h.DestinationAddr = destAddr
|
---|
96 | }
|
---|
97 | return h
|
---|
98 | }
|
---|
99 |
|
---|
100 | func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) {
|
---|
101 | if !header.TransportProtocol.IsStream() {
|
---|
102 | return nil, nil, false
|
---|
103 | }
|
---|
104 | sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr)
|
---|
105 | destAddr, destOK := header.DestinationAddr.(*net.TCPAddr)
|
---|
106 | return sourceAddr, destAddr, sourceOK && destOK
|
---|
107 | }
|
---|
108 |
|
---|
109 | func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) {
|
---|
110 | if !header.TransportProtocol.IsDatagram() {
|
---|
111 | return nil, nil, false
|
---|
112 | }
|
---|
113 | sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr)
|
---|
114 | destAddr, destOK := header.DestinationAddr.(*net.UDPAddr)
|
---|
115 | return sourceAddr, destAddr, sourceOK && destOK
|
---|
116 | }
|
---|
117 |
|
---|
118 | func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) {
|
---|
119 | if !header.TransportProtocol.IsUnix() {
|
---|
120 | return nil, nil, false
|
---|
121 | }
|
---|
122 | sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr)
|
---|
123 | destAddr, destOK := header.DestinationAddr.(*net.UnixAddr)
|
---|
124 | return sourceAddr, destAddr, sourceOK && destOK
|
---|
125 | }
|
---|
126 |
|
---|
127 | func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) {
|
---|
128 | if sourceAddr, destAddr, ok := header.TCPAddrs(); ok {
|
---|
129 | return sourceAddr.IP, destAddr.IP, true
|
---|
130 | } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok {
|
---|
131 | return sourceAddr.IP, destAddr.IP, true
|
---|
132 | } else {
|
---|
133 | return nil, nil, false
|
---|
134 | }
|
---|
135 | }
|
---|
136 |
|
---|
137 | func (header *Header) Ports() (sourcePort, destPort int, ok bool) {
|
---|
138 | if sourceAddr, destAddr, ok := header.TCPAddrs(); ok {
|
---|
139 | return sourceAddr.Port, destAddr.Port, true
|
---|
140 | } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok {
|
---|
141 | return sourceAddr.Port, destAddr.Port, true
|
---|
142 | } else {
|
---|
143 | return 0, 0, false
|
---|
144 | }
|
---|
145 | }
|
---|
146 |
|
---|
147 | // EqualTo returns true if headers are equivalent, false otherwise.
|
---|
148 | // Deprecated: use EqualsTo instead. This method will eventually be removed.
|
---|
149 | func (header *Header) EqualTo(otherHeader *Header) bool {
|
---|
150 | return header.EqualsTo(otherHeader)
|
---|
151 | }
|
---|
152 |
|
---|
153 | // EqualsTo returns true if headers are equivalent, false otherwise.
|
---|
154 | func (header *Header) EqualsTo(otherHeader *Header) bool {
|
---|
155 | if otherHeader == nil {
|
---|
156 | return false
|
---|
157 | }
|
---|
158 | // TLVs only exist for version 2
|
---|
159 | if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) {
|
---|
160 | return false
|
---|
161 | }
|
---|
162 | if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol {
|
---|
163 | return false
|
---|
164 | }
|
---|
165 | // Return early for header with LOCAL command, which contains no address information
|
---|
166 | if header.Command == LOCAL {
|
---|
167 | return true
|
---|
168 | }
|
---|
169 | return header.SourceAddr.String() == otherHeader.SourceAddr.String() &&
|
---|
170 | header.DestinationAddr.String() == otherHeader.DestinationAddr.String()
|
---|
171 | }
|
---|
172 |
|
---|
173 | // WriteTo renders a proxy protocol header in a format and writes it to an io.Writer.
|
---|
174 | func (header *Header) WriteTo(w io.Writer) (int64, error) {
|
---|
175 | buf, err := header.Format()
|
---|
176 | if err != nil {
|
---|
177 | return 0, err
|
---|
178 | }
|
---|
179 |
|
---|
180 | return bytes.NewBuffer(buf).WriteTo(w)
|
---|
181 | }
|
---|
182 |
|
---|
183 | // Format renders a proxy protocol header in a format to write over the wire.
|
---|
184 | func (header *Header) Format() ([]byte, error) {
|
---|
185 | switch header.Version {
|
---|
186 | case 1:
|
---|
187 | return header.formatVersion1()
|
---|
188 | case 2:
|
---|
189 | return header.formatVersion2()
|
---|
190 | default:
|
---|
191 | return nil, ErrUnknownProxyProtocolVersion
|
---|
192 | }
|
---|
193 | }
|
---|
194 |
|
---|
195 | // TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol.
|
---|
196 | func (header *Header) TLVs() ([]TLV, error) {
|
---|
197 | return SplitTLVs(header.rawTLVs)
|
---|
198 | }
|
---|
199 |
|
---|
200 | // SetTLVs sets the TLVs stored in this header. This method replaces any
|
---|
201 | // previous TLV.
|
---|
202 | func (header *Header) SetTLVs(tlvs []TLV) error {
|
---|
203 | raw, err := JoinTLVs(tlvs)
|
---|
204 | if err != nil {
|
---|
205 | return err
|
---|
206 | }
|
---|
207 | header.rawTLVs = raw
|
---|
208 | return nil
|
---|
209 | }
|
---|
210 |
|
---|
211 | // Read identifies the proxy protocol version and reads the remaining of
|
---|
212 | // the header, accordingly.
|
---|
213 | //
|
---|
214 | // If proxy protocol header signature is not present, the reader buffer remains untouched
|
---|
215 | // and is safe for reading outside of this code.
|
---|
216 | //
|
---|
217 | // If proxy protocol header signature is present but an error is raised while processing
|
---|
218 | // the remaining header, assume the reader buffer to be in a corrupt state.
|
---|
219 | // Also, this operation will block until enough bytes are available for peeking.
|
---|
220 | func Read(reader *bufio.Reader) (*Header, error) {
|
---|
221 | // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone.
|
---|
222 | b1, err := reader.Peek(1)
|
---|
223 | if err != nil {
|
---|
224 | if err == io.EOF {
|
---|
225 | return nil, ErrNoProxyProtocol
|
---|
226 | }
|
---|
227 | return nil, err
|
---|
228 | }
|
---|
229 |
|
---|
230 | if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) {
|
---|
231 | signature, err := reader.Peek(5)
|
---|
232 | if err != nil {
|
---|
233 | if err == io.EOF {
|
---|
234 | return nil, ErrNoProxyProtocol
|
---|
235 | }
|
---|
236 | return nil, err
|
---|
237 | }
|
---|
238 | if bytes.Equal(signature[:5], SIGV1) {
|
---|
239 | return parseVersion1(reader)
|
---|
240 | }
|
---|
241 |
|
---|
242 | signature, err = reader.Peek(12)
|
---|
243 | if err != nil {
|
---|
244 | if err == io.EOF {
|
---|
245 | return nil, ErrNoProxyProtocol
|
---|
246 | }
|
---|
247 | return nil, err
|
---|
248 | }
|
---|
249 | if bytes.Equal(signature[:12], SIGV2) {
|
---|
250 | return parseVersion2(reader)
|
---|
251 | }
|
---|
252 | }
|
---|
253 |
|
---|
254 | return nil, ErrNoProxyProtocol
|
---|
255 | }
|
---|
256 |
|
---|
257 | // ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed
|
---|
258 | // there's no proxy protocol header.
|
---|
259 | func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) {
|
---|
260 | type header struct {
|
---|
261 | h *Header
|
---|
262 | e error
|
---|
263 | }
|
---|
264 | read := make(chan *header, 1)
|
---|
265 |
|
---|
266 | go func() {
|
---|
267 | h := &header{}
|
---|
268 | h.h, h.e = Read(reader)
|
---|
269 | read <- h
|
---|
270 | }()
|
---|
271 |
|
---|
272 | timer := time.NewTimer(timeout)
|
---|
273 | select {
|
---|
274 | case result := <-read:
|
---|
275 | timer.Stop()
|
---|
276 | return result.h, result.e
|
---|
277 | case <-timer.C:
|
---|
278 | return nil, ErrNoProxyProtocol
|
---|
279 | }
|
---|
280 | }
|
---|