source: code/trunk/db.go@ 508

Last change on this file since 508 was 497, checked in by contact, 4 years ago

Store last internal msg ID in DB when detaching

References: https://todo.sr.ht/~emersion/soju/98

File size: 16.6 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "math"
7 "net/url"
8 "strings"
9 "sync"
10 "time"
11
12 _ "github.com/mattn/go-sqlite3"
13)
14
15type User struct {
16 ID int64
17 Username string
18 Password string // hashed
19 Admin bool
20}
21
22type SASL struct {
23 Mechanism string
24
25 Plain struct {
26 Username string
27 Password string
28 }
29
30 // TLS client certificate authentication.
31 External struct {
32 // X.509 certificate in DER form.
33 CertBlob []byte
34 // PKCS#8 private key in DER form.
35 PrivKeyBlob []byte
36 }
37}
38
39type Network struct {
40 ID int64
41 Name string
42 Addr string
43 Nick string
44 Username string
45 Realname string
46 Pass string
47 ConnectCommands []string
48 SASL SASL
49}
50
51func (net *Network) GetName() string {
52 if net.Name != "" {
53 return net.Name
54 }
55 return net.Addr
56}
57
58func (net *Network) URL() (*url.URL, error) {
59 s := net.Addr
60 if !strings.Contains(s, "://") {
61 // This is a raw domain name, make it an URL with the default scheme
62 s = "ircs://" + s
63 }
64
65 u, err := url.Parse(s)
66 if err != nil {
67 return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
68 }
69
70 return u, nil
71}
72
73func (net *Network) GetUsername() string {
74 if net.Username != "" {
75 return net.Username
76 }
77 return net.Nick
78}
79
80func (net *Network) GetRealname() string {
81 if net.Realname != "" {
82 return net.Realname
83 }
84 return net.Nick
85}
86
87type MessageFilter int
88
89const (
90 // TODO: use customizable user defaults for FilterDefault
91 FilterDefault MessageFilter = iota
92 FilterNone
93 FilterHighlight
94 FilterMessage
95)
96
97func parseFilter(filter string) (MessageFilter, error) {
98 switch filter {
99 case "default":
100 return FilterDefault, nil
101 case "none":
102 return FilterNone, nil
103 case "highlight":
104 return FilterHighlight, nil
105 case "message":
106 return FilterMessage, nil
107 }
108 return 0, fmt.Errorf("unknown filter: %q", filter)
109}
110
111type Channel struct {
112 ID int64
113 Name string
114 Key string
115
116 Detached bool
117 DetachedInternalMsgID string
118
119 RelayDetached MessageFilter
120 ReattachOn MessageFilter
121 DetachAfter time.Duration
122 DetachOn MessageFilter
123}
124
125type DeliveryReceipt struct {
126 ID int64
127 Target string // channel or nick
128 Client string
129 InternalMsgID string
130}
131
132const schema = `
133CREATE TABLE User (
134 id INTEGER PRIMARY KEY,
135 username VARCHAR(255) NOT NULL UNIQUE,
136 password VARCHAR(255),
137 admin INTEGER NOT NULL DEFAULT 0
138);
139
140CREATE TABLE Network (
141 id INTEGER PRIMARY KEY,
142 name VARCHAR(255),
143 user INTEGER NOT NULL,
144 addr VARCHAR(255) NOT NULL,
145 nick VARCHAR(255) NOT NULL,
146 username VARCHAR(255),
147 realname VARCHAR(255),
148 pass VARCHAR(255),
149 connect_commands VARCHAR(1023),
150 sasl_mechanism VARCHAR(255),
151 sasl_plain_username VARCHAR(255),
152 sasl_plain_password VARCHAR(255),
153 sasl_external_cert BLOB DEFAULT NULL,
154 sasl_external_key BLOB DEFAULT NULL,
155 FOREIGN KEY(user) REFERENCES User(id),
156 UNIQUE(user, addr, nick),
157 UNIQUE(user, name)
158);
159
160CREATE TABLE Channel (
161 id INTEGER PRIMARY KEY,
162 network INTEGER NOT NULL,
163 name VARCHAR(255) NOT NULL,
164 key VARCHAR(255),
165 detached INTEGER NOT NULL DEFAULT 0,
166 detached_internal_msgid VARCHAR(255),
167 relay_detached INTEGER NOT NULL DEFAULT 0,
168 reattach_on INTEGER NOT NULL DEFAULT 0,
169 detach_after INTEGER NOT NULL DEFAULT 0,
170 detach_on INTEGER NOT NULL DEFAULT 0,
171 FOREIGN KEY(network) REFERENCES Network(id),
172 UNIQUE(network, name)
173);
174
175CREATE TABLE DeliveryReceipt (
176 id INTEGER PRIMARY KEY,
177 network INTEGER NOT NULL,
178 target VARCHAR(255) NOT NULL,
179 client VARCHAR(255),
180 internal_msgid VARCHAR(255) NOT NULL,
181 FOREIGN KEY(network) REFERENCES Network(id),
182 UNIQUE(network, target, client)
183);
184`
185
186var migrations = []string{
187 "", // migration #0 is reserved for schema initialization
188 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
189 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
190 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
191 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
192 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
193 `
194 CREATE TABLE UserNew (
195 id INTEGER PRIMARY KEY,
196 username VARCHAR(255) NOT NULL UNIQUE,
197 password VARCHAR(255),
198 admin INTEGER NOT NULL DEFAULT 0
199 );
200 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
201 DROP TABLE User;
202 ALTER TABLE UserNew RENAME TO User;
203 `,
204 `
205 CREATE TABLE NetworkNew (
206 id INTEGER PRIMARY KEY,
207 name VARCHAR(255),
208 user INTEGER NOT NULL,
209 addr VARCHAR(255) NOT NULL,
210 nick VARCHAR(255) NOT NULL,
211 username VARCHAR(255),
212 realname VARCHAR(255),
213 pass VARCHAR(255),
214 connect_commands VARCHAR(1023),
215 sasl_mechanism VARCHAR(255),
216 sasl_plain_username VARCHAR(255),
217 sasl_plain_password VARCHAR(255),
218 sasl_external_cert BLOB DEFAULT NULL,
219 sasl_external_key BLOB DEFAULT NULL,
220 FOREIGN KEY(user) REFERENCES User(id),
221 UNIQUE(user, addr, nick),
222 UNIQUE(user, name)
223 );
224 INSERT INTO NetworkNew
225 SELECT Network.id, name, User.id as user, addr, nick,
226 Network.username, realname, pass, connect_commands,
227 sasl_mechanism, sasl_plain_username, sasl_plain_password,
228 sasl_external_cert, sasl_external_key
229 FROM Network
230 JOIN User ON Network.user = User.username;
231 DROP TABLE Network;
232 ALTER TABLE NetworkNew RENAME TO Network;
233 `,
234 `
235 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
236 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
237 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
238 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
239 `,
240 `
241 CREATE TABLE DeliveryReceipt (
242 id INTEGER PRIMARY KEY,
243 network INTEGER NOT NULL,
244 target VARCHAR(255) NOT NULL,
245 client VARCHAR(255),
246 internal_msgid VARCHAR(255) NOT NULL,
247 FOREIGN KEY(network) REFERENCES Network(id),
248 UNIQUE(network, target, client)
249 );
250 `,
251 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
252}
253
254type DB struct {
255 lock sync.RWMutex
256 db *sql.DB
257}
258
259func OpenSQLDB(driver, source string) (*DB, error) {
260 sqlDB, err := sql.Open(driver, source)
261 if err != nil {
262 return nil, err
263 }
264
265 db := &DB{db: sqlDB}
266 if err := db.upgrade(); err != nil {
267 return nil, err
268 }
269
270 return db, nil
271}
272
273func (db *DB) Close() error {
274 db.lock.Lock()
275 defer db.lock.Unlock()
276 return db.db.Close()
277}
278
279func (db *DB) upgrade() error {
280 db.lock.Lock()
281 defer db.lock.Unlock()
282
283 var version int
284 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
285 return fmt.Errorf("failed to query schema version: %v", err)
286 }
287
288 if version == len(migrations) {
289 return nil
290 } else if version > len(migrations) {
291 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
292 }
293
294 tx, err := db.db.Begin()
295 if err != nil {
296 return err
297 }
298 defer tx.Rollback()
299
300 if version == 0 {
301 if _, err := tx.Exec(schema); err != nil {
302 return fmt.Errorf("failed to initialize schema: %v", err)
303 }
304 } else {
305 for i := version; i < len(migrations); i++ {
306 if _, err := tx.Exec(migrations[i]); err != nil {
307 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
308 }
309 }
310 }
311
312 // For some reason prepared statements don't work here
313 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
314 if err != nil {
315 return fmt.Errorf("failed to bump schema version: %v", err)
316 }
317
318 return tx.Commit()
319}
320
321func toNullString(s string) sql.NullString {
322 return sql.NullString{
323 String: s,
324 Valid: s != "",
325 }
326}
327
328func (db *DB) ListUsers() ([]User, error) {
329 db.lock.RLock()
330 defer db.lock.RUnlock()
331
332 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
333 if err != nil {
334 return nil, err
335 }
336 defer rows.Close()
337
338 var users []User
339 for rows.Next() {
340 var user User
341 var password sql.NullString
342 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
343 return nil, err
344 }
345 user.Password = password.String
346 users = append(users, user)
347 }
348 if err := rows.Err(); err != nil {
349 return nil, err
350 }
351
352 return users, nil
353}
354
355func (db *DB) GetUser(username string) (*User, error) {
356 db.lock.RLock()
357 defer db.lock.RUnlock()
358
359 user := &User{Username: username}
360
361 var password sql.NullString
362 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
363 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
364 return nil, err
365 }
366 user.Password = password.String
367 return user, nil
368}
369
370func (db *DB) StoreUser(user *User) error {
371 db.lock.Lock()
372 defer db.lock.Unlock()
373
374 password := toNullString(user.Password)
375
376 var err error
377 if user.ID != 0 {
378 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
379 password, user.Admin, user.Username)
380 } else {
381 var res sql.Result
382 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
383 user.Username, password, user.Admin)
384 if err != nil {
385 return err
386 }
387 user.ID, err = res.LastInsertId()
388 }
389
390 return err
391}
392
393func (db *DB) DeleteUser(id int64) error {
394 db.lock.Lock()
395 defer db.lock.Unlock()
396
397 tx, err := db.db.Begin()
398 if err != nil {
399 return err
400 }
401 defer tx.Rollback()
402
403 _, err = tx.Exec(`DELETE FROM Channel
404 WHERE id IN (
405 SELECT Channel.id
406 FROM Channel
407 JOIN Network ON Channel.network = Network.id
408 WHERE Network.user = ?
409 )`, id)
410 if err != nil {
411 return err
412 }
413
414 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
415 if err != nil {
416 return err
417 }
418
419 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
420 if err != nil {
421 return err
422 }
423
424 return tx.Commit()
425}
426
427func (db *DB) ListNetworks(userID int64) ([]Network, error) {
428 db.lock.RLock()
429 defer db.lock.RUnlock()
430
431 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
432 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
433 sasl_external_cert, sasl_external_key
434 FROM Network
435 WHERE user = ?`,
436 userID)
437 if err != nil {
438 return nil, err
439 }
440 defer rows.Close()
441
442 var networks []Network
443 for rows.Next() {
444 var net Network
445 var name, username, realname, pass, connectCommands sql.NullString
446 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
447 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
448 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
449 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
450 if err != nil {
451 return nil, err
452 }
453 net.Name = name.String
454 net.Username = username.String
455 net.Realname = realname.String
456 net.Pass = pass.String
457 if connectCommands.Valid {
458 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
459 }
460 net.SASL.Mechanism = saslMechanism.String
461 net.SASL.Plain.Username = saslPlainUsername.String
462 net.SASL.Plain.Password = saslPlainPassword.String
463 networks = append(networks, net)
464 }
465 if err := rows.Err(); err != nil {
466 return nil, err
467 }
468
469 return networks, nil
470}
471
472func (db *DB) StoreNetwork(userID int64, network *Network) error {
473 db.lock.Lock()
474 defer db.lock.Unlock()
475
476 netName := toNullString(network.Name)
477 netUsername := toNullString(network.Username)
478 realname := toNullString(network.Realname)
479 pass := toNullString(network.Pass)
480 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
481
482 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
483 if network.SASL.Mechanism != "" {
484 saslMechanism = toNullString(network.SASL.Mechanism)
485 switch network.SASL.Mechanism {
486 case "PLAIN":
487 saslPlainUsername = toNullString(network.SASL.Plain.Username)
488 saslPlainPassword = toNullString(network.SASL.Plain.Password)
489 network.SASL.External.CertBlob = nil
490 network.SASL.External.PrivKeyBlob = nil
491 case "EXTERNAL":
492 // keep saslPlain* nil
493 default:
494 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
495 }
496 }
497
498 var err error
499 if network.ID != 0 {
500 _, err = db.db.Exec(`UPDATE Network
501 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
502 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
503 sasl_external_cert = ?, sasl_external_key = ?
504 WHERE id = ?`,
505 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
506 saslMechanism, saslPlainUsername, saslPlainPassword,
507 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
508 network.ID)
509 } else {
510 var res sql.Result
511 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
512 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
513 sasl_plain_password, sasl_external_cert, sasl_external_key)
514 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
515 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
516 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
517 network.SASL.External.PrivKeyBlob)
518 if err != nil {
519 return err
520 }
521 network.ID, err = res.LastInsertId()
522 }
523 return err
524}
525
526func (db *DB) DeleteNetwork(id int64) error {
527 db.lock.Lock()
528 defer db.lock.Unlock()
529
530 tx, err := db.db.Begin()
531 if err != nil {
532 return err
533 }
534 defer tx.Rollback()
535
536 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
537 if err != nil {
538 return err
539 }
540
541 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
542 if err != nil {
543 return err
544 }
545
546 return tx.Commit()
547}
548
549func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
550 db.lock.RLock()
551 defer db.lock.RUnlock()
552
553 rows, err := db.db.Query(`SELECT
554 id, name, key, detached, detached_internal_msgid,
555 relay_detached, reattach_on, detach_after, detach_on
556 FROM Channel
557 WHERE network = ?`, networkID)
558 if err != nil {
559 return nil, err
560 }
561 defer rows.Close()
562
563 var channels []Channel
564 for rows.Next() {
565 var ch Channel
566 var key, detachedInternalMsgID sql.NullString
567 var detachAfter int64
568 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
569 return nil, err
570 }
571 ch.Key = key.String
572 ch.DetachedInternalMsgID = detachedInternalMsgID.String
573 ch.DetachAfter = time.Duration(detachAfter) * time.Second
574 channels = append(channels, ch)
575 }
576 if err := rows.Err(); err != nil {
577 return nil, err
578 }
579
580 return channels, nil
581}
582
583func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
584 db.lock.Lock()
585 defer db.lock.Unlock()
586
587 key := toNullString(ch.Key)
588 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
589
590 var err error
591 if ch.ID != 0 {
592 _, err = db.db.Exec(`UPDATE Channel
593 SET network = ?, name = ?, key = ?, detached = ?, detached_internal_msgid = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?
594 WHERE id = ?`,
595 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)
596 } else {
597 var res sql.Result
598 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
599 VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
600 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
601 if err != nil {
602 return err
603 }
604 ch.ID, err = res.LastInsertId()
605 }
606 return err
607}
608
609func (db *DB) DeleteChannel(id int64) error {
610 db.lock.Lock()
611 defer db.lock.Unlock()
612
613 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
614 return err
615}
616
617func (db *DB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
618 db.lock.RLock()
619 defer db.lock.RUnlock()
620
621 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
622 FROM DeliveryReceipt
623 WHERE network = ?`, networkID)
624 if err != nil {
625 return nil, err
626 }
627 defer rows.Close()
628
629 var receipts []DeliveryReceipt
630 for rows.Next() {
631 var rcpt DeliveryReceipt
632 var client sql.NullString
633 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
634 return nil, err
635 }
636 rcpt.Client = client.String
637 receipts = append(receipts, rcpt)
638 }
639 if err := rows.Err(); err != nil {
640 return nil, err
641 }
642
643 return receipts, nil
644}
645
646func (db *DB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
647 db.lock.Lock()
648 defer db.lock.Unlock()
649
650 tx, err := db.db.Begin()
651 if err != nil {
652 return err
653 }
654 defer tx.Rollback()
655
656 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",
657 networkID, toNullString(client))
658 if err != nil {
659 return err
660 }
661
662 for i := range receipts {
663 rcpt := &receipts[i]
664
665 res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
666 networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
667 if err != nil {
668 return err
669 }
670 rcpt.ID, err = res.LastInsertId()
671 if err != nil {
672 return err
673 }
674 }
675
676 return tx.Commit()
677}
Note: See TracBrowser for help on using the repository browser.