source: code/trunk/cmd/soju/main.go@ 491

Last change on this file since 491 was 491, checked in by contact, 4 years ago

cmd/soju: allow specifying -listen multiple times

Closes: https://todo.sr.ht/~emersion/soju/67

File size: 6.0 KB
RevLine 
[98]1package main
2
3import (
[477]4 "context"
[98]5 "crypto/tls"
6 "flag"
[491]7 "fmt"
[98]8 "log"
9 "net"
[323]10 "net/http"
[317]11 "net/url"
[449]12 "os"
13 "os/signal"
[317]14 "strings"
[475]15 "sync/atomic"
[449]16 "syscall"
[477]17 "time"
[98]18
[418]19 "github.com/pires/go-proxyproto"
20
[98]21 "git.sr.ht/~emersion/soju"
22 "git.sr.ht/~emersion/soju/config"
23)
24
[477]25// TCP keep-alive interval for downstream TCP connections
26const downstreamKeepAlive = 1 * time.Hour
27
[491]28type stringSliceFlag []string
29
30func (v *stringSliceFlag) String() string {
31 return fmt.Sprint([]string(*v))
32}
33
34func (v *stringSliceFlag) Set(s string) error {
35 *v = append(*v, s)
36 return nil
37}
38
[98]39func main() {
[491]40 var listen []string
41 var configPath string
[98]42 var debug bool
[491]43 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
[98]44 flag.StringVar(&configPath, "config", "", "path to configuration file")
45 flag.BoolVar(&debug, "debug", false, "enable debug logging")
46 flag.Parse()
47
48 var cfg *config.Server
49 if configPath != "" {
50 var err error
51 cfg, err = config.Load(configPath)
52 if err != nil {
53 log.Fatalf("failed to load config file: %v", err)
54 }
55 } else {
56 cfg = config.Defaults()
57 }
58
[491]59 cfg.Listen = append(cfg.Listen, listen...)
[317]60 if len(cfg.Listen) == 0 {
61 cfg.Listen = []string{":6697"}
62 }
[98]63
64 db, err := soju.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource)
65 if err != nil {
66 log.Fatalf("failed to open database: %v", err)
67 }
68
[317]69 var tlsCfg *tls.Config
[475]70 var tlsCert atomic.Value
[98]71 if cfg.TLS != nil {
72 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
73 if err != nil {
74 log.Fatalf("failed to load TLS certificate and key: %v", err)
75 }
[476]76 tlsCert.Store(&cert)
[475]77
78 tlsCfg = &tls.Config{
79 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
80 return tlsCert.Load().(*tls.Certificate), nil
81 },
82 }
[98]83 }
84
85 srv := soju.NewServer(db)
86 // TODO: load from config/DB
87 srv.Hostname = cfg.Hostname
[178]88 srv.LogPath = cfg.LogPath
[323]89 srv.HTTPOrigins = cfg.HTTPOrigins
[417]90 srv.AcceptProxyIPs = cfg.AcceptProxyIPs
[98]91 srv.Debug = debug
92
[317]93 for _, listen := range cfg.Listen {
94 listenURI := listen
95 if !strings.Contains(listenURI, ":/") {
96 // This is a raw domain name, make it an URL with an empty scheme
97 listenURI = "//" + listenURI
[98]98 }
[317]99 u, err := url.Parse(listenURI)
100 if err != nil {
101 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
102 }
103
104 switch u.Scheme {
105 case "ircs", "":
106 if tlsCfg == nil {
107 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
108 }
109 host := u.Host
110 if _, _, err := net.SplitHostPort(host); err != nil {
111 host = host + ":6697"
112 }
[470]113 ircsTLSCfg := tlsCfg.Clone()
114 ircsTLSCfg.NextProtos = []string{"irc"}
[477]115 lc := net.ListenConfig{
116 KeepAlive: downstreamKeepAlive,
117 }
118 l, err := lc.Listen(context.Background(), "tcp", host)
[317]119 if err != nil {
120 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
121 }
[477]122 ln := tls.NewListener(l, ircsTLSCfg)
[418]123 ln = proxyProtoListener(ln, srv)
[317]124 go func() {
[449]125 if err := srv.Serve(ln); err != nil {
126 log.Printf("serving %q: %v", listen, err)
127 }
[317]128 }()
129 case "irc+insecure":
130 host := u.Host
131 if _, _, err := net.SplitHostPort(host); err != nil {
132 host = host + ":6667"
133 }
[477]134 lc := net.ListenConfig{
135 KeepAlive: downstreamKeepAlive,
136 }
137 ln, err := lc.Listen(context.Background(), "tcp", host)
[317]138 if err != nil {
139 log.Fatalf("failed to start listener on %q: %v", listen, err)
140 }
[418]141 ln = proxyProtoListener(ln, srv)
[317]142 go func() {
[449]143 if err := srv.Serve(ln); err != nil {
144 log.Printf("serving %q: %v", listen, err)
145 }
[317]146 }()
[466]147 case "unix":
148 ln, err := net.Listen("unix", u.Path)
149 if err != nil {
150 log.Fatalf("failed to start listener on %q: %v", listen, err)
151 }
152 ln = proxyProtoListener(ln, srv)
153 go func() {
154 if err := srv.Serve(ln); err != nil {
155 log.Printf("serving %q: %v", listen, err)
156 }
157 }()
[323]158 case "wss":
159 addr := u.Host
160 if _, _, err := net.SplitHostPort(addr); err != nil {
161 addr = addr + ":https"
162 }
163 httpSrv := http.Server{
164 Addr: addr,
165 TLSConfig: tlsCfg,
166 Handler: srv,
167 }
168 go func() {
[449]169 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
170 log.Fatalf("serving %q: %v", listen, err)
171 }
[323]172 }()
173 case "ws+insecure":
174 addr := u.Host
175 if _, _, err := net.SplitHostPort(addr); err != nil {
176 addr = addr + ":http"
177 }
178 httpSrv := http.Server{
179 Addr: addr,
180 Handler: srv,
181 }
182 go func() {
[449]183 if err := httpSrv.ListenAndServe(); err != nil {
184 log.Fatalf("serving %q: %v", listen, err)
185 }
[323]186 }()
[385]187 case "ident":
188 if srv.Identd == nil {
189 srv.Identd = soju.NewIdentd()
190 }
191
192 host := u.Host
193 if _, _, err := net.SplitHostPort(host); err != nil {
194 host = host + ":113"
195 }
196 ln, err := net.Listen("tcp", host)
197 if err != nil {
198 log.Fatalf("failed to start listener on %q: %v", listen, err)
199 }
[418]200 ln = proxyProtoListener(ln, srv)
[385]201 go func() {
[449]202 if err := srv.Identd.Serve(ln); err != nil {
203 log.Printf("serving %q: %v", listen, err)
204 }
[385]205 }()
[317]206 default:
207 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
208 }
209
210 log.Printf("server listening on %q", listen)
211 }
[449]212
213 sigCh := make(chan os.Signal, 1)
[475]214 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
[449]215
216 if err := srv.Start(); err != nil {
217 log.Fatal(err)
218 }
219
[475]220 for sig := range sigCh {
221 switch sig {
222 case syscall.SIGHUP:
223 if cfg.TLS != nil {
224 log.Print("reloading TLS certificate")
225 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
226 if err != nil {
227 log.Printf("failed to reload TLS certificate and key: %v", err)
228 break
229 }
[476]230 tlsCert.Store(&cert)
[475]231 }
232 case syscall.SIGINT, syscall.SIGTERM:
233 log.Print("shutting down server")
234 srv.Shutdown()
235 return
236 }
237 }
[98]238}
[418]239
240func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
241 return &proxyproto.Listener{
242 Listener: ln,
243 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
244 tcpAddr, ok := upstream.(*net.TCPAddr)
245 if !ok {
246 return proxyproto.IGNORE, nil
247 }
248 if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
249 return proxyproto.USE, nil
250 }
251 return proxyproto.IGNORE, nil
252 },
253 }
254}
Note: See TracBrowser for help on using the repository browser.