source: code/trunk/db.go@ 377

Last change on this file since 377 was 375, checked in by contact, 5 years ago

Add DB.DeleteUser

File size: 10.9 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
[324]13 Created bool
[77]14 Username string
[85]15 Password string // hashed
[327]16 Admin bool
[77]17}
18
[95]19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
[307]26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
[95]34}
35
[77]36type Network struct {
[263]37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
[77]46}
47
[149]48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
[77]55type Channel struct {
[284]56 ID int64
57 Name string
58 Key string
59 Detached bool
[77]60}
61
[255]62const schema = `
63CREATE TABLE User (
64 username VARCHAR(255) PRIMARY KEY,
[327]65 password VARCHAR(255) NOT NULL,
66 admin INTEGER NOT NULL DEFAULT 0
[255]67);
68
69CREATE TABLE Network (
70 id INTEGER PRIMARY KEY,
71 name VARCHAR(255),
72 user VARCHAR(255) NOT NULL,
73 addr VARCHAR(255) NOT NULL,
74 nick VARCHAR(255) NOT NULL,
75 username VARCHAR(255),
76 realname VARCHAR(255),
77 pass VARCHAR(255),
[263]78 connect_commands VARCHAR(1023),
[255]79 sasl_mechanism VARCHAR(255),
80 sasl_plain_username VARCHAR(255),
81 sasl_plain_password VARCHAR(255),
[307]82 sasl_external_cert BLOB DEFAULT NULL,
83 sasl_external_key BLOB DEFAULT NULL,
[255]84 FOREIGN KEY(user) REFERENCES User(username),
85 UNIQUE(user, addr, nick)
86);
87
88CREATE TABLE Channel (
89 id INTEGER PRIMARY KEY,
90 network INTEGER NOT NULL,
91 name VARCHAR(255) NOT NULL,
92 key VARCHAR(255),
[284]93 detached INTEGER NOT NULL DEFAULT 0,
[255]94 FOREIGN KEY(network) REFERENCES Network(id),
95 UNIQUE(network, name)
96);
97`
98
99var migrations = []string{
100 "", // migration #0 is reserved for schema initialization
[263]101 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]102 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]103 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
104 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[327]105 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
[255]106}
107
[77]108type DB struct {
[81]109 lock sync.RWMutex
[77]110 db *sql.DB
111}
112
113func OpenSQLDB(driver, source string) (*DB, error) {
[255]114 sqlDB, err := sql.Open(driver, source)
[77]115 if err != nil {
116 return nil, err
117 }
[255]118
119 db := &DB{db: sqlDB}
120 if err := db.upgrade(); err != nil {
121 return nil, err
122 }
123
124 return db, nil
[77]125}
126
127func (db *DB) Close() error {
128 db.lock.Lock()
129 defer db.lock.Unlock()
[356]130 return db.db.Close()
[77]131}
132
[255]133func (db *DB) upgrade() error {
134 db.lock.Lock()
135 defer db.lock.Unlock()
136
137 var version int
138 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
139 return fmt.Errorf("failed to query schema version: %v", err)
140 }
141
142 if version == len(migrations) {
143 return nil
144 } else if version > len(migrations) {
145 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
146 }
147
148 tx, err := db.db.Begin()
149 if err != nil {
150 return err
151 }
152 defer tx.Rollback()
153
154 if version == 0 {
155 if _, err := tx.Exec(schema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %v", err)
157 }
158 } else {
159 for i := version; i < len(migrations); i++ {
160 if _, err := tx.Exec(migrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 // For some reason prepared statements don't work here
167 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
[95]175func fromStringPtr(ptr *string) string {
176 if ptr == nil {
177 return ""
178 }
179 return *ptr
180}
181
182func toStringPtr(s string) *string {
183 if s == "" {
184 return nil
185 }
186 return &s
187}
188
[77]189func (db *DB) ListUsers() ([]User, error) {
[81]190 db.lock.RLock()
191 defer db.lock.RUnlock()
[77]192
[327]193 rows, err := db.db.Query("SELECT username, password, admin FROM User")
[77]194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var users []User
200 for rows.Next() {
201 var user User
202 var password *string
[327]203 if err := rows.Scan(&user.Username, &password, &user.Admin); err != nil {
[77]204 return nil, err
205 }
[324]206 user.Created = true
[95]207 user.Password = fromStringPtr(password)
[77]208 users = append(users, user)
209 }
210 if err := rows.Err(); err != nil {
211 return nil, err
212 }
213
214 return users, nil
215}
216
[173]217func (db *DB) GetUser(username string) (*User, error) {
218 db.lock.RLock()
219 defer db.lock.RUnlock()
220
[324]221 user := &User{Created: true, Username: username}
[173]222
223 var password *string
[327]224 row := db.db.QueryRow("SELECT password, admin FROM User WHERE username = ?", username)
225 if err := row.Scan(&password, &user.Admin); err != nil {
[173]226 return nil, err
227 }
228 user.Password = fromStringPtr(password)
229 return user, nil
230}
231
[324]232func (db *DB) StoreUser(user *User) error {
[84]233 db.lock.Lock()
234 defer db.lock.Unlock()
235
[95]236 password := toStringPtr(user.Password)
[84]237
[324]238 var err error
239 if user.Created {
[327]240 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
241 password, user.Admin, user.Username)
[324]242 } else {
[327]243 _, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
244 user.Username, password, user.Admin)
[324]245 if err == nil {
246 user.Created = true
247 }
248 }
[251]249
250 return err
251}
252
[375]253func (db *DB) DeleteUser(username string) error {
254 db.lock.Lock()
255 defer db.lock.Unlock()
256
257 tx, err := db.db.Begin()
258 if err != nil {
259 return err
260 }
261 defer tx.Rollback()
262
263 _, err = tx.Exec(`DELETE FROM Channel
264 WHERE id IN (
265 SELECT Channel.id
266 FROM Channel
267 JOIN Network ON Channel.network = Network.id
268 WHERE Network.user = ?
269 )`, username)
270 if err != nil {
271 return err
272 }
273
274 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username)
275 if err != nil {
276 return err
277 }
278
279 _, err = tx.Exec("DELETE FROM User WHERE username = ?", username)
280 if err != nil {
281 return err
282 }
283
284 return tx.Commit()
285}
286
[77]287func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]288 db.lock.RLock()
289 defer db.lock.RUnlock()
[77]290
[118]291 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]292 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
293 sasl_external_cert, sasl_external_key
[95]294 FROM Network
295 WHERE user = ?`,
296 username)
[77]297 if err != nil {
298 return nil, err
299 }
300 defer rows.Close()
301
302 var networks []Network
303 for rows.Next() {
304 var net Network
[263]305 var name, username, realname, pass, connectCommands *string
[95]306 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]307 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]308 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
309 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]310 if err != nil {
[77]311 return nil, err
312 }
[118]313 net.Name = fromStringPtr(name)
[95]314 net.Username = fromStringPtr(username)
315 net.Realname = fromStringPtr(realname)
316 net.Pass = fromStringPtr(pass)
[263]317 if connectCommands != nil {
318 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
319 }
[95]320 net.SASL.Mechanism = fromStringPtr(saslMechanism)
321 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
322 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]323 networks = append(networks, net)
324 }
325 if err := rows.Err(); err != nil {
326 return nil, err
327 }
328
329 return networks, nil
330}
331
[90]332func (db *DB) StoreNetwork(username string, network *Network) error {
333 db.lock.Lock()
334 defer db.lock.Unlock()
335
[118]336 netName := toStringPtr(network.Name)
[95]337 netUsername := toStringPtr(network.Username)
338 realname := toStringPtr(network.Realname)
339 pass := toStringPtr(network.Pass)
[263]340 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]341
342 var saslMechanism, saslPlainUsername, saslPlainPassword *string
343 if network.SASL.Mechanism != "" {
344 saslMechanism = &network.SASL.Mechanism
345 switch network.SASL.Mechanism {
346 case "PLAIN":
347 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
348 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[307]349 network.SASL.External.CertBlob = nil
350 network.SASL.External.PrivKeyBlob = nil
351 case "EXTERNAL":
352 // keep saslPlain* nil
[148]353 default:
354 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]355 }
[90]356 }
357
358 var err error
359 if network.ID != 0 {
[93]360 _, err = db.db.Exec(`UPDATE Network
[263]361 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]362 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
363 sasl_external_cert = ?, sasl_external_key = ?
[93]364 WHERE id = ?`,
[263]365 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]366 saslMechanism, saslPlainUsername, saslPlainPassword,
367 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
368 network.ID)
[90]369 } else {
370 var res sql.Result
[118]371 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]372 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]373 sasl_plain_password, sasl_external_cert, sasl_external_key)
374 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[263]375 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]376 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
377 network.SASL.External.PrivKeyBlob)
[90]378 if err != nil {
379 return err
380 }
381 network.ID, err = res.LastInsertId()
382 }
383 return err
384}
385
[202]386func (db *DB) DeleteNetwork(id int64) error {
387 db.lock.Lock()
388 defer db.lock.Unlock()
389
390 tx, err := db.db.Begin()
391 if err != nil {
392 return err
393 }
394 defer tx.Rollback()
395
[375]396 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
[202]397 if err != nil {
398 return err
399 }
400
[375]401 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
[202]402 if err != nil {
403 return err
404 }
405
406 return tx.Commit()
407}
408
[77]409func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]410 db.lock.RLock()
411 defer db.lock.RUnlock()
[77]412
[284]413 rows, err := db.db.Query(`SELECT id, name, key, detached
414 FROM Channel
415 WHERE network = ?`, networkID)
[77]416 if err != nil {
417 return nil, err
418 }
419 defer rows.Close()
420
421 var channels []Channel
422 for rows.Next() {
423 var ch Channel
[146]424 var key *string
[284]425 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]426 return nil, err
427 }
[146]428 ch.Key = fromStringPtr(key)
[77]429 channels = append(channels, ch)
430 }
431 if err := rows.Err(); err != nil {
432 return nil, err
433 }
434
435 return channels, nil
436}
[89]437
438func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
439 db.lock.Lock()
440 defer db.lock.Unlock()
441
[146]442 key := toStringPtr(ch.Key)
[149]443
444 var err error
445 if ch.ID != 0 {
446 _, err = db.db.Exec(`UPDATE Channel
[284]447 SET network = ?, name = ?, key = ?, detached = ?
448 WHERE id = ?`,
449 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]450 } else {
451 var res sql.Result
[284]452 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
453 VALUES (?, ?, ?, ?)`,
454 networkID, ch.Name, key, ch.Detached)
[149]455 if err != nil {
456 return err
457 }
458 ch.ID, err = res.LastInsertId()
459 }
[89]460 return err
461}
462
463func (db *DB) DeleteChannel(networkID int64, name string) error {
464 db.lock.Lock()
465 defer db.lock.Unlock()
466
467 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
468 return err
469}
Note: See TracBrowser for help on using the repository browser.