source: code/trunk/server.go@ 771

Last change on this file since 771 was 766, checked in by contact, 3 years ago

Retry on temporary net.Listener failure

Instead of stopping to listen, retry on temporary failure. This
can happen when running out of FDs.

Closes: https://todo.sr.ht/~emersion/soju/183

File size: 9.8 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[656]5 "errors"
[1]6 "fmt"
[656]7 "io"
[37]8 "log"
[472]9 "mime"
[1]10 "net"
[323]11 "net/http"
[689]12 "runtime/debug"
[24]13 "sync"
[323]14 "sync/atomic"
[67]15 "time"
[1]16
[707]17 "github.com/prometheus/client_golang/prometheus"
[708]18 "github.com/prometheus/client_golang/prometheus/promauto"
[1]19 "gopkg.in/irc.v3"
[323]20 "nhooyr.io/websocket"
[370]21
22 "git.sr.ht/~emersion/soju/config"
[1]23)
24
[67]25// TODO: make configurable
[735]26var retryConnectMinDelay = time.Minute
27var retryConnectMaxDelay = 10 * time.Minute
28var retryConnectJitter = time.Minute
[206]29var connectTimeout = 15 * time.Second
[205]30var writeTimeout = 10 * time.Second
[398]31var upstreamMessageDelay = 2 * time.Second
32var upstreamMessageBurst = 10
[675]33var backlogTimeout = 10 * time.Second
34var handleDownstreamMessageTimeout = 10 * time.Second
[704]35var downstreamRegisterTimeout = 30 * time.Second
[670]36var chatHistoryLimit = 1000
37var backlogLimit = 4000
[67]38
[9]39type Logger interface {
40 Printf(format string, v ...interface{})
[747]41 Debugf(format string, v ...interface{})
[9]42}
43
[747]44type logger struct {
45 *log.Logger
46 debug bool
47}
48
49func (l logger) Debugf(format string, v ...interface{}) {
50 if !l.debug {
51 return
52 }
53 l.Logger.Printf(format, v...)
54}
55
56func NewLogger(out io.Writer, debug bool) Logger {
57 return logger{
58 Logger: log.New(log.Writer(), "", log.LstdFlags),
59 debug: debug,
60 }
61}
62
[21]63type prefixLogger struct {
64 logger Logger
65 prefix string
66}
67
68var _ Logger = (*prefixLogger)(nil)
69
70func (l *prefixLogger) Printf(format string, v ...interface{}) {
71 v = append([]interface{}{l.prefix}, v...)
72 l.logger.Printf("%v"+format, v...)
73}
74
[747]75func (l *prefixLogger) Debugf(format string, v ...interface{}) {
76 v = append([]interface{}{l.prefix}, v...)
77 l.logger.Debugf("%v"+format, v...)
78}
79
[709]80type int64Gauge struct {
81 v int64 // atomic
82}
83
84func (g *int64Gauge) Add(delta int64) {
85 atomic.AddInt64(&g.v, delta)
86}
87
88func (g *int64Gauge) Value() int64 {
89 return atomic.LoadInt64(&g.v)
90}
91
92func (g *int64Gauge) Float64() float64 {
93 return float64(g.Value())
94}
95
[766]96type retryListener struct {
97 net.Listener
98 Logger Logger
99
100 delay time.Duration
101}
102
103func (ln *retryListener) Accept() (net.Conn, error) {
104 for {
105 conn, err := ln.Listener.Accept()
106 if ne, ok := err.(net.Error); ok && ne.Temporary() {
107 if ln.delay == 0 {
108 ln.delay = 5 * time.Millisecond
109 } else {
110 ln.delay *= 2
111 }
112 if max := 1 * time.Second; ln.delay > max {
113 ln.delay = max
114 }
115 if ln.Logger != nil {
116 ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err)
117 }
118 time.Sleep(ln.delay)
119 } else {
120 ln.delay = 0
121 return conn, err
122 }
123 }
124}
125
[691]126type Config struct {
[612]127 Hostname string
[662]128 Title string
[612]129 LogPath string
130 HTTPOrigins []string
131 AcceptProxyIPs config.IPSet
132 MaxUserNetworks int
[694]133 MultiUpstream bool
[691]134 MOTD string
[705]135 UpstreamUserIPs []*net.IPNet
[691]136}
[22]137
[691]138type Server struct {
[707]139 Logger Logger
140 Identd *Identd // can be nil
141 MetricsRegistry prometheus.Registerer // can be nil
[691]142
[709]143 config atomic.Value // *Config
144 db Database
145 stopWG sync.WaitGroup
[77]146
[449]147 lock sync.Mutex
148 listeners map[net.Listener]struct{}
149 users map[string]*user
[709]150
151 metrics struct {
152 downstreams int64Gauge
[710]153 upstreams int64Gauge
[711]154
155 upstreamOutMessagesTotal prometheus.Counter
156 upstreamInMessagesTotal prometheus.Counter
157 downstreamOutMessagesTotal prometheus.Counter
158 downstreamInMessagesTotal prometheus.Counter
[734]159
160 upstreamConnectErrorsTotal prometheus.Counter
[709]161 }
[10]162}
163
[531]164func NewServer(db Database) *Server {
[636]165 srv := &Server{
[747]166 Logger: NewLogger(log.Writer(), true),
[691]167 db: db,
168 listeners: make(map[net.Listener]struct{}),
169 users: make(map[string]*user),
[37]170 }
[694]171 srv.config.Store(&Config{
172 Hostname: "localhost",
173 MaxUserNetworks: -1,
174 MultiUpstream: true,
175 })
[636]176 return srv
[37]177}
178
[5]179func (s *Server) prefix() *irc.Prefix {
[691]180 return &irc.Prefix{Name: s.Config().Hostname}
[5]181}
182
[691]183func (s *Server) Config() *Config {
184 return s.config.Load().(*Config)
185}
186
187func (s *Server) SetConfig(cfg *Config) {
188 s.config.Store(cfg)
189}
190
[449]191func (s *Server) Start() error {
[708]192 s.registerMetrics()
193
[652]194 users, err := s.db.ListUsers(context.TODO())
[77]195 if err != nil {
196 return err
197 }
[71]198
[77]199 s.lock.Lock()
[378]200 for i := range users {
201 s.addUserLocked(&users[i])
[71]202 }
[37]203 s.lock.Unlock()
204
[449]205 return nil
[10]206}
207
[708]208func (s *Server) registerMetrics() {
209 factory := promauto.With(s.MetricsRegistry)
210
211 factory.NewGaugeFunc(prometheus.GaugeOpts{
212 Name: "soju_users_active",
213 Help: "Current number of active users",
214 }, func() float64 {
215 s.lock.Lock()
216 n := len(s.users)
217 s.lock.Unlock()
218 return float64(n)
219 })
220
221 factory.NewGaugeFunc(prometheus.GaugeOpts{
222 Name: "soju_downstreams_active",
223 Help: "Current number of downstream connections",
[709]224 }, s.metrics.downstreams.Float64)
[710]225
226 factory.NewGaugeFunc(prometheus.GaugeOpts{
227 Name: "soju_upstreams_active",
228 Help: "Current number of upstream connections",
229 }, s.metrics.upstreams.Float64)
[711]230
231 s.metrics.upstreamOutMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
232 Name: "soju_upstream_out_messages_total",
233 Help: "Total number of outgoing messages sent to upstream servers",
234 })
235
236 s.metrics.upstreamInMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
237 Name: "soju_upstream_in_messages_total",
238 Help: "Total number of incoming messages received from upstream servers",
239 })
240
241 s.metrics.downstreamOutMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
242 Name: "soju_downstream_out_messages_total",
243 Help: "Total number of outgoing messages sent to downstream clients",
244 })
245
246 s.metrics.downstreamInMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
247 Name: "soju_downstream_in_messages_total",
248 Help: "Total number of incoming messages received from downstream clients",
249 })
[734]250
251 s.metrics.upstreamConnectErrorsTotal = factory.NewCounter(prometheus.CounterOpts{
252 Name: "soju_upstream_connect_errors_total",
253 Help: "Total number of upstream connection errors",
254 })
[708]255}
256
[449]257func (s *Server) Shutdown() {
258 s.lock.Lock()
259 for ln := range s.listeners {
260 if err := ln.Close(); err != nil {
261 s.Logger.Printf("failed to stop listener: %v", err)
262 }
263 }
264 for _, u := range s.users {
265 u.events <- eventStop{}
266 }
267 s.lock.Unlock()
268
269 s.stopWG.Wait()
[599]270
271 if err := s.db.Close(); err != nil {
272 s.Logger.Printf("failed to close DB: %v", err)
273 }
[449]274}
275
[680]276func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
[329]277 s.lock.Lock()
278 defer s.lock.Unlock()
279
280 if _, ok := s.users[user.Username]; ok {
281 return nil, fmt.Errorf("user %q already exists", user.Username)
282 }
283
[680]284 err := s.db.StoreUser(ctx, user)
[329]285 if err != nil {
286 return nil, fmt.Errorf("could not create user in db: %v", err)
287 }
288
[378]289 return s.addUserLocked(user), nil
[329]290}
291
[563]292func (s *Server) forEachUser(f func(*user)) {
293 s.lock.Lock()
294 for _, u := range s.users {
295 f(u)
296 }
297 s.lock.Unlock()
298}
299
[38]300func (s *Server) getUser(name string) *user {
301 s.lock.Lock()
302 u := s.users[name]
303 s.lock.Unlock()
304 return u
305}
306
[378]307func (s *Server) addUserLocked(user *User) *user {
308 s.Logger.Printf("starting bouncer for user %q", user.Username)
309 u := newUser(s, user)
310 s.users[u.Username] = u
311
[449]312 s.stopWG.Add(1)
313
[378]314 go func() {
[689]315 defer func() {
316 if err := recover(); err != nil {
317 s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
318 }
[756]319
320 s.lock.Lock()
321 delete(s.users, u.Username)
322 s.lock.Unlock()
323
324 s.stopWG.Done()
[689]325 }()
326
[378]327 u.run()
328 }()
329
330 return u
331}
332
[323]333var lastDownstreamID uint64 = 0
334
[347]335func (s *Server) handle(ic ircConn) {
[689]336 defer func() {
337 if err := recover(); err != nil {
338 s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
339 }
340 }()
341
[709]342 s.metrics.downstreams.Add(1)
[323]343 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]344 dc := newDownstreamConn(s, ic, id)
[323]345 if err := dc.runUntilRegistered(); err != nil {
[655]346 if !errors.Is(err, io.EOF) {
[746]347 dc.logger.Printf("%v", err)
[655]348 }
[323]349 } else {
350 dc.user.events <- eventDownstreamConnected{dc}
351 if err := dc.readMessages(dc.user.events); err != nil {
[746]352 dc.logger.Printf("%v", err)
[323]353 }
354 dc.user.events <- eventDownstreamDisconnected{dc}
355 }
356 dc.Close()
[709]357 s.metrics.downstreams.Add(-1)
[323]358}
359
[3]360func (s *Server) Serve(ln net.Listener) error {
[766]361 ln = &retryListener{
362 Listener: ln,
363 Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())},
364 }
365
[449]366 s.lock.Lock()
367 s.listeners[ln] = struct{}{}
368 s.lock.Unlock()
369
370 s.stopWG.Add(1)
371
372 defer func() {
373 s.lock.Lock()
374 delete(s.listeners, ln)
375 s.lock.Unlock()
376
377 s.stopWG.Done()
378 }()
379
[1]380 for {
[323]381 conn, err := ln.Accept()
[601]382 if isErrClosed(err) {
[449]383 return nil
384 } else if err != nil {
[1]385 return fmt.Errorf("failed to accept connection: %v", err)
386 }
387
[347]388 go s.handle(newNetIRCConn(conn))
[1]389 }
390}
[323]391
392func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
393 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]394 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[691]395 OriginPatterns: s.Config().HTTPOrigins,
[323]396 })
397 if err != nil {
398 s.Logger.Printf("failed to serve HTTP connection: %v", err)
399 return
400 }
[345]401
[370]402 isProxy := false
[345]403 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
404 if ip := net.ParseIP(host); ip != nil {
[691]405 isProxy = s.Config().AcceptProxyIPs.Contains(ip)
[345]406 }
407 }
408
[474]409 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]410 // to prevent users from spoofing the remote address
[344]411 remoteAddr := req.RemoteAddr
[472]412 if isProxy {
413 forwarded := parseForwarded(req.Header)
[473]414 if forwarded["for"] != "" {
415 remoteAddr = forwarded["for"]
[472]416 }
[344]417 }
[345]418
[347]419 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]420}
[472]421
422func parseForwarded(h http.Header) map[string]string {
423 forwarded := h.Get("Forwarded")
424 if forwarded == "" {
[474]425 return map[string]string{
426 "for": h.Get("X-Forwarded-For"),
427 "proto": h.Get("X-Forwarded-Proto"),
428 "host": h.Get("X-Forwarded-Host"),
429 }
[472]430 }
431 // Hack to easily parse header parameters
432 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
433 return params
434}
[605]435
436type ServerStats struct {
437 Users int
438 Downstreams int64
[710]439 Upstreams int64
[605]440}
441
442func (s *Server) Stats() *ServerStats {
443 var stats ServerStats
444 s.lock.Lock()
445 stats.Users = len(s.users)
446 s.lock.Unlock()
[709]447 stats.Downstreams = s.metrics.downstreams.Value()
[710]448 stats.Upstreams = s.metrics.upstreams.Value()
[605]449 return &stats
450}
Note: See TracBrowser for help on using the repository browser.