source: code/trunk/cmd/soju/main.go@ 687

Last change on this file since 687 was 687, checked in by contact, 4 years ago

cmd/soju: bump max number of opened files

The bouncer process may be dealing with many opened FDs. The default
on Linux is 1024. To support bouncers with a lot of users, bump
RLIMIT_NOFILE to the max as advised in [1].

[1]: http://0pointer.net/blog/file-descriptor-limits.html

File size: 7.1 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "io/ioutil"
9 "log"
10 "net"
11 "net/http"
12 "net/url"
13 "os"
14 "os/signal"
15 "strings"
16 "sync/atomic"
17 "syscall"
18 "time"
19
20 "github.com/pires/go-proxyproto"
21
22 "git.sr.ht/~emersion/soju"
23 "git.sr.ht/~emersion/soju/config"
24)
25
26// TCP keep-alive interval for downstream TCP connections
27const downstreamKeepAlive = 1 * time.Hour
28
29type stringSliceFlag []string
30
31func (v *stringSliceFlag) String() string {
32 return fmt.Sprint([]string(*v))
33}
34
35func (v *stringSliceFlag) Set(s string) error {
36 *v = append(*v, s)
37 return nil
38}
39
40func loadMOTD(srv *soju.Server, filename string) error {
41 if filename == "" {
42 return nil
43 }
44
45 b, err := ioutil.ReadFile(filename)
46 if err != nil {
47 return err
48 }
49 srv.SetMOTD(strings.TrimSuffix(string(b), "\n"))
50 return nil
51}
52
53func bumpOpenedFileLimit() error {
54 var rlimit syscall.Rlimit
55 if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
56 return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
57 }
58 rlimit.Cur = rlimit.Max
59 if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
60 return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
61 }
62 return nil
63}
64
65func main() {
66 var listen []string
67 var configPath string
68 var debug bool
69 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
70 flag.StringVar(&configPath, "config", "", "path to configuration file")
71 flag.BoolVar(&debug, "debug", false, "enable debug logging")
72 flag.Parse()
73
74 var cfg *config.Server
75 if configPath != "" {
76 var err error
77 cfg, err = config.Load(configPath)
78 if err != nil {
79 log.Fatalf("failed to load config file: %v", err)
80 }
81 } else {
82 cfg = config.Defaults()
83 }
84
85 cfg.Listen = append(cfg.Listen, listen...)
86 if len(cfg.Listen) == 0 {
87 cfg.Listen = []string{":6697"}
88 }
89
90 if err := bumpOpenedFileLimit(); err != nil {
91 log.Printf("failed to bump max number of opened files: %v", err)
92 }
93
94 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
95 if err != nil {
96 log.Fatalf("failed to open database: %v", err)
97 }
98
99 var tlsCfg *tls.Config
100 var tlsCert atomic.Value
101 if cfg.TLS != nil {
102 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
103 if err != nil {
104 log.Fatalf("failed to load TLS certificate and key: %v", err)
105 }
106 tlsCert.Store(&cert)
107
108 tlsCfg = &tls.Config{
109 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
110 return tlsCert.Load().(*tls.Certificate), nil
111 },
112 }
113 }
114
115 srv := soju.NewServer(db)
116 srv.Hostname = cfg.Hostname
117 srv.Title = cfg.Title
118 srv.LogPath = cfg.LogPath
119 srv.HTTPOrigins = cfg.HTTPOrigins
120 srv.AcceptProxyIPs = cfg.AcceptProxyIPs
121 srv.MaxUserNetworks = cfg.MaxUserNetworks
122 srv.Debug = debug
123
124 if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
125 log.Fatalf("failed to load MOTD: %v", err)
126 }
127
128 for _, listen := range cfg.Listen {
129 listenURI := listen
130 if !strings.Contains(listenURI, ":/") {
131 // This is a raw domain name, make it an URL with an empty scheme
132 listenURI = "//" + listenURI
133 }
134 u, err := url.Parse(listenURI)
135 if err != nil {
136 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
137 }
138
139 switch u.Scheme {
140 case "ircs", "":
141 if tlsCfg == nil {
142 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
143 }
144 host := u.Host
145 if _, _, err := net.SplitHostPort(host); err != nil {
146 host = host + ":6697"
147 }
148 ircsTLSCfg := tlsCfg.Clone()
149 ircsTLSCfg.NextProtos = []string{"irc"}
150 lc := net.ListenConfig{
151 KeepAlive: downstreamKeepAlive,
152 }
153 l, err := lc.Listen(context.Background(), "tcp", host)
154 if err != nil {
155 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
156 }
157 ln := tls.NewListener(l, ircsTLSCfg)
158 ln = proxyProtoListener(ln, srv)
159 go func() {
160 if err := srv.Serve(ln); err != nil {
161 log.Printf("serving %q: %v", listen, err)
162 }
163 }()
164 case "irc+insecure":
165 host := u.Host
166 if _, _, err := net.SplitHostPort(host); err != nil {
167 host = host + ":6667"
168 }
169 lc := net.ListenConfig{
170 KeepAlive: downstreamKeepAlive,
171 }
172 ln, err := lc.Listen(context.Background(), "tcp", host)
173 if err != nil {
174 log.Fatalf("failed to start listener on %q: %v", listen, err)
175 }
176 ln = proxyProtoListener(ln, srv)
177 go func() {
178 if err := srv.Serve(ln); err != nil {
179 log.Printf("serving %q: %v", listen, err)
180 }
181 }()
182 case "unix":
183 ln, err := net.Listen("unix", u.Path)
184 if err != nil {
185 log.Fatalf("failed to start listener on %q: %v", listen, err)
186 }
187 ln = proxyProtoListener(ln, srv)
188 go func() {
189 if err := srv.Serve(ln); err != nil {
190 log.Printf("serving %q: %v", listen, err)
191 }
192 }()
193 case "wss":
194 if tlsCfg == nil {
195 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
196 }
197 addr := u.Host
198 if _, _, err := net.SplitHostPort(addr); err != nil {
199 addr = addr + ":https"
200 }
201 httpSrv := http.Server{
202 Addr: addr,
203 TLSConfig: tlsCfg,
204 Handler: srv,
205 }
206 go func() {
207 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
208 log.Fatalf("serving %q: %v", listen, err)
209 }
210 }()
211 case "ws+insecure":
212 addr := u.Host
213 if _, _, err := net.SplitHostPort(addr); err != nil {
214 addr = addr + ":http"
215 }
216 httpSrv := http.Server{
217 Addr: addr,
218 Handler: srv,
219 }
220 go func() {
221 if err := httpSrv.ListenAndServe(); err != nil {
222 log.Fatalf("serving %q: %v", listen, err)
223 }
224 }()
225 case "ident":
226 if srv.Identd == nil {
227 srv.Identd = soju.NewIdentd()
228 }
229
230 host := u.Host
231 if _, _, err := net.SplitHostPort(host); err != nil {
232 host = host + ":113"
233 }
234 ln, err := net.Listen("tcp", host)
235 if err != nil {
236 log.Fatalf("failed to start listener on %q: %v", listen, err)
237 }
238 ln = proxyProtoListener(ln, srv)
239 go func() {
240 if err := srv.Identd.Serve(ln); err != nil {
241 log.Printf("serving %q: %v", listen, err)
242 }
243 }()
244 default:
245 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
246 }
247
248 log.Printf("server listening on %q", listen)
249 }
250
251 sigCh := make(chan os.Signal, 1)
252 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
253
254 if err := srv.Start(); err != nil {
255 log.Fatal(err)
256 }
257
258 for sig := range sigCh {
259 switch sig {
260 case syscall.SIGHUP:
261 log.Print("reloading TLS certificate and MOTD")
262 if cfg.TLS != nil {
263 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
264 if err != nil {
265 log.Printf("failed to reload TLS certificate and key: %v", err)
266 break
267 }
268 tlsCert.Store(&cert)
269 }
270 if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
271 log.Printf("failed to reload MOTD: %v", err)
272 }
273 case syscall.SIGINT, syscall.SIGTERM:
274 log.Print("shutting down server")
275 srv.Shutdown()
276 return
277 }
278 }
279}
280
281func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
282 return &proxyproto.Listener{
283 Listener: ln,
284 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
285 tcpAddr, ok := upstream.(*net.TCPAddr)
286 if !ok {
287 return proxyproto.IGNORE, nil
288 }
289 if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
290 return proxyproto.USE, nil
291 }
292 return proxyproto.IGNORE, nil
293 },
294 ReadHeaderTimeout: 5 * time.Second,
295 }
296}
Note: See TracBrowser for help on using the repository browser.