source: code/trunk/cmd/soju/main.go@ 612

Last change on this file since 612 was 612, checked in by contact, 4 years ago

Add max-user-networks config option

File size: 6.1 KB
RevLine 
[98]1package main
2
3import (
[477]4 "context"
[98]5 "crypto/tls"
6 "flag"
[491]7 "fmt"
[98]8 "log"
9 "net"
[323]10 "net/http"
[317]11 "net/url"
[449]12 "os"
13 "os/signal"
[317]14 "strings"
[475]15 "sync/atomic"
[449]16 "syscall"
[477]17 "time"
[98]18
[418]19 "github.com/pires/go-proxyproto"
20
[98]21 "git.sr.ht/~emersion/soju"
22 "git.sr.ht/~emersion/soju/config"
23)
24
[477]25// TCP keep-alive interval for downstream TCP connections
26const downstreamKeepAlive = 1 * time.Hour
27
[491]28type stringSliceFlag []string
29
30func (v *stringSliceFlag) String() string {
31 return fmt.Sprint([]string(*v))
32}
33
34func (v *stringSliceFlag) Set(s string) error {
35 *v = append(*v, s)
36 return nil
37}
38
[98]39func main() {
[491]40 var listen []string
41 var configPath string
[98]42 var debug bool
[491]43 flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
[98]44 flag.StringVar(&configPath, "config", "", "path to configuration file")
45 flag.BoolVar(&debug, "debug", false, "enable debug logging")
46 flag.Parse()
47
48 var cfg *config.Server
49 if configPath != "" {
50 var err error
51 cfg, err = config.Load(configPath)
52 if err != nil {
53 log.Fatalf("failed to load config file: %v", err)
54 }
55 } else {
56 cfg = config.Defaults()
57 }
58
[491]59 cfg.Listen = append(cfg.Listen, listen...)
[317]60 if len(cfg.Listen) == 0 {
61 cfg.Listen = []string{":6697"}
62 }
[98]63
[531]64 db, err := soju.OpenSqliteDB(cfg.SQLDriver, cfg.SQLSource)
[98]65 if err != nil {
66 log.Fatalf("failed to open database: %v", err)
67 }
68
[317]69 var tlsCfg *tls.Config
[475]70 var tlsCert atomic.Value
[98]71 if cfg.TLS != nil {
72 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
73 if err != nil {
74 log.Fatalf("failed to load TLS certificate and key: %v", err)
75 }
[476]76 tlsCert.Store(&cert)
[475]77
78 tlsCfg = &tls.Config{
79 GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
80 return tlsCert.Load().(*tls.Certificate), nil
81 },
82 }
[98]83 }
84
85 srv := soju.NewServer(db)
86 // TODO: load from config/DB
87 srv.Hostname = cfg.Hostname
[178]88 srv.LogPath = cfg.LogPath
[323]89 srv.HTTPOrigins = cfg.HTTPOrigins
[417]90 srv.AcceptProxyIPs = cfg.AcceptProxyIPs
[612]91 srv.MaxUserNetworks = cfg.MaxUserNetworks
[98]92 srv.Debug = debug
93
[317]94 for _, listen := range cfg.Listen {
95 listenURI := listen
96 if !strings.Contains(listenURI, ":/") {
97 // This is a raw domain name, make it an URL with an empty scheme
98 listenURI = "//" + listenURI
[98]99 }
[317]100 u, err := url.Parse(listenURI)
101 if err != nil {
102 log.Fatalf("failed to parse listen URI %q: %v", listen, err)
103 }
104
105 switch u.Scheme {
106 case "ircs", "":
107 if tlsCfg == nil {
108 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
109 }
110 host := u.Host
111 if _, _, err := net.SplitHostPort(host); err != nil {
112 host = host + ":6697"
113 }
[470]114 ircsTLSCfg := tlsCfg.Clone()
115 ircsTLSCfg.NextProtos = []string{"irc"}
[477]116 lc := net.ListenConfig{
117 KeepAlive: downstreamKeepAlive,
118 }
119 l, err := lc.Listen(context.Background(), "tcp", host)
[317]120 if err != nil {
121 log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
122 }
[477]123 ln := tls.NewListener(l, ircsTLSCfg)
[418]124 ln = proxyProtoListener(ln, srv)
[317]125 go func() {
[449]126 if err := srv.Serve(ln); err != nil {
127 log.Printf("serving %q: %v", listen, err)
128 }
[317]129 }()
130 case "irc+insecure":
131 host := u.Host
132 if _, _, err := net.SplitHostPort(host); err != nil {
133 host = host + ":6667"
134 }
[477]135 lc := net.ListenConfig{
136 KeepAlive: downstreamKeepAlive,
137 }
138 ln, err := lc.Listen(context.Background(), "tcp", host)
[317]139 if err != nil {
140 log.Fatalf("failed to start listener on %q: %v", listen, err)
141 }
[418]142 ln = proxyProtoListener(ln, srv)
[317]143 go func() {
[449]144 if err := srv.Serve(ln); err != nil {
145 log.Printf("serving %q: %v", listen, err)
146 }
[317]147 }()
[466]148 case "unix":
149 ln, err := net.Listen("unix", u.Path)
150 if err != nil {
151 log.Fatalf("failed to start listener on %q: %v", listen, err)
152 }
153 ln = proxyProtoListener(ln, srv)
154 go func() {
155 if err := srv.Serve(ln); err != nil {
156 log.Printf("serving %q: %v", listen, err)
157 }
158 }()
[323]159 case "wss":
[581]160 if tlsCfg == nil {
161 log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
162 }
[323]163 addr := u.Host
164 if _, _, err := net.SplitHostPort(addr); err != nil {
165 addr = addr + ":https"
166 }
167 httpSrv := http.Server{
168 Addr: addr,
169 TLSConfig: tlsCfg,
170 Handler: srv,
171 }
172 go func() {
[449]173 if err := httpSrv.ListenAndServeTLS("", ""); err != nil {
174 log.Fatalf("serving %q: %v", listen, err)
175 }
[323]176 }()
177 case "ws+insecure":
178 addr := u.Host
179 if _, _, err := net.SplitHostPort(addr); err != nil {
180 addr = addr + ":http"
181 }
182 httpSrv := http.Server{
183 Addr: addr,
184 Handler: srv,
185 }
186 go func() {
[449]187 if err := httpSrv.ListenAndServe(); err != nil {
188 log.Fatalf("serving %q: %v", listen, err)
189 }
[323]190 }()
[385]191 case "ident":
192 if srv.Identd == nil {
193 srv.Identd = soju.NewIdentd()
194 }
195
196 host := u.Host
197 if _, _, err := net.SplitHostPort(host); err != nil {
198 host = host + ":113"
199 }
200 ln, err := net.Listen("tcp", host)
201 if err != nil {
202 log.Fatalf("failed to start listener on %q: %v", listen, err)
203 }
[418]204 ln = proxyProtoListener(ln, srv)
[385]205 go func() {
[449]206 if err := srv.Identd.Serve(ln); err != nil {
207 log.Printf("serving %q: %v", listen, err)
208 }
[385]209 }()
[317]210 default:
211 log.Fatalf("failed to listen on %q: unsupported scheme", listen)
212 }
213
214 log.Printf("server listening on %q", listen)
215 }
[449]216
217 sigCh := make(chan os.Signal, 1)
[475]218 signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
[449]219
220 if err := srv.Start(); err != nil {
221 log.Fatal(err)
222 }
223
[475]224 for sig := range sigCh {
225 switch sig {
226 case syscall.SIGHUP:
227 if cfg.TLS != nil {
228 log.Print("reloading TLS certificate")
229 cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath)
230 if err != nil {
231 log.Printf("failed to reload TLS certificate and key: %v", err)
232 break
233 }
[476]234 tlsCert.Store(&cert)
[475]235 }
236 case syscall.SIGINT, syscall.SIGTERM:
237 log.Print("shutting down server")
238 srv.Shutdown()
239 return
240 }
241 }
[98]242}
[418]243
244func proxyProtoListener(ln net.Listener, srv *soju.Server) net.Listener {
245 return &proxyproto.Listener{
246 Listener: ln,
247 Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
248 tcpAddr, ok := upstream.(*net.TCPAddr)
249 if !ok {
250 return proxyproto.IGNORE, nil
251 }
252 if srv.AcceptProxyIPs.Contains(tcpAddr.IP) {
253 return proxyproto.USE, nil
254 }
255 return proxyproto.IGNORE, nil
256 },
[592]257 ReadHeaderTimeout: 5 * time.Second,
[418]258 }
259}
Note: See TracBrowser for help on using the repository browser.