1 | package proxyproto
|
---|
2 |
|
---|
3 | import (
|
---|
4 | "fmt"
|
---|
5 | "net"
|
---|
6 | "strings"
|
---|
7 | )
|
---|
8 |
|
---|
9 | // PolicyFunc can be used to decide whether to trust the PROXY info from
|
---|
10 | // upstream. If set, the connecting address is passed in as an argument.
|
---|
11 | //
|
---|
12 | // See below for the different policies.
|
---|
13 | //
|
---|
14 | // In case an error is returned the connection is denied.
|
---|
15 | type PolicyFunc func(upstream net.Addr) (Policy, error)
|
---|
16 |
|
---|
17 | // Policy defines how a connection with a PROXY header address is treated.
|
---|
18 | type Policy int
|
---|
19 |
|
---|
20 | const (
|
---|
21 | // USE address from PROXY header
|
---|
22 | USE Policy = iota
|
---|
23 | // IGNORE address from PROXY header, but accept connection
|
---|
24 | IGNORE
|
---|
25 | // REJECT connection when PROXY header is sent
|
---|
26 | // Note: even though the first read on the connection returns an error if
|
---|
27 | // a PROXY header is present, subsequent reads do not. It is the task of
|
---|
28 | // the code using the connection to handle that case properly.
|
---|
29 | REJECT
|
---|
30 | // REQUIRE connection to send PROXY header, reject if not present
|
---|
31 | // Note: even though the first read on the connection returns an error if
|
---|
32 | // a PROXY header is not present, subsequent reads do not. It is the task
|
---|
33 | // of the code using the connection to handle that case properly.
|
---|
34 | REQUIRE
|
---|
35 | // SKIP accepts a connection without requiring the PROXY header
|
---|
36 | // Note: an example usage can be found in the SkipProxyHeaderForCIDR
|
---|
37 | // function.
|
---|
38 | SKIP
|
---|
39 | )
|
---|
40 |
|
---|
41 | // SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a
|
---|
42 | // connection from a skipHeaderCIDR without requiring a PROXY header, e.g.
|
---|
43 | // Kubernetes pods local traffic. The def is a policy to use when an upstream
|
---|
44 | // address doesn't match the skipHeaderCIDR.
|
---|
45 | func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
|
---|
46 | return func(upstream net.Addr) (Policy, error) {
|
---|
47 | ip, err := ipFromAddr(upstream)
|
---|
48 | if err != nil {
|
---|
49 | return def, err
|
---|
50 | }
|
---|
51 |
|
---|
52 | if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) {
|
---|
53 | return SKIP, nil
|
---|
54 | }
|
---|
55 |
|
---|
56 | return def, nil
|
---|
57 | }
|
---|
58 | }
|
---|
59 |
|
---|
60 | // WithPolicy adds given policy to a connection when passed as option to NewConn()
|
---|
61 | func WithPolicy(p Policy) func(*Conn) {
|
---|
62 | return func(c *Conn) {
|
---|
63 | c.ProxyHeaderPolicy = p
|
---|
64 | }
|
---|
65 | }
|
---|
66 |
|
---|
67 | // LaxWhiteListPolicy returns a PolicyFunc which decides whether the
|
---|
68 | // upstream ip is allowed to send a proxy header based on a list of allowed
|
---|
69 | // IP addresses and IP ranges. In case upstream IP is not in list the proxy
|
---|
70 | // header will be ignored. If one of the provided IP addresses or IP ranges
|
---|
71 | // is invalid it will return an error instead of a PolicyFunc.
|
---|
72 | func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) {
|
---|
73 | allowFrom, err := parse(allowed)
|
---|
74 | if err != nil {
|
---|
75 | return nil, err
|
---|
76 | }
|
---|
77 |
|
---|
78 | return whitelistPolicy(allowFrom, IGNORE), nil
|
---|
79 | }
|
---|
80 |
|
---|
81 | // MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one
|
---|
82 | // of the provided IP addresses or IP ranges is invalid.
|
---|
83 | func MustLaxWhiteListPolicy(allowed []string) PolicyFunc {
|
---|
84 | pfunc, err := LaxWhiteListPolicy(allowed)
|
---|
85 | if err != nil {
|
---|
86 | panic(err)
|
---|
87 | }
|
---|
88 |
|
---|
89 | return pfunc
|
---|
90 | }
|
---|
91 |
|
---|
92 | // StrictWhiteListPolicy returns a PolicyFunc which decides whether the
|
---|
93 | // upstream ip is allowed to send a proxy header based on a list of allowed
|
---|
94 | // IP addresses and IP ranges. In case upstream IP is not in list reading on
|
---|
95 | // the connection will be refused on the first read. Please note: subsequent
|
---|
96 | // reads do not error. It is the task of the code using the connection to
|
---|
97 | // handle that case properly. If one of the provided IP addresses or IP
|
---|
98 | // ranges is invalid it will return an error instead of a PolicyFunc.
|
---|
99 | func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) {
|
---|
100 | allowFrom, err := parse(allowed)
|
---|
101 | if err != nil {
|
---|
102 | return nil, err
|
---|
103 | }
|
---|
104 |
|
---|
105 | return whitelistPolicy(allowFrom, REJECT), nil
|
---|
106 | }
|
---|
107 |
|
---|
108 | // MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic
|
---|
109 | // if one of the provided IP addresses or IP ranges is invalid.
|
---|
110 | func MustStrictWhiteListPolicy(allowed []string) PolicyFunc {
|
---|
111 | pfunc, err := StrictWhiteListPolicy(allowed)
|
---|
112 | if err != nil {
|
---|
113 | panic(err)
|
---|
114 | }
|
---|
115 |
|
---|
116 | return pfunc
|
---|
117 | }
|
---|
118 |
|
---|
119 | func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc {
|
---|
120 | return func(upstream net.Addr) (Policy, error) {
|
---|
121 | upstreamIP, err := ipFromAddr(upstream)
|
---|
122 | if err != nil {
|
---|
123 | // something is wrong with the source IP, better reject the connection
|
---|
124 | return REJECT, err
|
---|
125 | }
|
---|
126 |
|
---|
127 | for _, allowFrom := range allowed {
|
---|
128 | if allowFrom(upstreamIP) {
|
---|
129 | return USE, nil
|
---|
130 | }
|
---|
131 | }
|
---|
132 |
|
---|
133 | return def, nil
|
---|
134 | }
|
---|
135 | }
|
---|
136 |
|
---|
137 | func parse(allowed []string) ([]func(net.IP) bool, error) {
|
---|
138 | a := make([]func(net.IP) bool, len(allowed))
|
---|
139 | for i, allowFrom := range allowed {
|
---|
140 | if strings.LastIndex(allowFrom, "/") > 0 {
|
---|
141 | _, ipRange, err := net.ParseCIDR(allowFrom)
|
---|
142 | if err != nil {
|
---|
143 | return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP range: %v", allowFrom, err)
|
---|
144 | }
|
---|
145 |
|
---|
146 | a[i] = ipRange.Contains
|
---|
147 | } else {
|
---|
148 | allowed := net.ParseIP(allowFrom)
|
---|
149 | if allowed == nil {
|
---|
150 | return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP address", allowFrom)
|
---|
151 | }
|
---|
152 |
|
---|
153 | a[i] = allowed.Equal
|
---|
154 | }
|
---|
155 | }
|
---|
156 |
|
---|
157 | return a, nil
|
---|
158 | }
|
---|
159 |
|
---|
160 | func ipFromAddr(upstream net.Addr) (net.IP, error) {
|
---|
161 | upstreamString, _, err := net.SplitHostPort(upstream.String())
|
---|
162 | if err != nil {
|
---|
163 | return nil, err
|
---|
164 | }
|
---|
165 |
|
---|
166 | upstreamIP := net.ParseIP(upstreamString)
|
---|
167 | if nil == upstreamIP {
|
---|
168 | return nil, fmt.Errorf("proxyproto: invalid IP address")
|
---|
169 | }
|
---|
170 |
|
---|
171 | return upstreamIP, nil
|
---|
172 | }
|
---|