source: code/trunk/server.go@ 690

Last change on this file since 690 was 689, checked in by contact, 4 years ago

Add panic handlers for user and downstream goroutines

This only brings down a single user or downstream on panic, instead
or bringing down the whole bouncer.

Closes: https://todo.sr.ht/~emersion/soju/139

File size: 6.1 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[656]5 "errors"
[1]6 "fmt"
[656]7 "io"
[37]8 "log"
[472]9 "mime"
[1]10 "net"
[323]11 "net/http"
[689]12 "runtime/debug"
[24]13 "sync"
[323]14 "sync/atomic"
[67]15 "time"
[1]16
17 "gopkg.in/irc.v3"
[323]18 "nhooyr.io/websocket"
[370]19
20 "git.sr.ht/~emersion/soju/config"
[1]21)
22
[67]23// TODO: make configurable
[398]24var retryConnectDelay = time.Minute
[206]25var connectTimeout = 15 * time.Second
[205]26var writeTimeout = 10 * time.Second
[398]27var upstreamMessageDelay = 2 * time.Second
28var upstreamMessageBurst = 10
[675]29var backlogTimeout = 10 * time.Second
30var handleDownstreamMessageTimeout = 10 * time.Second
[670]31var chatHistoryLimit = 1000
32var backlogLimit = 4000
[67]33
[9]34type Logger interface {
35 Print(v ...interface{})
36 Printf(format string, v ...interface{})
37}
38
[21]39type prefixLogger struct {
40 logger Logger
41 prefix string
42}
43
44var _ Logger = (*prefixLogger)(nil)
45
46func (l *prefixLogger) Print(v ...interface{}) {
47 v = append([]interface{}{l.prefix}, v...)
48 l.logger.Print(v...)
49}
50
51func (l *prefixLogger) Printf(format string, v ...interface{}) {
52 v = append([]interface{}{l.prefix}, v...)
53 l.logger.Printf("%v"+format, v...)
54}
55
[10]56type Server struct {
[612]57 Hostname string
[662]58 Title string
[612]59 Logger Logger
60 LogPath string
61 Debug bool
62 HTTPOrigins []string
63 AcceptProxyIPs config.IPSet
64 MaxUserNetworks int
65 Identd *Identd // can be nil
[22]66
[605]67 db Database
68 stopWG sync.WaitGroup
69 connCount int64 // atomic
[77]70
[449]71 lock sync.Mutex
72 listeners map[net.Listener]struct{}
73 users map[string]*user
[636]74
75 motd atomic.Value // string
[10]76}
77
[531]78func NewServer(db Database) *Server {
[636]79 srv := &Server{
[612]80 Logger: log.New(log.Writer(), "", log.LstdFlags),
81 MaxUserNetworks: -1,
82 db: db,
83 listeners: make(map[net.Listener]struct{}),
84 users: make(map[string]*user),
[37]85 }
[636]86 srv.motd.Store("")
87 return srv
[37]88}
89
[5]90func (s *Server) prefix() *irc.Prefix {
91 return &irc.Prefix{Name: s.Hostname}
92}
93
[449]94func (s *Server) Start() error {
[652]95 users, err := s.db.ListUsers(context.TODO())
[77]96 if err != nil {
97 return err
98 }
[71]99
[77]100 s.lock.Lock()
[378]101 for i := range users {
102 s.addUserLocked(&users[i])
[71]103 }
[37]104 s.lock.Unlock()
105
[449]106 return nil
[10]107}
108
[449]109func (s *Server) Shutdown() {
110 s.lock.Lock()
111 for ln := range s.listeners {
112 if err := ln.Close(); err != nil {
113 s.Logger.Printf("failed to stop listener: %v", err)
114 }
115 }
116 for _, u := range s.users {
117 u.events <- eventStop{}
118 }
119 s.lock.Unlock()
120
121 s.stopWG.Wait()
[599]122
123 if err := s.db.Close(); err != nil {
124 s.Logger.Printf("failed to close DB: %v", err)
125 }
[449]126}
127
[680]128func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
[329]129 s.lock.Lock()
130 defer s.lock.Unlock()
131
132 if _, ok := s.users[user.Username]; ok {
133 return nil, fmt.Errorf("user %q already exists", user.Username)
134 }
135
[680]136 err := s.db.StoreUser(ctx, user)
[329]137 if err != nil {
138 return nil, fmt.Errorf("could not create user in db: %v", err)
139 }
140
[378]141 return s.addUserLocked(user), nil
[329]142}
143
[563]144func (s *Server) forEachUser(f func(*user)) {
145 s.lock.Lock()
146 for _, u := range s.users {
147 f(u)
148 }
149 s.lock.Unlock()
150}
151
[38]152func (s *Server) getUser(name string) *user {
153 s.lock.Lock()
154 u := s.users[name]
155 s.lock.Unlock()
156 return u
157}
158
[378]159func (s *Server) addUserLocked(user *User) *user {
160 s.Logger.Printf("starting bouncer for user %q", user.Username)
161 u := newUser(s, user)
162 s.users[u.Username] = u
163
[449]164 s.stopWG.Add(1)
165
[378]166 go func() {
[689]167 defer func() {
168 if err := recover(); err != nil {
169 s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
170 }
171 }()
172
[378]173 u.run()
174
175 s.lock.Lock()
176 delete(s.users, u.Username)
177 s.lock.Unlock()
[449]178
179 s.stopWG.Done()
[378]180 }()
181
182 return u
183}
184
[323]185var lastDownstreamID uint64 = 0
186
[347]187func (s *Server) handle(ic ircConn) {
[689]188 defer func() {
189 if err := recover(); err != nil {
190 s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
191 }
192 }()
193
[605]194 atomic.AddInt64(&s.connCount, 1)
[323]195 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]196 dc := newDownstreamConn(s, ic, id)
[323]197 if err := dc.runUntilRegistered(); err != nil {
[655]198 if !errors.Is(err, io.EOF) {
199 dc.logger.Print(err)
200 }
[323]201 } else {
202 dc.user.events <- eventDownstreamConnected{dc}
203 if err := dc.readMessages(dc.user.events); err != nil {
204 dc.logger.Print(err)
205 }
206 dc.user.events <- eventDownstreamDisconnected{dc}
207 }
208 dc.Close()
[605]209 atomic.AddInt64(&s.connCount, -1)
[323]210}
211
[3]212func (s *Server) Serve(ln net.Listener) error {
[449]213 s.lock.Lock()
214 s.listeners[ln] = struct{}{}
215 s.lock.Unlock()
216
217 s.stopWG.Add(1)
218
219 defer func() {
220 s.lock.Lock()
221 delete(s.listeners, ln)
222 s.lock.Unlock()
223
224 s.stopWG.Done()
225 }()
226
[1]227 for {
[323]228 conn, err := ln.Accept()
[601]229 if isErrClosed(err) {
[449]230 return nil
231 } else if err != nil {
[1]232 return fmt.Errorf("failed to accept connection: %v", err)
233 }
234
[347]235 go s.handle(newNetIRCConn(conn))
[1]236 }
237}
[323]238
239func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
240 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]241 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[323]242 OriginPatterns: s.HTTPOrigins,
243 })
244 if err != nil {
245 s.Logger.Printf("failed to serve HTTP connection: %v", err)
246 return
247 }
[345]248
[370]249 isProxy := false
[345]250 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
251 if ip := net.ParseIP(host); ip != nil {
[370]252 isProxy = s.AcceptProxyIPs.Contains(ip)
[345]253 }
254 }
255
[474]256 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]257 // to prevent users from spoofing the remote address
[344]258 remoteAddr := req.RemoteAddr
[472]259 if isProxy {
260 forwarded := parseForwarded(req.Header)
[473]261 if forwarded["for"] != "" {
262 remoteAddr = forwarded["for"]
[472]263 }
[344]264 }
[345]265
[347]266 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]267}
[472]268
269func parseForwarded(h http.Header) map[string]string {
270 forwarded := h.Get("Forwarded")
271 if forwarded == "" {
[474]272 return map[string]string{
273 "for": h.Get("X-Forwarded-For"),
274 "proto": h.Get("X-Forwarded-Proto"),
275 "host": h.Get("X-Forwarded-Host"),
276 }
[472]277 }
278 // Hack to easily parse header parameters
279 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
280 return params
281}
[605]282
283type ServerStats struct {
284 Users int
285 Downstreams int64
286}
287
288func (s *Server) Stats() *ServerStats {
289 var stats ServerStats
290 s.lock.Lock()
291 stats.Users = len(s.users)
292 s.lock.Unlock()
293 stats.Downstreams = atomic.LoadInt64(&s.connCount)
294 return &stats
295}
[636]296
297func (s *Server) SetMOTD(motd string) {
298 s.motd.Store(motd)
299}
300
301func (s *Server) MOTD() string {
302 return s.motd.Load().(string)
303}
Note: See TracBrowser for help on using the repository browser.