source: code/trunk/db_sqlite.go@ 640

Last change on this file since 640 was 620, checked in by hubert, 4 years ago

PostgreSQL support

File size: 16.9 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "strings"
8 "sync"
9 "time"
10
11 _ "github.com/mattn/go-sqlite3"
12)
13
14const sqliteSchema = `
15CREATE TABLE User (
16 id INTEGER PRIMARY KEY,
17 username VARCHAR(255) NOT NULL UNIQUE,
18 password VARCHAR(255),
19 admin INTEGER NOT NULL DEFAULT 0,
20 realname VARCHAR(255)
21);
22
23CREATE TABLE Network (
24 id INTEGER PRIMARY KEY,
25 name VARCHAR(255),
26 user INTEGER NOT NULL,
27 addr VARCHAR(255) NOT NULL,
28 nick VARCHAR(255) NOT NULL,
29 username VARCHAR(255),
30 realname VARCHAR(255),
31 pass VARCHAR(255),
32 connect_commands VARCHAR(1023),
33 sasl_mechanism VARCHAR(255),
34 sasl_plain_username VARCHAR(255),
35 sasl_plain_password VARCHAR(255),
36 sasl_external_cert BLOB DEFAULT NULL,
37 sasl_external_key BLOB DEFAULT NULL,
38 enabled INTEGER NOT NULL DEFAULT 1,
39 FOREIGN KEY(user) REFERENCES User(id),
40 UNIQUE(user, addr, nick),
41 UNIQUE(user, name)
42);
43
44CREATE TABLE Channel (
45 id INTEGER PRIMARY KEY,
46 network INTEGER NOT NULL,
47 name VARCHAR(255) NOT NULL,
48 key VARCHAR(255),
49 detached INTEGER NOT NULL DEFAULT 0,
50 detached_internal_msgid VARCHAR(255),
51 relay_detached INTEGER NOT NULL DEFAULT 0,
52 reattach_on INTEGER NOT NULL DEFAULT 0,
53 detach_after INTEGER NOT NULL DEFAULT 0,
54 detach_on INTEGER NOT NULL DEFAULT 0,
55 FOREIGN KEY(network) REFERENCES Network(id),
56 UNIQUE(network, name)
57);
58
59CREATE TABLE DeliveryReceipt (
60 id INTEGER PRIMARY KEY,
61 network INTEGER NOT NULL,
62 target VARCHAR(255) NOT NULL,
63 client VARCHAR(255),
64 internal_msgid VARCHAR(255) NOT NULL,
65 FOREIGN KEY(network) REFERENCES Network(id),
66 UNIQUE(network, target, client)
67);
68`
69
70var sqliteMigrations = []string{
71 "", // migration #0 is reserved for schema initialization
72 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
73 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
74 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
75 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
76 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
77 `
78 CREATE TABLE UserNew (
79 id INTEGER PRIMARY KEY,
80 username VARCHAR(255) NOT NULL UNIQUE,
81 password VARCHAR(255),
82 admin INTEGER NOT NULL DEFAULT 0
83 );
84 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
85 DROP TABLE User;
86 ALTER TABLE UserNew RENAME TO User;
87 `,
88 `
89 CREATE TABLE NetworkNew (
90 id INTEGER PRIMARY KEY,
91 name VARCHAR(255),
92 user INTEGER NOT NULL,
93 addr VARCHAR(255) NOT NULL,
94 nick VARCHAR(255) NOT NULL,
95 username VARCHAR(255),
96 realname VARCHAR(255),
97 pass VARCHAR(255),
98 connect_commands VARCHAR(1023),
99 sasl_mechanism VARCHAR(255),
100 sasl_plain_username VARCHAR(255),
101 sasl_plain_password VARCHAR(255),
102 sasl_external_cert BLOB DEFAULT NULL,
103 sasl_external_key BLOB DEFAULT NULL,
104 FOREIGN KEY(user) REFERENCES User(id),
105 UNIQUE(user, addr, nick),
106 UNIQUE(user, name)
107 );
108 INSERT INTO NetworkNew
109 SELECT Network.id, name, User.id as user, addr, nick,
110 Network.username, realname, pass, connect_commands,
111 sasl_mechanism, sasl_plain_username, sasl_plain_password,
112 sasl_external_cert, sasl_external_key
113 FROM Network
114 JOIN User ON Network.user = User.username;
115 DROP TABLE Network;
116 ALTER TABLE NetworkNew RENAME TO Network;
117 `,
118 `
119 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
120 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
121 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
122 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
123 `,
124 `
125 CREATE TABLE DeliveryReceipt (
126 id INTEGER PRIMARY KEY,
127 network INTEGER NOT NULL,
128 target VARCHAR(255) NOT NULL,
129 client VARCHAR(255),
130 internal_msgid VARCHAR(255) NOT NULL,
131 FOREIGN KEY(network) REFERENCES Network(id),
132 UNIQUE(network, target, client)
133 );
134 `,
135 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
136 "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
137 "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
138}
139
140type SqliteDB struct {
141 lock sync.RWMutex
142 db *sql.DB
143}
144
145func OpenSqliteDB(source string) (Database, error) {
146 sqlSqliteDB, err := sql.Open("sqlite3", source)
147 if err != nil {
148 return nil, err
149 }
150
151 db := &SqliteDB{db: sqlSqliteDB}
152 if err := db.upgrade(); err != nil {
153 sqlSqliteDB.Close()
154 return nil, err
155 }
156
157 return db, nil
158}
159
160func (db *SqliteDB) Close() error {
161 db.lock.Lock()
162 defer db.lock.Unlock()
163 return db.db.Close()
164}
165
166func (db *SqliteDB) upgrade() error {
167 db.lock.Lock()
168 defer db.lock.Unlock()
169
170 var version int
171 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
172 return fmt.Errorf("failed to query schema version: %v", err)
173 }
174
175 if version == len(sqliteMigrations) {
176 return nil
177 } else if version > len(sqliteMigrations) {
178 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(sqliteMigrations), version)
179 }
180
181 tx, err := db.db.Begin()
182 if err != nil {
183 return err
184 }
185 defer tx.Rollback()
186
187 if version == 0 {
188 if _, err := tx.Exec(sqliteSchema); err != nil {
189 return fmt.Errorf("failed to initialize schema: %v", err)
190 }
191 } else {
192 for i := version; i < len(sqliteMigrations); i++ {
193 if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
194 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
195 }
196 }
197 }
198
199 // For some reason prepared statements don't work here
200 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
201 if err != nil {
202 return fmt.Errorf("failed to bump schema version: %v", err)
203 }
204
205 return tx.Commit()
206}
207
208func (db *SqliteDB) Stats() (*DatabaseStats, error) {
209 db.lock.RLock()
210 defer db.lock.RUnlock()
211
212 var stats DatabaseStats
213 row := db.db.QueryRow(`SELECT
214 (SELECT COUNT(*) FROM User) AS users,
215 (SELECT COUNT(*) FROM Network) AS networks,
216 (SELECT COUNT(*) FROM Channel) AS channels`)
217 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
218 return nil, err
219 }
220
221 return &stats, nil
222}
223
224func toNullString(s string) sql.NullString {
225 return sql.NullString{
226 String: s,
227 Valid: s != "",
228 }
229}
230
231func (db *SqliteDB) ListUsers() ([]User, error) {
232 db.lock.RLock()
233 defer db.lock.RUnlock()
234
235 rows, err := db.db.Query("SELECT id, username, password, admin, realname FROM User")
236 if err != nil {
237 return nil, err
238 }
239 defer rows.Close()
240
241 var users []User
242 for rows.Next() {
243 var user User
244 var password, realname sql.NullString
245 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
246 return nil, err
247 }
248 user.Password = password.String
249 user.Realname = realname.String
250 users = append(users, user)
251 }
252 if err := rows.Err(); err != nil {
253 return nil, err
254 }
255
256 return users, nil
257}
258
259func (db *SqliteDB) GetUser(username string) (*User, error) {
260 db.lock.RLock()
261 defer db.lock.RUnlock()
262
263 user := &User{Username: username}
264
265 var password, realname sql.NullString
266 row := db.db.QueryRow("SELECT id, password, admin, realname FROM User WHERE username = ?", username)
267 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
268 return nil, err
269 }
270 user.Password = password.String
271 user.Realname = realname.String
272 return user, nil
273}
274
275func (db *SqliteDB) StoreUser(user *User) error {
276 db.lock.Lock()
277 defer db.lock.Unlock()
278
279 args := []interface{}{
280 sql.Named("username", user.Username),
281 sql.Named("password", toNullString(user.Password)),
282 sql.Named("admin", user.Admin),
283 sql.Named("realname", toNullString(user.Realname)),
284 }
285
286 var err error
287 if user.ID != 0 {
288 _, err = db.db.Exec("UPDATE User SET password = :password, admin = :admin, realname = :realname WHERE username = :username", args...)
289 } else {
290 var res sql.Result
291 res, err = db.db.Exec("INSERT INTO User(username, password, admin, realname) VALUES (:username, :password, :admin, :realname)", args...)
292 if err != nil {
293 return err
294 }
295 user.ID, err = res.LastInsertId()
296 }
297
298 return err
299}
300
301func (db *SqliteDB) DeleteUser(id int64) error {
302 db.lock.Lock()
303 defer db.lock.Unlock()
304
305 tx, err := db.db.Begin()
306 if err != nil {
307 return err
308 }
309 defer tx.Rollback()
310
311 _, err = tx.Exec(`DELETE FROM DeliveryReceipt
312 WHERE id IN (
313 SELECT DeliveryReceipt.id
314 FROM DeliveryReceipt
315 JOIN Network ON DeliveryReceipt.network = Network.id
316 WHERE Network.user = ?
317 )`, id)
318 if err != nil {
319 return err
320 }
321
322 _, err = tx.Exec(`DELETE FROM Channel
323 WHERE id IN (
324 SELECT Channel.id
325 FROM Channel
326 JOIN Network ON Channel.network = Network.id
327 WHERE Network.user = ?
328 )`, id)
329 if err != nil {
330 return err
331 }
332
333 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
334 if err != nil {
335 return err
336 }
337
338 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
339 if err != nil {
340 return err
341 }
342
343 return tx.Commit()
344}
345
346func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
347 db.lock.RLock()
348 defer db.lock.RUnlock()
349
350 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
351 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
352 sasl_external_cert, sasl_external_key, enabled
353 FROM Network
354 WHERE user = ?`,
355 userID)
356 if err != nil {
357 return nil, err
358 }
359 defer rows.Close()
360
361 var networks []Network
362 for rows.Next() {
363 var net Network
364 var name, username, realname, pass, connectCommands sql.NullString
365 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
366 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
367 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
368 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
369 if err != nil {
370 return nil, err
371 }
372 net.Name = name.String
373 net.Username = username.String
374 net.Realname = realname.String
375 net.Pass = pass.String
376 if connectCommands.Valid {
377 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
378 }
379 net.SASL.Mechanism = saslMechanism.String
380 net.SASL.Plain.Username = saslPlainUsername.String
381 net.SASL.Plain.Password = saslPlainPassword.String
382 networks = append(networks, net)
383 }
384 if err := rows.Err(); err != nil {
385 return nil, err
386 }
387
388 return networks, nil
389}
390
391func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
392 db.lock.Lock()
393 defer db.lock.Unlock()
394
395 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
396 if network.SASL.Mechanism != "" {
397 saslMechanism = toNullString(network.SASL.Mechanism)
398 switch network.SASL.Mechanism {
399 case "PLAIN":
400 saslPlainUsername = toNullString(network.SASL.Plain.Username)
401 saslPlainPassword = toNullString(network.SASL.Plain.Password)
402 network.SASL.External.CertBlob = nil
403 network.SASL.External.PrivKeyBlob = nil
404 case "EXTERNAL":
405 // keep saslPlain* nil
406 default:
407 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
408 }
409 }
410
411 args := []interface{}{
412 sql.Named("name", toNullString(network.Name)),
413 sql.Named("addr", network.Addr),
414 sql.Named("nick", network.Nick),
415 sql.Named("username", toNullString(network.Username)),
416 sql.Named("realname", toNullString(network.Realname)),
417 sql.Named("pass", toNullString(network.Pass)),
418 sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
419 sql.Named("sasl_mechanism", saslMechanism),
420 sql.Named("sasl_plain_username", saslPlainUsername),
421 sql.Named("sasl_plain_password", saslPlainPassword),
422 sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
423 sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
424 sql.Named("enabled", network.Enabled),
425
426 sql.Named("id", network.ID), // only for UPDATE
427 sql.Named("user", userID), // only for INSERT
428 }
429
430 var err error
431 if network.ID != 0 {
432 _, err = db.db.Exec(`
433 UPDATE Network
434 SET name = :name, addr = :addr, nick = :nick, username = :username,
435 realname = :realname, pass = :pass, connect_commands = :connect_commands,
436 sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
437 sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
438 enabled = :enabled
439 WHERE id = :id`, args...)
440 } else {
441 var res sql.Result
442 res, err = db.db.Exec(`
443 INSERT INTO Network(user, name, addr, nick, username, realname, pass,
444 connect_commands, sasl_mechanism, sasl_plain_username,
445 sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
446 VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
447 :connect_commands, :sasl_mechanism, :sasl_plain_username,
448 :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
449 args...)
450 if err != nil {
451 return err
452 }
453 network.ID, err = res.LastInsertId()
454 }
455 return err
456}
457
458func (db *SqliteDB) DeleteNetwork(id int64) error {
459 db.lock.Lock()
460 defer db.lock.Unlock()
461
462 tx, err := db.db.Begin()
463 if err != nil {
464 return err
465 }
466 defer tx.Rollback()
467
468 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ?", id)
469 if err != nil {
470 return err
471 }
472
473 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
474 if err != nil {
475 return err
476 }
477
478 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
479 if err != nil {
480 return err
481 }
482
483 return tx.Commit()
484}
485
486func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
487 db.lock.RLock()
488 defer db.lock.RUnlock()
489
490 rows, err := db.db.Query(`SELECT
491 id, name, key, detached, detached_internal_msgid,
492 relay_detached, reattach_on, detach_after, detach_on
493 FROM Channel
494 WHERE network = ?`, networkID)
495 if err != nil {
496 return nil, err
497 }
498 defer rows.Close()
499
500 var channels []Channel
501 for rows.Next() {
502 var ch Channel
503 var key, detachedInternalMsgID sql.NullString
504 var detachAfter int64
505 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
506 return nil, err
507 }
508 ch.Key = key.String
509 ch.DetachedInternalMsgID = detachedInternalMsgID.String
510 ch.DetachAfter = time.Duration(detachAfter) * time.Second
511 channels = append(channels, ch)
512 }
513 if err := rows.Err(); err != nil {
514 return nil, err
515 }
516
517 return channels, nil
518}
519
520func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
521 db.lock.Lock()
522 defer db.lock.Unlock()
523
524 args := []interface{}{
525 sql.Named("network", networkID),
526 sql.Named("name", ch.Name),
527 sql.Named("key", toNullString(ch.Key)),
528 sql.Named("detached", ch.Detached),
529 sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)),
530 sql.Named("relay_detached", ch.RelayDetached),
531 sql.Named("reattach_on", ch.ReattachOn),
532 sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))),
533 sql.Named("detach_on", ch.DetachOn),
534
535 sql.Named("id", ch.ID), // only for UPDATE
536 }
537
538 var err error
539 if ch.ID != 0 {
540 _, err = db.db.Exec(`UPDATE Channel
541 SET network = :network, name = :name, key = :key, detached = :detached,
542 detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached,
543 reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on
544 WHERE id = :id`, args...)
545 } else {
546 var res sql.Result
547 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
548 VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...)
549 if err != nil {
550 return err
551 }
552 ch.ID, err = res.LastInsertId()
553 }
554 return err
555}
556
557func (db *SqliteDB) DeleteChannel(id int64) error {
558 db.lock.Lock()
559 defer db.lock.Unlock()
560
561 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
562 return err
563}
564
565func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
566 db.lock.RLock()
567 defer db.lock.RUnlock()
568
569 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
570 FROM DeliveryReceipt
571 WHERE network = ?`, networkID)
572 if err != nil {
573 return nil, err
574 }
575 defer rows.Close()
576
577 var receipts []DeliveryReceipt
578 for rows.Next() {
579 var rcpt DeliveryReceipt
580 var client sql.NullString
581 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
582 return nil, err
583 }
584 rcpt.Client = client.String
585 receipts = append(receipts, rcpt)
586 }
587 if err := rows.Err(); err != nil {
588 return nil, err
589 }
590
591 return receipts, nil
592}
593
594func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
595 db.lock.Lock()
596 defer db.lock.Unlock()
597
598 tx, err := db.db.Begin()
599 if err != nil {
600 return err
601 }
602 defer tx.Rollback()
603
604 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
605 networkID, toNullString(client))
606 if err != nil {
607 return err
608 }
609
610 for i := range receipts {
611 rcpt := &receipts[i]
612
613 res, err := tx.Exec(`INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
614 VALUES (:network, :target, :client, :internal_msgid)`,
615 sql.Named("network", networkID),
616 sql.Named("target", rcpt.Target),
617 sql.Named("client", toNullString(client)),
618 sql.Named("internal_msgid", rcpt.InternalMsgID))
619 if err != nil {
620 return err
621 }
622 rcpt.ID, err = res.LastInsertId()
623 if err != nil {
624 return err
625 }
626 }
627
628 return tx.Commit()
629}
Note: See TracBrowser for help on using the repository browser.