source: code/trunk/server.go@ 655

Last change on this file since 655 was 655, checked in by contact, 4 years ago

Get rid of io.EOF errors in logs

Closes: https://todo.sr.ht/~emersion/soju/150

File size: 5.7 KB
RevLine 
[98]1package soju
[1]2
3import (
[652]4 "context"
[1]5 "fmt"
[37]6 "log"
[472]7 "mime"
[1]8 "net"
[323]9 "net/http"
[24]10 "sync"
[323]11 "sync/atomic"
[67]12 "time"
[1]13
14 "gopkg.in/irc.v3"
[323]15 "nhooyr.io/websocket"
[370]16
17 "git.sr.ht/~emersion/soju/config"
[1]18)
19
[67]20// TODO: make configurable
[398]21var retryConnectDelay = time.Minute
[206]22var connectTimeout = 15 * time.Second
[205]23var writeTimeout = 10 * time.Second
[398]24var upstreamMessageDelay = 2 * time.Second
25var upstreamMessageBurst = 10
[67]26
[9]27type Logger interface {
28 Print(v ...interface{})
29 Printf(format string, v ...interface{})
30}
31
[21]32type prefixLogger struct {
33 logger Logger
34 prefix string
35}
36
37var _ Logger = (*prefixLogger)(nil)
38
39func (l *prefixLogger) Print(v ...interface{}) {
40 v = append([]interface{}{l.prefix}, v...)
41 l.logger.Print(v...)
42}
43
44func (l *prefixLogger) Printf(format string, v ...interface{}) {
45 v = append([]interface{}{l.prefix}, v...)
46 l.logger.Printf("%v"+format, v...)
47}
48
[10]49type Server struct {
[612]50 Hostname string
51 Logger Logger
52 HistoryLimit int
53 LogPath string
54 Debug bool
55 HTTPOrigins []string
56 AcceptProxyIPs config.IPSet
57 MaxUserNetworks int
58 Identd *Identd // can be nil
[22]59
[605]60 db Database
61 stopWG sync.WaitGroup
62 connCount int64 // atomic
[77]63
[449]64 lock sync.Mutex
65 listeners map[net.Listener]struct{}
66 users map[string]*user
[636]67
68 motd atomic.Value // string
[10]69}
70
[531]71func NewServer(db Database) *Server {
[636]72 srv := &Server{
[612]73 Logger: log.New(log.Writer(), "", log.LstdFlags),
74 HistoryLimit: 1000,
75 MaxUserNetworks: -1,
76 db: db,
77 listeners: make(map[net.Listener]struct{}),
78 users: make(map[string]*user),
[37]79 }
[636]80 srv.motd.Store("")
81 return srv
[37]82}
83
[5]84func (s *Server) prefix() *irc.Prefix {
85 return &irc.Prefix{Name: s.Hostname}
86}
87
[449]88func (s *Server) Start() error {
[652]89 users, err := s.db.ListUsers(context.TODO())
[77]90 if err != nil {
91 return err
92 }
[71]93
[77]94 s.lock.Lock()
[378]95 for i := range users {
96 s.addUserLocked(&users[i])
[71]97 }
[37]98 s.lock.Unlock()
99
[449]100 return nil
[10]101}
102
[449]103func (s *Server) Shutdown() {
104 s.lock.Lock()
105 for ln := range s.listeners {
106 if err := ln.Close(); err != nil {
107 s.Logger.Printf("failed to stop listener: %v", err)
108 }
109 }
110 for _, u := range s.users {
111 u.events <- eventStop{}
112 }
113 s.lock.Unlock()
114
115 s.stopWG.Wait()
[599]116
117 if err := s.db.Close(); err != nil {
118 s.Logger.Printf("failed to close DB: %v", err)
119 }
[449]120}
121
[329]122func (s *Server) createUser(user *User) (*user, error) {
123 s.lock.Lock()
124 defer s.lock.Unlock()
125
126 if _, ok := s.users[user.Username]; ok {
127 return nil, fmt.Errorf("user %q already exists", user.Username)
128 }
129
[652]130 err := s.db.StoreUser(context.TODO(), user)
[329]131 if err != nil {
132 return nil, fmt.Errorf("could not create user in db: %v", err)
133 }
134
[378]135 return s.addUserLocked(user), nil
[329]136}
137
[563]138func (s *Server) forEachUser(f func(*user)) {
139 s.lock.Lock()
140 for _, u := range s.users {
141 f(u)
142 }
143 s.lock.Unlock()
144}
145
[38]146func (s *Server) getUser(name string) *user {
147 s.lock.Lock()
148 u := s.users[name]
149 s.lock.Unlock()
150 return u
151}
152
[378]153func (s *Server) addUserLocked(user *User) *user {
154 s.Logger.Printf("starting bouncer for user %q", user.Username)
155 u := newUser(s, user)
156 s.users[u.Username] = u
157
[449]158 s.stopWG.Add(1)
159
[378]160 go func() {
161 u.run()
162
163 s.lock.Lock()
164 delete(s.users, u.Username)
165 s.lock.Unlock()
[449]166
167 s.stopWG.Done()
[378]168 }()
169
170 return u
171}
172
[323]173var lastDownstreamID uint64 = 0
174
[347]175func (s *Server) handle(ic ircConn) {
[605]176 atomic.AddInt64(&s.connCount, 1)
[323]177 id := atomic.AddUint64(&lastDownstreamID, 1)
[347]178 dc := newDownstreamConn(s, ic, id)
[323]179 if err := dc.runUntilRegistered(); err != nil {
[655]180 if !errors.Is(err, io.EOF) {
181 dc.logger.Print(err)
182 }
[323]183 } else {
184 dc.user.events <- eventDownstreamConnected{dc}
185 if err := dc.readMessages(dc.user.events); err != nil {
186 dc.logger.Print(err)
187 }
188 dc.user.events <- eventDownstreamDisconnected{dc}
189 }
190 dc.Close()
[605]191 atomic.AddInt64(&s.connCount, -1)
[323]192}
193
[3]194func (s *Server) Serve(ln net.Listener) error {
[449]195 s.lock.Lock()
196 s.listeners[ln] = struct{}{}
197 s.lock.Unlock()
198
199 s.stopWG.Add(1)
200
201 defer func() {
202 s.lock.Lock()
203 delete(s.listeners, ln)
204 s.lock.Unlock()
205
206 s.stopWG.Done()
207 }()
208
[1]209 for {
[323]210 conn, err := ln.Accept()
[601]211 if isErrClosed(err) {
[449]212 return nil
213 } else if err != nil {
[1]214 return fmt.Errorf("failed to accept connection: %v", err)
215 }
216
[347]217 go s.handle(newNetIRCConn(conn))
[1]218 }
219}
[323]220
221func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
222 conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
[597]223 Subprotocols: []string{"text.ircv3.net"}, // non-compliant, fight me
[323]224 OriginPatterns: s.HTTPOrigins,
225 })
226 if err != nil {
227 s.Logger.Printf("failed to serve HTTP connection: %v", err)
228 return
229 }
[345]230
[370]231 isProxy := false
[345]232 if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
233 if ip := net.ParseIP(host); ip != nil {
[370]234 isProxy = s.AcceptProxyIPs.Contains(ip)
[345]235 }
236 }
237
[474]238 // Only trust the Forwarded header field if this is a trusted proxy IP
[345]239 // to prevent users from spoofing the remote address
[344]240 remoteAddr := req.RemoteAddr
[472]241 if isProxy {
242 forwarded := parseForwarded(req.Header)
[473]243 if forwarded["for"] != "" {
244 remoteAddr = forwarded["for"]
[472]245 }
[344]246 }
[345]247
[347]248 s.handle(newWebsocketIRCConn(conn, remoteAddr))
[323]249}
[472]250
251func parseForwarded(h http.Header) map[string]string {
252 forwarded := h.Get("Forwarded")
253 if forwarded == "" {
[474]254 return map[string]string{
255 "for": h.Get("X-Forwarded-For"),
256 "proto": h.Get("X-Forwarded-Proto"),
257 "host": h.Get("X-Forwarded-Host"),
258 }
[472]259 }
260 // Hack to easily parse header parameters
261 _, params, _ := mime.ParseMediaType("hack; " + forwarded)
262 return params
263}
[605]264
265type ServerStats struct {
266 Users int
267 Downstreams int64
268}
269
270func (s *Server) Stats() *ServerStats {
271 var stats ServerStats
272 s.lock.Lock()
273 stats.Users = len(s.users)
274 s.lock.Unlock()
275 stats.Downstreams = atomic.LoadInt64(&s.connCount)
276 return &stats
277}
[636]278
279func (s *Server) SetMOTD(motd string) {
280 s.motd.Store(motd)
281}
282
283func (s *Server) MOTD() string {
284 return s.motd.Load().(string)
285}
Note: See TracBrowser for help on using the repository browser.