source: code/trunk/cmd/soju/main.go@ 581

Last change on this file since 581 was 581, checked in by rafael, 4 years ago

Check for TLS config in wss listeners

Previously http.Server.ListenAndServeTLS would return a not very helpful
error about a failed open. This adds a check similar to the one in the
ircs case that should make it clearer to operators what the error is.

File size: 6.1 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "log"
9 "net"
10 "net/http"
11 "net/url"
12 "os"
13 "os/signal"
14 "strings"
15 "sync/atomic"
16 "syscall"
17 "time"
18
19 "github.com/pires/go-proxyproto"
20
21 "git.sr.ht/~emersion/soju"
22 "git.sr.ht/~emersion/soju/config"
23)
24
25// TCP keep-alive interval for downstream TCP connections
26const downstreamKeepAlive = 1 * time.Hour
27
28type stringSliceFlag []string
29
30func (v *stringSliceFlag) String() string {
31 return fmt.Sprint([]string(*v))
32}
33
34func (v *stringSliceFlag) Set(s string) error {
35 *v = append(*v, s)
36 return nil
37}
38
39func main() {
40 var listen []string
41 var configPath string
42 var debug bool
43 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
44 flag.StringVar(&configPath, "config", "", "path to configuration file")
45 flag.BoolVar(&debug, "debug", false, "enable debug logging")
46 flag.Parse()
47
48 var cfg *config.Server
49 if configPath != "" {
50 var err error
51 cfg, err = config.Load(configPath)
52 if err != nil {
53 log.Fatalf("failed to load config file: %v", err)
54 }
55 } else {
56 cfg = config.Defaults()
57 }
58
59 cfg.Listen = append(cfg.Listen, listen...)
60 if len(cfg.Listen) == 0 {
61 cfg.Listen = []string{":6697"}
62 }
63
64 db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource)
65 if err != nil {
66 log.Fatalf("failed to open database: %v", err)
67 }
68
69 var tlsCfg *tls.Config
70 var tlsCert atomic.Value
71 if cfg.TLS != nil {
72 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
73 if err != nil {
74 log.Fatalf("failed to load TLS certificate and key: %v", err)
75 }
76 tlsCert.Store(&cert)
77
78 tlsCfg = &tls.Config{
79 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
80 return tlsCert.Load().(*tls.Certificate), nil
81 },
82 }
83 }
84
85 srv := soju.NewServer(db)
86 // TODO: load from config/DB
87 srv.Hostname = cfg.Hostname
88 srv.LogPath = cfg.LogPath
89 srv.HTTPOrigins = cfg.HTTPOrigins
90 srv.AcceptProxyIPs = cfg.AcceptProxyIPs
91 srv.Debug = debug
92
93 for _, listen := range cfg.Listen {
94 listenURI := listen
95 if !strings.Contains(listenURI, ":/") {
96 // This is a raw domain name, make it an URL with an empty scheme
97 listenURI = "//" + listenURI
98 }
99 u, err := url.Parse(listenURI)
100 if err != nil {
101 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
102 }
103
104 switch u.Scheme {
105 case "ircs", "":
106 if tlsCfg == nil {
107 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
108 }
109 host := u.Host
110 if _, _, err := net.SplitHostPort(host); err != nil {
111 host = host + ":6697"
112 }
113 ircsTLSCfg := tlsCfg.Clone()
114 ircsTLSCfg.NextProtos = []string{"irc"}
115 lc := net.ListenConfig{
116 KeepAlive: downstreamKeepAlive,
117 }
118 l, err := lc.Listen(context.Background(), "tcp", host)
119 if err != nil {
120 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
121 }
122 ln := tls.NewListener(l, ircsTLSCfg)
123 ln = proxyProtoListener(ln, srv)
124 go func() {
125 if err := srv.Serve(ln); err != nil {
126 log.Printf("serving %q: %v", listen, err)
127 }
128 }()
129 case "irc+insecure":
130 host := u.Host
131 if _, _, err := net.SplitHostPort(host); err != nil {
132 host = host + ":6667"
133 }
134 lc := net.ListenConfig{
135 KeepAlive: downstreamKeepAlive,
136 }
137 ln, err := lc.Listen(context.Background(), "tcp", host)
138 if err != nil {
139 log.Fatalf("failed to start listener on %q: %v", listen, err)
140 }
141 ln = proxyProtoListener(ln, srv)
142 go func() {
143 if err := srv.Serve(ln); err != nil {
144 log.Printf("serving %q: %v", listen, err)
145 }
146 }()
147 case "unix":
148 ln, err := net.Listen("unix", u.Path)
149 if err != nil {
150 log.Fatalf("failed to start listener on %q: %v", listen, err)
151 }
152 ln = proxyProtoListener(ln, srv)
153 go func() {
154 if err := srv.Serve(ln); err != nil {
155 log.Printf("serving %q: %v", listen, err)
156 }
157 }()
158 case "wss":
159 if tlsCfg == nil {
160 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
161 }
162 addr := u.Host
163 if _, _, err := net.SplitHostPort(addr); err != nil {
164 addr = addr + ":https"
165 }
166 httpSrv := http.Server{
167 Addr: addr,
168 TLSConfig: tlsCfg,
169 Handler: srv,
170 }
171 go func() {
172 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
173 log.Fatalf("serving %q: %v", listen, err)
174 }
175 }()
176 case "ws+insecure":
177 addr := u.Host
178 if _, _, err := net.SplitHostPort(addr); err != nil {
179 addr = addr + ":http"
180 }
181 httpSrv := http.Server{
182 Addr: addr,
183 Handler: srv,
184 }
185 go func() {
186 if err := httpSrv.ListenAndServe(); err != nil {
187 log.Fatalf("serving %q: %v", listen, err)
188 }
189 }()
190 case "ident":
191 if srv.Identd == nil {
192 srv.Identd = soju.NewIdentd()
193 }
194
195 host := u.Host
196 if _, _, err := net.SplitHostPort(host); err != nil {
197 host = host + ":113"
198 }
199 ln, err := net.Listen("tcp", host)
200 if err != nil {
201 log.Fatalf("failed to start listener on %q: %v", listen, err)
202 }
203 ln = proxyProtoListener(ln, srv)
204 go func() {
205 if err := srv.Identd.Serve(ln); err != nil {
206 log.Printf("serving %q: %v", listen, err)
207 }
208 }()
209 default:
210 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
211 }
212
213 log.Printf("server listening on %q", listen)
214 }
215
216 sigCh := make(chan os.Signal, 1)
217 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
218
219 if err := srv.Start(); err != nil {
220 log.Fatal(err)
221 }
222
223 for sig := range sigCh {
224 switch sig {
225 case syscall.SIGHUP:
226 if cfg.TLS != nil {
227 log.Print("reloading TLS certificate")
228 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
229 if err != nil {
230 log.Printf("failed to reload TLS certificate and key: %v", err)
231 break
232 }
233 tlsCert.Store(&cert)
234 }
235 case syscall.SIGINT, syscall.SIGTERM:
236 log.Print("shutting down server")
237 srv.Shutdown()
238 return
239 }
240 }
241}
242
243func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
244 return &proxyproto.Listener{
245 Listener: ln,
246 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
247 tcpAddr, ok := upstream.(*net.TCPAddr)
248 if !ok {
249 return proxyproto.IGNORE, nil
250 }
251 if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
252 return proxyproto.USE, nil
253 }
254 return proxyproto.IGNORE, nil
255 },
256 }
257}
Note: See TracBrowser for help on using the repository browser.