source: code/trunk/vendor/github.com/pires/go-proxyproto/protocol.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 9.1 KB
RevLine 
[822]1package proxyproto
2
3import (
4 "bufio"
5 "io"
6 "net"
7 "sync"
8 "sync/atomic"
9 "time"
10)
11
12// DefaultReadHeaderTimeout is how long header processing waits for header to
13// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
14// It's kept as a global variable so to make it easier to find and override,
15// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
16var DefaultReadHeaderTimeout = 10 * time.Second
17
18// Listener is used to wrap an underlying listener,
19// whose connections may be using the HAProxy Proxy Protocol.
20// If the connection is using the protocol, the RemoteAddr() will return
21// the correct client address. ReadHeaderTimeout will be applied to all
22// connections in order to prevent blocking operations. If no ReadHeaderTimeout
23// is set, a default of 200ms will be used. This can be disabled by setting the
24// timeout to < 0.
25type Listener struct {
26 Listener net.Listener
27 Policy PolicyFunc
28 ValidateHeader Validator
29 ReadHeaderTimeout time.Duration
30}
31
32// Conn is used to wrap and underlying connection which
33// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
34// return the address of the client instead of the proxy address. Each connection
35// will have its own readHeaderTimeout and readDeadline set by the Accept() call.
36type Conn struct {
37 readDeadline atomic.Value // time.Time
38 once sync.Once
39 readErr error
40 conn net.Conn
41 Validate Validator
42 bufReader *bufio.Reader
43 header *Header
44 ProxyHeaderPolicy Policy
45 readHeaderTimeout time.Duration
46}
47
48// Validator receives a header and decides whether it is a valid one
49// In case the header is not deemed valid it should return an error.
50type Validator func(*Header) error
51
52// ValidateHeader adds given validator for proxy headers to a connection when passed as option to NewConn()
53func ValidateHeader(v Validator) func(*Conn) {
54 return func(c *Conn) {
55 if v != nil {
56 c.Validate = v
57 }
58 }
59}
60
61// Accept waits for and returns the next connection to the listener.
62func (p *Listener) Accept() (net.Conn, error) {
63 // Get the underlying connection
64 conn, err := p.Listener.Accept()
65 if err != nil {
66 return nil, err
67 }
68
69 proxyHeaderPolicy := USE
70 if p.Policy != nil {
71 proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
72 if err != nil {
73 // can't decide the policy, we can't accept the connection
74 conn.Close()
75 return nil, err
76 }
77 // Handle a connection as a regular one
78 if proxyHeaderPolicy == SKIP {
79 return conn, nil
80 }
81 }
82
83 newConn := NewConn(
84 conn,
85 WithPolicy(proxyHeaderPolicy),
86 ValidateHeader(p.ValidateHeader),
87 )
88
89 // If the ReadHeaderTimeout for the listener is unset, use the default timeout.
90 if p.ReadHeaderTimeout == 0 {
91 p.ReadHeaderTimeout = DefaultReadHeaderTimeout
92 }
93
94 // Set the readHeaderTimeout of the new conn to the value of the listener
95 newConn.readHeaderTimeout = p.ReadHeaderTimeout
96
97 return newConn, nil
98}
99
100// Close closes the underlying listener.
101func (p *Listener) Close() error {
102 return p.Listener.Close()
103}
104
105// Addr returns the underlying listener's network address.
106func (p *Listener) Addr() net.Addr {
107 return p.Listener.Addr()
108}
109
110// NewConn is used to wrap a net.Conn that may be speaking
111// the proxy protocol into a proxyproto.Conn
112func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
113 pConn := &Conn{
114 bufReader: bufio.NewReader(conn),
115 conn: conn,
116 }
117
118 for _, opt := range opts {
119 opt(pConn)
120 }
121
122 return pConn
123}
124
125// Read is check for the proxy protocol header when doing
126// the initial scan. If there is an error parsing the header,
127// it is returned and the socket is closed.
128func (p *Conn) Read(b []byte) (int, error) {
129 p.once.Do(func() {
130 p.readErr = p.readHeader()
131 })
132 if p.readErr != nil {
133 return 0, p.readErr
134 }
135
136 return p.bufReader.Read(b)
137}
138
139// Write wraps original conn.Write
140func (p *Conn) Write(b []byte) (int, error) {
141 return p.conn.Write(b)
142}
143
144// Close wraps original conn.Close
145func (p *Conn) Close() error {
146 return p.conn.Close()
147}
148
149// ProxyHeader returns the proxy protocol header, if any. If an error occurs
150// while reading the proxy header, nil is returned.
151func (p *Conn) ProxyHeader() *Header {
152 p.once.Do(func() { p.readErr = p.readHeader() })
153 return p.header
154}
155
156// LocalAddr returns the address of the server if the proxy
157// protocol is being used, otherwise just returns the address of
158// the socket server. In case an error happens on reading the
159// proxy header the original LocalAddr is returned, not the one
160// from the proxy header even if the proxy header itself is
161// syntactically correct.
162func (p *Conn) LocalAddr() net.Addr {
163 p.once.Do(func() { p.readErr = p.readHeader() })
164 if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil {
165 return p.conn.LocalAddr()
166 }
167
168 return p.header.DestinationAddr
169}
170
171// RemoteAddr returns the address of the client if the proxy
172// protocol is being used, otherwise just returns the address of
173// the socket peer. In case an error happens on reading the
174// proxy header the original RemoteAddr is returned, not the one
175// from the proxy header even if the proxy header itself is
176// syntactically correct.
177func (p *Conn) RemoteAddr() net.Addr {
178 p.once.Do(func() { p.readErr = p.readHeader() })
179 if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil {
180 return p.conn.RemoteAddr()
181 }
182
183 return p.header.SourceAddr
184}
185
186// Raw returns the underlying connection which can be casted to
187// a concrete type, allowing access to specialized functions.
188//
189// Use this ONLY if you know exactly what you are doing.
190func (p *Conn) Raw() net.Conn {
191 return p.conn
192}
193
194// TCPConn returns the underlying TCP connection,
195// allowing access to specialized functions.
196//
197// Use this ONLY if you know exactly what you are doing.
198func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) {
199 conn, ok = p.conn.(*net.TCPConn)
200 return
201}
202
203// UnixConn returns the underlying Unix socket connection,
204// allowing access to specialized functions.
205//
206// Use this ONLY if you know exactly what you are doing.
207func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) {
208 conn, ok = p.conn.(*net.UnixConn)
209 return
210}
211
212// UDPConn returns the underlying UDP connection,
213// allowing access to specialized functions.
214//
215// Use this ONLY if you know exactly what you are doing.
216func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) {
217 conn, ok = p.conn.(*net.UDPConn)
218 return
219}
220
221// SetDeadline wraps original conn.SetDeadline
222func (p *Conn) SetDeadline(t time.Time) error {
223 p.readDeadline.Store(t)
224 return p.conn.SetDeadline(t)
225}
226
227// SetReadDeadline wraps original conn.SetReadDeadline
228func (p *Conn) SetReadDeadline(t time.Time) error {
229 // Set a local var that tells us the desired deadline. This is
230 // needed in order to reset the read deadline to the one that is
231 // desired by the user, rather than an empty deadline.
232 p.readDeadline.Store(t)
233 return p.conn.SetReadDeadline(t)
234}
235
236// SetWriteDeadline wraps original conn.SetWriteDeadline
237func (p *Conn) SetWriteDeadline(t time.Time) error {
238 return p.conn.SetWriteDeadline(t)
239}
240
241func (p *Conn) readHeader() error {
242 // If the connection's readHeaderTimeout is more than 0,
243 // push our deadline back to now plus the timeout. This should only
244 // run on the connection, as we don't want to override the previous
245 // read deadline the user may have used.
246 if p.readHeaderTimeout > 0 {
247 if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil {
248 return err
249 }
250 }
251
252 header, err := Read(p.bufReader)
253
254 // If the connection's readHeaderTimeout is more than 0, undo the change to the
255 // deadline that we made above. Because we retain the readDeadline as part of our
256 // SetReadDeadline override, we know the user's desired deadline so we use that.
257 // Therefore, we check whether the error is a net.Timeout and if it is, we decide
258 // the proxy proto does not exist and set the error accordingly.
259 if p.readHeaderTimeout > 0 {
260 t := p.readDeadline.Load()
261 if t == nil {
262 t = time.Time{}
263 }
264 if err := p.conn.SetReadDeadline(t.(time.Time)); err != nil {
265 return err
266 }
267 if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
268 err = ErrNoProxyProtocol
269 }
270 }
271
272 // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto
273 // let's act as if there was no error when PROXY protocol is not present.
274 if err == ErrNoProxyProtocol {
275 // but not if it is required that the connection has one
276 if p.ProxyHeaderPolicy == REQUIRE {
277 return err
278 }
279
280 return nil
281 }
282
283 // proxy protocol header was found
284 if err == nil && header != nil {
285 switch p.ProxyHeaderPolicy {
286 case REJECT:
287 // this connection is not allowed to send one
288 return ErrSuperfluousProxyHeader
289 case USE, REQUIRE:
290 if p.Validate != nil {
291 err = p.Validate(header)
292 if err != nil {
293 return err
294 }
295 }
296
297 p.header = header
298 }
299 }
300
301 return err
302}
303
304// ReadFrom implements the io.ReaderFrom ReadFrom method
305func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
306 if rf, ok := p.conn.(io.ReaderFrom); ok {
307 return rf.ReadFrom(r)
308 }
309 return io.Copy(p.conn, r)
310}
311
312// WriteTo implements io.WriterTo
313func (p *Conn) WriteTo(w io.Writer) (int64, error) {
314 p.once.Do(func() { p.readErr = p.readHeader() })
315 if p.readErr != nil {
316 return 0, p.readErr
317 }
318 return p.bufReader.WriteTo(w)
319}
Note: See TracBrowser for help on using the repository browser.