source: code/trunk/server.go@ 692

Last change on this file since 692 was 691, checked in by contact, 4 years ago

Allow most config options to be reloaded

Closes: https://todo.sr.ht/~emersion/soju/42

File size: 6.2 KB
Line 
1package soju
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log"
9 "mime"
10 "net"
11 "net/http"
12 "runtime/debug"
13 "sync"
14 "sync/atomic"
15 "time"
16
17 "gopkg.in/irc.v3"
18 "nhooyr.io/websocket"
19
20 "git.sr.ht/~emersion/soju/config"
21)
22
23// TODO: make configurable
24var retryConnectDelay = time.Minute
25var connectTimeout = 15 * time.Second
26var writeTimeout = 10 * time.Second
27var upstreamMessageDelay = 2 * time.Second
28var upstreamMessageBurst = 10
29var backlogTimeout = 10 * time.Second
30var handleDownstreamMessageTimeout = 10 * time.Second
31var chatHistoryLimit = 1000
32var backlogLimit = 4000
33
34type Logger interface {
35 Print(v ...interface{})
36 Printf(format string, v ...interface{})
37}
38
39type prefixLogger struct {
40 logger Logger
41 prefix string
42}
43
44var _ Logger = (*prefixLogger)(nil)
45
46func (l *prefixLogger) Print(v ...interface{}) {
47 v = append([]interface{}{l.prefix}, v...)
48 l.logger.Print(v...)
49}
50
51func (l *prefixLogger) Printf(format string, v ...interface{}) {
52 v = append([]interface{}{l.prefix}, v...)
53 l.logger.Printf("%v"+format, v...)
54}
55
56type Config struct {
57 Hostname string
58 Title string
59 LogPath string
60 Debug bool
61 HTTPOrigins []string
62 AcceptProxyIPs config.IPSet
63 MaxUserNetworks int
64 MOTD string
65}
66
67type Server struct {
68 Logger Logger
69 Identd *Identd // can be nil
70
71 config atomic.Value // *Config
72 db Database
73 stopWG sync.WaitGroup
74 connCount int64 // atomic
75
76 lock sync.Mutex
77 listeners map[net.Listener]struct{}
78 users map[string]*user
79}
80
81func NewServer(db Database) *Server {
82 srv := &Server{
83 Logger: log.New(log.Writer(), "", log.LstdFlags),
84 db: db,
85 listeners: make(map[net.Listener]struct{}),
86 users: make(map[string]*user),
87 }
88 srv.config.Store(&Config{Hostname: "localhost", MaxUserNetworks: -1})
89 return srv
90}
91
92func (s *Server) prefix() *irc.Prefix {
93 return &irc.Prefix{Name: s.Config().Hostname}
94}
95
96func (s *Server) Config() *Config {
97 return s.config.Load().(*Config)
98}
99
100func (s *Server) SetConfig(cfg *Config) {
101 s.config.Store(cfg)
102}
103
104func (s *Server) Start() error {
105 users, err := s.db.ListUsers(context.TODO())
106 if err != nil {
107 return err
108 }
109
110 s.lock.Lock()
111 for i := range users {
112 s.addUserLocked(&users[i])
113 }
114 s.lock.Unlock()
115
116 return nil
117}
118
119func (s *Server) Shutdown() {
120 s.lock.Lock()
121 for ln := range s.listeners {
122 if err := ln.Close(); err != nil {
123 s.Logger.Printf("failed to stop listener: %v", err)
124 }
125 }
126 for _, u := range s.users {
127 u.events <- eventStop{}
128 }
129 s.lock.Unlock()
130
131 s.stopWG.Wait()
132
133 if err := s.db.Close(); err != nil {
134 s.Logger.Printf("failed to close DB: %v", err)
135 }
136}
137
138func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
139 s.lock.Lock()
140 defer s.lock.Unlock()
141
142 if _, ok := s.users[user.Username]; ok {
143 return nil, fmt.Errorf("user %q already exists", user.Username)
144 }
145
146 err := s.db.StoreUser(ctx, user)
147 if err != nil {
148 return nil, fmt.Errorf("could not create user in db: %v", err)
149 }
150
151 return s.addUserLocked(user), nil
152}
153
154func (s *Server) forEachUser(f func(*user)) {
155 s.lock.Lock()
156 for _, u := range s.users {
157 f(u)
158 }
159 s.lock.Unlock()
160}
161
162func (s *Server) getUser(name string) *user {
163 s.lock.Lock()
164 u := s.users[name]
165 s.lock.Unlock()
166 return u
167}
168
169func (s *Server) addUserLocked(user *User) *user {
170 s.Logger.Printf("starting bouncer for user %q", user.Username)
171 u := newUser(s, user)
172 s.users[u.Username] = u
173
174 s.stopWG.Add(1)
175
176 go func() {
177 defer func() {
178 if err := recover(); err != nil {
179 s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
180 }
181 }()
182
183 u.run()
184
185 s.lock.Lock()
186 delete(s.users, u.Username)
187 s.lock.Unlock()
188
189 s.stopWG.Done()
190 }()
191
192 return u
193}
194
195var lastDownstreamID uint64 = 0
196
197func (s *Server) handle(ic ircConn) {
198 defer func() {
199 if err := recover(); err != nil {
200 s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
201 }
202 }()
203
204 atomic.AddInt64(&s.connCount, 1)
205 id := atomic.AddUint64(&lastDownstreamID, 1)
206 dc := newDownstreamConn(s, ic, id)
207 if err := dc.runUntilRegistered(); err != nil {
208 if !errors.Is(err, io.EOF) {
209 dc.logger.Print(err)
210 }
211 } else {
212 dc.user.events <- eventDownstreamConnected{dc}
213 if err := dc.readMessages(dc.user.events); err != nil {
214 dc.logger.Print(err)
215 }
216 dc.user.events <- eventDownstreamDisconnected{dc}
217 }
218 dc.Close()
219 atomic.AddInt64(&s.connCount, -1)
220}
221
222func (s *Server) Serve(ln net.Listener) error {
223 s.lock.Lock()
224 s.listeners[ln] = struct{}{}
225 s.lock.Unlock()
226
227 s.stopWG.Add(1)
228
229 defer func() {
230 s.lock.Lock()
231 delete(s.listeners, ln)
232 s.lock.Unlock()
233
234 s.stopWG.Done()
235 }()
236
237 for {
238 conn, err := ln.Accept()
239 if isErrClosed(err) {
240 return nil
241 } else if err != nil {
242 return fmt.Errorf("failed to accept connection: %v", err)
243 }
244
245 go s.handle(newNetIRCConn(conn))
246 }
247}
248
249func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
250 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
251 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
252 OriginPatterns: s.Config().HTTPOrigins,
253 })
254 if err != nil {
255 s.Logger.Printf("failed to serve HTTP connection: %v", err)
256 return
257 }
258
259 isProxy := false
260 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
261 if ip := net.ParseIP(host); ip != nil {
262 isProxy = s.Config().AcceptProxyIPs.Contains(ip)
263 }
264 }
265
266 // Only trust the Forwarded header field if this is a trusted proxy IP
267 // to prevent users from spoofing the remote address
268 remoteAddr := req.RemoteAddr
269 if isProxy {
270 forwarded := parseForwarded(req.Header)
271 if forwarded["for"] != "" {
272 remoteAddr = forwarded["for"]
273 }
274 }
275
276 s.handle(newWebsocketIRCConn(conn, remoteAddr))
277}
278
279func parseForwarded(h http.Header) map[string]string {
280 forwarded := h.Get("Forwarded")
281 if forwarded == "" {
282 return map[string]string{
283 "for": h.Get("X-Forwarded-For"),
284 "proto": h.Get("X-Forwarded-Proto"),
285 "host": h.Get("X-Forwarded-Host"),
286 }
287 }
288 // Hack to easily parse header parameters
289 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
290 return params
291}
292
293type ServerStats struct {
294 Users int
295 Downstreams int64
296}
297
298func (s *Server) Stats() *ServerStats {
299 var stats ServerStats
300 s.lock.Lock()
301 stats.Users = len(s.users)
302 s.lock.Unlock()
303 stats.Downstreams = atomic.LoadInt64(&s.connCount)
304 return &stats
305}
Note: See TracBrowser for help on using the repository browser.