source: code/trunk/db_sqlite.go@ 597

Last change on this file since 597 was 596, checked in by contact, 4 years ago

db_sqlite: switch to sql.Named

This allows us to avoid mixing up arguments.

File size: 16.5 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "strings"
8 "sync"
9 "time"
10
11 _ "github.com/mattn/go-sqlite3"
12)
13
14const sqliteSchema = `
15CREATE TABLE User (
16 id INTEGER PRIMARY KEY,
17 username VARCHAR(255) NOT NULL UNIQUE,
18 password VARCHAR(255),
19 admin INTEGER NOT NULL DEFAULT 0,
20 realname VARCHAR(255)
21);
22
23CREATE TABLE Network (
24 id INTEGER PRIMARY KEY,
25 name VARCHAR(255),
26 user INTEGER NOT NULL,
27 addr VARCHAR(255) NOT NULL,
28 nick VARCHAR(255) NOT NULL,
29 username VARCHAR(255),
30 realname VARCHAR(255),
31 pass VARCHAR(255),
32 connect_commands VARCHAR(1023),
33 sasl_mechanism VARCHAR(255),
34 sasl_plain_username VARCHAR(255),
35 sasl_plain_password VARCHAR(255),
36 sasl_external_cert BLOB DEFAULT NULL,
37 sasl_external_key BLOB DEFAULT NULL,
38 enabled INTEGER NOT NULL DEFAULT 1,
39 FOREIGN KEY(user) REFERENCES User(id),
40 UNIQUE(user, addr, nick),
41 UNIQUE(user, name)
42);
43
44CREATE TABLE Channel (
45 id INTEGER PRIMARY KEY,
46 network INTEGER NOT NULL,
47 name VARCHAR(255) NOT NULL,
48 key VARCHAR(255),
49 detached INTEGER NOT NULL DEFAULT 0,
50 detached_internal_msgid VARCHAR(255),
51 relay_detached INTEGER NOT NULL DEFAULT 0,
52 reattach_on INTEGER NOT NULL DEFAULT 0,
53 detach_after INTEGER NOT NULL DEFAULT 0,
54 detach_on INTEGER NOT NULL DEFAULT 0,
55 FOREIGN KEY(network) REFERENCES Network(id),
56 UNIQUE(network, name)
57);
58
59CREATE TABLE DeliveryReceipt (
60 id INTEGER PRIMARY KEY,
61 network INTEGER NOT NULL,
62 target VARCHAR(255) NOT NULL,
63 client VARCHAR(255),
64 internal_msgid VARCHAR(255) NOT NULL,
65 FOREIGN KEY(network) REFERENCES Network(id),
66 UNIQUE(network, target, client)
67);
68`
69
70var sqliteMigrations = []string{
71 "", // migration #0 is reserved for schema initialization
72 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
73 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
74 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
75 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
76 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
77 `
78 CREATE TABLE UserNew (
79 id INTEGER PRIMARY KEY,
80 username VARCHAR(255) NOT NULL UNIQUE,
81 password VARCHAR(255),
82 admin INTEGER NOT NULL DEFAULT 0
83 );
84 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
85 DROP TABLE User;
86 ALTER TABLE UserNew RENAME TO User;
87 `,
88 `
89 CREATE TABLE NetworkNew (
90 id INTEGER PRIMARY KEY,
91 name VARCHAR(255),
92 user INTEGER NOT NULL,
93 addr VARCHAR(255) NOT NULL,
94 nick VARCHAR(255) NOT NULL,
95 username VARCHAR(255),
96 realname VARCHAR(255),
97 pass VARCHAR(255),
98 connect_commands VARCHAR(1023),
99 sasl_mechanism VARCHAR(255),
100 sasl_plain_username VARCHAR(255),
101 sasl_plain_password VARCHAR(255),
102 sasl_external_cert BLOB DEFAULT NULL,
103 sasl_external_key BLOB DEFAULT NULL,
104 FOREIGN KEY(user) REFERENCES User(id),
105 UNIQUE(user, addr, nick),
106 UNIQUE(user, name)
107 );
108 INSERT INTO NetworkNew
109 SELECT Network.id, name, User.id as user, addr, nick,
110 Network.username, realname, pass, connect_commands,
111 sasl_mechanism, sasl_plain_username, sasl_plain_password,
112 sasl_external_cert, sasl_external_key
113 FROM Network
114 JOIN User ON Network.user = User.username;
115 DROP TABLE Network;
116 ALTER TABLE NetworkNew RENAME TO Network;
117 `,
118 `
119 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
120 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
121 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
122 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
123 `,
124 `
125 CREATE TABLE DeliveryReceipt (
126 id INTEGER PRIMARY KEY,
127 network INTEGER NOT NULL,
128 target VARCHAR(255) NOT NULL,
129 client VARCHAR(255),
130 internal_msgid VARCHAR(255) NOT NULL,
131 FOREIGN KEY(network) REFERENCES Network(id),
132 UNIQUE(network, target, client)
133 );
134 `,
135 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
136 "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
137 "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
138}
139
140type SqliteDB struct {
141 lock sync.RWMutex
142 db *sql.DB
143}
144
145func OpenSqliteDB(driver, source string) (Database, error) {
146 sqlSqliteDB, err := sql.Open(driver, source)
147 if err != nil {
148 return nil, err
149 }
150
151 db := &SqliteDB{db: sqlSqliteDB}
152 if err := db.upgrade(); err != nil {
153 sqlSqliteDB.Close()
154 return nil, err
155 }
156
157 return db, nil
158}
159
160func (db *SqliteDB) Close() error {
161 db.lock.Lock()
162 defer db.lock.Unlock()
163 return db.db.Close()
164}
165
166func (db *SqliteDB) upgrade() error {
167 db.lock.Lock()
168 defer db.lock.Unlock()
169
170 var version int
171 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
172 return fmt.Errorf("failed to query schema version: %v", err)
173 }
174
175 if version == len(sqliteMigrations) {
176 return nil
177 } else if version > len(sqliteMigrations) {
178 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(sqliteMigrations), version)
179 }
180
181 tx, err := db.db.Begin()
182 if err != nil {
183 return err
184 }
185 defer tx.Rollback()
186
187 if version == 0 {
188 if _, err := tx.Exec(sqliteSchema); err != nil {
189 return fmt.Errorf("failed to initialize schema: %v", err)
190 }
191 } else {
192 for i := version; i < len(sqliteMigrations); i++ {
193 if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
194 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
195 }
196 }
197 }
198
199 // For some reason prepared statements don't work here
200 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
201 if err != nil {
202 return fmt.Errorf("failed to bump schema version: %v", err)
203 }
204
205 return tx.Commit()
206}
207
208func toNullString(s string) sql.NullString {
209 return sql.NullString{
210 String: s,
211 Valid: s != "",
212 }
213}
214
215func (db *SqliteDB) ListUsers() ([]User, error) {
216 db.lock.RLock()
217 defer db.lock.RUnlock()
218
219 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
220 if err != nil {
221 return nil, err
222 }
223 defer rows.Close()
224
225 var users []User
226 for rows.Next() {
227 var user User
228 var password sql.NullString
229 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
230 return nil, err
231 }
232 user.Password = password.String
233 users = append(users, user)
234 }
235 if err := rows.Err(); err != nil {
236 return nil, err
237 }
238
239 return users, nil
240}
241
242func (db *SqliteDB) GetUser(username string) (*User, error) {
243 db.lock.RLock()
244 defer db.lock.RUnlock()
245
246 user := &User{Username: username}
247
248 var password, realname sql.NullString
249 row := db.db.QueryRow("SELECT id, password, admin, realname FROM User WHERE username = ?", username)
250 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
251 return nil, err
252 }
253 user.Password = password.String
254 user.Realname = realname.String
255 return user, nil
256}
257
258func (db *SqliteDB) StoreUser(user *User) error {
259 db.lock.Lock()
260 defer db.lock.Unlock()
261
262 args := []interface{}{
263 sql.Named("username", user.Username),
264 sql.Named("password", toNullString(user.Password)),
265 sql.Named("admin", user.Admin),
266 sql.Named("realname", toNullString(user.Realname)),
267 }
268
269 var err error
270 if user.ID != 0 {
271 _, err = db.db.Exec("UPDATE User SET password = :password, admin = :admin, realname = :realname WHERE username = :username", args...)
272 } else {
273 var res sql.Result
274 res, err = db.db.Exec("INSERT INTO User(username, password, admin, realname) VALUES (:username, :password, :admin, :realname)", args...)
275 if err != nil {
276 return err
277 }
278 user.ID, err = res.LastInsertId()
279 }
280
281 return err
282}
283
284func (db *SqliteDB) DeleteUser(id int64) error {
285 db.lock.Lock()
286 defer db.lock.Unlock()
287
288 tx, err := db.db.Begin()
289 if err != nil {
290 return err
291 }
292 defer tx.Rollback()
293
294 _, err = tx.Exec(`DELETE FROM DeliveryReceipt
295 WHERE id IN (
296 SELECT DeliveryReceipt.id
297 FROM DeliveryReceipt
298 JOIN Network ON DeliveryReceipt.network = Network.id
299 WHERE Network.user = ?
300 )`, id)
301 if err != nil {
302 return err
303 }
304
305 _, err = tx.Exec(`DELETE FROM Channel
306 WHERE id IN (
307 SELECT Channel.id
308 FROM Channel
309 JOIN Network ON Channel.network = Network.id
310 WHERE Network.user = ?
311 )`, id)
312 if err != nil {
313 return err
314 }
315
316 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
317 if err != nil {
318 return err
319 }
320
321 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
322 if err != nil {
323 return err
324 }
325
326 return tx.Commit()
327}
328
329func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
330 db.lock.RLock()
331 defer db.lock.RUnlock()
332
333 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
334 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
335 sasl_external_cert, sasl_external_key, enabled
336 FROM Network
337 WHERE user = ?`,
338 userID)
339 if err != nil {
340 return nil, err
341 }
342 defer rows.Close()
343
344 var networks []Network
345 for rows.Next() {
346 var net Network
347 var name, username, realname, pass, connectCommands sql.NullString
348 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
349 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
350 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
351 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
352 if err != nil {
353 return nil, err
354 }
355 net.Name = name.String
356 net.Username = username.String
357 net.Realname = realname.String
358 net.Pass = pass.String
359 if connectCommands.Valid {
360 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
361 }
362 net.SASL.Mechanism = saslMechanism.String
363 net.SASL.Plain.Username = saslPlainUsername.String
364 net.SASL.Plain.Password = saslPlainPassword.String
365 networks = append(networks, net)
366 }
367 if err := rows.Err(); err != nil {
368 return nil, err
369 }
370
371 return networks, nil
372}
373
374func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
375 db.lock.Lock()
376 defer db.lock.Unlock()
377
378 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
379 if network.SASL.Mechanism != "" {
380 saslMechanism = toNullString(network.SASL.Mechanism)
381 switch network.SASL.Mechanism {
382 case "PLAIN":
383 saslPlainUsername = toNullString(network.SASL.Plain.Username)
384 saslPlainPassword = toNullString(network.SASL.Plain.Password)
385 network.SASL.External.CertBlob = nil
386 network.SASL.External.PrivKeyBlob = nil
387 case "EXTERNAL":
388 // keep saslPlain* nil
389 default:
390 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
391 }
392 }
393
394 args := []interface{}{
395 sql.Named("name", toNullString(network.Name)),
396 sql.Named("addr", network.Addr),
397 sql.Named("nick", network.Nick),
398 sql.Named("username", toNullString(network.Username)),
399 sql.Named("realname", toNullString(network.Realname)),
400 sql.Named("pass", toNullString(network.Pass)),
401 sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
402 sql.Named("sasl_mechanism", saslMechanism),
403 sql.Named("sasl_plain_username", saslPlainUsername),
404 sql.Named("sasl_plain_password", saslPlainPassword),
405 sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
406 sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
407 sql.Named("enabled", network.Enabled),
408
409 sql.Named("id", network.ID), // only for UPDATE
410 sql.Named("user", userID), // only for INSERT
411 }
412
413 var err error
414 if network.ID != 0 {
415 _, err = db.db.Exec(`
416 UPDATE Network
417 SET name = :name, addr = :addr, nick = :nick, username = :username,
418 realname = :realname, pass = :pass, connect_commands = :connect_commands,
419 sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
420 sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
421 enabled = :enabled
422 WHERE id = :id`, args...)
423 } else {
424 var res sql.Result
425 res, err = db.db.Exec(`
426 INSERT INTO Network(user, name, addr, nick, username, realname, pass,
427 connect_commands, sasl_mechanism, sasl_plain_username,
428 sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
429 VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
430 :connect_commands, :sasl_mechanism, :sasl_plain_username,
431 :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
432 args...)
433 if err != nil {
434 return err
435 }
436 network.ID, err = res.LastInsertId()
437 }
438 return err
439}
440
441func (db *SqliteDB) DeleteNetwork(id int64) error {
442 db.lock.Lock()
443 defer db.lock.Unlock()
444
445 tx, err := db.db.Begin()
446 if err != nil {
447 return err
448 }
449 defer tx.Rollback()
450
451 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ?", id)
452 if err != nil {
453 return err
454 }
455
456 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
457 if err != nil {
458 return err
459 }
460
461 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
462 if err != nil {
463 return err
464 }
465
466 return tx.Commit()
467}
468
469func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
470 db.lock.RLock()
471 defer db.lock.RUnlock()
472
473 rows, err := db.db.Query(`SELECT
474 id, name, key, detached, detached_internal_msgid,
475 relay_detached, reattach_on, detach_after, detach_on
476 FROM Channel
477 WHERE network = ?`, networkID)
478 if err != nil {
479 return nil, err
480 }
481 defer rows.Close()
482
483 var channels []Channel
484 for rows.Next() {
485 var ch Channel
486 var key, detachedInternalMsgID sql.NullString
487 var detachAfter int64
488 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
489 return nil, err
490 }
491 ch.Key = key.String
492 ch.DetachedInternalMsgID = detachedInternalMsgID.String
493 ch.DetachAfter = time.Duration(detachAfter) * time.Second
494 channels = append(channels, ch)
495 }
496 if err := rows.Err(); err != nil {
497 return nil, err
498 }
499
500 return channels, nil
501}
502
503func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
504 db.lock.Lock()
505 defer db.lock.Unlock()
506
507 args := []interface{}{
508 sql.Named("network", networkID),
509 sql.Named("name", ch.Name),
510 sql.Named("key", toNullString(ch.Key)),
511 sql.Named("detached", ch.Detached),
512 sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)),
513 sql.Named("relay_detached", ch.RelayDetached),
514 sql.Named("reattach_on", ch.ReattachOn),
515 sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))),
516 sql.Named("detach_on", ch.DetachOn),
517
518 sql.Named("id", ch.ID), // only for UPDATE
519 }
520
521 var err error
522 if ch.ID != 0 {
523 _, err = db.db.Exec(`UPDATE Channel
524 SET network = :network, name = :name, key = :key, detached = :detached,
525 detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached,
526 reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on
527 WHERE id = :id`, args...)
528 } else {
529 var res sql.Result
530 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
531 VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...)
532 if err != nil {
533 return err
534 }
535 ch.ID, err = res.LastInsertId()
536 }
537 return err
538}
539
540func (db *SqliteDB) DeleteChannel(id int64) error {
541 db.lock.Lock()
542 defer db.lock.Unlock()
543
544 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
545 return err
546}
547
548func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
549 db.lock.RLock()
550 defer db.lock.RUnlock()
551
552 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
553 FROM DeliveryReceipt
554 WHERE network = ?`, networkID)
555 if err != nil {
556 return nil, err
557 }
558 defer rows.Close()
559
560 var receipts []DeliveryReceipt
561 for rows.Next() {
562 var rcpt DeliveryReceipt
563 var client sql.NullString
564 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
565 return nil, err
566 }
567 rcpt.Client = client.String
568 receipts = append(receipts, rcpt)
569 }
570 if err := rows.Err(); err != nil {
571 return nil, err
572 }
573
574 return receipts, nil
575}
576
577func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
578 db.lock.Lock()
579 defer db.lock.Unlock()
580
581 tx, err := db.db.Begin()
582 if err != nil {
583 return err
584 }
585 defer tx.Rollback()
586
587 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
588 networkID, toNullString(client))
589 if err != nil {
590 return err
591 }
592
593 for i := range receipts {
594 rcpt := &receipts[i]
595
596 res, err := tx.Exec(`INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
597 VALUES (:network, :target, :client, :internal_msgid)`,
598 sql.Named("network", networkID),
599 sql.Named("target", rcpt.Target),
600 sql.Named("client", toNullString(client)),
601 sql.Named("internal_msgid", rcpt.InternalMsgID))
602 if err != nil {
603 return err
604 }
605 rcpt.ID, err = res.LastInsertId()
606 if err != nil {
607 return err
608 }
609 }
610
611 return tx.Commit()
612}
Note: See TracBrowser for help on using the repository browser.