source: code/trunk/server.go@ 699

Last change on this file since 699 was 694, checked in by contact, 4 years ago

Add config option to globally disable multi-upstream mode

Closes: https://todo.sr.ht/~emersion/soju/122

File size: 6.3 KB
Line 
1package soju
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log"
9 "mime"
10 "net"
11 "net/http"
12 "runtime/debug"
13 "sync"
14 "sync/atomic"
15 "time"
16
17 "gopkg.in/irc.v3"
18 "nhooyr.io/websocket"
19
20 "git.sr.ht/~emersion/soju/config"
21)
22
23// TODO: make configurable
24var retryConnectDelay = time.Minute
25var connectTimeout = 15 * time.Second
26var writeTimeout = 10 * time.Second
27var upstreamMessageDelay = 2 * time.Second
28var upstreamMessageBurst = 10
29var backlogTimeout = 10 * time.Second
30var handleDownstreamMessageTimeout = 10 * time.Second
31var chatHistoryLimit = 1000
32var backlogLimit = 4000
33
34type Logger interface {
35 Print(v ...interface{})
36 Printf(format string, v ...interface{})
37}
38
39type prefixLogger struct {
40 logger Logger
41 prefix string
42}
43
44var _ Logger = (*prefixLogger)(nil)
45
46func (l *prefixLogger) Print(v ...interface{}) {
47 v = append([]interface{}{l.prefix}, v...)
48 l.logger.Print(v...)
49}
50
51func (l *prefixLogger) Printf(format string, v ...interface{}) {
52 v = append([]interface{}{l.prefix}, v...)
53 l.logger.Printf("%v"+format, v...)
54}
55
56type Config struct {
57 Hostname string
58 Title string
59 LogPath string
60 Debug bool
61 HTTPOrigins []string
62 AcceptProxyIPs config.IPSet
63 MaxUserNetworks int
64 MultiUpstream bool
65 MOTD string
66}
67
68type Server struct {
69 Logger Logger
70 Identd *Identd // can be nil
71
72 config atomic.Value // *Config
73 db Database
74 stopWG sync.WaitGroup
75 connCount int64 // atomic
76
77 lock sync.Mutex
78 listeners map[net.Listener]struct{}
79 users map[string]*user
80}
81
82func NewServer(db Database) *Server {
83 srv := &Server{
84 Logger: log.New(log.Writer(), "", log.LstdFlags),
85 db: db,
86 listeners: make(map[net.Listener]struct{}),
87 users: make(map[string]*user),
88 }
89 srv.config.Store(&Config{
90 Hostname: "localhost",
91 MaxUserNetworks: -1,
92 MultiUpstream: true,
93 })
94 return srv
95}
96
97func (s *Server) prefix() *irc.Prefix {
98 return &irc.Prefix{Name: s.Config().Hostname}
99}
100
101func (s *Server) Config() *Config {
102 return s.config.Load().(*Config)
103}
104
105func (s *Server) SetConfig(cfg *Config) {
106 s.config.Store(cfg)
107}
108
109func (s *Server) Start() error {
110 users, err := s.db.ListUsers(context.TODO())
111 if err != nil {
112 return err
113 }
114
115 s.lock.Lock()
116 for i := range users {
117 s.addUserLocked(&users[i])
118 }
119 s.lock.Unlock()
120
121 return nil
122}
123
124func (s *Server) Shutdown() {
125 s.lock.Lock()
126 for ln := range s.listeners {
127 if err := ln.Close(); err != nil {
128 s.Logger.Printf("failed to stop listener: %v", err)
129 }
130 }
131 for _, u := range s.users {
132 u.events <- eventStop{}
133 }
134 s.lock.Unlock()
135
136 s.stopWG.Wait()
137
138 if err := s.db.Close(); err != nil {
139 s.Logger.Printf("failed to close DB: %v", err)
140 }
141}
142
143func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
144 s.lock.Lock()
145 defer s.lock.Unlock()
146
147 if _, ok := s.users[user.Username]; ok {
148 return nil, fmt.Errorf("user %q already exists", user.Username)
149 }
150
151 err := s.db.StoreUser(ctx, user)
152 if err != nil {
153 return nil, fmt.Errorf("could not create user in db: %v", err)
154 }
155
156 return s.addUserLocked(user), nil
157}
158
159func (s *Server) forEachUser(f func(*user)) {
160 s.lock.Lock()
161 for _, u := range s.users {
162 f(u)
163 }
164 s.lock.Unlock()
165}
166
167func (s *Server) getUser(name string) *user {
168 s.lock.Lock()
169 u := s.users[name]
170 s.lock.Unlock()
171 return u
172}
173
174func (s *Server) addUserLocked(user *User) *user {
175 s.Logger.Printf("starting bouncer for user %q", user.Username)
176 u := newUser(s, user)
177 s.users[u.Username] = u
178
179 s.stopWG.Add(1)
180
181 go func() {
182 defer func() {
183 if err := recover(); err != nil {
184 s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
185 }
186 }()
187
188 u.run()
189
190 s.lock.Lock()
191 delete(s.users, u.Username)
192 s.lock.Unlock()
193
194 s.stopWG.Done()
195 }()
196
197 return u
198}
199
200var lastDownstreamID uint64 = 0
201
202func (s *Server) handle(ic ircConn) {
203 defer func() {
204 if err := recover(); err != nil {
205 s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
206 }
207 }()
208
209 atomic.AddInt64(&s.connCount, 1)
210 id := atomic.AddUint64(&lastDownstreamID, 1)
211 dc := newDownstreamConn(s, ic, id)
212 if err := dc.runUntilRegistered(); err != nil {
213 if !errors.Is(err, io.EOF) {
214 dc.logger.Print(err)
215 }
216 } else {
217 dc.user.events <- eventDownstreamConnected{dc}
218 if err := dc.readMessages(dc.user.events); err != nil {
219 dc.logger.Print(err)
220 }
221 dc.user.events <- eventDownstreamDisconnected{dc}
222 }
223 dc.Close()
224 atomic.AddInt64(&s.connCount, -1)
225}
226
227func (s *Server) Serve(ln net.Listener) error {
228 s.lock.Lock()
229 s.listeners[ln] = struct{}{}
230 s.lock.Unlock()
231
232 s.stopWG.Add(1)
233
234 defer func() {
235 s.lock.Lock()
236 delete(s.listeners, ln)
237 s.lock.Unlock()
238
239 s.stopWG.Done()
240 }()
241
242 for {
243 conn, err := ln.Accept()
244 if isErrClosed(err) {
245 return nil
246 } else if err != nil {
247 return fmt.Errorf("failed to accept connection: %v", err)
248 }
249
250 go s.handle(newNetIRCConn(conn))
251 }
252}
253
254func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
255 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
256 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
257 OriginPatterns: s.Config().HTTPOrigins,
258 })
259 if err != nil {
260 s.Logger.Printf("failed to serve HTTP connection: %v", err)
261 return
262 }
263
264 isProxy := false
265 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
266 if ip := net.ParseIP(host); ip != nil {
267 isProxy = s.Config().AcceptProxyIPs.Contains(ip)
268 }
269 }
270
271 // Only trust the Forwarded header field if this is a trusted proxy IP
272 // to prevent users from spoofing the remote address
273 remoteAddr := req.RemoteAddr
274 if isProxy {
275 forwarded := parseForwarded(req.Header)
276 if forwarded["for"] != "" {
277 remoteAddr = forwarded["for"]
278 }
279 }
280
281 s.handle(newWebsocketIRCConn(conn, remoteAddr))
282}
283
284func parseForwarded(h http.Header) map[string]string {
285 forwarded := h.Get("Forwarded")
286 if forwarded == "" {
287 return map[string]string{
288 "for": h.Get("X-Forwarded-For"),
289 "proto": h.Get("X-Forwarded-Proto"),
290 "host": h.Get("X-Forwarded-Host"),
291 }
292 }
293 // Hack to easily parse header parameters
294 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
295 return params
296}
297
298type ServerStats struct {
299 Users int
300 Downstreams int64
301}
302
303func (s *Server) Stats() *ServerStats {
304 var stats ServerStats
305 s.lock.Lock()
306 stats.Users = len(s.users)
307 s.lock.Unlock()
308 stats.Downstreams = atomic.LoadInt64(&s.connCount)
309 return &stats
310}
Note: See TracBrowser for help on using the repository browser.