source: code/trunk/db_sqlite.go@ 557

Last change on this file since 557 was 542, checked in by contact, 4 years ago

Allow networks to be disabled

File size: 14.9 KB
RevLine 
[531]1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "strings"
8 "sync"
9 "time"
10
11 _ "github.com/mattn/go-sqlite3"
12)
13
14const sqliteSchema = `
15CREATE TABLE User (
16 id INTEGER PRIMARY KEY,
17 username VARCHAR(255) NOT NULL UNIQUE,
18 password VARCHAR(255),
19 admin INTEGER NOT NULL DEFAULT 0
20);
21
22CREATE TABLE Network (
23 id INTEGER PRIMARY KEY,
24 name VARCHAR(255),
25 user INTEGER NOT NULL,
26 addr VARCHAR(255) NOT NULL,
27 nick VARCHAR(255) NOT NULL,
28 username VARCHAR(255),
29 realname VARCHAR(255),
30 pass VARCHAR(255),
31 connect_commands VARCHAR(1023),
32 sasl_mechanism VARCHAR(255),
33 sasl_plain_username VARCHAR(255),
34 sasl_plain_password VARCHAR(255),
35 sasl_external_cert BLOB DEFAULT NULL,
36 sasl_external_key BLOB DEFAULT NULL,
[542]37 enabled INTEGER NOT NULL DEFAULT 1,
[531]38 FOREIGN KEY(user) REFERENCES User(id),
39 UNIQUE(user, addr, nick),
40 UNIQUE(user, name)
41);
42
43CREATE TABLE Channel (
44 id INTEGER PRIMARY KEY,
45 network INTEGER NOT NULL,
46 name VARCHAR(255) NOT NULL,
47 key VARCHAR(255),
48 detached INTEGER NOT NULL DEFAULT 0,
49 detached_internal_msgid VARCHAR(255),
50 relay_detached INTEGER NOT NULL DEFAULT 0,
51 reattach_on INTEGER NOT NULL DEFAULT 0,
52 detach_after INTEGER NOT NULL DEFAULT 0,
53 detach_on INTEGER NOT NULL DEFAULT 0,
54 FOREIGN KEY(network) REFERENCES Network(id),
55 UNIQUE(network, name)
56);
57
58CREATE TABLE DeliveryReceipt (
59 id INTEGER PRIMARY KEY,
60 network INTEGER NOT NULL,
61 target VARCHAR(255) NOT NULL,
62 client VARCHAR(255),
63 internal_msgid VARCHAR(255) NOT NULL,
64 FOREIGN KEY(network) REFERENCES Network(id),
65 UNIQUE(network, target, client)
66);
67`
68
69var sqliteMigrations = []string{
70 "", // migration #0 is reserved for schema initialization
71 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
72 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
73 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
74 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
75 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
76 `
77 CREATE TABLE UserNew (
78 id INTEGER PRIMARY KEY,
79 username VARCHAR(255) NOT NULL UNIQUE,
80 password VARCHAR(255),
81 admin INTEGER NOT NULL DEFAULT 0
82 );
83 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
84 DROP TABLE User;
85 ALTER TABLE UserNew RENAME TO User;
86 `,
87 `
88 CREATE TABLE NetworkNew (
89 id INTEGER PRIMARY KEY,
90 name VARCHAR(255),
91 user INTEGER NOT NULL,
92 addr VARCHAR(255) NOT NULL,
93 nick VARCHAR(255) NOT NULL,
94 username VARCHAR(255),
95 realname VARCHAR(255),
96 pass VARCHAR(255),
97 connect_commands VARCHAR(1023),
98 sasl_mechanism VARCHAR(255),
99 sasl_plain_username VARCHAR(255),
100 sasl_plain_password VARCHAR(255),
101 sasl_external_cert BLOB DEFAULT NULL,
102 sasl_external_key BLOB DEFAULT NULL,
103 FOREIGN KEY(user) REFERENCES User(id),
104 UNIQUE(user, addr, nick),
105 UNIQUE(user, name)
106 );
107 INSERT INTO NetworkNew
108 SELECT Network.id, name, User.id as user, addr, nick,
109 Network.username, realname, pass, connect_commands,
110 sasl_mechanism, sasl_plain_username, sasl_plain_password,
111 sasl_external_cert, sasl_external_key
112 FROM Network
113 JOIN User ON Network.user = User.username;
114 DROP TABLE Network;
115 ALTER TABLE NetworkNew RENAME TO Network;
116 `,
117 `
118 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
119 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
120 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
121 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
122 `,
123 `
124 CREATE TABLE DeliveryReceipt (
125 id INTEGER PRIMARY KEY,
126 network INTEGER NOT NULL,
127 target VARCHAR(255) NOT NULL,
128 client VARCHAR(255),
129 internal_msgid VARCHAR(255) NOT NULL,
130 FOREIGN KEY(network) REFERENCES Network(id),
131 UNIQUE(network, target, client)
132 );
133 `,
134 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
[542]135 "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
[531]136}
137
138type SqliteDB struct {
139 lock sync.RWMutex
140 db *sql.DB
141}
142
143func OpenSqliteDB(driver, source string) (Database, error) {
144 sqlSqliteDB, err := sql.Open(driver, source)
145 if err != nil {
146 return nil, err
147 }
148
149 db := &SqliteDB{db: sqlSqliteDB}
150 if err := db.upgrade(); err != nil {
151 return nil, err
152 }
153
154 return db, nil
155}
156
157func (db *SqliteDB) Close() error {
158 db.lock.Lock()
159 defer db.lock.Unlock()
160 return db.db.Close()
161}
162
163func (db *SqliteDB) upgrade() error {
164 db.lock.Lock()
165 defer db.lock.Unlock()
166
167 var version int
168 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
169 return fmt.Errorf("failed to query schema version: %v", err)
170 }
171
172 if version == len(sqliteMigrations) {
173 return nil
174 } else if version > len(sqliteMigrations) {
175 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(sqliteMigrations), version)
176 }
177
178 tx, err := db.db.Begin()
179 if err != nil {
180 return err
181 }
182 defer tx.Rollback()
183
184 if version == 0 {
185 if _, err := tx.Exec(sqliteSchema); err != nil {
186 return fmt.Errorf("failed to initialize schema: %v", err)
187 }
188 } else {
189 for i := version; i < len(sqliteMigrations); i++ {
190 if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
191 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
192 }
193 }
194 }
195
196 // For some reason prepared statements don't work here
197 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
198 if err != nil {
199 return fmt.Errorf("failed to bump schema version: %v", err)
200 }
201
202 return tx.Commit()
203}
204
205func toNullString(s string) sql.NullString {
206 return sql.NullString{
207 String: s,
208 Valid: s != "",
209 }
210}
211
212func (db *SqliteDB) ListUsers() ([]User, error) {
213 db.lock.RLock()
214 defer db.lock.RUnlock()
215
216 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
217 if err != nil {
218 return nil, err
219 }
220 defer rows.Close()
221
222 var users []User
223 for rows.Next() {
224 var user User
225 var password sql.NullString
226 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
227 return nil, err
228 }
229 user.Password = password.String
230 users = append(users, user)
231 }
232 if err := rows.Err(); err != nil {
233 return nil, err
234 }
235
236 return users, nil
237}
238
239func (db *SqliteDB) GetUser(username string) (*User, error) {
240 db.lock.RLock()
241 defer db.lock.RUnlock()
242
243 user := &User{Username: username}
244
245 var password sql.NullString
246 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
247 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
248 return nil, err
249 }
250 user.Password = password.String
251 return user, nil
252}
253
254func (db *SqliteDB) StoreUser(user *User) error {
255 db.lock.Lock()
256 defer db.lock.Unlock()
257
258 password := toNullString(user.Password)
259
260 var err error
261 if user.ID != 0 {
262 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
263 password, user.Admin, user.Username)
264 } else {
265 var res sql.Result
266 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
267 user.Username, password, user.Admin)
268 if err != nil {
269 return err
270 }
271 user.ID, err = res.LastInsertId()
272 }
273
274 return err
275}
276
277func (db *SqliteDB) DeleteUser(id int64) error {
278 db.lock.Lock()
279 defer db.lock.Unlock()
280
281 tx, err := db.db.Begin()
282 if err != nil {
283 return err
284 }
285 defer tx.Rollback()
286
287 _, err = tx.Exec(`DELETE FROM Channel
288 WHERE id IN (
289 SELECT Channel.id
290 FROM Channel
291 JOIN Network ON Channel.network = Network.id
292 WHERE Network.user = ?
293 )`, id)
294 if err != nil {
295 return err
296 }
297
298 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
299 if err != nil {
300 return err
301 }
302
303 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
304 if err != nil {
305 return err
306 }
307
308 return tx.Commit()
309}
310
311func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
312 db.lock.RLock()
313 defer db.lock.RUnlock()
314
315 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
316 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
[542]317 sasl_external_cert, sasl_external_key, enabled
[531]318 FROM Network
319 WHERE user = ?`,
320 userID)
321 if err != nil {
322 return nil, err
323 }
324 defer rows.Close()
325
326 var networks []Network
327 for rows.Next() {
328 var net Network
329 var name, username, realname, pass, connectCommands sql.NullString
330 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
331 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
332 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
[542]333 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
[531]334 if err != nil {
335 return nil, err
336 }
337 net.Name = name.String
338 net.Username = username.String
339 net.Realname = realname.String
340 net.Pass = pass.String
341 if connectCommands.Valid {
342 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
343 }
344 net.SASL.Mechanism = saslMechanism.String
345 net.SASL.Plain.Username = saslPlainUsername.String
346 net.SASL.Plain.Password = saslPlainPassword.String
347 networks = append(networks, net)
348 }
349 if err := rows.Err(); err != nil {
350 return nil, err
351 }
352
353 return networks, nil
354}
355
356func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
357 db.lock.Lock()
358 defer db.lock.Unlock()
359
360 netName := toNullString(network.Name)
361 netUsername := toNullString(network.Username)
362 realname := toNullString(network.Realname)
363 pass := toNullString(network.Pass)
364 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
365
366 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
367 if network.SASL.Mechanism != "" {
368 saslMechanism = toNullString(network.SASL.Mechanism)
369 switch network.SASL.Mechanism {
370 case "PLAIN":
371 saslPlainUsername = toNullString(network.SASL.Plain.Username)
372 saslPlainPassword = toNullString(network.SASL.Plain.Password)
373 network.SASL.External.CertBlob = nil
374 network.SASL.External.PrivKeyBlob = nil
375 case "EXTERNAL":
376 // keep saslPlain* nil
377 default:
378 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
379 }
380 }
381
382 var err error
383 if network.ID != 0 {
384 _, err = db.db.Exec(`UPDATE Network
385 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
386 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
[542]387 sasl_external_cert = ?, sasl_external_key = ?, enabled = ?
[531]388 WHERE id = ?`,
389 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
390 saslMechanism, saslPlainUsername, saslPlainPassword,
[542]391 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob, network.Enabled,
[531]392 network.ID)
393 } else {
394 var res sql.Result
395 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
396 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[542]397 sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
398 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[531]399 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
400 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
[542]401 network.SASL.External.PrivKeyBlob, network.Enabled)
[531]402 if err != nil {
403 return err
404 }
405 network.ID, err = res.LastInsertId()
406 }
407 return err
408}
409
410func (db *SqliteDB) DeleteNetwork(id int64) error {
411 db.lock.Lock()
412 defer db.lock.Unlock()
413
414 tx, err := db.db.Begin()
415 if err != nil {
416 return err
417 }
418 defer tx.Rollback()
419
420 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
421 if err != nil {
422 return err
423 }
424
425 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
426 if err != nil {
427 return err
428 }
429
430 return tx.Commit()
431}
432
433func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
434 db.lock.RLock()
435 defer db.lock.RUnlock()
436
437 rows, err := db.db.Query(`SELECT
438 id, name, key, detached, detached_internal_msgid,
439 relay_detached, reattach_on, detach_after, detach_on
440 FROM Channel
441 WHERE network = ?`, networkID)
442 if err != nil {
443 return nil, err
444 }
445 defer rows.Close()
446
447 var channels []Channel
448 for rows.Next() {
449 var ch Channel
450 var key, detachedInternalMsgID sql.NullString
451 var detachAfter int64
452 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
453 return nil, err
454 }
455 ch.Key = key.String
456 ch.DetachedInternalMsgID = detachedInternalMsgID.String
457 ch.DetachAfter = time.Duration(detachAfter) * time.Second
458 channels = append(channels, ch)
459 }
460 if err := rows.Err(); err != nil {
461 return nil, err
462 }
463
464 return channels, nil
465}
466
467func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
468 db.lock.Lock()
469 defer db.lock.Unlock()
470
471 key := toNullString(ch.Key)
472 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
473
474 var err error
475 if ch.ID != 0 {
476 _, err = db.db.Exec(`UPDATE Channel
477 SET network = ?, name = ?, key = ?, detached = ?, detached_internal_msgid = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?
478 WHERE id = ?`,
479 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)
480 } else {
481 var res sql.Result
482 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
483 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
484 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
485 if err != nil {
486 return err
487 }
488 ch.ID, err = res.LastInsertId()
489 }
490 return err
491}
492
493func (db *SqliteDB) DeleteChannel(id int64) error {
494 db.lock.Lock()
495 defer db.lock.Unlock()
496
497 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
498 return err
499}
500
501func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
502 db.lock.RLock()
503 defer db.lock.RUnlock()
504
505 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
506 FROM DeliveryReceipt
507 WHERE network = ?`, networkID)
508 if err != nil {
509 return nil, err
510 }
511 defer rows.Close()
512
513 var receipts []DeliveryReceipt
514 for rows.Next() {
515 var rcpt DeliveryReceipt
516 var client sql.NullString
517 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
518 return nil, err
519 }
520 rcpt.Client = client.String
521 receipts = append(receipts, rcpt)
522 }
523 if err := rows.Err(); err != nil {
524 return nil, err
525 }
526
527 return receipts, nil
528}
529
530func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
531 db.lock.Lock()
532 defer db.lock.Unlock()
533
534 tx, err := db.db.Begin()
535 if err != nil {
536 return err
537 }
538 defer tx.Rollback()
539
540 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",
541 networkID, toNullString(client))
542 if err != nil {
543 return err
544 }
545
546 for i := range receipts {
547 rcpt := &receipts[i]
548
549 res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
550 networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
551 if err != nil {
552 return err
553 }
554 rcpt.ID, err = res.LastInsertId()
555 if err != nil {
556 return err
557 }
558 }
559
560 return tx.Commit()
561}
Note: See TracBrowser for help on using the repository browser.