source: code/trunk/cmd/soju/main.go@ 671

Last change on this file since 671 was 662, checked in by contact, 4 years ago

Add title config option

Closes: https://todo.sr.ht/~emersion/soju/146

File size: 6.6 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "io/ioutil"
9 "log"
10 "net"
11 "net/http"
12 "net/url"
13 "os"
14 "os/signal"
15 "strings"
16 "sync/atomic"
17 "syscall"
18 "time"
19
20 "github.com/pires/go-proxyproto"
21
22 "git.sr.ht/~emersion/soju"
23 "git.sr.ht/~emersion/soju/config"
24)
25
26// TCP keep-alive interval for downstream TCP connections
27const downstreamKeepAlive = 1 * time.Hour
28
29type stringSliceFlag []string
30
31func (v *stringSliceFlag) String() string {
32 return fmt.Sprint([]string(*v))
33}
34
35func (v *stringSliceFlag) Set(s string) error {
36 *v = append(*v, s)
37 return nil
38}
39
40func loadMOTD(srv *soju.Server, filename string) error {
41 if filename == "" {
42 return nil
43 }
44
45 b, err := ioutil.ReadFile(filename)
46 if err != nil {
47 return err
48 }
49 srv.SetMOTD(strings.TrimSuffix(string(b), "\n"))
50 return nil
51}
52
53func main() {
54 var listen []string
55 var configPath string
56 var debug bool
57 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
58 flag.StringVar(&configPath, "config", "", "path to configuration file")
59 flag.BoolVar(&debug, "debug", false, "enable debug logging")
60 flag.Parse()
61
62 var cfg *config.Server
63 if configPath != "" {
64 var err error
65 cfg, err = config.Load(configPath)
66 if err != nil {
67 log.Fatalf("failed to load config file: %v", err)
68 }
69 } else {
70 cfg = config.Defaults()
71 }
72
73 cfg.Listen = append(cfg.Listen, listen...)
74 if len(cfg.Listen) == 0 {
75 cfg.Listen = []string{":6697"}
76 }
77
78 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
79 if err != nil {
80 log.Fatalf("failed to open database: %v", err)
81 }
82
83 var tlsCfg *tls.Config
84 var tlsCert atomic.Value
85 if cfg.TLS != nil {
86 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
87 if err != nil {
88 log.Fatalf("failed to load TLS certificate and key: %v", err)
89 }
90 tlsCert.Store(&cert)
91
92 tlsCfg = &tls.Config{
93 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
94 return tlsCert.Load().(*tls.Certificate), nil
95 },
96 }
97 }
98
99 srv := soju.NewServer(db)
100 srv.Hostname = cfg.Hostname
101 srv.Title = cfg.Title
102 srv.LogPath = cfg.LogPath
103 srv.HTTPOrigins = cfg.HTTPOrigins
104 srv.AcceptProxyIPs = cfg.AcceptProxyIPs
105 srv.MaxUserNetworks = cfg.MaxUserNetworks
106 srv.Debug = debug
107
108 if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
109 log.Fatalf("failed to load MOTD: %v", err)
110 }
111
112 for _, listen := range cfg.Listen {
113 listenURI := listen
114 if !strings.Contains(listenURI, ":/") {
115 // This is a raw domain name, make it an URL with an empty scheme
116 listenURI = "//" + listenURI
117 }
118 u, err := url.Parse(listenURI)
119 if err != nil {
120 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
121 }
122
123 switch u.Scheme {
124 case "ircs", "":
125 if tlsCfg == nil {
126 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
127 }
128 host := u.Host
129 if _, _, err := net.SplitHostPort(host); err != nil {
130 host = host + ":6697"
131 }
132 ircsTLSCfg := tlsCfg.Clone()
133 ircsTLSCfg.NextProtos = []string{"irc"}
134 lc := net.ListenConfig{
135 KeepAlive: downstreamKeepAlive,
136 }
137 l, err := lc.Listen(context.Background(), "tcp", host)
138 if err != nil {
139 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
140 }
141 ln := tls.NewListener(l, ircsTLSCfg)
142 ln = proxyProtoListener(ln, srv)
143 go func() {
144 if err := srv.Serve(ln); err != nil {
145 log.Printf("serving %q: %v", listen, err)
146 }
147 }()
148 case "irc+insecure":
149 host := u.Host
150 if _, _, err := net.SplitHostPort(host); err != nil {
151 host = host + ":6667"
152 }
153 lc := net.ListenConfig{
154 KeepAlive: downstreamKeepAlive,
155 }
156 ln, err := lc.Listen(context.Background(), "tcp", host)
157 if err != nil {
158 log.Fatalf("failed to start listener on %q: %v", listen, err)
159 }
160 ln = proxyProtoListener(ln, srv)
161 go func() {
162 if err := srv.Serve(ln); err != nil {
163 log.Printf("serving %q: %v", listen, err)
164 }
165 }()
166 case "unix":
167 ln, err := net.Listen("unix", u.Path)
168 if err != nil {
169 log.Fatalf("failed to start listener on %q: %v", listen, err)
170 }
171 ln = proxyProtoListener(ln, srv)
172 go func() {
173 if err := srv.Serve(ln); err != nil {
174 log.Printf("serving %q: %v", listen, err)
175 }
176 }()
177 case "wss":
178 if tlsCfg == nil {
179 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
180 }
181 addr := u.Host
182 if _, _, err := net.SplitHostPort(addr); err != nil {
183 addr = addr + ":https"
184 }
185 httpSrv := http.Server{
186 Addr: addr,
187 TLSConfig: tlsCfg,
188 Handler: srv,
189 }
190 go func() {
191 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
192 log.Fatalf("serving %q: %v", listen, err)
193 }
194 }()
195 case "ws+insecure":
196 addr := u.Host
197 if _, _, err := net.SplitHostPort(addr); err != nil {
198 addr = addr + ":http"
199 }
200 httpSrv := http.Server{
201 Addr: addr,
202 Handler: srv,
203 }
204 go func() {
205 if err := httpSrv.ListenAndServe(); err != nil {
206 log.Fatalf("serving %q: %v", listen, err)
207 }
208 }()
209 case "ident":
210 if srv.Identd == nil {
211 srv.Identd = soju.NewIdentd()
212 }
213
214 host := u.Host
215 if _, _, err := net.SplitHostPort(host); err != nil {
216 host = host + ":113"
217 }
218 ln, err := net.Listen("tcp", host)
219 if err != nil {
220 log.Fatalf("failed to start listener on %q: %v", listen, err)
221 }
222 ln = proxyProtoListener(ln, srv)
223 go func() {
224 if err := srv.Identd.Serve(ln); err != nil {
225 log.Printf("serving %q: %v", listen, err)
226 }
227 }()
228 default:
229 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
230 }
231
232 log.Printf("server listening on %q", listen)
233 }
234
235 sigCh := make(chan os.Signal, 1)
236 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
237
238 if err := srv.Start(); err != nil {
239 log.Fatal(err)
240 }
241
242 for sig := range sigCh {
243 switch sig {
244 case syscall.SIGHUP:
245 log.Print("reloading TLS certificate and MOTD")
246 if cfg.TLS != nil {
247 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
248 if err != nil {
249 log.Printf("failed to reload TLS certificate and key: %v", err)
250 break
251 }
252 tlsCert.Store(&cert)
253 }
254 if err := loadMOTD(srv, cfg.MOTDPath); err != nil {
255 log.Printf("failed to reload MOTD: %v", err)
256 }
257 case syscall.SIGINT, syscall.SIGTERM:
258 log.Print("shutting down server")
259 srv.Shutdown()
260 return
261 }
262 }
263}
264
265func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
266 return &proxyproto.Listener{
267 Listener: ln,
268 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
269 tcpAddr, ok := upstream.(*net.TCPAddr)
270 if !ok {
271 return proxyproto.IGNORE, nil
272 }
273 if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
274 return proxyproto.USE, nil
275 }
276 return proxyproto.IGNORE, nil
277 },
278 ReadHeaderTimeout: 5 * time.Second,
279 }
280}
Note: See TracBrowser for help on using the repository browser.