source: code/trunk/cmd/soju/main.go@ 531

Last change on this file since 531 was 531, checked in by sir, 4 years ago

db: refactor into interface

This refactors the SQLite-specific bits into db_sqlite.go. A future
patch will add PostgreSQL support.

File size: 6.0 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "log"
9 "net"
10 "net/http"
11 "net/url"
12 "os"
13 "os/signal"
14 "strings"
15 "sync/atomic"
16 "syscall"
17 "time"
18
19 "github.com/pires/go-proxyproto"
20
21 "git.sr.ht/~emersion/soju"
22 "git.sr.ht/~emersion/soju/config"
23)
24
25// TCP keep-alive interval for downstream TCP connections
26const downstreamKeepAlive = 1 * time.Hour
27
28type stringSliceFlag []string
29
30func (v *stringSliceFlag) String() string {
31 return fmt.Sprint([]string(*v))
32}
33
34func (v *stringSliceFlag) Set(s string) error {
35 *v = append(*v, s)
36 return nil
37}
38
39func main() {
40 var listen []string
41 var configPath string
42 var debug bool
43 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
44 flag.StringVar(&configPath, "config", "", "path to configuration file")
45 flag.BoolVar(&debug, "debug", false, "enable debug logging")
46 flag.Parse()
47
48 var cfg *config.Server
49 if configPath != "" {
50 var err error
51 cfg, err = config.Load(configPath)
52 if err != nil {
53 log.Fatalf("failed to load config file: %v", err)
54 }
55 } else {
56 cfg = config.Defaults()
57 }
58
59 cfg.Listen = append(cfg.Listen, listen...)
60 if len(cfg.Listen) == 0 {
61 cfg.Listen = []string{":6697"}
62 }
63
64 db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource)
65 if err != nil {
66 log.Fatalf("failed to open database: %v", err)
67 }
68
69 var tlsCfg *tls.Config
70 var tlsCert atomic.Value
71 if cfg.TLS != nil {
72 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
73 if err != nil {
74 log.Fatalf("failed to load TLS certificate and key: %v", err)
75 }
76 tlsCert.Store(&cert)
77
78 tlsCfg = &tls.Config{
79 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
80 return tlsCert.Load().(*tls.Certificate), nil
81 },
82 }
83 }
84
85 srv := soju.NewServer(db)
86 // TODO: load from config/DB
87 srv.Hostname = cfg.Hostname
88 srv.LogPath = cfg.LogPath
89 srv.HTTPOrigins = cfg.HTTPOrigins
90 srv.AcceptProxyIPs = cfg.AcceptProxyIPs
91 srv.Debug = debug
92
93 for _, listen := range cfg.Listen {
94 listenURI := listen
95 if !strings.Contains(listenURI, ":/") {
96 // This is a raw domain name, make it an URL with an empty scheme
97 listenURI = "//" + listenURI
98 }
99 u, err := url.Parse(listenURI)
100 if err != nil {
101 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
102 }
103
104 switch u.Scheme {
105 case "ircs", "":
106 if tlsCfg == nil {
107 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
108 }
109 host := u.Host
110 if _, _, err := net.SplitHostPort(host); err != nil {
111 host = host + ":6697"
112 }
113 ircsTLSCfg := tlsCfg.Clone()
114 ircsTLSCfg.NextProtos = []string{"irc"}
115 lc := net.ListenConfig{
116 KeepAlive: downstreamKeepAlive,
117 }
118 l, err := lc.Listen(context.Background(), "tcp", host)
119 if err != nil {
120 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
121 }
122 ln := tls.NewListener(l, ircsTLSCfg)
123 ln = proxyProtoListener(ln, srv)
124 go func() {
125 if err := srv.Serve(ln); err != nil {
126 log.Printf("serving %q: %v", listen, err)
127 }
128 }()
129 case "irc+insecure":
130 host := u.Host
131 if _, _, err := net.SplitHostPort(host); err != nil {
132 host = host + ":6667"
133 }
134 lc := net.ListenConfig{
135 KeepAlive: downstreamKeepAlive,
136 }
137 ln, err := lc.Listen(context.Background(), "tcp", host)
138 if err != nil {
139 log.Fatalf("failed to start listener on %q: %v", listen, err)
140 }
141 ln = proxyProtoListener(ln, srv)
142 go func() {
143 if err := srv.Serve(ln); err != nil {
144 log.Printf("serving %q: %v", listen, err)
145 }
146 }()
147 case "unix":
148 ln, err := net.Listen("unix", u.Path)
149 if err != nil {
150 log.Fatalf("failed to start listener on %q: %v", listen, err)
151 }
152 ln = proxyProtoListener(ln, srv)
153 go func() {
154 if err := srv.Serve(ln); err != nil {
155 log.Printf("serving %q: %v", listen, err)
156 }
157 }()
158 case "wss":
159 addr := u.Host
160 if _, _, err := net.SplitHostPort(addr); err != nil {
161 addr = addr + ":https"
162 }
163 httpSrv := http.Server{
164 Addr: addr,
165 TLSConfig: tlsCfg,
166 Handler: srv,
167 }
168 go func() {
169 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
170 log.Fatalf("serving %q: %v", listen, err)
171 }
172 }()
173 case "ws+insecure":
174 addr := u.Host
175 if _, _, err := net.SplitHostPort(addr); err != nil {
176 addr = addr + ":http"
177 }
178 httpSrv := http.Server{
179 Addr: addr,
180 Handler: srv,
181 }
182 go func() {
183 if err := httpSrv.ListenAndServe(); err != nil {
184 log.Fatalf("serving %q: %v", listen, err)
185 }
186 }()
187 case "ident":
188 if srv.Identd == nil {
189 srv.Identd = soju.NewIdentd()
190 }
191
192 host := u.Host
193 if _, _, err := net.SplitHostPort(host); err != nil {
194 host = host + ":113"
195 }
196 ln, err := net.Listen("tcp", host)
197 if err != nil {
198 log.Fatalf("failed to start listener on %q: %v", listen, err)
199 }
200 ln = proxyProtoListener(ln, srv)
201 go func() {
202 if err := srv.Identd.Serve(ln); err != nil {
203 log.Printf("serving %q: %v", listen, err)
204 }
205 }()
206 default:
207 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
208 }
209
210 log.Printf("server listening on %q", listen)
211 }
212
213 sigCh := make(chan os.Signal, 1)
214 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
215
216 if err := srv.Start(); err != nil {
217 log.Fatal(err)
218 }
219
220 for sig := range sigCh {
221 switch sig {
222 case syscall.SIGHUP:
223 if cfg.TLS != nil {
224 log.Print("reloading TLS certificate")
225 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
226 if err != nil {
227 log.Printf("failed to reload TLS certificate and key: %v", err)
228 break
229 }
230 tlsCert.Store(&cert)
231 }
232 case syscall.SIGINT, syscall.SIGTERM:
233 log.Print("shutting down server")
234 srv.Shutdown()
235 return
236 }
237 }
238}
239
240func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
241 return &proxyproto.Listener{
242 Listener: ln,
243 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
244 tcpAddr, ok := upstream.(*net.TCPAddr)
245 if !ok {
246 return proxyproto.IGNORE, nil
247 }
248 if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
249 return proxyproto.USE, nil
250 }
251 return proxyproto.IGNORE, nil
252 },
253 }
254}
Note: See TracBrowser for help on using the repository browser.