source: code/trunk/user.go@ 103

Last change on this file since 103 was 103, checked in by contact, 5 years ago

Per-user dispatcher goroutine

This allows message handlers to read upstream/downstream connection
information without causing any race condition.

References: https://todo.sr.ht/~emersion/soju/1

File size: 3.6 KB
RevLine 
[101]1package soju
2
3import (
4 "sync"
5 "time"
[103]6
7 "gopkg.in/irc.v3"
[101]8)
9
[103]10type upstreamIncomingMessage struct {
11 msg *irc.Message
12 uc *upstreamConn
13}
14
15type downstreamIncomingMessage struct {
16 msg *irc.Message
17 dc *downstreamConn
18}
19
[101]20type network struct {
21 Network
22 user *user
23 conn *upstreamConn
24}
25
26func newNetwork(user *user, record *Network) *network {
27 return &network{
28 Network: *record,
29 user: user,
30 }
31}
32
33func (net *network) run() {
34 var lastTry time.Time
35 for {
36 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
37 delay := retryConnectMinDelay - dur
38 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
39 time.Sleep(delay)
40 }
41 lastTry = time.Now()
42
43 uc, err := connectToUpstream(net)
44 if err != nil {
45 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
46 continue
47 }
48
49 uc.register()
50
51 net.user.lock.Lock()
52 net.conn = uc
53 net.user.lock.Unlock()
54
[103]55 if err := uc.readMessages(net.user.upstreamIncoming); err != nil {
[101]56 uc.logger.Printf("failed to handle messages: %v", err)
57 }
58 uc.Close()
59
60 net.user.lock.Lock()
61 net.conn = nil
62 net.user.lock.Unlock()
63 }
64}
65
66type user struct {
67 User
68 srv *Server
69
[103]70 upstreamIncoming chan upstreamIncomingMessage
71 downstreamIncoming chan downstreamIncomingMessage
72
[101]73 lock sync.Mutex
74 networks []*network
75 downstreamConns []*downstreamConn
76}
77
78func newUser(srv *Server, record *User) *user {
79 return &user{
[103]80 User: *record,
81 srv: srv,
82 upstreamIncoming: make(chan upstreamIncomingMessage, 64),
83 downstreamIncoming: make(chan downstreamIncomingMessage, 64),
[101]84 }
85}
86
87func (u *user) forEachNetwork(f func(*network)) {
88 u.lock.Lock()
89 for _, network := range u.networks {
90 f(network)
91 }
92 u.lock.Unlock()
93}
94
95func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
96 u.lock.Lock()
97 for _, network := range u.networks {
98 uc := network.conn
99 if uc == nil || !uc.registered || uc.closed {
100 continue
101 }
102 f(uc)
103 }
104 u.lock.Unlock()
105}
106
107func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
108 u.lock.Lock()
109 for _, dc := range u.downstreamConns {
110 f(dc)
111 }
112 u.lock.Unlock()
113}
114
115func (u *user) getNetwork(name string) *network {
116 for _, network := range u.networks {
117 if network.Addr == name {
118 return network
119 }
120 }
121 return nil
122}
123
124func (u *user) run() {
125 networks, err := u.srv.db.ListNetworks(u.Username)
126 if err != nil {
127 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
128 return
129 }
130
131 u.lock.Lock()
132 for _, record := range networks {
133 network := newNetwork(u, &record)
134 u.networks = append(u.networks, network)
135
136 go network.run()
137 }
138 u.lock.Unlock()
[103]139
140 for {
141 select {
142 case upstreamMsg := <-u.upstreamIncoming:
143 msg, uc := upstreamMsg.msg, upstreamMsg.uc
144 if err := uc.handleMessage(msg); err != nil {
145 uc.logger.Printf("failed to handle message %q: %v", msg, err)
146 }
147 case downstreamMsg := <-u.downstreamIncoming:
148 msg, dc := downstreamMsg.msg, downstreamMsg.dc
149 err := dc.handleMessage(msg)
150 if ircErr, ok := err.(ircError); ok {
151 ircErr.Message.Prefix = dc.srv.prefix()
152 dc.SendMessage(ircErr.Message)
153 } else if err != nil {
154 dc.logger.Printf("failed to handle message %q: %v", msg, err)
155 dc.Close()
156 }
157 }
158 }
[101]159}
160
161func (u *user) createNetwork(addr, nick string) (*network, error) {
162 network := newNetwork(u, &Network{
163 Addr: addr,
164 Nick: nick,
165 })
166 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
167 if err != nil {
168 return nil, err
169 }
170 u.lock.Lock()
171 u.networks = append(u.networks, network)
172 u.lock.Unlock()
173 go network.run()
174 return network, nil
175}
Note: See TracBrowser for help on using the repository browser.