source: code/trunk/config/config.go@ 417

Last change on this file since 417 was 371, checked in by contact, 5 years ago

config: make http-origin directive overwrite previous list

Let's be on the safe side and assume the user doesn't meant the union of
all directive values.

File size: 3.0 KB
Line 
1package config
2
3import (
4 "bufio"
5 "fmt"
6 "io"
7 "net"
8 "os"
9
10 "github.com/google/shlex"
11)
12
13type IPSet []*net.IPNet
14
15func (set IPSet) Contains(ip net.IP) bool {
16 for _, n := range set {
17 if n.Contains(ip) {
18 return true
19 }
20 }
21 return false
22}
23
24// loopbackIPs contains the loopback networks 127.0.0.0/8 and ::1/128.
25var loopbackIPs = IPSet{
26 &net.IPNet{
27 IP: net.IP{127, 0, 0, 0},
28 Mask: net.CIDRMask(8, 32),
29 },
30 &net.IPNet{
31 IP: net.IPv6loopback,
32 Mask: net.CIDRMask(128, 128),
33 },
34}
35
36type TLS struct {
37 CertPath, KeyPath string
38}
39
40type Server struct {
41 Listen []string
42 Hostname string
43 TLS *TLS
44 SQLDriver string
45 SQLSource string
46 LogPath string
47 HTTPOrigins []string
48 AcceptProxyIPs IPSet
49}
50
51func Defaults() *Server {
52 hostname, err := os.Hostname()
53 if err != nil {
54 hostname = "localhost"
55 }
56 return &Server{
57 Hostname: hostname,
58 SQLDriver: "sqlite3",
59 SQLSource: "soju.db",
60 AcceptProxyIPs: loopbackIPs,
61 }
62}
63
64func Load(path string) (*Server, error) {
65 f, err := os.Open(path)
66 if err != nil {
67 return nil, err
68 }
69 defer f.Close()
70
71 return Parse(f)
72}
73
74func Parse(r io.Reader) (*Server, error) {
75 scanner := bufio.NewScanner(r)
76
77 var directives []directive
78 for scanner.Scan() {
79 words, err := shlex.Split(scanner.Text())
80 if err != nil {
81 return nil, fmt.Errorf("failed to parse config file: %v", err)
82 } else if len(words) == 0 {
83 continue
84 }
85
86 name, params := words[0], words[1:]
87 directives = append(directives, directive{name, params})
88 }
89 if err := scanner.Err(); err != nil {
90 return nil, fmt.Errorf("failed to read config file: %v", err)
91 }
92
93 srv := Defaults()
94 for _, d := range directives {
95 switch d.Name {
96 case "listen":
97 var uri string
98 if err := d.parseParams(&uri); err != nil {
99 return nil, err
100 }
101 srv.Listen = append(srv.Listen, uri)
102 case "hostname":
103 if err := d.parseParams(&srv.Hostname); err != nil {
104 return nil, err
105 }
106 case "tls":
107 tls := &TLS{}
108 if err := d.parseParams(&tls.CertPath, &tls.KeyPath); err != nil {
109 return nil, err
110 }
111 srv.TLS = tls
112 case "sql":
113 if err := d.parseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
114 return nil, err
115 }
116 case "log":
117 if err := d.parseParams(&srv.LogPath); err != nil {
118 return nil, err
119 }
120 case "http-origin":
121 srv.HTTPOrigins = d.Params
122 case "accept-proxy-ip":
123 srv.AcceptProxyIPs = nil
124 for _, s := range d.Params {
125 _, n, err := net.ParseCIDR(s)
126 if err != nil {
127 return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
128 }
129 srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n)
130 }
131 default:
132 return nil, fmt.Errorf("unknown directive %q", d.Name)
133 }
134 }
135
136 return srv, nil
137}
138
139type directive struct {
140 Name string
141 Params []string
142}
143
144func (d *directive) parseParams(out ...*string) error {
145 if len(d.Params) != len(out) {
146 return fmt.Errorf("directive %q has wrong number of parameters: expected %v, got %v", d.Name, len(out), len(d.Params))
147 }
148 for i := range out {
149 *out[i] = d.Params[i]
150 }
151 return nil
152}
Note: See TracBrowser for help on using the repository browser.