Changeset 531 in code for trunk/db.go
- Timestamp:
- May 25, 2021, 2:35:39 PM (4 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/db.go
r509 r531 2 2 3 3 import ( 4 "database/sql"5 4 "fmt" 6 "math"7 5 "net/url" 8 6 "strings" 9 "sync"10 7 "time" 8 ) 11 9 12 _ "github.com/mattn/go-sqlite3" 13 ) 10 type Database interface { 11 Close() error 12 13 ListUsers() ([]User, error) 14 GetUser(username string) (*User, error) 15 StoreUser(user *User) error 16 DeleteUser(id int64) error 17 18 ListNetworks(userID int64) ([]Network, error) 19 StoreNetwork(userID int64, network *Network) error 20 DeleteNetwork(id int64) error 21 ListChannels(networkID int64) ([]Channel, error) 22 StoreChannel(networKID int64, ch *Channel) error 23 DeleteChannel(id int64) error 24 25 ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) 26 StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error 27 } 14 28 15 29 type User struct { … … 129 143 InternalMsgID string 130 144 } 131 132 const schema = `133 CREATE TABLE User (134 id INTEGER PRIMARY KEY,135 username VARCHAR(255) NOT NULL UNIQUE,136 password VARCHAR(255),137 admin INTEGER NOT NULL DEFAULT 0138 );139 140 CREATE TABLE Network (141 id INTEGER PRIMARY KEY,142 name VARCHAR(255),143 user INTEGER NOT NULL,144 addr VARCHAR(255) NOT NULL,145 nick VARCHAR(255) NOT NULL,146 username VARCHAR(255),147 realname VARCHAR(255),148 pass VARCHAR(255),149 connect_commands VARCHAR(1023),150 sasl_mechanism VARCHAR(255),151 sasl_plain_username VARCHAR(255),152 sasl_plain_password VARCHAR(255),153 sasl_external_cert BLOB DEFAULT NULL,154 sasl_external_key BLOB DEFAULT NULL,155 FOREIGN KEY(user) REFERENCES User(id),156 UNIQUE(user, addr, nick),157 UNIQUE(user, name)158 );159 160 CREATE TABLE Channel (161 id INTEGER PRIMARY KEY,162 network INTEGER NOT NULL,163 name VARCHAR(255) NOT NULL,164 key VARCHAR(255),165 detached INTEGER NOT NULL DEFAULT 0,166 detached_internal_msgid VARCHAR(255),167 relay_detached INTEGER NOT NULL DEFAULT 0,168 reattach_on INTEGER NOT NULL DEFAULT 0,169 detach_after INTEGER NOT NULL DEFAULT 0,170 detach_on INTEGER NOT NULL DEFAULT 0,171 FOREIGN KEY(network) REFERENCES Network(id),172 UNIQUE(network, name)173 );174 175 CREATE TABLE DeliveryReceipt (176 id INTEGER PRIMARY KEY,177 network INTEGER NOT NULL,178 target VARCHAR(255) NOT NULL,179 client VARCHAR(255),180 internal_msgid VARCHAR(255) NOT NULL,181 FOREIGN KEY(network) REFERENCES Network(id),182 UNIQUE(network, target, client)183 );184 `185 186 var migrations = []string{187 "", // migration #0 is reserved for schema initialization188 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",189 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",190 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",191 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",192 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",193 `194 CREATE TABLE UserNew (195 id INTEGER PRIMARY KEY,196 username VARCHAR(255) NOT NULL UNIQUE,197 password VARCHAR(255),198 admin INTEGER NOT NULL DEFAULT 0199 );200 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;201 DROP TABLE User;202 ALTER TABLE UserNew RENAME TO User;203 `,204 `205 CREATE TABLE NetworkNew (206 id INTEGER PRIMARY KEY,207 name VARCHAR(255),208 user INTEGER NOT NULL,209 addr VARCHAR(255) NOT NULL,210 nick VARCHAR(255) NOT NULL,211 username VARCHAR(255),212 realname VARCHAR(255),213 pass VARCHAR(255),214 connect_commands VARCHAR(1023),215 sasl_mechanism VARCHAR(255),216 sasl_plain_username VARCHAR(255),217 sasl_plain_password VARCHAR(255),218 sasl_external_cert BLOB DEFAULT NULL,219 sasl_external_key BLOB DEFAULT NULL,220 FOREIGN KEY(user) REFERENCES User(id),221 UNIQUE(user, addr, nick),222 UNIQUE(user, name)223 );224 INSERT INTO NetworkNew225 SELECT Network.id, name, User.id as user, addr, nick,226 Network.username, realname, pass, connect_commands,227 sasl_mechanism, sasl_plain_username, sasl_plain_password,228 sasl_external_cert, sasl_external_key229 FROM Network230 JOIN User ON Network.user = User.username;231 DROP TABLE Network;232 ALTER TABLE NetworkNew RENAME TO Network;233 `,234 `235 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;236 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;237 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;238 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;239 `,240 `241 CREATE TABLE DeliveryReceipt (242 id INTEGER PRIMARY KEY,243 network INTEGER NOT NULL,244 target VARCHAR(255) NOT NULL,245 client VARCHAR(255),246 internal_msgid VARCHAR(255) NOT NULL,247 FOREIGN KEY(network) REFERENCES Network(id),248 UNIQUE(network, target, client)249 );250 `,251 "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",252 }253 254 type DB struct {255 lock sync.RWMutex256 db *sql.DB257 }258 259 func OpenSQLDB(driver, source string) (*DB, error) {260 sqlDB, err := sql.Open(driver, source)261 if err != nil {262 return nil, err263 }264 265 db := &DB{db: sqlDB}266 if err := db.upgrade(); err != nil {267 return nil, err268 }269 270 return db, nil271 }272 273 func (db *DB) Close() error {274 db.lock.Lock()275 defer db.lock.Unlock()276 return db.db.Close()277 }278 279 func (db *DB) upgrade() error {280 db.lock.Lock()281 defer db.lock.Unlock()282 283 var version int284 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {285 return fmt.Errorf("failed to query schema version: %v", err)286 }287 288 if version == len(migrations) {289 return nil290 } else if version > len(migrations) {291 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)292 }293 294 tx, err := db.db.Begin()295 if err != nil {296 return err297 }298 defer tx.Rollback()299 300 if version == 0 {301 if _, err := tx.Exec(schema); err != nil {302 return fmt.Errorf("failed to initialize schema: %v", err)303 }304 } else {305 for i := version; i < len(migrations); i++ {306 if _, err := tx.Exec(migrations[i]); err != nil {307 return fmt.Errorf("failed to execute migration #%v: %v", i, err)308 }309 }310 }311 312 // For some reason prepared statements don't work here313 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))314 if err != nil {315 return fmt.Errorf("failed to bump schema version: %v", err)316 }317 318 return tx.Commit()319 }320 321 func toNullString(s string) sql.NullString {322 return sql.NullString{323 String: s,324 Valid: s != "",325 }326 }327 328 func (db *DB) ListUsers() ([]User, error) {329 db.lock.RLock()330 defer db.lock.RUnlock()331 332 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")333 if err != nil {334 return nil, err335 }336 defer rows.Close()337 338 var users []User339 for rows.Next() {340 var user User341 var password sql.NullString342 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {343 return nil, err344 }345 user.Password = password.String346 users = append(users, user)347 }348 if err := rows.Err(); err != nil {349 return nil, err350 }351 352 return users, nil353 }354 355 func (db *DB) GetUser(username string) (*User, error) {356 db.lock.RLock()357 defer db.lock.RUnlock()358 359 user := &User{Username: username}360 361 var password sql.NullString362 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)363 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {364 return nil, err365 }366 user.Password = password.String367 return user, nil368 }369 370 func (db *DB) StoreUser(user *User) error {371 db.lock.Lock()372 defer db.lock.Unlock()373 374 password := toNullString(user.Password)375 376 var err error377 if user.ID != 0 {378 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",379 password, user.Admin, user.Username)380 } else {381 var res sql.Result382 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",383 user.Username, password, user.Admin)384 if err != nil {385 return err386 }387 user.ID, err = res.LastInsertId()388 }389 390 return err391 }392 393 func (db *DB) DeleteUser(id int64) error {394 db.lock.Lock()395 defer db.lock.Unlock()396 397 tx, err := db.db.Begin()398 if err != nil {399 return err400 }401 defer tx.Rollback()402 403 _, err = tx.Exec(`DELETE FROM Channel404 WHERE id IN (405 SELECT Channel.id406 FROM Channel407 JOIN Network ON Channel.network = Network.id408 WHERE Network.user = ?409 )`, id)410 if err != nil {411 return err412 }413 414 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)415 if err != nil {416 return err417 }418 419 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)420 if err != nil {421 return err422 }423 424 return tx.Commit()425 }426 427 func (db *DB) ListNetworks(userID int64) ([]Network, error) {428 db.lock.RLock()429 defer db.lock.RUnlock()430 431 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,432 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,433 sasl_external_cert, sasl_external_key434 FROM Network435 WHERE user = ?`,436 userID)437 if err != nil {438 return nil, err439 }440 defer rows.Close()441 442 var networks []Network443 for rows.Next() {444 var net Network445 var name, username, realname, pass, connectCommands sql.NullString446 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString447 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,448 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,449 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)450 if err != nil {451 return nil, err452 }453 net.Name = name.String454 net.Username = username.String455 net.Realname = realname.String456 net.Pass = pass.String457 if connectCommands.Valid {458 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")459 }460 net.SASL.Mechanism = saslMechanism.String461 net.SASL.Plain.Username = saslPlainUsername.String462 net.SASL.Plain.Password = saslPlainPassword.String463 networks = append(networks, net)464 }465 if err := rows.Err(); err != nil {466 return nil, err467 }468 469 return networks, nil470 }471 472 func (db *DB) StoreNetwork(userID int64, network *Network) error {473 db.lock.Lock()474 defer db.lock.Unlock()475 476 netName := toNullString(network.Name)477 netUsername := toNullString(network.Username)478 realname := toNullString(network.Realname)479 pass := toNullString(network.Pass)480 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))481 482 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString483 if network.SASL.Mechanism != "" {484 saslMechanism = toNullString(network.SASL.Mechanism)485 switch network.SASL.Mechanism {486 case "PLAIN":487 saslPlainUsername = toNullString(network.SASL.Plain.Username)488 saslPlainPassword = toNullString(network.SASL.Plain.Password)489 network.SASL.External.CertBlob = nil490 network.SASL.External.PrivKeyBlob = nil491 case "EXTERNAL":492 // keep saslPlain* nil493 default:494 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)495 }496 }497 498 var err error499 if network.ID != 0 {500 _, err = db.db.Exec(`UPDATE Network501 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,502 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,503 sasl_external_cert = ?, sasl_external_key = ?504 WHERE id = ?`,505 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,506 saslMechanism, saslPlainUsername, saslPlainPassword,507 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,508 network.ID)509 } else {510 var res sql.Result511 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,512 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,513 sasl_plain_password, sasl_external_cert, sasl_external_key)514 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,515 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,516 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,517 network.SASL.External.PrivKeyBlob)518 if err != nil {519 return err520 }521 network.ID, err = res.LastInsertId()522 }523 return err524 }525 526 func (db *DB) DeleteNetwork(id int64) error {527 db.lock.Lock()528 defer db.lock.Unlock()529 530 tx, err := db.db.Begin()531 if err != nil {532 return err533 }534 defer tx.Rollback()535 536 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)537 if err != nil {538 return err539 }540 541 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)542 if err != nil {543 return err544 }545 546 return tx.Commit()547 }548 549 func (db *DB) ListChannels(networkID int64) ([]Channel, error) {550 db.lock.RLock()551 defer db.lock.RUnlock()552 553 rows, err := db.db.Query(`SELECT554 id, name, key, detached, detached_internal_msgid,555 relay_detached, reattach_on, detach_after, detach_on556 FROM Channel557 WHERE network = ?`, networkID)558 if err != nil {559 return nil, err560 }561 defer rows.Close()562 563 var channels []Channel564 for rows.Next() {565 var ch Channel566 var key, detachedInternalMsgID sql.NullString567 var detachAfter int64568 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {569 return nil, err570 }571 ch.Key = key.String572 ch.DetachedInternalMsgID = detachedInternalMsgID.String573 ch.DetachAfter = time.Duration(detachAfter) * time.Second574 channels = append(channels, ch)575 }576 if err := rows.Err(); err != nil {577 return nil, err578 }579 580 return channels, nil581 }582 583 func (db *DB) StoreChannel(networkID int64, ch *Channel) error {584 db.lock.Lock()585 defer db.lock.Unlock()586 587 key := toNullString(ch.Key)588 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))589 590 var err error591 if ch.ID != 0 {592 _, err = db.db.Exec(`UPDATE Channel593 SET network = ?, name = ?, key = ?, detached = ?, detached_internal_msgid = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?594 WHERE id = ?`,595 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)596 } else {597 var res sql.Result598 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)599 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,600 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)601 if err != nil {602 return err603 }604 ch.ID, err = res.LastInsertId()605 }606 return err607 }608 609 func (db *DB) DeleteChannel(id int64) error {610 db.lock.Lock()611 defer db.lock.Unlock()612 613 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)614 return err615 }616 617 func (db *DB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {618 db.lock.RLock()619 defer db.lock.RUnlock()620 621 rows, err := db.db.Query(`SELECT id, target, client, internal_msgid622 FROM DeliveryReceipt623 WHERE network = ?`, networkID)624 if err != nil {625 return nil, err626 }627 defer rows.Close()628 629 var receipts []DeliveryReceipt630 for rows.Next() {631 var rcpt DeliveryReceipt632 var client sql.NullString633 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {634 return nil, err635 }636 rcpt.Client = client.String637 receipts = append(receipts, rcpt)638 }639 if err := rows.Err(); err != nil {640 return nil, err641 }642 643 return receipts, nil644 }645 646 func (db *DB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {647 db.lock.Lock()648 defer db.lock.Unlock()649 650 tx, err := db.db.Begin()651 if err != nil {652 return err653 }654 defer tx.Rollback()655 656 _, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",657 networkID, toNullString(client))658 if err != nil {659 return err660 }661 662 for i := range receipts {663 rcpt := &receipts[i]664 665 res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",666 networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)667 if err != nil {668 return err669 }670 rcpt.ID, err = res.LastInsertId()671 if err != nil {672 return err673 }674 }675 676 return tx.Commit()677 }
Note:
See TracChangeset
for help on using the changeset viewer.