source: code/trunk/db.go@ 421

Last change on this file since 421 was 421, checked in by contact, 5 years ago

Switch DB API to user IDs

This commit changes the Network schema to use user IDs instead of
usernames. While at it, a new UNIQUE(user, name) constraint ensures
there is no conflict with custom network names.

Closes: https://todo.sr.ht/~emersion/soju/86
References: https://todo.sr.ht/~emersion/soju/29

File size: 12.1 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
[382]13 ID int64
[77]14 Username string
[85]15 Password string // hashed
[327]16 Admin bool
[77]17}
18
[95]19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
[307]26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
[95]34}
35
[77]36type Network struct {
[263]37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
[77]46}
47
[149]48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
[77]55type Channel struct {
[284]56 ID int64
57 Name string
58 Key string
59 Detached bool
[77]60}
61
[255]62const schema = `
63CREATE TABLE User (
[420]64 id INTEGER PRIMARY KEY,
65 username VARCHAR(255) NOT NULL UNIQUE,
66 password VARCHAR(255),
[327]67 admin INTEGER NOT NULL DEFAULT 0
[255]68);
69
70CREATE TABLE Network (
71 id INTEGER PRIMARY KEY,
72 name VARCHAR(255),
[421]73 user INTEGER NOT NULL,
[255]74 addr VARCHAR(255) NOT NULL,
75 nick VARCHAR(255) NOT NULL,
76 username VARCHAR(255),
77 realname VARCHAR(255),
78 pass VARCHAR(255),
[263]79 connect_commands VARCHAR(1023),
[255]80 sasl_mechanism VARCHAR(255),
81 sasl_plain_username VARCHAR(255),
82 sasl_plain_password VARCHAR(255),
[307]83 sasl_external_cert BLOB DEFAULT NULL,
84 sasl_external_key BLOB DEFAULT NULL,
[421]85 FOREIGN KEY(user) REFERENCES User(id),
86 UNIQUE(user, addr, nick),
87 UNIQUE(user, name)
[255]88);
89
90CREATE TABLE Channel (
91 id INTEGER PRIMARY KEY,
92 network INTEGER NOT NULL,
93 name VARCHAR(255) NOT NULL,
94 key VARCHAR(255),
[284]95 detached INTEGER NOT NULL DEFAULT 0,
[255]96 FOREIGN KEY(network) REFERENCES Network(id),
97 UNIQUE(network, name)
98);
99`
100
101var migrations = []string{
102 "", // migration #0 is reserved for schema initialization
[263]103 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]104 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]105 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
106 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[327]107 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
[420]108 `
109 CREATE TABLE UserNew (
110 id INTEGER PRIMARY KEY,
111 username VARCHAR(255) NOT NULL UNIQUE,
112 password VARCHAR(255),
113 admin INTEGER NOT NULL DEFAULT 0
114 );
115 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
116 DROP TABLE User;
117 ALTER TABLE UserNew RENAME TO User;
118 `,
[421]119 `
120 CREATE TABLE NetworkNew (
121 id INTEGER PRIMARY KEY,
122 name VARCHAR(255),
123 user INTEGER NOT NULL,
124 addr VARCHAR(255) NOT NULL,
125 nick VARCHAR(255) NOT NULL,
126 username VARCHAR(255),
127 realname VARCHAR(255),
128 pass VARCHAR(255),
129 connect_commands VARCHAR(1023),
130 sasl_mechanism VARCHAR(255),
131 sasl_plain_username VARCHAR(255),
132 sasl_plain_password VARCHAR(255),
133 sasl_external_cert BLOB DEFAULT NULL,
134 sasl_external_key BLOB DEFAULT NULL,
135 FOREIGN KEY(user) REFERENCES User(id),
136 UNIQUE(user, addr, nick),
137 UNIQUE(user, name)
138 );
139 INSERT INTO NetworkNew
140 SELECT Network.id, name, User.id as user, addr, nick,
141 Network.username, realname, pass, connect_commands,
142 sasl_mechanism, sasl_plain_username, sasl_plain_password,
143 sasl_external_cert, sasl_external_key
144 FROM Network
145 JOIN User ON Network.user = User.username;
146 DROP TABLE Network;
147 ALTER TABLE NetworkNew RENAME TO Network;
148 `,
[255]149}
150
[77]151type DB struct {
[81]152 lock sync.RWMutex
[77]153 db *sql.DB
154}
155
156func OpenSQLDB(driver, source string) (*DB, error) {
[255]157 sqlDB, err := sql.Open(driver, source)
[77]158 if err != nil {
159 return nil, err
160 }
[255]161
162 db := &DB{db: sqlDB}
163 if err := db.upgrade(); err != nil {
164 return nil, err
165 }
166
167 return db, nil
[77]168}
169
170func (db *DB) Close() error {
171 db.lock.Lock()
172 defer db.lock.Unlock()
[356]173 return db.db.Close()
[77]174}
175
[255]176func (db *DB) upgrade() error {
177 db.lock.Lock()
178 defer db.lock.Unlock()
179
180 var version int
181 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
182 return fmt.Errorf("failed to query schema version: %v", err)
183 }
184
185 if version == len(migrations) {
186 return nil
187 } else if version > len(migrations) {
188 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
189 }
190
191 tx, err := db.db.Begin()
192 if err != nil {
193 return err
194 }
195 defer tx.Rollback()
196
197 if version == 0 {
198 if _, err := tx.Exec(schema); err != nil {
199 return fmt.Errorf("failed to initialize schema: %v", err)
200 }
201 } else {
202 for i := version; i < len(migrations); i++ {
203 if _, err := tx.Exec(migrations[i]); err != nil {
204 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
205 }
206 }
207 }
208
209 // For some reason prepared statements don't work here
210 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
211 if err != nil {
212 return fmt.Errorf("failed to bump schema version: %v", err)
213 }
214
215 return tx.Commit()
216}
217
[95]218func fromStringPtr(ptr *string) string {
219 if ptr == nil {
220 return ""
221 }
222 return *ptr
223}
224
225func toStringPtr(s string) *string {
226 if s == "" {
227 return nil
228 }
229 return &s
230}
231
[77]232func (db *DB) ListUsers() ([]User, error) {
[81]233 db.lock.RLock()
234 defer db.lock.RUnlock()
[77]235
[420]236 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
[77]237 if err != nil {
238 return nil, err
239 }
240 defer rows.Close()
241
242 var users []User
243 for rows.Next() {
244 var user User
245 var password *string
[382]246 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
[77]247 return nil, err
248 }
[95]249 user.Password = fromStringPtr(password)
[77]250 users = append(users, user)
251 }
252 if err := rows.Err(); err != nil {
253 return nil, err
254 }
255
256 return users, nil
257}
258
[173]259func (db *DB) GetUser(username string) (*User, error) {
260 db.lock.RLock()
261 defer db.lock.RUnlock()
262
[382]263 user := &User{Username: username}
[173]264
265 var password *string
[420]266 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
[382]267 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
[173]268 return nil, err
269 }
270 user.Password = fromStringPtr(password)
271 return user, nil
272}
273
[324]274func (db *DB) StoreUser(user *User) error {
[84]275 db.lock.Lock()
276 defer db.lock.Unlock()
277
[95]278 password := toStringPtr(user.Password)
[84]279
[324]280 var err error
[382]281 if user.ID != 0 {
[327]282 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
283 password, user.Admin, user.Username)
[324]284 } else {
[382]285 var res sql.Result
[393]286 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
[327]287 user.Username, password, user.Admin)
[382]288 if err != nil {
289 return err
[324]290 }
[382]291 user.ID, err = res.LastInsertId()
[324]292 }
[251]293
294 return err
295}
296
[421]297func (db *DB) DeleteUser(id int64) error {
[375]298 db.lock.Lock()
299 defer db.lock.Unlock()
300
301 tx, err := db.db.Begin()
302 if err != nil {
303 return err
304 }
305 defer tx.Rollback()
306
307 _, err = tx.Exec(`DELETE FROM Channel
308 WHERE id IN (
309 SELECT Channel.id
310 FROM Channel
311 JOIN Network ON Channel.network = Network.id
312 WHERE Network.user = ?
[421]313 )`, id)
[375]314 if err != nil {
315 return err
316 }
317
[421]318 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
[375]319 if err != nil {
320 return err
321 }
322
[421]323 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
[375]324 if err != nil {
325 return err
326 }
327
328 return tx.Commit()
329}
330
[421]331func (db *DB) ListNetworks(userID int64) ([]Network, error) {
[81]332 db.lock.RLock()
333 defer db.lock.RUnlock()
[77]334
[118]335 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]336 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
337 sasl_external_cert, sasl_external_key
[95]338 FROM Network
339 WHERE user = ?`,
[421]340 userID)
[77]341 if err != nil {
342 return nil, err
343 }
344 defer rows.Close()
345
346 var networks []Network
347 for rows.Next() {
348 var net Network
[263]349 var name, username, realname, pass, connectCommands *string
[95]350 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]351 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]352 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
353 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]354 if err != nil {
[77]355 return nil, err
356 }
[118]357 net.Name = fromStringPtr(name)
[95]358 net.Username = fromStringPtr(username)
359 net.Realname = fromStringPtr(realname)
360 net.Pass = fromStringPtr(pass)
[263]361 if connectCommands != nil {
362 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
363 }
[95]364 net.SASL.Mechanism = fromStringPtr(saslMechanism)
365 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
366 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]367 networks = append(networks, net)
368 }
369 if err := rows.Err(); err != nil {
370 return nil, err
371 }
372
373 return networks, nil
374}
375
[421]376func (db *DB) StoreNetwork(userID int64, network *Network) error {
[90]377 db.lock.Lock()
378 defer db.lock.Unlock()
379
[118]380 netName := toStringPtr(network.Name)
[95]381 netUsername := toStringPtr(network.Username)
382 realname := toStringPtr(network.Realname)
383 pass := toStringPtr(network.Pass)
[263]384 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]385
386 var saslMechanism, saslPlainUsername, saslPlainPassword *string
387 if network.SASL.Mechanism != "" {
388 saslMechanism = &network.SASL.Mechanism
389 switch network.SASL.Mechanism {
390 case "PLAIN":
391 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
392 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[307]393 network.SASL.External.CertBlob = nil
394 network.SASL.External.PrivKeyBlob = nil
395 case "EXTERNAL":
396 // keep saslPlain* nil
[148]397 default:
398 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]399 }
[90]400 }
401
402 var err error
403 if network.ID != 0 {
[93]404 _, err = db.db.Exec(`UPDATE Network
[263]405 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]406 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
407 sasl_external_cert = ?, sasl_external_key = ?
[93]408 WHERE id = ?`,
[263]409 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]410 saslMechanism, saslPlainUsername, saslPlainPassword,
411 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
412 network.ID)
[90]413 } else {
414 var res sql.Result
[118]415 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]416 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]417 sasl_plain_password, sasl_external_cert, sasl_external_key)
418 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[421]419 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]420 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
421 network.SASL.External.PrivKeyBlob)
[90]422 if err != nil {
423 return err
424 }
425 network.ID, err = res.LastInsertId()
426 }
427 return err
428}
429
[202]430func (db *DB) DeleteNetwork(id int64) error {
431 db.lock.Lock()
432 defer db.lock.Unlock()
433
434 tx, err := db.db.Begin()
435 if err != nil {
436 return err
437 }
438 defer tx.Rollback()
439
[375]440 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
[202]441 if err != nil {
442 return err
443 }
444
[375]445 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
[202]446 if err != nil {
447 return err
448 }
449
450 return tx.Commit()
451}
452
[77]453func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]454 db.lock.RLock()
455 defer db.lock.RUnlock()
[77]456
[284]457 rows, err := db.db.Query(`SELECT id, name, key, detached
458 FROM Channel
459 WHERE network = ?`, networkID)
[77]460 if err != nil {
461 return nil, err
462 }
463 defer rows.Close()
464
465 var channels []Channel
466 for rows.Next() {
467 var ch Channel
[146]468 var key *string
[284]469 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]470 return nil, err
471 }
[146]472 ch.Key = fromStringPtr(key)
[77]473 channels = append(channels, ch)
474 }
475 if err := rows.Err(); err != nil {
476 return nil, err
477 }
478
479 return channels, nil
480}
[89]481
482func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
483 db.lock.Lock()
484 defer db.lock.Unlock()
485
[146]486 key := toStringPtr(ch.Key)
[149]487
488 var err error
489 if ch.ID != 0 {
490 _, err = db.db.Exec(`UPDATE Channel
[284]491 SET network = ?, name = ?, key = ?, detached = ?
492 WHERE id = ?`,
493 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]494 } else {
495 var res sql.Result
[284]496 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
497 VALUES (?, ?, ?, ?)`,
498 networkID, ch.Name, key, ch.Detached)
[149]499 if err != nil {
500 return err
501 }
502 ch.ID, err = res.LastInsertId()
503 }
[89]504 return err
505}
506
[416]507func (db *DB) DeleteChannel(id int64) error {
[89]508 db.lock.Lock()
509 defer db.lock.Unlock()
510
[416]511 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
[89]512 return err
513}
Note: See TracBrowser for help on using the repository browser.