source: code/trunk/db.go@ 380

Last change on this file since 380 was 375, checked in by contact, 5 years ago

Add DB.DeleteUser

File size: 10.9 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Created bool
14 Username string
15 Password string // hashed
16 Admin bool
17}
18
19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
34}
35
36type Network struct {
37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
46}
47
48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
55type Channel struct {
56 ID int64
57 Name string
58 Key string
59 Detached bool
60}
61
62const schema = `
63CREATE TABLE User (
64 username VARCHAR(255) PRIMARY KEY,
65 password VARCHAR(255) NOT NULL,
66 admin INTEGER NOT NULL DEFAULT 0
67);
68
69CREATE TABLE Network (
70 id INTEGER PRIMARY KEY,
71 name VARCHAR(255),
72 user VARCHAR(255) NOT NULL,
73 addr VARCHAR(255) NOT NULL,
74 nick VARCHAR(255) NOT NULL,
75 username VARCHAR(255),
76 realname VARCHAR(255),
77 pass VARCHAR(255),
78 connect_commands VARCHAR(1023),
79 sasl_mechanism VARCHAR(255),
80 sasl_plain_username VARCHAR(255),
81 sasl_plain_password VARCHAR(255),
82 sasl_external_cert BLOB DEFAULT NULL,
83 sasl_external_key BLOB DEFAULT NULL,
84 FOREIGN KEY(user) REFERENCES User(username),
85 UNIQUE(user, addr, nick)
86);
87
88CREATE TABLE Channel (
89 id INTEGER PRIMARY KEY,
90 network INTEGER NOT NULL,
91 name VARCHAR(255) NOT NULL,
92 key VARCHAR(255),
93 detached INTEGER NOT NULL DEFAULT 0,
94 FOREIGN KEY(network) REFERENCES Network(id),
95 UNIQUE(network, name)
96);
97`
98
99var migrations = []string{
100 "", // migration #0 is reserved for schema initialization
101 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
102 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
103 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
104 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
105 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
106}
107
108type DB struct {
109 lock sync.RWMutex
110 db *sql.DB
111}
112
113func OpenSQLDB(driver, source string) (*DB, error) {
114 sqlDB, err := sql.Open(driver, source)
115 if err != nil {
116 return nil, err
117 }
118
119 db := &DB{db: sqlDB}
120 if err := db.upgrade(); err != nil {
121 return nil, err
122 }
123
124 return db, nil
125}
126
127func (db *DB) Close() error {
128 db.lock.Lock()
129 defer db.lock.Unlock()
130 return db.db.Close()
131}
132
133func (db *DB) upgrade() error {
134 db.lock.Lock()
135 defer db.lock.Unlock()
136
137 var version int
138 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
139 return fmt.Errorf("failed to query schema version: %v", err)
140 }
141
142 if version == len(migrations) {
143 return nil
144 } else if version > len(migrations) {
145 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
146 }
147
148 tx, err := db.db.Begin()
149 if err != nil {
150 return err
151 }
152 defer tx.Rollback()
153
154 if version == 0 {
155 if _, err := tx.Exec(schema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %v", err)
157 }
158 } else {
159 for i := version; i < len(migrations); i++ {
160 if _, err := tx.Exec(migrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 // For some reason prepared statements don't work here
167 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
175func fromStringPtr(ptr *string) string {
176 if ptr == nil {
177 return ""
178 }
179 return *ptr
180}
181
182func toStringPtr(s string) *string {
183 if s == "" {
184 return nil
185 }
186 return &s
187}
188
189func (db *DB) ListUsers() ([]User, error) {
190 db.lock.RLock()
191 defer db.lock.RUnlock()
192
193 rows, err := db.db.Query("SELECT username, password, admin FROM User")
194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var users []User
200 for rows.Next() {
201 var user User
202 var password *string
203 if err := rows.Scan(&user.Username, &password, &user.Admin); err != nil {
204 return nil, err
205 }
206 user.Created = true
207 user.Password = fromStringPtr(password)
208 users = append(users, user)
209 }
210 if err := rows.Err(); err != nil {
211 return nil, err
212 }
213
214 return users, nil
215}
216
217func (db *DB) GetUser(username string) (*User, error) {
218 db.lock.RLock()
219 defer db.lock.RUnlock()
220
221 user := &User{Created: true, Username: username}
222
223 var password *string
224 row := db.db.QueryRow("SELECT password, admin FROM User WHERE username = ?", username)
225 if err := row.Scan(&password, &user.Admin); err != nil {
226 return nil, err
227 }
228 user.Password = fromStringPtr(password)
229 return user, nil
230}
231
232func (db *DB) StoreUser(user *User) error {
233 db.lock.Lock()
234 defer db.lock.Unlock()
235
236 password := toStringPtr(user.Password)
237
238 var err error
239 if user.Created {
240 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
241 password, user.Admin, user.Username)
242 } else {
243 _, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
244 user.Username, password, user.Admin)
245 if err == nil {
246 user.Created = true
247 }
248 }
249
250 return err
251}
252
253func (db *DB) DeleteUser(username string) error {
254 db.lock.Lock()
255 defer db.lock.Unlock()
256
257 tx, err := db.db.Begin()
258 if err != nil {
259 return err
260 }
261 defer tx.Rollback()
262
263 _, err = tx.Exec(`DELETE FROM Channel
264 WHERE id IN (
265 SELECT Channel.id
266 FROM Channel
267 JOIN Network ON Channel.network = Network.id
268 WHERE Network.user = ?
269 )`, username)
270 if err != nil {
271 return err
272 }
273
274 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username)
275 if err != nil {
276 return err
277 }
278
279 _, err = tx.Exec("DELETE FROM User WHERE username = ?", username)
280 if err != nil {
281 return err
282 }
283
284 return tx.Commit()
285}
286
287func (db *DB) ListNetworks(username string) ([]Network, error) {
288 db.lock.RLock()
289 defer db.lock.RUnlock()
290
291 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
292 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
293 sasl_external_cert, sasl_external_key
294 FROM Network
295 WHERE user = ?`,
296 username)
297 if err != nil {
298 return nil, err
299 }
300 defer rows.Close()
301
302 var networks []Network
303 for rows.Next() {
304 var net Network
305 var name, username, realname, pass, connectCommands *string
306 var saslMechanism, saslPlainUsername, saslPlainPassword *string
307 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
308 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
309 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
310 if err != nil {
311 return nil, err
312 }
313 net.Name = fromStringPtr(name)
314 net.Username = fromStringPtr(username)
315 net.Realname = fromStringPtr(realname)
316 net.Pass = fromStringPtr(pass)
317 if connectCommands != nil {
318 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
319 }
320 net.SASL.Mechanism = fromStringPtr(saslMechanism)
321 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
322 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
323 networks = append(networks, net)
324 }
325 if err := rows.Err(); err != nil {
326 return nil, err
327 }
328
329 return networks, nil
330}
331
332func (db *DB) StoreNetwork(username string, network *Network) error {
333 db.lock.Lock()
334 defer db.lock.Unlock()
335
336 netName := toStringPtr(network.Name)
337 netUsername := toStringPtr(network.Username)
338 realname := toStringPtr(network.Realname)
339 pass := toStringPtr(network.Pass)
340 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
341
342 var saslMechanism, saslPlainUsername, saslPlainPassword *string
343 if network.SASL.Mechanism != "" {
344 saslMechanism = &network.SASL.Mechanism
345 switch network.SASL.Mechanism {
346 case "PLAIN":
347 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
348 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
349 network.SASL.External.CertBlob = nil
350 network.SASL.External.PrivKeyBlob = nil
351 case "EXTERNAL":
352 // keep saslPlain* nil
353 default:
354 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
355 }
356 }
357
358 var err error
359 if network.ID != 0 {
360 _, err = db.db.Exec(`UPDATE Network
361 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
362 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
363 sasl_external_cert = ?, sasl_external_key = ?
364 WHERE id = ?`,
365 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
366 saslMechanism, saslPlainUsername, saslPlainPassword,
367 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
368 network.ID)
369 } else {
370 var res sql.Result
371 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
372 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
373 sasl_plain_password, sasl_external_cert, sasl_external_key)
374 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
375 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
376 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
377 network.SASL.External.PrivKeyBlob)
378 if err != nil {
379 return err
380 }
381 network.ID, err = res.LastInsertId()
382 }
383 return err
384}
385
386func (db *DB) DeleteNetwork(id int64) error {
387 db.lock.Lock()
388 defer db.lock.Unlock()
389
390 tx, err := db.db.Begin()
391 if err != nil {
392 return err
393 }
394 defer tx.Rollback()
395
396 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
397 if err != nil {
398 return err
399 }
400
401 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
402 if err != nil {
403 return err
404 }
405
406 return tx.Commit()
407}
408
409func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
410 db.lock.RLock()
411 defer db.lock.RUnlock()
412
413 rows, err := db.db.Query(`SELECT id, name, key, detached
414 FROM Channel
415 WHERE network = ?`, networkID)
416 if err != nil {
417 return nil, err
418 }
419 defer rows.Close()
420
421 var channels []Channel
422 for rows.Next() {
423 var ch Channel
424 var key *string
425 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
426 return nil, err
427 }
428 ch.Key = fromStringPtr(key)
429 channels = append(channels, ch)
430 }
431 if err := rows.Err(); err != nil {
432 return nil, err
433 }
434
435 return channels, nil
436}
437
438func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
439 db.lock.Lock()
440 defer db.lock.Unlock()
441
442 key := toStringPtr(ch.Key)
443
444 var err error
445 if ch.ID != 0 {
446 _, err = db.db.Exec(`UPDATE Channel
447 SET network = ?, name = ?, key = ?, detached = ?
448 WHERE id = ?`,
449 networkID, ch.Name, key, ch.Detached, ch.ID)
450 } else {
451 var res sql.Result
452 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
453 VALUES (?, ?, ?, ?)`,
454 networkID, ch.Name, key, ch.Detached)
455 if err != nil {
456 return err
457 }
458 ch.ID, err = res.LastInsertId()
459 }
460 return err
461}
462
463func (db *DB) DeleteChannel(networkID int64, name string) error {
464 db.lock.Lock()
465 defer db.lock.Unlock()
466
467 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
468 return err
469}
Note: See TracBrowser for help on using the repository browser.