source: code/trunk/server.go@ 743

Last change on this file since 743 was 735, checked in by contact, 4 years ago

Add exponential backoff when re-connecting to upstream

The first reconnection attempt waits for 1min, the second the 2min,
and so on up to 10min. There's a 1min jitter so that multiple failed
connections don't try to reconnect at the exact same time.

Closes: https://todo.sr.ht/~emersion/soju/161

File size: 8.8 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[656]5 "errors"
[1]6 "fmt"
[656]7 "io"
[37]8 "log"
[472]9 "mime"
[1]10 "net"
[323]11 "net/http"
[689]12 "runtime/debug"
[24]13 "sync"
[323]14 "sync/atomic"
[67]15 "time"
[1]16
[707]17 "github.com/prometheus/client_golang/prometheus"
[708]18 "github.com/prometheus/client_golang/prometheus/promauto"
[1]19 "gopkg.in/irc.v3"
[323]20 "nhooyr.io/websocket"
[370]21
22 "git.sr.ht/~emersion/soju/config"
[1]23)
24
[67]25// TODO: make configurable
[735]26var retryConnectMinDelay = time.Minute
27var retryConnectMaxDelay = 10 * time.Minute
28var retryConnectJitter = time.Minute
[206]29var connectTimeout = 15 * time.Second
[205]30var writeTimeout = 10 * time.Second
[398]31var upstreamMessageDelay = 2 * time.Second
32var upstreamMessageBurst = 10
[675]33var backlogTimeout = 10 * time.Second
34var handleDownstreamMessageTimeout = 10 * time.Second
[704]35var downstreamRegisterTimeout = 30 * time.Second
[670]36var chatHistoryLimit = 1000
37var backlogLimit = 4000
[67]38
[9]39type Logger interface {
40 Print(v ...interface{})
41 Printf(format string, v ...interface{})
42}
43
[21]44type prefixLogger struct {
45 logger Logger
46 prefix string
47}
48
49var _ Logger = (*prefixLogger)(nil)
50
51func (l *prefixLogger) Print(v ...interface{}) {
52 v = append([]interface{}{l.prefix}, v...)
53 l.logger.Print(v...)
54}
55
56func (l *prefixLogger) Printf(format string, v ...interface{}) {
57 v = append([]interface{}{l.prefix}, v...)
58 l.logger.Printf("%v"+format, v...)
59}
60
[709]61type int64Gauge struct {
62 v int64 // atomic
63}
64
65func (g *int64Gauge) Add(delta int64) {
66 atomic.AddInt64(&g.v, delta)
67}
68
69func (g *int64Gauge) Value() int64 {
70 return atomic.LoadInt64(&g.v)
71}
72
73func (g *int64Gauge) Float64() float64 {
74 return float64(g.Value())
75}
76
[691]77type Config struct {
[612]78 Hostname string
[662]79 Title string
[612]80 LogPath string
81 Debug bool
82 HTTPOrigins []string
83 AcceptProxyIPs config.IPSet
84 MaxUserNetworks int
[694]85 MultiUpstream bool
[691]86 MOTD string
[705]87 UpstreamUserIPs []*net.IPNet
[691]88}
[22]89
[691]90type Server struct {
[707]91 Logger Logger
92 Identd *Identd // can be nil
93 MetricsRegistry prometheus.Registerer // can be nil
[691]94
[709]95 config atomic.Value // *Config
96 db Database
97 stopWG sync.WaitGroup
[77]98
[449]99 lock sync.Mutex
100 listeners map[net.Listener]struct{}
101 users map[string]*user
[709]102
103 metrics struct {
104 downstreams int64Gauge
[710]105 upstreams int64Gauge
[711]106
107 upstreamOutMessagesTotal prometheus.Counter
108 upstreamInMessagesTotal prometheus.Counter
109 downstreamOutMessagesTotal prometheus.Counter
110 downstreamInMessagesTotal prometheus.Counter
[734]111
112 upstreamConnectErrorsTotal prometheus.Counter
[709]113 }
[10]114}
115
[531]116func NewServer(db Database) *Server {
[636]117 srv := &Server{
[691]118 Logger: log.New(log.Writer(), "", log.LstdFlags),
119 db: db,
120 listeners: make(map[net.Listener]struct{}),
121 users: make(map[string]*user),
[37]122 }
[694]123 srv.config.Store(&Config{
124 Hostname: "localhost",
125 MaxUserNetworks: -1,
126 MultiUpstream: true,
127 })
[636]128 return srv
[37]129}
130
[5]131func (s *Server) prefix() *irc.Prefix {
[691]132 return &irc.Prefix{Name: s.Config().Hostname}
[5]133}
134
[691]135func (s *Server) Config() *Config {
136 return s.config.Load().(*Config)
137}
138
139func (s *Server) SetConfig(cfg *Config) {
140 s.config.Store(cfg)
141}
142
[449]143func (s *Server) Start() error {
[708]144 s.registerMetrics()
145
[652]146 users, err := s.db.ListUsers(context.TODO())
[77]147 if err != nil {
148 return err
149 }
[71]150
[77]151 s.lock.Lock()
[378]152 for i := range users {
153 s.addUserLocked(&users[i])
[71]154 }
[37]155 s.lock.Unlock()
156
[449]157 return nil
[10]158}
159
[708]160func (s *Server) registerMetrics() {
161 factory := promauto.With(s.MetricsRegistry)
162
163 factory.NewGaugeFunc(prometheus.GaugeOpts{
164 Name: "soju_users_active",
165 Help: "Current number of active users",
166 }, func() float64 {
167 s.lock.Lock()
168 n := len(s.users)
169 s.lock.Unlock()
170 return float64(n)
171 })
172
173 factory.NewGaugeFunc(prometheus.GaugeOpts{
174 Name: "soju_downstreams_active",
175 Help: "Current number of downstream connections",
[709]176 }, s.metrics.downstreams.Float64)
[710]177
178 factory.NewGaugeFunc(prometheus.GaugeOpts{
179 Name: "soju_upstreams_active",
180 Help: "Current number of upstream connections",
181 }, s.metrics.upstreams.Float64)
[711]182
183 s.metrics.upstreamOutMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
184 Name: "soju_upstream_out_messages_total",
185 Help: "Total number of outgoing messages sent to upstream servers",
186 })
187
188 s.metrics.upstreamInMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
189 Name: "soju_upstream_in_messages_total",
190 Help: "Total number of incoming messages received from upstream servers",
191 })
192
193 s.metrics.downstreamOutMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
194 Name: "soju_downstream_out_messages_total",
195 Help: "Total number of outgoing messages sent to downstream clients",
196 })
197
198 s.metrics.downstreamInMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
199 Name: "soju_downstream_in_messages_total",
200 Help: "Total number of incoming messages received from downstream clients",
201 })
[734]202
203 s.metrics.upstreamConnectErrorsTotal = factory.NewCounter(prometheus.CounterOpts{
204 Name: "soju_upstream_connect_errors_total",
205 Help: "Total number of upstream connection errors",
206 })
[708]207}
208
[449]209func (s *Server) Shutdown() {
210 s.lock.Lock()
211 for ln := range s.listeners {
212 if err := ln.Close(); err != nil {
213 s.Logger.Printf("failed to stop listener: %v", err)
214 }
215 }
216 for _, u := range s.users {
217 u.events <- eventStop{}
218 }
219 s.lock.Unlock()
220
221 s.stopWG.Wait()
[599]222
223 if err := s.db.Close(); err != nil {
224 s.Logger.Printf("failed to close DB: %v", err)
225 }
[449]226}
227
[680]228func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
[329]229 s.lock.Lock()
230 defer s.lock.Unlock()
231
232 if _, ok := s.users[user.Username]; ok {
233 return nil, fmt.Errorf("user %q already exists", user.Username)
234 }
235
[680]236 err := s.db.StoreUser(ctx, user)
[329]237 if err != nil {
238 return nil, fmt.Errorf("could not create user in db: %v", err)
239 }
240
[378]241 return s.addUserLocked(user), nil
[329]242}
243
[563]244func (s *Server) forEachUser(f func(*user)) {
245 s.lock.Lock()
246 for _, u := range s.users {
247 f(u)
248 }
249 s.lock.Unlock()
250}
251
[38]252func (s *Server) getUser(name string) *user {
253 s.lock.Lock()
254 u := s.users[name]
255 s.lock.Unlock()
256 return u
257}
258
[378]259func (s *Server) addUserLocked(user *User) *user {
260 s.Logger.Printf("starting bouncer for user %q", user.Username)
261 u := newUser(s, user)
262 s.users[u.Username] = u
263
[449]264 s.stopWG.Add(1)
265
[378]266 go func() {
[689]267 defer func() {
268 if err := recover(); err != nil {
269 s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
270 }
271 }()
272
[378]273 u.run()
274
275 s.lock.Lock()
276 delete(s.users, u.Username)
277 s.lock.Unlock()
[449]278
279 s.stopWG.Done()
[378]280 }()
281
282 return u
283}
284
[323]285var lastDownstreamID uint64 = 0
286
[347]287func (s *Server) handle(ic ircConn) {
[689]288 defer func() {
289 if err := recover(); err != nil {
290 s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
291 }
292 }()
293
[709]294 s.metrics.downstreams.Add(1)
[323]295 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]296 dc := newDownstreamConn(s, ic, id)
[323]297 if err := dc.runUntilRegistered(); err != nil {
[655]298 if !errors.Is(err, io.EOF) {
299 dc.logger.Print(err)
300 }
[323]301 } else {
302 dc.user.events <- eventDownstreamConnected{dc}
303 if err := dc.readMessages(dc.user.events); err != nil {
304 dc.logger.Print(err)
305 }
306 dc.user.events <- eventDownstreamDisconnected{dc}
307 }
308 dc.Close()
[709]309 s.metrics.downstreams.Add(-1)
[323]310}
311
[3]312func (s *Server) Serve(ln net.Listener) error {
[449]313 s.lock.Lock()
314 s.listeners[ln] = struct{}{}
315 s.lock.Unlock()
316
317 s.stopWG.Add(1)
318
319 defer func() {
320 s.lock.Lock()
321 delete(s.listeners, ln)
322 s.lock.Unlock()
323
324 s.stopWG.Done()
325 }()
326
[1]327 for {
[323]328 conn, err := ln.Accept()
[601]329 if isErrClosed(err) {
[449]330 return nil
331 } else if err != nil {
[1]332 return fmt.Errorf("failed to accept connection: %v", err)
333 }
334
[347]335 go s.handle(newNetIRCConn(conn))
[1]336 }
337}
[323]338
339func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
340 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]341 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[691]342 OriginPatterns: s.Config().HTTPOrigins,
[323]343 })
344 if err != nil {
345 s.Logger.Printf("failed to serve HTTP connection: %v", err)
346 return
347 }
[345]348
[370]349 isProxy := false
[345]350 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
351 if ip := net.ParseIP(host); ip != nil {
[691]352 isProxy = s.Config().AcceptProxyIPs.Contains(ip)
[345]353 }
354 }
355
[474]356 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]357 // to prevent users from spoofing the remote address
[344]358 remoteAddr := req.RemoteAddr
[472]359 if isProxy {
360 forwarded := parseForwarded(req.Header)
[473]361 if forwarded["for"] != "" {
362 remoteAddr = forwarded["for"]
[472]363 }
[344]364 }
[345]365
[347]366 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]367}
[472]368
369func parseForwarded(h http.Header) map[string]string {
370 forwarded := h.Get("Forwarded")
371 if forwarded == "" {
[474]372 return map[string]string{
373 "for": h.Get("X-Forwarded-For"),
374 "proto": h.Get("X-Forwarded-Proto"),
375 "host": h.Get("X-Forwarded-Host"),
376 }
[472]377 }
378 // Hack to easily parse header parameters
379 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
380 return params
381}
[605]382
383type ServerStats struct {
384 Users int
385 Downstreams int64
[710]386 Upstreams int64
[605]387}
388
389func (s *Server) Stats() *ServerStats {
390 var stats ServerStats
391 s.lock.Lock()
392 stats.Users = len(s.users)
393 s.lock.Unlock()
[709]394 stats.Downstreams = s.metrics.downstreams.Value()
[710]395 stats.Upstreams = s.metrics.upstreams.Value()
[605]396 return &stats
397}
Note: See TracBrowser for help on using the repository browser.