source: code/trunk/cmd/soju/main.go@ 628

Last change on this file since 628 was 620, checked in by hubert, 4 years ago

PostgreSQL support

File size: 6.1 KB
Line 
1package main
2
3import (
4 "context"
5 "crypto/tls"
6 "flag"
7 "fmt"
8 "log"
9 "net"
10 "net/http"
11 "net/url"
12 "os"
13 "os/signal"
14 "strings"
15 "sync/atomic"
16 "syscall"
17 "time"
18
19 "github.com/pires/go-proxyproto"
20
21 "git.sr.ht/~emersion/soju"
22 "git.sr.ht/~emersion/soju/config"
23)
24
25// TCP keep-alive interval for downstream TCP connections
26const downstreamKeepAlive = 1 * time.Hour
27
28type stringSliceFlag []string
29
30func (v *stringSliceFlag) String() string {
31 return fmt.Sprint([]string(*v))
32}
33
34func (v *stringSliceFlag) Set(s string) error {
35 *v = append(*v, s)
36 return nil
37}
38
39func main() {
40 var listen []string
41 var configPath string
42 var debug bool
43 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
44 flag.StringVar(&configPath, "config", "", "path to configuration file")
45 flag.BoolVar(&debug, "debug", false, "enable debug logging")
46 flag.Parse()
47
48 var cfg *config.Server
49 if configPath != "" {
50 var err error
51 cfg, err = config.Load(configPath)
52 if err != nil {
53 log.Fatalf("failed to load config file: %v", err)
54 }
55 } else {
56 cfg = config.Defaults()
57 }
58
59 cfg.Listen = append(cfg.Listen, listen...)
60 if len(cfg.Listen) == 0 {
61 cfg.Listen = []string{":6697"}
62 }
63
64 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
65 if err != nil {
66 log.Fatalf("failed to open database: %v", err)
67 }
68
69 var tlsCfg *tls.Config
70 var tlsCert atomic.Value
71 if cfg.TLS != nil {
72 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
73 if err != nil {
74 log.Fatalf("failed to load TLS certificate and key: %v", err)
75 }
76 tlsCert.Store(&cert)
77
78 tlsCfg = &tls.Config{
79 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
80 return tlsCert.Load().(*tls.Certificate), nil
81 },
82 }
83 }
84
85 srv := soju.NewServer(db)
86 // TODO: load from config/DB
87 srv.Hostname = cfg.Hostname
88 srv.LogPath = cfg.LogPath
89 srv.HTTPOrigins = cfg.HTTPOrigins
90 srv.AcceptProxyIPs = cfg.AcceptProxyIPs
91 srv.MaxUserNetworks = cfg.MaxUserNetworks
92 srv.Debug = debug
93
94 for _, listen := range cfg.Listen {
95 listenURI := listen
96 if !strings.Contains(listenURI, ":/") {
97 // This is a raw domain name, make it an URL with an empty scheme
98 listenURI = "//" + listenURI
99 }
100 u, err := url.Parse(listenURI)
101 if err != nil {
102 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
103 }
104
105 switch u.Scheme {
106 case "ircs", "":
107 if tlsCfg == nil {
108 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
109 }
110 host := u.Host
111 if _, _, err := net.SplitHostPort(host); err != nil {
112 host = host + ":6697"
113 }
114 ircsTLSCfg := tlsCfg.Clone()
115 ircsTLSCfg.NextProtos = []string{"irc"}
116 lc := net.ListenConfig{
117 KeepAlive: downstreamKeepAlive,
118 }
119 l, err := lc.Listen(context.Background(), "tcp", host)
120 if err != nil {
121 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
122 }
123 ln := tls.NewListener(l, ircsTLSCfg)
124 ln = proxyProtoListener(ln, srv)
125 go func() {
126 if err := srv.Serve(ln); err != nil {
127 log.Printf("serving %q: %v", listen, err)
128 }
129 }()
130 case "irc+insecure":
131 host := u.Host
132 if _, _, err := net.SplitHostPort(host); err != nil {
133 host = host + ":6667"
134 }
135 lc := net.ListenConfig{
136 KeepAlive: downstreamKeepAlive,
137 }
138 ln, err := lc.Listen(context.Background(), "tcp", host)
139 if err != nil {
140 log.Fatalf("failed to start listener on %q: %v", listen, err)
141 }
142 ln = proxyProtoListener(ln, srv)
143 go func() {
144 if err := srv.Serve(ln); err != nil {
145 log.Printf("serving %q: %v", listen, err)
146 }
147 }()
148 case "unix":
149 ln, err := net.Listen("unix", u.Path)
150 if err != nil {
151 log.Fatalf("failed to start listener on %q: %v", listen, err)
152 }
153 ln = proxyProtoListener(ln, srv)
154 go func() {
155 if err := srv.Serve(ln); err != nil {
156 log.Printf("serving %q: %v", listen, err)
157 }
158 }()
159 case "wss":
160 if tlsCfg == nil {
161 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
162 }
163 addr := u.Host
164 if _, _, err := net.SplitHostPort(addr); err != nil {
165 addr = addr + ":https"
166 }
167 httpSrv := http.Server{
168 Addr: addr,
169 TLSConfig: tlsCfg,
170 Handler: srv,
171 }
172 go func() {
173 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
174 log.Fatalf("serving %q: %v", listen, err)
175 }
176 }()
177 case "ws+insecure":
178 addr := u.Host
179 if _, _, err := net.SplitHostPort(addr); err != nil {
180 addr = addr + ":http"
181 }
182 httpSrv := http.Server{
183 Addr: addr,
184 Handler: srv,
185 }
186 go func() {
187 if err := httpSrv.ListenAndServe(); err != nil {
188 log.Fatalf("serving %q: %v", listen, err)
189 }
190 }()
191 case "ident":
192 if srv.Identd == nil {
193 srv.Identd = soju.NewIdentd()
194 }
195
196 host := u.Host
197 if _, _, err := net.SplitHostPort(host); err != nil {
198 host = host + ":113"
199 }
200 ln, err := net.Listen("tcp", host)
201 if err != nil {
202 log.Fatalf("failed to start listener on %q: %v", listen, err)
203 }
204 ln = proxyProtoListener(ln, srv)
205 go func() {
206 if err := srv.Identd.Serve(ln); err != nil {
207 log.Printf("serving %q: %v", listen, err)
208 }
209 }()
210 default:
211 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
212 }
213
214 log.Printf("server listening on %q", listen)
215 }
216
217 sigCh := make(chan os.Signal, 1)
218 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
219
220 if err := srv.Start(); err != nil {
221 log.Fatal(err)
222 }
223
224 for sig := range sigCh {
225 switch sig {
226 case syscall.SIGHUP:
227 if cfg.TLS != nil {
228 log.Print("reloading TLS certificate")
229 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
230 if err != nil {
231 log.Printf("failed to reload TLS certificate and key: %v", err)
232 break
233 }
234 tlsCert.Store(&cert)
235 }
236 case syscall.SIGINT, syscall.SIGTERM:
237 log.Print("shutting down server")
238 srv.Shutdown()
239 return
240 }
241 }
242}
243
244func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
245 return &proxyproto.Listener{
246 Listener: ln,
247 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
248 tcpAddr, ok := upstream.(*net.TCPAddr)
249 if !ok {
250 return proxyproto.IGNORE, nil
251 }
252 if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
253 return proxyproto.USE, nil
254 }
255 return proxyproto.IGNORE, nil
256 },
257 ReadHeaderTimeout: 5 * time.Second,
258 }
259}
Note: See TracBrowser for help on using the repository browser.