source: code/trunk/db.go@ 356

Last change on this file since 356 was 356, checked in by contact, 5 years ago

Fix deadlock in DB.Close

This method was calling itself, instead of the underlying SQLite
database's Close method.

File size: 10.3 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Created bool
14 Username string
15 Password string // hashed
16 Admin bool
17}
18
19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
34}
35
36type Network struct {
37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
46}
47
48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
55type Channel struct {
56 ID int64
57 Name string
58 Key string
59 Detached bool
60}
61
62const schema = `
63CREATE TABLE User (
64 username VARCHAR(255) PRIMARY KEY,
65 password VARCHAR(255) NOT NULL,
66 admin INTEGER NOT NULL DEFAULT 0
67);
68
69CREATE TABLE Network (
70 id INTEGER PRIMARY KEY,
71 name VARCHAR(255),
72 user VARCHAR(255) NOT NULL,
73 addr VARCHAR(255) NOT NULL,
74 nick VARCHAR(255) NOT NULL,
75 username VARCHAR(255),
76 realname VARCHAR(255),
77 pass VARCHAR(255),
78 connect_commands VARCHAR(1023),
79 sasl_mechanism VARCHAR(255),
80 sasl_plain_username VARCHAR(255),
81 sasl_plain_password VARCHAR(255),
82 sasl_external_cert BLOB DEFAULT NULL,
83 sasl_external_key BLOB DEFAULT NULL,
84 FOREIGN KEY(user) REFERENCES User(username),
85 UNIQUE(user, addr, nick)
86);
87
88CREATE TABLE Channel (
89 id INTEGER PRIMARY KEY,
90 network INTEGER NOT NULL,
91 name VARCHAR(255) NOT NULL,
92 key VARCHAR(255),
93 detached INTEGER NOT NULL DEFAULT 0,
94 FOREIGN KEY(network) REFERENCES Network(id),
95 UNIQUE(network, name)
96);
97`
98
99var migrations = []string{
100 "", // migration #0 is reserved for schema initialization
101 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
102 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
103 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
104 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
105 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
106}
107
108type DB struct {
109 lock sync.RWMutex
110 db *sql.DB
111}
112
113func OpenSQLDB(driver, source string) (*DB, error) {
114 sqlDB, err := sql.Open(driver, source)
115 if err != nil {
116 return nil, err
117 }
118
119 db := &DB{db: sqlDB}
120 if err := db.upgrade(); err != nil {
121 return nil, err
122 }
123
124 return db, nil
125}
126
127func (db *DB) Close() error {
128 db.lock.Lock()
129 defer db.lock.Unlock()
130 return db.db.Close()
131}
132
133func (db *DB) upgrade() error {
134 db.lock.Lock()
135 defer db.lock.Unlock()
136
137 var version int
138 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
139 return fmt.Errorf("failed to query schema version: %v", err)
140 }
141
142 if version == len(migrations) {
143 return nil
144 } else if version > len(migrations) {
145 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
146 }
147
148 tx, err := db.db.Begin()
149 if err != nil {
150 return err
151 }
152 defer tx.Rollback()
153
154 if version == 0 {
155 if _, err := tx.Exec(schema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %v", err)
157 }
158 } else {
159 for i := version; i < len(migrations); i++ {
160 if _, err := tx.Exec(migrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 // For some reason prepared statements don't work here
167 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
175func fromStringPtr(ptr *string) string {
176 if ptr == nil {
177 return ""
178 }
179 return *ptr
180}
181
182func toStringPtr(s string) *string {
183 if s == "" {
184 return nil
185 }
186 return &s
187}
188
189func (db *DB) ListUsers() ([]User, error) {
190 db.lock.RLock()
191 defer db.lock.RUnlock()
192
193 rows, err := db.db.Query("SELECT username, password, admin FROM User")
194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var users []User
200 for rows.Next() {
201 var user User
202 var password *string
203 if err := rows.Scan(&user.Username, &password, &user.Admin); err != nil {
204 return nil, err
205 }
206 user.Created = true
207 user.Password = fromStringPtr(password)
208 users = append(users, user)
209 }
210 if err := rows.Err(); err != nil {
211 return nil, err
212 }
213
214 return users, nil
215}
216
217func (db *DB) GetUser(username string) (*User, error) {
218 db.lock.RLock()
219 defer db.lock.RUnlock()
220
221 user := &User{Created: true, Username: username}
222
223 var password *string
224 row := db.db.QueryRow("SELECT password, admin FROM User WHERE username = ?", username)
225 if err := row.Scan(&password, &user.Admin); err != nil {
226 return nil, err
227 }
228 user.Password = fromStringPtr(password)
229 return user, nil
230}
231
232func (db *DB) StoreUser(user *User) error {
233 db.lock.Lock()
234 defer db.lock.Unlock()
235
236 password := toStringPtr(user.Password)
237
238 var err error
239 if user.Created {
240 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
241 password, user.Admin, user.Username)
242 } else {
243 _, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
244 user.Username, password, user.Admin)
245 if err == nil {
246 user.Created = true
247 }
248 }
249
250 return err
251}
252
253func (db *DB) ListNetworks(username string) ([]Network, error) {
254 db.lock.RLock()
255 defer db.lock.RUnlock()
256
257 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
258 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
259 sasl_external_cert, sasl_external_key
260 FROM Network
261 WHERE user = ?`,
262 username)
263 if err != nil {
264 return nil, err
265 }
266 defer rows.Close()
267
268 var networks []Network
269 for rows.Next() {
270 var net Network
271 var name, username, realname, pass, connectCommands *string
272 var saslMechanism, saslPlainUsername, saslPlainPassword *string
273 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
274 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
275 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
276 if err != nil {
277 return nil, err
278 }
279 net.Name = fromStringPtr(name)
280 net.Username = fromStringPtr(username)
281 net.Realname = fromStringPtr(realname)
282 net.Pass = fromStringPtr(pass)
283 if connectCommands != nil {
284 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
285 }
286 net.SASL.Mechanism = fromStringPtr(saslMechanism)
287 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
288 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
289 networks = append(networks, net)
290 }
291 if err := rows.Err(); err != nil {
292 return nil, err
293 }
294
295 return networks, nil
296}
297
298func (db *DB) StoreNetwork(username string, network *Network) error {
299 db.lock.Lock()
300 defer db.lock.Unlock()
301
302 netName := toStringPtr(network.Name)
303 netUsername := toStringPtr(network.Username)
304 realname := toStringPtr(network.Realname)
305 pass := toStringPtr(network.Pass)
306 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
307
308 var saslMechanism, saslPlainUsername, saslPlainPassword *string
309 if network.SASL.Mechanism != "" {
310 saslMechanism = &network.SASL.Mechanism
311 switch network.SASL.Mechanism {
312 case "PLAIN":
313 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
314 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
315 network.SASL.External.CertBlob = nil
316 network.SASL.External.PrivKeyBlob = nil
317 case "EXTERNAL":
318 // keep saslPlain* nil
319 default:
320 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
321 }
322 }
323
324 var err error
325 if network.ID != 0 {
326 _, err = db.db.Exec(`UPDATE Network
327 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
328 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
329 sasl_external_cert = ?, sasl_external_key = ?
330 WHERE id = ?`,
331 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
332 saslMechanism, saslPlainUsername, saslPlainPassword,
333 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
334 network.ID)
335 } else {
336 var res sql.Result
337 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
338 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
339 sasl_plain_password, sasl_external_cert, sasl_external_key)
340 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
341 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
342 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
343 network.SASL.External.PrivKeyBlob)
344 if err != nil {
345 return err
346 }
347 network.ID, err = res.LastInsertId()
348 }
349 return err
350}
351
352func (db *DB) DeleteNetwork(id int64) error {
353 db.lock.Lock()
354 defer db.lock.Unlock()
355
356 tx, err := db.db.Begin()
357 if err != nil {
358 return err
359 }
360 defer tx.Rollback()
361
362 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
363 if err != nil {
364 return err
365 }
366
367 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
368 if err != nil {
369 return err
370 }
371
372 return tx.Commit()
373}
374
375func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
376 db.lock.RLock()
377 defer db.lock.RUnlock()
378
379 rows, err := db.db.Query(`SELECT id, name, key, detached
380 FROM Channel
381 WHERE network = ?`, networkID)
382 if err != nil {
383 return nil, err
384 }
385 defer rows.Close()
386
387 var channels []Channel
388 for rows.Next() {
389 var ch Channel
390 var key *string
391 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
392 return nil, err
393 }
394 ch.Key = fromStringPtr(key)
395 channels = append(channels, ch)
396 }
397 if err := rows.Err(); err != nil {
398 return nil, err
399 }
400
401 return channels, nil
402}
403
404func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
405 db.lock.Lock()
406 defer db.lock.Unlock()
407
408 key := toStringPtr(ch.Key)
409
410 var err error
411 if ch.ID != 0 {
412 _, err = db.db.Exec(`UPDATE Channel
413 SET network = ?, name = ?, key = ?, detached = ?
414 WHERE id = ?`,
415 networkID, ch.Name, key, ch.Detached, ch.ID)
416 } else {
417 var res sql.Result
418 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
419 VALUES (?, ?, ?, ?)`,
420 networkID, ch.Name, key, ch.Detached)
421 if err != nil {
422 return err
423 }
424 ch.ID, err = res.LastInsertId()
425 }
426 return err
427}
428
429func (db *DB) DeleteChannel(networkID int64, name string) error {
430 db.lock.Lock()
431 defer db.lock.Unlock()
432
433 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
434 return err
435}
Note: See TracBrowser for help on using the repository browser.