source: code/trunk/server.go@ 653

Last change on this file since 653 was 652, checked in by contact, 4 years ago

Add context args to Database interface

This is a mecanical change, which just lifts up the context.TODO()
calls from inside the DB implementations to the callers.

Future work involves properly wiring up the contexts when it makes
sense.

File size: 5.6 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[1]5 "fmt"
[37]6 "log"
[472]7 "mime"
[1]8 "net"
[323]9 "net/http"
[24]10 "sync"
[323]11 "sync/atomic"
[67]12 "time"
[1]13
14 "gopkg.in/irc.v3"
[323]15 "nhooyr.io/websocket"
[370]16
17 "git.sr.ht/~emersion/soju/config"
[1]18)
19
[67]20// TODO: make configurable
[398]21var retryConnectDelay = time.Minute
[206]22var connectTimeout = 15 * time.Second
[205]23var writeTimeout = 10 * time.Second
[398]24var upstreamMessageDelay = 2 * time.Second
25var upstreamMessageBurst = 10
[67]26
[9]27type Logger interface {
28 Print(v ...interface{})
29 Printf(format string, v ...interface{})
30}
31
[21]32type prefixLogger struct {
33 logger Logger
34 prefix string
35}
36
37var _ Logger = (*prefixLogger)(nil)
38
39func (l *prefixLogger) Print(v ...interface{}) {
40 v = append([]interface{}{l.prefix}, v...)
41 l.logger.Print(v...)
42}
43
44func (l *prefixLogger) Printf(format string, v ...interface{}) {
45 v = append([]interface{}{l.prefix}, v...)
46 l.logger.Printf("%v"+format, v...)
47}
48
[10]49type Server struct {
[612]50 Hostname string
51 Logger Logger
52 HistoryLimit int
53 LogPath string
54 Debug bool
55 HTTPOrigins []string
56 AcceptProxyIPs config.IPSet
57 MaxUserNetworks int
58 Identd *Identd // can be nil
[22]59
[605]60 db Database
61 stopWG sync.WaitGroup
62 connCount int64 // atomic
[77]63
[449]64 lock sync.Mutex
65 listeners map[net.Listener]struct{}
66 users map[string]*user
[636]67
68 motd atomic.Value // string
[10]69}
70
[531]71func NewServer(db Database) *Server {
[636]72 srv := &Server{
[612]73 Logger: log.New(log.Writer(), "", log.LstdFlags),
74 HistoryLimit: 1000,
75 MaxUserNetworks: -1,
76 db: db,
77 listeners: make(map[net.Listener]struct{}),
78 users: make(map[string]*user),
[37]79 }
[636]80 srv.motd.Store("")
81 return srv
[37]82}
83
[5]84func (s *Server) prefix() *irc.Prefix {
85 return &irc.Prefix{Name: s.Hostname}
86}
87
[449]88func (s *Server) Start() error {
[652]89 users, err := s.db.ListUsers(context.TODO())
[77]90 if err != nil {
91 return err
92 }
[71]93
[77]94 s.lock.Lock()
[378]95 for i := range users {
96 s.addUserLocked(&users[i])
[71]97 }
[37]98 s.lock.Unlock()
99
[449]100 return nil
[10]101}
102
[449]103func (s *Server) Shutdown() {
104 s.lock.Lock()
105 for ln := range s.listeners {
106 if err := ln.Close(); err != nil {
107 s.Logger.Printf("failed to stop listener: %v", err)
108 }
109 }
110 for _, u := range s.users {
111 u.events <- eventStop{}
112 }
113 s.lock.Unlock()
114
115 s.stopWG.Wait()
[599]116
117 if err := s.db.Close(); err != nil {
118 s.Logger.Printf("failed to close DB: %v", err)
119 }
[449]120}
121
[329]122func (s *Server) createUser(user *User) (*user, error) {
123 s.lock.Lock()
124 defer s.lock.Unlock()
125
126 if _, ok := s.users[user.Username]; ok {
127 return nil, fmt.Errorf("user %q already exists", user.Username)
128 }
129
[652]130 err := s.db.StoreUser(context.TODO(), user)
[329]131 if err != nil {
132 return nil, fmt.Errorf("could not create user in db: %v", err)
133 }
134
[378]135 return s.addUserLocked(user), nil
[329]136}
137
[563]138func (s *Server) forEachUser(f func(*user)) {
139 s.lock.Lock()
140 for _, u := range s.users {
141 f(u)
142 }
143 s.lock.Unlock()
144}
145
[38]146func (s *Server) getUser(name string) *user {
147 s.lock.Lock()
148 u := s.users[name]
149 s.lock.Unlock()
150 return u
151}
152
[378]153func (s *Server) addUserLocked(user *User) *user {
154 s.Logger.Printf("starting bouncer for user %q", user.Username)
155 u := newUser(s, user)
156 s.users[u.Username] = u
157
[449]158 s.stopWG.Add(1)
159
[378]160 go func() {
161 u.run()
162
163 s.lock.Lock()
164 delete(s.users, u.Username)
165 s.lock.Unlock()
[449]166
167 s.stopWG.Done()
[378]168 }()
169
170 return u
171}
172
[323]173var lastDownstreamID uint64 = 0
174
[347]175func (s *Server) handle(ic ircConn) {
[605]176 atomic.AddInt64(&s.connCount, 1)
[323]177 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]178 dc := newDownstreamConn(s, ic, id)
[323]179 if err := dc.runUntilRegistered(); err != nil {
180 dc.logger.Print(err)
181 } else {
182 dc.user.events <- eventDownstreamConnected{dc}
183 if err := dc.readMessages(dc.user.events); err != nil {
184 dc.logger.Print(err)
185 }
186 dc.user.events <- eventDownstreamDisconnected{dc}
187 }
188 dc.Close()
[605]189 atomic.AddInt64(&s.connCount, -1)
[323]190}
191
[3]192func (s *Server) Serve(ln net.Listener) error {
[449]193 s.lock.Lock()
194 s.listeners[ln] = struct{}{}
195 s.lock.Unlock()
196
197 s.stopWG.Add(1)
198
199 defer func() {
200 s.lock.Lock()
201 delete(s.listeners, ln)
202 s.lock.Unlock()
203
204 s.stopWG.Done()
205 }()
206
[1]207 for {
[323]208 conn, err := ln.Accept()
[601]209 if isErrClosed(err) {
[449]210 return nil
211 } else if err != nil {
[1]212 return fmt.Errorf("failed to accept connection: %v", err)
213 }
214
[347]215 go s.handle(newNetIRCConn(conn))
[1]216 }
217}
[323]218
219func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
220 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]221 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[323]222 OriginPatterns: s.HTTPOrigins,
223 })
224 if err != nil {
225 s.Logger.Printf("failed to serve HTTP connection: %v", err)
226 return
227 }
[345]228
[370]229 isProxy := false
[345]230 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
231 if ip := net.ParseIP(host); ip != nil {
[370]232 isProxy = s.AcceptProxyIPs.Contains(ip)
[345]233 }
234 }
235
[474]236 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]237 // to prevent users from spoofing the remote address
[344]238 remoteAddr := req.RemoteAddr
[472]239 if isProxy {
240 forwarded := parseForwarded(req.Header)
[473]241 if forwarded["for"] != "" {
242 remoteAddr = forwarded["for"]
[472]243 }
[344]244 }
[345]245
[347]246 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]247}
[472]248
249func parseForwarded(h http.Header) map[string]string {
250 forwarded := h.Get("Forwarded")
251 if forwarded == "" {
[474]252 return map[string]string{
253 "for": h.Get("X-Forwarded-For"),
254 "proto": h.Get("X-Forwarded-Proto"),
255 "host": h.Get("X-Forwarded-Host"),
256 }
[472]257 }
258 // Hack to easily parse header parameters
259 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
260 return params
261}
[605]262
263type ServerStats struct {
264 Users int
265 Downstreams int64
266}
267
268func (s *Server) Stats() *ServerStats {
269 var stats ServerStats
270 s.lock.Lock()
271 stats.Users = len(s.users)
272 s.lock.Unlock()
273 stats.Downstreams = atomic.LoadInt64(&s.connCount)
274 return &stats
275}
[636]276
277func (s *Server) SetMOTD(motd string) {
278 s.motd.Store(motd)
279}
280
281func (s *Server) MOTD() string {
282 return s.motd.Load().(string)
283}
Note: See TracBrowser for help on using the repository browser.