source: code/trunk/upstream.go@ 78

Last change on this file since 78 was 77, checked in by contact, 5 years ago

Add SQLite database

Closes: https://todo.sr.ht/~emersion/jounce/9

File size: 9.6 KB
Line 
1package jounce
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "strings"
10 "time"
11
12 "gopkg.in/irc.v3"
13)
14
15type upstreamChannel struct {
16 Name string
17 conn *upstreamConn
18 Topic string
19 TopicWho string
20 TopicTime time.Time
21 Status channelStatus
22 modes modeSet
23 Members map[string]membership
24 complete bool
25}
26
27type upstreamConn struct {
28 network *network
29 logger Logger
30 net net.Conn
31 irc *irc.Conn
32 srv *Server
33 user *user
34 messages chan<- *irc.Message
35 ring *Ring
36
37 serverName string
38 availableUserModes string
39 availableChannelModes string
40 channelModesWithParam string
41
42 registered bool
43 nick string
44 username string
45 realname string
46 closed bool
47 modes modeSet
48 channels map[string]*upstreamChannel
49 history map[string]uint64
50}
51
52func connectToUpstream(network *network) (*upstreamConn, error) {
53 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
54
55 addr := network.Addr
56 if !strings.ContainsRune(addr, ':') {
57 addr = addr + ":6697"
58 }
59
60 logger.Printf("connecting to TLS server at address %q", addr)
61 netConn, err := tls.Dial("tcp", addr, nil)
62 if err != nil {
63 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
64 }
65
66 setKeepAlive(netConn)
67
68 msgs := make(chan *irc.Message, 64)
69 uc := &upstreamConn{
70 network: network,
71 logger: logger,
72 net: netConn,
73 irc: irc.NewConn(netConn),
74 srv: network.user.srv,
75 user: network.user,
76 messages: msgs,
77 ring: NewRing(network.user.srv.RingCap),
78 channels: make(map[string]*upstreamChannel),
79 history: make(map[string]uint64),
80 }
81
82 go func() {
83 for msg := range msgs {
84 if uc.srv.Debug {
85 uc.logger.Printf("sent: %v", msg)
86 }
87 if err := uc.irc.WriteMessage(msg); err != nil {
88 uc.logger.Printf("failed to write message: %v", err)
89 }
90 }
91 if err := uc.net.Close(); err != nil {
92 uc.logger.Printf("failed to close connection: %v", err)
93 } else {
94 uc.logger.Printf("connection closed")
95 }
96 }()
97
98 return uc, nil
99}
100
101func (uc *upstreamConn) Close() error {
102 if uc.closed {
103 return fmt.Errorf("upstream connection already closed")
104 }
105 close(uc.messages)
106 uc.closed = true
107 return nil
108}
109
110func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
111 uc.user.forEachDownstream(func(dc *downstreamConn) {
112 if dc.network != nil && dc.network != uc.network {
113 return
114 }
115 f(dc)
116 })
117}
118
119func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
120 ch, ok := uc.channels[name]
121 if !ok {
122 return nil, fmt.Errorf("unknown channel %q", name)
123 }
124 return ch, nil
125}
126
127func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
128 switch msg.Command {
129 case "PING":
130 uc.SendMessage(&irc.Message{
131 Command: "PONG",
132 Params: msg.Params,
133 })
134 return nil
135 case "MODE":
136 if msg.Prefix == nil {
137 return fmt.Errorf("missing prefix")
138 }
139
140 var name, modeStr string
141 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
142 return err
143 }
144
145 if name == msg.Prefix.Name { // user mode change
146 if name != uc.nick {
147 return fmt.Errorf("received MODE message for unknow nick %q", name)
148 }
149 return uc.modes.Apply(modeStr)
150 } else { // channel mode change
151 ch, err := uc.getChannel(name)
152 if err != nil {
153 return err
154 }
155 if err := ch.modes.Apply(modeStr); err != nil {
156 return err
157 }
158
159 uc.forEachDownstream(func(dc *downstreamConn) {
160 dc.SendMessage(&irc.Message{
161 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
162 Command: "MODE",
163 Params: []string{dc.marshalChannel(uc, name), modeStr},
164 })
165 })
166 }
167 case "NOTICE":
168 uc.logger.Print(msg)
169 case irc.RPL_WELCOME:
170 uc.registered = true
171 uc.logger.Printf("connection registered")
172
173 channels, err := uc.srv.db.ListChannels(uc.network.ID)
174 if err != nil {
175 uc.logger.Printf("failed to list channels from database: %v", err)
176 break
177 }
178
179 for _, ch := range channels {
180 uc.SendMessage(&irc.Message{
181 Command: "JOIN",
182 Params: []string{ch.Name},
183 })
184 }
185 case irc.RPL_MYINFO:
186 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
187 return err
188 }
189 if len(msg.Params) > 5 {
190 uc.channelModesWithParam = msg.Params[5]
191 }
192 case "NICK":
193 var newNick string
194 if err := parseMessageParams(msg, &newNick); err != nil {
195 return err
196 }
197
198 if msg.Prefix.Name == uc.nick {
199 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
200 uc.nick = newNick
201 }
202
203 for _, ch := range uc.channels {
204 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
205 delete(ch.Members, msg.Prefix.Name)
206 ch.Members[newNick] = membership
207 }
208 }
209 case "JOIN":
210 if msg.Prefix == nil {
211 return fmt.Errorf("expected a prefix")
212 }
213
214 var channels string
215 if err := parseMessageParams(msg, &channels); err != nil {
216 return err
217 }
218
219 for _, ch := range strings.Split(channels, ",") {
220 if msg.Prefix.Name == uc.nick {
221 uc.logger.Printf("joined channel %q", ch)
222 uc.channels[ch] = &upstreamChannel{
223 Name: ch,
224 conn: uc,
225 Members: make(map[string]membership),
226 }
227 } else {
228 ch, err := uc.getChannel(ch)
229 if err != nil {
230 return err
231 }
232 ch.Members[msg.Prefix.Name] = 0
233 }
234
235 uc.forEachDownstream(func(dc *downstreamConn) {
236 dc.SendMessage(&irc.Message{
237 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
238 Command: "JOIN",
239 Params: []string{dc.marshalChannel(uc, ch)},
240 })
241 })
242 }
243 case "PART":
244 if msg.Prefix == nil {
245 return fmt.Errorf("expected a prefix")
246 }
247
248 var channels string
249 if err := parseMessageParams(msg, &channels); err != nil {
250 return err
251 }
252
253 for _, ch := range strings.Split(channels, ",") {
254 if msg.Prefix.Name == uc.nick {
255 uc.logger.Printf("parted channel %q", ch)
256 delete(uc.channels, ch)
257 } else {
258 ch, err := uc.getChannel(ch)
259 if err != nil {
260 return err
261 }
262 delete(ch.Members, msg.Prefix.Name)
263 }
264
265 uc.forEachDownstream(func(dc *downstreamConn) {
266 dc.SendMessage(&irc.Message{
267 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
268 Command: "PART",
269 Params: []string{dc.marshalChannel(uc, ch)},
270 })
271 })
272 }
273 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
274 var name, topic string
275 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
276 return err
277 }
278 ch, err := uc.getChannel(name)
279 if err != nil {
280 return err
281 }
282 if msg.Command == irc.RPL_TOPIC {
283 ch.Topic = topic
284 } else {
285 ch.Topic = ""
286 }
287 case "TOPIC":
288 var name string
289 if err := parseMessageParams(msg, &name); err != nil {
290 return err
291 }
292 ch, err := uc.getChannel(name)
293 if err != nil {
294 return err
295 }
296 if len(msg.Params) > 1 {
297 ch.Topic = msg.Params[1]
298 } else {
299 ch.Topic = ""
300 }
301 uc.forEachDownstream(func(dc *downstreamConn) {
302 params := []string{dc.marshalChannel(uc, name)}
303 if ch.Topic != "" {
304 params = append(params, ch.Topic)
305 }
306 dc.SendMessage(&irc.Message{
307 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
308 Command: "TOPIC",
309 Params: params,
310 })
311 })
312 case rpl_topicwhotime:
313 var name, who, timeStr string
314 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
315 return err
316 }
317 ch, err := uc.getChannel(name)
318 if err != nil {
319 return err
320 }
321 ch.TopicWho = who
322 sec, err := strconv.ParseInt(timeStr, 10, 64)
323 if err != nil {
324 return fmt.Errorf("failed to parse topic time: %v", err)
325 }
326 ch.TopicTime = time.Unix(sec, 0)
327 case irc.RPL_NAMREPLY:
328 var name, statusStr, members string
329 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
330 return err
331 }
332 ch, err := uc.getChannel(name)
333 if err != nil {
334 return err
335 }
336
337 status, err := parseChannelStatus(statusStr)
338 if err != nil {
339 return err
340 }
341 ch.Status = status
342
343 for _, s := range strings.Split(members, " ") {
344 membership, nick := parseMembershipPrefix(s)
345 ch.Members[nick] = membership
346 }
347 case irc.RPL_ENDOFNAMES:
348 var name string
349 if err := parseMessageParams(msg, nil, &name); err != nil {
350 return err
351 }
352 ch, err := uc.getChannel(name)
353 if err != nil {
354 return err
355 }
356
357 if ch.complete {
358 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
359 }
360 ch.complete = true
361
362 uc.forEachDownstream(func(dc *downstreamConn) {
363 forwardChannel(dc, ch)
364 })
365 case "PRIVMSG":
366 if err := parseMessageParams(msg, nil, nil); err != nil {
367 return err
368 }
369 uc.ring.Produce(msg)
370 case irc.RPL_YOURHOST, irc.RPL_CREATED:
371 // Ignore
372 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
373 // Ignore
374 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
375 // Ignore
376 case rpl_localusers, rpl_globalusers:
377 // Ignore
378 case irc.RPL_STATSVLINE, irc.RPL_STATSPING, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
379 // Ignore
380 default:
381 uc.logger.Printf("unhandled upstream message: %v", msg)
382 }
383 return nil
384}
385
386func (uc *upstreamConn) register() {
387 uc.nick = uc.network.Nick
388 uc.username = uc.network.Username
389 if uc.username == "" {
390 uc.username = uc.nick
391 }
392 uc.realname = uc.network.Realname
393 if uc.realname == "" {
394 uc.realname = uc.nick
395 }
396
397 uc.SendMessage(&irc.Message{
398 Command: "NICK",
399 Params: []string{uc.nick},
400 })
401 uc.SendMessage(&irc.Message{
402 Command: "USER",
403 Params: []string{uc.username, "0", "*", uc.realname},
404 })
405}
406
407func (uc *upstreamConn) readMessages() error {
408 for {
409 msg, err := uc.irc.ReadMessage()
410 if err == io.EOF {
411 break
412 } else if err != nil {
413 return fmt.Errorf("failed to read IRC command: %v", err)
414 }
415
416 if uc.srv.Debug {
417 uc.logger.Printf("received: %v", msg)
418 }
419
420 if err := uc.handleMessage(msg); err != nil {
421 uc.logger.Printf("failed to handle message %q: %v", msg, err)
422 }
423 }
424
425 return nil
426}
427
428func (uc *upstreamConn) SendMessage(msg *irc.Message) {
429 uc.messages <- msg
430}
Note: See TracBrowser for help on using the repository browser.