source: code/trunk/upstream.go@ 42

Last change on this file since 42 was 42, checked in by contact, 5 years ago

Allow changing nickname

File size: 7.5 KB
Line 
1package jounce
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "strings"
10 "time"
11
12 "gopkg.in/irc.v3"
13)
14
15type upstreamChannel struct {
16 Name string
17 Topic string
18 TopicWho string
19 TopicTime time.Time
20 Status channelStatus
21 modes modeSet
22 Members map[string]membership
23 complete bool
24}
25
26type upstreamConn struct {
27 upstream *Upstream
28 logger Logger
29 net net.Conn
30 irc *irc.Conn
31 srv *Server
32 user *user
33 messages chan<- *irc.Message
34
35 serverName string
36 availableUserModes string
37 availableChannelModes string
38 channelModesWithParam string
39
40 registered bool
41 nick string
42 closed bool
43 modes modeSet
44 channels map[string]*upstreamChannel
45}
46
47func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) {
48 logger := &prefixLogger{u.srv.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
49 logger.Printf("connecting to server")
50
51 netConn, err := tls.Dial("tcp", upstream.Addr, nil)
52 if err != nil {
53 return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
54 }
55
56 msgs := make(chan *irc.Message, 64)
57 conn := &upstreamConn{
58 upstream: upstream,
59 logger: logger,
60 net: netConn,
61 irc: irc.NewConn(netConn),
62 srv: u.srv,
63 user: u,
64 messages: msgs,
65 channels: make(map[string]*upstreamChannel),
66 }
67
68 go func() {
69 for msg := range msgs {
70 if err := conn.irc.WriteMessage(msg); err != nil {
71 conn.logger.Printf("failed to write message: %v", err)
72 }
73 }
74 }()
75
76 return conn, nil
77}
78
79func (c *upstreamConn) Close() error {
80 if c.closed {
81 return fmt.Errorf("upstream connection already closed")
82 }
83 if err := c.net.Close(); err != nil {
84 return err
85 }
86 close(c.messages)
87 c.closed = true
88 return nil
89}
90
91func (c *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
92 ch, ok := c.channels[name]
93 if !ok {
94 return nil, fmt.Errorf("unknown channel %q", name)
95 }
96 return ch, nil
97}
98
99func (c *upstreamConn) handleMessage(msg *irc.Message) error {
100 switch msg.Command {
101 case "PING":
102 // TODO: handle params
103 c.messages <- &irc.Message{
104 Command: "PONG",
105 Params: []string{c.srv.Hostname},
106 }
107 return nil
108 case "MODE":
109 if len(msg.Params) < 2 {
110 return newNeedMoreParamsError(msg.Command)
111 }
112 name := msg.Params[0]
113 modeStr := msg.Params[1]
114
115 if name == msg.Prefix.Name { // user mode change
116 if name != c.nick {
117 return fmt.Errorf("received MODE message for unknow nick %q", name)
118 }
119 return c.modes.Apply(modeStr)
120 } else { // channel mode change
121 ch, err := c.getChannel(name)
122 if err != nil {
123 return err
124 }
125 if err := ch.modes.Apply(modeStr); err != nil {
126 return err
127 }
128
129 c.user.forEachDownstream(func(dc *downstreamConn) {
130 dc.messages <- msg
131 })
132 }
133 case "NOTICE":
134 c.logger.Print(msg)
135 case irc.RPL_WELCOME:
136 c.registered = true
137 c.logger.Printf("connection registered")
138
139 for _, ch := range c.upstream.Channels {
140 c.messages <- &irc.Message{
141 Command: "JOIN",
142 Params: []string{ch},
143 }
144 }
145 case irc.RPL_MYINFO:
146 if len(msg.Params) < 5 {
147 return newNeedMoreParamsError(msg.Command)
148 }
149 c.serverName = msg.Params[1]
150 c.availableUserModes = msg.Params[3]
151 c.availableChannelModes = msg.Params[4]
152 if len(msg.Params) > 5 {
153 c.channelModesWithParam = msg.Params[5]
154 }
155 case "NICK":
156 if len(msg.Params) < 1 {
157 return newNeedMoreParamsError(msg.Command)
158 }
159 newNick := msg.Params[0]
160
161 if msg.Prefix.Name == c.nick {
162 c.logger.Printf("changed nick from %q to %q", c.nick, newNick)
163 c.nick = newNick
164 }
165
166 for _, ch := range c.channels {
167 if membership, ok := ch.Members[msg.Prefix.Name]; ok {
168 delete(ch.Members, msg.Prefix.Name)
169 ch.Members[newNick] = membership
170 }
171 }
172
173 c.user.forEachDownstream(func(dc *downstreamConn) {
174 dc.messages <- msg
175 })
176 case "JOIN":
177 if len(msg.Params) < 1 {
178 return newNeedMoreParamsError(msg.Command)
179 }
180
181 for _, ch := range strings.Split(msg.Params[0], ",") {
182 if msg.Prefix.Name == c.nick {
183 c.logger.Printf("joined channel %q", ch)
184 c.channels[ch] = &upstreamChannel{
185 Name: ch,
186 Members: make(map[string]membership),
187 }
188 } else {
189 ch, err := c.getChannel(ch)
190 if err != nil {
191 return err
192 }
193 ch.Members[msg.Prefix.Name] = 0
194 }
195 }
196
197 c.user.forEachDownstream(func(dc *downstreamConn) {
198 dc.messages <- msg
199 })
200 case "PART":
201 if len(msg.Params) < 1 {
202 return newNeedMoreParamsError(msg.Command)
203 }
204
205 for _, ch := range strings.Split(msg.Params[0], ",") {
206 if msg.Prefix.Name == c.nick {
207 c.logger.Printf("parted channel %q", ch)
208 delete(c.channels, ch)
209 } else {
210 ch, err := c.getChannel(ch)
211 if err != nil {
212 return err
213 }
214 delete(ch.Members, msg.Prefix.Name)
215 }
216 }
217
218 c.user.forEachDownstream(func(dc *downstreamConn) {
219 dc.messages <- msg
220 })
221 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
222 if len(msg.Params) < 3 {
223 return newNeedMoreParamsError(msg.Command)
224 }
225 ch, err := c.getChannel(msg.Params[1])
226 if err != nil {
227 return err
228 }
229 if msg.Command == irc.RPL_TOPIC {
230 ch.Topic = msg.Params[2]
231 } else {
232 ch.Topic = ""
233 }
234 case "TOPIC":
235 if len(msg.Params) < 1 {
236 return newNeedMoreParamsError(msg.Command)
237 }
238 ch, err := c.getChannel(msg.Params[0])
239 if err != nil {
240 return err
241 }
242 if len(msg.Params) > 1 {
243 ch.Topic = msg.Params[1]
244 } else {
245 ch.Topic = ""
246 }
247 case rpl_topicwhotime:
248 if len(msg.Params) < 4 {
249 return newNeedMoreParamsError(msg.Command)
250 }
251 ch, err := c.getChannel(msg.Params[1])
252 if err != nil {
253 return err
254 }
255 ch.TopicWho = msg.Params[2]
256 sec, err := strconv.ParseInt(msg.Params[3], 10, 64)
257 if err != nil {
258 return fmt.Errorf("failed to parse topic time: %v", err)
259 }
260 ch.TopicTime = time.Unix(sec, 0)
261 case irc.RPL_NAMREPLY:
262 if len(msg.Params) < 4 {
263 return newNeedMoreParamsError(msg.Command)
264 }
265 ch, err := c.getChannel(msg.Params[2])
266 if err != nil {
267 return err
268 }
269
270 status, err := parseChannelStatus(msg.Params[1])
271 if err != nil {
272 return err
273 }
274 ch.Status = status
275
276 for _, s := range strings.Split(msg.Params[3], " ") {
277 membership, nick := parseMembershipPrefix(s)
278 ch.Members[nick] = membership
279 }
280 case irc.RPL_ENDOFNAMES:
281 if len(msg.Params) < 2 {
282 return newNeedMoreParamsError(msg.Command)
283 }
284 ch, err := c.getChannel(msg.Params[1])
285 if err != nil {
286 return err
287 }
288
289 if ch.complete {
290 return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
291 }
292 ch.complete = true
293
294 c.user.forEachDownstream(func(dc *downstreamConn) {
295 forwardChannel(dc, ch)
296 })
297 case "PRIVMSG":
298 c.user.forEachDownstream(func(dc *downstreamConn) {
299 dc.messages <- msg
300 })
301 case irc.RPL_YOURHOST, irc.RPL_CREATED:
302 // Ignore
303 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
304 // Ignore
305 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
306 // Ignore
307 case rpl_localusers, rpl_globalusers:
308 // Ignore
309 case irc.RPL_STATSVLINE, irc.RPL_STATSPING, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
310 // Ignore
311 default:
312 c.logger.Printf("unhandled upstream message: %v", msg)
313 }
314 return nil
315}
316
317func (c *upstreamConn) readMessages() error {
318 defer c.Close()
319
320 c.nick = c.upstream.Nick
321 c.messages <- &irc.Message{
322 Command: "NICK",
323 Params: []string{c.upstream.Nick},
324 }
325
326 c.messages <- &irc.Message{
327 Command: "USER",
328 Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname},
329 }
330
331 for {
332 msg, err := c.irc.ReadMessage()
333 if err == io.EOF {
334 break
335 } else if err != nil {
336 return fmt.Errorf("failed to read IRC command: %v", err)
337 }
338
339 if err := c.handleMessage(msg); err != nil {
340 c.logger.Printf("failed to handle message %q: %v", msg, err)
341 }
342 }
343
344 return c.Close()
345}
Note: See TracBrowser for help on using the repository browser.