source: code/trunk/db.go@ 384

Last change on this file since 384 was 382, checked in by contact, 5 years ago

Add User.ID

For now it's just a new field that'll be useful to generate user ident
strings. It uses the SQLite implicit rowid column. In the future the DB
interface will need to be updated to use user IDs instead of usernames.

File size: 11.0 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
[382]13 ID int64
[77]14 Username string
[85]15 Password string // hashed
[327]16 Admin bool
[77]17}
18
[95]19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
[307]26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
[95]34}
35
[77]36type Network struct {
[263]37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
[77]46}
47
[149]48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
[77]55type Channel struct {
[284]56 ID int64
57 Name string
58 Key string
59 Detached bool
[77]60}
61
[255]62const schema = `
63CREATE TABLE User (
64 username VARCHAR(255) PRIMARY KEY,
[327]65 password VARCHAR(255) NOT NULL,
66 admin INTEGER NOT NULL DEFAULT 0
[255]67);
68
69CREATE TABLE Network (
70 id INTEGER PRIMARY KEY,
71 name VARCHAR(255),
72 user VARCHAR(255) NOT NULL,
73 addr VARCHAR(255) NOT NULL,
74 nick VARCHAR(255) NOT NULL,
75 username VARCHAR(255),
76 realname VARCHAR(255),
77 pass VARCHAR(255),
[263]78 connect_commands VARCHAR(1023),
[255]79 sasl_mechanism VARCHAR(255),
80 sasl_plain_username VARCHAR(255),
81 sasl_plain_password VARCHAR(255),
[307]82 sasl_external_cert BLOB DEFAULT NULL,
83 sasl_external_key BLOB DEFAULT NULL,
[255]84 FOREIGN KEY(user) REFERENCES User(username),
85 UNIQUE(user, addr, nick)
86);
87
88CREATE TABLE Channel (
89 id INTEGER PRIMARY KEY,
90 network INTEGER NOT NULL,
91 name VARCHAR(255) NOT NULL,
92 key VARCHAR(255),
[284]93 detached INTEGER NOT NULL DEFAULT 0,
[255]94 FOREIGN KEY(network) REFERENCES Network(id),
95 UNIQUE(network, name)
96);
97`
98
99var migrations = []string{
100 "", // migration #0 is reserved for schema initialization
[263]101 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]102 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]103 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
104 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[327]105 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
[255]106}
107
[77]108type DB struct {
[81]109 lock sync.RWMutex
[77]110 db *sql.DB
111}
112
113func OpenSQLDB(driver, source string) (*DB, error) {
[255]114 sqlDB, err := sql.Open(driver, source)
[77]115 if err != nil {
116 return nil, err
117 }
[255]118
119 db := &DB{db: sqlDB}
120 if err := db.upgrade(); err != nil {
121 return nil, err
122 }
123
124 return db, nil
[77]125}
126
127func (db *DB) Close() error {
128 db.lock.Lock()
129 defer db.lock.Unlock()
[356]130 return db.db.Close()
[77]131}
132
[255]133func (db *DB) upgrade() error {
134 db.lock.Lock()
135 defer db.lock.Unlock()
136
137 var version int
138 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
139 return fmt.Errorf("failed to query schema version: %v", err)
140 }
141
142 if version == len(migrations) {
143 return nil
144 } else if version > len(migrations) {
145 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
146 }
147
148 tx, err := db.db.Begin()
149 if err != nil {
150 return err
151 }
152 defer tx.Rollback()
153
154 if version == 0 {
155 if _, err := tx.Exec(schema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %v", err)
157 }
158 } else {
159 for i := version; i < len(migrations); i++ {
160 if _, err := tx.Exec(migrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 // For some reason prepared statements don't work here
167 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
[95]175func fromStringPtr(ptr *string) string {
176 if ptr == nil {
177 return ""
178 }
179 return *ptr
180}
181
182func toStringPtr(s string) *string {
183 if s == "" {
184 return nil
185 }
186 return &s
187}
188
[77]189func (db *DB) ListUsers() ([]User, error) {
[81]190 db.lock.RLock()
191 defer db.lock.RUnlock()
[77]192
[382]193 rows, err := db.db.Query("SELECT rowid, username, password, admin FROM User")
[77]194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var users []User
200 for rows.Next() {
201 var user User
202 var password *string
[382]203 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
[77]204 return nil, err
205 }
[95]206 user.Password = fromStringPtr(password)
[77]207 users = append(users, user)
208 }
209 if err := rows.Err(); err != nil {
210 return nil, err
211 }
212
213 return users, nil
214}
215
[173]216func (db *DB) GetUser(username string) (*User, error) {
217 db.lock.RLock()
218 defer db.lock.RUnlock()
219
[382]220 user := &User{Username: username}
[173]221
222 var password *string
[382]223 row := db.db.QueryRow("SELECT rowid, password, admin FROM User WHERE username = ?", username)
224 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
[173]225 return nil, err
226 }
227 user.Password = fromStringPtr(password)
228 return user, nil
229}
230
[324]231func (db *DB) StoreUser(user *User) error {
[84]232 db.lock.Lock()
233 defer db.lock.Unlock()
234
[95]235 password := toStringPtr(user.Password)
[84]236
[324]237 var err error
[382]238 if user.ID != 0 {
[327]239 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
240 password, user.Admin, user.Username)
[324]241 } else {
[382]242 var res sql.Result
243 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?, ?)",
[327]244 user.Username, password, user.Admin)
[382]245 if err != nil {
246 return err
[324]247 }
[382]248 user.ID, err = res.LastInsertId()
[324]249 }
[251]250
251 return err
252}
253
[375]254func (db *DB) DeleteUser(username string) error {
255 db.lock.Lock()
256 defer db.lock.Unlock()
257
258 tx, err := db.db.Begin()
259 if err != nil {
260 return err
261 }
262 defer tx.Rollback()
263
264 _, err = tx.Exec(`DELETE FROM Channel
265 WHERE id IN (
266 SELECT Channel.id
267 FROM Channel
268 JOIN Network ON Channel.network = Network.id
269 WHERE Network.user = ?
270 )`, username)
271 if err != nil {
272 return err
273 }
274
275 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username)
276 if err != nil {
277 return err
278 }
279
280 _, err = tx.Exec("DELETE FROM User WHERE username = ?", username)
281 if err != nil {
282 return err
283 }
284
285 return tx.Commit()
286}
287
[77]288func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]289 db.lock.RLock()
290 defer db.lock.RUnlock()
[77]291
[118]292 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]293 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
294 sasl_external_cert, sasl_external_key
[95]295 FROM Network
296 WHERE user = ?`,
297 username)
[77]298 if err != nil {
299 return nil, err
300 }
301 defer rows.Close()
302
303 var networks []Network
304 for rows.Next() {
305 var net Network
[263]306 var name, username, realname, pass, connectCommands *string
[95]307 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]308 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]309 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
310 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]311 if err != nil {
[77]312 return nil, err
313 }
[118]314 net.Name = fromStringPtr(name)
[95]315 net.Username = fromStringPtr(username)
316 net.Realname = fromStringPtr(realname)
317 net.Pass = fromStringPtr(pass)
[263]318 if connectCommands != nil {
319 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
320 }
[95]321 net.SASL.Mechanism = fromStringPtr(saslMechanism)
322 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
323 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]324 networks = append(networks, net)
325 }
326 if err := rows.Err(); err != nil {
327 return nil, err
328 }
329
330 return networks, nil
331}
332
[90]333func (db *DB) StoreNetwork(username string, network *Network) error {
334 db.lock.Lock()
335 defer db.lock.Unlock()
336
[118]337 netName := toStringPtr(network.Name)
[95]338 netUsername := toStringPtr(network.Username)
339 realname := toStringPtr(network.Realname)
340 pass := toStringPtr(network.Pass)
[263]341 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]342
343 var saslMechanism, saslPlainUsername, saslPlainPassword *string
344 if network.SASL.Mechanism != "" {
345 saslMechanism = &network.SASL.Mechanism
346 switch network.SASL.Mechanism {
347 case "PLAIN":
348 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
349 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[307]350 network.SASL.External.CertBlob = nil
351 network.SASL.External.PrivKeyBlob = nil
352 case "EXTERNAL":
353 // keep saslPlain* nil
[148]354 default:
355 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]356 }
[90]357 }
358
359 var err error
360 if network.ID != 0 {
[93]361 _, err = db.db.Exec(`UPDATE Network
[263]362 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]363 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
364 sasl_external_cert = ?, sasl_external_key = ?
[93]365 WHERE id = ?`,
[263]366 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]367 saslMechanism, saslPlainUsername, saslPlainPassword,
368 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
369 network.ID)
[90]370 } else {
371 var res sql.Result
[118]372 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]373 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]374 sasl_plain_password, sasl_external_cert, sasl_external_key)
375 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[263]376 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]377 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
378 network.SASL.External.PrivKeyBlob)
[90]379 if err != nil {
380 return err
381 }
382 network.ID, err = res.LastInsertId()
383 }
384 return err
385}
386
[202]387func (db *DB) DeleteNetwork(id int64) error {
388 db.lock.Lock()
389 defer db.lock.Unlock()
390
391 tx, err := db.db.Begin()
392 if err != nil {
393 return err
394 }
395 defer tx.Rollback()
396
[375]397 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
[202]398 if err != nil {
399 return err
400 }
401
[375]402 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
[202]403 if err != nil {
404 return err
405 }
406
407 return tx.Commit()
408}
409
[77]410func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]411 db.lock.RLock()
412 defer db.lock.RUnlock()
[77]413
[284]414 rows, err := db.db.Query(`SELECT id, name, key, detached
415 FROM Channel
416 WHERE network = ?`, networkID)
[77]417 if err != nil {
418 return nil, err
419 }
420 defer rows.Close()
421
422 var channels []Channel
423 for rows.Next() {
424 var ch Channel
[146]425 var key *string
[284]426 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]427 return nil, err
428 }
[146]429 ch.Key = fromStringPtr(key)
[77]430 channels = append(channels, ch)
431 }
432 if err := rows.Err(); err != nil {
433 return nil, err
434 }
435
436 return channels, nil
437}
[89]438
439func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
440 db.lock.Lock()
441 defer db.lock.Unlock()
442
[146]443 key := toStringPtr(ch.Key)
[149]444
445 var err error
446 if ch.ID != 0 {
447 _, err = db.db.Exec(`UPDATE Channel
[284]448 SET network = ?, name = ?, key = ?, detached = ?
449 WHERE id = ?`,
450 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]451 } else {
452 var res sql.Result
[284]453 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
454 VALUES (?, ?, ?, ?)`,
455 networkID, ch.Name, key, ch.Detached)
[149]456 if err != nil {
457 return err
458 }
459 ch.ID, err = res.LastInsertId()
460 }
[89]461 return err
462}
463
464func (db *DB) DeleteChannel(networkID int64, name string) error {
465 db.lock.Lock()
466 defer db.lock.Unlock()
467
468 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
469 return err
470}
Note: See TracBrowser for help on using the repository browser.