source: code/trunk/db.go@ 420

Last change on this file since 420 was 420, checked in by contact, 5 years ago

Add id column to User table

We used rowid before, but an explicit ID column is cleaner.

File size: 11.2 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
[382]13 ID int64
[77]14 Username string
[85]15 Password string // hashed
[327]16 Admin bool
[77]17}
18
[95]19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
[307]26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
[95]34}
35
[77]36type Network struct {
[263]37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
[77]46}
47
[149]48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
[77]55type Channel struct {
[284]56 ID int64
57 Name string
58 Key string
59 Detached bool
[77]60}
61
[255]62const schema = `
63CREATE TABLE User (
[420]64 id INTEGER PRIMARY KEY,
65 username VARCHAR(255) NOT NULL UNIQUE,
66 password VARCHAR(255),
[327]67 admin INTEGER NOT NULL DEFAULT 0
[255]68);
69
70CREATE TABLE Network (
71 id INTEGER PRIMARY KEY,
72 name VARCHAR(255),
73 user VARCHAR(255) NOT NULL,
74 addr VARCHAR(255) NOT NULL,
75 nick VARCHAR(255) NOT NULL,
76 username VARCHAR(255),
77 realname VARCHAR(255),
78 pass VARCHAR(255),
[263]79 connect_commands VARCHAR(1023),
[255]80 sasl_mechanism VARCHAR(255),
81 sasl_plain_username VARCHAR(255),
82 sasl_plain_password VARCHAR(255),
[307]83 sasl_external_cert BLOB DEFAULT NULL,
84 sasl_external_key BLOB DEFAULT NULL,
[255]85 FOREIGN KEY(user) REFERENCES User(username),
86 UNIQUE(user, addr, nick)
87);
88
89CREATE TABLE Channel (
90 id INTEGER PRIMARY KEY,
91 network INTEGER NOT NULL,
92 name VARCHAR(255) NOT NULL,
93 key VARCHAR(255),
[284]94 detached INTEGER NOT NULL DEFAULT 0,
[255]95 FOREIGN KEY(network) REFERENCES Network(id),
96 UNIQUE(network, name)
97);
98`
99
100var migrations = []string{
101 "", // migration #0 is reserved for schema initialization
[263]102 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]103 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]104 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
105 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[327]106 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
[420]107 `
108 CREATE TABLE UserNew (
109 id INTEGER PRIMARY KEY,
110 username VARCHAR(255) NOT NULL UNIQUE,
111 password VARCHAR(255),
112 admin INTEGER NOT NULL DEFAULT 0
113 );
114 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
115 DROP TABLE User;
116 ALTER TABLE UserNew RENAME TO User;
117 `,
[255]118}
119
[77]120type DB struct {
[81]121 lock sync.RWMutex
[77]122 db *sql.DB
123}
124
125func OpenSQLDB(driver, source string) (*DB, error) {
[255]126 sqlDB, err := sql.Open(driver, source)
[77]127 if err != nil {
128 return nil, err
129 }
[255]130
131 db := &DB{db: sqlDB}
132 if err := db.upgrade(); err != nil {
133 return nil, err
134 }
135
136 return db, nil
[77]137}
138
139func (db *DB) Close() error {
140 db.lock.Lock()
141 defer db.lock.Unlock()
[356]142 return db.db.Close()
[77]143}
144
[255]145func (db *DB) upgrade() error {
146 db.lock.Lock()
147 defer db.lock.Unlock()
148
149 var version int
150 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
151 return fmt.Errorf("failed to query schema version: %v", err)
152 }
153
154 if version == len(migrations) {
155 return nil
156 } else if version > len(migrations) {
157 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
158 }
159
160 tx, err := db.db.Begin()
161 if err != nil {
162 return err
163 }
164 defer tx.Rollback()
165
166 if version == 0 {
167 if _, err := tx.Exec(schema); err != nil {
168 return fmt.Errorf("failed to initialize schema: %v", err)
169 }
170 } else {
171 for i := version; i < len(migrations); i++ {
172 if _, err := tx.Exec(migrations[i]); err != nil {
173 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
174 }
175 }
176 }
177
178 // For some reason prepared statements don't work here
179 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
180 if err != nil {
181 return fmt.Errorf("failed to bump schema version: %v", err)
182 }
183
184 return tx.Commit()
185}
186
[95]187func fromStringPtr(ptr *string) string {
188 if ptr == nil {
189 return ""
190 }
191 return *ptr
192}
193
194func toStringPtr(s string) *string {
195 if s == "" {
196 return nil
197 }
198 return &s
199}
200
[77]201func (db *DB) ListUsers() ([]User, error) {
[81]202 db.lock.RLock()
203 defer db.lock.RUnlock()
[77]204
[420]205 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
[77]206 if err != nil {
207 return nil, err
208 }
209 defer rows.Close()
210
211 var users []User
212 for rows.Next() {
213 var user User
214 var password *string
[382]215 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
[77]216 return nil, err
217 }
[95]218 user.Password = fromStringPtr(password)
[77]219 users = append(users, user)
220 }
221 if err := rows.Err(); err != nil {
222 return nil, err
223 }
224
225 return users, nil
226}
227
[173]228func (db *DB) GetUser(username string) (*User, error) {
229 db.lock.RLock()
230 defer db.lock.RUnlock()
231
[382]232 user := &User{Username: username}
[173]233
234 var password *string
[420]235 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
[382]236 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
[173]237 return nil, err
238 }
239 user.Password = fromStringPtr(password)
240 return user, nil
241}
242
[324]243func (db *DB) StoreUser(user *User) error {
[84]244 db.lock.Lock()
245 defer db.lock.Unlock()
246
[95]247 password := toStringPtr(user.Password)
[84]248
[324]249 var err error
[382]250 if user.ID != 0 {
[327]251 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
252 password, user.Admin, user.Username)
[324]253 } else {
[382]254 var res sql.Result
[393]255 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
[327]256 user.Username, password, user.Admin)
[382]257 if err != nil {
258 return err
[324]259 }
[382]260 user.ID, err = res.LastInsertId()
[324]261 }
[251]262
263 return err
264}
265
[375]266func (db *DB) DeleteUser(username string) error {
267 db.lock.Lock()
268 defer db.lock.Unlock()
269
270 tx, err := db.db.Begin()
271 if err != nil {
272 return err
273 }
274 defer tx.Rollback()
275
276 _, err = tx.Exec(`DELETE FROM Channel
277 WHERE id IN (
278 SELECT Channel.id
279 FROM Channel
280 JOIN Network ON Channel.network = Network.id
281 WHERE Network.user = ?
282 )`, username)
283 if err != nil {
284 return err
285 }
286
287 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username)
288 if err != nil {
289 return err
290 }
291
292 _, err = tx.Exec("DELETE FROM User WHERE username = ?", username)
293 if err != nil {
294 return err
295 }
296
297 return tx.Commit()
298}
299
[77]300func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]301 db.lock.RLock()
302 defer db.lock.RUnlock()
[77]303
[118]304 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]305 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
306 sasl_external_cert, sasl_external_key
[95]307 FROM Network
308 WHERE user = ?`,
309 username)
[77]310 if err != nil {
311 return nil, err
312 }
313 defer rows.Close()
314
315 var networks []Network
316 for rows.Next() {
317 var net Network
[263]318 var name, username, realname, pass, connectCommands *string
[95]319 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]320 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]321 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
322 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]323 if err != nil {
[77]324 return nil, err
325 }
[118]326 net.Name = fromStringPtr(name)
[95]327 net.Username = fromStringPtr(username)
328 net.Realname = fromStringPtr(realname)
329 net.Pass = fromStringPtr(pass)
[263]330 if connectCommands != nil {
331 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
332 }
[95]333 net.SASL.Mechanism = fromStringPtr(saslMechanism)
334 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
335 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]336 networks = append(networks, net)
337 }
338 if err := rows.Err(); err != nil {
339 return nil, err
340 }
341
342 return networks, nil
343}
344
[90]345func (db *DB) StoreNetwork(username string, network *Network) error {
346 db.lock.Lock()
347 defer db.lock.Unlock()
348
[118]349 netName := toStringPtr(network.Name)
[95]350 netUsername := toStringPtr(network.Username)
351 realname := toStringPtr(network.Realname)
352 pass := toStringPtr(network.Pass)
[263]353 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]354
355 var saslMechanism, saslPlainUsername, saslPlainPassword *string
356 if network.SASL.Mechanism != "" {
357 saslMechanism = &network.SASL.Mechanism
358 switch network.SASL.Mechanism {
359 case "PLAIN":
360 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
361 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[307]362 network.SASL.External.CertBlob = nil
363 network.SASL.External.PrivKeyBlob = nil
364 case "EXTERNAL":
365 // keep saslPlain* nil
[148]366 default:
367 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]368 }
[90]369 }
370
371 var err error
372 if network.ID != 0 {
[93]373 _, err = db.db.Exec(`UPDATE Network
[263]374 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]375 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
376 sasl_external_cert = ?, sasl_external_key = ?
[93]377 WHERE id = ?`,
[263]378 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]379 saslMechanism, saslPlainUsername, saslPlainPassword,
380 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
381 network.ID)
[90]382 } else {
383 var res sql.Result
[118]384 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]385 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]386 sasl_plain_password, sasl_external_cert, sasl_external_key)
387 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[263]388 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]389 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
390 network.SASL.External.PrivKeyBlob)
[90]391 if err != nil {
392 return err
393 }
394 network.ID, err = res.LastInsertId()
395 }
396 return err
397}
398
[202]399func (db *DB) DeleteNetwork(id int64) error {
400 db.lock.Lock()
401 defer db.lock.Unlock()
402
403 tx, err := db.db.Begin()
404 if err != nil {
405 return err
406 }
407 defer tx.Rollback()
408
[375]409 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
[202]410 if err != nil {
411 return err
412 }
413
[375]414 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
[202]415 if err != nil {
416 return err
417 }
418
419 return tx.Commit()
420}
421
[77]422func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]423 db.lock.RLock()
424 defer db.lock.RUnlock()
[77]425
[284]426 rows, err := db.db.Query(`SELECT id, name, key, detached
427 FROM Channel
428 WHERE network = ?`, networkID)
[77]429 if err != nil {
430 return nil, err
431 }
432 defer rows.Close()
433
434 var channels []Channel
435 for rows.Next() {
436 var ch Channel
[146]437 var key *string
[284]438 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]439 return nil, err
440 }
[146]441 ch.Key = fromStringPtr(key)
[77]442 channels = append(channels, ch)
443 }
444 if err := rows.Err(); err != nil {
445 return nil, err
446 }
447
448 return channels, nil
449}
[89]450
451func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
452 db.lock.Lock()
453 defer db.lock.Unlock()
454
[146]455 key := toStringPtr(ch.Key)
[149]456
457 var err error
458 if ch.ID != 0 {
459 _, err = db.db.Exec(`UPDATE Channel
[284]460 SET network = ?, name = ?, key = ?, detached = ?
461 WHERE id = ?`,
462 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]463 } else {
464 var res sql.Result
[284]465 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
466 VALUES (?, ?, ?, ?)`,
467 networkID, ch.Name, key, ch.Detached)
[149]468 if err != nil {
469 return err
470 }
471 ch.ID, err = res.LastInsertId()
472 }
[89]473 return err
474}
475
[416]476func (db *DB) DeleteChannel(id int64) error {
[89]477 db.lock.Lock()
478 defer db.lock.Unlock()
479
[416]480 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
[89]481 return err
482}
Note: See TracBrowser for help on using the repository browser.