source: code/trunk/server.go@ 704

Last change on this file since 704 was 704, checked in by contact, 4 years ago

Add timeout for downstream connection registration

File size: 6.3 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[656]5 "errors"
[1]6 "fmt"
[656]7 "io"
[37]8 "log"
[472]9 "mime"
[1]10 "net"
[323]11 "net/http"
[689]12 "runtime/debug"
[24]13 "sync"
[323]14 "sync/atomic"
[67]15 "time"
[1]16
17 "gopkg.in/irc.v3"
[323]18 "nhooyr.io/websocket"
[370]19
20 "git.sr.ht/~emersion/soju/config"
[1]21)
22
[67]23// TODO: make configurable
[398]24var retryConnectDelay = time.Minute
[206]25var connectTimeout = 15 * time.Second
[205]26var writeTimeout = 10 * time.Second
[398]27var upstreamMessageDelay = 2 * time.Second
28var upstreamMessageBurst = 10
[675]29var backlogTimeout = 10 * time.Second
30var handleDownstreamMessageTimeout = 10 * time.Second
[704]31var downstreamRegisterTimeout = 30 * time.Second
[670]32var chatHistoryLimit = 1000
33var backlogLimit = 4000
[67]34
[9]35type Logger interface {
36 Print(v ...interface{})
37 Printf(format string, v ...interface{})
38}
39
[21]40type prefixLogger struct {
41 logger Logger
42 prefix string
43}
44
45var _ Logger = (*prefixLogger)(nil)
46
47func (l *prefixLogger) Print(v ...interface{}) {
48 v = append([]interface{}{l.prefix}, v...)
49 l.logger.Print(v...)
50}
51
52func (l *prefixLogger) Printf(format string, v ...interface{}) {
53 v = append([]interface{}{l.prefix}, v...)
54 l.logger.Printf("%v"+format, v...)
55}
56
[691]57type Config struct {
[612]58 Hostname string
[662]59 Title string
[612]60 LogPath string
61 Debug bool
62 HTTPOrigins []string
63 AcceptProxyIPs config.IPSet
64 MaxUserNetworks int
[694]65 MultiUpstream bool
[691]66 MOTD string
67}
[22]68
[691]69type Server struct {
70 Logger Logger
71 Identd *Identd // can be nil
72
73 config atomic.Value // *Config
[605]74 db Database
75 stopWG sync.WaitGroup
76 connCount int64 // atomic
[77]77
[449]78 lock sync.Mutex
79 listeners map[net.Listener]struct{}
80 users map[string]*user
[10]81}
82
[531]83func NewServer(db Database) *Server {
[636]84 srv := &Server{
[691]85 Logger: log.New(log.Writer(), "", log.LstdFlags),
86 db: db,
87 listeners: make(map[net.Listener]struct{}),
88 users: make(map[string]*user),
[37]89 }
[694]90 srv.config.Store(&Config{
91 Hostname: "localhost",
92 MaxUserNetworks: -1,
93 MultiUpstream: true,
94 })
[636]95 return srv
[37]96}
97
[5]98func (s *Server) prefix() *irc.Prefix {
[691]99 return &irc.Prefix{Name: s.Config().Hostname}
[5]100}
101
[691]102func (s *Server) Config() *Config {
103 return s.config.Load().(*Config)
104}
105
106func (s *Server) SetConfig(cfg *Config) {
107 s.config.Store(cfg)
108}
109
[449]110func (s *Server) Start() error {
[652]111 users, err := s.db.ListUsers(context.TODO())
[77]112 if err != nil {
113 return err
114 }
[71]115
[77]116 s.lock.Lock()
[378]117 for i := range users {
118 s.addUserLocked(&users[i])
[71]119 }
[37]120 s.lock.Unlock()
121
[449]122 return nil
[10]123}
124
[449]125func (s *Server) Shutdown() {
126 s.lock.Lock()
127 for ln := range s.listeners {
128 if err := ln.Close(); err != nil {
129 s.Logger.Printf("failed to stop listener: %v", err)
130 }
131 }
132 for _, u := range s.users {
133 u.events <- eventStop{}
134 }
135 s.lock.Unlock()
136
137 s.stopWG.Wait()
[599]138
139 if err := s.db.Close(); err != nil {
140 s.Logger.Printf("failed to close DB: %v", err)
141 }
[449]142}
143
[680]144func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
[329]145 s.lock.Lock()
146 defer s.lock.Unlock()
147
148 if _, ok := s.users[user.Username]; ok {
149 return nil, fmt.Errorf("user %q already exists", user.Username)
150 }
151
[680]152 err := s.db.StoreUser(ctx, user)
[329]153 if err != nil {
154 return nil, fmt.Errorf("could not create user in db: %v", err)
155 }
156
[378]157 return s.addUserLocked(user), nil
[329]158}
159
[563]160func (s *Server) forEachUser(f func(*user)) {
161 s.lock.Lock()
162 for _, u := range s.users {
163 f(u)
164 }
165 s.lock.Unlock()
166}
167
[38]168func (s *Server) getUser(name string) *user {
169 s.lock.Lock()
170 u := s.users[name]
171 s.lock.Unlock()
172 return u
173}
174
[378]175func (s *Server) addUserLocked(user *User) *user {
176 s.Logger.Printf("starting bouncer for user %q", user.Username)
177 u := newUser(s, user)
178 s.users[u.Username] = u
179
[449]180 s.stopWG.Add(1)
181
[378]182 go func() {
[689]183 defer func() {
184 if err := recover(); err != nil {
185 s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
186 }
187 }()
188
[378]189 u.run()
190
191 s.lock.Lock()
192 delete(s.users, u.Username)
193 s.lock.Unlock()
[449]194
195 s.stopWG.Done()
[378]196 }()
197
198 return u
199}
200
[323]201var lastDownstreamID uint64 = 0
202
[347]203func (s *Server) handle(ic ircConn) {
[689]204 defer func() {
205 if err := recover(); err != nil {
206 s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
207 }
208 }()
209
[605]210 atomic.AddInt64(&s.connCount, 1)
[323]211 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]212 dc := newDownstreamConn(s, ic, id)
[323]213 if err := dc.runUntilRegistered(); err != nil {
[655]214 if !errors.Is(err, io.EOF) {
215 dc.logger.Print(err)
216 }
[323]217 } else {
218 dc.user.events <- eventDownstreamConnected{dc}
219 if err := dc.readMessages(dc.user.events); err != nil {
220 dc.logger.Print(err)
221 }
222 dc.user.events <- eventDownstreamDisconnected{dc}
223 }
224 dc.Close()
[605]225 atomic.AddInt64(&s.connCount, -1)
[323]226}
227
[3]228func (s *Server) Serve(ln net.Listener) error {
[449]229 s.lock.Lock()
230 s.listeners[ln] = struct{}{}
231 s.lock.Unlock()
232
233 s.stopWG.Add(1)
234
235 defer func() {
236 s.lock.Lock()
237 delete(s.listeners, ln)
238 s.lock.Unlock()
239
240 s.stopWG.Done()
241 }()
242
[1]243 for {
[323]244 conn, err := ln.Accept()
[601]245 if isErrClosed(err) {
[449]246 return nil
247 } else if err != nil {
[1]248 return fmt.Errorf("failed to accept connection: %v", err)
249 }
250
[347]251 go s.handle(newNetIRCConn(conn))
[1]252 }
253}
[323]254
255func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
256 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]257 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[691]258 OriginPatterns: s.Config().HTTPOrigins,
[323]259 })
260 if err != nil {
261 s.Logger.Printf("failed to serve HTTP connection: %v", err)
262 return
263 }
[345]264
[370]265 isProxy := false
[345]266 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
267 if ip := net.ParseIP(host); ip != nil {
[691]268 isProxy = s.Config().AcceptProxyIPs.Contains(ip)
[345]269 }
270 }
271
[474]272 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]273 // to prevent users from spoofing the remote address
[344]274 remoteAddr := req.RemoteAddr
[472]275 if isProxy {
276 forwarded := parseForwarded(req.Header)
[473]277 if forwarded["for"] != "" {
278 remoteAddr = forwarded["for"]
[472]279 }
[344]280 }
[345]281
[347]282 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]283}
[472]284
285func parseForwarded(h http.Header) map[string]string {
286 forwarded := h.Get("Forwarded")
287 if forwarded == "" {
[474]288 return map[string]string{
289 "for": h.Get("X-Forwarded-For"),
290 "proto": h.Get("X-Forwarded-Proto"),
291 "host": h.Get("X-Forwarded-Host"),
292 }
[472]293 }
294 // Hack to easily parse header parameters
295 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
296 return params
297}
[605]298
299type ServerStats struct {
300 Users int
301 Downstreams int64
302}
303
304func (s *Server) Stats() *ServerStats {
305 var stats ServerStats
306 s.lock.Lock()
307 stats.Users = len(s.users)
308 s.lock.Unlock()
309 stats.Downstreams = atomic.LoadInt64(&s.connCount)
310 return &stats
311}
Note: See TracBrowser for help on using the repository browser.