source: code/trunk/db_sqlite.go@ 595

Last change on this file since 595 was 595, checked in by hubert, 4 years ago

Fix DeliveryReceipt not being cleaned up

File size: 15.5 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "strings"
8 "sync"
9 "time"
10
11 _ "github.com/mattn/go-sqlite3"
12)
13
14const sqliteSchema = `
15CREATE TABLE User (
16 id INTEGER PRIMARY KEY,
17 username VARCHAR(255) NOT NULL UNIQUE,
18 password VARCHAR(255),
19 admin INTEGER NOT NULL DEFAULT 0,
20 realname VARCHAR(255)
21);
22
23CREATE TABLE Network (
24 id INTEGER PRIMARY KEY,
25 name VARCHAR(255),
26 user INTEGER NOT NULL,
27 addr VARCHAR(255) NOT NULL,
28 nick VARCHAR(255) NOT NULL,
29 username VARCHAR(255),
30 realname VARCHAR(255),
31 pass VARCHAR(255),
32 connect_commands VARCHAR(1023),
33 sasl_mechanism VARCHAR(255),
34 sasl_plain_username VARCHAR(255),
35 sasl_plain_password VARCHAR(255),
36 sasl_external_cert BLOB DEFAULT NULL,
37 sasl_external_key BLOB DEFAULT NULL,
38 enabled INTEGER NOT NULL DEFAULT 1,
39 FOREIGN KEY(user) REFERENCES User(id),
40 UNIQUE(user, addr, nick),
41 UNIQUE(user, name)
42);
43
44CREATE TABLE Channel (
45 id INTEGER PRIMARY KEY,
46 network INTEGER NOT NULL,
47 name VARCHAR(255) NOT NULL,
48 key VARCHAR(255),
49 detached INTEGER NOT NULL DEFAULT 0,
50 detached_internal_msgid VARCHAR(255),
51 relay_detached INTEGER NOT NULL DEFAULT 0,
52 reattach_on INTEGER NOT NULL DEFAULT 0,
53 detach_after INTEGER NOT NULL DEFAULT 0,
54 detach_on INTEGER NOT NULL DEFAULT 0,
55 FOREIGN KEY(network) REFERENCES Network(id),
56 UNIQUE(network, name)
57);
58
59CREATE TABLE DeliveryReceipt (
60 id INTEGER PRIMARY KEY,
61 network INTEGER NOT NULL,
62 target VARCHAR(255) NOT NULL,
63 client VARCHAR(255),
64 internal_msgid VARCHAR(255) NOT NULL,
65 FOREIGN KEY(network) REFERENCES Network(id),
66 UNIQUE(network, target, client)
67);
68`
69
70var sqliteMigrations = []string{
71 "", // migration #0 is reserved for schema initialization
72 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
73 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
74 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
75 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
76 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
77 `
78 CREATE TABLE UserNew (
79 id INTEGER PRIMARY KEY,
80 username VARCHAR(255) NOT NULL UNIQUE,
81 password VARCHAR(255),
82 admin INTEGER NOT NULL DEFAULT 0
83 );
84 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
85 DROP TABLE User;
86 ALTER TABLE UserNew RENAME TO User;
87 `,
88 `
89 CREATE TABLE NetworkNew (
90 id INTEGER PRIMARY KEY,
91 name VARCHAR(255),
92 user INTEGER NOT NULL,
93 addr VARCHAR(255) NOT NULL,
94 nick VARCHAR(255) NOT NULL,
95 username VARCHAR(255),
96 realname VARCHAR(255),
97 pass VARCHAR(255),
98 connect_commands VARCHAR(1023),
99 sasl_mechanism VARCHAR(255),
100 sasl_plain_username VARCHAR(255),
101 sasl_plain_password VARCHAR(255),
102 sasl_external_cert BLOB DEFAULT NULL,
103 sasl_external_key BLOB DEFAULT NULL,
104 FOREIGN KEY(user) REFERENCES User(id),
105 UNIQUE(user, addr, nick),
106 UNIQUE(user, name)
107 );
108 INSERT INTO NetworkNew
109 SELECT Network.id, name, User.id as user, addr, nick,
110 Network.username, realname, pass, connect_commands,
111 sasl_mechanism, sasl_plain_username, sasl_plain_password,
112 sasl_external_cert, sasl_external_key
113 FROM Network
114 JOIN User ON Network.user = User.username;
115 DROP TABLE Network;
116 ALTER TABLE NetworkNew RENAME TO Network;
117 `,
118 `
119 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
120 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
121 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
122 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
123 `,
124 `
125 CREATE TABLE DeliveryReceipt (
126 id INTEGER PRIMARY KEY,
127 network INTEGER NOT NULL,
128 target VARCHAR(255) NOT NULL,
129 client VARCHAR(255),
130 internal_msgid VARCHAR(255) NOT NULL,
131 FOREIGN KEY(network) REFERENCES Network(id),
132 UNIQUE(network, target, client)
133 );
134 `,
135 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
136 "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
137 "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
138}
139
140type SqliteDB struct {
141 lock sync.RWMutex
142 db *sql.DB
143}
144
145func OpenSqliteDB(driver, source string) (Database, error) {
146 sqlSqliteDB, err := sql.Open(driver, source)
147 if err != nil {
148 return nil, err
149 }
150
151 db := &SqliteDB{db: sqlSqliteDB}
152 if err := db.upgrade(); err != nil {
153 sqlSqliteDB.Close()
154 return nil, err
155 }
156
157 return db, nil
158}
159
160func (db *SqliteDB) Close() error {
161 db.lock.Lock()
162 defer db.lock.Unlock()
163 return db.db.Close()
164}
165
166func (db *SqliteDB) upgrade() error {
167 db.lock.Lock()
168 defer db.lock.Unlock()
169
170 var version int
171 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
172 return fmt.Errorf("failed to query schema version: %v", err)
173 }
174
175 if version == len(sqliteMigrations) {
176 return nil
177 } else if version > len(sqliteMigrations) {
178 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(sqliteMigrations), version)
179 }
180
181 tx, err := db.db.Begin()
182 if err != nil {
183 return err
184 }
185 defer tx.Rollback()
186
187 if version == 0 {
188 if _, err := tx.Exec(sqliteSchema); err != nil {
189 return fmt.Errorf("failed to initialize schema: %v", err)
190 }
191 } else {
192 for i := version; i < len(sqliteMigrations); i++ {
193 if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
194 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
195 }
196 }
197 }
198
199 // For some reason prepared statements don't work here
200 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
201 if err != nil {
202 return fmt.Errorf("failed to bump schema version: %v", err)
203 }
204
205 return tx.Commit()
206}
207
208func toNullString(s string) sql.NullString {
209 return sql.NullString{
210 String: s,
211 Valid: s != "",
212 }
213}
214
215func (db *SqliteDB) ListUsers() ([]User, error) {
216 db.lock.RLock()
217 defer db.lock.RUnlock()
218
219 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
220 if err != nil {
221 return nil, err
222 }
223 defer rows.Close()
224
225 var users []User
226 for rows.Next() {
227 var user User
228 var password sql.NullString
229 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
230 return nil, err
231 }
232 user.Password = password.String
233 users = append(users, user)
234 }
235 if err := rows.Err(); err != nil {
236 return nil, err
237 }
238
239 return users, nil
240}
241
242func (db *SqliteDB) GetUser(username string) (*User, error) {
243 db.lock.RLock()
244 defer db.lock.RUnlock()
245
246 user := &User{Username: username}
247
248 var password, realname sql.NullString
249 row := db.db.QueryRow("SELECT id, password, admin, realname FROM User WHERE username = ?", username)
250 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
251 return nil, err
252 }
253 user.Password = password.String
254 user.Realname = realname.String
255 return user, nil
256}
257
258func (db *SqliteDB) StoreUser(user *User) error {
259 db.lock.Lock()
260 defer db.lock.Unlock()
261
262 password := toNullString(user.Password)
263 realname := toNullString(user.Realname)
264
265 var err error
266 if user.ID != 0 {
267 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ?, realname = ? WHERE username = ?",
268 password, user.Admin, realname, user.Username)
269 } else {
270 var res sql.Result
271 res, err = db.db.Exec("INSERT INTO User(username, password, admin, realname) VALUES (?, ?, ?, ?)",
272 user.Username, password, user.Admin, realname)
273 if err != nil {
274 return err
275 }
276 user.ID, err = res.LastInsertId()
277 }
278
279 return err
280}
281
282func (db *SqliteDB) DeleteUser(id int64) error {
283 db.lock.Lock()
284 defer db.lock.Unlock()
285
286 tx, err := db.db.Begin()
287 if err != nil {
288 return err
289 }
290 defer tx.Rollback()
291
292 _, err = tx.Exec(`DELETE FROM DeliveryReceipt
293 WHERE id IN (
294 SELECT DeliveryReceipt.id
295 FROM DeliveryReceipt
296 JOIN Network ON DeliveryReceipt.network = Network.id
297 WHERE Network.user = ?
298 )`, id)
299 if err != nil {
300 return err
301 }
302
303 _, err = tx.Exec(`DELETE FROM Channel
304 WHERE id IN (
305 SELECT Channel.id
306 FROM Channel
307 JOIN Network ON Channel.network = Network.id
308 WHERE Network.user = ?
309 )`, id)
310 if err != nil {
311 return err
312 }
313
314 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
315 if err != nil {
316 return err
317 }
318
319 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
320 if err != nil {
321 return err
322 }
323
324 return tx.Commit()
325}
326
327func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
328 db.lock.RLock()
329 defer db.lock.RUnlock()
330
331 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
332 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
333 sasl_external_cert, sasl_external_key, enabled
334 FROM Network
335 WHERE user = ?`,
336 userID)
337 if err != nil {
338 return nil, err
339 }
340 defer rows.Close()
341
342 var networks []Network
343 for rows.Next() {
344 var net Network
345 var name, username, realname, pass, connectCommands sql.NullString
346 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
347 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
348 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
349 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
350 if err != nil {
351 return nil, err
352 }
353 net.Name = name.String
354 net.Username = username.String
355 net.Realname = realname.String
356 net.Pass = pass.String
357 if connectCommands.Valid {
358 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
359 }
360 net.SASL.Mechanism = saslMechanism.String
361 net.SASL.Plain.Username = saslPlainUsername.String
362 net.SASL.Plain.Password = saslPlainPassword.String
363 networks = append(networks, net)
364 }
365 if err := rows.Err(); err != nil {
366 return nil, err
367 }
368
369 return networks, nil
370}
371
372func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
373 db.lock.Lock()
374 defer db.lock.Unlock()
375
376 netName := toNullString(network.Name)
377 netUsername := toNullString(network.Username)
378 realname := toNullString(network.Realname)
379 pass := toNullString(network.Pass)
380 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
381
382 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
383 if network.SASL.Mechanism != "" {
384 saslMechanism = toNullString(network.SASL.Mechanism)
385 switch network.SASL.Mechanism {
386 case "PLAIN":
387 saslPlainUsername = toNullString(network.SASL.Plain.Username)
388 saslPlainPassword = toNullString(network.SASL.Plain.Password)
389 network.SASL.External.CertBlob = nil
390 network.SASL.External.PrivKeyBlob = nil
391 case "EXTERNAL":
392 // keep saslPlain* nil
393 default:
394 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
395 }
396 }
397
398 var err error
399 if network.ID != 0 {
400 _, err = db.db.Exec(`UPDATE Network
401 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
402 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
403 sasl_external_cert = ?, sasl_external_key = ?, enabled = ?
404 WHERE id = ?`,
405 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
406 saslMechanism, saslPlainUsername, saslPlainPassword,
407 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob, network.Enabled,
408 network.ID)
409 } else {
410 var res sql.Result
411 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
412 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
413 sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
414 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
415 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
416 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
417 network.SASL.External.PrivKeyBlob, network.Enabled)
418 if err != nil {
419 return err
420 }
421 network.ID, err = res.LastInsertId()
422 }
423 return err
424}
425
426func (db *SqliteDB) DeleteNetwork(id int64) error {
427 db.lock.Lock()
428 defer db.lock.Unlock()
429
430 tx, err := db.db.Begin()
431 if err != nil {
432 return err
433 }
434 defer tx.Rollback()
435
436 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ?", id)
437 if err != nil {
438 return err
439 }
440
441 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
442 if err != nil {
443 return err
444 }
445
446 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
447 if err != nil {
448 return err
449 }
450
451 return tx.Commit()
452}
453
454func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
455 db.lock.RLock()
456 defer db.lock.RUnlock()
457
458 rows, err := db.db.Query(`SELECT
459 id, name, key, detached, detached_internal_msgid,
460 relay_detached, reattach_on, detach_after, detach_on
461 FROM Channel
462 WHERE network = ?`, networkID)
463 if err != nil {
464 return nil, err
465 }
466 defer rows.Close()
467
468 var channels []Channel
469 for rows.Next() {
470 var ch Channel
471 var key, detachedInternalMsgID sql.NullString
472 var detachAfter int64
473 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
474 return nil, err
475 }
476 ch.Key = key.String
477 ch.DetachedInternalMsgID = detachedInternalMsgID.String
478 ch.DetachAfter = time.Duration(detachAfter) * time.Second
479 channels = append(channels, ch)
480 }
481 if err := rows.Err(); err != nil {
482 return nil, err
483 }
484
485 return channels, nil
486}
487
488func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
489 db.lock.Lock()
490 defer db.lock.Unlock()
491
492 key := toNullString(ch.Key)
493 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
494
495 var err error
496 if ch.ID != 0 {
497 _, err = db.db.Exec(`UPDATE Channel
498 SET network = ?, name = ?, key = ?, detached = ?, detached_internal_msgid = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?
499 WHERE id = ?`,
500 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)
501 } else {
502 var res sql.Result
503 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
504 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
505 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
506 if err != nil {
507 return err
508 }
509 ch.ID, err = res.LastInsertId()
510 }
511 return err
512}
513
514func (db *SqliteDB) DeleteChannel(id int64) error {
515 db.lock.Lock()
516 defer db.lock.Unlock()
517
518 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
519 return err
520}
521
522func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
523 db.lock.RLock()
524 defer db.lock.RUnlock()
525
526 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
527 FROM DeliveryReceipt
528 WHERE network = ?`, networkID)
529 if err != nil {
530 return nil, err
531 }
532 defer rows.Close()
533
534 var receipts []DeliveryReceipt
535 for rows.Next() {
536 var rcpt DeliveryReceipt
537 var client sql.NullString
538 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
539 return nil, err
540 }
541 rcpt.Client = client.String
542 receipts = append(receipts, rcpt)
543 }
544 if err := rows.Err(); err != nil {
545 return nil, err
546 }
547
548 return receipts, nil
549}
550
551func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
552 db.lock.Lock()
553 defer db.lock.Unlock()
554
555 tx, err := db.db.Begin()
556 if err != nil {
557 return err
558 }
559 defer tx.Rollback()
560
561 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
562 networkID, toNullString(client))
563 if err != nil {
564 return err
565 }
566
567 for i := range receipts {
568 rcpt := &receipts[i]
569
570 res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
571 networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
572 if err != nil {
573 return err
574 }
575 rcpt.ID, err = res.LastInsertId()
576 if err != nil {
577 return err
578 }
579 }
580
581 return tx.Commit()
582}
Note: See TracBrowser for help on using the repository browser.