source: code/trunk/db.go@ 402

Last change on this file since 402 was 393, checked in by dan.shick, 5 years ago

Fix store user query values

File size: 11.0 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
[382]13 ID int64
[77]14 Username string
[85]15 Password string // hashed
[327]16 Admin bool
[77]17}
18
[95]19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
[307]26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
[95]34}
35
[77]36type Network struct {
[263]37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
[77]46}
47
[149]48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
[77]55type Channel struct {
[284]56 ID int64
57 Name string
58 Key string
59 Detached bool
[77]60}
61
[255]62const schema = `
63CREATE TABLE User (
64 username VARCHAR(255) PRIMARY KEY,
[327]65 password VARCHAR(255) NOT NULL,
66 admin INTEGER NOT NULL DEFAULT 0
[255]67);
68
69CREATE TABLE Network (
70 id INTEGER PRIMARY KEY,
71 name VARCHAR(255),
72 user VARCHAR(255) NOT NULL,
73 addr VARCHAR(255) NOT NULL,
74 nick VARCHAR(255) NOT NULL,
75 username VARCHAR(255),
76 realname VARCHAR(255),
77 pass VARCHAR(255),
[263]78 connect_commands VARCHAR(1023),
[255]79 sasl_mechanism VARCHAR(255),
80 sasl_plain_username VARCHAR(255),
81 sasl_plain_password VARCHAR(255),
[307]82 sasl_external_cert BLOB DEFAULT NULL,
83 sasl_external_key BLOB DEFAULT NULL,
[255]84 FOREIGN KEY(user) REFERENCES User(username),
85 UNIQUE(user, addr, nick)
86);
87
88CREATE TABLE Channel (
89 id INTEGER PRIMARY KEY,
90 network INTEGER NOT NULL,
91 name VARCHAR(255) NOT NULL,
92 key VARCHAR(255),
[284]93 detached INTEGER NOT NULL DEFAULT 0,
[255]94 FOREIGN KEY(network) REFERENCES Network(id),
95 UNIQUE(network, name)
96);
97`
98
99var migrations = []string{
100 "", // migration #0 is reserved for schema initialization
[263]101 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]102 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]103 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
104 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[327]105 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
[255]106}
107
[77]108type DB struct {
[81]109 lock sync.RWMutex
[77]110 db *sql.DB
111}
112
113func OpenSQLDB(driver, source string) (*DB, error) {
[255]114 sqlDB, err := sql.Open(driver, source)
[77]115 if err != nil {
116 return nil, err
117 }
[255]118
119 db := &DB{db: sqlDB}
120 if err := db.upgrade(); err != nil {
121 return nil, err
122 }
123
124 return db, nil
[77]125}
126
127func (db *DB) Close() error {
128 db.lock.Lock()
129 defer db.lock.Unlock()
[356]130 return db.db.Close()
[77]131}
132
[255]133func (db *DB) upgrade() error {
134 db.lock.Lock()
135 defer db.lock.Unlock()
136
137 var version int
138 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
139 return fmt.Errorf("failed to query schema version: %v", err)
140 }
141
142 if version == len(migrations) {
143 return nil
144 } else if version > len(migrations) {
145 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
146 }
147
148 tx, err := db.db.Begin()
149 if err != nil {
150 return err
151 }
152 defer tx.Rollback()
153
154 if version == 0 {
155 if _, err := tx.Exec(schema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %v", err)
157 }
158 } else {
159 for i := version; i < len(migrations); i++ {
160 if _, err := tx.Exec(migrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 // For some reason prepared statements don't work here
167 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
[95]175func fromStringPtr(ptr *string) string {
176 if ptr == nil {
177 return ""
178 }
179 return *ptr
180}
181
182func toStringPtr(s string) *string {
183 if s == "" {
184 return nil
185 }
186 return &s
187}
188
[77]189func (db *DB) ListUsers() ([]User, error) {
[81]190 db.lock.RLock()
191 defer db.lock.RUnlock()
[77]192
[382]193 rows, err := db.db.Query("SELECT rowid, username, password, admin FROM User")
[77]194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var users []User
200 for rows.Next() {
201 var user User
202 var password *string
[382]203 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
[77]204 return nil, err
205 }
[95]206 user.Password = fromStringPtr(password)
[77]207 users = append(users, user)
208 }
209 if err := rows.Err(); err != nil {
210 return nil, err
211 }
212
213 return users, nil
214}
215
[173]216func (db *DB) GetUser(username string) (*User, error) {
217 db.lock.RLock()
218 defer db.lock.RUnlock()
219
[382]220 user := &User{Username: username}
[173]221
222 var password *string
[382]223 row := db.db.QueryRow("SELECT rowid, password, admin FROM User WHERE username = ?", username)
224 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
[173]225 return nil, err
226 }
227 user.Password = fromStringPtr(password)
228 return user, nil
229}
230
[324]231func (db *DB) StoreUser(user *User) error {
[84]232 db.lock.Lock()
233 defer db.lock.Unlock()
234
[95]235 password := toStringPtr(user.Password)
[84]236
[324]237 var err error
[382]238 if user.ID != 0 {
[327]239 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
240 password, user.Admin, user.Username)
[324]241 } else {
[382]242 var res sql.Result
[393]243 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
[327]244 user.Username, password, user.Admin)
[382]245 if err != nil {
246 return err
[324]247 }
[382]248 user.ID, err = res.LastInsertId()
[324]249 }
[251]250
251 return err
252}
253
[375]254func (db *DB) DeleteUser(username string) error {
255 db.lock.Lock()
256 defer db.lock.Unlock()
257
258 tx, err := db.db.Begin()
259 if err != nil {
260 return err
261 }
262 defer tx.Rollback()
263
264 _, err = tx.Exec(`DELETE FROM Channel
265 WHERE id IN (
266 SELECT Channel.id
267 FROM Channel
268 JOIN Network ON Channel.network = Network.id
269 WHERE Network.user = ?
270 )`, username)
271 if err != nil {
272 return err
273 }
274
275 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username)
276 if err != nil {
277 return err
278 }
279
280 _, err = tx.Exec("DELETE FROM User WHERE username = ?", username)
281 if err != nil {
282 return err
283 }
284
285 return tx.Commit()
286}
287
[77]288func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]289 db.lock.RLock()
290 defer db.lock.RUnlock()
[77]291
[118]292 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]293 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
294 sasl_external_cert, sasl_external_key
[95]295 FROM Network
296 WHERE user = ?`,
297 username)
[77]298 if err != nil {
299 return nil, err
300 }
301 defer rows.Close()
302
303 var networks []Network
304 for rows.Next() {
305 var net Network
[263]306 var name, username, realname, pass, connectCommands *string
[95]307 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]308 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]309 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
310 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]311 if err != nil {
[77]312 return nil, err
313 }
[118]314 net.Name = fromStringPtr(name)
[95]315 net.Username = fromStringPtr(username)
316 net.Realname = fromStringPtr(realname)
317 net.Pass = fromStringPtr(pass)
[263]318 if connectCommands != nil {
319 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
320 }
[95]321 net.SASL.Mechanism = fromStringPtr(saslMechanism)
322 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
323 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]324 networks = append(networks, net)
325 }
326 if err := rows.Err(); err != nil {
327 return nil, err
328 }
329
330 return networks, nil
331}
332
[90]333func (db *DB) StoreNetwork(username string, network *Network) error {
334 db.lock.Lock()
335 defer db.lock.Unlock()
336
[118]337 netName := toStringPtr(network.Name)
[95]338 netUsername := toStringPtr(network.Username)
339 realname := toStringPtr(network.Realname)
340 pass := toStringPtr(network.Pass)
[263]341 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]342
343 var saslMechanism, saslPlainUsername, saslPlainPassword *string
344 if network.SASL.Mechanism != "" {
345 saslMechanism = &network.SASL.Mechanism
346 switch network.SASL.Mechanism {
347 case "PLAIN":
348 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
349 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[307]350 network.SASL.External.CertBlob = nil
351 network.SASL.External.PrivKeyBlob = nil
352 case "EXTERNAL":
353 // keep saslPlain* nil
[148]354 default:
355 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]356 }
[90]357 }
358
359 var err error
360 if network.ID != 0 {
[93]361 _, err = db.db.Exec(`UPDATE Network
[263]362 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]363 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
364 sasl_external_cert = ?, sasl_external_key = ?
[93]365 WHERE id = ?`,
[263]366 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]367 saslMechanism, saslPlainUsername, saslPlainPassword,
368 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
369 network.ID)
[90]370 } else {
371 var res sql.Result
[118]372 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]373 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]374 sasl_plain_password, sasl_external_cert, sasl_external_key)
375 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[263]376 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]377 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
378 network.SASL.External.PrivKeyBlob)
[90]379 if err != nil {
380 return err
381 }
382 network.ID, err = res.LastInsertId()
383 }
384 return err
385}
386
[202]387func (db *DB) DeleteNetwork(id int64) error {
388 db.lock.Lock()
389 defer db.lock.Unlock()
390
391 tx, err := db.db.Begin()
392 if err != nil {
393 return err
394 }
395 defer tx.Rollback()
396
[375]397 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
[202]398 if err != nil {
399 return err
400 }
401
[375]402 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
[202]403 if err != nil {
404 return err
405 }
406
407 return tx.Commit()
408}
409
[77]410func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]411 db.lock.RLock()
412 defer db.lock.RUnlock()
[77]413
[284]414 rows, err := db.db.Query(`SELECT id, name, key, detached
415 FROM Channel
416 WHERE network = ?`, networkID)
[77]417 if err != nil {
418 return nil, err
419 }
420 defer rows.Close()
421
422 var channels []Channel
423 for rows.Next() {
424 var ch Channel
[146]425 var key *string
[284]426 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]427 return nil, err
428 }
[146]429 ch.Key = fromStringPtr(key)
[77]430 channels = append(channels, ch)
431 }
432 if err := rows.Err(); err != nil {
433 return nil, err
434 }
435
436 return channels, nil
437}
[89]438
439func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
440 db.lock.Lock()
441 defer db.lock.Unlock()
442
[146]443 key := toStringPtr(ch.Key)
[149]444
445 var err error
446 if ch.ID != 0 {
447 _, err = db.db.Exec(`UPDATE Channel
[284]448 SET network = ?, name = ?, key = ?, detached = ?
449 WHERE id = ?`,
450 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]451 } else {
452 var res sql.Result
[284]453 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
454 VALUES (?, ?, ?, ?)`,
455 networkID, ch.Name, key, ch.Detached)
[149]456 if err != nil {
457 return err
458 }
459 ch.ID, err = res.LastInsertId()
460 }
[89]461 return err
462}
463
464func (db *DB) DeleteChannel(networkID int64, name string) error {
465 db.lock.Lock()
466 defer db.lock.Unlock()
467
468 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
469 return err
470}
Note: See TracBrowser for help on using the repository browser.