source: code/trunk/db_sqlite.go@ 599

Last change on this file since 599 was 598, checked in by contact, 4 years ago

db_sqlite: fix realname not fetched in ListUsers

This fixes per-user realname not being used on bouncer startup.

File size: 16.5 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "strings"
8 "sync"
9 "time"
10
11 _ "github.com/mattn/go-sqlite3"
12)
13
14const sqliteSchema = `
15CREATE TABLE User (
16 id INTEGER PRIMARY KEY,
17 username VARCHAR(255) NOT NULL UNIQUE,
18 password VARCHAR(255),
19 admin INTEGER NOT NULL DEFAULT 0,
20 realname VARCHAR(255)
21);
22
23CREATE TABLE Network (
24 id INTEGER PRIMARY KEY,
25 name VARCHAR(255),
26 user INTEGER NOT NULL,
27 addr VARCHAR(255) NOT NULL,
28 nick VARCHAR(255) NOT NULL,
29 username VARCHAR(255),
30 realname VARCHAR(255),
31 pass VARCHAR(255),
32 connect_commands VARCHAR(1023),
33 sasl_mechanism VARCHAR(255),
34 sasl_plain_username VARCHAR(255),
35 sasl_plain_password VARCHAR(255),
36 sasl_external_cert BLOB DEFAULT NULL,
37 sasl_external_key BLOB DEFAULT NULL,
38 enabled INTEGER NOT NULL DEFAULT 1,
39 FOREIGN KEY(user) REFERENCES User(id),
40 UNIQUE(user, addr, nick),
41 UNIQUE(user, name)
42);
43
44CREATE TABLE Channel (
45 id INTEGER PRIMARY KEY,
46 network INTEGER NOT NULL,
47 name VARCHAR(255) NOT NULL,
48 key VARCHAR(255),
49 detached INTEGER NOT NULL DEFAULT 0,
50 detached_internal_msgid VARCHAR(255),
51 relay_detached INTEGER NOT NULL DEFAULT 0,
52 reattach_on INTEGER NOT NULL DEFAULT 0,
53 detach_after INTEGER NOT NULL DEFAULT 0,
54 detach_on INTEGER NOT NULL DEFAULT 0,
55 FOREIGN KEY(network) REFERENCES Network(id),
56 UNIQUE(network, name)
57);
58
59CREATE TABLE DeliveryReceipt (
60 id INTEGER PRIMARY KEY,
61 network INTEGER NOT NULL,
62 target VARCHAR(255) NOT NULL,
63 client VARCHAR(255),
64 internal_msgid VARCHAR(255) NOT NULL,
65 FOREIGN KEY(network) REFERENCES Network(id),
66 UNIQUE(network, target, client)
67);
68`
69
70var sqliteMigrations = []string{
71 "", // migration #0 is reserved for schema initialization
72 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
73 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
74 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
75 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
76 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
77 `
78 CREATE TABLE UserNew (
79 id INTEGER PRIMARY KEY,
80 username VARCHAR(255) NOT NULL UNIQUE,
81 password VARCHAR(255),
82 admin INTEGER NOT NULL DEFAULT 0
83 );
84 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
85 DROP TABLE User;
86 ALTER TABLE UserNew RENAME TO User;
87 `,
88 `
89 CREATE TABLE NetworkNew (
90 id INTEGER PRIMARY KEY,
91 name VARCHAR(255),
92 user INTEGER NOT NULL,
93 addr VARCHAR(255) NOT NULL,
94 nick VARCHAR(255) NOT NULL,
95 username VARCHAR(255),
96 realname VARCHAR(255),
97 pass VARCHAR(255),
98 connect_commands VARCHAR(1023),
99 sasl_mechanism VARCHAR(255),
100 sasl_plain_username VARCHAR(255),
101 sasl_plain_password VARCHAR(255),
102 sasl_external_cert BLOB DEFAULT NULL,
103 sasl_external_key BLOB DEFAULT NULL,
104 FOREIGN KEY(user) REFERENCES User(id),
105 UNIQUE(user, addr, nick),
106 UNIQUE(user, name)
107 );
108 INSERT INTO NetworkNew
109 SELECT Network.id, name, User.id as user, addr, nick,
110 Network.username, realname, pass, connect_commands,
111 sasl_mechanism, sasl_plain_username, sasl_plain_password,
112 sasl_external_cert, sasl_external_key
113 FROM Network
114 JOIN User ON Network.user = User.username;
115 DROP TABLE Network;
116 ALTER TABLE NetworkNew RENAME TO Network;
117 `,
118 `
119 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
120 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
121 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
122 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
123 `,
124 `
125 CREATE TABLE DeliveryReceipt (
126 id INTEGER PRIMARY KEY,
127 network INTEGER NOT NULL,
128 target VARCHAR(255) NOT NULL,
129 client VARCHAR(255),
130 internal_msgid VARCHAR(255) NOT NULL,
131 FOREIGN KEY(network) REFERENCES Network(id),
132 UNIQUE(network, target, client)
133 );
134 `,
135 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
136 "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
137 "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
138}
139
140type SqliteDB struct {
141 lock sync.RWMutex
142 db *sql.DB
143}
144
145func OpenSqliteDB(driver, source string) (Database, error) {
146 sqlSqliteDB, err := sql.Open(driver, source)
147 if err != nil {
148 return nil, err
149 }
150
151 db := &SqliteDB{db: sqlSqliteDB}
152 if err := db.upgrade(); err != nil {
153 sqlSqliteDB.Close()
154 return nil, err
155 }
156
157 return db, nil
158}
159
160func (db *SqliteDB) Close() error {
161 db.lock.Lock()
162 defer db.lock.Unlock()
163 return db.db.Close()
164}
165
166func (db *SqliteDB) upgrade() error {
167 db.lock.Lock()
168 defer db.lock.Unlock()
169
170 var version int
171 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
172 return fmt.Errorf("failed to query schema version: %v", err)
173 }
174
175 if version == len(sqliteMigrations) {
176 return nil
177 } else if version > len(sqliteMigrations) {
178 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(sqliteMigrations), version)
179 }
180
181 tx, err := db.db.Begin()
182 if err != nil {
183 return err
184 }
185 defer tx.Rollback()
186
187 if version == 0 {
188 if _, err := tx.Exec(sqliteSchema); err != nil {
189 return fmt.Errorf("failed to initialize schema: %v", err)
190 }
191 } else {
192 for i := version; i < len(sqliteMigrations); i++ {
193 if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
194 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
195 }
196 }
197 }
198
199 // For some reason prepared statements don't work here
200 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
201 if err != nil {
202 return fmt.Errorf("failed to bump schema version: %v", err)
203 }
204
205 return tx.Commit()
206}
207
208func toNullString(s string) sql.NullString {
209 return sql.NullString{
210 String: s,
211 Valid: s != "",
212 }
213}
214
215func (db *SqliteDB) ListUsers() ([]User, error) {
216 db.lock.RLock()
217 defer db.lock.RUnlock()
218
219 rows, err := db.db.Query("SELECT id, username, password, admin, realname FROM User")
220 if err != nil {
221 return nil, err
222 }
223 defer rows.Close()
224
225 var users []User
226 for rows.Next() {
227 var user User
228 var password, realname sql.NullString
229 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
230 return nil, err
231 }
232 user.Password = password.String
233 user.Realname = realname.String
234 users = append(users, user)
235 }
236 if err := rows.Err(); err != nil {
237 return nil, err
238 }
239
240 return users, nil
241}
242
243func (db *SqliteDB) GetUser(username string) (*User, error) {
244 db.lock.RLock()
245 defer db.lock.RUnlock()
246
247 user := &User{Username: username}
248
249 var password, realname sql.NullString
250 row := db.db.QueryRow("SELECT id, password, admin, realname FROM User WHERE username = ?", username)
251 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
252 return nil, err
253 }
254 user.Password = password.String
255 user.Realname = realname.String
256 return user, nil
257}
258
259func (db *SqliteDB) StoreUser(user *User) error {
260 db.lock.Lock()
261 defer db.lock.Unlock()
262
263 args := []interface{}{
264 sql.Named("username", user.Username),
265 sql.Named("password", toNullString(user.Password)),
266 sql.Named("admin", user.Admin),
267 sql.Named("realname", toNullString(user.Realname)),
268 }
269
270 var err error
271 if user.ID != 0 {
272 _, err = db.db.Exec("UPDATE User SET password = :password, admin = :admin, realname = :realname WHERE username = :username", args...)
273 } else {
274 var res sql.Result
275 res, err = db.db.Exec("INSERT INTO User(username, password, admin, realname) VALUES (:username, :password, :admin, :realname)", args...)
276 if err != nil {
277 return err
278 }
279 user.ID, err = res.LastInsertId()
280 }
281
282 return err
283}
284
285func (db *SqliteDB) DeleteUser(id int64) error {
286 db.lock.Lock()
287 defer db.lock.Unlock()
288
289 tx, err := db.db.Begin()
290 if err != nil {
291 return err
292 }
293 defer tx.Rollback()
294
295 _, err = tx.Exec(`DELETE FROM DeliveryReceipt
296 WHERE id IN (
297 SELECT DeliveryReceipt.id
298 FROM DeliveryReceipt
299 JOIN Network ON DeliveryReceipt.network = Network.id
300 WHERE Network.user = ?
301 )`, id)
302 if err != nil {
303 return err
304 }
305
306 _, err = tx.Exec(`DELETE FROM Channel
307 WHERE id IN (
308 SELECT Channel.id
309 FROM Channel
310 JOIN Network ON Channel.network = Network.id
311 WHERE Network.user = ?
312 )`, id)
313 if err != nil {
314 return err
315 }
316
317 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
318 if err != nil {
319 return err
320 }
321
322 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
323 if err != nil {
324 return err
325 }
326
327 return tx.Commit()
328}
329
330func (db *SqliteDB) ListNetworks(userID int64) ([]Network, error) {
331 db.lock.RLock()
332 defer db.lock.RUnlock()
333
334 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
335 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
336 sasl_external_cert, sasl_external_key, enabled
337 FROM Network
338 WHERE user = ?`,
339 userID)
340 if err != nil {
341 return nil, err
342 }
343 defer rows.Close()
344
345 var networks []Network
346 for rows.Next() {
347 var net Network
348 var name, username, realname, pass, connectCommands sql.NullString
349 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
350 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
351 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
352 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
353 if err != nil {
354 return nil, err
355 }
356 net.Name = name.String
357 net.Username = username.String
358 net.Realname = realname.String
359 net.Pass = pass.String
360 if connectCommands.Valid {
361 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
362 }
363 net.SASL.Mechanism = saslMechanism.String
364 net.SASL.Plain.Username = saslPlainUsername.String
365 net.SASL.Plain.Password = saslPlainPassword.String
366 networks = append(networks, net)
367 }
368 if err := rows.Err(); err != nil {
369 return nil, err
370 }
371
372 return networks, nil
373}
374
375func (db *SqliteDB) StoreNetwork(userID int64, network *Network) error {
376 db.lock.Lock()
377 defer db.lock.Unlock()
378
379 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
380 if network.SASL.Mechanism != "" {
381 saslMechanism = toNullString(network.SASL.Mechanism)
382 switch network.SASL.Mechanism {
383 case "PLAIN":
384 saslPlainUsername = toNullString(network.SASL.Plain.Username)
385 saslPlainPassword = toNullString(network.SASL.Plain.Password)
386 network.SASL.External.CertBlob = nil
387 network.SASL.External.PrivKeyBlob = nil
388 case "EXTERNAL":
389 // keep saslPlain* nil
390 default:
391 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
392 }
393 }
394
395 args := []interface{}{
396 sql.Named("name", toNullString(network.Name)),
397 sql.Named("addr", network.Addr),
398 sql.Named("nick", network.Nick),
399 sql.Named("username", toNullString(network.Username)),
400 sql.Named("realname", toNullString(network.Realname)),
401 sql.Named("pass", toNullString(network.Pass)),
402 sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
403 sql.Named("sasl_mechanism", saslMechanism),
404 sql.Named("sasl_plain_username", saslPlainUsername),
405 sql.Named("sasl_plain_password", saslPlainPassword),
406 sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
407 sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
408 sql.Named("enabled", network.Enabled),
409
410 sql.Named("id", network.ID), // only for UPDATE
411 sql.Named("user", userID), // only for INSERT
412 }
413
414 var err error
415 if network.ID != 0 {
416 _, err = db.db.Exec(`
417 UPDATE Network
418 SET name = :name, addr = :addr, nick = :nick, username = :username,
419 realname = :realname, pass = :pass, connect_commands = :connect_commands,
420 sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
421 sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
422 enabled = :enabled
423 WHERE id = :id`, args...)
424 } else {
425 var res sql.Result
426 res, err = db.db.Exec(`
427 INSERT INTO Network(user, name, addr, nick, username, realname, pass,
428 connect_commands, sasl_mechanism, sasl_plain_username,
429 sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
430 VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
431 :connect_commands, :sasl_mechanism, :sasl_plain_username,
432 :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
433 args...)
434 if err != nil {
435 return err
436 }
437 network.ID, err = res.LastInsertId()
438 }
439 return err
440}
441
442func (db *SqliteDB) DeleteNetwork(id int64) error {
443 db.lock.Lock()
444 defer db.lock.Unlock()
445
446 tx, err := db.db.Begin()
447 if err != nil {
448 return err
449 }
450 defer tx.Rollback()
451
452 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ?", id)
453 if err != nil {
454 return err
455 }
456
457 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
458 if err != nil {
459 return err
460 }
461
462 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
463 if err != nil {
464 return err
465 }
466
467 return tx.Commit()
468}
469
470func (db *SqliteDB) ListChannels(networkID int64) ([]Channel, error) {
471 db.lock.RLock()
472 defer db.lock.RUnlock()
473
474 rows, err := db.db.Query(`SELECT
475 id, name, key, detached, detached_internal_msgid,
476 relay_detached, reattach_on, detach_after, detach_on
477 FROM Channel
478 WHERE network = ?`, networkID)
479 if err != nil {
480 return nil, err
481 }
482 defer rows.Close()
483
484 var channels []Channel
485 for rows.Next() {
486 var ch Channel
487 var key, detachedInternalMsgID sql.NullString
488 var detachAfter int64
489 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
490 return nil, err
491 }
492 ch.Key = key.String
493 ch.DetachedInternalMsgID = detachedInternalMsgID.String
494 ch.DetachAfter = time.Duration(detachAfter) * time.Second
495 channels = append(channels, ch)
496 }
497 if err := rows.Err(); err != nil {
498 return nil, err
499 }
500
501 return channels, nil
502}
503
504func (db *SqliteDB) StoreChannel(networkID int64, ch *Channel) error {
505 db.lock.Lock()
506 defer db.lock.Unlock()
507
508 args := []interface{}{
509 sql.Named("network", networkID),
510 sql.Named("name", ch.Name),
511 sql.Named("key", toNullString(ch.Key)),
512 sql.Named("detached", ch.Detached),
513 sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)),
514 sql.Named("relay_detached", ch.RelayDetached),
515 sql.Named("reattach_on", ch.ReattachOn),
516 sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))),
517 sql.Named("detach_on", ch.DetachOn),
518
519 sql.Named("id", ch.ID), // only for UPDATE
520 }
521
522 var err error
523 if ch.ID != 0 {
524 _, err = db.db.Exec(`UPDATE Channel
525 SET network = :network, name = :name, key = :key, detached = :detached,
526 detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached,
527 reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on
528 WHERE id = :id`, args...)
529 } else {
530 var res sql.Result
531 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
532 VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...)
533 if err != nil {
534 return err
535 }
536 ch.ID, err = res.LastInsertId()
537 }
538 return err
539}
540
541func (db *SqliteDB) DeleteChannel(id int64) error {
542 db.lock.Lock()
543 defer db.lock.Unlock()
544
545 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
546 return err
547}
548
549func (db *SqliteDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
550 db.lock.RLock()
551 defer db.lock.RUnlock()
552
553 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
554 FROM DeliveryReceipt
555 WHERE network = ?`, networkID)
556 if err != nil {
557 return nil, err
558 }
559 defer rows.Close()
560
561 var receipts []DeliveryReceipt
562 for rows.Next() {
563 var rcpt DeliveryReceipt
564 var client sql.NullString
565 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
566 return nil, err
567 }
568 rcpt.Client = client.String
569 receipts = append(receipts, rcpt)
570 }
571 if err := rows.Err(); err != nil {
572 return nil, err
573 }
574
575 return receipts, nil
576}
577
578func (db *SqliteDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
579 db.lock.Lock()
580 defer db.lock.Unlock()
581
582 tx, err := db.db.Begin()
583 if err != nil {
584 return err
585 }
586 defer tx.Rollback()
587
588 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
589 networkID, toNullString(client))
590 if err != nil {
591 return err
592 }
593
594 for i := range receipts {
595 rcpt := &receipts[i]
596
597 res, err := tx.Exec(`INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
598 VALUES (:network, :target, :client, :internal_msgid)`,
599 sql.Named("network", networkID),
600 sql.Named("target", rcpt.Target),
601 sql.Named("client", toNullString(client)),
602 sql.Named("internal_msgid", rcpt.InternalMsgID))
603 if err != nil {
604 return err
605 }
606 rcpt.ID, err = res.LastInsertId()
607 if err != nil {
608 return err
609 }
610 }
611
612 return tx.Commit()
613}
Note: See TracBrowser for help on using the repository browser.