source: code/trunk/downstream.go@ 88

Last change on this file since 88 was 88, checked in by contact, 5 years ago

Fix nil dereference when network is specified in username

File size: 11.8 KB
Line 
1package jounce
2
3import (
4 "fmt"
5 "io"
6 "net"
7 "strings"
8
9 "golang.org/x/crypto/bcrypt"
10 "gopkg.in/irc.v3"
11)
12
13type ircError struct {
14 Message *irc.Message
15}
16
17func (err ircError) Error() string {
18 return err.Message.String()
19}
20
21func newUnknownCommandError(cmd string) ircError {
22 return ircError{&irc.Message{
23 Command: irc.ERR_UNKNOWNCOMMAND,
24 Params: []string{
25 "*",
26 cmd,
27 "Unknown command",
28 },
29 }}
30}
31
32func newNeedMoreParamsError(cmd string) ircError {
33 return ircError{&irc.Message{
34 Command: irc.ERR_NEEDMOREPARAMS,
35 Params: []string{
36 "*",
37 cmd,
38 "Not enough parameters",
39 },
40 }}
41}
42
43var errAuthFailed = ircError{&irc.Message{
44 Command: irc.ERR_PASSWDMISMATCH,
45 Params: []string{"*", "Invalid username or password"},
46}}
47
48type consumption struct {
49 consumer *RingConsumer
50 upstreamConn *upstreamConn
51}
52
53type downstreamConn struct {
54 net net.Conn
55 irc *irc.Conn
56 srv *Server
57 logger Logger
58 messages chan *irc.Message
59 consumptions chan consumption
60 closed chan struct{}
61
62 registered bool
63 user *user
64 nick string
65 username string
66 realname string
67 password string // empty after authentication
68 network *network // can be nil
69}
70
71func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn {
72 dc := &downstreamConn{
73 net: netConn,
74 irc: irc.NewConn(netConn),
75 srv: srv,
76 logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())},
77 messages: make(chan *irc.Message, 64),
78 consumptions: make(chan consumption),
79 closed: make(chan struct{}),
80 }
81
82 go func() {
83 if err := dc.writeMessages(); err != nil {
84 dc.logger.Printf("failed to write message: %v", err)
85 }
86 if err := dc.net.Close(); err != nil {
87 dc.logger.Printf("failed to close connection: %v", err)
88 } else {
89 dc.logger.Printf("connection closed")
90 }
91 }()
92
93 return dc
94}
95
96func (dc *downstreamConn) prefix() *irc.Prefix {
97 return &irc.Prefix{
98 Name: dc.nick,
99 User: dc.username,
100 // TODO: fill the host?
101 }
102}
103
104func (dc *downstreamConn) marshalChannel(uc *upstreamConn, name string) string {
105 return name
106}
107
108func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
109 dc.user.forEachUpstream(func(uc *upstreamConn) {
110 if dc.network != nil && uc.network != dc.network {
111 return
112 }
113 f(uc)
114 })
115}
116
117func (dc *downstreamConn) unmarshalChannel(name string) (*upstreamConn, string, error) {
118 // TODO: extract network name from channel name if dc.upstream == nil
119 var channel *upstreamChannel
120 var err error
121 dc.forEachUpstream(func(uc *upstreamConn) {
122 if err != nil {
123 return
124 }
125 if ch, ok := uc.channels[name]; ok {
126 if channel != nil {
127 err = fmt.Errorf("ambiguous channel name %q", name)
128 } else {
129 channel = ch
130 }
131 }
132 })
133 if channel == nil {
134 return nil, "", ircError{&irc.Message{
135 Command: irc.ERR_NOSUCHCHANNEL,
136 Params: []string{name, "No such channel"},
137 }}
138 }
139 return channel.conn, channel.Name, nil
140}
141
142func (dc *downstreamConn) marshalNick(uc *upstreamConn, nick string) string {
143 if nick == uc.nick {
144 return dc.nick
145 }
146 return nick
147}
148
149func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix {
150 if prefix.Name == uc.nick {
151 return dc.prefix()
152 }
153 return prefix
154}
155
156func (dc *downstreamConn) isClosed() bool {
157 select {
158 case <-dc.closed:
159 return true
160 default:
161 return false
162 }
163}
164
165func (dc *downstreamConn) readMessages() error {
166 dc.logger.Printf("new connection")
167
168 for {
169 msg, err := dc.irc.ReadMessage()
170 if err == io.EOF {
171 break
172 } else if err != nil {
173 return fmt.Errorf("failed to read IRC command: %v", err)
174 }
175
176 if dc.srv.Debug {
177 dc.logger.Printf("received: %v", msg)
178 }
179
180 err = dc.handleMessage(msg)
181 if ircErr, ok := err.(ircError); ok {
182 ircErr.Message.Prefix = dc.srv.prefix()
183 dc.SendMessage(ircErr.Message)
184 } else if err != nil {
185 return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err)
186 }
187
188 if dc.isClosed() {
189 return nil
190 }
191 }
192
193 return nil
194}
195
196func (dc *downstreamConn) writeMessages() error {
197 for {
198 var err error
199 var closed bool
200 select {
201 case msg := <-dc.messages:
202 if dc.srv.Debug {
203 dc.logger.Printf("sent: %v", msg)
204 }
205 err = dc.irc.WriteMessage(msg)
206 case consumption := <-dc.consumptions:
207 consumer, uc := consumption.consumer, consumption.upstreamConn
208 for {
209 msg := consumer.Peek()
210 if msg == nil {
211 break
212 }
213 msg = msg.Copy()
214 switch msg.Command {
215 case "PRIVMSG":
216 // TODO: detect whether it's a user or a channel
217 msg.Params[0] = dc.marshalChannel(uc, msg.Params[0])
218 default:
219 panic("expected to consume a PRIVMSG message")
220 }
221 if dc.srv.Debug {
222 dc.logger.Printf("sent: %v", msg)
223 }
224 err = dc.irc.WriteMessage(msg)
225 if err != nil {
226 break
227 }
228 consumer.Consume()
229 }
230 case <-dc.closed:
231 closed = true
232 }
233 if err != nil {
234 return err
235 }
236 if closed {
237 break
238 }
239 }
240 return nil
241}
242
243func (dc *downstreamConn) Close() error {
244 if dc.isClosed() {
245 return fmt.Errorf("downstream connection already closed")
246 }
247
248 if u := dc.user; u != nil {
249 u.lock.Lock()
250 for i := range u.downstreamConns {
251 if u.downstreamConns[i] == dc {
252 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
253 break
254 }
255 }
256 u.lock.Unlock()
257 }
258
259 close(dc.closed)
260 return nil
261}
262
263func (dc *downstreamConn) SendMessage(msg *irc.Message) {
264 dc.messages <- msg
265}
266
267func (dc *downstreamConn) handleMessage(msg *irc.Message) error {
268 switch msg.Command {
269 case "QUIT":
270 return dc.Close()
271 case "PING":
272 dc.SendMessage(&irc.Message{
273 Prefix: dc.srv.prefix(),
274 Command: "PONG",
275 Params: msg.Params,
276 })
277 return nil
278 default:
279 if dc.registered {
280 return dc.handleMessageRegistered(msg)
281 } else {
282 return dc.handleMessageUnregistered(msg)
283 }
284 }
285}
286
287func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
288 switch msg.Command {
289 case "NICK":
290 if err := parseMessageParams(msg, &dc.nick); err != nil {
291 return err
292 }
293 case "USER":
294 var username string
295 if err := parseMessageParams(msg, &username, nil, nil, &dc.realname); err != nil {
296 return err
297 }
298 dc.username = "~" + username
299 case "PASS":
300 if err := parseMessageParams(msg, &dc.password); err != nil {
301 return err
302 }
303 default:
304 dc.logger.Printf("unhandled message: %v", msg)
305 return newUnknownCommandError(msg.Command)
306 }
307 if dc.username != "" && dc.nick != "" {
308 return dc.register()
309 }
310 return nil
311}
312
313func (dc *downstreamConn) register() error {
314 username := strings.TrimPrefix(dc.username, "~")
315 var networkName string
316 if i := strings.LastIndexAny(username, "/@"); i >= 0 {
317 networkName = username[i+1:]
318 }
319 if i := strings.IndexAny(username, "/@"); i >= 0 {
320 username = username[:i]
321 }
322
323 password := dc.password
324 dc.password = ""
325
326 u := dc.srv.getUser(username)
327 if u == nil {
328 dc.logger.Printf("failed authentication for %q: unknown username", username)
329 return errAuthFailed
330 }
331
332 err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
333 if err != nil {
334 dc.logger.Printf("failed authentication for %q: %v", username, err)
335 return errAuthFailed
336 }
337
338 var network *network
339 if networkName != "" {
340 network = u.getNetwork(networkName)
341 if network == nil {
342 dc.logger.Printf("failed registration: unknown network %q", networkName)
343 dc.SendMessage(&irc.Message{
344 Prefix: dc.srv.prefix(),
345 Command: irc.ERR_PASSWDMISMATCH,
346 Params: []string{"*", fmt.Sprintf("Unknown network %q", networkName)},
347 })
348 return nil
349 }
350 }
351
352 dc.registered = true
353 dc.user = u
354 dc.network = network
355
356 u.lock.Lock()
357 firstDownstream := len(u.downstreamConns) == 0
358 u.downstreamConns = append(u.downstreamConns, dc)
359 u.lock.Unlock()
360
361 dc.SendMessage(&irc.Message{
362 Prefix: dc.srv.prefix(),
363 Command: irc.RPL_WELCOME,
364 Params: []string{dc.nick, "Welcome to jounce, " + dc.nick},
365 })
366 dc.SendMessage(&irc.Message{
367 Prefix: dc.srv.prefix(),
368 Command: irc.RPL_YOURHOST,
369 Params: []string{dc.nick, "Your host is " + dc.srv.Hostname},
370 })
371 dc.SendMessage(&irc.Message{
372 Prefix: dc.srv.prefix(),
373 Command: irc.RPL_CREATED,
374 Params: []string{dc.nick, "Who cares when the server was created?"},
375 })
376 dc.SendMessage(&irc.Message{
377 Prefix: dc.srv.prefix(),
378 Command: irc.RPL_MYINFO,
379 Params: []string{dc.nick, dc.srv.Hostname, "jounce", "aiwroO", "OovaimnqpsrtklbeI"},
380 })
381 dc.SendMessage(&irc.Message{
382 Prefix: dc.srv.prefix(),
383 Command: irc.ERR_NOMOTD,
384 Params: []string{dc.nick, "No MOTD"},
385 })
386
387 dc.forEachUpstream(func(uc *upstreamConn) {
388 // TODO: fix races accessing upstream connection data
389 for _, ch := range uc.channels {
390 if ch.complete {
391 forwardChannel(dc, ch)
392 }
393 }
394
395 historyName := dc.username
396
397 var seqPtr *uint64
398 if firstDownstream {
399 seq, ok := uc.history[historyName]
400 if ok {
401 seqPtr = &seq
402 }
403 }
404
405 consumer, ch := uc.ring.NewConsumer(seqPtr)
406 go func() {
407 for {
408 var closed bool
409 select {
410 case <-ch:
411 dc.consumptions <- consumption{consumer, uc}
412 case <-dc.closed:
413 closed = true
414 }
415 if closed {
416 break
417 }
418 }
419
420 seq := consumer.Close()
421
422 dc.user.lock.Lock()
423 lastDownstream := len(dc.user.downstreamConns) == 0
424 dc.user.lock.Unlock()
425
426 if lastDownstream {
427 uc.history[historyName] = seq
428 }
429 }()
430 })
431
432 return nil
433}
434
435func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
436 switch msg.Command {
437 case "USER":
438 return ircError{&irc.Message{
439 Command: irc.ERR_ALREADYREGISTERED,
440 Params: []string{dc.nick, "You may not reregister"},
441 }}
442 case "NICK":
443 dc.forEachUpstream(func(uc *upstreamConn) {
444 uc.SendMessage(msg)
445 })
446 case "JOIN", "PART":
447 var name string
448 if err := parseMessageParams(msg, &name); err != nil {
449 return err
450 }
451
452 uc, upstreamName, err := dc.unmarshalChannel(name)
453 if err != nil {
454 return ircError{&irc.Message{
455 Command: irc.ERR_NOSUCHCHANNEL,
456 Params: []string{name, err.Error()},
457 }}
458 }
459
460 uc.SendMessage(&irc.Message{
461 Command: msg.Command,
462 Params: []string{upstreamName},
463 })
464 // TODO: add/remove channel from upstream config
465 case "MODE":
466 if msg.Prefix == nil {
467 return fmt.Errorf("missing prefix")
468 }
469
470 var name string
471 if err := parseMessageParams(msg, &name); err != nil {
472 return err
473 }
474
475 var modeStr string
476 if len(msg.Params) > 1 {
477 modeStr = msg.Params[1]
478 }
479
480 if msg.Prefix.Name != name {
481 uc, upstreamName, err := dc.unmarshalChannel(name)
482 if err != nil {
483 return err
484 }
485
486 if modeStr != "" {
487 uc.SendMessage(&irc.Message{
488 Command: "MODE",
489 Params: []string{upstreamName, modeStr},
490 })
491 } else {
492 ch, ok := uc.channels[upstreamName]
493 if !ok {
494 return ircError{&irc.Message{
495 Command: irc.ERR_NOSUCHCHANNEL,
496 Params: []string{name, "No such channel"},
497 }}
498 }
499
500 dc.SendMessage(&irc.Message{
501 Prefix: dc.srv.prefix(),
502 Command: irc.RPL_CHANNELMODEIS,
503 Params: []string{name, string(ch.modes)},
504 })
505 }
506 } else {
507 if name != dc.nick {
508 return ircError{&irc.Message{
509 Command: irc.ERR_USERSDONTMATCH,
510 Params: []string{dc.nick, "Cannot change mode for other users"},
511 }}
512 }
513
514 if modeStr != "" {
515 dc.forEachUpstream(func(uc *upstreamConn) {
516 uc.SendMessage(&irc.Message{
517 Command: "MODE",
518 Params: []string{uc.nick, modeStr},
519 })
520 })
521 } else {
522 dc.SendMessage(&irc.Message{
523 Prefix: dc.srv.prefix(),
524 Command: irc.RPL_UMODEIS,
525 Params: []string{""}, // TODO
526 })
527 }
528 }
529 case "PRIVMSG":
530 var targetsStr, text string
531 if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
532 return err
533 }
534
535 for _, name := range strings.Split(targetsStr, ",") {
536 uc, upstreamName, err := dc.unmarshalChannel(name)
537 if err != nil {
538 return err
539 }
540
541 uc.SendMessage(&irc.Message{
542 Command: "PRIVMSG",
543 Params: []string{upstreamName, text},
544 })
545 }
546 default:
547 dc.logger.Printf("unhandled message: %v", msg)
548 return newUnknownCommandError(msg.Command)
549 }
550 return nil
551}
Note: See TracBrowser for help on using the repository browser.