source: code/trunk/vendor/github.com/pires/go-proxyproto/v2.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 7.3 KB
RevLine 
[822]1package proxyproto
2
3import (
4 "bufio"
5 "bytes"
6 "encoding/binary"
7 "errors"
8 "io"
9 "net"
10)
11
12var (
13 lengthUnspec = uint16(0)
14 lengthV4 = uint16(12)
15 lengthV6 = uint16(36)
16 lengthUnix = uint16(216)
17 lengthUnspecBytes = func() []byte {
18 a := make([]byte, 2)
19 binary.BigEndian.PutUint16(a, lengthUnspec)
20 return a
21 }()
22 lengthV4Bytes = func() []byte {
23 a := make([]byte, 2)
24 binary.BigEndian.PutUint16(a, lengthV4)
25 return a
26 }()
27 lengthV6Bytes = func() []byte {
28 a := make([]byte, 2)
29 binary.BigEndian.PutUint16(a, lengthV6)
30 return a
31 }()
32 lengthUnixBytes = func() []byte {
33 a := make([]byte, 2)
34 binary.BigEndian.PutUint16(a, lengthUnix)
35 return a
36 }()
37 errUint16Overflow = errors.New("proxyproto: uint16 overflow")
38)
39
40type _ports struct {
41 SrcPort uint16
42 DstPort uint16
43}
44
45type _addr4 struct {
46 Src [4]byte
47 Dst [4]byte
48 SrcPort uint16
49 DstPort uint16
50}
51
52type _addr6 struct {
53 Src [16]byte
54 Dst [16]byte
55 _ports
56}
57
58type _addrUnix struct {
59 Src [108]byte
60 Dst [108]byte
61}
62
63func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
64 // Skip first 12 bytes (signature)
65 for i := 0; i < 12; i++ {
66 if _, err = reader.ReadByte(); err != nil {
67 return nil, ErrCantReadProtocolVersionAndCommand
68 }
69 }
70
71 header = new(Header)
72 header.Version = 2
73
74 // Read the 13th byte, protocol version and command
75 b13, err := reader.ReadByte()
76 if err != nil {
77 return nil, ErrCantReadProtocolVersionAndCommand
78 }
79 header.Command = ProtocolVersionAndCommand(b13)
80 if _, ok := supportedCommand[header.Command]; !ok {
81 return nil, ErrUnsupportedProtocolVersionAndCommand
82 }
83
84 // Read the 14th byte, address family and protocol
85 b14, err := reader.ReadByte()
86 if err != nil {
87 return nil, ErrCantReadAddressFamilyAndProtocol
88 }
89 header.TransportProtocol = AddressFamilyAndProtocol(b14)
90 // UNSPEC is only supported when LOCAL is set.
91 if header.TransportProtocol == UNSPEC && header.Command != LOCAL {
92 return nil, ErrUnsupportedAddressFamilyAndProtocol
93 }
94
95 // Make sure there are bytes available as specified in length
96 var length uint16
97 if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil {
98 return nil, ErrCantReadLength
99 }
100 if !header.validateLength(length) {
101 return nil, ErrInvalidLength
102 }
103
104 // Return early if the length is zero, which means that
105 // there's no address information and TLVs present for UNSPEC.
106 if length == 0 {
107 return header, nil
108 }
109
110 if _, err := reader.Peek(int(length)); err != nil {
111 return nil, ErrInvalidLength
112 }
113
114 // Length-limited reader for payload section
115 payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader)
116
117 // Read addresses and ports for protocols other than UNSPEC.
118 // Ignore address information for UNSPEC, and skip straight to read TLVs,
119 // since the length is greater than zero.
120 if header.TransportProtocol != UNSPEC {
121 if header.TransportProtocol.IsIPv4() {
122 var addr _addr4
123 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
124 return nil, ErrInvalidAddress
125 }
126 header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
127 header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
128 } else if header.TransportProtocol.IsIPv6() {
129 var addr _addr6
130 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
131 return nil, ErrInvalidAddress
132 }
133 header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
134 header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
135 } else if header.TransportProtocol.IsUnix() {
136 var addr _addrUnix
137 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
138 return nil, ErrInvalidAddress
139 }
140
141 network := "unix"
142 if header.TransportProtocol.IsDatagram() {
143 network = "unixgram"
144 }
145
146 header.SourceAddr = &net.UnixAddr{
147 Net: network,
148 Name: parseUnixName(addr.Src[:]),
149 }
150 header.DestinationAddr = &net.UnixAddr{
151 Net: network,
152 Name: parseUnixName(addr.Dst[:]),
153 }
154 }
155 }
156
157 // Copy bytes for optional Type-Length-Value vector
158 header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice
159 if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF {
160 return nil, err
161 }
162
163 return header, nil
164}
165
166func (header *Header) formatVersion2() ([]byte, error) {
167 var buf bytes.Buffer
168 buf.Write(SIGV2)
169 buf.WriteByte(header.Command.toByte())
170 buf.WriteByte(header.TransportProtocol.toByte())
171 if header.TransportProtocol.IsUnspec() {
172 // For UNSPEC, write no addresses and ports but only TLVs if they are present
173 hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs))
174 if err != nil {
175 return nil, err
176 }
177 buf.Write(hdrLen)
178 } else {
179 var addrSrc, addrDst []byte
180 if header.TransportProtocol.IsIPv4() {
181 hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs))
182 if err != nil {
183 return nil, err
184 }
185 buf.Write(hdrLen)
186 sourceIP, destIP, _ := header.IPs()
187 addrSrc = sourceIP.To4()
188 addrDst = destIP.To4()
189 } else if header.TransportProtocol.IsIPv6() {
190 hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs))
191 if err != nil {
192 return nil, err
193 }
194 buf.Write(hdrLen)
195 sourceIP, destIP, _ := header.IPs()
196 addrSrc = sourceIP.To16()
197 addrDst = destIP.To16()
198 } else if header.TransportProtocol.IsUnix() {
199 buf.Write(lengthUnixBytes)
200 sourceAddr, destAddr, ok := header.UnixAddrs()
201 if !ok {
202 return nil, ErrInvalidAddress
203 }
204 addrSrc = formatUnixName(sourceAddr.Name)
205 addrDst = formatUnixName(destAddr.Name)
206 }
207
208 if addrSrc == nil || addrDst == nil {
209 return nil, ErrInvalidAddress
210 }
211 buf.Write(addrSrc)
212 buf.Write(addrDst)
213
214 if sourcePort, destPort, ok := header.Ports(); ok {
215 portBytes := make([]byte, 2)
216
217 binary.BigEndian.PutUint16(portBytes, uint16(sourcePort))
218 buf.Write(portBytes)
219
220 binary.BigEndian.PutUint16(portBytes, uint16(destPort))
221 buf.Write(portBytes)
222 }
223 }
224
225 if len(header.rawTLVs) > 0 {
226 buf.Write(header.rawTLVs)
227 }
228
229 return buf.Bytes(), nil
230}
231
232func (header *Header) validateLength(length uint16) bool {
233 if header.TransportProtocol.IsIPv4() {
234 return length >= lengthV4
235 } else if header.TransportProtocol.IsIPv6() {
236 return length >= lengthV6
237 } else if header.TransportProtocol.IsUnix() {
238 return length >= lengthUnix
239 } else if header.TransportProtocol.IsUnspec() {
240 return length >= lengthUnspec
241 }
242 return false
243}
244
245// addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow.
246func addTLVLen(cur []byte, tlvLen int) ([]byte, error) {
247 if tlvLen == 0 {
248 return cur, nil
249 }
250 curLen := binary.BigEndian.Uint16(cur)
251 newLen := int(curLen) + tlvLen
252 if newLen >= 1<<16 {
253 return nil, errUint16Overflow
254 }
255 a := make([]byte, 2)
256 binary.BigEndian.PutUint16(a, uint16(newLen))
257 return a, nil
258}
259
260func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr {
261 if transport.IsStream() {
262 return &net.TCPAddr{IP: ip, Port: int(port)}
263 } else if transport.IsDatagram() {
264 return &net.UDPAddr{IP: ip, Port: int(port)}
265 } else {
266 return nil
267 }
268}
269
270func parseUnixName(b []byte) string {
271 i := bytes.IndexByte(b, 0)
272 if i < 0 {
273 return string(b)
274 }
275 return string(b[:i])
276}
277
278func formatUnixName(name string) []byte {
279 n := int(lengthUnix) / 2
280 if len(name) >= n {
281 return []byte(name[:n])
282 }
283 pad := make([]byte, n-len(name))
284 return append([]byte(name), pad...)
285}
Note: See TracBrowser for help on using the repository browser.