source: code/trunk/upstream.go@ 33

Last change on this file since 33 was 33, checked in by contact, 5 years ago

Use a dedicated goroutine to write upstream messages

File size: 5.8 KB
Line 
1package jounce
2
3import (
4 "crypto/tls"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "strings"
10 "time"
11
12 "gopkg.in/irc.v3"
13)
14
15type upstreamChannel struct {
16 Name string
17 Topic string
18 TopicWho string
19 TopicTime time.Time
20 Status channelStatus
21 Members map[string]membership
22 complete bool
23}
24
25type upstreamConn struct {
26 upstream *Upstream
27 logger Logger
28 net net.Conn
29 irc *irc.Conn
30 srv *Server
31 messages chan<- *irc.Message
32
33 serverName string
34 availableUserModes string
35 availableChannelModes string
36 channelModesWithParam string
37
38 registered bool
39 closed bool
40 modes modeSet
41 channels map[string]*upstreamChannel
42}
43
44func connectToUpstream(s *Server, upstream *Upstream) (*upstreamConn, error) {
45 logger := &prefixLogger{s.Logger, fmt.Sprintf("upstream %q: ", upstream.Addr)}
46 logger.Printf("connecting to server")
47
48 netConn, err := tls.Dial("tcp", upstream.Addr, nil)
49 if err != nil {
50 return nil, fmt.Errorf("failed to dial %q: %v", upstream.Addr, err)
51 }
52
53 msgs := make(chan *irc.Message, 64)
54 conn := &upstreamConn{
55 upstream: upstream,
56 logger: logger,
57 net: netConn,
58 irc: irc.NewConn(netConn),
59 srv: s,
60 messages: msgs,
61 channels: make(map[string]*upstreamChannel),
62 }
63
64 go func() {
65 for msg := range msgs {
66 if err := conn.irc.WriteMessage(msg); err != nil {
67 conn.logger.Printf("failed to write message: %v", err)
68 }
69 }
70 }()
71
72 return conn, nil
73}
74
75func (c *upstreamConn) Close() error {
76 if c.closed {
77 return fmt.Errorf("upstream connection already closed")
78 }
79 if err := c.net.Close(); err != nil {
80 return err
81 }
82 close(c.messages)
83 c.closed = true
84 return nil
85}
86
87func (c *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
88 ch, ok := c.channels[name]
89 if !ok {
90 return nil, fmt.Errorf("unknown channel %q", name)
91 }
92 return ch, nil
93}
94
95func (c *upstreamConn) handleMessage(msg *irc.Message) error {
96 switch msg.Command {
97 case "PING":
98 // TODO: handle params
99 c.messages <- &irc.Message{
100 Command: "PONG",
101 Params: []string{c.srv.Hostname},
102 }
103 return nil
104 case "MODE":
105 if len(msg.Params) < 2 {
106 return newNeedMoreParamsError(msg.Command)
107 }
108 if nick := msg.Params[0]; nick != c.upstream.Nick {
109 return fmt.Errorf("received MODE message for unknow nick %q", nick)
110 }
111 return c.modes.Apply(msg.Params[1])
112 case "NOTICE":
113 c.logger.Print(msg)
114 case irc.RPL_WELCOME:
115 c.registered = true
116 c.logger.Printf("connection registered")
117
118 for _, ch := range c.upstream.Channels {
119 c.messages <- &irc.Message{
120 Command: "JOIN",
121 Params: []string{ch},
122 }
123 }
124 case irc.RPL_MYINFO:
125 if len(msg.Params) < 5 {
126 return newNeedMoreParamsError(msg.Command)
127 }
128 c.serverName = msg.Params[1]
129 c.availableUserModes = msg.Params[3]
130 c.availableChannelModes = msg.Params[4]
131 if len(msg.Params) > 5 {
132 c.channelModesWithParam = msg.Params[5]
133 }
134 case "JOIN":
135 if len(msg.Params) < 1 {
136 return newNeedMoreParamsError(msg.Command)
137 }
138 for _, ch := range strings.Split(msg.Params[0], ",") {
139 c.logger.Printf("joined channel %q", ch)
140 c.channels[ch] = &upstreamChannel{
141 Name: ch,
142 Members: make(map[string]membership),
143 }
144 }
145 case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
146 if len(msg.Params) < 3 {
147 return newNeedMoreParamsError(msg.Command)
148 }
149 ch, err := c.getChannel(msg.Params[1])
150 if err != nil {
151 return err
152 }
153 if msg.Command == irc.RPL_TOPIC {
154 ch.Topic = msg.Params[2]
155 } else {
156 ch.Topic = ""
157 }
158 case "TOPIC":
159 if len(msg.Params) < 1 {
160 return newNeedMoreParamsError(msg.Command)
161 }
162 ch, err := c.getChannel(msg.Params[0])
163 if err != nil {
164 return err
165 }
166 if len(msg.Params) > 1 {
167 ch.Topic = msg.Params[1]
168 } else {
169 ch.Topic = ""
170 }
171 case rpl_topicwhotime:
172 if len(msg.Params) < 4 {
173 return newNeedMoreParamsError(msg.Command)
174 }
175 ch, err := c.getChannel(msg.Params[1])
176 if err != nil {
177 return err
178 }
179 ch.TopicWho = msg.Params[2]
180 sec, err := strconv.ParseInt(msg.Params[3], 10, 64)
181 if err != nil {
182 return fmt.Errorf("failed to parse topic time: %v", err)
183 }
184 ch.TopicTime = time.Unix(sec, 0)
185 case irc.RPL_NAMREPLY:
186 if len(msg.Params) < 4 {
187 return newNeedMoreParamsError(msg.Command)
188 }
189 ch, err := c.getChannel(msg.Params[2])
190 if err != nil {
191 return err
192 }
193
194 status, err := parseChannelStatus(msg.Params[1])
195 if err != nil {
196 return err
197 }
198 ch.Status = status
199
200 for _, s := range strings.Split(msg.Params[3], " ") {
201 membership, nick := parseMembershipPrefix(s)
202 ch.Members[nick] = membership
203 }
204 case irc.RPL_ENDOFNAMES:
205 if len(msg.Params) < 2 {
206 return newNeedMoreParamsError(msg.Command)
207 }
208 ch, err := c.getChannel(msg.Params[1])
209 if err != nil {
210 return err
211 }
212
213 ch.complete = true
214
215 c.srv.lock.Lock()
216 for _, dc := range c.srv.downstreamConns {
217 forwardChannel(dc, ch)
218 }
219 c.srv.lock.Unlock()
220 case irc.RPL_YOURHOST, irc.RPL_CREATED:
221 // Ignore
222 case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
223 // Ignore
224 case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
225 // Ignore
226 case rpl_localusers, rpl_globalusers:
227 // Ignore
228 case irc.RPL_STATSVLINE, irc.RPL_STATSPING, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
229 // Ignore
230 default:
231 c.logger.Printf("unhandled upstream message: %v", msg)
232 }
233 return nil
234}
235
236func (c *upstreamConn) readMessages() error {
237 defer c.Close()
238
239 c.messages <- &irc.Message{
240 Command: "NICK",
241 Params: []string{c.upstream.Nick},
242 }
243
244 c.messages <- &irc.Message{
245 Command: "USER",
246 Params: []string{c.upstream.Username, "0", "*", c.upstream.Realname},
247 }
248
249 for {
250 msg, err := c.irc.ReadMessage()
251 if err == io.EOF {
252 break
253 } else if err != nil {
254 return fmt.Errorf("failed to read IRC command: %v", err)
255 }
256
257 if err := c.handleMessage(msg); err != nil {
258 c.logger.Printf("failed to handle message %q: %v", msg, err)
259 }
260 }
261
262 return c.Close()
263}
Note: See TracBrowser for help on using the repository browser.