source: code/trunk/db_sqlite.go@ 540

Last change on this file since 540 was 531, checked in by sir, 4 years ago

db: refactor into interface

This refactors the SQLite-specific bits into db_sqlite.go. A future
patch will add PostgreSQL support.

File size: 14.7 KB
RevLine 
[531]1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "strings"
8 "sync"
9 "time"
10
11 _ "github.com/mattn/go-sqlite3"
12)
13
14const sqliteSchema = `
15CREATE TABLE User (
16 id INTEGER PRIMARY KEY,
17 username VARCHAR(255) NOT NULL UNIQUE,
18 password VARCHAR(255),
19 admin INTEGER NOT NULL DEFAULT 0
20);
21
22CREATE TABLE Network (
23 id INTEGER PRIMARY KEY,
24 name VARCHAR(255),
25 user INTEGER NOT NULL,
26 addr VARCHAR(255) NOT NULL,
27 nick VARCHAR(255) NOT NULL,
28 username VARCHAR(255),
29 realname VARCHAR(255),
30 pass VARCHAR(255),
31 connect_commands VARCHAR(1023),
32 sasl_mechanism VARCHAR(255),
33 sasl_plain_username VARCHAR(255),
34 sasl_plain_password VARCHAR(255),
35 sasl_external_cert BLOB DEFAULT NULL,
36 sasl_external_key BLOB DEFAULT NULL,
37 FOREIGN KEY(user) REFERENCES User(id),
38 UNIQUE(user, addr, nick),
39 UNIQUE(user, name)
40);
41
42CREATE TABLE Channel (
43 id INTEGER PRIMARY KEY,
44 network INTEGER NOT NULL,
45 name VARCHAR(255) NOT NULL,
46 key VARCHAR(255),
47 detached INTEGER NOT NULL DEFAULT 0,
48 detached_internal_msgid VARCHAR(255),
49 relay_detached INTEGER NOT NULL DEFAULT 0,
50 reattach_on INTEGER NOT NULL DEFAULT 0,
51 detach_after INTEGER NOT NULL DEFAULT 0,
52 detach_on INTEGER NOT NULL DEFAULT 0,
53 FOREIGN KEY(network) REFERENCES Network(id),
54 UNIQUE(network, name)
55);
56
57CREATE TABLE DeliveryReceipt (
58 id INTEGER PRIMARY KEY,
59 network INTEGER NOT NULL,
60 target VARCHAR(255) NOT NULL,
61 client VARCHAR(255),
62 internal_msgid VARCHAR(255) NOT NULL,
63 FOREIGN KEY(network) REFERENCES Network(id),
64 UNIQUE(network, target, client)
65);
66`
67
68var sqliteMigrations = []string{
69 "", // migration #0 is reserved for schema initialization
70 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
71 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
72 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
73 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
74 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
75 `
76 CREATE TABLE UserNew (
77 id INTEGER PRIMARY KEY,
78 username VARCHAR(255) NOT NULL UNIQUE,
79 password VARCHAR(255),
80 admin INTEGER NOT NULL DEFAULT 0
81 );
82 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
83 DROP TABLE User;
84 ALTER TABLE UserNew RENAME TO User;
85 `,
86 `
87 CREATE TABLE NetworkNew (
88 id INTEGER PRIMARY KEY,
89 name VARCHAR(255),
90 user INTEGER NOT NULL,
91 addr VARCHAR(255) NOT NULL,
92 nick VARCHAR(255) NOT NULL,
93 username VARCHAR(255),
94 realname VARCHAR(255),
95 pass VARCHAR(255),
96 connect_commands VARCHAR(1023),
97 sasl_mechanism VARCHAR(255),
98 sasl_plain_username VARCHAR(255),
99 sasl_plain_password VARCHAR(255),
100 sasl_external_cert BLOB DEFAULT NULL,
101 sasl_external_key BLOB DEFAULT NULL,
102 FOREIGN KEY(user) REFERENCES User(id),
103 UNIQUE(user, addr, nick),
104 UNIQUE(user, name)
105 );
106 INSERT INTO NetworkNew
107 SELECT Network.id, name, User.id as user, addr, nick,
108 Network.username, realname, pass, connect_commands,
109 sasl_mechanism, sasl_plain_username, sasl_plain_password,
110 sasl_external_cert, sasl_external_key
111 FROM Network
112 JOIN User ON Network.user = User.username;
113 DROP TABLE Network;
114 ALTER TABLE NetworkNew RENAME TO Network;
115 `,
116 `
117 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
118 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
119 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
120 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
121 `,
122 `
123 CREATE TABLE DeliveryReceipt (
124 id INTEGER PRIMARY KEY,
125 network INTEGER NOT NULL,
126 target VARCHAR(255) NOT NULL,
127 client VARCHAR(255),
128 internal_msgid VARCHAR(255) NOT NULL,
129 FOREIGN KEY(network) REFERENCES Network(id),
130 UNIQUE(network, target, client)
131 );
132 `,
133 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
134}
135
136type SqliteDB struct {
137 lock sync.RWMutex
138 db *sql.DB
139}
140
141func OpenSqliteDB(driver, source string) (Database, error) {
142 sqlSqliteDB, err := sql.Open(driver, source)
143 if err != nil {
144 return nil, err
145 }
146
147 db := &SqliteDB{db: sqlSqliteDB}
148 if err := db.upgrade(); err != nil {
149 return nil, err
150 }
151
152 return db, nil
153}
154
155func (db *SqliteDB) Close() error {
156 db.lock.Lock()
157 defer db.lock.Unlock()
158 return db.db.Close()
159}
160
161func (db *SqliteDB) upgrade() error {
162 db.lock.Lock()
163 defer db.lock.Unlock()
164
165 var version int
166 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
167 return fmt.Errorf("failed to query schema version: %v", err)
168 }
169
170 if version == len(sqliteMigrations) {
171 return nil
172 } else if version > len(sqliteMigrations) {
173 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(sqliteMigrations), version)
174 }
175
176 tx, err := db.db.Begin()
177 if err != nil {
178 return err
179 }
180 defer tx.Rollback()
181
182 if version == 0 {
183 if _, err := tx.Exec(sqliteSchema); err != nil {
184 return fmt.Errorf("failed to initialize schema: %v", err)
185 }
186 } else {
187 for i := version; i < len(sqliteMigrations); i++ {
188 if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
189 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
190 }
191 }
192 }
193
194 // For some reason prepared statements don't work here
195 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
196 if err != nil {
197 return fmt.Errorf("failed to bump schema version: %v", err)
198 }
199
200 return tx.Commit()
201}
202
203func toNullString(s string) sql.NullString {
204 return sql.NullString{
205 String: s,
206 Valid: s != "",
207 }
208}
209
210func (db *SqliteDB) ListUsers() ([]User, error) {
211 db.lock.RLock()
212 defer db.lock.RUnlock()
213
214 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
215 if err != nil {
216 return nil, err
217 }
218 defer rows.Close()
219
220 var users []User
221 for rows.Next() {
222 var user User
223 var password sql.NullString
224 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
225 return nil, err
226 }
227 user.Password = password.String
228 users = append(users, user)
229 }
230 if err := rows.Err(); err != nil {
231 return nil, err
232 }
233
234 return users, nil
235}
236
237func (db *SqliteDB) GetUser(username string) (*User, error) {
238 db.lock.RLock()
239 defer db.lock.RUnlock()
240
241 user := &User{Username: username}
242
243 var password sql.NullString
244 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
245 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
246 return nil, err
247 }
248 user.Password = password.String
249 return user, nil
250}
251
252func (db *SqliteDB) StoreUser(user *User) error {
253 db.lock.Lock()
254 defer db.lock.Unlock()
255
256 password := toNullString(user.Password)
257
258 var err error
259 if user.ID != 0 {
260 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
261 password, user.Admin, user.Username)
262 } else {
263 var res sql.Result
264 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
265 user.Username, password, user.Admin)
266 if err != nil {
267 return err
268 }
269 user.ID, err = res.LastInsertId()
270 }
271
272 return err
273}
274
275func (db *SqliteDB) DeleteUser(id int64) error {
276 db.lock.Lock()
277 defer db.lock.Unlock()
278
279 tx, err := db.db.Begin()
280 if err != nil {
281 return err
282 }
283 defer tx.Rollback()
284
285 _, err = tx.Exec(`DELETE FROM Channel
286 WHERE id IN (
287 SELECT Channel.id
288 FROM Channel
289 JOIN Network ON Channel.network = Network.id
290 WHERE Network.user = ?
291 )`, id)
292 if err != nil {
293 return err
294 }
295
296 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
297 if err != nil {
298 return err
299 }
300
301 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
302 if err != nil {
303 return err
304 }
305
306 return tx.Commit()
307}
308
309func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
310 db.lock.RLock()
311 defer db.lock.RUnlock()
312
313 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
314 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
315 sasl_external_cert, sasl_external_key
316 FROM Network
317 WHERE user = ?`,
318 userID)
319 if err != nil {
320 return nil, err
321 }
322 defer rows.Close()
323
324 var networks []Network
325 for rows.Next() {
326 var net Network
327 var name, username, realname, pass, connectCommands sql.NullString
328 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
329 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
330 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
331 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
332 if err != nil {
333 return nil, err
334 }
335 net.Name = name.String
336 net.Username = username.String
337 net.Realname = realname.String
338 net.Pass = pass.String
339 if connectCommands.Valid {
340 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
341 }
342 net.SASL.Mechanism = saslMechanism.String
343 net.SASL.Plain.Username = saslPlainUsername.String
344 net.SASL.Plain.Password = saslPlainPassword.String
345 networks = append(networks, net)
346 }
347 if err := rows.Err(); err != nil {
348 return nil, err
349 }
350
351 return networks, nil
352}
353
354func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
355 db.lock.Lock()
356 defer db.lock.Unlock()
357
358 netName := toNullString(network.Name)
359 netUsername := toNullString(network.Username)
360 realname := toNullString(network.Realname)
361 pass := toNullString(network.Pass)
362 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
363
364 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
365 if network.SASL.Mechanism != "" {
366 saslMechanism = toNullString(network.SASL.Mechanism)
367 switch network.SASL.Mechanism {
368 case "PLAIN":
369 saslPlainUsername = toNullString(network.SASL.Plain.Username)
370 saslPlainPassword = toNullString(network.SASL.Plain.Password)
371 network.SASL.External.CertBlob = nil
372 network.SASL.External.PrivKeyBlob = nil
373 case "EXTERNAL":
374 // keep saslPlain* nil
375 default:
376 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
377 }
378 }
379
380 var err error
381 if network.ID != 0 {
382 _, err = db.db.Exec(`UPDATE Network
383 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
384 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
385 sasl_external_cert = ?, sasl_external_key = ?
386 WHERE id = ?`,
387 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
388 saslMechanism, saslPlainUsername, saslPlainPassword,
389 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
390 network.ID)
391 } else {
392 var res sql.Result
393 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
394 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
395 sasl_plain_password, sasl_external_cert, sasl_external_key)
396 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
397 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
398 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
399 network.SASL.External.PrivKeyBlob)
400 if err != nil {
401 return err
402 }
403 network.ID, err = res.LastInsertId()
404 }
405 return err
406}
407
408func (db *SqliteDB) DeleteNetwork(id int64) error {
409 db.lock.Lock()
410 defer db.lock.Unlock()
411
412 tx, err := db.db.Begin()
413 if err != nil {
414 return err
415 }
416 defer tx.Rollback()
417
418 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
419 if err != nil {
420 return err
421 }
422
423 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
424 if err != nil {
425 return err
426 }
427
428 return tx.Commit()
429}
430
431func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
432 db.lock.RLock()
433 defer db.lock.RUnlock()
434
435 rows, err := db.db.Query(`SELECT
436 id, name, key, detached, detached_internal_msgid,
437 relay_detached, reattach_on, detach_after, detach_on
438 FROM Channel
439 WHERE network = ?`, networkID)
440 if err != nil {
441 return nil, err
442 }
443 defer rows.Close()
444
445 var channels []Channel
446 for rows.Next() {
447 var ch Channel
448 var key, detachedInternalMsgID sql.NullString
449 var detachAfter int64
450 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
451 return nil, err
452 }
453 ch.Key = key.String
454 ch.DetachedInternalMsgID = detachedInternalMsgID.String
455 ch.DetachAfter = time.Duration(detachAfter) * time.Second
456 channels = append(channels, ch)
457 }
458 if err := rows.Err(); err != nil {
459 return nil, err
460 }
461
462 return channels, nil
463}
464
465func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
466 db.lock.Lock()
467 defer db.lock.Unlock()
468
469 key := toNullString(ch.Key)
470 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
471
472 var err error
473 if ch.ID != 0 {
474 _, err = db.db.Exec(`UPDATE Channel
475 SET network = ?, name = ?, key = ?, detached = ?, detached_internal_msgid = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?
476 WHERE id = ?`,
477 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)
478 } else {
479 var res sql.Result
480 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
481 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
482 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
483 if err != nil {
484 return err
485 }
486 ch.ID, err = res.LastInsertId()
487 }
488 return err
489}
490
491func (db *SqliteDB) DeleteChannel(id int64) error {
492 db.lock.Lock()
493 defer db.lock.Unlock()
494
495 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
496 return err
497}
498
499func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
500 db.lock.RLock()
501 defer db.lock.RUnlock()
502
503 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
504 FROM DeliveryReceipt
505 WHERE network = ?`, networkID)
506 if err != nil {
507 return nil, err
508 }
509 defer rows.Close()
510
511 var receipts []DeliveryReceipt
512 for rows.Next() {
513 var rcpt DeliveryReceipt
514 var client sql.NullString
515 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
516 return nil, err
517 }
518 rcpt.Client = client.String
519 receipts = append(receipts, rcpt)
520 }
521 if err := rows.Err(); err != nil {
522 return nil, err
523 }
524
525 return receipts, nil
526}
527
528func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
529 db.lock.Lock()
530 defer db.lock.Unlock()
531
532 tx, err := db.db.Begin()
533 if err != nil {
534 return err
535 }
536 defer tx.Rollback()
537
538 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",
539 networkID, toNullString(client))
540 if err != nil {
541 return err
542 }
543
544 for i := range receipts {
545 rcpt := &receipts[i]
546
547 res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
548 networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
549 if err != nil {
550 return err
551 }
552 rcpt.ID, err = res.LastInsertId()
553 if err != nil {
554 return err
555 }
556 }
557
558 return tx.Commit()
559}
Note: See TracBrowser for help on using the repository browser.