source: code/trunk/db.go@ 493

Last change on this file since 493 was 489, checked in by contact, 4 years ago

Save delivery receipts in DB

This avoids loosing history on restart for clients that don't
support chathistory.

Closes: https://todo.sr.ht/~emersion/soju/80

File size: 16.2 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "net/url"
8 "strings"
9 "sync"
10 "time"
11
12 _ "github.com/mattn/go-sqlite3"
13)
14
15type User struct {
16 ID int64
17 Username string
18 Password string // hashed
19 Admin bool
20}
21
22type SASL struct {
23 Mechanism string
24
25 Plain struct {
26 Username string
27 Password string
28 }
29
30 // TLS client certificate authentication.
31 External struct {
32 // X.509 certificate in DER form.
33 CertBlob []byte
34 // PKCS#8 private key in DER form.
35 PrivKeyBlob []byte
36 }
37}
38
39type Network struct {
40 ID int64
41 Name string
42 Addr string
43 Nick string
44 Username string
45 Realname string
46 Pass string
47 ConnectCommands []string
48 SASL SASL
49}
50
51func (net *Network) GetName() string {
52 if net.Name != "" {
53 return net.Name
54 }
55 return net.Addr
56}
57
58func (net *Network) URL() (*url.URL, error) {
59 s := net.Addr
60 if !strings.Contains(s, "://") {
61 // This is a raw domain name, make it an URL with the default scheme
62 s = "ircs://" + s
63 }
64
65 u, err := url.Parse(s)
66 if err != nil {
67 return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
68 }
69
70 return u, nil
71}
72
73func (net *Network) GetUsername() string {
74 if net.Username != "" {
75 return net.Username
76 }
77 return net.Nick
78}
79
80func (net *Network) GetRealname() string {
81 if net.Realname != "" {
82 return net.Realname
83 }
84 return net.Nick
85}
86
87type MessageFilter int
88
89const (
90 // TODO: use customizable user defaults for FilterDefault
91 FilterDefault MessageFilter = iota
92 FilterNone
93 FilterHighlight
94 FilterMessage
95)
96
97func parseFilter(filter string) (MessageFilter, error) {
98 switch filter {
99 case "default":
100 return FilterDefault, nil
101 case "none":
102 return FilterNone, nil
103 case "highlight":
104 return FilterHighlight, nil
105 case "message":
106 return FilterMessage, nil
107 }
108 return 0, fmt.Errorf("unknown filter: %q", filter)
109}
110
111type Channel struct {
112 ID int64
113 Name string
114 Key string
115 Detached bool
116
117 RelayDetached MessageFilter
118 ReattachOn MessageFilter
119 DetachAfter time.Duration
120 DetachOn MessageFilter
121}
122
123type DeliveryReceipt struct {
124 ID int64
125 Target string // channel or nick
126 Client string
127 InternalMsgID string
128}
129
130const schema = `
131CREATE TABLE User (
132 id INTEGER PRIMARY KEY,
133 username VARCHAR(255) NOT NULL UNIQUE,
134 password VARCHAR(255),
135 admin INTEGER NOT NULL DEFAULT 0
136);
137
138CREATE TABLE Network (
139 id INTEGER PRIMARY KEY,
140 name VARCHAR(255),
141 user INTEGER NOT NULL,
142 addr VARCHAR(255) NOT NULL,
143 nick VARCHAR(255) NOT NULL,
144 username VARCHAR(255),
145 realname VARCHAR(255),
146 pass VARCHAR(255),
147 connect_commands VARCHAR(1023),
148 sasl_mechanism VARCHAR(255),
149 sasl_plain_username VARCHAR(255),
150 sasl_plain_password VARCHAR(255),
151 sasl_external_cert BLOB DEFAULT NULL,
152 sasl_external_key BLOB DEFAULT NULL,
153 FOREIGN KEY(user) REFERENCES User(id),
154 UNIQUE(user, addr, nick),
155 UNIQUE(user, name)
156);
157
158CREATE TABLE Channel (
159 id INTEGER PRIMARY KEY,
160 network INTEGER NOT NULL,
161 name VARCHAR(255) NOT NULL,
162 key VARCHAR(255),
163 detached INTEGER NOT NULL DEFAULT 0,
164 relay_detached INTEGER NOT NULL DEFAULT 0,
165 reattach_on INTEGER NOT NULL DEFAULT 0,
166 detach_after INTEGER NOT NULL DEFAULT 0,
167 detach_on INTEGER NOT NULL DEFAULT 0,
168 FOREIGN KEY(network) REFERENCES Network(id),
169 UNIQUE(network, name)
170);
171
172CREATE TABLE DeliveryReceipt (
173 id INTEGER PRIMARY KEY,
174 network INTEGER NOT NULL,
175 target VARCHAR(255) NOT NULL,
176 client VARCHAR(255),
177 internal_msgid VARCHAR(255) NOT NULL,
178 FOREIGN KEY(network) REFERENCES Network(id),
179 UNIQUE(network, target, client)
180);
181`
182
183var migrations = []string{
184 "", // migration #0 is reserved for schema initialization
185 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
186 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
187 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
188 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
189 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
190 `
191 CREATE TABLE UserNew (
192 id INTEGER PRIMARY KEY,
193 username VARCHAR(255) NOT NULL UNIQUE,
194 password VARCHAR(255),
195 admin INTEGER NOT NULL DEFAULT 0
196 );
197 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
198 DROP TABLE User;
199 ALTER TABLE UserNew RENAME TO User;
200 `,
201 `
202 CREATE TABLE NetworkNew (
203 id INTEGER PRIMARY KEY,
204 name VARCHAR(255),
205 user INTEGER NOT NULL,
206 addr VARCHAR(255) NOT NULL,
207 nick VARCHAR(255) NOT NULL,
208 username VARCHAR(255),
209 realname VARCHAR(255),
210 pass VARCHAR(255),
211 connect_commands VARCHAR(1023),
212 sasl_mechanism VARCHAR(255),
213 sasl_plain_username VARCHAR(255),
214 sasl_plain_password VARCHAR(255),
215 sasl_external_cert BLOB DEFAULT NULL,
216 sasl_external_key BLOB DEFAULT NULL,
217 FOREIGN KEY(user) REFERENCES User(id),
218 UNIQUE(user, addr, nick),
219 UNIQUE(user, name)
220 );
221 INSERT INTO NetworkNew
222 SELECT Network.id, name, User.id as user, addr, nick,
223 Network.username, realname, pass, connect_commands,
224 sasl_mechanism, sasl_plain_username, sasl_plain_password,
225 sasl_external_cert, sasl_external_key
226 FROM Network
227 JOIN User ON Network.user = User.username;
228 DROP TABLE Network;
229 ALTER TABLE NetworkNew RENAME TO Network;
230 `,
231 `
232 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
233 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
234 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
235 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
236 `,
237 `
238 CREATE TABLE DeliveryReceipt (
239 id INTEGER PRIMARY KEY,
240 network INTEGER NOT NULL,
241 target VARCHAR(255) NOT NULL,
242 client VARCHAR(255),
243 internal_msgid VARCHAR(255) NOT NULL,
244 FOREIGN KEY(network) REFERENCES Network(id),
245 UNIQUE(network, target, client)
246 );
247 `,
248}
249
250type DB struct {
251 lock sync.RWMutex
252 db *sql.DB
253}
254
255func OpenSQLDB(driver, source string) (*DB, error) {
256 sqlDB, err := sql.Open(driver, source)
257 if err != nil {
258 return nil, err
259 }
260
261 db := &DB{db: sqlDB}
262 if err := db.upgrade(); err != nil {
263 return nil, err
264 }
265
266 return db, nil
267}
268
269func (db *DB) Close() error {
270 db.lock.Lock()
271 defer db.lock.Unlock()
272 return db.db.Close()
273}
274
275func (db *DB) upgrade() error {
276 db.lock.Lock()
277 defer db.lock.Unlock()
278
279 var version int
280 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
281 return fmt.Errorf("failed to query schema version: %v", err)
282 }
283
284 if version == len(migrations) {
285 return nil
286 } else if version > len(migrations) {
287 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
288 }
289
290 tx, err := db.db.Begin()
291 if err != nil {
292 return err
293 }
294 defer tx.Rollback()
295
296 if version == 0 {
297 if _, err := tx.Exec(schema); err != nil {
298 return fmt.Errorf("failed to initialize schema: %v", err)
299 }
300 } else {
301 for i := version; i < len(migrations); i++ {
302 if _, err := tx.Exec(migrations[i]); err != nil {
303 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
304 }
305 }
306 }
307
308 // For some reason prepared statements don't work here
309 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
310 if err != nil {
311 return fmt.Errorf("failed to bump schema version: %v", err)
312 }
313
314 return tx.Commit()
315}
316
317func toNullString(s string) sql.NullString {
318 return sql.NullString{
319 String: s,
320 Valid: s != "",
321 }
322}
323
324func (db *DB) ListUsers() ([]User, error) {
325 db.lock.RLock()
326 defer db.lock.RUnlock()
327
328 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
329 if err != nil {
330 return nil, err
331 }
332 defer rows.Close()
333
334 var users []User
335 for rows.Next() {
336 var user User
337 var password sql.NullString
338 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
339 return nil, err
340 }
341 user.Password = password.String
342 users = append(users, user)
343 }
344 if err := rows.Err(); err != nil {
345 return nil, err
346 }
347
348 return users, nil
349}
350
351func (db *DB) GetUser(username string) (*User, error) {
352 db.lock.RLock()
353 defer db.lock.RUnlock()
354
355 user := &User{Username: username}
356
357 var password sql.NullString
358 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
359 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
360 return nil, err
361 }
362 user.Password = password.String
363 return user, nil
364}
365
366func (db *DB) StoreUser(user *User) error {
367 db.lock.Lock()
368 defer db.lock.Unlock()
369
370 password := toNullString(user.Password)
371
372 var err error
373 if user.ID != 0 {
374 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
375 password, user.Admin, user.Username)
376 } else {
377 var res sql.Result
378 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
379 user.Username, password, user.Admin)
380 if err != nil {
381 return err
382 }
383 user.ID, err = res.LastInsertId()
384 }
385
386 return err
387}
388
389func (db *DB) DeleteUser(id int64) error {
390 db.lock.Lock()
391 defer db.lock.Unlock()
392
393 tx, err := db.db.Begin()
394 if err != nil {
395 return err
396 }
397 defer tx.Rollback()
398
399 _, err = tx.Exec(`DELETE FROM Channel
400 WHERE id IN (
401 SELECT Channel.id
402 FROM Channel
403 JOIN Network ON Channel.network = Network.id
404 WHERE Network.user = ?
405 )`, id)
406 if err != nil {
407 return err
408 }
409
410 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
411 if err != nil {
412 return err
413 }
414
415 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
416 if err != nil {
417 return err
418 }
419
420 return tx.Commit()
421}
422
423func (db *DB) ListNetworks(userID int64) ([]Network, error) {
424 db.lock.RLock()
425 defer db.lock.RUnlock()
426
427 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
428 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
429 sasl_external_cert, sasl_external_key
430 FROM Network
431 WHERE user = ?`,
432 userID)
433 if err != nil {
434 return nil, err
435 }
436 defer rows.Close()
437
438 var networks []Network
439 for rows.Next() {
440 var net Network
441 var name, username, realname, pass, connectCommands sql.NullString
442 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
443 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
444 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
445 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
446 if err != nil {
447 return nil, err
448 }
449 net.Name = name.String
450 net.Username = username.String
451 net.Realname = realname.String
452 net.Pass = pass.String
453 if connectCommands.Valid {
454 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
455 }
456 net.SASL.Mechanism = saslMechanism.String
457 net.SASL.Plain.Username = saslPlainUsername.String
458 net.SASL.Plain.Password = saslPlainPassword.String
459 networks = append(networks, net)
460 }
461 if err := rows.Err(); err != nil {
462 return nil, err
463 }
464
465 return networks, nil
466}
467
468func (db *DB) StoreNetwork(userID int64, network *Network) error {
469 db.lock.Lock()
470 defer db.lock.Unlock()
471
472 netName := toNullString(network.Name)
473 netUsername := toNullString(network.Username)
474 realname := toNullString(network.Realname)
475 pass := toNullString(network.Pass)
476 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
477
478 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
479 if network.SASL.Mechanism != "" {
480 saslMechanism = toNullString(network.SASL.Mechanism)
481 switch network.SASL.Mechanism {
482 case "PLAIN":
483 saslPlainUsername = toNullString(network.SASL.Plain.Username)
484 saslPlainPassword = toNullString(network.SASL.Plain.Password)
485 network.SASL.External.CertBlob = nil
486 network.SASL.External.PrivKeyBlob = nil
487 case "EXTERNAL":
488 // keep saslPlain* nil
489 default:
490 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
491 }
492 }
493
494 var err error
495 if network.ID != 0 {
496 _, err = db.db.Exec(`UPDATE Network
497 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
498 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
499 sasl_external_cert = ?, sasl_external_key = ?
500 WHERE id = ?`,
501 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
502 saslMechanism, saslPlainUsername, saslPlainPassword,
503 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
504 network.ID)
505 } else {
506 var res sql.Result
507 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
508 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
509 sasl_plain_password, sasl_external_cert, sasl_external_key)
510 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
511 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
512 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
513 network.SASL.External.PrivKeyBlob)
514 if err != nil {
515 return err
516 }
517 network.ID, err = res.LastInsertId()
518 }
519 return err
520}
521
522func (db *DB) DeleteNetwork(id int64) error {
523 db.lock.Lock()
524 defer db.lock.Unlock()
525
526 tx, err := db.db.Begin()
527 if err != nil {
528 return err
529 }
530 defer tx.Rollback()
531
532 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
533 if err != nil {
534 return err
535 }
536
537 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
538 if err != nil {
539 return err
540 }
541
542 return tx.Commit()
543}
544
545func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
546 db.lock.RLock()
547 defer db.lock.RUnlock()
548
549 rows, err := db.db.Query(`SELECT id, name, key, detached, relay_detached, reattach_on, detach_after, detach_on
550 FROM Channel
551 WHERE network = ?`, networkID)
552 if err != nil {
553 return nil, err
554 }
555 defer rows.Close()
556
557 var channels []Channel
558 for rows.Next() {
559 var ch Channel
560 var key sql.NullString
561 var detachAfter int64
562 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
563 return nil, err
564 }
565 ch.Key = key.String
566 ch.DetachAfter = time.Duration(detachAfter) * time.Second
567 channels = append(channels, ch)
568 }
569 if err := rows.Err(); err != nil {
570 return nil, err
571 }
572
573 return channels, nil
574}
575
576func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
577 db.lock.Lock()
578 defer db.lock.Unlock()
579
580 key := toNullString(ch.Key)
581 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
582
583 var err error
584 if ch.ID != 0 {
585 _, err = db.db.Exec(`UPDATE Channel
586 SET network = ?, name = ?, key = ?, detached = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?
587 WHERE id = ?`,
588 networkID, ch.Name, key, ch.Detached, ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)
589 } else {
590 var res sql.Result
591 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, relay_detached, reattach_on, detach_after, detach_on)
592 VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
593 networkID, ch.Name, key, ch.Detached, ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
594 if err != nil {
595 return err
596 }
597 ch.ID, err = res.LastInsertId()
598 }
599 return err
600}
601
602func (db *DB) DeleteChannel(id int64) error {
603 db.lock.Lock()
604 defer db.lock.Unlock()
605
606 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
607 return err
608}
609
610func (db *DB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
611 db.lock.RLock()
612 defer db.lock.RUnlock()
613
614 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
615 FROM DeliveryReceipt
616 WHERE network = ?`, networkID)
617 if err != nil {
618 return nil, err
619 }
620 defer rows.Close()
621
622 var receipts []DeliveryReceipt
623 for rows.Next() {
624 var rcpt DeliveryReceipt
625 var client sql.NullString
626 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
627 return nil, err
628 }
629 rcpt.Client = client.String
630 receipts = append(receipts, rcpt)
631 }
632 if err := rows.Err(); err != nil {
633 return nil, err
634 }
635
636 return receipts, nil
637}
638
639func (db *DB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
640 db.lock.Lock()
641 defer db.lock.Unlock()
642
643 tx, err := db.db.Begin()
644 if err != nil {
645 return err
646 }
647 defer tx.Rollback()
648
649 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",
650 networkID, toNullString(client))
651 if err != nil {
652 return err
653 }
654
655 for i := range receipts {
656 rcpt := &receipts[i]
657
658 res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
659 networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
660 if err != nil {
661 return err
662 }
663 rcpt.ID, err = res.LastInsertId()
664 if err != nil {
665 return err
666 }
667 }
668
669 return tx.Commit()
670}
Note: See TracBrowser for help on using the repository browser.