source: code/trunk/server.go@ 651

Last change on this file since 651 was 636, checked in by contact, 4 years ago

Add bouncer MOTD

Closes: https://todo.sr.ht/~emersion/soju/137

File size: 5.6 KB
RevLine 
[98]1package soju
[1]2
3import (
4 "fmt"
[37]5 "log"
[472]6 "mime"
[1]7 "net"
[323]8 "net/http"
[24]9 "sync"
[323]10 "sync/atomic"
[67]11 "time"
[1]12
13 "gopkg.in/irc.v3"
[323]14 "nhooyr.io/websocket"
[370]15
16 "git.sr.ht/~emersion/soju/config"
[1]17)
18
[67]19// TODO: make configurable
[398]20var retryConnectDelay = time.Minute
[206]21var connectTimeout = 15 * time.Second
[205]22var writeTimeout = 10 * time.Second
[398]23var upstreamMessageDelay = 2 * time.Second
24var upstreamMessageBurst = 10
[67]25
[9]26type Logger interface {
27 Print(v ...interface{})
28 Printf(format string, v ...interface{})
29}
30
[21]31type prefixLogger struct {
32 logger Logger
33 prefix string
34}
35
36var _ Logger = (*prefixLogger)(nil)
37
38func (l *prefixLogger) Print(v ...interface{}) {
39 v = append([]interface{}{l.prefix}, v...)
40 l.logger.Print(v...)
41}
42
43func (l *prefixLogger) Printf(format string, v ...interface{}) {
44 v = append([]interface{}{l.prefix}, v...)
45 l.logger.Printf("%v"+format, v...)
46}
47
[10]48type Server struct {
[612]49 Hostname string
50 Logger Logger
51 HistoryLimit int
52 LogPath string
53 Debug bool
54 HTTPOrigins []string
55 AcceptProxyIPs config.IPSet
56 MaxUserNetworks int
57 Identd *Identd // can be nil
[22]58
[605]59 db Database
60 stopWG sync.WaitGroup
61 connCount int64 // atomic
[77]62
[449]63 lock sync.Mutex
64 listeners map[net.Listener]struct{}
65 users map[string]*user
[636]66
67 motd atomic.Value // string
[10]68}
69
[531]70func NewServer(db Database) *Server {
[636]71 srv := &Server{
[612]72 Logger: log.New(log.Writer(), "", log.LstdFlags),
73 HistoryLimit: 1000,
74 MaxUserNetworks: -1,
75 db: db,
76 listeners: make(map[net.Listener]struct{}),
77 users: make(map[string]*user),
[37]78 }
[636]79 srv.motd.Store("")
80 return srv
[37]81}
82
[5]83func (s *Server) prefix() *irc.Prefix {
84 return &irc.Prefix{Name: s.Hostname}
85}
86
[449]87func (s *Server) Start() error {
[77]88 users, err := s.db.ListUsers()
89 if err != nil {
90 return err
91 }
[71]92
[77]93 s.lock.Lock()
[378]94 for i := range users {
95 s.addUserLocked(&users[i])
[71]96 }
[37]97 s.lock.Unlock()
98
[449]99 return nil
[10]100}
101
[449]102func (s *Server) Shutdown() {
103 s.lock.Lock()
104 for ln := range s.listeners {
105 if err := ln.Close(); err != nil {
106 s.Logger.Printf("failed to stop listener: %v", err)
107 }
108 }
109 for _, u := range s.users {
110 u.events <- eventStop{}
111 }
112 s.lock.Unlock()
113
114 s.stopWG.Wait()
[599]115
116 if err := s.db.Close(); err != nil {
117 s.Logger.Printf("failed to close DB: %v", err)
118 }
[449]119}
120
[329]121func (s *Server) createUser(user *User) (*user, error) {
122 s.lock.Lock()
123 defer s.lock.Unlock()
124
125 if _, ok := s.users[user.Username]; ok {
126 return nil, fmt.Errorf("user %q already exists", user.Username)
127 }
128
129 err := s.db.StoreUser(user)
130 if err != nil {
131 return nil, fmt.Errorf("could not create user in db: %v", err)
132 }
133
[378]134 return s.addUserLocked(user), nil
[329]135}
136
[563]137func (s *Server) forEachUser(f func(*user)) {
138 s.lock.Lock()
139 for _, u := range s.users {
140 f(u)
141 }
142 s.lock.Unlock()
143}
144
[38]145func (s *Server) getUser(name string) *user {
146 s.lock.Lock()
147 u := s.users[name]
148 s.lock.Unlock()
149 return u
150}
151
[378]152func (s *Server) addUserLocked(user *User) *user {
153 s.Logger.Printf("starting bouncer for user %q", user.Username)
154 u := newUser(s, user)
155 s.users[u.Username] = u
156
[449]157 s.stopWG.Add(1)
158
[378]159 go func() {
160 u.run()
161
162 s.lock.Lock()
163 delete(s.users, u.Username)
164 s.lock.Unlock()
[449]165
166 s.stopWG.Done()
[378]167 }()
168
169 return u
170}
171
[323]172var lastDownstreamID uint64 = 0
173
[347]174func (s *Server) handle(ic ircConn) {
[605]175 atomic.AddInt64(&s.connCount, 1)
[323]176 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]177 dc := newDownstreamConn(s, ic, id)
[323]178 if err := dc.runUntilRegistered(); err != nil {
179 dc.logger.Print(err)
180 } else {
181 dc.user.events <- eventDownstreamConnected{dc}
182 if err := dc.readMessages(dc.user.events); err != nil {
183 dc.logger.Print(err)
184 }
185 dc.user.events <- eventDownstreamDisconnected{dc}
186 }
187 dc.Close()
[605]188 atomic.AddInt64(&s.connCount, -1)
[323]189}
190
[3]191func (s *Server) Serve(ln net.Listener) error {
[449]192 s.lock.Lock()
193 s.listeners[ln] = struct{}{}
194 s.lock.Unlock()
195
196 s.stopWG.Add(1)
197
198 defer func() {
199 s.lock.Lock()
200 delete(s.listeners, ln)
201 s.lock.Unlock()
202
203 s.stopWG.Done()
204 }()
205
[1]206 for {
[323]207 conn, err := ln.Accept()
[601]208 if isErrClosed(err) {
[449]209 return nil
210 } else if err != nil {
[1]211 return fmt.Errorf("failed to accept connection: %v", err)
212 }
213
[347]214 go s.handle(newNetIRCConn(conn))
[1]215 }
216}
[323]217
218func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
219 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]220 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[323]221 OriginPatterns: s.HTTPOrigins,
222 })
223 if err != nil {
224 s.Logger.Printf("failed to serve HTTP connection: %v", err)
225 return
226 }
[345]227
[370]228 isProxy := false
[345]229 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
230 if ip := net.ParseIP(host); ip != nil {
[370]231 isProxy = s.AcceptProxyIPs.Contains(ip)
[345]232 }
233 }
234
[474]235 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]236 // to prevent users from spoofing the remote address
[344]237 remoteAddr := req.RemoteAddr
[472]238 if isProxy {
239 forwarded := parseForwarded(req.Header)
[473]240 if forwarded["for"] != "" {
241 remoteAddr = forwarded["for"]
[472]242 }
[344]243 }
[345]244
[347]245 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]246}
[472]247
248func parseForwarded(h http.Header) map[string]string {
249 forwarded := h.Get("Forwarded")
250 if forwarded == "" {
[474]251 return map[string]string{
252 "for": h.Get("X-Forwarded-For"),
253 "proto": h.Get("X-Forwarded-Proto"),
254 "host": h.Get("X-Forwarded-Host"),
255 }
[472]256 }
257 // Hack to easily parse header parameters
258 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
259 return params
260}
[605]261
262type ServerStats struct {
263 Users int
264 Downstreams int64
265}
266
267func (s *Server) Stats() *ServerStats {
268 var stats ServerStats
269 s.lock.Lock()
270 stats.Users = len(s.users)
271 s.lock.Unlock()
272 stats.Downstreams = atomic.LoadInt64(&s.connCount)
273 return &stats
274}
[636]275
276func (s *Server) SetMOTD(motd string) {
277 s.motd.Store(motd)
278}
279
280func (s *Server) MOTD() string {
281 return s.motd.Load().(string)
282}
Note: See TracBrowser for help on using the repository browser.