source: code/trunk/irc.go@ 771

Last change on this file since 771 was 761, checked in by contact, 3 years ago

Handle upstream multi-line SASL

References: https://todo.sr.ht/~emersion/soju/173

File size: 18.2 KB
RevLine 
[98]1package soju
[20]2
3import (
4 "fmt"
[350]5 "sort"
[20]6 "strings"
[516]7 "time"
[498]8 "unicode"
9 "unicode/utf8"
[43]10
11 "gopkg.in/irc.v3"
[20]12)
13
14const (
[108]15 rpl_statsping = "246"
16 rpl_localusers = "265"
17 rpl_globalusers = "266"
[162]18 rpl_creationtime = "329"
[108]19 rpl_topicwhotime = "333"
[660]20 rpl_whospcrpl = "354"
[661]21 rpl_whoisaccount = "330"
[108]22 err_invalidcapcmd = "410"
[20]23)
24
[463]25const (
26 maxMessageLength = 512
27 maxMessageParams = 15
[761]28 maxSASLLength = 400
[463]29)
[346]30
[350]31// The server-time layout, as defined in the IRCv3 spec.
32const serverTimeLayout = "2006-01-02T15:04:05.000Z"
33
[139]34type userModes string
[20]35
[139]36func (ms userModes) Has(c byte) bool {
[20]37 return strings.IndexByte(string(ms), c) >= 0
38}
39
[139]40func (ms *userModes) Add(c byte) {
[20]41 if !ms.Has(c) {
[139]42 *ms += userModes(c)
[20]43 }
44}
45
[139]46func (ms *userModes) Del(c byte) {
[20]47 i := strings.IndexByte(string(*ms), c)
48 if i >= 0 {
49 *ms = (*ms)[:i] + (*ms)[i+1:]
50 }
51}
52
[139]53func (ms *userModes) Apply(s string) error {
[20]54 var plusMinus byte
55 for i := 0; i < len(s); i++ {
56 switch c := s[i]; c {
57 case '+', '-':
58 plusMinus = c
59 default:
60 switch plusMinus {
61 case '+':
62 ms.Add(c)
63 case '-':
64 ms.Del(c)
65 default:
66 return fmt.Errorf("malformed modestring %q: missing plus/minus", s)
67 }
68 }
69 }
70 return nil
71}
72
[139]73type channelModeType byte
74
75// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message
76const (
77 // modes that add or remove an address to or from a list
78 modeTypeA channelModeType = iota
79 // modes that change a setting on a channel, and must always have a parameter
80 modeTypeB
81 // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset
82 modeTypeC
83 // modes that change a setting on a channel, and must not have a parameter
84 modeTypeD
85)
86
87var stdChannelModes = map[byte]channelModeType{
88 'b': modeTypeA, // ban list
89 'e': modeTypeA, // ban exception list
90 'I': modeTypeA, // invite exception list
91 'k': modeTypeB, // channel key
92 'l': modeTypeC, // channel user limit
93 'i': modeTypeD, // channel is invite-only
94 'm': modeTypeD, // channel is moderated
95 'n': modeTypeD, // channel has no external messages
96 's': modeTypeD, // channel is secret
97 't': modeTypeD, // channel has protected topic
98}
99
100type channelModes map[byte]string
101
[293]102// applyChannelModes parses a mode string and mode arguments from a MODE message,
103// and applies the corresponding channel mode and user membership changes on that channel.
104//
105// If ch.modes is nil, channel modes are not updated.
106//
107// needMarshaling is a list of indexes of mode arguments that represent entities
108// that must be marshaled when sent downstream.
109func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) {
110 needMarshaling = make(map[int]struct{}, len(arguments))
[139]111 nextArgument := 0
112 var plusMinus byte
[293]113outer:
[139]114 for i := 0; i < len(modeStr); i++ {
115 mode := modeStr[i]
116 if mode == '+' || mode == '-' {
117 plusMinus = mode
118 continue
119 }
120 if plusMinus != '+' && plusMinus != '-' {
[293]121 return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr)
[139]122 }
123
[293]124 for _, membership := range ch.conn.availableMemberships {
125 if membership.Mode == mode {
126 if nextArgument >= len(arguments) {
127 return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
128 }
129 member := arguments[nextArgument]
[478]130 m := ch.Members.Value(member)
131 if m != nil {
[293]132 if plusMinus == '+' {
[478]133 m.Add(ch.conn.availableMemberships, membership)
[293]134 } else {
135 // TODO: for upstreams without multi-prefix, query the user modes again
[478]136 m.Remove(membership)
[293]137 }
138 }
139 needMarshaling[nextArgument] = struct{}{}
140 nextArgument++
141 continue outer
142 }
143 }
144
145 mt, ok := ch.conn.availableChannelModes[mode]
[139]146 if !ok {
147 continue
148 }
[673]149 if mt == modeTypeA {
150 nextArgument++
151 } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') {
[139]152 if plusMinus == '+' {
153 var argument string
154 // some sentitive arguments (such as channel keys) can be omitted for privacy
155 // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages)
156 if nextArgument < len(arguments) {
157 argument = arguments[nextArgument]
158 }
[293]159 if ch.modes != nil {
160 ch.modes[mode] = argument
161 }
[139]162 } else {
[293]163 delete(ch.modes, mode)
[139]164 }
165 nextArgument++
166 } else if mt == modeTypeC || mt == modeTypeD {
167 if plusMinus == '+' {
[293]168 if ch.modes != nil {
169 ch.modes[mode] = ""
170 }
[139]171 } else {
[293]172 delete(ch.modes, mode)
[139]173 }
174 }
175 }
[293]176 return needMarshaling, nil
[139]177}
178
179func (cm channelModes) Format() (modeString string, parameters []string) {
180 var modesWithValues strings.Builder
181 var modesWithoutValues strings.Builder
182 parameters = make([]string, 0, 16)
183 for mode, value := range cm {
184 if value != "" {
185 modesWithValues.WriteString(string(mode))
186 parameters = append(parameters, value)
187 } else {
188 modesWithoutValues.WriteString(string(mode))
189 }
190 }
191 modeString = "+" + modesWithValues.String() + modesWithoutValues.String()
192 return
193}
194
195const stdChannelTypes = "#&+!"
196
[20]197type channelStatus byte
198
199const (
200 channelPublic channelStatus = '='
201 channelSecret channelStatus = '@'
202 channelPrivate channelStatus = '*'
203)
204
205func parseChannelStatus(s string) (channelStatus, error) {
206 if len(s) > 1 {
207 return 0, fmt.Errorf("invalid channel status %q: more than one character", s)
208 }
209 switch cs := channelStatus(s[0]); cs {
210 case channelPublic, channelSecret, channelPrivate:
211 return cs, nil
212 default:
213 return 0, fmt.Errorf("invalid channel status %q: unknown status", s)
214 }
215}
216
[139]217type membership struct {
218 Mode byte
219 Prefix byte
220}
[20]221
[139]222var stdMemberships = []membership{
223 {'q', '~'}, // founder
224 {'a', '&'}, // protected
225 {'o', '@'}, // operator
226 {'h', '%'}, // halfop
227 {'v', '+'}, // voice
228}
[20]229
[292]230// memberships always sorted by descending membership rank
231type memberships []membership
232
233func (m *memberships) Add(availableMemberships []membership, newMembership membership) {
234 l := *m
235 i := 0
236 for _, availableMembership := range availableMemberships {
237 if i >= len(l) {
238 break
239 }
240 if l[i] == availableMembership {
241 if availableMembership == newMembership {
242 // we already have this membership
243 return
244 }
245 i++
246 continue
247 }
248 if availableMembership == newMembership {
249 break
250 }
[128]251 }
[292]252 // insert newMembership at i
253 l = append(l, membership{})
254 copy(l[i+1:], l[i:])
255 l[i] = newMembership
256 *m = l
[128]257}
258
[292]259func (m *memberships) Remove(oldMembership membership) {
260 l := *m
261 for i, currentMembership := range l {
262 if currentMembership == oldMembership {
263 *m = append(l[:i], l[i+1:]...)
264 return
265 }
266 }
267}
268
269func (m memberships) Format(dc *downstreamConn) string {
270 if !dc.caps["multi-prefix"] {
271 if len(m) == 0 {
272 return ""
273 }
274 return string(m[0].Prefix)
275 }
276 prefixes := make([]byte, len(m))
277 for i, membership := range m {
278 prefixes[i] = membership.Prefix
279 }
280 return string(prefixes)
281}
282
[43]283func parseMessageParams(msg *irc.Message, out ...*string) error {
284 if len(msg.Params) < len(out) {
285 return newNeedMoreParamsError(msg.Command)
286 }
287 for i := range out {
288 if out[i] != nil {
289 *out[i] = msg.Params[i]
290 }
291 }
292 return nil
293}
[153]294
[303]295func copyClientTags(tags irc.Tags) irc.Tags {
296 t := make(irc.Tags, len(tags))
297 for k, v := range tags {
298 if strings.HasPrefix(k, "+") {
299 t[k] = v
300 }
301 }
302 return t
303}
304
[153]305type batch struct {
306 Type string
307 Params []string
308 Outer *batch // if not-nil, this batch is nested in Outer
[155]309 Label string
[153]310}
[193]311
[350]312func join(channels, keys []string) []*irc.Message {
313 // Put channels with a key first
314 js := joinSorter{channels, keys}
315 sort.Sort(&js)
316
317 // Two spaces because there are three words (JOIN, channels and keys)
318 maxLength := maxMessageLength - (len("JOIN") + 2)
319
320 var msgs []*irc.Message
321 var channelsBuf, keysBuf strings.Builder
322 for i, channel := range channels {
323 key := keys[i]
324
325 n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel)
326 if key != "" {
327 n += 1 + len(key)
328 }
329
330 if channelsBuf.Len() > 0 && n > maxLength {
331 // No room for the new channel in this message
332 params := []string{channelsBuf.String()}
333 if keysBuf.Len() > 0 {
334 params = append(params, keysBuf.String())
335 }
336 msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
337 channelsBuf.Reset()
338 keysBuf.Reset()
339 }
340
341 if channelsBuf.Len() > 0 {
342 channelsBuf.WriteByte(',')
343 }
344 channelsBuf.WriteString(channel)
345 if key != "" {
346 if keysBuf.Len() > 0 {
347 keysBuf.WriteByte(',')
348 }
349 keysBuf.WriteString(key)
350 }
351 }
352 if channelsBuf.Len() > 0 {
353 params := []string{channelsBuf.String()}
354 if keysBuf.Len() > 0 {
355 params = append(params, keysBuf.String())
356 }
357 msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
358 }
359
360 return msgs
361}
362
[463]363func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message {
364 maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text
365
366 var msgs []*irc.Message
367 for len(tokens) > 0 {
368 var msgTokens []string
369 if len(tokens) > maxTokens {
370 msgTokens = tokens[:maxTokens]
371 tokens = tokens[maxTokens:]
372 } else {
373 msgTokens = tokens
374 tokens = nil
375 }
376
377 msgs = append(msgs, &irc.Message{
378 Prefix: prefix,
379 Command: irc.RPL_ISUPPORT,
380 Params: append(append([]string{nick}, msgTokens...), "are supported"),
381 })
382 }
383
384 return msgs
385}
386
[636]387func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message {
388 var msgs []*irc.Message
389 msgs = append(msgs, &irc.Message{
390 Prefix: prefix,
391 Command: irc.RPL_MOTDSTART,
392 Params: []string{nick, fmt.Sprintf("- Message of the Day -")},
393 })
394
395 for _, l := range strings.Split(motd, "\n") {
396 msgs = append(msgs, &irc.Message{
397 Prefix: prefix,
398 Command: irc.RPL_MOTD,
399 Params: []string{nick, l},
400 })
401 }
402
403 msgs = append(msgs, &irc.Message{
404 Prefix: prefix,
405 Command: irc.RPL_ENDOFMOTD,
406 Params: []string{nick, "End of /MOTD command."},
407 })
408
409 return msgs
410}
411
[684]412func generateMonitor(subcmd string, targets []string) []*irc.Message {
413 maxLength := maxMessageLength - len("MONITOR "+subcmd+" ")
414
415 var msgs []*irc.Message
416 var buf []string
417 n := 0
418 for _, target := range targets {
419 if n+len(target)+1 > maxLength {
420 msgs = append(msgs, &irc.Message{
421 Command: "MONITOR",
422 Params: []string{subcmd, strings.Join(buf, ",")},
423 })
424 buf = buf[:0]
425 n = 0
426 }
427
428 buf = append(buf, target)
429 n += len(target) + 1
430 }
431
432 if len(buf) > 0 {
433 msgs = append(msgs, &irc.Message{
434 Command: "MONITOR",
435 Params: []string{subcmd, strings.Join(buf, ",")},
436 })
437 }
438
439 return msgs
440}
441
[350]442type joinSorter struct {
443 channels []string
444 keys []string
445}
446
447func (js *joinSorter) Len() int {
448 return len(js.channels)
449}
450
451func (js *joinSorter) Less(i, j int) bool {
452 if (js.keys[i] != "") != (js.keys[j] != "") {
453 // Only one of the channels has a key
454 return js.keys[i] != ""
455 }
456 return js.channels[i] < js.channels[j]
457}
458
459func (js *joinSorter) Swap(i, j int) {
460 js.channels[i], js.channels[j] = js.channels[j], js.channels[i]
461 js.keys[i], js.keys[j] = js.keys[j], js.keys[i]
462}
[392]463
464// parseCTCPMessage parses a CTCP message. CTCP is defined in
465// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02
466func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
467 if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 {
468 return "", "", false
469 }
470 text := msg.Params[1]
471
472 if !strings.HasPrefix(text, "\x01") {
473 return "", "", false
474 }
475 text = strings.Trim(text, "\x01")
476
477 words := strings.SplitN(text, " ", 2)
478 cmd = strings.ToUpper(words[0])
479 if len(words) > 1 {
480 params = words[1]
481 }
482
483 return cmd, params, true
484}
[478]485
486type casemapping func(string) string
487
488func casemapNone(name string) string {
489 return name
490}
491
492// CasemapASCII of name is the canonical representation of name according to the
493// ascii casemapping.
494func casemapASCII(name string) string {
[492]495 nameBytes := []byte(name)
496 for i, r := range nameBytes {
[478]497 if 'A' <= r && r <= 'Z' {
[492]498 nameBytes[i] = r + 'a' - 'A'
[478]499 }
500 }
[492]501 return string(nameBytes)
[478]502}
503
504// casemapRFC1459 of name is the canonical representation of name according to the
505// rfc1459 casemapping.
506func casemapRFC1459(name string) string {
[492]507 nameBytes := []byte(name)
508 for i, r := range nameBytes {
[478]509 if 'A' <= r && r <= 'Z' {
[492]510 nameBytes[i] = r + 'a' - 'A'
[478]511 } else if r == '{' {
[492]512 nameBytes[i] = '['
[478]513 } else if r == '}' {
[492]514 nameBytes[i] = ']'
[478]515 } else if r == '\\' {
[492]516 nameBytes[i] = '|'
[478]517 } else if r == '~' {
[492]518 nameBytes[i] = '^'
[478]519 }
520 }
[492]521 return string(nameBytes)
[478]522}
523
524// casemapRFC1459Strict of name is the canonical representation of name
525// according to the rfc1459-strict casemapping.
526func casemapRFC1459Strict(name string) string {
[492]527 nameBytes := []byte(name)
528 for i, r := range nameBytes {
[478]529 if 'A' <= r && r <= 'Z' {
[492]530 nameBytes[i] = r + 'a' - 'A'
[478]531 } else if r == '{' {
[492]532 nameBytes[i] = '['
[478]533 } else if r == '}' {
[492]534 nameBytes[i] = ']'
[478]535 } else if r == '\\' {
[492]536 nameBytes[i] = '|'
[478]537 }
538 }
[492]539 return string(nameBytes)
[478]540}
541
542func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
543 switch tokenValue {
544 case "ascii":
545 casemap = casemapASCII
546 case "rfc1459":
547 casemap = casemapRFC1459
548 case "rfc1459-strict":
549 casemap = casemapRFC1459Strict
550 default:
551 return nil, false
552 }
553 return casemap, true
554}
555
556func partialCasemap(higher casemapping, name string) string {
[492]557 nameFullyCM := []byte(higher(name))
558 nameBytes := []byte(name)
559 for i, r := range nameBytes {
560 if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') {
561 nameBytes[i] = nameFullyCM[i]
[478]562 }
563 }
[492]564 return string(nameBytes)
[478]565}
566
567type casemapMap struct {
568 innerMap map[string]casemapEntry
569 casemap casemapping
570}
571
572type casemapEntry struct {
573 originalKey string
574 value interface{}
575}
576
577func newCasemapMap(size int) casemapMap {
578 return casemapMap{
579 innerMap: make(map[string]casemapEntry, size),
580 casemap: casemapNone,
581 }
582}
583
584func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
585 entry, ok := cm.innerMap[cm.casemap(name)]
586 if !ok {
587 return "", false
588 }
589 return entry.originalKey, true
590}
591
592func (cm *casemapMap) Has(name string) bool {
593 _, ok := cm.innerMap[cm.casemap(name)]
594 return ok
595}
596
597func (cm *casemapMap) Len() int {
598 return len(cm.innerMap)
599}
600
601func (cm *casemapMap) SetValue(name string, value interface{}) {
602 nameCM := cm.casemap(name)
603 entry, ok := cm.innerMap[nameCM]
604 if !ok {
605 cm.innerMap[nameCM] = casemapEntry{
606 originalKey: name,
607 value: value,
608 }
609 return
610 }
611 entry.value = value
612 cm.innerMap[nameCM] = entry
613}
614
615func (cm *casemapMap) Delete(name string) {
616 delete(cm.innerMap, cm.casemap(name))
617}
618
619func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
620 cm.casemap = newCasemap
621 newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
622 for _, entry := range cm.innerMap {
623 newInnerMap[cm.casemap(entry.originalKey)] = entry
624 }
625 cm.innerMap = newInnerMap
626}
627
628type upstreamChannelCasemapMap struct{ casemapMap }
629
630func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
631 entry, ok := cm.innerMap[cm.casemap(name)]
632 if !ok {
633 return nil
634 }
635 return entry.value.(*upstreamChannel)
636}
637
638type channelCasemapMap struct{ casemapMap }
639
640func (cm *channelCasemapMap) Value(name string) *Channel {
641 entry, ok := cm.innerMap[cm.casemap(name)]
642 if !ok {
643 return nil
644 }
645 return entry.value.(*Channel)
646}
647
648type membershipsCasemapMap struct{ casemapMap }
649
650func (cm *membershipsCasemapMap) Value(name string) *memberships {
651 entry, ok := cm.innerMap[cm.casemap(name)]
652 if !ok {
653 return nil
654 }
655 return entry.value.(*memberships)
656}
657
[480]658type deliveredCasemapMap struct{ casemapMap }
[478]659
[480]660func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
[478]661 entry, ok := cm.innerMap[cm.casemap(name)]
662 if !ok {
663 return nil
664 }
[480]665 return entry.value.(deliveredClientMap)
[478]666}
[498]667
[684]668type monitorCasemapMap struct{ casemapMap }
669
670func (cm *monitorCasemapMap) Value(name string) (online bool) {
671 entry, ok := cm.innerMap[cm.casemap(name)]
672 if !ok {
673 return false
674 }
675 return entry.value.(bool)
676}
677
[498]678func isWordBoundary(r rune) bool {
679 switch r {
680 case '-', '_', '|':
681 return false
682 case '\u00A0':
683 return true
684 default:
685 return !unicode.IsLetter(r) && !unicode.IsNumber(r)
686 }
687}
688
689func isHighlight(text, nick string) bool {
690 for {
691 i := strings.Index(text, nick)
692 if i < 0 {
693 return false
694 }
695
696 // Detect word boundaries
697 var left, right rune
698 if i > 0 {
699 left, _ = utf8.DecodeLastRuneInString(text[:i])
700 }
701 if i < len(text) {
702 right, _ = utf8.DecodeRuneInString(text[i+len(nick):])
703 }
704 if isWordBoundary(left) && isWordBoundary(right) {
705 return true
706 }
707
708 text = text[i+len(nick):]
709 }
710}
[516]711
712// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound.
713// The zero time is returned on error.
714func parseChatHistoryBound(param string) time.Time {
715 parts := strings.SplitN(param, "=", 2)
716 if len(parts) != 2 {
717 return time.Time{}
718 }
719 switch parts[0] {
720 case "timestamp":
721 timestamp, err := time.Parse(serverTimeLayout, parts[1])
722 if err != nil {
723 return time.Time{}
724 }
725 return timestamp
726 default:
727 return time.Time{}
728 }
729}
[660]730
731type whoxInfo struct {
732 Token string
733 Username string
734 Hostname string
735 Server string
736 Nickname string
737 Flags string
738 Account string
739 Realname string
740}
741
742func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message {
743 if fields == "" {
744 return &irc.Message{
745 Prefix: prefix,
746 Command: irc.RPL_WHOREPLY,
747 Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname},
748 }
749 }
750
751 fieldSet := make(map[byte]bool)
752 for i := 0; i < len(fields); i++ {
753 fieldSet[fields[i]] = true
754 }
755
756 var params []string
757 if fieldSet['t'] {
758 params = append(params, info.Token)
759 }
760 if fieldSet['c'] {
761 params = append(params, "*")
762 }
763 if fieldSet['u'] {
764 params = append(params, info.Username)
765 }
766 if fieldSet['i'] {
767 params = append(params, "255.255.255.255")
768 }
769 if fieldSet['h'] {
770 params = append(params, info.Hostname)
771 }
772 if fieldSet['s'] {
773 params = append(params, info.Server)
774 }
775 if fieldSet['n'] {
776 params = append(params, info.Nickname)
777 }
778 if fieldSet['f'] {
779 params = append(params, info.Flags)
780 }
781 if fieldSet['d'] {
782 params = append(params, "0")
783 }
784 if fieldSet['l'] { // idle time
785 params = append(params, "0")
786 }
787 if fieldSet['a'] {
788 account := "0" // WHOX uses "0" to mean "no account"
789 if info.Account != "" && info.Account != "*" {
790 account = info.Account
791 }
792 params = append(params, account)
793 }
794 if fieldSet['o'] {
795 params = append(params, "0")
796 }
797 if fieldSet['r'] {
798 params = append(params, info.Realname)
799 }
800
801 return &irc.Message{
802 Prefix: prefix,
803 Command: rpl_whospcrpl,
804 Params: append([]string{nick}, params...),
805 }
806}
[662]807
808var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C")
809
810func encodeISUPPORT(s string) string {
811 return isupportEncoder.Replace(s)
812}
Note: See TracBrowser for help on using the repository browser.