source: code/trunk/upstream.go@ 96

Last change on this file since 96 was 96, checked in by contact, 5 years ago

Update dependencies

go-irc v3.1.1 contains a breaking change.

References: https://github.com/go-irc/irc/issues/76

File size: 15.2 KB
RevLine 
[13]1package jounce
2
3import (
4 "crypto/tls"
[95]5 "encoding/base64"
[13]6 "fmt"
7 "io"
8 "net"
[19]9 "strconv"
[17]10 "strings"
[19]11 "time"
[13]12
[95]13 "github.com/emersion/go-sasl"
[13]14 "gopkg.in/irc.v3"
15)
16
[19]17type upstreamChannel struct {
18 Name string
[46]19 conn *upstreamConn
[19]20 Topic string
21 TopicWho string
22 TopicTime time.Time
23 Status channelStatus
[35]24 modes modeSet
[19]25 Members map[string]membership
[25]26 complete bool
[19]27}
28
[13]29type upstreamConn struct {
[77]30 network *network
[21]31 logger Logger
[19]32 net net.Conn
33 irc *irc.Conn
34 srv *Server
[37]35 user *user
[33]36 messages chan<- *irc.Message
[50]37 ring *Ring
[16]38
39 serverName string
40 availableUserModes string
41 availableChannelModes string
42 channelModesWithParam string
[19]43
44 registered bool
[42]45 nick string
[77]46 username string
47 realname string
[33]48 closed bool
[19]49 modes modeSet
50 channels map[string]*upstreamChannel
[57]51 history map[string]uint64
[92]52 caps map[string]string
[95]53
54 saslClient sasl.Client
55 saslStarted bool
[13]56}
57
[77]58func connectToUpstream(network *network) (*upstreamConn, error) {
59 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
[33]60
[77]61 addr := network.Addr
62 if !strings.ContainsRune(addr, ':') {
63 addr = addr + ":6697"
64 }
65
66 logger.Printf("connecting to TLS server at address %q", addr)
67 netConn, err := tls.Dial("tcp", addr, nil)
[33]68 if err != nil {
[77]69 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
[33]70 }
71
[67]72 setKeepAlive(netConn)
73
[33]74 msgs := make(chan *irc.Message, 64)
[55]75 uc := &upstreamConn{
[79]76 network: network,
[33]77 logger: logger,
78 net: netConn,
79 irc: irc.NewConn(netConn),
[77]80 srv: network.user.srv,
81 user: network.user,
[33]82 messages: msgs,
[77]83 ring: NewRing(network.user.srv.RingCap),
[33]84 channels: make(map[string]*upstreamChannel),
[57]85 history: make(map[string]uint64),
[92]86 caps: make(map[string]string),
[33]87 }
88
89 go func() {
90 for msg := range msgs {
[64]91 if uc.srv.Debug {
92 uc.logger.Printf("sent: %v", msg)
93 }
[55]94 if err := uc.irc.WriteMessage(msg); err != nil {
95 uc.logger.Printf("failed to write message: %v", err)
[33]96 }
97 }
[55]98 if err := uc.net.Close(); err != nil {
99 uc.logger.Printf("failed to close connection: %v", err)
[45]100 } else {
[55]101 uc.logger.Printf("connection closed")
[45]102 }
[33]103 }()
104
[55]105 return uc, nil
[33]106}
107
[55]108func (uc *upstreamConn) Close() error {
109 if uc.closed {
[33]110 return fmt.Errorf("upstream connection already closed")
111 }
[55]112 close(uc.messages)
113 uc.closed = true
[33]114 return nil
115}
116
[73]117func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
118 uc.user.forEachDownstream(func(dc *downstreamConn) {
[77]119 if dc.network != nil && dc.network != uc.network {
[73]120 return
121 }
122 f(dc)
123 })
124}
125
[55]126func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
127 ch, ok := uc.channels[name]
[19]128 if !ok {
129 return nil, fmt.Errorf("unknown channel %q", name)
130 }
131 return ch, nil
132}
133
[55]134func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
[13]135 switch msg.Command {
136 case "PING":
[60]137 uc.SendMessage(&irc.Message{
[13]138 Command: "PONG",
[68]139 Params: msg.Params,
[60]140 })
[33]141 return nil
[17]142 case "MODE":
[69]143 if msg.Prefix == nil {
144 return fmt.Errorf("missing prefix")
145 }
146
[43]147 var name, modeStr string
148 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
149 return err
[17]150 }
[35]151
152 if name == msg.Prefix.Name { // user mode change
[55]153 if name != uc.nick {
[35]154 return fmt.Errorf("received MODE message for unknow nick %q", name)
155 }
[55]156 return uc.modes.Apply(modeStr)
[35]157 } else { // channel mode change
[55]158 ch, err := uc.getChannel(name)
[35]159 if err != nil {
160 return err
161 }
162 if err := ch.modes.Apply(modeStr); err != nil {
163 return err
164 }
[69]165
[73]166 uc.forEachDownstream(func(dc *downstreamConn) {
[69]167 dc.SendMessage(&irc.Message{
168 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
169 Command: "MODE",
170 Params: []string{dc.marshalChannel(uc, name), modeStr},
171 })
172 })
[46]173 }
[18]174 case "NOTICE":
[55]175 uc.logger.Print(msg)
[92]176 case "CAP":
[95]177 var subCmd string
178 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
179 return err
[92]180 }
[95]181 subCmd = strings.ToUpper(subCmd)
182 subParams := msg.Params[2:]
183 switch subCmd {
184 case "LS":
185 if len(subParams) < 1 {
186 return newNeedMoreParamsError(msg.Command)
187 }
188 caps := strings.Fields(subParams[len(subParams)-1])
189 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
[92]190
[95]191 for _, s := range caps {
192 kv := strings.SplitN(s, "=", 2)
193 k := strings.ToLower(kv[0])
194 var v string
195 if len(kv) == 2 {
196 v = kv[1]
197 }
198 uc.caps[k] = v
[92]199 }
200
[95]201 if more {
202 break // wait to receive all capabilities
203 }
204
205 if uc.requestSASL() {
206 uc.SendMessage(&irc.Message{
207 Command: "CAP",
208 Params: []string{"REQ", "sasl"},
209 })
210 break // we'll send CAP END after authentication is completed
211 }
212
[92]213 uc.SendMessage(&irc.Message{
214 Command: "CAP",
215 Params: []string{"END"},
216 })
[95]217 case "ACK", "NAK":
218 if len(subParams) < 1 {
219 return newNeedMoreParamsError(msg.Command)
220 }
221 caps := strings.Fields(subParams[0])
222
223 for _, name := range caps {
224 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
225 return err
226 }
227 }
228
229 if uc.saslClient == nil {
230 uc.SendMessage(&irc.Message{
231 Command: "CAP",
232 Params: []string{"END"},
233 })
234 }
235 default:
236 uc.logger.Printf("unhandled message: %v", msg)
[92]237 }
[95]238 case "AUTHENTICATE":
239 if uc.saslClient == nil {
240 return fmt.Errorf("received unexpected AUTHENTICATE message")
241 }
242
243 // TODO: if a challenge is 400 bytes long, buffer it
244 var challengeStr string
245 if err := parseMessageParams(msg, &challengeStr); err != nil {
246 uc.SendMessage(&irc.Message{
247 Command: "AUTHENTICATE",
248 Params: []string{"*"},
249 })
250 return err
251 }
252
253 var challenge []byte
254 if challengeStr != "+" {
255 var err error
256 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
257 if err != nil {
258 uc.SendMessage(&irc.Message{
259 Command: "AUTHENTICATE",
260 Params: []string{"*"},
261 })
262 return err
263 }
264 }
265
266 var resp []byte
267 var err error
268 if !uc.saslStarted {
269 _, resp, err = uc.saslClient.Start()
270 uc.saslStarted = true
271 } else {
272 resp, err = uc.saslClient.Next(challenge)
273 }
274 if err != nil {
275 uc.SendMessage(&irc.Message{
276 Command: "AUTHENTICATE",
277 Params: []string{"*"},
278 })
279 return err
280 }
281
282 // TODO: send response in multiple chunks if >= 400 bytes
283 var respStr = "+"
284 if resp != nil {
285 respStr = base64.StdEncoding.EncodeToString(resp)
286 }
287
288 uc.SendMessage(&irc.Message{
289 Command: "AUTHENTICATE",
290 Params: []string{respStr},
291 })
292 case rpl_loggedin:
293 var account string
294 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
295 return err
296 }
297 uc.logger.Printf("logged in with account %q", account)
298 case rpl_loggedout:
299 uc.logger.Printf("logged out")
300 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
301 var info string
302 if err := parseMessageParams(msg, nil, &info); err != nil {
303 return err
304 }
305 switch msg.Command {
306 case err_nicklocked:
307 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
308 case err_saslfail:
309 uc.logger.Printf("SASL authentication failed: %v", info)
310 case err_sasltoolong:
311 uc.logger.Printf("SASL message too long: %v", info)
312 }
313
314 uc.saslClient = nil
315 uc.saslStarted = false
316
317 uc.SendMessage(&irc.Message{
318 Command: "CAP",
319 Params: []string{"END"},
320 })
[14]321 case irc.RPL_WELCOME:
[55]322 uc.registered = true
323 uc.logger.Printf("connection registered")
[19]324
[77]325 channels, err := uc.srv.db.ListChannels(uc.network.ID)
326 if err != nil {
327 uc.logger.Printf("failed to list channels from database: %v", err)
328 break
329 }
330
331 for _, ch := range channels {
[60]332 uc.SendMessage(&irc.Message{
[19]333 Command: "JOIN",
[77]334 Params: []string{ch.Name},
[60]335 })
[19]336 }
[16]337 case irc.RPL_MYINFO:
[55]338 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
[43]339 return err
[16]340 }
341 if len(msg.Params) > 5 {
[55]342 uc.channelModesWithParam = msg.Params[5]
[16]343 }
[42]344 case "NICK":
[83]345 if msg.Prefix == nil {
346 return fmt.Errorf("expected a prefix")
347 }
348
[43]349 var newNick string
350 if err := parseMessageParams(msg, &newNick); err != nil {
351 return err
[42]352 }
353
[55]354 if msg.Prefix.Name == uc.nick {
355 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
356 uc.nick = newNick
[42]357 }
358
[55]359 for _, ch := range uc.channels {
[42]360 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
361 delete(ch.Members, msg.Prefix.Name)
362 ch.Members[newNick] = membership
363 }
364 }
[82]365
366 if msg.Prefix.Name != uc.nick {
367 uc.forEachDownstream(func(dc *downstreamConn) {
368 dc.SendMessage(&irc.Message{
369 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
370 Command: "NICK",
371 Params: []string{newNick},
372 })
373 })
374 }
[69]375 case "JOIN":
376 if msg.Prefix == nil {
377 return fmt.Errorf("expected a prefix")
378 }
[42]379
[43]380 var channels string
381 if err := parseMessageParams(msg, &channels); err != nil {
382 return err
[19]383 }
[34]384
[43]385 for _, ch := range strings.Split(channels, ",") {
[55]386 if msg.Prefix.Name == uc.nick {
387 uc.logger.Printf("joined channel %q", ch)
388 uc.channels[ch] = &upstreamChannel{
[34]389 Name: ch,
[55]390 conn: uc,
[34]391 Members: make(map[string]membership),
392 }
393 } else {
[55]394 ch, err := uc.getChannel(ch)
[34]395 if err != nil {
396 return err
397 }
398 ch.Members[msg.Prefix.Name] = 0
[19]399 }
[69]400
[73]401 uc.forEachDownstream(func(dc *downstreamConn) {
[69]402 dc.SendMessage(&irc.Message{
403 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
404 Command: "JOIN",
405 Params: []string{dc.marshalChannel(uc, ch)},
406 })
407 })
[19]408 }
[69]409 case "PART":
410 if msg.Prefix == nil {
411 return fmt.Errorf("expected a prefix")
412 }
[34]413
[43]414 var channels string
415 if err := parseMessageParams(msg, &channels); err != nil {
416 return err
[34]417 }
418
[43]419 for _, ch := range strings.Split(channels, ",") {
[55]420 if msg.Prefix.Name == uc.nick {
421 uc.logger.Printf("parted channel %q", ch)
422 delete(uc.channels, ch)
[34]423 } else {
[55]424 ch, err := uc.getChannel(ch)
[34]425 if err != nil {
426 return err
427 }
428 delete(ch.Members, msg.Prefix.Name)
429 }
[69]430
[73]431 uc.forEachDownstream(func(dc *downstreamConn) {
[69]432 dc.SendMessage(&irc.Message{
433 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
434 Command: "PART",
435 Params: []string{dc.marshalChannel(uc, ch)},
436 })
437 })
[34]438 }
[83]439 case "QUIT":
440 if msg.Prefix == nil {
441 return fmt.Errorf("expected a prefix")
442 }
443
444 if msg.Prefix.Name == uc.nick {
445 uc.logger.Printf("quit")
446 }
447
448 for _, ch := range uc.channels {
449 delete(ch.Members, msg.Prefix.Name)
450 }
451
452 if msg.Prefix.Name != uc.nick {
453 uc.forEachDownstream(func(dc *downstreamConn) {
454 dc.SendMessage(&irc.Message{
455 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
456 Command: "QUIT",
457 Params: msg.Params,
458 })
459 })
460 }
[19]461 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
[43]462 var name, topic string
463 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
464 return err
[19]465 }
[55]466 ch, err := uc.getChannel(name)
[19]467 if err != nil {
468 return err
469 }
470 if msg.Command == irc.RPL_TOPIC {
[43]471 ch.Topic = topic
[19]472 } else {
473 ch.Topic = ""
474 }
475 case "TOPIC":
[43]476 var name string
[74]477 if err := parseMessageParams(msg, &name); err != nil {
[43]478 return err
[19]479 }
[55]480 ch, err := uc.getChannel(name)
[19]481 if err != nil {
482 return err
483 }
484 if len(msg.Params) > 1 {
485 ch.Topic = msg.Params[1]
486 } else {
487 ch.Topic = ""
488 }
[74]489 uc.forEachDownstream(func(dc *downstreamConn) {
490 params := []string{dc.marshalChannel(uc, name)}
491 if ch.Topic != "" {
492 params = append(params, ch.Topic)
493 }
494 dc.SendMessage(&irc.Message{
495 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
496 Command: "TOPIC",
497 Params: params,
498 })
499 })
[19]500 case rpl_topicwhotime:
[43]501 var name, who, timeStr string
502 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
503 return err
[19]504 }
[55]505 ch, err := uc.getChannel(name)
[19]506 if err != nil {
507 return err
508 }
[43]509 ch.TopicWho = who
510 sec, err := strconv.ParseInt(timeStr, 10, 64)
[19]511 if err != nil {
512 return fmt.Errorf("failed to parse topic time: %v", err)
513 }
514 ch.TopicTime = time.Unix(sec, 0)
515 case irc.RPL_NAMREPLY:
[43]516 var name, statusStr, members string
517 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
518 return err
[19]519 }
[55]520 ch, err := uc.getChannel(name)
[19]521 if err != nil {
522 return err
523 }
524
[43]525 status, err := parseChannelStatus(statusStr)
[19]526 if err != nil {
527 return err
528 }
529 ch.Status = status
530
[43]531 for _, s := range strings.Split(members, " ") {
[19]532 membership, nick := parseMembershipPrefix(s)
533 ch.Members[nick] = membership
534 }
535 case irc.RPL_ENDOFNAMES:
[43]536 var name string
537 if err := parseMessageParams(msg, nil, &name); err != nil {
538 return err
[25]539 }
[55]540 ch, err := uc.getChannel(name)
[25]541 if err != nil {
542 return err
543 }
544
[34]545 if ch.complete {
546 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
547 }
[25]548 ch.complete = true
[27]549
[73]550 uc.forEachDownstream(func(dc *downstreamConn) {
[27]551 forwardChannel(dc, ch)
[40]552 })
[36]553 case "PRIVMSG":
[69]554 if err := parseMessageParams(msg, nil, nil); err != nil {
555 return err
556 }
[55]557 uc.ring.Produce(msg)
[16]558 case irc.RPL_YOURHOST, irc.RPL_CREATED:
[14]559 // Ignore
560 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
561 // Ignore
562 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
563 // Ignore
564 case rpl_localusers, rpl_globalusers:
565 // Ignore
[96]566 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
[14]567 // Ignore
[13]568 default:
[95]569 uc.logger.Printf("unhandled message: %v", msg)
[13]570 }
[14]571 return nil
[13]572}
573
[55]574func (uc *upstreamConn) register() {
[77]575 uc.nick = uc.network.Nick
576 uc.username = uc.network.Username
577 if uc.username == "" {
578 uc.username = uc.nick
579 }
580 uc.realname = uc.network.Realname
581 if uc.realname == "" {
582 uc.realname = uc.nick
583 }
584
[60]585 uc.SendMessage(&irc.Message{
[92]586 Command: "CAP",
587 Params: []string{"LS", "302"},
588 })
589
[93]590 if uc.network.Pass != "" {
591 uc.SendMessage(&irc.Message{
592 Command: "PASS",
593 Params: []string{uc.network.Pass},
594 })
595 }
596
[92]597 uc.SendMessage(&irc.Message{
[13]598 Command: "NICK",
[69]599 Params: []string{uc.nick},
[60]600 })
601 uc.SendMessage(&irc.Message{
[13]602 Command: "USER",
[77]603 Params: []string{uc.username, "0", "*", uc.realname},
[60]604 })
[44]605}
[13]606
[95]607func (uc *upstreamConn) requestSASL() bool {
608 if uc.network.SASL.Mechanism == "" {
609 return false
610 }
611
612 v, ok := uc.caps["sasl"]
613 if !ok {
614 return false
615 }
616 if v != "" {
617 mechanisms := strings.Split(v, ",")
618 found := false
619 for _, mech := range mechanisms {
620 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
621 found = true
622 break
623 }
624 }
625 if !found {
626 return false
627 }
628 }
629
630 return true
631}
632
633func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
634 auth := &uc.network.SASL
635 switch name {
636 case "sasl":
637 if !ok {
638 uc.logger.Printf("server refused to acknowledge the SASL capability")
639 return nil
640 }
641
642 switch auth.Mechanism {
643 case "PLAIN":
644 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
645 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
646 default:
647 return fmt.Errorf("unsupported SASL mechanism %q", name)
648 }
649
650 uc.SendMessage(&irc.Message{
651 Command: "AUTHENTICATE",
652 Params: []string{auth.Mechanism},
653 })
654 }
655 return nil
656}
657
[55]658func (uc *upstreamConn) readMessages() error {
[13]659 for {
[55]660 msg, err := uc.irc.ReadMessage()
[13]661 if err == io.EOF {
662 break
663 } else if err != nil {
664 return fmt.Errorf("failed to read IRC command: %v", err)
665 }
666
[64]667 if uc.srv.Debug {
668 uc.logger.Printf("received: %v", msg)
669 }
670
[55]671 if err := uc.handleMessage(msg); err != nil {
672 uc.logger.Printf("failed to handle message %q: %v", msg, err)
[13]673 }
674 }
675
[45]676 return nil
[13]677}
[60]678
679func (uc *upstreamConn) SendMessage(msg *irc.Message) {
680 uc.messages <- msg
681}
Note: See TracBrowser for help on using the repository browser.