source: code/trunk/config/config.go@ 636

Last change on this file since 636 was 636, checked in by contact, 4 years ago

Add bouncer MOTD

Closes: https://todo.sr.ht/~emersion/soju/137

File size: 2.8 KB
Line 
1package config
2
3import (
4 "fmt"
5 "net"
6 "os"
7 "strconv"
8
9 "git.sr.ht/~emersion/go-scfg"
10)
11
12type IPSet []*net.IPNet
13
14func (set IPSet) Contains(ip net.IP) bool {
15 for _, n := range set {
16 if n.Contains(ip) {
17 return true
18 }
19 }
20 return false
21}
22
23// loopbackIPs contains the loopback networks 127.0.0.0/8 and ::1/128.
24var loopbackIPs = IPSet{
25 &net.IPNet{
26 IP: net.IP{127, 0, 0, 0},
27 Mask: net.CIDRMask(8, 32),
28 },
29 &net.IPNet{
30 IP: net.IPv6loopback,
31 Mask: net.CIDRMask(128, 128),
32 },
33}
34
35type TLS struct {
36 CertPath, KeyPath string
37}
38
39type Server struct {
40 Listen []string
41 Hostname string
42 TLS *TLS
43 MOTDPath string
44
45 SQLDriver string
46 SQLSource string
47 LogPath string
48
49 HTTPOrigins []string
50 AcceptProxyIPs IPSet
51
52 MaxUserNetworks int
53}
54
55func Defaults() *Server {
56 hostname, err := os.Hostname()
57 if err != nil {
58 hostname = "localhost"
59 }
60 return &Server{
61 Hostname: hostname,
62 SQLDriver: "sqlite3",
63 SQLSource: "soju.db",
64 MaxUserNetworks: -1,
65 }
66}
67
68func Load(path string) (*Server, error) {
69 cfg, err := scfg.Load(path)
70 if err != nil {
71 return nil, err
72 }
73 return parse(cfg)
74}
75
76func parse(cfg scfg.Block) (*Server, error) {
77 srv := Defaults()
78 for _, d := range cfg {
79 switch d.Name {
80 case "listen":
81 var uri string
82 if err := d.ParseParams(&uri); err != nil {
83 return nil, err
84 }
85 srv.Listen = append(srv.Listen, uri)
86 case "hostname":
87 if err := d.ParseParams(&srv.Hostname); err != nil {
88 return nil, err
89 }
90 case "tls":
91 tls := &TLS{}
92 if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
93 return nil, err
94 }
95 srv.TLS = tls
96 case "db":
97 if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
98 return nil, err
99 }
100 case "log":
101 var driver string
102 if err := d.ParseParams(&driver, &srv.LogPath); err != nil {
103 return nil, err
104 }
105 if driver != "fs" {
106 return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver)
107 }
108 case "http-origin":
109 srv.HTTPOrigins = d.Params
110 case "accept-proxy-ip":
111 srv.AcceptProxyIPs = nil
112 for _, s := range d.Params {
113 if s == "localhost" {
114 srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, loopbackIPs...)
115 continue
116 }
117 _, n, err := net.ParseCIDR(s)
118 if err != nil {
119 return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
120 }
121 srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n)
122 }
123 case "max-user-networks":
124 var max string
125 if err := d.ParseParams(&max); err != nil {
126 return nil, err
127 }
128 var err error
129 if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil {
130 return nil, fmt.Errorf("directive %q: %v", d.Name, err)
131 }
132 case "motd":
133 if err := d.ParseParams(&srv.MOTDPath); err != nil {
134 return nil, err
135 }
136 default:
137 return nil, fmt.Errorf("unknown directive %q", d.Name)
138 }
139 }
140
141 return srv, nil
142}
Note: See TracBrowser for help on using the repository browser.