source: code/trunk/server.go@ 67

Last change on this file since 67 was 67, checked in by contact, 5 years ago

Enable TCP keep-alive on all connections

File size: 4.1 KB
Line 
1package jounce
2
3import (
4 "fmt"
5 "log"
6 "net"
7 "sync"
8 "time"
9
10 "gopkg.in/irc.v3"
11)
12
13// TODO: make configurable
14var keepAlivePeriod = time.Minute
15
16func setKeepAlive(c net.Conn) error {
17 tcpConn, ok := c.(*net.TCPConn)
18 if !ok {
19 return fmt.Errorf("cannot enable keep-alive on a non-TCP connection")
20 }
21 if err := tcpConn.SetKeepAlive(true); err != nil {
22 return err
23 }
24 return tcpConn.SetKeepAlivePeriod(keepAlivePeriod)
25}
26
27type Logger interface {
28 Print(v ...interface{})
29 Printf(format string, v ...interface{})
30}
31
32type prefixLogger struct {
33 logger Logger
34 prefix string
35}
36
37var _ Logger = (*prefixLogger)(nil)
38
39func (l *prefixLogger) Print(v ...interface{}) {
40 v = append([]interface{}{l.prefix}, v...)
41 l.logger.Print(v...)
42}
43
44func (l *prefixLogger) Printf(format string, v ...interface{}) {
45 v = append([]interface{}{l.prefix}, v...)
46 l.logger.Printf("%v"+format, v...)
47}
48
49type user struct {
50 username string
51 srv *Server
52
53 lock sync.Mutex
54 upstreamConns []*upstreamConn
55 downstreamConns []*downstreamConn
56}
57
58func newUser(srv *Server, username string) *user {
59 return &user{
60 username: username,
61 srv: srv,
62 }
63}
64
65func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
66 u.lock.Lock()
67 for _, uc := range u.upstreamConns {
68 if !uc.registered || uc.closed {
69 continue
70 }
71 f(uc)
72 }
73 u.lock.Unlock()
74}
75
76func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
77 u.lock.Lock()
78 for _, dc := range u.downstreamConns {
79 f(dc)
80 }
81 u.lock.Unlock()
82}
83
84func (u *user) getChannel(name string) (*upstreamChannel, error) {
85 var channel *upstreamChannel
86 var err error
87 u.forEachUpstream(func(uc *upstreamConn) {
88 if err != nil {
89 return
90 }
91 if ch, ok := uc.channels[name]; ok {
92 if channel != nil {
93 err = fmt.Errorf("ambiguous channel name %q", name)
94 } else {
95 channel = ch
96 }
97 }
98 })
99 if channel == nil {
100 return nil, ircError{&irc.Message{
101 Command: irc.ERR_NOSUCHCHANNEL,
102 Params: []string{name, "No such channel"},
103 }}
104 }
105 return channel, nil
106}
107
108type Upstream struct {
109 Addr string
110 Nick string
111 Username string
112 Realname string
113 Channels []string
114}
115
116type Server struct {
117 Hostname string
118 Logger Logger
119 RingCap int
120 Debug bool
121 Upstreams []Upstream // TODO: per-user
122
123 lock sync.Mutex
124 users map[string]*user
125 downstreamConns []*downstreamConn
126}
127
128func NewServer() *Server {
129 return &Server{
130 Logger: log.New(log.Writer(), "", log.LstdFlags),
131 RingCap: 4096,
132 users: make(map[string]*user),
133 }
134}
135
136func (s *Server) prefix() *irc.Prefix {
137 return &irc.Prefix{Name: s.Hostname}
138}
139
140func (s *Server) Run() {
141 // TODO: multi-user
142 u := newUser(s, "jounce")
143
144 s.lock.Lock()
145 s.users[u.username] = u
146 s.lock.Unlock()
147
148 for i := range s.Upstreams {
149 upstream := &s.Upstreams[i]
150 // TODO: retry connecting
151 go func() {
152 uc, err := connectToUpstream(u, upstream)
153 if err != nil {
154 s.Logger.Printf("failed to connect to upstream server %q: %v", upstream.Addr, err)
155 return
156 }
157
158 uc.register()
159
160 u.lock.Lock()
161 u.upstreamConns = append(u.upstreamConns, uc)
162 u.lock.Unlock()
163
164 if err := uc.readMessages(); err != nil {
165 uc.logger.Printf("failed to handle messages: %v", err)
166 }
167 uc.Close()
168
169 u.lock.Lock()
170 for i := range u.upstreamConns {
171 if u.upstreamConns[i] == uc {
172 u.upstreamConns = append(u.upstreamConns[:i], u.upstreamConns[i+1:]...)
173 break
174 }
175 }
176 u.lock.Unlock()
177 }()
178 }
179}
180
181func (s *Server) getUser(name string) *user {
182 s.lock.Lock()
183 u := s.users[name]
184 s.lock.Unlock()
185 return u
186}
187
188func (s *Server) Serve(ln net.Listener) error {
189 for {
190 netConn, err := ln.Accept()
191 if err != nil {
192 return fmt.Errorf("failed to accept connection: %v", err)
193 }
194
195 setKeepAlive(netConn)
196
197 dc := newDownstreamConn(s, netConn)
198 go func() {
199 s.lock.Lock()
200 s.downstreamConns = append(s.downstreamConns, dc)
201 s.lock.Unlock()
202
203 if err := dc.readMessages(); err != nil {
204 dc.logger.Printf("failed to handle messages: %v", err)
205 }
206 dc.Close()
207
208 s.lock.Lock()
209 for i := range s.downstreamConns {
210 if s.downstreamConns[i] == dc {
211 s.downstreamConns = append(s.downstreamConns[:i], s.downstreamConns[i+1:]...)
212 break
213 }
214 }
215 s.lock.Unlock()
216 }()
217 }
218}
Note: See TracBrowser for help on using the repository browser.