source: code/trunk/cmd/soju/main.go@ 746

Last change on this file since 746 was 713, checked in by contact, 4 years ago

Add pprof HTTP server

This enables production debugging of the bouncer.

Closes: https://todo.sr.ht/~emersion/soju/155

File size: 9.1 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "io/ioutil"
9 "log"
10 "net"
11 "net/http"
12 _ "net/http/pprof"
13 "net/url"
14 "os"
15 "os/signal"
16 "strings"
17 "sync/atomic"
18 "syscall"
19 "time"
20
21 "github.com/pires/go-proxyproto"
22 "github.com/prometheus/client_golang/prometheus"
23 "github.com/prometheus/client_golang/prometheus/promhttp"
24
25 "git.sr.ht/~emersion/soju"
26 "git.sr.ht/~emersion/soju/config"
27)
28
29// TCP keep-alive interval for downstream TCP connections
30const downstreamKeepAlive = 1 * time.Hour
31
32type stringSliceFlag []string
33
34func (v *stringSliceFlag) String() string {
35 return fmt.Sprint([]string(*v))
36}
37
38func (v *stringSliceFlag) Set(s string) error {
39 *v = append(*v, s)
40 return nil
41}
42
43func bumpOpenedFileLimit() error {
44 var rlimit syscall.Rlimit
45 if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
46 return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
47 }
48 rlimit.Cur = rlimit.Max
49 if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
50 return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
51 }
52 return nil
53}
54
55var (
56 configPath string
57 debug bool
58
59 tlsCert atomic.Value // *tls.Certificate
60)
61
62func loadConfig() (*config.Server, *soju.Config, error) {
63 var raw *config.Server
64 if configPath != "" {
65 var err error
66 raw, err = config.Load(configPath)
67 if err != nil {
68 return nil, nil, fmt.Errorf("failed to load config file: %v", err)
69 }
70 } else {
71 raw = config.Defaults()
72 }
73
74 var motd string
75 if raw.MOTDPath != "" {
76 b, err := ioutil.ReadFile(raw.MOTDPath)
77 if err != nil {
78 return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
79 }
80 motd = strings.TrimSuffix(string(b), "\n")
81 }
82
83 if raw.TLS != nil {
84 cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
85 if err != nil {
86 return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
87 }
88 tlsCert.Store(&cert)
89 }
90
91 cfg := &soju.Config{
92 Hostname: raw.Hostname,
93 Title: raw.Title,
94 LogPath: raw.LogPath,
95 HTTPOrigins: raw.HTTPOrigins,
96 AcceptProxyIPs: raw.AcceptProxyIPs,
97 MaxUserNetworks: raw.MaxUserNetworks,
98 MultiUpstream: raw.MultiUpstream,
99 UpstreamUserIPs: raw.UpstreamUserIPs,
100 Debug: debug,
101 MOTD: motd,
102 }
103 return raw, cfg, nil
104}
105
106func main() {
107 var listen []string
108 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
109 flag.StringVar(&configPath, "config", "", "path to configuration file")
110 flag.BoolVar(&debug, "debug", false, "enable debug logging")
111 flag.Parse()
112
113 cfg, serverCfg, err := loadConfig()
114 if err != nil {
115 log.Fatal(err)
116 }
117
118 cfg.Listen = append(cfg.Listen, listen...)
119 if len(cfg.Listen) == 0 {
120 cfg.Listen = []string{":6697"}
121 }
122
123 if err := bumpOpenedFileLimit(); err != nil {
124 log.Printf("failed to bump max number of opened files: %v", err)
125 }
126
127 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
128 if err != nil {
129 log.Fatalf("failed to open database: %v", err)
130 }
131
132 var tlsCfg *tls.Config
133 if cfg.TLS != nil {
134 tlsCfg = &tls.Config{
135 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
136 return tlsCert.Load().(*tls.Certificate), nil
137 },
138 }
139 }
140
141 srv := soju.NewServer(db)
142 srv.SetConfig(serverCfg)
143
144 for _, listen := range cfg.Listen {
145 listenURI := listen
146 if !strings.Contains(listenURI, ":/") {
147 // This is a raw domain name, make it an URL with an empty scheme
148 listenURI = "//" + listenURI
149 }
150 u, err := url.Parse(listenURI)
151 if err != nil {
152 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
153 }
154
155 switch u.Scheme {
156 case "ircs", "":
157 if tlsCfg == nil {
158 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
159 }
160 host := u.Host
161 if _, _, err := net.SplitHostPort(host); err != nil {
162 host = host + ":6697"
163 }
164 ircsTLSCfg := tlsCfg.Clone()
165 ircsTLSCfg.NextProtos = []string{"irc"}
166 lc := net.ListenConfig{
167 KeepAlive: downstreamKeepAlive,
168 }
169 l, err := lc.Listen(context.Background(), "tcp", host)
170 if err != nil {
171 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
172 }
173 ln := tls.NewListener(l, ircsTLSCfg)
174 ln = proxyProtoListener(ln, srv)
175 go func() {
176 if err := srv.Serve(ln); err != nil {
177 log.Printf("serving %q: %v", listen, err)
178 }
179 }()
180 case "irc+insecure":
181 host := u.Host
182 if _, _, err := net.SplitHostPort(host); err != nil {
183 host = host + ":6667"
184 }
185 lc := net.ListenConfig{
186 KeepAlive: downstreamKeepAlive,
187 }
188 ln, err := lc.Listen(context.Background(), "tcp", host)
189 if err != nil {
190 log.Fatalf("failed to start listener on %q: %v", listen, err)
191 }
192 ln = proxyProtoListener(ln, srv)
193 go func() {
194 if err := srv.Serve(ln); err != nil {
195 log.Printf("serving %q: %v", listen, err)
196 }
197 }()
198 case "unix":
199 ln, err := net.Listen("unix", u.Path)
200 if err != nil {
201 log.Fatalf("failed to start listener on %q: %v", listen, err)
202 }
203 ln = proxyProtoListener(ln, srv)
204 go func() {
205 if err := srv.Serve(ln); err != nil {
206 log.Printf("serving %q: %v", listen, err)
207 }
208 }()
209 case "wss":
210 if tlsCfg == nil {
211 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
212 }
213 addr := u.Host
214 if _, _, err := net.SplitHostPort(addr); err != nil {
215 addr = addr + ":https"
216 }
217 httpSrv := http.Server{
218 Addr: addr,
219 TLSConfig: tlsCfg,
220 Handler: srv,
221 }
222 go func() {
223 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
224 log.Fatalf("serving %q: %v", listen, err)
225 }
226 }()
227 case "ws+insecure":
228 addr := u.Host
229 if _, _, err := net.SplitHostPort(addr); err != nil {
230 addr = addr + ":http"
231 }
232 httpSrv := http.Server{
233 Addr: addr,
234 Handler: srv,
235 }
236 go func() {
237 if err := httpSrv.ListenAndServe(); err != nil {
238 log.Fatalf("serving %q: %v", listen, err)
239 }
240 }()
241 case "ident":
242 if srv.Identd == nil {
243 srv.Identd = soju.NewIdentd()
244 }
245
246 host := u.Host
247 if _, _, err := net.SplitHostPort(host); err != nil {
248 host = host + ":113"
249 }
250 ln, err := net.Listen("tcp", host)
251 if err != nil {
252 log.Fatalf("failed to start listener on %q: %v", listen, err)
253 }
254 ln = proxyProtoListener(ln, srv)
255 go func() {
256 if err := srv.Identd.Serve(ln); err != nil {
257 log.Printf("serving %q: %v", listen, err)
258 }
259 }()
260 case "http+prometheus":
261 if srv.MetricsRegistry == nil {
262 srv.MetricsRegistry = prometheus.DefaultRegisterer
263 }
264
265 // Only allow localhost as listening host for security reasons.
266 // Users can always explicitly setup reverse proxies if desirable.
267 hostname, _, err := net.SplitHostPort(u.Host)
268 if err != nil {
269 log.Fatalf("invalid host in URI %q: %v", listen, err)
270 } else if hostname != "localhost" {
271 log.Fatalf("Prometheus listening host must be localhost")
272 }
273
274 metricsHandler := promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{
275 MaxRequestsInFlight: 10,
276 Timeout: 10 * time.Second,
277 EnableOpenMetrics: true,
278 })
279 metricsHandler = promhttp.InstrumentMetricHandler(prometheus.DefaultRegisterer, metricsHandler)
280
281 httpSrv := http.Server{
282 Addr: u.Host,
283 Handler: metricsHandler,
284 }
285 go func() {
286 if err := httpSrv.ListenAndServe(); err != nil {
287 log.Fatalf("serving %q: %v", listen, err)
288 }
289 }()
290 case "http+pprof":
291 // Only allow localhost as listening host for security reasons.
292 // Users can always explicitly setup reverse proxies if desirable.
293 hostname, _, err := net.SplitHostPort(u.Host)
294 if err != nil {
295 log.Fatalf("invalid host in URI %q: %v", listen, err)
296 } else if hostname != "localhost" {
297 log.Fatalf("pprof listening host must be localhost")
298 }
299
300 // net/http/pprof registers its handlers in http.DefaultServeMux
301 httpSrv := http.Server{
302 Addr: u.Host,
303 Handler: http.DefaultServeMux,
304 }
305 go func() {
306 if err := httpSrv.ListenAndServe(); err != nil {
307 log.Fatalf("serving %q: %v", listen, err)
308 }
309 }()
310 default:
311 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
312 }
313
314 log.Printf("server listening on %q", listen)
315 }
316
317 if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
318 srv.MetricsRegistry.MustRegister(db.MetricsCollector())
319 }
320
321 sigCh := make(chan os.Signal, 1)
322 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
323
324 if err := srv.Start(); err != nil {
325 log.Fatal(err)
326 }
327
328 for sig := range sigCh {
329 switch sig {
330 case syscall.SIGHUP:
331 log.Print("reloading configuration")
332 _, serverCfg, err := loadConfig()
333 if err != nil {
334 log.Printf("failed to reloading configuration: %v", err)
335 } else {
336 srv.SetConfig(serverCfg)
337 }
338 case syscall.SIGINT, syscall.SIGTERM:
339 log.Print("shutting down server")
340 srv.Shutdown()
341 return
342 }
343 }
344}
345
346func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
347 return &proxyproto.Listener{
348 Listener: ln,
349 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
350 tcpAddr, ok := upstream.(*net.TCPAddr)
351 if !ok {
352 return proxyproto.IGNORE, nil
353 }
354 if srv.Config().AcceptProxyIPs.Contains(tcpAddr.IP) {
355 return proxyproto.USE, nil
356 }
357 return proxyproto.IGNORE, nil
358 },
359 ReadHeaderTimeout: 5 * time.Second,
360 }
361}
Note: See TracBrowser for help on using the repository browser.