Changeset 531 in code for trunk/db.go


Ignore:
Timestamp:
May 25, 2021, 2:35:39 PM (4 years ago)
Author:
sir
Message:

db: refactor into interface

This refactors the SQLite-specific bits into db_sqlite.go. A future
patch will add PostgreSQL support.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/db.go

    r509 r531  
    22
    33import (
    4         "database/sql"
    54        "fmt"
    6         "math"
    75        "net/url"
    86        "strings"
    9         "sync"
    107        "time"
     8)
    119
    12         _ "github.com/mattn/go-sqlite3"
    13 )
     10type Database interface {
     11        Close() error
     12
     13        ListUsers() ([]User, error)
     14        GetUser(username string) (*User, error)
     15        StoreUser(user *User) error
     16        DeleteUser(id int64) error
     17
     18        ListNetworks(userID int64) ([]Network, error)
     19        StoreNetwork(userID int64, network *Network) error
     20        DeleteNetwork(id int64) error
     21        ListChannels(networkID int64) ([]Channel, error)
     22        StoreChannel(networKID int64, ch *Channel) error
     23        DeleteChannel(id int64) error
     24
     25        ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error)
     26        StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error
     27}
    1428
    1529type User struct {
     
    129143        InternalMsgID string
    130144}
    131 
    132 const schema = `
    133 CREATE TABLE User (
    134         id INTEGER PRIMARY KEY,
    135         username VARCHAR(255) NOT NULL UNIQUE,
    136         password VARCHAR(255),
    137         admin INTEGER NOT NULL DEFAULT 0
    138 );
    139 
    140 CREATE TABLE Network (
    141         id INTEGER PRIMARY KEY,
    142         name VARCHAR(255),
    143         user INTEGER NOT NULL,
    144         addr VARCHAR(255) NOT NULL,
    145         nick VARCHAR(255) NOT NULL,
    146         username VARCHAR(255),
    147         realname VARCHAR(255),
    148         pass VARCHAR(255),
    149         connect_commands VARCHAR(1023),
    150         sasl_mechanism VARCHAR(255),
    151         sasl_plain_username VARCHAR(255),
    152         sasl_plain_password VARCHAR(255),
    153         sasl_external_cert BLOB DEFAULT NULL,
    154         sasl_external_key BLOB DEFAULT NULL,
    155         FOREIGN KEY(user) REFERENCES User(id),
    156         UNIQUE(user, addr, nick),
    157         UNIQUE(user, name)
    158 );
    159 
    160 CREATE TABLE Channel (
    161         id INTEGER PRIMARY KEY,
    162         network INTEGER NOT NULL,
    163         name VARCHAR(255) NOT NULL,
    164         key VARCHAR(255),
    165         detached INTEGER NOT NULL DEFAULT 0,
    166         detached_internal_msgid VARCHAR(255),
    167         relay_detached INTEGER NOT NULL DEFAULT 0,
    168         reattach_on INTEGER NOT NULL DEFAULT 0,
    169         detach_after INTEGER NOT NULL DEFAULT 0,
    170         detach_on INTEGER NOT NULL DEFAULT 0,
    171         FOREIGN KEY(network) REFERENCES Network(id),
    172         UNIQUE(network, name)
    173 );
    174 
    175 CREATE TABLE DeliveryReceipt (
    176         id INTEGER PRIMARY KEY,
    177         network INTEGER NOT NULL,
    178         target VARCHAR(255) NOT NULL,
    179         client VARCHAR(255),
    180         internal_msgid VARCHAR(255) NOT NULL,
    181         FOREIGN KEY(network) REFERENCES Network(id),
    182         UNIQUE(network, target, client)
    183 );
    184 `
    185 
    186 var migrations = []string{
    187         "", // migration #0 is reserved for schema initialization
    188         "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
    189         "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
    190         "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
    191         "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
    192         "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
    193         `
    194                 CREATE TABLE UserNew (
    195                         id INTEGER PRIMARY KEY,
    196                         username VARCHAR(255) NOT NULL UNIQUE,
    197                         password VARCHAR(255),
    198                         admin INTEGER NOT NULL DEFAULT 0
    199                 );
    200                 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
    201                 DROP TABLE User;
    202                 ALTER TABLE UserNew RENAME TO User;
    203         `,
    204         `
    205                 CREATE TABLE NetworkNew (
    206                         id INTEGER PRIMARY KEY,
    207                         name VARCHAR(255),
    208                         user INTEGER NOT NULL,
    209                         addr VARCHAR(255) NOT NULL,
    210                         nick VARCHAR(255) NOT NULL,
    211                         username VARCHAR(255),
    212                         realname VARCHAR(255),
    213                         pass VARCHAR(255),
    214                         connect_commands VARCHAR(1023),
    215                         sasl_mechanism VARCHAR(255),
    216                         sasl_plain_username VARCHAR(255),
    217                         sasl_plain_password VARCHAR(255),
    218                         sasl_external_cert BLOB DEFAULT NULL,
    219                         sasl_external_key BLOB DEFAULT NULL,
    220                         FOREIGN KEY(user) REFERENCES User(id),
    221                         UNIQUE(user, addr, nick),
    222                         UNIQUE(user, name)
    223                 );
    224                 INSERT INTO NetworkNew
    225                         SELECT Network.id, name, User.id as user, addr, nick,
    226                                 Network.username, realname, pass, connect_commands,
    227                                 sasl_mechanism, sasl_plain_username, sasl_plain_password,
    228                                 sasl_external_cert, sasl_external_key
    229                         FROM Network
    230                         JOIN User ON Network.user = User.username;
    231                 DROP TABLE Network;
    232                 ALTER TABLE NetworkNew RENAME TO Network;
    233         `,
    234         `
    235                 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
    236                 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
    237                 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
    238                 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
    239         `,
    240         `
    241                 CREATE TABLE DeliveryReceipt (
    242                         id INTEGER PRIMARY KEY,
    243                         network INTEGER NOT NULL,
    244                         target VARCHAR(255) NOT NULL,
    245                         client VARCHAR(255),
    246                         internal_msgid VARCHAR(255) NOT NULL,
    247                         FOREIGN KEY(network) REFERENCES Network(id),
    248                         UNIQUE(network, target, client)
    249                 );
    250         `,
    251         "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
    252 }
    253 
    254 type DB struct {
    255         lock sync.RWMutex
    256         db   *sql.DB
    257 }
    258 
    259 func OpenSQLDB(driver, source string) (*DB, error) {
    260         sqlDB, err := sql.Open(driver, source)
    261         if err != nil {
    262                 return nil, err
    263         }
    264 
    265         db := &DB{db: sqlDB}
    266         if err := db.upgrade(); err != nil {
    267                 return nil, err
    268         }
    269 
    270         return db, nil
    271 }
    272 
    273 func (db *DB) Close() error {
    274         db.lock.Lock()
    275         defer db.lock.Unlock()
    276         return db.db.Close()
    277 }
    278 
    279 func (db *DB) upgrade() error {
    280         db.lock.Lock()
    281         defer db.lock.Unlock()
    282 
    283         var version int
    284         if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
    285                 return fmt.Errorf("failed to query schema version: %v", err)
    286         }
    287 
    288         if version == len(migrations) {
    289                 return nil
    290         } else if version > len(migrations) {
    291                 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
    292         }
    293 
    294         tx, err := db.db.Begin()
    295         if err != nil {
    296                 return err
    297         }
    298         defer tx.Rollback()
    299 
    300         if version == 0 {
    301                 if _, err := tx.Exec(schema); err != nil {
    302                         return fmt.Errorf("failed to initialize schema: %v", err)
    303                 }
    304         } else {
    305                 for i := version; i < len(migrations); i++ {
    306                         if _, err := tx.Exec(migrations[i]); err != nil {
    307                                 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
    308                         }
    309                 }
    310         }
    311 
    312         // For some reason prepared statements don't work here
    313         _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
    314         if err != nil {
    315                 return fmt.Errorf("failed to bump schema version: %v", err)
    316         }
    317 
    318         return tx.Commit()
    319 }
    320 
    321 func toNullString(s string) sql.NullString {
    322         return sql.NullString{
    323                 String: s,
    324                 Valid:  s != "",
    325         }
    326 }
    327 
    328 func (db *DB) ListUsers() ([]User, error) {
    329         db.lock.RLock()
    330         defer db.lock.RUnlock()
    331 
    332         rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
    333         if err != nil {
    334                 return nil, err
    335         }
    336         defer rows.Close()
    337 
    338         var users []User
    339         for rows.Next() {
    340                 var user User
    341                 var password sql.NullString
    342                 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
    343                         return nil, err
    344                 }
    345                 user.Password = password.String
    346                 users = append(users, user)
    347         }
    348         if err := rows.Err(); err != nil {
    349                 return nil, err
    350         }
    351 
    352         return users, nil
    353 }
    354 
    355 func (db *DB) GetUser(username string) (*User, error) {
    356         db.lock.RLock()
    357         defer db.lock.RUnlock()
    358 
    359         user := &User{Username: username}
    360 
    361         var password sql.NullString
    362         row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
    363         if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
    364                 return nil, err
    365         }
    366         user.Password = password.String
    367         return user, nil
    368 }
    369 
    370 func (db *DB) StoreUser(user *User) error {
    371         db.lock.Lock()
    372         defer db.lock.Unlock()
    373 
    374         password := toNullString(user.Password)
    375 
    376         var err error
    377         if user.ID != 0 {
    378                 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
    379                         password, user.Admin, user.Username)
    380         } else {
    381                 var res sql.Result
    382                 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
    383                         user.Username, password, user.Admin)
    384                 if err != nil {
    385                         return err
    386                 }
    387                 user.ID, err = res.LastInsertId()
    388         }
    389 
    390         return err
    391 }
    392 
    393 func (db *DB) DeleteUser(id int64) error {
    394         db.lock.Lock()
    395         defer db.lock.Unlock()
    396 
    397         tx, err := db.db.Begin()
    398         if err != nil {
    399                 return err
    400         }
    401         defer tx.Rollback()
    402 
    403         _, err = tx.Exec(`DELETE FROM Channel
    404                 WHERE id IN (
    405                         SELECT Channel.id
    406                         FROM Channel
    407                         JOIN Network ON Channel.network = Network.id
    408                         WHERE Network.user = ?
    409                 )`, id)
    410         if err != nil {
    411                 return err
    412         }
    413 
    414         _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
    415         if err != nil {
    416                 return err
    417         }
    418 
    419         _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
    420         if err != nil {
    421                 return err
    422         }
    423 
    424         return tx.Commit()
    425 }
    426 
    427 func (db *DB) ListNetworks(userID int64) ([]Network, error) {
    428         db.lock.RLock()
    429         defer db.lock.RUnlock()
    430 
    431         rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
    432                         connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
    433                         sasl_external_cert, sasl_external_key
    434                 FROM Network
    435                 WHERE user = ?`,
    436                 userID)
    437         if err != nil {
    438                 return nil, err
    439         }
    440         defer rows.Close()
    441 
    442         var networks []Network
    443         for rows.Next() {
    444                 var net Network
    445                 var name, username, realname, pass, connectCommands sql.NullString
    446                 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
    447                 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
    448                         &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
    449                         &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
    450                 if err != nil {
    451                         return nil, err
    452                 }
    453                 net.Name = name.String
    454                 net.Username = username.String
    455                 net.Realname = realname.String
    456                 net.Pass = pass.String
    457                 if connectCommands.Valid {
    458                         net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
    459                 }
    460                 net.SASL.Mechanism = saslMechanism.String
    461                 net.SASL.Plain.Username = saslPlainUsername.String
    462                 net.SASL.Plain.Password = saslPlainPassword.String
    463                 networks = append(networks, net)
    464         }
    465         if err := rows.Err(); err != nil {
    466                 return nil, err
    467         }
    468 
    469         return networks, nil
    470 }
    471 
    472 func (db *DB) StoreNetwork(userID int64, network *Network) error {
    473         db.lock.Lock()
    474         defer db.lock.Unlock()
    475 
    476         netName := toNullString(network.Name)
    477         netUsername := toNullString(network.Username)
    478         realname := toNullString(network.Realname)
    479         pass := toNullString(network.Pass)
    480         connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
    481 
    482         var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
    483         if network.SASL.Mechanism != "" {
    484                 saslMechanism = toNullString(network.SASL.Mechanism)
    485                 switch network.SASL.Mechanism {
    486                 case "PLAIN":
    487                         saslPlainUsername = toNullString(network.SASL.Plain.Username)
    488                         saslPlainPassword = toNullString(network.SASL.Plain.Password)
    489                         network.SASL.External.CertBlob = nil
    490                         network.SASL.External.PrivKeyBlob = nil
    491                 case "EXTERNAL":
    492                         // keep saslPlain* nil
    493                 default:
    494                         return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
    495                 }
    496         }
    497 
    498         var err error
    499         if network.ID != 0 {
    500                 _, err = db.db.Exec(`UPDATE Network
    501                         SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
    502                                 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
    503                                 sasl_external_cert = ?, sasl_external_key = ?
    504                         WHERE id = ?`,
    505                         netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
    506                         saslMechanism, saslPlainUsername, saslPlainPassword,
    507                         network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
    508                         network.ID)
    509         } else {
    510                 var res sql.Result
    511                 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
    512                                 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
    513                                 sasl_plain_password, sasl_external_cert, sasl_external_key)
    514                         VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
    515                         userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
    516                         saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
    517                         network.SASL.External.PrivKeyBlob)
    518                 if err != nil {
    519                         return err
    520                 }
    521                 network.ID, err = res.LastInsertId()
    522         }
    523         return err
    524 }
    525 
    526 func (db *DB) DeleteNetwork(id int64) error {
    527         db.lock.Lock()
    528         defer db.lock.Unlock()
    529 
    530         tx, err := db.db.Begin()
    531         if err != nil {
    532                 return err
    533         }
    534         defer tx.Rollback()
    535 
    536         _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
    537         if err != nil {
    538                 return err
    539         }
    540 
    541         _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
    542         if err != nil {
    543                 return err
    544         }
    545 
    546         return tx.Commit()
    547 }
    548 
    549 func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
    550         db.lock.RLock()
    551         defer db.lock.RUnlock()
    552 
    553         rows, err := db.db.Query(`SELECT
    554                         id, name, key, detached, detached_internal_msgid,
    555                         relay_detached, reattach_on, detach_after, detach_on
    556                 FROM Channel
    557                 WHERE network = ?`, networkID)
    558         if err != nil {
    559                 return nil, err
    560         }
    561         defer rows.Close()
    562 
    563         var channels []Channel
    564         for rows.Next() {
    565                 var ch Channel
    566                 var key, detachedInternalMsgID sql.NullString
    567                 var detachAfter int64
    568                 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
    569                         return nil, err
    570                 }
    571                 ch.Key = key.String
    572                 ch.DetachedInternalMsgID = detachedInternalMsgID.String
    573                 ch.DetachAfter = time.Duration(detachAfter) * time.Second
    574                 channels = append(channels, ch)
    575         }
    576         if err := rows.Err(); err != nil {
    577                 return nil, err
    578         }
    579 
    580         return channels, nil
    581 }
    582 
    583 func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
    584         db.lock.Lock()
    585         defer db.lock.Unlock()
    586 
    587         key := toNullString(ch.Key)
    588         detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
    589 
    590         var err error
    591         if ch.ID != 0 {
    592                 _, err = db.db.Exec(`UPDATE Channel
    593                         SET network = ?, name = ?, key = ?, detached = ?, detached_internal_msgid = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?
    594                         WHERE id = ?`,
    595                         networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)
    596         } else {
    597                 var res sql.Result
    598                 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
    599                         VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
    600                         networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
    601                 if err != nil {
    602                         return err
    603                 }
    604                 ch.ID, err = res.LastInsertId()
    605         }
    606         return err
    607 }
    608 
    609 func (db *DB) DeleteChannel(id int64) error {
    610         db.lock.Lock()
    611         defer db.lock.Unlock()
    612 
    613         _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
    614         return err
    615 }
    616 
    617 func (db *DB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
    618         db.lock.RLock()
    619         defer db.lock.RUnlock()
    620 
    621         rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
    622                 FROM DeliveryReceipt
    623                 WHERE network = ?`, networkID)
    624         if err != nil {
    625                 return nil, err
    626         }
    627         defer rows.Close()
    628 
    629         var receipts []DeliveryReceipt
    630         for rows.Next() {
    631                 var rcpt DeliveryReceipt
    632                 var client sql.NullString
    633                 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
    634                         return nil, err
    635                 }
    636                 rcpt.Client = client.String
    637                 receipts = append(receipts, rcpt)
    638         }
    639         if err := rows.Err(); err != nil {
    640                 return nil, err
    641         }
    642 
    643         return receipts, nil
    644 }
    645 
    646 func (db *DB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
    647         db.lock.Lock()
    648         defer db.lock.Unlock()
    649 
    650         tx, err := db.db.Begin()
    651         if err != nil {
    652                 return err
    653         }
    654         defer tx.Rollback()
    655 
    656         _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",
    657                 networkID, toNullString(client))
    658         if err != nil {
    659                 return err
    660         }
    661 
    662         for i := range receipts {
    663                 rcpt := &receipts[i]
    664 
    665                 res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
    666                         networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
    667                 if err != nil {
    668                         return err
    669                 }
    670                 rcpt.ID, err = res.LastInsertId()
    671                 if err != nil {
    672                         return err
    673                 }
    674         }
    675 
    676         return tx.Commit()
    677 }
Note: See TracChangeset for help on using the changeset viewer.