source: code/trunk/upstream.go@ 99

Last change on this file since 99 was 98, checked in by contact, 5 years ago

Rename project to soju

File size: 15.2 KB
RevLine 
[98]1package soju
[13]2
3import (
4 "crypto/tls"
[95]5 "encoding/base64"
[13]6 "fmt"
7 "io"
8 "net"
[19]9 "strconv"
[17]10 "strings"
[19]11 "time"
[13]12
[95]13 "github.com/emersion/go-sasl"
[13]14 "gopkg.in/irc.v3"
15)
16
[19]17type upstreamChannel struct {
18 Name string
[46]19 conn *upstreamConn
[19]20 Topic string
21 TopicWho string
22 TopicTime time.Time
23 Status channelStatus
[35]24 modes modeSet
[19]25 Members map[string]membership
[25]26 complete bool
[19]27}
28
[13]29type upstreamConn struct {
[77]30 network *network
[21]31 logger Logger
[19]32 net net.Conn
33 irc *irc.Conn
34 srv *Server
[37]35 user *user
[33]36 messages chan<- *irc.Message
[50]37 ring *Ring
[16]38
39 serverName string
40 availableUserModes string
41 availableChannelModes string
42 channelModesWithParam string
[19]43
44 registered bool
[42]45 nick string
[77]46 username string
47 realname string
[33]48 closed bool
[19]49 modes modeSet
50 channels map[string]*upstreamChannel
[57]51 history map[string]uint64
[92]52 caps map[string]string
[95]53
54 saslClient sasl.Client
55 saslStarted bool
[13]56}
57
[77]58func connectToUpstream(network *network) (*upstreamConn, error) {
59 logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
[33]60
[77]61 addr := network.Addr
62 if !strings.ContainsRune(addr, ':') {
63 addr = addr + ":6697"
64 }
65
66 logger.Printf("connecting to TLS server at address %q", addr)
67 netConn, err := tls.Dial("tcp", addr, nil)
[33]68 if err != nil {
[77]69 return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
[33]70 }
71
[67]72 setKeepAlive(netConn)
73
[33]74 msgs := make(chan *irc.Message, 64)
[55]75 uc := &upstreamConn{
[79]76 network: network,
[33]77 logger: logger,
78 net: netConn,
79 irc: irc.NewConn(netConn),
[77]80 srv: network.user.srv,
81 user: network.user,
[33]82 messages: msgs,
[77]83 ring: NewRing(network.user.srv.RingCap),
[33]84 channels: make(map[string]*upstreamChannel),
[57]85 history: make(map[string]uint64),
[92]86 caps: make(map[string]string),
[33]87 }
88
89 go func() {
90 for msg := range msgs {
[64]91 if uc.srv.Debug {
92 uc.logger.Printf("sent: %v", msg)
93 }
[55]94 if err := uc.irc.WriteMessage(msg); err != nil {
95 uc.logger.Printf("failed to write message: %v", err)
[33]96 }
97 }
[55]98 if err := uc.net.Close(); err != nil {
99 uc.logger.Printf("failed to close connection: %v", err)
[45]100 } else {
[55]101 uc.logger.Printf("connection closed")
[45]102 }
[33]103 }()
104
[55]105 return uc, nil
[33]106}
107
[55]108func (uc *upstreamConn) Close() error {
109 if uc.closed {
[33]110 return fmt.Errorf("upstream connection already closed")
111 }
[55]112 close(uc.messages)
113 uc.closed = true
[33]114 return nil
115}
116
[73]117func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
118 uc.user.forEachDownstream(func(dc *downstreamConn) {
[77]119 if dc.network != nil && dc.network != uc.network {
[73]120 return
121 }
122 f(dc)
123 })
124}
125
[55]126func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
127 ch, ok := uc.channels[name]
[19]128 if !ok {
129 return nil, fmt.Errorf("unknown channel %q", name)
130 }
131 return ch, nil
132}
133
[55]134func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
[13]135 switch msg.Command {
136 case "PING":
[60]137 uc.SendMessage(&irc.Message{
[13]138 Command: "PONG",
[68]139 Params: msg.Params,
[60]140 })
[33]141 return nil
[17]142 case "MODE":
[69]143 if msg.Prefix == nil {
144 return fmt.Errorf("missing prefix")
145 }
146
[43]147 var name, modeStr string
148 if err := parseMessageParams(msg, &name, &modeStr); err != nil {
149 return err
[17]150 }
[35]151
152 if name == msg.Prefix.Name { // user mode change
[55]153 if name != uc.nick {
[35]154 return fmt.Errorf("received MODE message for unknow nick %q", name)
155 }
[55]156 return uc.modes.Apply(modeStr)
[35]157 } else { // channel mode change
[55]158 ch, err := uc.getChannel(name)
[35]159 if err != nil {
160 return err
161 }
162 if err := ch.modes.Apply(modeStr); err != nil {
163 return err
164 }
[69]165
[73]166 uc.forEachDownstream(func(dc *downstreamConn) {
[69]167 dc.SendMessage(&irc.Message{
168 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
169 Command: "MODE",
170 Params: []string{dc.marshalChannel(uc, name), modeStr},
171 })
172 })
[46]173 }
[18]174 case "NOTICE":
[55]175 uc.logger.Print(msg)
[97]176
177 uc.forEachDownstream(func(dc *downstreamConn) {
178 dc.SendMessage(msg)
179 })
[92]180 case "CAP":
[95]181 var subCmd string
182 if err := parseMessageParams(msg, nil, &subCmd); err != nil {
183 return err
[92]184 }
[95]185 subCmd = strings.ToUpper(subCmd)
186 subParams := msg.Params[2:]
187 switch subCmd {
188 case "LS":
189 if len(subParams) < 1 {
190 return newNeedMoreParamsError(msg.Command)
191 }
192 caps := strings.Fields(subParams[len(subParams)-1])
193 more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
[92]194
[95]195 for _, s := range caps {
196 kv := strings.SplitN(s, "=", 2)
197 k := strings.ToLower(kv[0])
198 var v string
199 if len(kv) == 2 {
200 v = kv[1]
201 }
202 uc.caps[k] = v
[92]203 }
204
[95]205 if more {
206 break // wait to receive all capabilities
207 }
208
209 if uc.requestSASL() {
210 uc.SendMessage(&irc.Message{
211 Command: "CAP",
212 Params: []string{"REQ", "sasl"},
213 })
214 break // we'll send CAP END after authentication is completed
215 }
216
[92]217 uc.SendMessage(&irc.Message{
218 Command: "CAP",
219 Params: []string{"END"},
220 })
[95]221 case "ACK", "NAK":
222 if len(subParams) < 1 {
223 return newNeedMoreParamsError(msg.Command)
224 }
225 caps := strings.Fields(subParams[0])
226
227 for _, name := range caps {
228 if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
229 return err
230 }
231 }
232
233 if uc.saslClient == nil {
234 uc.SendMessage(&irc.Message{
235 Command: "CAP",
236 Params: []string{"END"},
237 })
238 }
239 default:
240 uc.logger.Printf("unhandled message: %v", msg)
[92]241 }
[95]242 case "AUTHENTICATE":
243 if uc.saslClient == nil {
244 return fmt.Errorf("received unexpected AUTHENTICATE message")
245 }
246
247 // TODO: if a challenge is 400 bytes long, buffer it
248 var challengeStr string
249 if err := parseMessageParams(msg, &challengeStr); err != nil {
250 uc.SendMessage(&irc.Message{
251 Command: "AUTHENTICATE",
252 Params: []string{"*"},
253 })
254 return err
255 }
256
257 var challenge []byte
258 if challengeStr != "+" {
259 var err error
260 challenge, err = base64.StdEncoding.DecodeString(challengeStr)
261 if err != nil {
262 uc.SendMessage(&irc.Message{
263 Command: "AUTHENTICATE",
264 Params: []string{"*"},
265 })
266 return err
267 }
268 }
269
270 var resp []byte
271 var err error
272 if !uc.saslStarted {
273 _, resp, err = uc.saslClient.Start()
274 uc.saslStarted = true
275 } else {
276 resp, err = uc.saslClient.Next(challenge)
277 }
278 if err != nil {
279 uc.SendMessage(&irc.Message{
280 Command: "AUTHENTICATE",
281 Params: []string{"*"},
282 })
283 return err
284 }
285
286 // TODO: send response in multiple chunks if >= 400 bytes
287 var respStr = "+"
288 if resp != nil {
289 respStr = base64.StdEncoding.EncodeToString(resp)
290 }
291
292 uc.SendMessage(&irc.Message{
293 Command: "AUTHENTICATE",
294 Params: []string{respStr},
295 })
296 case rpl_loggedin:
297 var account string
298 if err := parseMessageParams(msg, nil, nil, &account); err != nil {
299 return err
300 }
301 uc.logger.Printf("logged in with account %q", account)
302 case rpl_loggedout:
303 uc.logger.Printf("logged out")
304 case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
305 var info string
306 if err := parseMessageParams(msg, nil, &info); err != nil {
307 return err
308 }
309 switch msg.Command {
310 case err_nicklocked:
311 uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
312 case err_saslfail:
313 uc.logger.Printf("SASL authentication failed: %v", info)
314 case err_sasltoolong:
315 uc.logger.Printf("SASL message too long: %v", info)
316 }
317
318 uc.saslClient = nil
319 uc.saslStarted = false
320
321 uc.SendMessage(&irc.Message{
322 Command: "CAP",
323 Params: []string{"END"},
324 })
[14]325 case irc.RPL_WELCOME:
[55]326 uc.registered = true
327 uc.logger.Printf("connection registered")
[19]328
[77]329 channels, err := uc.srv.db.ListChannels(uc.network.ID)
330 if err != nil {
331 uc.logger.Printf("failed to list channels from database: %v", err)
332 break
333 }
334
335 for _, ch := range channels {
[60]336 uc.SendMessage(&irc.Message{
[19]337 Command: "JOIN",
[77]338 Params: []string{ch.Name},
[60]339 })
[19]340 }
[16]341 case irc.RPL_MYINFO:
[55]342 if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, &uc.availableChannelModes); err != nil {
[43]343 return err
[16]344 }
345 if len(msg.Params) > 5 {
[55]346 uc.channelModesWithParam = msg.Params[5]
[16]347 }
[42]348 case "NICK":
[83]349 if msg.Prefix == nil {
350 return fmt.Errorf("expected a prefix")
351 }
352
[43]353 var newNick string
354 if err := parseMessageParams(msg, &newNick); err != nil {
355 return err
[42]356 }
357
[55]358 if msg.Prefix.Name == uc.nick {
359 uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
360 uc.nick = newNick
[42]361 }
362
[55]363 for _, ch := range uc.channels {
[42]364 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
365 delete(ch.Members, msg.Prefix.Name)
366 ch.Members[newNick] = membership
367 }
368 }
[82]369
370 if msg.Prefix.Name != uc.nick {
371 uc.forEachDownstream(func(dc *downstreamConn) {
372 dc.SendMessage(&irc.Message{
373 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
374 Command: "NICK",
375 Params: []string{newNick},
376 })
377 })
378 }
[69]379 case "JOIN":
380 if msg.Prefix == nil {
381 return fmt.Errorf("expected a prefix")
382 }
[42]383
[43]384 var channels string
385 if err := parseMessageParams(msg, &channels); err != nil {
386 return err
[19]387 }
[34]388
[43]389 for _, ch := range strings.Split(channels, ",") {
[55]390 if msg.Prefix.Name == uc.nick {
391 uc.logger.Printf("joined channel %q", ch)
392 uc.channels[ch] = &upstreamChannel{
[34]393 Name: ch,
[55]394 conn: uc,
[34]395 Members: make(map[string]membership),
396 }
397 } else {
[55]398 ch, err := uc.getChannel(ch)
[34]399 if err != nil {
400 return err
401 }
402 ch.Members[msg.Prefix.Name] = 0
[19]403 }
[69]404
[73]405 uc.forEachDownstream(func(dc *downstreamConn) {
[69]406 dc.SendMessage(&irc.Message{
407 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
408 Command: "JOIN",
409 Params: []string{dc.marshalChannel(uc, ch)},
410 })
411 })
[19]412 }
[69]413 case "PART":
414 if msg.Prefix == nil {
415 return fmt.Errorf("expected a prefix")
416 }
[34]417
[43]418 var channels string
419 if err := parseMessageParams(msg, &channels); err != nil {
420 return err
[34]421 }
422
[43]423 for _, ch := range strings.Split(channels, ",") {
[55]424 if msg.Prefix.Name == uc.nick {
425 uc.logger.Printf("parted channel %q", ch)
426 delete(uc.channels, ch)
[34]427 } else {
[55]428 ch, err := uc.getChannel(ch)
[34]429 if err != nil {
430 return err
431 }
432 delete(ch.Members, msg.Prefix.Name)
433 }
[69]434
[73]435 uc.forEachDownstream(func(dc *downstreamConn) {
[69]436 dc.SendMessage(&irc.Message{
437 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
438 Command: "PART",
439 Params: []string{dc.marshalChannel(uc, ch)},
440 })
441 })
[34]442 }
[83]443 case "QUIT":
444 if msg.Prefix == nil {
445 return fmt.Errorf("expected a prefix")
446 }
447
448 if msg.Prefix.Name == uc.nick {
449 uc.logger.Printf("quit")
450 }
451
452 for _, ch := range uc.channels {
453 delete(ch.Members, msg.Prefix.Name)
454 }
455
456 if msg.Prefix.Name != uc.nick {
457 uc.forEachDownstream(func(dc *downstreamConn) {
458 dc.SendMessage(&irc.Message{
459 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
460 Command: "QUIT",
461 Params: msg.Params,
462 })
463 })
464 }
[19]465 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
[43]466 var name, topic string
467 if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
468 return err
[19]469 }
[55]470 ch, err := uc.getChannel(name)
[19]471 if err != nil {
472 return err
473 }
474 if msg.Command == irc.RPL_TOPIC {
[43]475 ch.Topic = topic
[19]476 } else {
477 ch.Topic = ""
478 }
479 case "TOPIC":
[43]480 var name string
[74]481 if err := parseMessageParams(msg, &name); err != nil {
[43]482 return err
[19]483 }
[55]484 ch, err := uc.getChannel(name)
[19]485 if err != nil {
486 return err
487 }
488 if len(msg.Params) > 1 {
489 ch.Topic = msg.Params[1]
490 } else {
491 ch.Topic = ""
492 }
[74]493 uc.forEachDownstream(func(dc *downstreamConn) {
494 params := []string{dc.marshalChannel(uc, name)}
495 if ch.Topic != "" {
496 params = append(params, ch.Topic)
497 }
498 dc.SendMessage(&irc.Message{
499 Prefix: dc.marshalUserPrefix(uc, msg.Prefix),
500 Command: "TOPIC",
501 Params: params,
502 })
503 })
[19]504 case rpl_topicwhotime:
[43]505 var name, who, timeStr string
506 if err := parseMessageParams(msg, nil, &name, &who, &timeStr); err != nil {
507 return err
[19]508 }
[55]509 ch, err := uc.getChannel(name)
[19]510 if err != nil {
511 return err
512 }
[43]513 ch.TopicWho = who
514 sec, err := strconv.ParseInt(timeStr, 10, 64)
[19]515 if err != nil {
516 return fmt.Errorf("failed to parse topic time: %v", err)
517 }
518 ch.TopicTime = time.Unix(sec, 0)
519 case irc.RPL_NAMREPLY:
[43]520 var name, statusStr, members string
521 if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
522 return err
[19]523 }
[55]524 ch, err := uc.getChannel(name)
[19]525 if err != nil {
526 return err
527 }
528
[43]529 status, err := parseChannelStatus(statusStr)
[19]530 if err != nil {
531 return err
532 }
533 ch.Status = status
534
[43]535 for _, s := range strings.Split(members, " ") {
[19]536 membership, nick := parseMembershipPrefix(s)
537 ch.Members[nick] = membership
538 }
539 case irc.RPL_ENDOFNAMES:
[43]540 var name string
541 if err := parseMessageParams(msg, nil, &name); err != nil {
542 return err
[25]543 }
[55]544 ch, err := uc.getChannel(name)
[25]545 if err != nil {
546 return err
547 }
548
[34]549 if ch.complete {
550 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
551 }
[25]552 ch.complete = true
[27]553
[73]554 uc.forEachDownstream(func(dc *downstreamConn) {
[27]555 forwardChannel(dc, ch)
[40]556 })
[36]557 case "PRIVMSG":
[69]558 if err := parseMessageParams(msg, nil, nil); err != nil {
559 return err
560 }
[55]561 uc.ring.Produce(msg)
[16]562 case irc.RPL_YOURHOST, irc.RPL_CREATED:
[14]563 // Ignore
564 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
565 // Ignore
566 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
567 // Ignore
568 case rpl_localusers, rpl_globalusers:
569 // Ignore
[96]570 case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
[14]571 // Ignore
[13]572 default:
[95]573 uc.logger.Printf("unhandled message: %v", msg)
[13]574 }
[14]575 return nil
[13]576}
577
[55]578func (uc *upstreamConn) register() {
[77]579 uc.nick = uc.network.Nick
580 uc.username = uc.network.Username
581 if uc.username == "" {
582 uc.username = uc.nick
583 }
584 uc.realname = uc.network.Realname
585 if uc.realname == "" {
586 uc.realname = uc.nick
587 }
588
[60]589 uc.SendMessage(&irc.Message{
[92]590 Command: "CAP",
591 Params: []string{"LS", "302"},
592 })
593
[93]594 if uc.network.Pass != "" {
595 uc.SendMessage(&irc.Message{
596 Command: "PASS",
597 Params: []string{uc.network.Pass},
598 })
599 }
600
[92]601 uc.SendMessage(&irc.Message{
[13]602 Command: "NICK",
[69]603 Params: []string{uc.nick},
[60]604 })
605 uc.SendMessage(&irc.Message{
[13]606 Command: "USER",
[77]607 Params: []string{uc.username, "0", "*", uc.realname},
[60]608 })
[44]609}
[13]610
[95]611func (uc *upstreamConn) requestSASL() bool {
612 if uc.network.SASL.Mechanism == "" {
613 return false
614 }
615
616 v, ok := uc.caps["sasl"]
617 if !ok {
618 return false
619 }
620 if v != "" {
621 mechanisms := strings.Split(v, ",")
622 found := false
623 for _, mech := range mechanisms {
624 if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
625 found = true
626 break
627 }
628 }
629 if !found {
630 return false
631 }
632 }
633
634 return true
635}
636
637func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
638 auth := &uc.network.SASL
639 switch name {
640 case "sasl":
641 if !ok {
642 uc.logger.Printf("server refused to acknowledge the SASL capability")
643 return nil
644 }
645
646 switch auth.Mechanism {
647 case "PLAIN":
648 uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
649 uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
650 default:
651 return fmt.Errorf("unsupported SASL mechanism %q", name)
652 }
653
654 uc.SendMessage(&irc.Message{
655 Command: "AUTHENTICATE",
656 Params: []string{auth.Mechanism},
657 })
658 }
659 return nil
660}
661
[55]662func (uc *upstreamConn) readMessages() error {
[13]663 for {
[55]664 msg, err := uc.irc.ReadMessage()
[13]665 if err == io.EOF {
666 break
667 } else if err != nil {
668 return fmt.Errorf("failed to read IRC command: %v", err)
669 }
670
[64]671 if uc.srv.Debug {
672 uc.logger.Printf("received: %v", msg)
673 }
674
[55]675 if err := uc.handleMessage(msg); err != nil {
676 uc.logger.Printf("failed to handle message %q: %v", msg, err)
[13]677 }
678 }
679
[45]680 return nil
[13]681}
[60]682
683func (uc *upstreamConn) SendMessage(msg *irc.Message) {
684 uc.messages <- msg
685}
Note: See TracBrowser for help on using the repository browser.