source: code/trunk/cmd/soju/main.go@ 767

Last change on this file since 767 was 767, checked in by contact, 3 years ago

Fix incorrect listen addr in error message

The various server goroutines would always capture the last listen
addr in the loop.

File size: 9.1 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "io/ioutil"
9 "log"
10 "net"
11 "net/http"
12 _ "net/http/pprof"
13 "net/url"
14 "os"
15 "os/signal"
16 "strings"
17 "sync/atomic"
18 "syscall"
19 "time"
20
21 "github.com/pires/go-proxyproto"
22 "github.com/prometheus/client_golang/prometheus"
23 "github.com/prometheus/client_golang/prometheus/promhttp"
24
25 "git.sr.ht/~emersion/soju"
26 "git.sr.ht/~emersion/soju/config"
27)
28
29// TCP keep-alive interval for downstream TCP connections
30const downstreamKeepAlive = 1 * time.Hour
31
32type stringSliceFlag []string
33
34func (v *stringSliceFlag) String() string {
35 return fmt.Sprint([]string(*v))
36}
37
38func (v *stringSliceFlag) Set(s string) error {
39 *v = append(*v, s)
40 return nil
41}
42
43func bumpOpenedFileLimit() error {
44 var rlimit syscall.Rlimit
45 if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
46 return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
47 }
48 rlimit.Cur = rlimit.Max
49 if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
50 return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
51 }
52 return nil
53}
54
55var (
56 configPath string
57 debug bool
58
59 tlsCert atomic.Value // *tls.Certificate
60)
61
62func loadConfig() (*config.Server, *soju.Config, error) {
63 var raw *config.Server
64 if configPath != "" {
65 var err error
66 raw, err = config.Load(configPath)
67 if err != nil {
68 return nil, nil, fmt.Errorf("failed to load config file: %v", err)
69 }
70 } else {
71 raw = config.Defaults()
72 }
73
74 var motd string
75 if raw.MOTDPath != "" {
76 b, err := ioutil.ReadFile(raw.MOTDPath)
77 if err != nil {
78 return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
79 }
80 motd = strings.TrimSuffix(string(b), "\n")
81 }
82
83 if raw.TLS != nil {
84 cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
85 if err != nil {
86 return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
87 }
88 tlsCert.Store(&cert)
89 }
90
91 cfg := &soju.Config{
92 Hostname: raw.Hostname,
93 Title: raw.Title,
94 LogPath: raw.LogPath,
95 HTTPOrigins: raw.HTTPOrigins,
96 AcceptProxyIPs: raw.AcceptProxyIPs,
97 MaxUserNetworks: raw.MaxUserNetworks,
98 MultiUpstream: raw.MultiUpstream,
99 UpstreamUserIPs: raw.UpstreamUserIPs,
100 MOTD: motd,
101 }
102 return raw, cfg, nil
103}
104
105func main() {
106 var listen []string
107 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
108 flag.StringVar(&configPath, "config", "", "path to configuration file")
109 flag.BoolVar(&debug, "debug", false, "enable debug logging")
110 flag.Parse()
111
112 cfg, serverCfg, err := loadConfig()
113 if err != nil {
114 log.Fatal(err)
115 }
116
117 cfg.Listen = append(cfg.Listen, listen...)
118 if len(cfg.Listen) == 0 {
119 cfg.Listen = []string{":6697"}
120 }
121
122 if err := bumpOpenedFileLimit(); err != nil {
123 log.Printf("failed to bump max number of opened files: %v", err)
124 }
125
126 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
127 if err != nil {
128 log.Fatalf("failed to open database: %v", err)
129 }
130
131 var tlsCfg *tls.Config
132 if cfg.TLS != nil {
133 tlsCfg = &tls.Config{
134 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
135 return tlsCert.Load().(*tls.Certificate), nil
136 },
137 }
138 }
139
140 srv := soju.NewServer(db)
141 srv.SetConfig(serverCfg)
142 srv.Logger = soju.NewLogger(log.Writer(), debug)
143
144 for _, listen := range cfg.Listen {
145 listen := listen // copy
146 listenURI := listen
147 if !strings.Contains(listenURI, ":/") {
148 // This is a raw domain name, make it an URL with an empty scheme
149 listenURI = "//" + listenURI
150 }
151 u, err := url.Parse(listenURI)
152 if err != nil {
153 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
154 }
155
156 switch u.Scheme {
157 case "ircs", "":
158 if tlsCfg == nil {
159 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
160 }
161 host := u.Host
162 if _, _, err := net.SplitHostPort(host); err != nil {
163 host = host + ":6697"
164 }
165 ircsTLSCfg := tlsCfg.Clone()
166 ircsTLSCfg.NextProtos = []string{"irc"}
167 lc := net.ListenConfig{
168 KeepAlive: downstreamKeepAlive,
169 }
170 l, err := lc.Listen(context.Background(), "tcp", host)
171 if err != nil {
172 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
173 }
174 ln := tls.NewListener(l, ircsTLSCfg)
175 ln = proxyProtoListener(ln, srv)
176 go func() {
177 if err := srv.Serve(ln); err != nil {
178 log.Printf("serving %q: %v", listen, err)
179 }
180 }()
181 case "irc+insecure":
182 host := u.Host
183 if _, _, err := net.SplitHostPort(host); err != nil {
184 host = host + ":6667"
185 }
186 lc := net.ListenConfig{
187 KeepAlive: downstreamKeepAlive,
188 }
189 ln, err := lc.Listen(context.Background(), "tcp", host)
190 if err != nil {
191 log.Fatalf("failed to start listener on %q: %v", listen, err)
192 }
193 ln = proxyProtoListener(ln, srv)
194 go func() {
195 if err := srv.Serve(ln); err != nil {
196 log.Printf("serving %q: %v", listen, err)
197 }
198 }()
199 case "unix":
200 ln, err := net.Listen("unix", u.Path)
201 if err != nil {
202 log.Fatalf("failed to start listener on %q: %v", listen, err)
203 }
204 ln = proxyProtoListener(ln, srv)
205 go func() {
206 if err := srv.Serve(ln); err != nil {
207 log.Printf("serving %q: %v", listen, err)
208 }
209 }()
210 case "wss":
211 if tlsCfg == nil {
212 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
213 }
214 addr := u.Host
215 if _, _, err := net.SplitHostPort(addr); err != nil {
216 addr = addr + ":https"
217 }
218 httpSrv := http.Server{
219 Addr: addr,
220 TLSConfig: tlsCfg,
221 Handler: srv,
222 }
223 go func() {
224 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
225 log.Fatalf("serving %q: %v", listen, err)
226 }
227 }()
228 case "ws+insecure":
229 addr := u.Host
230 if _, _, err := net.SplitHostPort(addr); err != nil {
231 addr = addr + ":http"
232 }
233 httpSrv := http.Server{
234 Addr: addr,
235 Handler: srv,
236 }
237 go func() {
238 if err := httpSrv.ListenAndServe(); err != nil {
239 log.Fatalf("serving %q: %v", listen, err)
240 }
241 }()
242 case "ident":
243 if srv.Identd == nil {
244 srv.Identd = soju.NewIdentd()
245 }
246
247 host := u.Host
248 if _, _, err := net.SplitHostPort(host); err != nil {
249 host = host + ":113"
250 }
251 ln, err := net.Listen("tcp", host)
252 if err != nil {
253 log.Fatalf("failed to start listener on %q: %v", listen, err)
254 }
255 ln = proxyProtoListener(ln, srv)
256 go func() {
257 if err := srv.Identd.Serve(ln); err != nil {
258 log.Printf("serving %q: %v", listen, err)
259 }
260 }()
261 case "http+prometheus":
262 if srv.MetricsRegistry == nil {
263 srv.MetricsRegistry = prometheus.DefaultRegisterer
264 }
265
266 // Only allow localhost as listening host for security reasons.
267 // Users can always explicitly setup reverse proxies if desirable.
268 hostname, _, err := net.SplitHostPort(u.Host)
269 if err != nil {
270 log.Fatalf("invalid host in URI %q: %v", listen, err)
271 } else if hostname != "localhost" {
272 log.Fatalf("Prometheus listening host must be localhost")
273 }
274
275 metricsHandler := promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{
276 MaxRequestsInFlight: 10,
277 Timeout: 10 * time.Second,
278 EnableOpenMetrics: true,
279 })
280 metricsHandler = promhttp.InstrumentMetricHandler(prometheus.DefaultRegisterer, metricsHandler)
281
282 httpSrv := http.Server{
283 Addr: u.Host,
284 Handler: metricsHandler,
285 }
286 go func() {
287 if err := httpSrv.ListenAndServe(); err != nil {
288 log.Fatalf("serving %q: %v", listen, err)
289 }
290 }()
291 case "http+pprof":
292 // Only allow localhost as listening host for security reasons.
293 // Users can always explicitly setup reverse proxies if desirable.
294 hostname, _, err := net.SplitHostPort(u.Host)
295 if err != nil {
296 log.Fatalf("invalid host in URI %q: %v", listen, err)
297 } else if hostname != "localhost" {
298 log.Fatalf("pprof listening host must be localhost")
299 }
300
301 // net/http/pprof registers its handlers in http.DefaultServeMux
302 httpSrv := http.Server{
303 Addr: u.Host,
304 Handler: http.DefaultServeMux,
305 }
306 go func() {
307 if err := httpSrv.ListenAndServe(); err != nil {
308 log.Fatalf("serving %q: %v", listen, err)
309 }
310 }()
311 default:
312 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
313 }
314
315 log.Printf("server listening on %q", listen)
316 }
317
318 if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
319 srv.MetricsRegistry.MustRegister(db.MetricsCollector())
320 }
321
322 sigCh := make(chan os.Signal, 1)
323 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
324
325 if err := srv.Start(); err != nil {
326 log.Fatal(err)
327 }
328
329 for sig := range sigCh {
330 switch sig {
331 case syscall.SIGHUP:
332 log.Print("reloading configuration")
333 _, serverCfg, err := loadConfig()
334 if err != nil {
335 log.Printf("failed to reloading configuration: %v", err)
336 } else {
337 srv.SetConfig(serverCfg)
338 }
339 case syscall.SIGINT, syscall.SIGTERM:
340 log.Print("shutting down server")
341 srv.Shutdown()
342 return
343 }
344 }
345}
346
347func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
348 return &proxyproto.Listener{
349 Listener: ln,
350 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
351 tcpAddr, ok := upstream.(*net.TCPAddr)
352 if !ok {
353 return proxyproto.IGNORE, nil
354 }
355 if srv.Config().AcceptProxyIPs.Contains(tcpAddr.IP) {
356 return proxyproto.USE, nil
357 }
358 return proxyproto.IGNORE, nil
359 },
360 ReadHeaderTimeout: 5 * time.Second,
361 }
362}
Note: See TracBrowser for help on using the repository browser.