source: code/trunk/cmd/soju/main.go@ 712

Last change on this file since 712 was 712, checked in by contact, 4 years ago

Add Prometheus instrumentation for the database

File size: 8.4 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "io/ioutil"
9 "log"
10 "net"
11 "net/http"
12 "net/url"
13 "os"
14 "os/signal"
15 "strings"
16 "sync/atomic"
17 "syscall"
18 "time"
19
20 "github.com/pires/go-proxyproto"
21 "github.com/prometheus/client_golang/prometheus"
22 "github.com/prometheus/client_golang/prometheus/promhttp"
23
24 "git.sr.ht/~emersion/soju"
25 "git.sr.ht/~emersion/soju/config"
26)
27
28// TCP keep-alive interval for downstream TCP connections
29const downstreamKeepAlive = 1 * time.Hour
30
31type stringSliceFlag []string
32
33func (v *stringSliceFlag) String() string {
34 return fmt.Sprint([]string(*v))
35}
36
37func (v *stringSliceFlag) Set(s string) error {
38 *v = append(*v, s)
39 return nil
40}
41
42func bumpOpenedFileLimit() error {
43 var rlimit syscall.Rlimit
44 if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
45 return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
46 }
47 rlimit.Cur = rlimit.Max
48 if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
49 return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
50 }
51 return nil
52}
53
54var (
55 configPath string
56 debug bool
57
58 tlsCert atomic.Value // *tls.Certificate
59)
60
61func loadConfig() (*config.Server, *soju.Config, error) {
62 var raw *config.Server
63 if configPath != "" {
64 var err error
65 raw, err = config.Load(configPath)
66 if err != nil {
67 return nil, nil, fmt.Errorf("failed to load config file: %v", err)
68 }
69 } else {
70 raw = config.Defaults()
71 }
72
73 var motd string
74 if raw.MOTDPath != "" {
75 b, err := ioutil.ReadFile(raw.MOTDPath)
76 if err != nil {
77 return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
78 }
79 motd = strings.TrimSuffix(string(b), "\n")
80 }
81
82 if raw.TLS != nil {
83 cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
84 if err != nil {
85 return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
86 }
87 tlsCert.Store(&cert)
88 }
89
90 cfg := &soju.Config{
91 Hostname: raw.Hostname,
92 Title: raw.Title,
93 LogPath: raw.LogPath,
94 HTTPOrigins: raw.HTTPOrigins,
95 AcceptProxyIPs: raw.AcceptProxyIPs,
96 MaxUserNetworks: raw.MaxUserNetworks,
97 MultiUpstream: raw.MultiUpstream,
98 UpstreamUserIPs: raw.UpstreamUserIPs,
99 Debug: debug,
100 MOTD: motd,
101 }
102 return raw, cfg, nil
103}
104
105func main() {
106 var listen []string
107 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
108 flag.StringVar(&configPath, "config", "", "path to configuration file")
109 flag.BoolVar(&debug, "debug", false, "enable debug logging")
110 flag.Parse()
111
112 cfg, serverCfg, err := loadConfig()
113 if err != nil {
114 log.Fatal(err)
115 }
116
117 cfg.Listen = append(cfg.Listen, listen...)
118 if len(cfg.Listen) == 0 {
119 cfg.Listen = []string{":6697"}
120 }
121
122 if err := bumpOpenedFileLimit(); err != nil {
123 log.Printf("failed to bump max number of opened files: %v", err)
124 }
125
126 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
127 if err != nil {
128 log.Fatalf("failed to open database: %v", err)
129 }
130
131 var tlsCfg *tls.Config
132 if cfg.TLS != nil {
133 tlsCfg = &tls.Config{
134 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
135 return tlsCert.Load().(*tls.Certificate), nil
136 },
137 }
138 }
139
140 srv := soju.NewServer(db)
141 srv.SetConfig(serverCfg)
142
143 for _, listen := range cfg.Listen {
144 listenURI := listen
145 if !strings.Contains(listenURI, ":/") {
146 // This is a raw domain name, make it an URL with an empty scheme
147 listenURI = "//" + listenURI
148 }
149 u, err := url.Parse(listenURI)
150 if err != nil {
151 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
152 }
153
154 switch u.Scheme {
155 case "ircs", "":
156 if tlsCfg == nil {
157 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
158 }
159 host := u.Host
160 if _, _, err := net.SplitHostPort(host); err != nil {
161 host = host + ":6697"
162 }
163 ircsTLSCfg := tlsCfg.Clone()
164 ircsTLSCfg.NextProtos = []string{"irc"}
165 lc := net.ListenConfig{
166 KeepAlive: downstreamKeepAlive,
167 }
168 l, err := lc.Listen(context.Background(), "tcp", host)
169 if err != nil {
170 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
171 }
172 ln := tls.NewListener(l, ircsTLSCfg)
173 ln = proxyProtoListener(ln, srv)
174 go func() {
175 if err := srv.Serve(ln); err != nil {
176 log.Printf("serving %q: %v", listen, err)
177 }
178 }()
179 case "irc+insecure":
180 host := u.Host
181 if _, _, err := net.SplitHostPort(host); err != nil {
182 host = host + ":6667"
183 }
184 lc := net.ListenConfig{
185 KeepAlive: downstreamKeepAlive,
186 }
187 ln, err := lc.Listen(context.Background(), "tcp", host)
188 if err != nil {
189 log.Fatalf("failed to start listener on %q: %v", listen, err)
190 }
191 ln = proxyProtoListener(ln, srv)
192 go func() {
193 if err := srv.Serve(ln); err != nil {
194 log.Printf("serving %q: %v", listen, err)
195 }
196 }()
197 case "unix":
198 ln, err := net.Listen("unix", u.Path)
199 if err != nil {
200 log.Fatalf("failed to start listener on %q: %v", listen, err)
201 }
202 ln = proxyProtoListener(ln, srv)
203 go func() {
204 if err := srv.Serve(ln); err != nil {
205 log.Printf("serving %q: %v", listen, err)
206 }
207 }()
208 case "wss":
209 if tlsCfg == nil {
210 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
211 }
212 addr := u.Host
213 if _, _, err := net.SplitHostPort(addr); err != nil {
214 addr = addr + ":https"
215 }
216 httpSrv := http.Server{
217 Addr: addr,
218 TLSConfig: tlsCfg,
219 Handler: srv,
220 }
221 go func() {
222 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
223 log.Fatalf("serving %q: %v", listen, err)
224 }
225 }()
226 case "ws+insecure":
227 addr := u.Host
228 if _, _, err := net.SplitHostPort(addr); err != nil {
229 addr = addr + ":http"
230 }
231 httpSrv := http.Server{
232 Addr: addr,
233 Handler: srv,
234 }
235 go func() {
236 if err := httpSrv.ListenAndServe(); err != nil {
237 log.Fatalf("serving %q: %v", listen, err)
238 }
239 }()
240 case "ident":
241 if srv.Identd == nil {
242 srv.Identd = soju.NewIdentd()
243 }
244
245 host := u.Host
246 if _, _, err := net.SplitHostPort(host); err != nil {
247 host = host + ":113"
248 }
249 ln, err := net.Listen("tcp", host)
250 if err != nil {
251 log.Fatalf("failed to start listener on %q: %v", listen, err)
252 }
253 ln = proxyProtoListener(ln, srv)
254 go func() {
255 if err := srv.Identd.Serve(ln); err != nil {
256 log.Printf("serving %q: %v", listen, err)
257 }
258 }()
259 case "http+prometheus":
260 if srv.MetricsRegistry == nil {
261 srv.MetricsRegistry = prometheus.DefaultRegisterer
262 }
263
264 // Only allow localhost as listening host for security reasons.
265 // Users can always explicitly setup reverse proxies if desirable.
266 hostname, _, err := net.SplitHostPort(u.Host)
267 if err != nil {
268 log.Fatalf("invalid host in URI %q: %v", listen, err)
269 } else if hostname != "localhost" {
270 log.Fatalf("Prometheus listening host must be localhost")
271 }
272
273 metricsHandler := promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{
274 MaxRequestsInFlight: 10,
275 Timeout: 10 * time.Second,
276 EnableOpenMetrics: true,
277 })
278 metricsHandler = promhttp.InstrumentMetricHandler(prometheus.DefaultRegisterer, metricsHandler)
279
280 httpSrv := http.Server{
281 Addr: u.Host,
282 Handler: metricsHandler,
283 }
284 go func() {
285 if err := httpSrv.ListenAndServe(); err != nil {
286 log.Fatalf("serving %q: %v", listen, err)
287 }
288 }()
289 default:
290 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
291 }
292
293 log.Printf("server listening on %q", listen)
294 }
295
296 if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
297 srv.MetricsRegistry.MustRegister(db.MetricsCollector())
298 }
299
300 sigCh := make(chan os.Signal, 1)
301 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
302
303 if err := srv.Start(); err != nil {
304 log.Fatal(err)
305 }
306
307 for sig := range sigCh {
308 switch sig {
309 case syscall.SIGHUP:
310 log.Print("reloading configuration")
311 _, serverCfg, err := loadConfig()
312 if err != nil {
313 log.Printf("failed to reloading configuration: %v", err)
314 } else {
315 srv.SetConfig(serverCfg)
316 }
317 case syscall.SIGINT, syscall.SIGTERM:
318 log.Print("shutting down server")
319 srv.Shutdown()
320 return
321 }
322 }
323}
324
325func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
326 return &proxyproto.Listener{
327 Listener: ln,
328 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
329 tcpAddr, ok := upstream.(*net.TCPAddr)
330 if !ok {
331 return proxyproto.IGNORE, nil
332 }
333 if srv.Config().AcceptProxyIPs.Contains(tcpAddr.IP) {
334 return proxyproto.USE, nil
335 }
336 return proxyproto.IGNORE, nil
337 },
338 ReadHeaderTimeout: 5 * time.Second,
339 }
340}
Note: See TracBrowser for help on using the repository browser.