source: code/trunk/server.go@ 764

Last change on this file since 764 was 756, checked in by contact, 4 years ago

server: cleanup user in defer

If a user goroutine panics, this correctly removes it from the
global map instead of leaving a dangling entry behind.

File size: 9.1 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[656]5 "errors"
[1]6 "fmt"
[656]7 "io"
[37]8 "log"
[472]9 "mime"
[1]10 "net"
[323]11 "net/http"
[689]12 "runtime/debug"
[24]13 "sync"
[323]14 "sync/atomic"
[67]15 "time"
[1]16
[707]17 "github.com/prometheus/client_golang/prometheus"
[708]18 "github.com/prometheus/client_golang/prometheus/promauto"
[1]19 "gopkg.in/irc.v3"
[323]20 "nhooyr.io/websocket"
[370]21
22 "git.sr.ht/~emersion/soju/config"
[1]23)
24
[67]25// TODO: make configurable
[735]26var retryConnectMinDelay = time.Minute
27var retryConnectMaxDelay = 10 * time.Minute
28var retryConnectJitter = time.Minute
[206]29var connectTimeout = 15 * time.Second
[205]30var writeTimeout = 10 * time.Second
[398]31var upstreamMessageDelay = 2 * time.Second
32var upstreamMessageBurst = 10
[675]33var backlogTimeout = 10 * time.Second
34var handleDownstreamMessageTimeout = 10 * time.Second
[704]35var downstreamRegisterTimeout = 30 * time.Second
[670]36var chatHistoryLimit = 1000
37var backlogLimit = 4000
[67]38
[9]39type Logger interface {
40 Printf(format string, v ...interface{})
[747]41 Debugf(format string, v ...interface{})
[9]42}
43
[747]44type logger struct {
45 *log.Logger
46 debug bool
47}
48
49func (l logger) Debugf(format string, v ...interface{}) {
50 if !l.debug {
51 return
52 }
53 l.Logger.Printf(format, v...)
54}
55
56func NewLogger(out io.Writer, debug bool) Logger {
57 return logger{
58 Logger: log.New(log.Writer(), "", log.LstdFlags),
59 debug: debug,
60 }
61}
62
[21]63type prefixLogger struct {
64 logger Logger
65 prefix string
66}
67
68var _ Logger = (*prefixLogger)(nil)
69
70func (l *prefixLogger) Printf(format string, v ...interface{}) {
71 v = append([]interface{}{l.prefix}, v...)
72 l.logger.Printf("%v"+format, v...)
73}
74
[747]75func (l *prefixLogger) Debugf(format string, v ...interface{}) {
76 v = append([]interface{}{l.prefix}, v...)
77 l.logger.Debugf("%v"+format, v...)
78}
79
[709]80type int64Gauge struct {
81 v int64 // atomic
82}
83
84func (g *int64Gauge) Add(delta int64) {
85 atomic.AddInt64(&g.v, delta)
86}
87
88func (g *int64Gauge) Value() int64 {
89 return atomic.LoadInt64(&g.v)
90}
91
92func (g *int64Gauge) Float64() float64 {
93 return float64(g.Value())
94}
95
[691]96type Config struct {
[612]97 Hostname string
[662]98 Title string
[612]99 LogPath string
100 HTTPOrigins []string
101 AcceptProxyIPs config.IPSet
102 MaxUserNetworks int
[694]103 MultiUpstream bool
[691]104 MOTD string
[705]105 UpstreamUserIPs []*net.IPNet
[691]106}
[22]107
[691]108type Server struct {
[707]109 Logger Logger
110 Identd *Identd // can be nil
111 MetricsRegistry prometheus.Registerer // can be nil
[691]112
[709]113 config atomic.Value // *Config
114 db Database
115 stopWG sync.WaitGroup
[77]116
[449]117 lock sync.Mutex
118 listeners map[net.Listener]struct{}
119 users map[string]*user
[709]120
121 metrics struct {
122 downstreams int64Gauge
[710]123 upstreams int64Gauge
[711]124
125 upstreamOutMessagesTotal prometheus.Counter
126 upstreamInMessagesTotal prometheus.Counter
127 downstreamOutMessagesTotal prometheus.Counter
128 downstreamInMessagesTotal prometheus.Counter
[734]129
130 upstreamConnectErrorsTotal prometheus.Counter
[709]131 }
[10]132}
133
[531]134func NewServer(db Database) *Server {
[636]135 srv := &Server{
[747]136 Logger: NewLogger(log.Writer(), true),
[691]137 db: db,
138 listeners: make(map[net.Listener]struct{}),
139 users: make(map[string]*user),
[37]140 }
[694]141 srv.config.Store(&Config{
142 Hostname: "localhost",
143 MaxUserNetworks: -1,
144 MultiUpstream: true,
145 })
[636]146 return srv
[37]147}
148
[5]149func (s *Server) prefix() *irc.Prefix {
[691]150 return &irc.Prefix{Name: s.Config().Hostname}
[5]151}
152
[691]153func (s *Server) Config() *Config {
154 return s.config.Load().(*Config)
155}
156
157func (s *Server) SetConfig(cfg *Config) {
158 s.config.Store(cfg)
159}
160
[449]161func (s *Server) Start() error {
[708]162 s.registerMetrics()
163
[652]164 users, err := s.db.ListUsers(context.TODO())
[77]165 if err != nil {
166 return err
167 }
[71]168
[77]169 s.lock.Lock()
[378]170 for i := range users {
171 s.addUserLocked(&users[i])
[71]172 }
[37]173 s.lock.Unlock()
174
[449]175 return nil
[10]176}
177
[708]178func (s *Server) registerMetrics() {
179 factory := promauto.With(s.MetricsRegistry)
180
181 factory.NewGaugeFunc(prometheus.GaugeOpts{
182 Name: "soju_users_active",
183 Help: "Current number of active users",
184 }, func() float64 {
185 s.lock.Lock()
186 n := len(s.users)
187 s.lock.Unlock()
188 return float64(n)
189 })
190
191 factory.NewGaugeFunc(prometheus.GaugeOpts{
192 Name: "soju_downstreams_active",
193 Help: "Current number of downstream connections",
[709]194 }, s.metrics.downstreams.Float64)
[710]195
196 factory.NewGaugeFunc(prometheus.GaugeOpts{
197 Name: "soju_upstreams_active",
198 Help: "Current number of upstream connections",
199 }, s.metrics.upstreams.Float64)
[711]200
201 s.metrics.upstreamOutMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
202 Name: "soju_upstream_out_messages_total",
203 Help: "Total number of outgoing messages sent to upstream servers",
204 })
205
206 s.metrics.upstreamInMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
207 Name: "soju_upstream_in_messages_total",
208 Help: "Total number of incoming messages received from upstream servers",
209 })
210
211 s.metrics.downstreamOutMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
212 Name: "soju_downstream_out_messages_total",
213 Help: "Total number of outgoing messages sent to downstream clients",
214 })
215
216 s.metrics.downstreamInMessagesTotal = factory.NewCounter(prometheus.CounterOpts{
217 Name: "soju_downstream_in_messages_total",
218 Help: "Total number of incoming messages received from downstream clients",
219 })
[734]220
221 s.metrics.upstreamConnectErrorsTotal = factory.NewCounter(prometheus.CounterOpts{
222 Name: "soju_upstream_connect_errors_total",
223 Help: "Total number of upstream connection errors",
224 })
[708]225}
226
[449]227func (s *Server) Shutdown() {
228 s.lock.Lock()
229 for ln := range s.listeners {
230 if err := ln.Close(); err != nil {
231 s.Logger.Printf("failed to stop listener: %v", err)
232 }
233 }
234 for _, u := range s.users {
235 u.events <- eventStop{}
236 }
237 s.lock.Unlock()
238
239 s.stopWG.Wait()
[599]240
241 if err := s.db.Close(); err != nil {
242 s.Logger.Printf("failed to close DB: %v", err)
243 }
[449]244}
245
[680]246func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
[329]247 s.lock.Lock()
248 defer s.lock.Unlock()
249
250 if _, ok := s.users[user.Username]; ok {
251 return nil, fmt.Errorf("user %q already exists", user.Username)
252 }
253
[680]254 err := s.db.StoreUser(ctx, user)
[329]255 if err != nil {
256 return nil, fmt.Errorf("could not create user in db: %v", err)
257 }
258
[378]259 return s.addUserLocked(user), nil
[329]260}
261
[563]262func (s *Server) forEachUser(f func(*user)) {
263 s.lock.Lock()
264 for _, u := range s.users {
265 f(u)
266 }
267 s.lock.Unlock()
268}
269
[38]270func (s *Server) getUser(name string) *user {
271 s.lock.Lock()
272 u := s.users[name]
273 s.lock.Unlock()
274 return u
275}
276
[378]277func (s *Server) addUserLocked(user *User) *user {
278 s.Logger.Printf("starting bouncer for user %q", user.Username)
279 u := newUser(s, user)
280 s.users[u.Username] = u
281
[449]282 s.stopWG.Add(1)
283
[378]284 go func() {
[689]285 defer func() {
286 if err := recover(); err != nil {
287 s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
288 }
[756]289
290 s.lock.Lock()
291 delete(s.users, u.Username)
292 s.lock.Unlock()
293
294 s.stopWG.Done()
[689]295 }()
296
[378]297 u.run()
298 }()
299
300 return u
301}
302
[323]303var lastDownstreamID uint64 = 0
304
[347]305func (s *Server) handle(ic ircConn) {
[689]306 defer func() {
307 if err := recover(); err != nil {
308 s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
309 }
310 }()
311
[709]312 s.metrics.downstreams.Add(1)
[323]313 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]314 dc := newDownstreamConn(s, ic, id)
[323]315 if err := dc.runUntilRegistered(); err != nil {
[655]316 if !errors.Is(err, io.EOF) {
[746]317 dc.logger.Printf("%v", err)
[655]318 }
[323]319 } else {
320 dc.user.events <- eventDownstreamConnected{dc}
321 if err := dc.readMessages(dc.user.events); err != nil {
[746]322 dc.logger.Printf("%v", err)
[323]323 }
324 dc.user.events <- eventDownstreamDisconnected{dc}
325 }
326 dc.Close()
[709]327 s.metrics.downstreams.Add(-1)
[323]328}
329
[3]330func (s *Server) Serve(ln net.Listener) error {
[449]331 s.lock.Lock()
332 s.listeners[ln] = struct{}{}
333 s.lock.Unlock()
334
335 s.stopWG.Add(1)
336
337 defer func() {
338 s.lock.Lock()
339 delete(s.listeners, ln)
340 s.lock.Unlock()
341
342 s.stopWG.Done()
343 }()
344
[1]345 for {
[323]346 conn, err := ln.Accept()
[601]347 if isErrClosed(err) {
[449]348 return nil
349 } else if err != nil {
[1]350 return fmt.Errorf("failed to accept connection: %v", err)
351 }
352
[347]353 go s.handle(newNetIRCConn(conn))
[1]354 }
355}
[323]356
357func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
358 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]359 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[691]360 OriginPatterns: s.Config().HTTPOrigins,
[323]361 })
362 if err != nil {
363 s.Logger.Printf("failed to serve HTTP connection: %v", err)
364 return
365 }
[345]366
[370]367 isProxy := false
[345]368 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
369 if ip := net.ParseIP(host); ip != nil {
[691]370 isProxy = s.Config().AcceptProxyIPs.Contains(ip)
[345]371 }
372 }
373
[474]374 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]375 // to prevent users from spoofing the remote address
[344]376 remoteAddr := req.RemoteAddr
[472]377 if isProxy {
378 forwarded := parseForwarded(req.Header)
[473]379 if forwarded["for"] != "" {
380 remoteAddr = forwarded["for"]
[472]381 }
[344]382 }
[345]383
[347]384 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]385}
[472]386
387func parseForwarded(h http.Header) map[string]string {
388 forwarded := h.Get("Forwarded")
389 if forwarded == "" {
[474]390 return map[string]string{
391 "for": h.Get("X-Forwarded-For"),
392 "proto": h.Get("X-Forwarded-Proto"),
393 "host": h.Get("X-Forwarded-Host"),
394 }
[472]395 }
396 // Hack to easily parse header parameters
397 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
398 return params
399}
[605]400
401type ServerStats struct {
402 Users int
403 Downstreams int64
[710]404 Upstreams int64
[605]405}
406
407func (s *Server) Stats() *ServerStats {
408 var stats ServerStats
409 s.lock.Lock()
410 stats.Users = len(s.users)
411 s.lock.Unlock()
[709]412 stats.Downstreams = s.metrics.downstreams.Value()
[710]413 stats.Upstreams = s.metrics.upstreams.Value()
[605]414 return &stats
415}
Note: See TracBrowser for help on using the repository browser.