source: code/trunk/config/config.go@ 704

Last change on this file since 704 was 694, checked in by contact, 4 years ago

Add config option to globally disable multi-upstream mode

Closes: https://todo.sr.ht/~emersion/soju/122

File size: 3.2 KB
Line 
1package config
2
3import (
4 "fmt"
5 "net"
6 "os"
7 "strconv"
8
9 "git.sr.ht/~emersion/go-scfg"
10)
11
12type IPSet []*net.IPNet
13
14func (set IPSet) Contains(ip net.IP) bool {
15 for _, n := range set {
16 if n.Contains(ip) {
17 return true
18 }
19 }
20 return false
21}
22
23// loopbackIPs contains the loopback networks 127.0.0.0/8 and ::1/128.
24var loopbackIPs = IPSet{
25 &net.IPNet{
26 IP: net.IP{127, 0, 0, 0},
27 Mask: net.CIDRMask(8, 32),
28 },
29 &net.IPNet{
30 IP: net.IPv6loopback,
31 Mask: net.CIDRMask(128, 128),
32 },
33}
34
35type TLS struct {
36 CertPath, KeyPath string
37}
38
39type Server struct {
40 Listen []string
41 TLS *TLS
42 Hostname string
43 Title string
44 MOTDPath string
45
46 SQLDriver string
47 SQLSource string
48 LogPath string
49
50 HTTPOrigins []string
51 AcceptProxyIPs IPSet
52
53 MaxUserNetworks int
54 MultiUpstream bool
55}
56
57func Defaults() *Server {
58 hostname, err := os.Hostname()
59 if err != nil {
60 hostname = "localhost"
61 }
62 return &Server{
63 Hostname: hostname,
64 SQLDriver: "sqlite3",
65 SQLSource: "soju.db",
66 MaxUserNetworks: -1,
67 MultiUpstream: true,
68 }
69}
70
71func Load(path string) (*Server, error) {
72 cfg, err := scfg.Load(path)
73 if err != nil {
74 return nil, err
75 }
76 return parse(cfg)
77}
78
79func parse(cfg scfg.Block) (*Server, error) {
80 srv := Defaults()
81 for _, d := range cfg {
82 switch d.Name {
83 case "listen":
84 var uri string
85 if err := d.ParseParams(&uri); err != nil {
86 return nil, err
87 }
88 srv.Listen = append(srv.Listen, uri)
89 case "hostname":
90 if err := d.ParseParams(&srv.Hostname); err != nil {
91 return nil, err
92 }
93 case "title":
94 if err := d.ParseParams(&srv.Title); err != nil {
95 return nil, err
96 }
97 case "motd":
98 if err := d.ParseParams(&srv.MOTDPath); err != nil {
99 return nil, err
100 }
101 case "tls":
102 tls := &TLS{}
103 if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
104 return nil, err
105 }
106 srv.TLS = tls
107 case "db":
108 if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
109 return nil, err
110 }
111 case "log":
112 var driver string
113 if err := d.ParseParams(&driver, &srv.LogPath); err != nil {
114 return nil, err
115 }
116 if driver != "fs" {
117 return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver)
118 }
119 case "http-origin":
120 srv.HTTPOrigins = d.Params
121 case "accept-proxy-ip":
122 srv.AcceptProxyIPs = nil
123 for _, s := range d.Params {
124 if s == "localhost" {
125 srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, loopbackIPs...)
126 continue
127 }
128 _, n, err := net.ParseCIDR(s)
129 if err != nil {
130 return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
131 }
132 srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n)
133 }
134 case "max-user-networks":
135 var max string
136 if err := d.ParseParams(&max); err != nil {
137 return nil, err
138 }
139 var err error
140 if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil {
141 return nil, fmt.Errorf("directive %q: %v", d.Name, err)
142 }
143 case "multi-upstream-mode":
144 var str string
145 if err := d.ParseParams(&str); err != nil {
146 return nil, err
147 }
148 v, err := strconv.ParseBool(str)
149 if err != nil {
150 return nil, fmt.Errorf("directive %q: %v", d.Name, err)
151 }
152 srv.MultiUpstream = v
153 default:
154 return nil, fmt.Errorf("unknown directive %q", d.Name)
155 }
156 }
157
158 return srv, nil
159}
Note: See TracBrowser for help on using the repository browser.