source: code/trunk/upstream.go@ 117

Last change on this file since 117 was 117, checked in by contact, 5 years ago

Add basic infrastructure for bouncer service

File size: 16.0 KB
RevLine 
[98]1package soju
[13]2
3import (
4 "crypto/tls"
[95]5 "encoding/base64"
[13]6 "fmt"
7 "io"
8 "net"
[19]9 "strconv"
[17]10 "strings"
[109]11 "sync"
[19]12 "time"
[13]13
[95]14 "github.com/emersion/go-sasl"
[13]15 "gopkg.in/irc.v3"
16)
17
[19]18type upstreamChannel struct {
19 Name string
[46]20 conn *upstreamConn
[19]21 Topic string
22 TopicWho string
23 TopicTime time.Time
24 Status channelStatus
[35]25 modes modeSet
[19]26 Members map[string]membership
[25]27 complete bool
[19]28}
29
[13]30type upstreamConn struct {
[77]31 network *network
[21]32 logger Logger
[19]33 net net.Conn
34 irc *irc.Conn
35 srv *Server
[37]36 user *user
[102]37 outgoing chan<- *irc.Message
[50]38 ring *Ring
[16]39
40 serverName string
41 availableUserModes string
42 availableChannelModes string
43 channelModesWithParam string
[19]44
45 registered bool
[42]46 nick string
[77]47 username string
48 realname string
[33]49 closed bool
[19]50 modes modeSet
51 channels map[string]*upstreamChannel
[92]52 caps map[string]string
[95]53
54 saslClient sasl.Client
55 saslStarted bool
[109]56
57 lock sync.Mutex
58 history map[string]uint64 // TODO: move to network
[13]59}
60
[77]61func connectToUpstream(network *network) (*upstreamConn, error) {
62 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
[33]63
[77]64 addr := network.Addr
65 if !strings.ContainsRune(addr, ':') {
66 addr = addr + ":6697"
67 }
68
69 logger.Printf("connecting to TLS server at address %q", addr)
70 netConn, err := tls.Dial("tcp", addr, nil)
[33]71 if err != nil {
[77]72 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
[33]73 }
74
[67]75 setKeepAlive(netConn)
76
[102]77 outgoing := make(chan *irc.Message, 64)
[55]78 uc := &upstreamConn{
[79]79 network: network,
[33]80 logger: logger,
81 net: netConn,
82 irc: irc.NewConn(netConn),
[77]83 srv: network.user.srv,
84 user: network.user,
[102]85 outgoing: outgoing,
[77]86 ring: NewRing(network.user.srv.RingCap),
[33]87 channels: make(map[string]*upstreamChannel),
[57]88 history: make(map[string]uint64),
[92]89 caps: make(map[string]string),
[33]90 }
91
92 go func() {
[102]93 for msg := range outgoing {
[64]94 if uc.srv.Debug {
95 uc.logger.Printf("sent: %v", msg)
96 }
[55]97 if err := uc.irc.WriteMessage(msg); err != nil {
98 uc.logger.Printf("failed to write message: %v", err)
[33]99 }
100 }
[55]101 if err := uc.net.Close(); err != nil {
102 uc.logger.Printf("failed to close connection: %v", err)
[45]103 } else {
[55]104 uc.logger.Printf("connection closed")
[45]105 }
[33]106 }()
107
[55]108 return uc, nil
[33]109}
110
[55]111func (uc *upstreamConn) Close() error {
112 if uc.closed {
[33]113 return fmt.Errorf("upstream connection already closed")
114 }
[102]115 close(uc.outgoing)
[55]116 uc.closed = true
[33]117 return nil
118}
119
[73]120func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
121 uc.user.forEachDownstream(func(dc *downstreamConn) {
[77]122 if dc.network != nil && dc.network != uc.network {
[73]123 return
124 }
125 f(dc)
126 })
127}
128
[55]129func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
130 ch, ok := uc.channels[name]
[19]131 if !ok {
132 return nil, fmt.Errorf("unknown channel %q", name)
133 }
134 return ch, nil
135}
136
[55]137func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
[13]138 switch msg.Command {
139 case "PING":
[60]140 uc.SendMessage(&irc.Message{
[13]141 Command: "PONG",
[68]142 Params: msg.Params,
[60]143 })
[33]144 return nil
[17]145 case "MODE":
[69]146 if msg.Prefix == nil {
147 return fmt.Errorf("missing prefix")
148 }
149
[43]150 var name, modeStr string
151 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
152 return err
[17]153 }
[35]154
155 if name == msg.Prefix.Name { // user mode change
[55]156 if name != uc.nick {
[35]157 return fmt.Errorf("received MODE message for unknow nick %q", name)
158 }
[55]159 return uc.modes.Apply(modeStr)
[35]160 } else { // channel mode change
[55]161 ch, err := uc.getChannel(name)
[35]162 if err != nil {
163 return err
164 }
165 if err := ch.modes.Apply(modeStr); err != nil {
166 return err
167 }
[69]168
[73]169 uc.forEachDownstream(func(dc *downstreamConn) {
[69]170 dc.SendMessage(&irc.Message{
171 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
172 Command: "MODE",
173 Params: []string{dc.marshalChannel(uc, name), modeStr},
174 })
175 })
[46]176 }
[18]177 case "NOTICE":
[55]178 uc.logger.Print(msg)
[97]179
180 uc.forEachDownstream(func(dc *downstreamConn) {
181 dc.SendMessage(msg)
182 })
[92]183 case "CAP":
[95]184 var subCmd string
185 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
186 return err
[92]187 }
[95]188 subCmd = strings.ToUpper(subCmd)
189 subParams := msg.Params[2:]
190 switch subCmd {
191 case "LS":
192 if len(subParams) < 1 {
193 return newNeedMoreParamsError(msg.Command)
194 }
195 caps := strings.Fields(subParams[len(subParams)-1])
196 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
[92]197
[95]198 for _, s := range caps {
199 kv := strings.SplitN(s, "=", 2)
200 k := strings.ToLower(kv[0])
201 var v string
202 if len(kv) == 2 {
203 v = kv[1]
204 }
205 uc.caps[k] = v
[92]206 }
207
[95]208 if more {
209 break // wait to receive all capabilities
210 }
211
212 if uc.requestSASL() {
213 uc.SendMessage(&irc.Message{
214 Command: "CAP",
215 Params: []string{"REQ", "sasl"},
216 })
217 break // we'll send CAP END after authentication is completed
218 }
219
[92]220 uc.SendMessage(&irc.Message{
221 Command: "CAP",
222 Params: []string{"END"},
223 })
[95]224 case "ACK", "NAK":
225 if len(subParams) < 1 {
226 return newNeedMoreParamsError(msg.Command)
227 }
228 caps := strings.Fields(subParams[0])
229
230 for _, name := range caps {
231 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
232 return err
233 }
234 }
235
236 if uc.saslClient == nil {
237 uc.SendMessage(&irc.Message{
238 Command: "CAP",
239 Params: []string{"END"},
240 })
241 }
242 default:
243 uc.logger.Printf("unhandled message: %v", msg)
[92]244 }
[95]245 case "AUTHENTICATE":
246 if uc.saslClient == nil {
247 return fmt.Errorf("received unexpected AUTHENTICATE message")
248 }
249
250 // TODO: if a challenge is 400 bytes long, buffer it
251 var challengeStr string
252 if err := parseMessageParams(msg, &challengeStr); err != nil {
253 uc.SendMessage(&irc.Message{
254 Command: "AUTHENTICATE",
255 Params: []string{"*"},
256 })
257 return err
258 }
259
260 var challenge []byte
261 if challengeStr != "+" {
262 var err error
263 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
264 if err != nil {
265 uc.SendMessage(&irc.Message{
266 Command: "AUTHENTICATE",
267 Params: []string{"*"},
268 })
269 return err
270 }
271 }
272
273 var resp []byte
274 var err error
275 if !uc.saslStarted {
276 _, resp, err = uc.saslClient.Start()
277 uc.saslStarted = true
278 } else {
279 resp, err = uc.saslClient.Next(challenge)
280 }
281 if err != nil {
282 uc.SendMessage(&irc.Message{
283 Command: "AUTHENTICATE",
284 Params: []string{"*"},
285 })
286 return err
287 }
288
289 // TODO: send response in multiple chunks if >= 400 bytes
290 var respStr = "+"
291 if resp != nil {
292 respStr = base64.StdEncoding.EncodeToString(resp)
293 }
294
295 uc.SendMessage(&irc.Message{
296 Command: "AUTHENTICATE",
297 Params: []string{respStr},
298 })
299 case rpl_loggedin:
300 var account string
301 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
302 return err
303 }
304 uc.logger.Printf("logged in with account %q", account)
305 case rpl_loggedout:
306 uc.logger.Printf("logged out")
307 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
308 var info string
309 if err := parseMessageParams(msg, nil, &info); err != nil {
310 return err
311 }
312 switch msg.Command {
313 case err_nicklocked:
314 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
315 case err_saslfail:
316 uc.logger.Printf("SASL authentication failed: %v", info)
317 case err_sasltoolong:
318 uc.logger.Printf("SASL message too long: %v", info)
319 }
320
321 uc.saslClient = nil
322 uc.saslStarted = false
323
324 uc.SendMessage(&irc.Message{
325 Command: "CAP",
326 Params: []string{"END"},
327 })
[14]328 case irc.RPL_WELCOME:
[55]329 uc.registered = true
330 uc.logger.Printf("connection registered")
[19]331
[77]332 channels, err := uc.srv.db.ListChannels(uc.network.ID)
333 if err != nil {
334 uc.logger.Printf("failed to list channels from database: %v", err)
335 break
336 }
337
338 for _, ch := range channels {
[60]339 uc.SendMessage(&irc.Message{
[19]340 Command: "JOIN",
[77]341 Params: []string{ch.Name},
[60]342 })
[19]343 }
[16]344 case irc.RPL_MYINFO:
[55]345 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
[43]346 return err
[16]347 }
348 if len(msg.Params) > 5 {
[55]349 uc.channelModesWithParam = msg.Params[5]
[16]350 }
[42]351 case "NICK":
[83]352 if msg.Prefix == nil {
353 return fmt.Errorf("expected a prefix")
354 }
355
[43]356 var newNick string
357 if err := parseMessageParams(msg, &newNick); err != nil {
358 return err
[42]359 }
360
[55]361 if msg.Prefix.Name == uc.nick {
362 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
363 uc.nick = newNick
[42]364 }
365
[55]366 for _, ch := range uc.channels {
[42]367 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
368 delete(ch.Members, msg.Prefix.Name)
369 ch.Members[newNick] = membership
370 }
371 }
[82]372
373 if msg.Prefix.Name != uc.nick {
374 uc.forEachDownstream(func(dc *downstreamConn) {
375 dc.SendMessage(&irc.Message{
376 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
377 Command: "NICK",
378 Params: []string{newNick},
379 })
380 })
381 }
[69]382 case "JOIN":
383 if msg.Prefix == nil {
384 return fmt.Errorf("expected a prefix")
385 }
[42]386
[43]387 var channels string
388 if err := parseMessageParams(msg, &channels); err != nil {
389 return err
[19]390 }
[34]391
[43]392 for _, ch := range strings.Split(channels, ",") {
[55]393 if msg.Prefix.Name == uc.nick {
394 uc.logger.Printf("joined channel %q", ch)
395 uc.channels[ch] = &upstreamChannel{
[34]396 Name: ch,
[55]397 conn: uc,
[34]398 Members: make(map[string]membership),
399 }
400 } else {
[55]401 ch, err := uc.getChannel(ch)
[34]402 if err != nil {
403 return err
404 }
405 ch.Members[msg.Prefix.Name] = 0
[19]406 }
[69]407
[73]408 uc.forEachDownstream(func(dc *downstreamConn) {
[69]409 dc.SendMessage(&irc.Message{
410 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
411 Command: "JOIN",
412 Params: []string{dc.marshalChannel(uc, ch)},
413 })
414 })
[19]415 }
[69]416 case "PART":
417 if msg.Prefix == nil {
418 return fmt.Errorf("expected a prefix")
419 }
[34]420
[43]421 var channels string
422 if err := parseMessageParams(msg, &channels); err != nil {
423 return err
[34]424 }
425
[43]426 for _, ch := range strings.Split(channels, ",") {
[55]427 if msg.Prefix.Name == uc.nick {
428 uc.logger.Printf("parted channel %q", ch)
429 delete(uc.channels, ch)
[34]430 } else {
[55]431 ch, err := uc.getChannel(ch)
[34]432 if err != nil {
433 return err
434 }
435 delete(ch.Members, msg.Prefix.Name)
436 }
[69]437
[73]438 uc.forEachDownstream(func(dc *downstreamConn) {
[69]439 dc.SendMessage(&irc.Message{
440 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
441 Command: "PART",
442 Params: []string{dc.marshalChannel(uc, ch)},
443 })
444 })
[34]445 }
[83]446 case "QUIT":
447 if msg.Prefix == nil {
448 return fmt.Errorf("expected a prefix")
449 }
450
451 if msg.Prefix.Name == uc.nick {
452 uc.logger.Printf("quit")
453 }
454
455 for _, ch := range uc.channels {
456 delete(ch.Members, msg.Prefix.Name)
457 }
458
459 if msg.Prefix.Name != uc.nick {
460 uc.forEachDownstream(func(dc *downstreamConn) {
461 dc.SendMessage(&irc.Message{
462 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
463 Command: "QUIT",
464 Params: msg.Params,
465 })
466 })
467 }
[19]468 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
[43]469 var name, topic string
470 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
471 return err
[19]472 }
[55]473 ch, err := uc.getChannel(name)
[19]474 if err != nil {
475 return err
476 }
477 if msg.Command == irc.RPL_TOPIC {
[43]478 ch.Topic = topic
[19]479 } else {
480 ch.Topic = ""
481 }
482 case "TOPIC":
[43]483 var name string
[74]484 if err := parseMessageParams(msg, &name); err != nil {
[43]485 return err
[19]486 }
[55]487 ch, err := uc.getChannel(name)
[19]488 if err != nil {
489 return err
490 }
491 if len(msg.Params) > 1 {
492 ch.Topic = msg.Params[1]
493 } else {
494 ch.Topic = ""
495 }
[74]496 uc.forEachDownstream(func(dc *downstreamConn) {
497 params := []string{dc.marshalChannel(uc, name)}
498 if ch.Topic != "" {
499 params = append(params, ch.Topic)
500 }
501 dc.SendMessage(&irc.Message{
502 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
503 Command: "TOPIC",
504 Params: params,
505 })
506 })
[19]507 case rpl_topicwhotime:
[43]508 var name, who, timeStr string
509 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
510 return err
[19]511 }
[55]512 ch, err := uc.getChannel(name)
[19]513 if err != nil {
514 return err
515 }
[43]516 ch.TopicWho = who
517 sec, err := strconv.ParseInt(timeStr, 10, 64)
[19]518 if err != nil {
519 return fmt.Errorf("failed to parse topic time: %v", err)
520 }
521 ch.TopicTime = time.Unix(sec, 0)
522 case irc.RPL_NAMREPLY:
[43]523 var name, statusStr, members string
524 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
525 return err
[19]526 }
[55]527 ch, err := uc.getChannel(name)
[19]528 if err != nil {
529 return err
530 }
531
[43]532 status, err := parseChannelStatus(statusStr)
[19]533 if err != nil {
534 return err
535 }
536 ch.Status = status
537
[43]538 for _, s := range strings.Split(members, " ") {
[19]539 membership, nick := parseMembershipPrefix(s)
540 ch.Members[nick] = membership
541 }
542 case irc.RPL_ENDOFNAMES:
[43]543 var name string
544 if err := parseMessageParams(msg, nil, &name); err != nil {
545 return err
[25]546 }
[55]547 ch, err := uc.getChannel(name)
[25]548 if err != nil {
549 return err
550 }
551
[34]552 if ch.complete {
553 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
554 }
[25]555 ch.complete = true
[27]556
[73]557 uc.forEachDownstream(func(dc *downstreamConn) {
[27]558 forwardChannel(dc, ch)
[40]559 })
[36]560 case "PRIVMSG":
[117]561 if msg.Prefix == nil {
562 return fmt.Errorf("expected a prefix")
563 }
564
565 var nick string
566 if err := parseMessageParams(msg, &nick, nil); err != nil {
[69]567 return err
568 }
[117]569
570 if msg.Prefix.Name == serviceNick {
571 uc.logger.Printf("skipping PRIVMSG from soju's service: %v", msg)
572 break
573 }
574 if nick == serviceNick {
575 uc.logger.Printf("skipping PRIVMSG to soju's service: %v", msg)
576 break
577 }
578
[55]579 uc.ring.Produce(msg)
[115]580 case "INVITE":
581 var nick string
582 var channel string
583 if err := parseMessageParams(msg, &nick, &channel); err != nil {
584 return err
585 }
586
587 uc.forEachDownstream(func(dc *downstreamConn) {
588 dc.SendMessage(&irc.Message{
589 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
590 Command: "INVITE",
591 Params: []string{dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)},
592 })
593 })
[16]594 case irc.RPL_YOURHOST, irc.RPL_CREATED:
[14]595 // Ignore
596 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
597 // Ignore
598 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
599 // Ignore
600 case rpl_localusers, rpl_globalusers:
601 // Ignore
[96]602 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
[14]603 // Ignore
[13]604 default:
[95]605 uc.logger.Printf("unhandled message: %v", msg)
[13]606 }
[14]607 return nil
[13]608}
609
[55]610func (uc *upstreamConn) register() {
[77]611 uc.nick = uc.network.Nick
612 uc.username = uc.network.Username
613 if uc.username == "" {
614 uc.username = uc.nick
615 }
616 uc.realname = uc.network.Realname
617 if uc.realname == "" {
618 uc.realname = uc.nick
619 }
620
[60]621 uc.SendMessage(&irc.Message{
[92]622 Command: "CAP",
623 Params: []string{"LS", "302"},
624 })
625
[93]626 if uc.network.Pass != "" {
627 uc.SendMessage(&irc.Message{
628 Command: "PASS",
629 Params: []string{uc.network.Pass},
630 })
631 }
632
[92]633 uc.SendMessage(&irc.Message{
[13]634 Command: "NICK",
[69]635 Params: []string{uc.nick},
[60]636 })
637 uc.SendMessage(&irc.Message{
[13]638 Command: "USER",
[77]639 Params: []string{uc.username, "0", "*", uc.realname},
[60]640 })
[44]641}
[13]642
[95]643func (uc *upstreamConn) requestSASL() bool {
644 if uc.network.SASL.Mechanism == "" {
645 return false
646 }
647
648 v, ok := uc.caps["sasl"]
649 if !ok {
650 return false
651 }
652 if v != "" {
653 mechanisms := strings.Split(v, ",")
654 found := false
655 for _, mech := range mechanisms {
656 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
657 found = true
658 break
659 }
660 }
661 if !found {
662 return false
663 }
664 }
665
666 return true
667}
668
669func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
670 auth := &uc.network.SASL
671 switch name {
672 case "sasl":
673 if !ok {
674 uc.logger.Printf("server refused to acknowledge the SASL capability")
675 return nil
676 }
677
678 switch auth.Mechanism {
679 case "PLAIN":
680 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
681 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
682 default:
683 return fmt.Errorf("unsupported SASL mechanism %q", name)
684 }
685
686 uc.SendMessage(&irc.Message{
687 Command: "AUTHENTICATE",
688 Params: []string{auth.Mechanism},
689 })
690 }
691 return nil
692}
693
[103]694func (uc *upstreamConn) readMessages(ch chan<- upstreamIncomingMessage) error {
[13]695 for {
[55]696 msg, err := uc.irc.ReadMessage()
[13]697 if err == io.EOF {
698 break
699 } else if err != nil {
700 return fmt.Errorf("failed to read IRC command: %v", err)
701 }
702
[64]703 if uc.srv.Debug {
704 uc.logger.Printf("received: %v", msg)
705 }
706
[103]707 ch <- upstreamIncomingMessage{msg, uc}
[13]708 }
709
[45]710 return nil
[13]711}
[60]712
713func (uc *upstreamConn) SendMessage(msg *irc.Message) {
[102]714 uc.outgoing <- msg
[60]715}
Note: See TracBrowser for help on using the repository browser.