source: code/trunk/db.go@ 175

Last change on this file since 175 was 173, checked in by contact, 5 years ago

Stop accessing user data in downstreamConn.authenticate

This becomes racy once user.Password is updated on-the-fly.

File size: 6.1 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[77]6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
[85]13 Password string // hashed
[77]14}
15
[95]16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
[77]25type Network struct {
26 ID int64
[118]27 Name string
[77]28 Addr string
29 Nick string
30 Username string
31 Realname string
[93]32 Pass string
[95]33 SASL SASL
[77]34}
35
[149]36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
[77]43type Channel struct {
44 ID int64
45 Name string
[146]46 Key string
[77]47}
48
49type DB struct {
[81]50 lock sync.RWMutex
[77]51 db *sql.DB
52}
53
54func OpenSQLDB(driver, source string) (*DB, error) {
55 db, err := sql.Open(driver, source)
56 if err != nil {
57 return nil, err
58 }
59 return &DB{db: db}, nil
60}
61
62func (db *DB) Close() error {
63 db.lock.Lock()
64 defer db.lock.Unlock()
65 return db.Close()
66}
67
[95]68func fromStringPtr(ptr *string) string {
69 if ptr == nil {
70 return ""
71 }
72 return *ptr
73}
74
75func toStringPtr(s string) *string {
76 if s == "" {
77 return nil
78 }
79 return &s
80}
81
[77]82func (db *DB) ListUsers() ([]User, error) {
[81]83 db.lock.RLock()
84 defer db.lock.RUnlock()
[77]85
86 rows, err := db.db.Query("SELECT username, password FROM User")
87 if err != nil {
88 return nil, err
89 }
90 defer rows.Close()
91
92 var users []User
93 for rows.Next() {
94 var user User
95 var password *string
96 if err := rows.Scan(&user.Username, &password); err != nil {
97 return nil, err
98 }
[95]99 user.Password = fromStringPtr(password)
[77]100 users = append(users, user)
101 }
102 if err := rows.Err(); err != nil {
103 return nil, err
104 }
105
106 return users, nil
107}
108
[173]109func (db *DB) GetUser(username string) (*User, error) {
110 db.lock.RLock()
111 defer db.lock.RUnlock()
112
113 user := &User{Username: username}
114
115 var password *string
116 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
117 if err := row.Scan(&password); err != nil {
118 return nil, err
119 }
120 user.Password = fromStringPtr(password)
121 return user, nil
122}
123
[84]124func (db *DB) CreateUser(user *User) error {
125 db.lock.Lock()
126 defer db.lock.Unlock()
127
[95]128 password := toStringPtr(user.Password)
[89]129 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
130 return err
[84]131}
132
[77]133func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]134 db.lock.RLock()
135 defer db.lock.RUnlock()
[77]136
[118]137 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[95]138 sasl_mechanism, sasl_plain_username, sasl_plain_password
139 FROM Network
140 WHERE user = ?`,
141 username)
[77]142 if err != nil {
143 return nil, err
144 }
145 defer rows.Close()
146
147 var networks []Network
148 for rows.Next() {
149 var net Network
[118]150 var name, username, realname, pass *string
[95]151 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]152 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[95]153 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
154 if err != nil {
[77]155 return nil, err
156 }
[118]157 net.Name = fromStringPtr(name)
[95]158 net.Username = fromStringPtr(username)
159 net.Realname = fromStringPtr(realname)
160 net.Pass = fromStringPtr(pass)
161 net.SASL.Mechanism = fromStringPtr(saslMechanism)
162 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
163 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]164 networks = append(networks, net)
165 }
166 if err := rows.Err(); err != nil {
167 return nil, err
168 }
169
170 return networks, nil
171}
172
[90]173func (db *DB) StoreNetwork(username string, network *Network) error {
174 db.lock.Lock()
175 defer db.lock.Unlock()
176
[118]177 netName := toStringPtr(network.Name)
[95]178 netUsername := toStringPtr(network.Username)
179 realname := toStringPtr(network.Realname)
180 pass := toStringPtr(network.Pass)
181
182 var saslMechanism, saslPlainUsername, saslPlainPassword *string
183 if network.SASL.Mechanism != "" {
184 saslMechanism = &network.SASL.Mechanism
185 switch network.SASL.Mechanism {
186 case "PLAIN":
187 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
188 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]189 default:
190 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]191 }
[90]192 }
193
194 var err error
195 if network.ID != 0 {
[93]196 _, err = db.db.Exec(`UPDATE Network
[118]197 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
[95]198 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]199 WHERE id = ?`,
[118]200 netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]201 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]202 } else {
203 var res sql.Result
[118]204 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[95]205 realname, pass, sasl_mechanism, sasl_plain_username,
206 sasl_plain_password)
[118]207 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
208 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]209 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]210 if err != nil {
211 return err
212 }
213 network.ID, err = res.LastInsertId()
214 }
215 return err
216}
217
[77]218func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]219 db.lock.RLock()
220 defer db.lock.RUnlock()
[77]221
[146]222 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]223 if err != nil {
224 return nil, err
225 }
226 defer rows.Close()
227
228 var channels []Channel
229 for rows.Next() {
230 var ch Channel
[146]231 var key *string
232 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]233 return nil, err
234 }
[146]235 ch.Key = fromStringPtr(key)
[77]236 channels = append(channels, ch)
237 }
238 if err := rows.Err(); err != nil {
239 return nil, err
240 }
241
242 return channels, nil
243}
[89]244
245func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
246 db.lock.Lock()
247 defer db.lock.Unlock()
248
[146]249 key := toStringPtr(ch.Key)
[149]250
251 var err error
252 if ch.ID != 0 {
253 _, err = db.db.Exec(`UPDATE Channel
254 SET network = ?, name = ?, key = ?
255 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
256 } else {
257 var res sql.Result
258 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
259 VALUES (?, ?, ?)`, networkID, ch.Name, key)
260 if err != nil {
261 return err
262 }
263 ch.ID, err = res.LastInsertId()
264 }
[89]265 return err
266}
267
268func (db *DB) DeleteChannel(networkID int64, name string) error {
269 db.lock.Lock()
270 defer db.lock.Unlock()
271
272 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
273 return err
274}
Note: See TracBrowser for help on using the repository browser.