source: code/trunk/server.go@ 671

Last change on this file since 671 was 670, checked in by contact, 4 years ago

Turn CHATHISTORY and backlog limits into constants

File size: 5.7 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[656]5 "errors"
[1]6 "fmt"
[656]7 "io"
[37]8 "log"
[472]9 "mime"
[1]10 "net"
[323]11 "net/http"
[24]12 "sync"
[323]13 "sync/atomic"
[67]14 "time"
[1]15
16 "gopkg.in/irc.v3"
[323]17 "nhooyr.io/websocket"
[370]18
19 "git.sr.ht/~emersion/soju/config"
[1]20)
21
[67]22// TODO: make configurable
[398]23var retryConnectDelay = time.Minute
[206]24var connectTimeout = 15 * time.Second
[205]25var writeTimeout = 10 * time.Second
[398]26var upstreamMessageDelay = 2 * time.Second
27var upstreamMessageBurst = 10
[667]28var messageStoreTimeout = 10 * time.Second
[670]29var chatHistoryLimit = 1000
30var backlogLimit = 4000
[67]31
[9]32type Logger interface {
33 Print(v ...interface{})
34 Printf(format string, v ...interface{})
35}
36
[21]37type prefixLogger struct {
38 logger Logger
39 prefix string
40}
41
42var _ Logger = (*prefixLogger)(nil)
43
44func (l *prefixLogger) Print(v ...interface{}) {
45 v = append([]interface{}{l.prefix}, v...)
46 l.logger.Print(v...)
47}
48
49func (l *prefixLogger) Printf(format string, v ...interface{}) {
50 v = append([]interface{}{l.prefix}, v...)
51 l.logger.Printf("%v"+format, v...)
52}
53
[10]54type Server struct {
[612]55 Hostname string
[662]56 Title string
[612]57 Logger Logger
58 LogPath string
59 Debug bool
60 HTTPOrigins []string
61 AcceptProxyIPs config.IPSet
62 MaxUserNetworks int
63 Identd *Identd // can be nil
[22]64
[605]65 db Database
66 stopWG sync.WaitGroup
67 connCount int64 // atomic
[77]68
[449]69 lock sync.Mutex
70 listeners map[net.Listener]struct{}
71 users map[string]*user
[636]72
73 motd atomic.Value // string
[10]74}
75
[531]76func NewServer(db Database) *Server {
[636]77 srv := &Server{
[612]78 Logger: log.New(log.Writer(), "", log.LstdFlags),
79 MaxUserNetworks: -1,
80 db: db,
81 listeners: make(map[net.Listener]struct{}),
82 users: make(map[string]*user),
[37]83 }
[636]84 srv.motd.Store("")
85 return srv
[37]86}
87
[5]88func (s *Server) prefix() *irc.Prefix {
89 return &irc.Prefix{Name: s.Hostname}
90}
91
[449]92func (s *Server) Start() error {
[652]93 users, err := s.db.ListUsers(context.TODO())
[77]94 if err != nil {
95 return err
96 }
[71]97
[77]98 s.lock.Lock()
[378]99 for i := range users {
100 s.addUserLocked(&users[i])
[71]101 }
[37]102 s.lock.Unlock()
103
[449]104 return nil
[10]105}
106
[449]107func (s *Server) Shutdown() {
108 s.lock.Lock()
109 for ln := range s.listeners {
110 if err := ln.Close(); err != nil {
111 s.Logger.Printf("failed to stop listener: %v", err)
112 }
113 }
114 for _, u := range s.users {
115 u.events <- eventStop{}
116 }
117 s.lock.Unlock()
118
119 s.stopWG.Wait()
[599]120
121 if err := s.db.Close(); err != nil {
122 s.Logger.Printf("failed to close DB: %v", err)
123 }
[449]124}
125
[329]126func (s *Server) createUser(user *User) (*user, error) {
127 s.lock.Lock()
128 defer s.lock.Unlock()
129
130 if _, ok := s.users[user.Username]; ok {
131 return nil, fmt.Errorf("user %q already exists", user.Username)
132 }
133
[652]134 err := s.db.StoreUser(context.TODO(), user)
[329]135 if err != nil {
136 return nil, fmt.Errorf("could not create user in db: %v", err)
137 }
138
[378]139 return s.addUserLocked(user), nil
[329]140}
141
[563]142func (s *Server) forEachUser(f func(*user)) {
143 s.lock.Lock()
144 for _, u := range s.users {
145 f(u)
146 }
147 s.lock.Unlock()
148}
149
[38]150func (s *Server) getUser(name string) *user {
151 s.lock.Lock()
152 u := s.users[name]
153 s.lock.Unlock()
154 return u
155}
156
[378]157func (s *Server) addUserLocked(user *User) *user {
158 s.Logger.Printf("starting bouncer for user %q", user.Username)
159 u := newUser(s, user)
160 s.users[u.Username] = u
161
[449]162 s.stopWG.Add(1)
163
[378]164 go func() {
165 u.run()
166
167 s.lock.Lock()
168 delete(s.users, u.Username)
169 s.lock.Unlock()
[449]170
171 s.stopWG.Done()
[378]172 }()
173
174 return u
175}
176
[323]177var lastDownstreamID uint64 = 0
178
[347]179func (s *Server) handle(ic ircConn) {
[605]180 atomic.AddInt64(&s.connCount, 1)
[323]181 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]182 dc := newDownstreamConn(s, ic, id)
[323]183 if err := dc.runUntilRegistered(); err != nil {
[655]184 if !errors.Is(err, io.EOF) {
185 dc.logger.Print(err)
186 }
[323]187 } else {
188 dc.user.events <- eventDownstreamConnected{dc}
189 if err := dc.readMessages(dc.user.events); err != nil {
190 dc.logger.Print(err)
191 }
192 dc.user.events <- eventDownstreamDisconnected{dc}
193 }
194 dc.Close()
[605]195 atomic.AddInt64(&s.connCount, -1)
[323]196}
197
[3]198func (s *Server) Serve(ln net.Listener) error {
[449]199 s.lock.Lock()
200 s.listeners[ln] = struct{}{}
201 s.lock.Unlock()
202
203 s.stopWG.Add(1)
204
205 defer func() {
206 s.lock.Lock()
207 delete(s.listeners, ln)
208 s.lock.Unlock()
209
210 s.stopWG.Done()
211 }()
212
[1]213 for {
[323]214 conn, err := ln.Accept()
[601]215 if isErrClosed(err) {
[449]216 return nil
217 } else if err != nil {
[1]218 return fmt.Errorf("failed to accept connection: %v", err)
219 }
220
[347]221 go s.handle(newNetIRCConn(conn))
[1]222 }
223}
[323]224
225func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
226 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]227 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[323]228 OriginPatterns: s.HTTPOrigins,
229 })
230 if err != nil {
231 s.Logger.Printf("failed to serve HTTP connection: %v", err)
232 return
233 }
[345]234
[370]235 isProxy := false
[345]236 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
237 if ip := net.ParseIP(host); ip != nil {
[370]238 isProxy = s.AcceptProxyIPs.Contains(ip)
[345]239 }
240 }
241
[474]242 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]243 // to prevent users from spoofing the remote address
[344]244 remoteAddr := req.RemoteAddr
[472]245 if isProxy {
246 forwarded := parseForwarded(req.Header)
[473]247 if forwarded["for"] != "" {
248 remoteAddr = forwarded["for"]
[472]249 }
[344]250 }
[345]251
[347]252 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]253}
[472]254
255func parseForwarded(h http.Header) map[string]string {
256 forwarded := h.Get("Forwarded")
257 if forwarded == "" {
[474]258 return map[string]string{
259 "for": h.Get("X-Forwarded-For"),
260 "proto": h.Get("X-Forwarded-Proto"),
261 "host": h.Get("X-Forwarded-Host"),
262 }
[472]263 }
264 // Hack to easily parse header parameters
265 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
266 return params
267}
[605]268
269type ServerStats struct {
270 Users int
271 Downstreams int64
272}
273
274func (s *Server) Stats() *ServerStats {
275 var stats ServerStats
276 s.lock.Lock()
277 stats.Users = len(s.users)
278 s.lock.Unlock()
279 stats.Downstreams = atomic.LoadInt64(&s.connCount)
280 return &stats
281}
[636]282
283func (s *Server) SetMOTD(motd string) {
284 s.motd.Store(motd)
285}
286
287func (s *Server) MOTD() string {
288 return s.motd.Load().(string)
289}
Note: See TracBrowser for help on using the repository browser.