source: code/trunk/db.go@ 420

Last change on this file since 420 was 420, checked in by contact, 5 years ago

Add id column to User table

We used rowid before, but an explicit ID column is cleaner.

File size: 11.2 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 ID int64
14 Username string
15 Password string // hashed
16 Admin bool
17}
18
19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
34}
35
36type Network struct {
37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
46}
47
48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
55type Channel struct {
56 ID int64
57 Name string
58 Key string
59 Detached bool
60}
61
62const schema = `
63CREATE TABLE User (
64 id INTEGER PRIMARY KEY,
65 username VARCHAR(255) NOT NULL UNIQUE,
66 password VARCHAR(255),
67 admin INTEGER NOT NULL DEFAULT 0
68);
69
70CREATE TABLE Network (
71 id INTEGER PRIMARY KEY,
72 name VARCHAR(255),
73 user VARCHAR(255) NOT NULL,
74 addr VARCHAR(255) NOT NULL,
75 nick VARCHAR(255) NOT NULL,
76 username VARCHAR(255),
77 realname VARCHAR(255),
78 pass VARCHAR(255),
79 connect_commands VARCHAR(1023),
80 sasl_mechanism VARCHAR(255),
81 sasl_plain_username VARCHAR(255),
82 sasl_plain_password VARCHAR(255),
83 sasl_external_cert BLOB DEFAULT NULL,
84 sasl_external_key BLOB DEFAULT NULL,
85 FOREIGN KEY(user) REFERENCES User(username),
86 UNIQUE(user, addr, nick)
87);
88
89CREATE TABLE Channel (
90 id INTEGER PRIMARY KEY,
91 network INTEGER NOT NULL,
92 name VARCHAR(255) NOT NULL,
93 key VARCHAR(255),
94 detached INTEGER NOT NULL DEFAULT 0,
95 FOREIGN KEY(network) REFERENCES Network(id),
96 UNIQUE(network, name)
97);
98`
99
100var migrations = []string{
101 "", // migration #0 is reserved for schema initialization
102 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
103 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
104 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
105 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
106 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
107 `
108 CREATE TABLE UserNew (
109 id INTEGER PRIMARY KEY,
110 username VARCHAR(255) NOT NULL UNIQUE,
111 password VARCHAR(255),
112 admin INTEGER NOT NULL DEFAULT 0
113 );
114 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
115 DROP TABLE User;
116 ALTER TABLE UserNew RENAME TO User;
117 `,
118}
119
120type DB struct {
121 lock sync.RWMutex
122 db *sql.DB
123}
124
125func OpenSQLDB(driver, source string) (*DB, error) {
126 sqlDB, err := sql.Open(driver, source)
127 if err != nil {
128 return nil, err
129 }
130
131 db := &DB{db: sqlDB}
132 if err := db.upgrade(); err != nil {
133 return nil, err
134 }
135
136 return db, nil
137}
138
139func (db *DB) Close() error {
140 db.lock.Lock()
141 defer db.lock.Unlock()
142 return db.db.Close()
143}
144
145func (db *DB) upgrade() error {
146 db.lock.Lock()
147 defer db.lock.Unlock()
148
149 var version int
150 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
151 return fmt.Errorf("failed to query schema version: %v", err)
152 }
153
154 if version == len(migrations) {
155 return nil
156 } else if version > len(migrations) {
157 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
158 }
159
160 tx, err := db.db.Begin()
161 if err != nil {
162 return err
163 }
164 defer tx.Rollback()
165
166 if version == 0 {
167 if _, err := tx.Exec(schema); err != nil {
168 return fmt.Errorf("failed to initialize schema: %v", err)
169 }
170 } else {
171 for i := version; i < len(migrations); i++ {
172 if _, err := tx.Exec(migrations[i]); err != nil {
173 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
174 }
175 }
176 }
177
178 // For some reason prepared statements don't work here
179 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
180 if err != nil {
181 return fmt.Errorf("failed to bump schema version: %v", err)
182 }
183
184 return tx.Commit()
185}
186
187func fromStringPtr(ptr *string) string {
188 if ptr == nil {
189 return ""
190 }
191 return *ptr
192}
193
194func toStringPtr(s string) *string {
195 if s == "" {
196 return nil
197 }
198 return &s
199}
200
201func (db *DB) ListUsers() ([]User, error) {
202 db.lock.RLock()
203 defer db.lock.RUnlock()
204
205 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
206 if err != nil {
207 return nil, err
208 }
209 defer rows.Close()
210
211 var users []User
212 for rows.Next() {
213 var user User
214 var password *string
215 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
216 return nil, err
217 }
218 user.Password = fromStringPtr(password)
219 users = append(users, user)
220 }
221 if err := rows.Err(); err != nil {
222 return nil, err
223 }
224
225 return users, nil
226}
227
228func (db *DB) GetUser(username string) (*User, error) {
229 db.lock.RLock()
230 defer db.lock.RUnlock()
231
232 user := &User{Username: username}
233
234 var password *string
235 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
236 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
237 return nil, err
238 }
239 user.Password = fromStringPtr(password)
240 return user, nil
241}
242
243func (db *DB) StoreUser(user *User) error {
244 db.lock.Lock()
245 defer db.lock.Unlock()
246
247 password := toStringPtr(user.Password)
248
249 var err error
250 if user.ID != 0 {
251 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
252 password, user.Admin, user.Username)
253 } else {
254 var res sql.Result
255 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
256 user.Username, password, user.Admin)
257 if err != nil {
258 return err
259 }
260 user.ID, err = res.LastInsertId()
261 }
262
263 return err
264}
265
266func (db *DB) DeleteUser(username string) error {
267 db.lock.Lock()
268 defer db.lock.Unlock()
269
270 tx, err := db.db.Begin()
271 if err != nil {
272 return err
273 }
274 defer tx.Rollback()
275
276 _, err = tx.Exec(`DELETE FROM Channel
277 WHERE id IN (
278 SELECT Channel.id
279 FROM Channel
280 JOIN Network ON Channel.network = Network.id
281 WHERE Network.user = ?
282 )`, username)
283 if err != nil {
284 return err
285 }
286
287 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username)
288 if err != nil {
289 return err
290 }
291
292 _, err = tx.Exec("DELETE FROM User WHERE username = ?", username)
293 if err != nil {
294 return err
295 }
296
297 return tx.Commit()
298}
299
300func (db *DB) ListNetworks(username string) ([]Network, error) {
301 db.lock.RLock()
302 defer db.lock.RUnlock()
303
304 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
305 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
306 sasl_external_cert, sasl_external_key
307 FROM Network
308 WHERE user = ?`,
309 username)
310 if err != nil {
311 return nil, err
312 }
313 defer rows.Close()
314
315 var networks []Network
316 for rows.Next() {
317 var net Network
318 var name, username, realname, pass, connectCommands *string
319 var saslMechanism, saslPlainUsername, saslPlainPassword *string
320 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
321 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
322 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
323 if err != nil {
324 return nil, err
325 }
326 net.Name = fromStringPtr(name)
327 net.Username = fromStringPtr(username)
328 net.Realname = fromStringPtr(realname)
329 net.Pass = fromStringPtr(pass)
330 if connectCommands != nil {
331 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
332 }
333 net.SASL.Mechanism = fromStringPtr(saslMechanism)
334 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
335 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
336 networks = append(networks, net)
337 }
338 if err := rows.Err(); err != nil {
339 return nil, err
340 }
341
342 return networks, nil
343}
344
345func (db *DB) StoreNetwork(username string, network *Network) error {
346 db.lock.Lock()
347 defer db.lock.Unlock()
348
349 netName := toStringPtr(network.Name)
350 netUsername := toStringPtr(network.Username)
351 realname := toStringPtr(network.Realname)
352 pass := toStringPtr(network.Pass)
353 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
354
355 var saslMechanism, saslPlainUsername, saslPlainPassword *string
356 if network.SASL.Mechanism != "" {
357 saslMechanism = &network.SASL.Mechanism
358 switch network.SASL.Mechanism {
359 case "PLAIN":
360 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
361 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
362 network.SASL.External.CertBlob = nil
363 network.SASL.External.PrivKeyBlob = nil
364 case "EXTERNAL":
365 // keep saslPlain* nil
366 default:
367 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
368 }
369 }
370
371 var err error
372 if network.ID != 0 {
373 _, err = db.db.Exec(`UPDATE Network
374 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
375 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
376 sasl_external_cert = ?, sasl_external_key = ?
377 WHERE id = ?`,
378 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
379 saslMechanism, saslPlainUsername, saslPlainPassword,
380 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
381 network.ID)
382 } else {
383 var res sql.Result
384 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
385 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
386 sasl_plain_password, sasl_external_cert, sasl_external_key)
387 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
388 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
389 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
390 network.SASL.External.PrivKeyBlob)
391 if err != nil {
392 return err
393 }
394 network.ID, err = res.LastInsertId()
395 }
396 return err
397}
398
399func (db *DB) DeleteNetwork(id int64) error {
400 db.lock.Lock()
401 defer db.lock.Unlock()
402
403 tx, err := db.db.Begin()
404 if err != nil {
405 return err
406 }
407 defer tx.Rollback()
408
409 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
410 if err != nil {
411 return err
412 }
413
414 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
415 if err != nil {
416 return err
417 }
418
419 return tx.Commit()
420}
421
422func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
423 db.lock.RLock()
424 defer db.lock.RUnlock()
425
426 rows, err := db.db.Query(`SELECT id, name, key, detached
427 FROM Channel
428 WHERE network = ?`, networkID)
429 if err != nil {
430 return nil, err
431 }
432 defer rows.Close()
433
434 var channels []Channel
435 for rows.Next() {
436 var ch Channel
437 var key *string
438 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
439 return nil, err
440 }
441 ch.Key = fromStringPtr(key)
442 channels = append(channels, ch)
443 }
444 if err := rows.Err(); err != nil {
445 return nil, err
446 }
447
448 return channels, nil
449}
450
451func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
452 db.lock.Lock()
453 defer db.lock.Unlock()
454
455 key := toStringPtr(ch.Key)
456
457 var err error
458 if ch.ID != 0 {
459 _, err = db.db.Exec(`UPDATE Channel
460 SET network = ?, name = ?, key = ?, detached = ?
461 WHERE id = ?`,
462 networkID, ch.Name, key, ch.Detached, ch.ID)
463 } else {
464 var res sql.Result
465 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
466 VALUES (?, ?, ?, ?)`,
467 networkID, ch.Name, key, ch.Detached)
468 if err != nil {
469 return err
470 }
471 ch.ID, err = res.LastInsertId()
472 }
473 return err
474}
475
476func (db *DB) DeleteChannel(id int64) error {
477 db.lock.Lock()
478 defer db.lock.Unlock()
479
480 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
481 return err
482}
Note: See TracBrowser for help on using the repository browser.