source: code/trunk/db.go@ 410

Last change on this file since 410 was 393, checked in by dan.shick, 5 years ago

Fix store user query values

File size: 11.0 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 ID int64
14 Username string
15 Password string // hashed
16 Admin bool
17}
18
19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
34}
35
36type Network struct {
37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
46}
47
48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
55type Channel struct {
56 ID int64
57 Name string
58 Key string
59 Detached bool
60}
61
62const schema = `
63CREATE TABLE User (
64 username VARCHAR(255) PRIMARY KEY,
65 password VARCHAR(255) NOT NULL,
66 admin INTEGER NOT NULL DEFAULT 0
67);
68
69CREATE TABLE Network (
70 id INTEGER PRIMARY KEY,
71 name VARCHAR(255),
72 user VARCHAR(255) NOT NULL,
73 addr VARCHAR(255) NOT NULL,
74 nick VARCHAR(255) NOT NULL,
75 username VARCHAR(255),
76 realname VARCHAR(255),
77 pass VARCHAR(255),
78 connect_commands VARCHAR(1023),
79 sasl_mechanism VARCHAR(255),
80 sasl_plain_username VARCHAR(255),
81 sasl_plain_password VARCHAR(255),
82 sasl_external_cert BLOB DEFAULT NULL,
83 sasl_external_key BLOB DEFAULT NULL,
84 FOREIGN KEY(user) REFERENCES User(username),
85 UNIQUE(user, addr, nick)
86);
87
88CREATE TABLE Channel (
89 id INTEGER PRIMARY KEY,
90 network INTEGER NOT NULL,
91 name VARCHAR(255) NOT NULL,
92 key VARCHAR(255),
93 detached INTEGER NOT NULL DEFAULT 0,
94 FOREIGN KEY(network) REFERENCES Network(id),
95 UNIQUE(network, name)
96);
97`
98
99var migrations = []string{
100 "", // migration #0 is reserved for schema initialization
101 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
102 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
103 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
104 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
105 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
106}
107
108type DB struct {
109 lock sync.RWMutex
110 db *sql.DB
111}
112
113func OpenSQLDB(driver, source string) (*DB, error) {
114 sqlDB, err := sql.Open(driver, source)
115 if err != nil {
116 return nil, err
117 }
118
119 db := &DB{db: sqlDB}
120 if err := db.upgrade(); err != nil {
121 return nil, err
122 }
123
124 return db, nil
125}
126
127func (db *DB) Close() error {
128 db.lock.Lock()
129 defer db.lock.Unlock()
130 return db.db.Close()
131}
132
133func (db *DB) upgrade() error {
134 db.lock.Lock()
135 defer db.lock.Unlock()
136
137 var version int
138 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
139 return fmt.Errorf("failed to query schema version: %v", err)
140 }
141
142 if version == len(migrations) {
143 return nil
144 } else if version > len(migrations) {
145 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
146 }
147
148 tx, err := db.db.Begin()
149 if err != nil {
150 return err
151 }
152 defer tx.Rollback()
153
154 if version == 0 {
155 if _, err := tx.Exec(schema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %v", err)
157 }
158 } else {
159 for i := version; i < len(migrations); i++ {
160 if _, err := tx.Exec(migrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 // For some reason prepared statements don't work here
167 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
175func fromStringPtr(ptr *string) string {
176 if ptr == nil {
177 return ""
178 }
179 return *ptr
180}
181
182func toStringPtr(s string) *string {
183 if s == "" {
184 return nil
185 }
186 return &s
187}
188
189func (db *DB) ListUsers() ([]User, error) {
190 db.lock.RLock()
191 defer db.lock.RUnlock()
192
193 rows, err := db.db.Query("SELECT rowid, username, password, admin FROM User")
194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var users []User
200 for rows.Next() {
201 var user User
202 var password *string
203 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
204 return nil, err
205 }
206 user.Password = fromStringPtr(password)
207 users = append(users, user)
208 }
209 if err := rows.Err(); err != nil {
210 return nil, err
211 }
212
213 return users, nil
214}
215
216func (db *DB) GetUser(username string) (*User, error) {
217 db.lock.RLock()
218 defer db.lock.RUnlock()
219
220 user := &User{Username: username}
221
222 var password *string
223 row := db.db.QueryRow("SELECT rowid, password, admin FROM User WHERE username = ?", username)
224 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
225 return nil, err
226 }
227 user.Password = fromStringPtr(password)
228 return user, nil
229}
230
231func (db *DB) StoreUser(user *User) error {
232 db.lock.Lock()
233 defer db.lock.Unlock()
234
235 password := toStringPtr(user.Password)
236
237 var err error
238 if user.ID != 0 {
239 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
240 password, user.Admin, user.Username)
241 } else {
242 var res sql.Result
243 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
244 user.Username, password, user.Admin)
245 if err != nil {
246 return err
247 }
248 user.ID, err = res.LastInsertId()
249 }
250
251 return err
252}
253
254func (db *DB) DeleteUser(username string) error {
255 db.lock.Lock()
256 defer db.lock.Unlock()
257
258 tx, err := db.db.Begin()
259 if err != nil {
260 return err
261 }
262 defer tx.Rollback()
263
264 _, err = tx.Exec(`DELETE FROM Channel
265 WHERE id IN (
266 SELECT Channel.id
267 FROM Channel
268 JOIN Network ON Channel.network = Network.id
269 WHERE Network.user = ?
270 )`, username)
271 if err != nil {
272 return err
273 }
274
275 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username)
276 if err != nil {
277 return err
278 }
279
280 _, err = tx.Exec("DELETE FROM User WHERE username = ?", username)
281 if err != nil {
282 return err
283 }
284
285 return tx.Commit()
286}
287
288func (db *DB) ListNetworks(username string) ([]Network, error) {
289 db.lock.RLock()
290 defer db.lock.RUnlock()
291
292 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
293 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
294 sasl_external_cert, sasl_external_key
295 FROM Network
296 WHERE user = ?`,
297 username)
298 if err != nil {
299 return nil, err
300 }
301 defer rows.Close()
302
303 var networks []Network
304 for rows.Next() {
305 var net Network
306 var name, username, realname, pass, connectCommands *string
307 var saslMechanism, saslPlainUsername, saslPlainPassword *string
308 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
309 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
310 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
311 if err != nil {
312 return nil, err
313 }
314 net.Name = fromStringPtr(name)
315 net.Username = fromStringPtr(username)
316 net.Realname = fromStringPtr(realname)
317 net.Pass = fromStringPtr(pass)
318 if connectCommands != nil {
319 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
320 }
321 net.SASL.Mechanism = fromStringPtr(saslMechanism)
322 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
323 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
324 networks = append(networks, net)
325 }
326 if err := rows.Err(); err != nil {
327 return nil, err
328 }
329
330 return networks, nil
331}
332
333func (db *DB) StoreNetwork(username string, network *Network) error {
334 db.lock.Lock()
335 defer db.lock.Unlock()
336
337 netName := toStringPtr(network.Name)
338 netUsername := toStringPtr(network.Username)
339 realname := toStringPtr(network.Realname)
340 pass := toStringPtr(network.Pass)
341 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
342
343 var saslMechanism, saslPlainUsername, saslPlainPassword *string
344 if network.SASL.Mechanism != "" {
345 saslMechanism = &network.SASL.Mechanism
346 switch network.SASL.Mechanism {
347 case "PLAIN":
348 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
349 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
350 network.SASL.External.CertBlob = nil
351 network.SASL.External.PrivKeyBlob = nil
352 case "EXTERNAL":
353 // keep saslPlain* nil
354 default:
355 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
356 }
357 }
358
359 var err error
360 if network.ID != 0 {
361 _, err = db.db.Exec(`UPDATE Network
362 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
363 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
364 sasl_external_cert = ?, sasl_external_key = ?
365 WHERE id = ?`,
366 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
367 saslMechanism, saslPlainUsername, saslPlainPassword,
368 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
369 network.ID)
370 } else {
371 var res sql.Result
372 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
373 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
374 sasl_plain_password, sasl_external_cert, sasl_external_key)
375 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
376 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
377 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
378 network.SASL.External.PrivKeyBlob)
379 if err != nil {
380 return err
381 }
382 network.ID, err = res.LastInsertId()
383 }
384 return err
385}
386
387func (db *DB) DeleteNetwork(id int64) error {
388 db.lock.Lock()
389 defer db.lock.Unlock()
390
391 tx, err := db.db.Begin()
392 if err != nil {
393 return err
394 }
395 defer tx.Rollback()
396
397 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
398 if err != nil {
399 return err
400 }
401
402 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
403 if err != nil {
404 return err
405 }
406
407 return tx.Commit()
408}
409
410func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
411 db.lock.RLock()
412 defer db.lock.RUnlock()
413
414 rows, err := db.db.Query(`SELECT id, name, key, detached
415 FROM Channel
416 WHERE network = ?`, networkID)
417 if err != nil {
418 return nil, err
419 }
420 defer rows.Close()
421
422 var channels []Channel
423 for rows.Next() {
424 var ch Channel
425 var key *string
426 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
427 return nil, err
428 }
429 ch.Key = fromStringPtr(key)
430 channels = append(channels, ch)
431 }
432 if err := rows.Err(); err != nil {
433 return nil, err
434 }
435
436 return channels, nil
437}
438
439func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
440 db.lock.Lock()
441 defer db.lock.Unlock()
442
443 key := toStringPtr(ch.Key)
444
445 var err error
446 if ch.ID != 0 {
447 _, err = db.db.Exec(`UPDATE Channel
448 SET network = ?, name = ?, key = ?, detached = ?
449 WHERE id = ?`,
450 networkID, ch.Name, key, ch.Detached, ch.ID)
451 } else {
452 var res sql.Result
453 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
454 VALUES (?, ?, ?, ?)`,
455 networkID, ch.Name, key, ch.Detached)
456 if err != nil {
457 return err
458 }
459 ch.ID, err = res.LastInsertId()
460 }
461 return err
462}
463
464func (db *DB) DeleteChannel(networkID int64, name string) error {
465 db.lock.Lock()
466 defer db.lock.Unlock()
467
468 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
469 return err
470}
Note: See TracBrowser for help on using the repository browser.