source: code/trunk/irc.go@ 492

Last change on this file since 492 was 492, checked in by hubert, 4 years ago

Make casemapping work over bytes instead of runes

Fixes a panic in partialCasemap when the input string was invalid UTF-8.

File size: 13.9 KB
Line 
1package soju
2
3import (
4 "fmt"
5 "sort"
6 "strings"
7
8 "gopkg.in/irc.v3"
9)
10
11const (
12 rpl_statsping = "246"
13 rpl_localusers = "265"
14 rpl_globalusers = "266"
15 rpl_creationtime = "329"
16 rpl_topicwhotime = "333"
17 err_invalidcapcmd = "410"
18)
19
20const (
21 maxMessageLength = 512
22 maxMessageParams = 15
23)
24
25// The server-time layout, as defined in the IRCv3 spec.
26const serverTimeLayout = "2006-01-02T15:04:05.000Z"
27
28type userModes string
29
30func (ms userModes) Has(c byte) bool {
31 return strings.IndexByte(string(ms), c) >= 0
32}
33
34func (ms *userModes) Add(c byte) {
35 if !ms.Has(c) {
36 *ms += userModes(c)
37 }
38}
39
40func (ms *userModes) Del(c byte) {
41 i := strings.IndexByte(string(*ms), c)
42 if i >= 0 {
43 *ms = (*ms)[:i] + (*ms)[i+1:]
44 }
45}
46
47func (ms *userModes) Apply(s string) error {
48 var plusMinus byte
49 for i := 0; i < len(s); i++ {
50 switch c := s[i]; c {
51 case '+', '-':
52 plusMinus = c
53 default:
54 switch plusMinus {
55 case '+':
56 ms.Add(c)
57 case '-':
58 ms.Del(c)
59 default:
60 return fmt.Errorf("malformed modestring %q: missing plus/minus", s)
61 }
62 }
63 }
64 return nil
65}
66
67type channelModeType byte
68
69// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message
70const (
71 // modes that add or remove an address to or from a list
72 modeTypeA channelModeType = iota
73 // modes that change a setting on a channel, and must always have a parameter
74 modeTypeB
75 // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset
76 modeTypeC
77 // modes that change a setting on a channel, and must not have a parameter
78 modeTypeD
79)
80
81var stdChannelModes = map[byte]channelModeType{
82 'b': modeTypeA, // ban list
83 'e': modeTypeA, // ban exception list
84 'I': modeTypeA, // invite exception list
85 'k': modeTypeB, // channel key
86 'l': modeTypeC, // channel user limit
87 'i': modeTypeD, // channel is invite-only
88 'm': modeTypeD, // channel is moderated
89 'n': modeTypeD, // channel has no external messages
90 's': modeTypeD, // channel is secret
91 't': modeTypeD, // channel has protected topic
92}
93
94type channelModes map[byte]string
95
96// applyChannelModes parses a mode string and mode arguments from a MODE message,
97// and applies the corresponding channel mode and user membership changes on that channel.
98//
99// If ch.modes is nil, channel modes are not updated.
100//
101// needMarshaling is a list of indexes of mode arguments that represent entities
102// that must be marshaled when sent downstream.
103func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) {
104 needMarshaling = make(map[int]struct{}, len(arguments))
105 nextArgument := 0
106 var plusMinus byte
107outer:
108 for i := 0; i < len(modeStr); i++ {
109 mode := modeStr[i]
110 if mode == '+' || mode == '-' {
111 plusMinus = mode
112 continue
113 }
114 if plusMinus != '+' && plusMinus != '-' {
115 return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr)
116 }
117
118 for _, membership := range ch.conn.availableMemberships {
119 if membership.Mode == mode {
120 if nextArgument >= len(arguments) {
121 return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
122 }
123 member := arguments[nextArgument]
124 m := ch.Members.Value(member)
125 if m != nil {
126 if plusMinus == '+' {
127 m.Add(ch.conn.availableMemberships, membership)
128 } else {
129 // TODO: for upstreams without multi-prefix, query the user modes again
130 m.Remove(membership)
131 }
132 }
133 needMarshaling[nextArgument] = struct{}{}
134 nextArgument++
135 continue outer
136 }
137 }
138
139 mt, ok := ch.conn.availableChannelModes[mode]
140 if !ok {
141 continue
142 }
143 if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') {
144 if plusMinus == '+' {
145 var argument string
146 // some sentitive arguments (such as channel keys) can be omitted for privacy
147 // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages)
148 if nextArgument < len(arguments) {
149 argument = arguments[nextArgument]
150 }
151 if ch.modes != nil {
152 ch.modes[mode] = argument
153 }
154 } else {
155 delete(ch.modes, mode)
156 }
157 nextArgument++
158 } else if mt == modeTypeC || mt == modeTypeD {
159 if plusMinus == '+' {
160 if ch.modes != nil {
161 ch.modes[mode] = ""
162 }
163 } else {
164 delete(ch.modes, mode)
165 }
166 }
167 }
168 return needMarshaling, nil
169}
170
171func (cm channelModes) Format() (modeString string, parameters []string) {
172 var modesWithValues strings.Builder
173 var modesWithoutValues strings.Builder
174 parameters = make([]string, 0, 16)
175 for mode, value := range cm {
176 if value != "" {
177 modesWithValues.WriteString(string(mode))
178 parameters = append(parameters, value)
179 } else {
180 modesWithoutValues.WriteString(string(mode))
181 }
182 }
183 modeString = "+" + modesWithValues.String() + modesWithoutValues.String()
184 return
185}
186
187const stdChannelTypes = "#&+!"
188
189type channelStatus byte
190
191const (
192 channelPublic channelStatus = '='
193 channelSecret channelStatus = '@'
194 channelPrivate channelStatus = '*'
195)
196
197func parseChannelStatus(s string) (channelStatus, error) {
198 if len(s) > 1 {
199 return 0, fmt.Errorf("invalid channel status %q: more than one character", s)
200 }
201 switch cs := channelStatus(s[0]); cs {
202 case channelPublic, channelSecret, channelPrivate:
203 return cs, nil
204 default:
205 return 0, fmt.Errorf("invalid channel status %q: unknown status", s)
206 }
207}
208
209type membership struct {
210 Mode byte
211 Prefix byte
212}
213
214var stdMemberships = []membership{
215 {'q', '~'}, // founder
216 {'a', '&'}, // protected
217 {'o', '@'}, // operator
218 {'h', '%'}, // halfop
219 {'v', '+'}, // voice
220}
221
222// memberships always sorted by descending membership rank
223type memberships []membership
224
225func (m *memberships) Add(availableMemberships []membership, newMembership membership) {
226 l := *m
227 i := 0
228 for _, availableMembership := range availableMemberships {
229 if i >= len(l) {
230 break
231 }
232 if l[i] == availableMembership {
233 if availableMembership == newMembership {
234 // we already have this membership
235 return
236 }
237 i++
238 continue
239 }
240 if availableMembership == newMembership {
241 break
242 }
243 }
244 // insert newMembership at i
245 l = append(l, membership{})
246 copy(l[i+1:], l[i:])
247 l[i] = newMembership
248 *m = l
249}
250
251func (m *memberships) Remove(oldMembership membership) {
252 l := *m
253 for i, currentMembership := range l {
254 if currentMembership == oldMembership {
255 *m = append(l[:i], l[i+1:]...)
256 return
257 }
258 }
259}
260
261func (m memberships) Format(dc *downstreamConn) string {
262 if !dc.caps["multi-prefix"] {
263 if len(m) == 0 {
264 return ""
265 }
266 return string(m[0].Prefix)
267 }
268 prefixes := make([]byte, len(m))
269 for i, membership := range m {
270 prefixes[i] = membership.Prefix
271 }
272 return string(prefixes)
273}
274
275func parseMessageParams(msg *irc.Message, out ...*string) error {
276 if len(msg.Params) < len(out) {
277 return newNeedMoreParamsError(msg.Command)
278 }
279 for i := range out {
280 if out[i] != nil {
281 *out[i] = msg.Params[i]
282 }
283 }
284 return nil
285}
286
287func copyClientTags(tags irc.Tags) irc.Tags {
288 t := make(irc.Tags, len(tags))
289 for k, v := range tags {
290 if strings.HasPrefix(k, "+") {
291 t[k] = v
292 }
293 }
294 return t
295}
296
297type batch struct {
298 Type string
299 Params []string
300 Outer *batch // if not-nil, this batch is nested in Outer
301 Label string
302}
303
304func join(channels, keys []string) []*irc.Message {
305 // Put channels with a key first
306 js := joinSorter{channels, keys}
307 sort.Sort(&js)
308
309 // Two spaces because there are three words (JOIN, channels and keys)
310 maxLength := maxMessageLength - (len("JOIN") + 2)
311
312 var msgs []*irc.Message
313 var channelsBuf, keysBuf strings.Builder
314 for i, channel := range channels {
315 key := keys[i]
316
317 n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel)
318 if key != "" {
319 n += 1 + len(key)
320 }
321
322 if channelsBuf.Len() > 0 && n > maxLength {
323 // No room for the new channel in this message
324 params := []string{channelsBuf.String()}
325 if keysBuf.Len() > 0 {
326 params = append(params, keysBuf.String())
327 }
328 msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
329 channelsBuf.Reset()
330 keysBuf.Reset()
331 }
332
333 if channelsBuf.Len() > 0 {
334 channelsBuf.WriteByte(',')
335 }
336 channelsBuf.WriteString(channel)
337 if key != "" {
338 if keysBuf.Len() > 0 {
339 keysBuf.WriteByte(',')
340 }
341 keysBuf.WriteString(key)
342 }
343 }
344 if channelsBuf.Len() > 0 {
345 params := []string{channelsBuf.String()}
346 if keysBuf.Len() > 0 {
347 params = append(params, keysBuf.String())
348 }
349 msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
350 }
351
352 return msgs
353}
354
355func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message {
356 maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text
357
358 var msgs []*irc.Message
359 for len(tokens) > 0 {
360 var msgTokens []string
361 if len(tokens) > maxTokens {
362 msgTokens = tokens[:maxTokens]
363 tokens = tokens[maxTokens:]
364 } else {
365 msgTokens = tokens
366 tokens = nil
367 }
368
369 msgs = append(msgs, &irc.Message{
370 Prefix: prefix,
371 Command: irc.RPL_ISUPPORT,
372 Params: append(append([]string{nick}, msgTokens...), "are supported"),
373 })
374 }
375
376 return msgs
377}
378
379type joinSorter struct {
380 channels []string
381 keys []string
382}
383
384func (js *joinSorter) Len() int {
385 return len(js.channels)
386}
387
388func (js *joinSorter) Less(i, j int) bool {
389 if (js.keys[i] != "") != (js.keys[j] != "") {
390 // Only one of the channels has a key
391 return js.keys[i] != ""
392 }
393 return js.channels[i] < js.channels[j]
394}
395
396func (js *joinSorter) Swap(i, j int) {
397 js.channels[i], js.channels[j] = js.channels[j], js.channels[i]
398 js.keys[i], js.keys[j] = js.keys[j], js.keys[i]
399}
400
401// parseCTCPMessage parses a CTCP message. CTCP is defined in
402// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02
403func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
404 if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 {
405 return "", "", false
406 }
407 text := msg.Params[1]
408
409 if !strings.HasPrefix(text, "\x01") {
410 return "", "", false
411 }
412 text = strings.Trim(text, "\x01")
413
414 words := strings.SplitN(text, " ", 2)
415 cmd = strings.ToUpper(words[0])
416 if len(words) > 1 {
417 params = words[1]
418 }
419
420 return cmd, params, true
421}
422
423type casemapping func(string) string
424
425func casemapNone(name string) string {
426 return name
427}
428
429// CasemapASCII of name is the canonical representation of name according to the
430// ascii casemapping.
431func casemapASCII(name string) string {
432 nameBytes := []byte(name)
433 for i, r := range nameBytes {
434 if 'A' <= r && r <= 'Z' {
435 nameBytes[i] = r + 'a' - 'A'
436 }
437 }
438 return string(nameBytes)
439}
440
441// casemapRFC1459 of name is the canonical representation of name according to the
442// rfc1459 casemapping.
443func casemapRFC1459(name string) string {
444 nameBytes := []byte(name)
445 for i, r := range nameBytes {
446 if 'A' <= r && r <= 'Z' {
447 nameBytes[i] = r + 'a' - 'A'
448 } else if r == '{' {
449 nameBytes[i] = '['
450 } else if r == '}' {
451 nameBytes[i] = ']'
452 } else if r == '\\' {
453 nameBytes[i] = '|'
454 } else if r == '~' {
455 nameBytes[i] = '^'
456 }
457 }
458 return string(nameBytes)
459}
460
461// casemapRFC1459Strict of name is the canonical representation of name
462// according to the rfc1459-strict casemapping.
463func casemapRFC1459Strict(name string) string {
464 nameBytes := []byte(name)
465 for i, r := range nameBytes {
466 if 'A' <= r && r <= 'Z' {
467 nameBytes[i] = r + 'a' - 'A'
468 } else if r == '{' {
469 nameBytes[i] = '['
470 } else if r == '}' {
471 nameBytes[i] = ']'
472 } else if r == '\\' {
473 nameBytes[i] = '|'
474 }
475 }
476 return string(nameBytes)
477}
478
479func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
480 switch tokenValue {
481 case "ascii":
482 casemap = casemapASCII
483 case "rfc1459":
484 casemap = casemapRFC1459
485 case "rfc1459-strict":
486 casemap = casemapRFC1459Strict
487 default:
488 return nil, false
489 }
490 return casemap, true
491}
492
493func partialCasemap(higher casemapping, name string) string {
494 nameFullyCM := []byte(higher(name))
495 nameBytes := []byte(name)
496 for i, r := range nameBytes {
497 if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') {
498 nameBytes[i] = nameFullyCM[i]
499 }
500 }
501 return string(nameBytes)
502}
503
504type casemapMap struct {
505 innerMap map[string]casemapEntry
506 casemap casemapping
507}
508
509type casemapEntry struct {
510 originalKey string
511 value interface{}
512}
513
514func newCasemapMap(size int) casemapMap {
515 return casemapMap{
516 innerMap: make(map[string]casemapEntry, size),
517 casemap: casemapNone,
518 }
519}
520
521func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
522 entry, ok := cm.innerMap[cm.casemap(name)]
523 if !ok {
524 return "", false
525 }
526 return entry.originalKey, true
527}
528
529func (cm *casemapMap) Has(name string) bool {
530 _, ok := cm.innerMap[cm.casemap(name)]
531 return ok
532}
533
534func (cm *casemapMap) Len() int {
535 return len(cm.innerMap)
536}
537
538func (cm *casemapMap) SetValue(name string, value interface{}) {
539 nameCM := cm.casemap(name)
540 entry, ok := cm.innerMap[nameCM]
541 if !ok {
542 cm.innerMap[nameCM] = casemapEntry{
543 originalKey: name,
544 value: value,
545 }
546 return
547 }
548 entry.value = value
549 cm.innerMap[nameCM] = entry
550}
551
552func (cm *casemapMap) Delete(name string) {
553 delete(cm.innerMap, cm.casemap(name))
554}
555
556func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
557 cm.casemap = newCasemap
558 newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
559 for _, entry := range cm.innerMap {
560 newInnerMap[cm.casemap(entry.originalKey)] = entry
561 }
562 cm.innerMap = newInnerMap
563}
564
565type upstreamChannelCasemapMap struct{ casemapMap }
566
567func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
568 entry, ok := cm.innerMap[cm.casemap(name)]
569 if !ok {
570 return nil
571 }
572 return entry.value.(*upstreamChannel)
573}
574
575type channelCasemapMap struct{ casemapMap }
576
577func (cm *channelCasemapMap) Value(name string) *Channel {
578 entry, ok := cm.innerMap[cm.casemap(name)]
579 if !ok {
580 return nil
581 }
582 return entry.value.(*Channel)
583}
584
585type membershipsCasemapMap struct{ casemapMap }
586
587func (cm *membershipsCasemapMap) Value(name string) *memberships {
588 entry, ok := cm.innerMap[cm.casemap(name)]
589 if !ok {
590 return nil
591 }
592 return entry.value.(*memberships)
593}
594
595type deliveredCasemapMap struct{ casemapMap }
596
597func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
598 entry, ok := cm.innerMap[cm.casemap(name)]
599 if !ok {
600 return nil
601 }
602 return entry.value.(deliveredClientMap)
603}
Note: See TracBrowser for help on using the repository browser.