source: code/trunk/db.go@ 429

Last change on this file since 429 was 422, checked in by contact, 5 years ago

Switch to sql.NullString

Not really better than what we had before, however new contributors will
maybe be familiar with it.

File size: 12.0 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 ID int64
14 Username string
15 Password string // hashed
16 Admin bool
17}
18
19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
34}
35
36type Network struct {
37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
46}
47
48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
55type Channel struct {
56 ID int64
57 Name string
58 Key string
59 Detached bool
60}
61
62const schema = `
63CREATE TABLE User (
64 id INTEGER PRIMARY KEY,
65 username VARCHAR(255) NOT NULL UNIQUE,
66 password VARCHAR(255),
67 admin INTEGER NOT NULL DEFAULT 0
68);
69
70CREATE TABLE Network (
71 id INTEGER PRIMARY KEY,
72 name VARCHAR(255),
73 user INTEGER NOT NULL,
74 addr VARCHAR(255) NOT NULL,
75 nick VARCHAR(255) NOT NULL,
76 username VARCHAR(255),
77 realname VARCHAR(255),
78 pass VARCHAR(255),
79 connect_commands VARCHAR(1023),
80 sasl_mechanism VARCHAR(255),
81 sasl_plain_username VARCHAR(255),
82 sasl_plain_password VARCHAR(255),
83 sasl_external_cert BLOB DEFAULT NULL,
84 sasl_external_key BLOB DEFAULT NULL,
85 FOREIGN KEY(user) REFERENCES User(id),
86 UNIQUE(user, addr, nick),
87 UNIQUE(user, name)
88);
89
90CREATE TABLE Channel (
91 id INTEGER PRIMARY KEY,
92 network INTEGER NOT NULL,
93 name VARCHAR(255) NOT NULL,
94 key VARCHAR(255),
95 detached INTEGER NOT NULL DEFAULT 0,
96 FOREIGN KEY(network) REFERENCES Network(id),
97 UNIQUE(network, name)
98);
99`
100
101var migrations = []string{
102 "", // migration #0 is reserved for schema initialization
103 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
104 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
105 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
106 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
107 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
108 `
109 CREATE TABLE UserNew (
110 id INTEGER PRIMARY KEY,
111 username VARCHAR(255) NOT NULL UNIQUE,
112 password VARCHAR(255),
113 admin INTEGER NOT NULL DEFAULT 0
114 );
115 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
116 DROP TABLE User;
117 ALTER TABLE UserNew RENAME TO User;
118 `,
119 `
120 CREATE TABLE NetworkNew (
121 id INTEGER PRIMARY KEY,
122 name VARCHAR(255),
123 user INTEGER NOT NULL,
124 addr VARCHAR(255) NOT NULL,
125 nick VARCHAR(255) NOT NULL,
126 username VARCHAR(255),
127 realname VARCHAR(255),
128 pass VARCHAR(255),
129 connect_commands VARCHAR(1023),
130 sasl_mechanism VARCHAR(255),
131 sasl_plain_username VARCHAR(255),
132 sasl_plain_password VARCHAR(255),
133 sasl_external_cert BLOB DEFAULT NULL,
134 sasl_external_key BLOB DEFAULT NULL,
135 FOREIGN KEY(user) REFERENCES User(id),
136 UNIQUE(user, addr, nick),
137 UNIQUE(user, name)
138 );
139 INSERT INTO NetworkNew
140 SELECT Network.id, name, User.id as user, addr, nick,
141 Network.username, realname, pass, connect_commands,
142 sasl_mechanism, sasl_plain_username, sasl_plain_password,
143 sasl_external_cert, sasl_external_key
144 FROM Network
145 JOIN User ON Network.user = User.username;
146 DROP TABLE Network;
147 ALTER TABLE NetworkNew RENAME TO Network;
148 `,
149}
150
151type DB struct {
152 lock sync.RWMutex
153 db *sql.DB
154}
155
156func OpenSQLDB(driver, source string) (*DB, error) {
157 sqlDB, err := sql.Open(driver, source)
158 if err != nil {
159 return nil, err
160 }
161
162 db := &DB{db: sqlDB}
163 if err := db.upgrade(); err != nil {
164 return nil, err
165 }
166
167 return db, nil
168}
169
170func (db *DB) Close() error {
171 db.lock.Lock()
172 defer db.lock.Unlock()
173 return db.db.Close()
174}
175
176func (db *DB) upgrade() error {
177 db.lock.Lock()
178 defer db.lock.Unlock()
179
180 var version int
181 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
182 return fmt.Errorf("failed to query schema version: %v", err)
183 }
184
185 if version == len(migrations) {
186 return nil
187 } else if version > len(migrations) {
188 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
189 }
190
191 tx, err := db.db.Begin()
192 if err != nil {
193 return err
194 }
195 defer tx.Rollback()
196
197 if version == 0 {
198 if _, err := tx.Exec(schema); err != nil {
199 return fmt.Errorf("failed to initialize schema: %v", err)
200 }
201 } else {
202 for i := version; i < len(migrations); i++ {
203 if _, err := tx.Exec(migrations[i]); err != nil {
204 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
205 }
206 }
207 }
208
209 // For some reason prepared statements don't work here
210 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
211 if err != nil {
212 return fmt.Errorf("failed to bump schema version: %v", err)
213 }
214
215 return tx.Commit()
216}
217
218func toNullString(s string) sql.NullString {
219 return sql.NullString{
220 String: s,
221 Valid: s != "",
222 }
223}
224
225func (db *DB) ListUsers() ([]User, error) {
226 db.lock.RLock()
227 defer db.lock.RUnlock()
228
229 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
230 if err != nil {
231 return nil, err
232 }
233 defer rows.Close()
234
235 var users []User
236 for rows.Next() {
237 var user User
238 var password sql.NullString
239 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
240 return nil, err
241 }
242 user.Password = password.String
243 users = append(users, user)
244 }
245 if err := rows.Err(); err != nil {
246 return nil, err
247 }
248
249 return users, nil
250}
251
252func (db *DB) GetUser(username string) (*User, error) {
253 db.lock.RLock()
254 defer db.lock.RUnlock()
255
256 user := &User{Username: username}
257
258 var password sql.NullString
259 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
260 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
261 return nil, err
262 }
263 user.Password = password.String
264 return user, nil
265}
266
267func (db *DB) StoreUser(user *User) error {
268 db.lock.Lock()
269 defer db.lock.Unlock()
270
271 password := toNullString(user.Password)
272
273 var err error
274 if user.ID != 0 {
275 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
276 password, user.Admin, user.Username)
277 } else {
278 var res sql.Result
279 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
280 user.Username, password, user.Admin)
281 if err != nil {
282 return err
283 }
284 user.ID, err = res.LastInsertId()
285 }
286
287 return err
288}
289
290func (db *DB) DeleteUser(id int64) error {
291 db.lock.Lock()
292 defer db.lock.Unlock()
293
294 tx, err := db.db.Begin()
295 if err != nil {
296 return err
297 }
298 defer tx.Rollback()
299
300 _, err = tx.Exec(`DELETE FROM Channel
301 WHERE id IN (
302 SELECT Channel.id
303 FROM Channel
304 JOIN Network ON Channel.network = Network.id
305 WHERE Network.user = ?
306 )`, id)
307 if err != nil {
308 return err
309 }
310
311 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
312 if err != nil {
313 return err
314 }
315
316 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
317 if err != nil {
318 return err
319 }
320
321 return tx.Commit()
322}
323
324func (db *DB) ListNetworks(userID int64) ([]Network, error) {
325 db.lock.RLock()
326 defer db.lock.RUnlock()
327
328 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
329 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
330 sasl_external_cert, sasl_external_key
331 FROM Network
332 WHERE user = ?`,
333 userID)
334 if err != nil {
335 return nil, err
336 }
337 defer rows.Close()
338
339 var networks []Network
340 for rows.Next() {
341 var net Network
342 var name, username, realname, pass, connectCommands sql.NullString
343 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
344 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
345 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
346 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
347 if err != nil {
348 return nil, err
349 }
350 net.Name = name.String
351 net.Username = username.String
352 net.Realname = realname.String
353 net.Pass = pass.String
354 if connectCommands.Valid {
355 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
356 }
357 net.SASL.Mechanism = saslMechanism.String
358 net.SASL.Plain.Username = saslPlainUsername.String
359 net.SASL.Plain.Password = saslPlainPassword.String
360 networks = append(networks, net)
361 }
362 if err := rows.Err(); err != nil {
363 return nil, err
364 }
365
366 return networks, nil
367}
368
369func (db *DB) StoreNetwork(userID int64, network *Network) error {
370 db.lock.Lock()
371 defer db.lock.Unlock()
372
373 netName := toNullString(network.Name)
374 netUsername := toNullString(network.Username)
375 realname := toNullString(network.Realname)
376 pass := toNullString(network.Pass)
377 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
378
379 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
380 if network.SASL.Mechanism != "" {
381 saslMechanism = toNullString(network.SASL.Mechanism)
382 switch network.SASL.Mechanism {
383 case "PLAIN":
384 saslPlainUsername = toNullString(network.SASL.Plain.Username)
385 saslPlainPassword = toNullString(network.SASL.Plain.Password)
386 network.SASL.External.CertBlob = nil
387 network.SASL.External.PrivKeyBlob = nil
388 case "EXTERNAL":
389 // keep saslPlain* nil
390 default:
391 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
392 }
393 }
394
395 var err error
396 if network.ID != 0 {
397 _, err = db.db.Exec(`UPDATE Network
398 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
399 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
400 sasl_external_cert = ?, sasl_external_key = ?
401 WHERE id = ?`,
402 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
403 saslMechanism, saslPlainUsername, saslPlainPassword,
404 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
405 network.ID)
406 } else {
407 var res sql.Result
408 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
409 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
410 sasl_plain_password, sasl_external_cert, sasl_external_key)
411 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
412 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
413 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
414 network.SASL.External.PrivKeyBlob)
415 if err != nil {
416 return err
417 }
418 network.ID, err = res.LastInsertId()
419 }
420 return err
421}
422
423func (db *DB) DeleteNetwork(id int64) error {
424 db.lock.Lock()
425 defer db.lock.Unlock()
426
427 tx, err := db.db.Begin()
428 if err != nil {
429 return err
430 }
431 defer tx.Rollback()
432
433 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
434 if err != nil {
435 return err
436 }
437
438 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
439 if err != nil {
440 return err
441 }
442
443 return tx.Commit()
444}
445
446func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
447 db.lock.RLock()
448 defer db.lock.RUnlock()
449
450 rows, err := db.db.Query(`SELECT id, name, key, detached
451 FROM Channel
452 WHERE network = ?`, networkID)
453 if err != nil {
454 return nil, err
455 }
456 defer rows.Close()
457
458 var channels []Channel
459 for rows.Next() {
460 var ch Channel
461 var key sql.NullString
462 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
463 return nil, err
464 }
465 ch.Key = key.String
466 channels = append(channels, ch)
467 }
468 if err := rows.Err(); err != nil {
469 return nil, err
470 }
471
472 return channels, nil
473}
474
475func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
476 db.lock.Lock()
477 defer db.lock.Unlock()
478
479 key := toNullString(ch.Key)
480
481 var err error
482 if ch.ID != 0 {
483 _, err = db.db.Exec(`UPDATE Channel
484 SET network = ?, name = ?, key = ?, detached = ?
485 WHERE id = ?`,
486 networkID, ch.Name, key, ch.Detached, ch.ID)
487 } else {
488 var res sql.Result
489 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
490 VALUES (?, ?, ?, ?)`,
491 networkID, ch.Name, key, ch.Detached)
492 if err != nil {
493 return err
494 }
495 ch.ID, err = res.LastInsertId()
496 }
497 return err
498}
499
500func (db *DB) DeleteChannel(id int64) error {
501 db.lock.Lock()
502 defer db.lock.Unlock()
503
504 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
505 return err
506}
Note: See TracBrowser for help on using the repository browser.