1 | package proxyproto
|
---|
2 |
|
---|
3 | import (
|
---|
4 | "bufio"
|
---|
5 | "bytes"
|
---|
6 | "encoding/binary"
|
---|
7 | "errors"
|
---|
8 | "io"
|
---|
9 | "net"
|
---|
10 | )
|
---|
11 |
|
---|
12 | var (
|
---|
13 | lengthUnspec = uint16(0)
|
---|
14 | lengthV4 = uint16(12)
|
---|
15 | lengthV6 = uint16(36)
|
---|
16 | lengthUnix = uint16(216)
|
---|
17 | lengthUnspecBytes = func() []byte {
|
---|
18 | a := make([]byte, 2)
|
---|
19 | binary.BigEndian.PutUint16(a, lengthUnspec)
|
---|
20 | return a
|
---|
21 | }()
|
---|
22 | lengthV4Bytes = func() []byte {
|
---|
23 | a := make([]byte, 2)
|
---|
24 | binary.BigEndian.PutUint16(a, lengthV4)
|
---|
25 | return a
|
---|
26 | }()
|
---|
27 | lengthV6Bytes = func() []byte {
|
---|
28 | a := make([]byte, 2)
|
---|
29 | binary.BigEndian.PutUint16(a, lengthV6)
|
---|
30 | return a
|
---|
31 | }()
|
---|
32 | lengthUnixBytes = func() []byte {
|
---|
33 | a := make([]byte, 2)
|
---|
34 | binary.BigEndian.PutUint16(a, lengthUnix)
|
---|
35 | return a
|
---|
36 | }()
|
---|
37 | errUint16Overflow = errors.New("proxyproto: uint16 overflow")
|
---|
38 | )
|
---|
39 |
|
---|
40 | type _ports struct {
|
---|
41 | SrcPort uint16
|
---|
42 | DstPort uint16
|
---|
43 | }
|
---|
44 |
|
---|
45 | type _addr4 struct {
|
---|
46 | Src [4]byte
|
---|
47 | Dst [4]byte
|
---|
48 | SrcPort uint16
|
---|
49 | DstPort uint16
|
---|
50 | }
|
---|
51 |
|
---|
52 | type _addr6 struct {
|
---|
53 | Src [16]byte
|
---|
54 | Dst [16]byte
|
---|
55 | _ports
|
---|
56 | }
|
---|
57 |
|
---|
58 | type _addrUnix struct {
|
---|
59 | Src [108]byte
|
---|
60 | Dst [108]byte
|
---|
61 | }
|
---|
62 |
|
---|
63 | func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
|
---|
64 | // Skip first 12 bytes (signature)
|
---|
65 | for i := 0; i < 12; i++ {
|
---|
66 | if _, err = reader.ReadByte(); err != nil {
|
---|
67 | return nil, ErrCantReadProtocolVersionAndCommand
|
---|
68 | }
|
---|
69 | }
|
---|
70 |
|
---|
71 | header = new(Header)
|
---|
72 | header.Version = 2
|
---|
73 |
|
---|
74 | // Read the 13th byte, protocol version and command
|
---|
75 | b13, err := reader.ReadByte()
|
---|
76 | if err != nil {
|
---|
77 | return nil, ErrCantReadProtocolVersionAndCommand
|
---|
78 | }
|
---|
79 | header.Command = ProtocolVersionAndCommand(b13)
|
---|
80 | if _, ok := supportedCommand[header.Command]; !ok {
|
---|
81 | return nil, ErrUnsupportedProtocolVersionAndCommand
|
---|
82 | }
|
---|
83 |
|
---|
84 | // Read the 14th byte, address family and protocol
|
---|
85 | b14, err := reader.ReadByte()
|
---|
86 | if err != nil {
|
---|
87 | return nil, ErrCantReadAddressFamilyAndProtocol
|
---|
88 | }
|
---|
89 | header.TransportProtocol = AddressFamilyAndProtocol(b14)
|
---|
90 | // UNSPEC is only supported when LOCAL is set.
|
---|
91 | if header.TransportProtocol == UNSPEC && header.Command != LOCAL {
|
---|
92 | return nil, ErrUnsupportedAddressFamilyAndProtocol
|
---|
93 | }
|
---|
94 |
|
---|
95 | // Make sure there are bytes available as specified in length
|
---|
96 | var length uint16
|
---|
97 | if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil {
|
---|
98 | return nil, ErrCantReadLength
|
---|
99 | }
|
---|
100 | if !header.validateLength(length) {
|
---|
101 | return nil, ErrInvalidLength
|
---|
102 | }
|
---|
103 |
|
---|
104 | // Return early if the length is zero, which means that
|
---|
105 | // there's no address information and TLVs present for UNSPEC.
|
---|
106 | if length == 0 {
|
---|
107 | return header, nil
|
---|
108 | }
|
---|
109 |
|
---|
110 | if _, err := reader.Peek(int(length)); err != nil {
|
---|
111 | return nil, ErrInvalidLength
|
---|
112 | }
|
---|
113 |
|
---|
114 | // Length-limited reader for payload section
|
---|
115 | payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader)
|
---|
116 |
|
---|
117 | // Read addresses and ports for protocols other than UNSPEC.
|
---|
118 | // Ignore address information for UNSPEC, and skip straight to read TLVs,
|
---|
119 | // since the length is greater than zero.
|
---|
120 | if header.TransportProtocol != UNSPEC {
|
---|
121 | if header.TransportProtocol.IsIPv4() {
|
---|
122 | var addr _addr4
|
---|
123 | if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
|
---|
124 | return nil, ErrInvalidAddress
|
---|
125 | }
|
---|
126 | header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
|
---|
127 | header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
|
---|
128 | } else if header.TransportProtocol.IsIPv6() {
|
---|
129 | var addr _addr6
|
---|
130 | if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
|
---|
131 | return nil, ErrInvalidAddress
|
---|
132 | }
|
---|
133 | header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
|
---|
134 | header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
|
---|
135 | } else if header.TransportProtocol.IsUnix() {
|
---|
136 | var addr _addrUnix
|
---|
137 | if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
|
---|
138 | return nil, ErrInvalidAddress
|
---|
139 | }
|
---|
140 |
|
---|
141 | network := "unix"
|
---|
142 | if header.TransportProtocol.IsDatagram() {
|
---|
143 | network = "unixgram"
|
---|
144 | }
|
---|
145 |
|
---|
146 | header.SourceAddr = &net.UnixAddr{
|
---|
147 | Net: network,
|
---|
148 | Name: parseUnixName(addr.Src[:]),
|
---|
149 | }
|
---|
150 | header.DestinationAddr = &net.UnixAddr{
|
---|
151 | Net: network,
|
---|
152 | Name: parseUnixName(addr.Dst[:]),
|
---|
153 | }
|
---|
154 | }
|
---|
155 | }
|
---|
156 |
|
---|
157 | // Copy bytes for optional Type-Length-Value vector
|
---|
158 | header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice
|
---|
159 | if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF {
|
---|
160 | return nil, err
|
---|
161 | }
|
---|
162 |
|
---|
163 | return header, nil
|
---|
164 | }
|
---|
165 |
|
---|
166 | func (header *Header) formatVersion2() ([]byte, error) {
|
---|
167 | var buf bytes.Buffer
|
---|
168 | buf.Write(SIGV2)
|
---|
169 | buf.WriteByte(header.Command.toByte())
|
---|
170 | buf.WriteByte(header.TransportProtocol.toByte())
|
---|
171 | if header.TransportProtocol.IsUnspec() {
|
---|
172 | // For UNSPEC, write no addresses and ports but only TLVs if they are present
|
---|
173 | hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs))
|
---|
174 | if err != nil {
|
---|
175 | return nil, err
|
---|
176 | }
|
---|
177 | buf.Write(hdrLen)
|
---|
178 | } else {
|
---|
179 | var addrSrc, addrDst []byte
|
---|
180 | if header.TransportProtocol.IsIPv4() {
|
---|
181 | hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs))
|
---|
182 | if err != nil {
|
---|
183 | return nil, err
|
---|
184 | }
|
---|
185 | buf.Write(hdrLen)
|
---|
186 | sourceIP, destIP, _ := header.IPs()
|
---|
187 | addrSrc = sourceIP.To4()
|
---|
188 | addrDst = destIP.To4()
|
---|
189 | } else if header.TransportProtocol.IsIPv6() {
|
---|
190 | hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs))
|
---|
191 | if err != nil {
|
---|
192 | return nil, err
|
---|
193 | }
|
---|
194 | buf.Write(hdrLen)
|
---|
195 | sourceIP, destIP, _ := header.IPs()
|
---|
196 | addrSrc = sourceIP.To16()
|
---|
197 | addrDst = destIP.To16()
|
---|
198 | } else if header.TransportProtocol.IsUnix() {
|
---|
199 | buf.Write(lengthUnixBytes)
|
---|
200 | sourceAddr, destAddr, ok := header.UnixAddrs()
|
---|
201 | if !ok {
|
---|
202 | return nil, ErrInvalidAddress
|
---|
203 | }
|
---|
204 | addrSrc = formatUnixName(sourceAddr.Name)
|
---|
205 | addrDst = formatUnixName(destAddr.Name)
|
---|
206 | }
|
---|
207 |
|
---|
208 | if addrSrc == nil || addrDst == nil {
|
---|
209 | return nil, ErrInvalidAddress
|
---|
210 | }
|
---|
211 | buf.Write(addrSrc)
|
---|
212 | buf.Write(addrDst)
|
---|
213 |
|
---|
214 | if sourcePort, destPort, ok := header.Ports(); ok {
|
---|
215 | portBytes := make([]byte, 2)
|
---|
216 |
|
---|
217 | binary.BigEndian.PutUint16(portBytes, uint16(sourcePort))
|
---|
218 | buf.Write(portBytes)
|
---|
219 |
|
---|
220 | binary.BigEndian.PutUint16(portBytes, uint16(destPort))
|
---|
221 | buf.Write(portBytes)
|
---|
222 | }
|
---|
223 | }
|
---|
224 |
|
---|
225 | if len(header.rawTLVs) > 0 {
|
---|
226 | buf.Write(header.rawTLVs)
|
---|
227 | }
|
---|
228 |
|
---|
229 | return buf.Bytes(), nil
|
---|
230 | }
|
---|
231 |
|
---|
232 | func (header *Header) validateLength(length uint16) bool {
|
---|
233 | if header.TransportProtocol.IsIPv4() {
|
---|
234 | return length >= lengthV4
|
---|
235 | } else if header.TransportProtocol.IsIPv6() {
|
---|
236 | return length >= lengthV6
|
---|
237 | } else if header.TransportProtocol.IsUnix() {
|
---|
238 | return length >= lengthUnix
|
---|
239 | } else if header.TransportProtocol.IsUnspec() {
|
---|
240 | return length >= lengthUnspec
|
---|
241 | }
|
---|
242 | return false
|
---|
243 | }
|
---|
244 |
|
---|
245 | // addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow.
|
---|
246 | func addTLVLen(cur []byte, tlvLen int) ([]byte, error) {
|
---|
247 | if tlvLen == 0 {
|
---|
248 | return cur, nil
|
---|
249 | }
|
---|
250 | curLen := binary.BigEndian.Uint16(cur)
|
---|
251 | newLen := int(curLen) + tlvLen
|
---|
252 | if newLen >= 1<<16 {
|
---|
253 | return nil, errUint16Overflow
|
---|
254 | }
|
---|
255 | a := make([]byte, 2)
|
---|
256 | binary.BigEndian.PutUint16(a, uint16(newLen))
|
---|
257 | return a, nil
|
---|
258 | }
|
---|
259 |
|
---|
260 | func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr {
|
---|
261 | if transport.IsStream() {
|
---|
262 | return &net.TCPAddr{IP: ip, Port: int(port)}
|
---|
263 | } else if transport.IsDatagram() {
|
---|
264 | return &net.UDPAddr{IP: ip, Port: int(port)}
|
---|
265 | } else {
|
---|
266 | return nil
|
---|
267 | }
|
---|
268 | }
|
---|
269 |
|
---|
270 | func parseUnixName(b []byte) string {
|
---|
271 | i := bytes.IndexByte(b, 0)
|
---|
272 | if i < 0 {
|
---|
273 | return string(b)
|
---|
274 | }
|
---|
275 | return string(b[:i])
|
---|
276 | }
|
---|
277 |
|
---|
278 | func formatUnixName(name string) []byte {
|
---|
279 | n := int(lengthUnix) / 2
|
---|
280 | if len(name) >= n {
|
---|
281 | return []byte(name[:n])
|
---|
282 | }
|
---|
283 | pad := make([]byte, n-len(name))
|
---|
284 | return append([]byte(name), pad...)
|
---|
285 | }
|
---|