source: code/trunk/server.go@ 654

Last change on this file since 654 was 652, checked in by contact, 4 years ago

Add context args to Database interface

This is a mecanical change, which just lifts up the context.TODO()
calls from inside the DB implementations to the callers.

Future work involves properly wiring up the contexts when it makes
sense.

File size: 5.6 KB
Line 
1package soju
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "mime"
8 "net"
9 "net/http"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "gopkg.in/irc.v3"
15 "nhooyr.io/websocket"
16
17 "git.sr.ht/~emersion/soju/config"
18)
19
20// TODO: make configurable
21var retryConnectDelay = time.Minute
22var connectTimeout = 15 * time.Second
23var writeTimeout = 10 * time.Second
24var upstreamMessageDelay = 2 * time.Second
25var upstreamMessageBurst = 10
26
27type Logger interface {
28 Print(v ...interface{})
29 Printf(format string, v ...interface{})
30}
31
32type prefixLogger struct {
33 logger Logger
34 prefix string
35}
36
37var _ Logger = (*prefixLogger)(nil)
38
39func (l *prefixLogger) Print(v ...interface{}) {
40 v = append([]interface{}{l.prefix}, v...)
41 l.logger.Print(v...)
42}
43
44func (l *prefixLogger) Printf(format string, v ...interface{}) {
45 v = append([]interface{}{l.prefix}, v...)
46 l.logger.Printf("%v"+format, v...)
47}
48
49type Server struct {
50 Hostname string
51 Logger Logger
52 HistoryLimit int
53 LogPath string
54 Debug bool
55 HTTPOrigins []string
56 AcceptProxyIPs config.IPSet
57 MaxUserNetworks int
58 Identd *Identd // can be nil
59
60 db Database
61 stopWG sync.WaitGroup
62 connCount int64 // atomic
63
64 lock sync.Mutex
65 listeners map[net.Listener]struct{}
66 users map[string]*user
67
68 motd atomic.Value // string
69}
70
71func NewServer(db Database) *Server {
72 srv := &Server{
73 Logger: log.New(log.Writer(), "", log.LstdFlags),
74 HistoryLimit: 1000,
75 MaxUserNetworks: -1,
76 db: db,
77 listeners: make(map[net.Listener]struct{}),
78 users: make(map[string]*user),
79 }
80 srv.motd.Store("")
81 return srv
82}
83
84func (s *Server) prefix() *irc.Prefix {
85 return &irc.Prefix{Name: s.Hostname}
86}
87
88func (s *Server) Start() error {
89 users, err := s.db.ListUsers(context.TODO())
90 if err != nil {
91 return err
92 }
93
94 s.lock.Lock()
95 for i := range users {
96 s.addUserLocked(&users[i])
97 }
98 s.lock.Unlock()
99
100 return nil
101}
102
103func (s *Server) Shutdown() {
104 s.lock.Lock()
105 for ln := range s.listeners {
106 if err := ln.Close(); err != nil {
107 s.Logger.Printf("failed to stop listener: %v", err)
108 }
109 }
110 for _, u := range s.users {
111 u.events <- eventStop{}
112 }
113 s.lock.Unlock()
114
115 s.stopWG.Wait()
116
117 if err := s.db.Close(); err != nil {
118 s.Logger.Printf("failed to close DB: %v", err)
119 }
120}
121
122func (s *Server) createUser(user *User) (*user, error) {
123 s.lock.Lock()
124 defer s.lock.Unlock()
125
126 if _, ok := s.users[user.Username]; ok {
127 return nil, fmt.Errorf("user %q already exists", user.Username)
128 }
129
130 err := s.db.StoreUser(context.TODO(), user)
131 if err != nil {
132 return nil, fmt.Errorf("could not create user in db: %v", err)
133 }
134
135 return s.addUserLocked(user), nil
136}
137
138func (s *Server) forEachUser(f func(*user)) {
139 s.lock.Lock()
140 for _, u := range s.users {
141 f(u)
142 }
143 s.lock.Unlock()
144}
145
146func (s *Server) getUser(name string) *user {
147 s.lock.Lock()
148 u := s.users[name]
149 s.lock.Unlock()
150 return u
151}
152
153func (s *Server) addUserLocked(user *User) *user {
154 s.Logger.Printf("starting bouncer for user %q", user.Username)
155 u := newUser(s, user)
156 s.users[u.Username] = u
157
158 s.stopWG.Add(1)
159
160 go func() {
161 u.run()
162
163 s.lock.Lock()
164 delete(s.users, u.Username)
165 s.lock.Unlock()
166
167 s.stopWG.Done()
168 }()
169
170 return u
171}
172
173var lastDownstreamID uint64 = 0
174
175func (s *Server) handle(ic ircConn) {
176 atomic.AddInt64(&s.connCount, 1)
177 id := atomic.AddUint64(&lastDownstreamID, 1)
178 dc := newDownstreamConn(s, ic, id)
179 if err := dc.runUntilRegistered(); err != nil {
180 dc.logger.Print(err)
181 } else {
182 dc.user.events <- eventDownstreamConnected{dc}
183 if err := dc.readMessages(dc.user.events); err != nil {
184 dc.logger.Print(err)
185 }
186 dc.user.events <- eventDownstreamDisconnected{dc}
187 }
188 dc.Close()
189 atomic.AddInt64(&s.connCount, -1)
190}
191
192func (s *Server) Serve(ln net.Listener) error {
193 s.lock.Lock()
194 s.listeners[ln] = struct{}{}
195 s.lock.Unlock()
196
197 s.stopWG.Add(1)
198
199 defer func() {
200 s.lock.Lock()
201 delete(s.listeners, ln)
202 s.lock.Unlock()
203
204 s.stopWG.Done()
205 }()
206
207 for {
208 conn, err := ln.Accept()
209 if isErrClosed(err) {
210 return nil
211 } else if err != nil {
212 return fmt.Errorf("failed to accept connection: %v", err)
213 }
214
215 go s.handle(newNetIRCConn(conn))
216 }
217}
218
219func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
220 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
221 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
222 OriginPatterns: s.HTTPOrigins,
223 })
224 if err != nil {
225 s.Logger.Printf("failed to serve HTTP connection: %v", err)
226 return
227 }
228
229 isProxy := false
230 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
231 if ip := net.ParseIP(host); ip != nil {
232 isProxy = s.AcceptProxyIPs.Contains(ip)
233 }
234 }
235
236 // Only trust the Forwarded header field if this is a trusted proxy IP
237 // to prevent users from spoofing the remote address
238 remoteAddr := req.RemoteAddr
239 if isProxy {
240 forwarded := parseForwarded(req.Header)
241 if forwarded["for"] != "" {
242 remoteAddr = forwarded["for"]
243 }
244 }
245
246 s.handle(newWebsocketIRCConn(conn, remoteAddr))
247}
248
249func parseForwarded(h http.Header) map[string]string {
250 forwarded := h.Get("Forwarded")
251 if forwarded == "" {
252 return map[string]string{
253 "for": h.Get("X-Forwarded-For"),
254 "proto": h.Get("X-Forwarded-Proto"),
255 "host": h.Get("X-Forwarded-Host"),
256 }
257 }
258 // Hack to easily parse header parameters
259 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
260 return params
261}
262
263type ServerStats struct {
264 Users int
265 Downstreams int64
266}
267
268func (s *Server) Stats() *ServerStats {
269 var stats ServerStats
270 s.lock.Lock()
271 stats.Users = len(s.users)
272 s.lock.Unlock()
273 stats.Downstreams = atomic.LoadInt64(&s.connCount)
274 return &stats
275}
276
277func (s *Server) SetMOTD(motd string) {
278 s.motd.Store(motd)
279}
280
281func (s *Server) MOTD() string {
282 return s.motd.Load().(string)
283}
Note: See TracBrowser for help on using the repository browser.