source: code/trunk/upstream.go@ 67

Last change on this file since 67 was 67, checked in by contact, 5 years ago

Enable TCP keep-alive on all connections

File size: 8.4 KB
Line 
1package jounce
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "strings"
10 "time"
11
12 "gopkg.in/irc.v3"
13)
14
15type upstreamChannel struct {
16 Name string
17 conn *upstreamConn
18 Topic string
19 TopicWho string
20 TopicTime time.Time
21 Status channelStatus
22 modes modeSet
23 Members map[string]membership
24 complete bool
25}
26
27type upstreamConn struct {
28 upstream *Upstream
29 logger Logger
30 net net.Conn
31 irc *irc.Conn
32 srv *Server
33 user *user
34 messages chan<- *irc.Message
35 ring *Ring
36
37 serverName string
38 availableUserModes string
39 availableChannelModes string
40 channelModesWithParam string
41
42 registered bool
43 nick string
44 closed bool
45 modes modeSet
46 channels map[string]*upstreamChannel
47 history map[string]uint64
48}
49
50func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) {
51 logger := &prefixLogger{u.srv.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
52 logger.Printf("connecting to server")
53
54 netConn, err := tls.Dial("tcp", upstream.Addr, nil)
55 if err != nil {
56 return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
57 }
58
59 setKeepAlive(netConn)
60
61 msgs := make(chan *irc.Message, 64)
62 uc := &upstreamConn{
63 upstream: upstream,
64 logger: logger,
65 net: netConn,
66 irc: irc.NewConn(netConn),
67 srv: u.srv,
68 user: u,
69 messages: msgs,
70 ring: NewRing(u.srv.RingCap),
71 channels: make(map[string]*upstreamChannel),
72 history: make(map[string]uint64),
73 }
74
75 go func() {
76 for msg := range msgs {
77 if uc.srv.Debug {
78 uc.logger.Printf("sent: %v", msg)
79 }
80 if err := uc.irc.WriteMessage(msg); err != nil {
81 uc.logger.Printf("failed to write message: %v", err)
82 }
83 }
84 if err := uc.net.Close(); err != nil {
85 uc.logger.Printf("failed to close connection: %v", err)
86 } else {
87 uc.logger.Printf("connection closed")
88 }
89 }()
90
91 return uc, nil
92}
93
94func (uc *upstreamConn) Close() error {
95 if uc.closed {
96 return fmt.Errorf("upstream connection already closed")
97 }
98 close(uc.messages)
99 uc.closed = true
100 return nil
101}
102
103func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
104 ch, ok := uc.channels[name]
105 if !ok {
106 return nil, fmt.Errorf("unknown channel %q", name)
107 }
108 return ch, nil
109}
110
111func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
112 switch msg.Command {
113 case "PING":
114 var from, to string
115 if len(msg.Params) >= 1 {
116 from = msg.Params[0]
117 }
118 if len(msg.Params) >= 2 {
119 to = msg.Params[1]
120 }
121
122 if to != "" && to != uc.srv.Hostname {
123 return fmt.Errorf("invalid PING destination %q", to)
124 }
125
126 params := []string{uc.srv.Hostname}
127 if from != "" {
128 params = append(params, from)
129 }
130 uc.SendMessage(&irc.Message{
131 Command: "PONG",
132 Params: params,
133 })
134 return nil
135 case "MODE":
136 var name, modeStr string
137 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
138 return err
139 }
140
141 if name == msg.Prefix.Name { // user mode change
142 if name != uc.nick {
143 return fmt.Errorf("received MODE message for unknow nick %q", name)
144 }
145 return uc.modes.Apply(modeStr)
146 } else { // channel mode change
147 ch, err := uc.getChannel(name)
148 if err != nil {
149 return err
150 }
151 if err := ch.modes.Apply(modeStr); err != nil {
152 return err
153 }
154 }
155
156 uc.user.forEachDownstream(func(dc *downstreamConn) {
157 dc.SendMessage(msg)
158 })
159 case "NOTICE":
160 uc.logger.Print(msg)
161 case irc.RPL_WELCOME:
162 uc.registered = true
163 uc.logger.Printf("connection registered")
164
165 for _, ch := range uc.upstream.Channels {
166 uc.SendMessage(&irc.Message{
167 Command: "JOIN",
168 Params: []string{ch},
169 })
170 }
171 case irc.RPL_MYINFO:
172 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
173 return err
174 }
175 if len(msg.Params) > 5 {
176 uc.channelModesWithParam = msg.Params[5]
177 }
178 case "NICK":
179 var newNick string
180 if err := parseMessageParams(msg, &newNick); err != nil {
181 return err
182 }
183
184 if msg.Prefix.Name == uc.nick {
185 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
186 uc.nick = newNick
187 }
188
189 for _, ch := range uc.channels {
190 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
191 delete(ch.Members, msg.Prefix.Name)
192 ch.Members[newNick] = membership
193 }
194 }
195
196 uc.user.forEachDownstream(func(dc *downstreamConn) {
197 dc.SendMessage(msg)
198 })
199 case "JOIN":
200 var channels string
201 if err := parseMessageParams(msg, &channels); err != nil {
202 return err
203 }
204
205 for _, ch := range strings.Split(channels, ",") {
206 if msg.Prefix.Name == uc.nick {
207 uc.logger.Printf("joined channel %q", ch)
208 uc.channels[ch] = &upstreamChannel{
209 Name: ch,
210 conn: uc,
211 Members: make(map[string]membership),
212 }
213 } else {
214 ch, err := uc.getChannel(ch)
215 if err != nil {
216 return err
217 }
218 ch.Members[msg.Prefix.Name] = 0
219 }
220 }
221
222 uc.user.forEachDownstream(func(dc *downstreamConn) {
223 dc.SendMessage(msg)
224 })
225 case "PART":
226 var channels string
227 if err := parseMessageParams(msg, &channels); err != nil {
228 return err
229 }
230
231 for _, ch := range strings.Split(channels, ",") {
232 if msg.Prefix.Name == uc.nick {
233 uc.logger.Printf("parted channel %q", ch)
234 delete(uc.channels, ch)
235 } else {
236 ch, err := uc.getChannel(ch)
237 if err != nil {
238 return err
239 }
240 delete(ch.Members, msg.Prefix.Name)
241 }
242 }
243
244 uc.user.forEachDownstream(func(dc *downstreamConn) {
245 dc.SendMessage(msg)
246 })
247 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
248 var name, topic string
249 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
250 return err
251 }
252 ch, err := uc.getChannel(name)
253 if err != nil {
254 return err
255 }
256 if msg.Command == irc.RPL_TOPIC {
257 ch.Topic = topic
258 } else {
259 ch.Topic = ""
260 }
261 case "TOPIC":
262 var name string
263 if err := parseMessageParams(msg, nil, &name); err != nil {
264 return err
265 }
266 ch, err := uc.getChannel(name)
267 if err != nil {
268 return err
269 }
270 if len(msg.Params) > 1 {
271 ch.Topic = msg.Params[1]
272 } else {
273 ch.Topic = ""
274 }
275 case rpl_topicwhotime:
276 var name, who, timeStr string
277 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
278 return err
279 }
280 ch, err := uc.getChannel(name)
281 if err != nil {
282 return err
283 }
284 ch.TopicWho = who
285 sec, err := strconv.ParseInt(timeStr, 10, 64)
286 if err != nil {
287 return fmt.Errorf("failed to parse topic time: %v", err)
288 }
289 ch.TopicTime = time.Unix(sec, 0)
290 case irc.RPL_NAMREPLY:
291 var name, statusStr, members string
292 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
293 return err
294 }
295 ch, err := uc.getChannel(name)
296 if err != nil {
297 return err
298 }
299
300 status, err := parseChannelStatus(statusStr)
301 if err != nil {
302 return err
303 }
304 ch.Status = status
305
306 for _, s := range strings.Split(members, " ") {
307 membership, nick := parseMembershipPrefix(s)
308 ch.Members[nick] = membership
309 }
310 case irc.RPL_ENDOFNAMES:
311 var name string
312 if err := parseMessageParams(msg, nil, &name); err != nil {
313 return err
314 }
315 ch, err := uc.getChannel(name)
316 if err != nil {
317 return err
318 }
319
320 if ch.complete {
321 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
322 }
323 ch.complete = true
324
325 uc.user.forEachDownstream(func(dc *downstreamConn) {
326 forwardChannel(dc, ch)
327 })
328 case "PRIVMSG":
329 uc.ring.Produce(msg)
330 case irc.RPL_YOURHOST, irc.RPL_CREATED:
331 // Ignore
332 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
333 // Ignore
334 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
335 // Ignore
336 case rpl_localusers, rpl_globalusers:
337 // Ignore
338 case irc.RPL_STATSVLINE, irc.RPL_STATSPING, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
339 // Ignore
340 default:
341 uc.logger.Printf("unhandled upstream message: %v", msg)
342 }
343 return nil
344}
345
346func (uc *upstreamConn) register() {
347 uc.nick = uc.upstream.Nick
348 uc.SendMessage(&irc.Message{
349 Command: "NICK",
350 Params: []string{uc.upstream.Nick},
351 })
352 uc.SendMessage(&irc.Message{
353 Command: "USER",
354 Params: []string{uc.upstream.Username, "0", "*", uc.upstream.Realname},
355 })
356}
357
358func (uc *upstreamConn) readMessages() error {
359 for {
360 msg, err := uc.irc.ReadMessage()
361 if err == io.EOF {
362 break
363 } else if err != nil {
364 return fmt.Errorf("failed to read IRC command: %v", err)
365 }
366
367 if uc.srv.Debug {
368 uc.logger.Printf("received: %v", msg)
369 }
370
371 if err := uc.handleMessage(msg); err != nil {
372 uc.logger.Printf("failed to handle message %q: %v", msg, err)
373 }
374 }
375
376 return nil
377}
378
379func (uc *upstreamConn) SendMessage(msg *irc.Message) {
380 uc.messages <- msg
381}
Note: See TracBrowser for help on using the repository browser.