source: code/trunk/db.go@ 363

Last change on this file since 363 was 356, checked in by contact, 5 years ago

Fix deadlock in DB.Close

This method was calling itself, instead of the underlying SQLite
database's Close method.

File size: 10.3 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
[324]13 Created bool
[77]14 Username string
[85]15 Password string // hashed
[327]16 Admin bool
[77]17}
18
[95]19type SASL struct {
20 Mechanism string
21
22 Plain struct {
23 Username string
24 Password string
25 }
[307]26
27 // TLS client certificate authentication.
28 External struct {
29 // X.509 certificate in DER form.
30 CertBlob []byte
31 // PKCS#8 private key in DER form.
32 PrivKeyBlob []byte
33 }
[95]34}
35
[77]36type Network struct {
[263]37 ID int64
38 Name string
39 Addr string
40 Nick string
41 Username string
42 Realname string
43 Pass string
44 ConnectCommands []string
45 SASL SASL
[77]46}
47
[149]48func (net *Network) GetName() string {
49 if net.Name != "" {
50 return net.Name
51 }
52 return net.Addr
53}
54
[77]55type Channel struct {
[284]56 ID int64
57 Name string
58 Key string
59 Detached bool
[77]60}
61
[255]62const schema = `
63CREATE TABLE User (
64 username VARCHAR(255) PRIMARY KEY,
[327]65 password VARCHAR(255) NOT NULL,
66 admin INTEGER NOT NULL DEFAULT 0
[255]67);
68
69CREATE TABLE Network (
70 id INTEGER PRIMARY KEY,
71 name VARCHAR(255),
72 user VARCHAR(255) NOT NULL,
73 addr VARCHAR(255) NOT NULL,
74 nick VARCHAR(255) NOT NULL,
75 username VARCHAR(255),
76 realname VARCHAR(255),
77 pass VARCHAR(255),
[263]78 connect_commands VARCHAR(1023),
[255]79 sasl_mechanism VARCHAR(255),
80 sasl_plain_username VARCHAR(255),
81 sasl_plain_password VARCHAR(255),
[307]82 sasl_external_cert BLOB DEFAULT NULL,
83 sasl_external_key BLOB DEFAULT NULL,
[255]84 FOREIGN KEY(user) REFERENCES User(username),
85 UNIQUE(user, addr, nick)
86);
87
88CREATE TABLE Channel (
89 id INTEGER PRIMARY KEY,
90 network INTEGER NOT NULL,
91 name VARCHAR(255) NOT NULL,
92 key VARCHAR(255),
[284]93 detached INTEGER NOT NULL DEFAULT 0,
[255]94 FOREIGN KEY(network) REFERENCES Network(id),
95 UNIQUE(network, name)
96);
97`
98
99var migrations = []string{
100 "", // migration #0 is reserved for schema initialization
[263]101 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]102 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]103 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
104 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[327]105 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
[255]106}
107
[77]108type DB struct {
[81]109 lock sync.RWMutex
[77]110 db *sql.DB
111}
112
113func OpenSQLDB(driver, source string) (*DB, error) {
[255]114 sqlDB, err := sql.Open(driver, source)
[77]115 if err != nil {
116 return nil, err
117 }
[255]118
119 db := &DB{db: sqlDB}
120 if err := db.upgrade(); err != nil {
121 return nil, err
122 }
123
124 return db, nil
[77]125}
126
127func (db *DB) Close() error {
128 db.lock.Lock()
129 defer db.lock.Unlock()
[356]130 return db.db.Close()
[77]131}
132
[255]133func (db *DB) upgrade() error {
134 db.lock.Lock()
135 defer db.lock.Unlock()
136
137 var version int
138 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
139 return fmt.Errorf("failed to query schema version: %v", err)
140 }
141
142 if version == len(migrations) {
143 return nil
144 } else if version > len(migrations) {
145 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
146 }
147
148 tx, err := db.db.Begin()
149 if err != nil {
150 return err
151 }
152 defer tx.Rollback()
153
154 if version == 0 {
155 if _, err := tx.Exec(schema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %v", err)
157 }
158 } else {
159 for i := version; i < len(migrations); i++ {
160 if _, err := tx.Exec(migrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 // For some reason prepared statements don't work here
167 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
[95]175func fromStringPtr(ptr *string) string {
176 if ptr == nil {
177 return ""
178 }
179 return *ptr
180}
181
182func toStringPtr(s string) *string {
183 if s == "" {
184 return nil
185 }
186 return &s
187}
188
[77]189func (db *DB) ListUsers() ([]User, error) {
[81]190 db.lock.RLock()
191 defer db.lock.RUnlock()
[77]192
[327]193 rows, err := db.db.Query("SELECT username, password, admin FROM User")
[77]194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var users []User
200 for rows.Next() {
201 var user User
202 var password *string
[327]203 if err := rows.Scan(&user.Username, &password, &user.Admin); err != nil {
[77]204 return nil, err
205 }
[324]206 user.Created = true
[95]207 user.Password = fromStringPtr(password)
[77]208 users = append(users, user)
209 }
210 if err := rows.Err(); err != nil {
211 return nil, err
212 }
213
214 return users, nil
215}
216
[173]217func (db *DB) GetUser(username string) (*User, error) {
218 db.lock.RLock()
219 defer db.lock.RUnlock()
220
[324]221 user := &User{Created: true, Username: username}
[173]222
223 var password *string
[327]224 row := db.db.QueryRow("SELECT password, admin FROM User WHERE username = ?", username)
225 if err := row.Scan(&password, &user.Admin); err != nil {
[173]226 return nil, err
227 }
228 user.Password = fromStringPtr(password)
229 return user, nil
230}
231
[324]232func (db *DB) StoreUser(user *User) error {
[84]233 db.lock.Lock()
234 defer db.lock.Unlock()
235
[95]236 password := toStringPtr(user.Password)
[84]237
[324]238 var err error
239 if user.Created {
[327]240 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
241 password, user.Admin, user.Username)
[324]242 } else {
[327]243 _, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
244 user.Username, password, user.Admin)
[324]245 if err == nil {
246 user.Created = true
247 }
248 }
[251]249
250 return err
251}
252
[77]253func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]254 db.lock.RLock()
255 defer db.lock.RUnlock()
[77]256
[118]257 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]258 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
259 sasl_external_cert, sasl_external_key
[95]260 FROM Network
261 WHERE user = ?`,
262 username)
[77]263 if err != nil {
264 return nil, err
265 }
266 defer rows.Close()
267
268 var networks []Network
269 for rows.Next() {
270 var net Network
[263]271 var name, username, realname, pass, connectCommands *string
[95]272 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]273 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]274 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
275 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]276 if err != nil {
[77]277 return nil, err
278 }
[118]279 net.Name = fromStringPtr(name)
[95]280 net.Username = fromStringPtr(username)
281 net.Realname = fromStringPtr(realname)
282 net.Pass = fromStringPtr(pass)
[263]283 if connectCommands != nil {
284 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
285 }
[95]286 net.SASL.Mechanism = fromStringPtr(saslMechanism)
287 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
288 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]289 networks = append(networks, net)
290 }
291 if err := rows.Err(); err != nil {
292 return nil, err
293 }
294
295 return networks, nil
296}
297
[90]298func (db *DB) StoreNetwork(username string, network *Network) error {
299 db.lock.Lock()
300 defer db.lock.Unlock()
301
[118]302 netName := toStringPtr(network.Name)
[95]303 netUsername := toStringPtr(network.Username)
304 realname := toStringPtr(network.Realname)
305 pass := toStringPtr(network.Pass)
[263]306 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]307
308 var saslMechanism, saslPlainUsername, saslPlainPassword *string
309 if network.SASL.Mechanism != "" {
310 saslMechanism = &network.SASL.Mechanism
311 switch network.SASL.Mechanism {
312 case "PLAIN":
313 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
314 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[307]315 network.SASL.External.CertBlob = nil
316 network.SASL.External.PrivKeyBlob = nil
317 case "EXTERNAL":
318 // keep saslPlain* nil
[148]319 default:
320 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]321 }
[90]322 }
323
324 var err error
325 if network.ID != 0 {
[93]326 _, err = db.db.Exec(`UPDATE Network
[263]327 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]328 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
329 sasl_external_cert = ?, sasl_external_key = ?
[93]330 WHERE id = ?`,
[263]331 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]332 saslMechanism, saslPlainUsername, saslPlainPassword,
333 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
334 network.ID)
[90]335 } else {
336 var res sql.Result
[118]337 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]338 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]339 sasl_plain_password, sasl_external_cert, sasl_external_key)
340 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[263]341 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]342 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
343 network.SASL.External.PrivKeyBlob)
[90]344 if err != nil {
345 return err
346 }
347 network.ID, err = res.LastInsertId()
348 }
349 return err
350}
351
[202]352func (db *DB) DeleteNetwork(id int64) error {
353 db.lock.Lock()
354 defer db.lock.Unlock()
355
356 tx, err := db.db.Begin()
357 if err != nil {
358 return err
359 }
360 defer tx.Rollback()
361
362 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
363 if err != nil {
364 return err
365 }
366
367 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
368 if err != nil {
369 return err
370 }
371
372 return tx.Commit()
373}
374
[77]375func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]376 db.lock.RLock()
377 defer db.lock.RUnlock()
[77]378
[284]379 rows, err := db.db.Query(`SELECT id, name, key, detached
380 FROM Channel
381 WHERE network = ?`, networkID)
[77]382 if err != nil {
383 return nil, err
384 }
385 defer rows.Close()
386
387 var channels []Channel
388 for rows.Next() {
389 var ch Channel
[146]390 var key *string
[284]391 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]392 return nil, err
393 }
[146]394 ch.Key = fromStringPtr(key)
[77]395 channels = append(channels, ch)
396 }
397 if err := rows.Err(); err != nil {
398 return nil, err
399 }
400
401 return channels, nil
402}
[89]403
404func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
405 db.lock.Lock()
406 defer db.lock.Unlock()
407
[146]408 key := toStringPtr(ch.Key)
[149]409
410 var err error
411 if ch.ID != 0 {
412 _, err = db.db.Exec(`UPDATE Channel
[284]413 SET network = ?, name = ?, key = ?, detached = ?
414 WHERE id = ?`,
415 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]416 } else {
417 var res sql.Result
[284]418 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
419 VALUES (?, ?, ?, ?)`,
420 networkID, ch.Name, key, ch.Detached)
[149]421 if err != nil {
422 return err
423 }
424 ch.ID, err = res.LastInsertId()
425 }
[89]426 return err
427}
428
429func (db *DB) DeleteChannel(networkID int64, name string) error {
430 db.lock.Lock()
431 defer db.lock.Unlock()
432
433 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
434 return err
435}
Note: See TracBrowser for help on using the repository browser.