source: code/trunk/user.go@ 142

Last change on this file since 142 was 137, checked in by contact, 5 years ago

Add user.{add,remove}Downstream

File size: 4.4 KB
RevLine 
[101]1package soju
2
3import (
4 "sync"
5 "time"
[103]6
7 "gopkg.in/irc.v3"
[101]8)
9
[103]10type upstreamIncomingMessage struct {
11 msg *irc.Message
12 uc *upstreamConn
13}
14
15type downstreamIncomingMessage struct {
16 msg *irc.Message
17 dc *downstreamConn
18}
19
[101]20type network struct {
21 Network
22 user *user
[131]23
24 lock sync.Mutex
25 conn *upstreamConn
26 history map[string]uint64
[101]27}
28
29func newNetwork(user *user, record *Network) *network {
30 return &network{
31 Network: *record,
32 user: user,
[131]33 history: make(map[string]uint64),
[101]34 }
35}
36
37func (net *network) run() {
38 var lastTry time.Time
39 for {
40 if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay {
41 delay := retryConnectMinDelay - dur
42 net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
43 time.Sleep(delay)
44 }
45 lastTry = time.Now()
46
47 uc, err := connectToUpstream(net)
48 if err != nil {
49 net.user.srv.Logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
50 continue
51 }
52
53 uc.register()
54
[131]55 net.lock.Lock()
[101]56 net.conn = uc
[131]57 net.lock.Unlock()
[101]58
[103]59 if err := uc.readMessages(net.user.upstreamIncoming); err != nil {
[101]60 uc.logger.Printf("failed to handle messages: %v", err)
61 }
62 uc.Close()
63
[131]64 net.lock.Lock()
[101]65 net.conn = nil
[131]66 net.lock.Unlock()
[101]67 }
68}
69
[136]70func (net *network) upstream() *upstreamConn {
71 net.lock.Lock()
72 defer net.lock.Unlock()
73 return net.conn
74}
75
[101]76type user struct {
77 User
78 srv *Server
79
[103]80 upstreamIncoming chan upstreamIncomingMessage
81 downstreamIncoming chan downstreamIncomingMessage
82
[101]83 lock sync.Mutex
84 networks []*network
85 downstreamConns []*downstreamConn
86}
87
88func newUser(srv *Server, record *User) *user {
89 return &user{
[103]90 User: *record,
91 srv: srv,
92 upstreamIncoming: make(chan upstreamIncomingMessage, 64),
93 downstreamIncoming: make(chan downstreamIncomingMessage, 64),
[101]94 }
95}
96
97func (u *user) forEachNetwork(f func(*network)) {
98 u.lock.Lock()
99 for _, network := range u.networks {
100 f(network)
101 }
102 u.lock.Unlock()
103}
104
105func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
106 u.lock.Lock()
107 for _, network := range u.networks {
[136]108 uc := network.upstream()
[101]109 if uc == nil || !uc.registered || uc.closed {
110 continue
111 }
112 f(uc)
113 }
114 u.lock.Unlock()
115}
116
117func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
118 u.lock.Lock()
119 for _, dc := range u.downstreamConns {
120 f(dc)
121 }
122 u.lock.Unlock()
123}
124
125func (u *user) getNetwork(name string) *network {
126 for _, network := range u.networks {
127 if network.Addr == name {
128 return network
129 }
130 }
131 return nil
132}
133
134func (u *user) run() {
135 networks, err := u.srv.db.ListNetworks(u.Username)
136 if err != nil {
137 u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err)
138 return
139 }
140
141 u.lock.Lock()
142 for _, record := range networks {
143 network := newNetwork(u, &record)
144 u.networks = append(u.networks, network)
145
146 go network.run()
147 }
148 u.lock.Unlock()
[103]149
150 for {
151 select {
152 case upstreamMsg := <-u.upstreamIncoming:
153 msg, uc := upstreamMsg.msg, upstreamMsg.uc
[133]154 if uc.closed {
155 uc.logger.Printf("ignoring message on closed connection: %v", msg)
156 break
157 }
[103]158 if err := uc.handleMessage(msg); err != nil {
159 uc.logger.Printf("failed to handle message %q: %v", msg, err)
160 }
161 case downstreamMsg := <-u.downstreamIncoming:
162 msg, dc := downstreamMsg.msg, downstreamMsg.dc
[133]163 if dc.isClosed() {
164 dc.logger.Printf("ignoring message on closed connection: %v", msg)
165 break
166 }
[103]167 err := dc.handleMessage(msg)
168 if ircErr, ok := err.(ircError); ok {
169 ircErr.Message.Prefix = dc.srv.prefix()
170 dc.SendMessage(ircErr.Message)
171 } else if err != nil {
172 dc.logger.Printf("failed to handle message %q: %v", msg, err)
173 dc.Close()
174 }
175 }
176 }
[101]177}
178
[137]179func (u *user) addDownstream(dc *downstreamConn) (first bool) {
180 u.lock.Lock()
181 first = len(dc.user.downstreamConns) == 0
182 u.downstreamConns = append(u.downstreamConns, dc)
183 u.lock.Unlock()
184 return first
185}
186
187func (u *user) removeDownstream(dc *downstreamConn) {
188 u.lock.Lock()
189 for i := range u.downstreamConns {
190 if u.downstreamConns[i] == dc {
191 u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
192 break
193 }
194 }
195 u.lock.Unlock()
196}
197
[120]198func (u *user) createNetwork(net *Network) (*network, error) {
199 network := newNetwork(u, net)
[101]200 err := u.srv.db.StoreNetwork(u.Username, &network.Network)
201 if err != nil {
202 return nil, err
203 }
204 u.lock.Lock()
205 u.networks = append(u.networks, network)
206 u.lock.Unlock()
207 go network.run()
208 return network, nil
209}
Note: See TracBrowser for help on using the repository browser.