source: code/trunk/server.go@ 671

Last change on this file since 671 was 670, checked in by contact, 4 years ago

Turn CHATHISTORY and backlog limits into constants

File size: 5.7 KB
Line 
1package soju
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log"
9 "mime"
10 "net"
11 "net/http"
12 "sync"
13 "sync/atomic"
14 "time"
15
16 "gopkg.in/irc.v3"
17 "nhooyr.io/websocket"
18
19 "git.sr.ht/~emersion/soju/config"
20)
21
22// TODO: make configurable
23var retryConnectDelay = time.Minute
24var connectTimeout = 15 * time.Second
25var writeTimeout = 10 * time.Second
26var upstreamMessageDelay = 2 * time.Second
27var upstreamMessageBurst = 10
28var messageStoreTimeout = 10 * time.Second
29var chatHistoryLimit = 1000
30var backlogLimit = 4000
31
32type Logger interface {
33 Print(v ...interface{})
34 Printf(format string, v ...interface{})
35}
36
37type prefixLogger struct {
38 logger Logger
39 prefix string
40}
41
42var _ Logger = (*prefixLogger)(nil)
43
44func (l *prefixLogger) Print(v ...interface{}) {
45 v = append([]interface{}{l.prefix}, v...)
46 l.logger.Print(v...)
47}
48
49func (l *prefixLogger) Printf(format string, v ...interface{}) {
50 v = append([]interface{}{l.prefix}, v...)
51 l.logger.Printf("%v"+format, v...)
52}
53
54type Server struct {
55 Hostname string
56 Title string
57 Logger Logger
58 LogPath string
59 Debug bool
60 HTTPOrigins []string
61 AcceptProxyIPs config.IPSet
62 MaxUserNetworks int
63 Identd *Identd // can be nil
64
65 db Database
66 stopWG sync.WaitGroup
67 connCount int64 // atomic
68
69 lock sync.Mutex
70 listeners map[net.Listener]struct{}
71 users map[string]*user
72
73 motd atomic.Value // string
74}
75
76func NewServer(db Database) *Server {
77 srv := &Server{
78 Logger: log.New(log.Writer(), "", log.LstdFlags),
79 MaxUserNetworks: -1,
80 db: db,
81 listeners: make(map[net.Listener]struct{}),
82 users: make(map[string]*user),
83 }
84 srv.motd.Store("")
85 return srv
86}
87
88func (s *Server) prefix() *irc.Prefix {
89 return &irc.Prefix{Name: s.Hostname}
90}
91
92func (s *Server) Start() error {
93 users, err := s.db.ListUsers(context.TODO())
94 if err != nil {
95 return err
96 }
97
98 s.lock.Lock()
99 for i := range users {
100 s.addUserLocked(&users[i])
101 }
102 s.lock.Unlock()
103
104 return nil
105}
106
107func (s *Server) Shutdown() {
108 s.lock.Lock()
109 for ln := range s.listeners {
110 if err := ln.Close(); err != nil {
111 s.Logger.Printf("failed to stop listener: %v", err)
112 }
113 }
114 for _, u := range s.users {
115 u.events <- eventStop{}
116 }
117 s.lock.Unlock()
118
119 s.stopWG.Wait()
120
121 if err := s.db.Close(); err != nil {
122 s.Logger.Printf("failed to close DB: %v", err)
123 }
124}
125
126func (s *Server) createUser(user *User) (*user, error) {
127 s.lock.Lock()
128 defer s.lock.Unlock()
129
130 if _, ok := s.users[user.Username]; ok {
131 return nil, fmt.Errorf("user %q already exists", user.Username)
132 }
133
134 err := s.db.StoreUser(context.TODO(), user)
135 if err != nil {
136 return nil, fmt.Errorf("could not create user in db: %v", err)
137 }
138
139 return s.addUserLocked(user), nil
140}
141
142func (s *Server) forEachUser(f func(*user)) {
143 s.lock.Lock()
144 for _, u := range s.users {
145 f(u)
146 }
147 s.lock.Unlock()
148}
149
150func (s *Server) getUser(name string) *user {
151 s.lock.Lock()
152 u := s.users[name]
153 s.lock.Unlock()
154 return u
155}
156
157func (s *Server) addUserLocked(user *User) *user {
158 s.Logger.Printf("starting bouncer for user %q", user.Username)
159 u := newUser(s, user)
160 s.users[u.Username] = u
161
162 s.stopWG.Add(1)
163
164 go func() {
165 u.run()
166
167 s.lock.Lock()
168 delete(s.users, u.Username)
169 s.lock.Unlock()
170
171 s.stopWG.Done()
172 }()
173
174 return u
175}
176
177var lastDownstreamID uint64 = 0
178
179func (s *Server) handle(ic ircConn) {
180 atomic.AddInt64(&s.connCount, 1)
181 id := atomic.AddUint64(&lastDownstreamID, 1)
182 dc := newDownstreamConn(s, ic, id)
183 if err := dc.runUntilRegistered(); err != nil {
184 if !errors.Is(err, io.EOF) {
185 dc.logger.Print(err)
186 }
187 } else {
188 dc.user.events <- eventDownstreamConnected{dc}
189 if err := dc.readMessages(dc.user.events); err != nil {
190 dc.logger.Print(err)
191 }
192 dc.user.events <- eventDownstreamDisconnected{dc}
193 }
194 dc.Close()
195 atomic.AddInt64(&s.connCount, -1)
196}
197
198func (s *Server) Serve(ln net.Listener) error {
199 s.lock.Lock()
200 s.listeners[ln] = struct{}{}
201 s.lock.Unlock()
202
203 s.stopWG.Add(1)
204
205 defer func() {
206 s.lock.Lock()
207 delete(s.listeners, ln)
208 s.lock.Unlock()
209
210 s.stopWG.Done()
211 }()
212
213 for {
214 conn, err := ln.Accept()
215 if isErrClosed(err) {
216 return nil
217 } else if err != nil {
218 return fmt.Errorf("failed to accept connection: %v", err)
219 }
220
221 go s.handle(newNetIRCConn(conn))
222 }
223}
224
225func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
226 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
227 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
228 OriginPatterns: s.HTTPOrigins,
229 })
230 if err != nil {
231 s.Logger.Printf("failed to serve HTTP connection: %v", err)
232 return
233 }
234
235 isProxy := false
236 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
237 if ip := net.ParseIP(host); ip != nil {
238 isProxy = s.AcceptProxyIPs.Contains(ip)
239 }
240 }
241
242 // Only trust the Forwarded header field if this is a trusted proxy IP
243 // to prevent users from spoofing the remote address
244 remoteAddr := req.RemoteAddr
245 if isProxy {
246 forwarded := parseForwarded(req.Header)
247 if forwarded["for"] != "" {
248 remoteAddr = forwarded["for"]
249 }
250 }
251
252 s.handle(newWebsocketIRCConn(conn, remoteAddr))
253}
254
255func parseForwarded(h http.Header) map[string]string {
256 forwarded := h.Get("Forwarded")
257 if forwarded == "" {
258 return map[string]string{
259 "for": h.Get("X-Forwarded-For"),
260 "proto": h.Get("X-Forwarded-Proto"),
261 "host": h.Get("X-Forwarded-Host"),
262 }
263 }
264 // Hack to easily parse header parameters
265 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
266 return params
267}
268
269type ServerStats struct {
270 Users int
271 Downstreams int64
272}
273
274func (s *Server) Stats() *ServerStats {
275 var stats ServerStats
276 s.lock.Lock()
277 stats.Users = len(s.users)
278 s.lock.Unlock()
279 stats.Downstreams = atomic.LoadInt64(&s.connCount)
280 return &stats
281}
282
283func (s *Server) SetMOTD(motd string) {
284 s.motd.Store(motd)
285}
286
287func (s *Server) MOTD() string {
288 return s.motd.Load().(string)
289}
Note: See TracBrowser for help on using the repository browser.