1 | package irc
|
---|
2 |
|
---|
3 | import (
|
---|
4 | "context"
|
---|
5 | "errors"
|
---|
6 | "fmt"
|
---|
7 | "io"
|
---|
8 | "sync"
|
---|
9 | "time"
|
---|
10 | )
|
---|
11 |
|
---|
12 | // ClientConfig is a structure used to configure a Client.
|
---|
13 | type ClientConfig struct {
|
---|
14 | // General connection information.
|
---|
15 | Nick string
|
---|
16 | Pass string
|
---|
17 | User string
|
---|
18 | Name string
|
---|
19 |
|
---|
20 | // Connection settings
|
---|
21 | PingFrequency time.Duration
|
---|
22 | PingTimeout time.Duration
|
---|
23 |
|
---|
24 | // SendLimit is how frequent messages can be sent. If this is zero,
|
---|
25 | // there will be no limit.
|
---|
26 | SendLimit time.Duration
|
---|
27 |
|
---|
28 | // SendBurst is the number of messages which can be sent in a burst.
|
---|
29 | SendBurst int
|
---|
30 |
|
---|
31 | // Handler is used for message dispatching.
|
---|
32 | Handler Handler
|
---|
33 | }
|
---|
34 |
|
---|
35 | type cap struct {
|
---|
36 | // Requested means that this cap was requested by the user
|
---|
37 | Requested bool
|
---|
38 |
|
---|
39 | // Required will be true if this cap is non-optional
|
---|
40 | Required bool
|
---|
41 |
|
---|
42 | // Enabled means that this cap was accepted by the server
|
---|
43 | Enabled bool
|
---|
44 |
|
---|
45 | // Available means that the server supports this cap
|
---|
46 | Available bool
|
---|
47 | }
|
---|
48 |
|
---|
49 | // Client is a wrapper around Conn which is designed to make common operations
|
---|
50 | // much simpler.
|
---|
51 | type Client struct {
|
---|
52 | *Conn
|
---|
53 | rwc io.ReadWriteCloser
|
---|
54 | config ClientConfig
|
---|
55 |
|
---|
56 | // Internal state
|
---|
57 | currentNick string
|
---|
58 | limiter chan struct{}
|
---|
59 | incomingPongChan chan string
|
---|
60 | errChan chan error
|
---|
61 | caps map[string]cap
|
---|
62 | remainingCapResponses int
|
---|
63 | connected bool
|
---|
64 | }
|
---|
65 |
|
---|
66 | // NewClient creates a client given an io stream and a client config.
|
---|
67 | func NewClient(rwc io.ReadWriteCloser, config ClientConfig) *Client {
|
---|
68 | c := &Client{
|
---|
69 | Conn: NewConn(rwc),
|
---|
70 | rwc: rwc,
|
---|
71 | config: config,
|
---|
72 | errChan: make(chan error, 1),
|
---|
73 | caps: make(map[string]cap),
|
---|
74 | }
|
---|
75 |
|
---|
76 | // Replace the writer writeCallback with one of our own
|
---|
77 | c.Conn.Writer.writeCallback = c.writeCallback
|
---|
78 |
|
---|
79 | return c
|
---|
80 | }
|
---|
81 |
|
---|
82 | func (c *Client) writeCallback(w *Writer, line string) error {
|
---|
83 | if c.limiter != nil {
|
---|
84 | <-c.limiter
|
---|
85 | }
|
---|
86 |
|
---|
87 | _, err := w.writer.Write([]byte(line + "\r\n"))
|
---|
88 | if err != nil {
|
---|
89 | c.sendError(err)
|
---|
90 | }
|
---|
91 | return err
|
---|
92 | }
|
---|
93 |
|
---|
94 | // maybeStartLimiter will start a ticker which will limit how quickly messages
|
---|
95 | // can be written to the connection if the SendLimit is set in the config.
|
---|
96 | func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, exiting chan struct{}) {
|
---|
97 | if c.config.SendLimit == 0 {
|
---|
98 | return
|
---|
99 | }
|
---|
100 |
|
---|
101 | wg.Add(1)
|
---|
102 |
|
---|
103 | // If SendBurst is 0, this will be unbuffered, so keep that in mind.
|
---|
104 | c.limiter = make(chan struct{}, c.config.SendBurst)
|
---|
105 | limitTick := time.NewTicker(c.config.SendLimit)
|
---|
106 |
|
---|
107 | go func() {
|
---|
108 | defer wg.Done()
|
---|
109 |
|
---|
110 | var done bool
|
---|
111 | for !done {
|
---|
112 | select {
|
---|
113 | case <-limitTick.C:
|
---|
114 | select {
|
---|
115 | case c.limiter <- struct{}{}:
|
---|
116 | default:
|
---|
117 | }
|
---|
118 | case <-exiting:
|
---|
119 | done = true
|
---|
120 | }
|
---|
121 | }
|
---|
122 |
|
---|
123 | limitTick.Stop()
|
---|
124 | close(c.limiter)
|
---|
125 | c.limiter = nil
|
---|
126 | }()
|
---|
127 | }
|
---|
128 |
|
---|
129 | // maybeStartPingLoop will start a goroutine to send out PING messages at the
|
---|
130 | // PingFrequency in the config if the frequency is not 0.
|
---|
131 | func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) {
|
---|
132 | if c.config.PingFrequency <= 0 {
|
---|
133 | return
|
---|
134 | }
|
---|
135 |
|
---|
136 | wg.Add(1)
|
---|
137 |
|
---|
138 | c.incomingPongChan = make(chan string, 5)
|
---|
139 |
|
---|
140 | go func() {
|
---|
141 | defer wg.Done()
|
---|
142 |
|
---|
143 | pingHandlers := make(map[string]chan struct{})
|
---|
144 | ticker := time.NewTicker(c.config.PingFrequency)
|
---|
145 |
|
---|
146 | defer ticker.Stop()
|
---|
147 |
|
---|
148 | for {
|
---|
149 | select {
|
---|
150 | case <-ticker.C:
|
---|
151 | // Each time we get a tick, we send off a ping and start a
|
---|
152 | // goroutine to handle the pong.
|
---|
153 | timestamp := time.Now().Unix()
|
---|
154 | pongChan := make(chan struct{}, 1)
|
---|
155 | pingHandlers[fmt.Sprintf("%d", timestamp)] = pongChan
|
---|
156 | wg.Add(1)
|
---|
157 | go c.handlePing(timestamp, pongChan, wg, exiting)
|
---|
158 | case data := <-c.incomingPongChan:
|
---|
159 | // Make sure the pong gets routed to the correct
|
---|
160 | // goroutine.
|
---|
161 |
|
---|
162 | c := pingHandlers[data]
|
---|
163 | delete(pingHandlers, data)
|
---|
164 |
|
---|
165 | if c != nil {
|
---|
166 | c <- struct{}{}
|
---|
167 | }
|
---|
168 | case <-exiting:
|
---|
169 | return
|
---|
170 | }
|
---|
171 | }
|
---|
172 | }()
|
---|
173 | }
|
---|
174 |
|
---|
175 | func (c *Client) handlePing(timestamp int64, pongChan chan struct{}, wg *sync.WaitGroup, exiting chan struct{}) {
|
---|
176 | defer wg.Done()
|
---|
177 |
|
---|
178 | c.Writef("PING :%d", timestamp)
|
---|
179 |
|
---|
180 | timer := time.NewTimer(c.config.PingTimeout)
|
---|
181 | defer timer.Stop()
|
---|
182 |
|
---|
183 | select {
|
---|
184 | case <-timer.C:
|
---|
185 | c.sendError(errors.New("ping timeout"))
|
---|
186 | case <-pongChan:
|
---|
187 | return
|
---|
188 | case <-exiting:
|
---|
189 | return
|
---|
190 | }
|
---|
191 | }
|
---|
192 |
|
---|
193 | // maybeStartCapHandshake will run a CAP LS and all the relevant CAP REQ
|
---|
194 | // commands if there are any CAPs requested.
|
---|
195 | func (c *Client) maybeStartCapHandshake() {
|
---|
196 | if len(c.caps) == 0 {
|
---|
197 | return
|
---|
198 | }
|
---|
199 |
|
---|
200 | c.Write("CAP LS")
|
---|
201 | c.remainingCapResponses = 1 // We count the CAP LS response as a normal response
|
---|
202 | for key, cap := range c.caps {
|
---|
203 | if cap.Requested {
|
---|
204 | c.Writef("CAP REQ :%s", key)
|
---|
205 | c.remainingCapResponses++
|
---|
206 | }
|
---|
207 | }
|
---|
208 | }
|
---|
209 |
|
---|
210 | // CapRequest allows you to request IRCv3 capabilities from the server during
|
---|
211 | // the handshake. The behavior is undefined if this is called before the
|
---|
212 | // handshake completes so it is recommended that this be called before Run. If
|
---|
213 | // the CAP is marked as required, the client will exit if that CAP could not be
|
---|
214 | // negotiated during the handshake.
|
---|
215 | func (c *Client) CapRequest(capName string, required bool) {
|
---|
216 | cap := c.caps[capName]
|
---|
217 | cap.Requested = true
|
---|
218 | cap.Required = cap.Required || required
|
---|
219 | c.caps[capName] = cap
|
---|
220 | }
|
---|
221 |
|
---|
222 | // CapEnabled allows you to check if a CAP is enabled for this connection. Note
|
---|
223 | // that it will not be populated until after the CAP handshake is done, so it is
|
---|
224 | // recommended to wait to check this until after a message like 001.
|
---|
225 | func (c *Client) CapEnabled(capName string) bool {
|
---|
226 | return c.caps[capName].Enabled
|
---|
227 | }
|
---|
228 |
|
---|
229 | // CapAvailable allows you to check if a CAP is available on this server. Note
|
---|
230 | // that it will not be populated until after the CAP handshake is done, so it is
|
---|
231 | // recommended to wait to check this until after a message like 001.
|
---|
232 | func (c *Client) CapAvailable(capName string) bool {
|
---|
233 | return c.caps[capName].Available
|
---|
234 | }
|
---|
235 |
|
---|
236 | func (c *Client) sendError(err error) {
|
---|
237 | select {
|
---|
238 | case c.errChan <- err:
|
---|
239 | default:
|
---|
240 | }
|
---|
241 | }
|
---|
242 |
|
---|
243 | func (c *Client) startReadLoop(wg *sync.WaitGroup, exiting chan struct{}) {
|
---|
244 | wg.Add(1)
|
---|
245 |
|
---|
246 | go func() {
|
---|
247 | defer wg.Done()
|
---|
248 |
|
---|
249 | for {
|
---|
250 | select {
|
---|
251 | case <-exiting:
|
---|
252 | return
|
---|
253 | default:
|
---|
254 | m, err := c.ReadMessage()
|
---|
255 | if err != nil {
|
---|
256 | c.sendError(err)
|
---|
257 | break
|
---|
258 | }
|
---|
259 |
|
---|
260 | if f, ok := clientFilters[m.Command]; ok {
|
---|
261 | f(c, m)
|
---|
262 | }
|
---|
263 |
|
---|
264 | if c.config.Handler != nil {
|
---|
265 | c.config.Handler.Handle(c, m)
|
---|
266 | }
|
---|
267 | }
|
---|
268 | }
|
---|
269 | }()
|
---|
270 | }
|
---|
271 |
|
---|
272 | // Run starts the main loop for this IRC connection. Note that it may break in
|
---|
273 | // strange and unexpected ways if it is called again before the first connection
|
---|
274 | // exits.
|
---|
275 | func (c *Client) Run() error {
|
---|
276 | return c.RunContext(context.TODO())
|
---|
277 | }
|
---|
278 |
|
---|
279 | // RunContext is the same as Run but a context.Context can be passed in for
|
---|
280 | // cancelation.
|
---|
281 | func (c *Client) RunContext(ctx context.Context) error {
|
---|
282 | // exiting is used by the main goroutine here to ensure any sub-goroutines
|
---|
283 | // get closed when exiting.
|
---|
284 | exiting := make(chan struct{})
|
---|
285 | var wg sync.WaitGroup
|
---|
286 |
|
---|
287 | c.maybeStartLimiter(&wg, exiting)
|
---|
288 | c.maybeStartPingLoop(&wg, exiting)
|
---|
289 |
|
---|
290 | c.currentNick = c.config.Nick
|
---|
291 |
|
---|
292 | if c.config.Pass != "" {
|
---|
293 | c.Writef("PASS :%s", c.config.Pass)
|
---|
294 | }
|
---|
295 |
|
---|
296 | c.maybeStartCapHandshake()
|
---|
297 |
|
---|
298 | // This feels wrong because it results in CAP LS, CAP REQ, NICK, USER, CAP
|
---|
299 | // END, but it works and lets us keep the code a bit simpler.
|
---|
300 | c.Writef("NICK :%s", c.config.Nick)
|
---|
301 | c.Writef("USER %s 0 * :%s", c.config.User, c.config.Name)
|
---|
302 |
|
---|
303 | // Now that the handshake is pretty much done, we can start listening for
|
---|
304 | // messages.
|
---|
305 | c.startReadLoop(&wg, exiting)
|
---|
306 |
|
---|
307 | // Wait for an error from any goroutine or for the context to time out, then
|
---|
308 | // signal we're exiting and wait for the goroutines to exit.
|
---|
309 | var err error
|
---|
310 | select {
|
---|
311 | case err = <-c.errChan:
|
---|
312 | case <-ctx.Done():
|
---|
313 | }
|
---|
314 |
|
---|
315 | close(exiting)
|
---|
316 | c.rwc.Close()
|
---|
317 | wg.Wait()
|
---|
318 |
|
---|
319 | return err
|
---|
320 | }
|
---|
321 |
|
---|
322 | // CurrentNick returns what the nick of the client is known to be at this point
|
---|
323 | // in time.
|
---|
324 | func (c *Client) CurrentNick() string {
|
---|
325 | return c.currentNick
|
---|
326 | }
|
---|
327 |
|
---|
328 | // FromChannel takes a Message representing a PRIVMSG and returns if that
|
---|
329 | // message came from a channel or directly from a user.
|
---|
330 | func (c *Client) FromChannel(m *Message) bool {
|
---|
331 | if len(m.Params) < 1 {
|
---|
332 | return false
|
---|
333 | }
|
---|
334 |
|
---|
335 | // The first param is the target, so if this doesn't match the current nick,
|
---|
336 | // the message came from a channel.
|
---|
337 | return m.Params[0] != c.currentNick
|
---|
338 | }
|
---|