source: code/trunk/db.go@ 262

Last change on this file since 262 was 255, checked in by contact, 5 years ago

Set up DB migration infrastructure

The database is now initialized automatically on first run. The schema
version is stored in SQLite's user_version special field. Migrations are
stored in an array and applied based on the schema version.

File size: 9.1 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
13 Password string // hashed
14}
15
16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
25type Network struct {
26 ID int64
27 Name string
28 Addr string
29 Nick string
30 Username string
31 Realname string
32 Pass string
33 SASL SASL
34}
35
36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
43type Channel struct {
44 ID int64
45 Name string
46 Key string
47}
48
49var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
50
51const schema = `
52CREATE TABLE User (
53 username VARCHAR(255) PRIMARY KEY,
54 password VARCHAR(255) NOT NULL
55);
56
57CREATE TABLE Network (
58 id INTEGER PRIMARY KEY,
59 name VARCHAR(255),
60 user VARCHAR(255) NOT NULL,
61 addr VARCHAR(255) NOT NULL,
62 nick VARCHAR(255) NOT NULL,
63 username VARCHAR(255),
64 realname VARCHAR(255),
65 pass VARCHAR(255),
66 sasl_mechanism VARCHAR(255),
67 sasl_plain_username VARCHAR(255),
68 sasl_plain_password VARCHAR(255),
69 FOREIGN KEY(user) REFERENCES User(username),
70 UNIQUE(user, addr, nick)
71);
72
73CREATE TABLE Channel (
74 id INTEGER PRIMARY KEY,
75 network INTEGER NOT NULL,
76 name VARCHAR(255) NOT NULL,
77 key VARCHAR(255),
78 FOREIGN KEY(network) REFERENCES Network(id),
79 UNIQUE(network, name)
80);
81`
82
83var migrations = []string{
84 "", // migration #0 is reserved for schema initialization
85}
86
87type DB struct {
88 lock sync.RWMutex
89 db *sql.DB
90}
91
92func OpenSQLDB(driver, source string) (*DB, error) {
93 sqlDB, err := sql.Open(driver, source)
94 if err != nil {
95 return nil, err
96 }
97
98 db := &DB{db: sqlDB}
99 if err := db.upgrade(); err != nil {
100 return nil, err
101 }
102
103 return db, nil
104}
105
106func (db *DB) Close() error {
107 db.lock.Lock()
108 defer db.lock.Unlock()
109 return db.Close()
110}
111
112func (db *DB) upgrade() error {
113 db.lock.Lock()
114 defer db.lock.Unlock()
115
116 var version int
117 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
118 return fmt.Errorf("failed to query schema version: %v", err)
119 }
120
121 if version == len(migrations) {
122 return nil
123 } else if version > len(migrations) {
124 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
125 }
126
127 tx, err := db.db.Begin()
128 if err != nil {
129 return err
130 }
131 defer tx.Rollback()
132
133 if version == 0 {
134 if _, err := tx.Exec(schema); err != nil {
135 return fmt.Errorf("failed to initialize schema: %v", err)
136 }
137 } else {
138 for i := version; i < len(migrations); i++ {
139 if _, err := tx.Exec(migrations[i]); err != nil {
140 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
141 }
142 }
143 }
144
145 // For some reason prepared statements don't work here
146 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
147 if err != nil {
148 return fmt.Errorf("failed to bump schema version: %v", err)
149 }
150
151 return tx.Commit()
152}
153
154func fromStringPtr(ptr *string) string {
155 if ptr == nil {
156 return ""
157 }
158 return *ptr
159}
160
161func toStringPtr(s string) *string {
162 if s == "" {
163 return nil
164 }
165 return &s
166}
167
168func (db *DB) ListUsers() ([]User, error) {
169 db.lock.RLock()
170 defer db.lock.RUnlock()
171
172 rows, err := db.db.Query("SELECT username, password FROM User")
173 if err != nil {
174 return nil, err
175 }
176 defer rows.Close()
177
178 var users []User
179 for rows.Next() {
180 var user User
181 var password *string
182 if err := rows.Scan(&user.Username, &password); err != nil {
183 return nil, err
184 }
185 user.Password = fromStringPtr(password)
186 users = append(users, user)
187 }
188 if err := rows.Err(); err != nil {
189 return nil, err
190 }
191
192 return users, nil
193}
194
195func (db *DB) GetUser(username string) (*User, error) {
196 db.lock.RLock()
197 defer db.lock.RUnlock()
198
199 user := &User{Username: username}
200
201 var password *string
202 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
203 if err := row.Scan(&password); err != nil {
204 return nil, err
205 }
206 user.Password = fromStringPtr(password)
207 return user, nil
208}
209
210func (db *DB) CreateUser(user *User) error {
211 db.lock.Lock()
212 defer db.lock.Unlock()
213
214 password := toStringPtr(user.Password)
215 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
216 return err
217}
218
219func (db *DB) UpdatePassword(user *User) error {
220 db.lock.Lock()
221 defer db.lock.Unlock()
222
223 password := toStringPtr(user.Password)
224 _, err := db.db.Exec(`UPDATE User
225 SET password = ?
226 WHERE username = ?`,
227 password, user.Username)
228 return err
229}
230
231func (db *DB) ListNetworks(username string) ([]Network, error) {
232 db.lock.RLock()
233 defer db.lock.RUnlock()
234
235 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
236 sasl_mechanism, sasl_plain_username, sasl_plain_password
237 FROM Network
238 WHERE user = ?`,
239 username)
240 if err != nil {
241 return nil, err
242 }
243 defer rows.Close()
244
245 var networks []Network
246 for rows.Next() {
247 var net Network
248 var name, username, realname, pass *string
249 var saslMechanism, saslPlainUsername, saslPlainPassword *string
250 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
251 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
252 if err != nil {
253 return nil, err
254 }
255 net.Name = fromStringPtr(name)
256 net.Username = fromStringPtr(username)
257 net.Realname = fromStringPtr(realname)
258 net.Pass = fromStringPtr(pass)
259 net.SASL.Mechanism = fromStringPtr(saslMechanism)
260 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
261 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
262 networks = append(networks, net)
263 }
264 if err := rows.Err(); err != nil {
265 return nil, err
266 }
267
268 return networks, nil
269}
270
271func (db *DB) StoreNetwork(username string, network *Network) error {
272 db.lock.Lock()
273 defer db.lock.Unlock()
274
275 netName := toStringPtr(network.Name)
276 netUsername := toStringPtr(network.Username)
277 realname := toStringPtr(network.Realname)
278 pass := toStringPtr(network.Pass)
279
280 var saslMechanism, saslPlainUsername, saslPlainPassword *string
281 if network.SASL.Mechanism != "" {
282 saslMechanism = &network.SASL.Mechanism
283 switch network.SASL.Mechanism {
284 case "PLAIN":
285 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
286 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
287 default:
288 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
289 }
290 }
291
292 var err error
293 if network.ID != 0 {
294 _, err = db.db.Exec(`UPDATE Network
295 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
296 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
297 WHERE id = ?`,
298 netName, network.Addr, network.Nick, netUsername, realname, pass,
299 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
300 } else {
301 var res sql.Result
302 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
303 realname, pass, sasl_mechanism, sasl_plain_username,
304 sasl_plain_password)
305 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
306 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
307 saslMechanism, saslPlainUsername, saslPlainPassword)
308 if err != nil {
309 return err
310 }
311 network.ID, err = res.LastInsertId()
312 }
313 return err
314}
315
316func (db *DB) DeleteNetwork(id int64) error {
317 db.lock.Lock()
318 defer db.lock.Unlock()
319
320 tx, err := db.db.Begin()
321 if err != nil {
322 return err
323 }
324 defer tx.Rollback()
325
326 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
327 if err != nil {
328 return err
329 }
330
331 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
332 if err != nil {
333 return err
334 }
335
336 return tx.Commit()
337}
338
339func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
340 db.lock.RLock()
341 defer db.lock.RUnlock()
342
343 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
344 if err != nil {
345 return nil, err
346 }
347 defer rows.Close()
348
349 var channels []Channel
350 for rows.Next() {
351 var ch Channel
352 var key *string
353 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
354 return nil, err
355 }
356 ch.Key = fromStringPtr(key)
357 channels = append(channels, ch)
358 }
359 if err := rows.Err(); err != nil {
360 return nil, err
361 }
362
363 return channels, nil
364}
365
366func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
367 db.lock.RLock()
368 defer db.lock.RUnlock()
369
370 ch := &Channel{Name: name}
371
372 var key *string
373 row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
374 if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
375 return nil, ErrNoSuchChannel
376 } else if err != nil {
377 return nil, err
378 }
379 ch.Key = fromStringPtr(key)
380 return ch, nil
381}
382
383func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
384 db.lock.Lock()
385 defer db.lock.Unlock()
386
387 key := toStringPtr(ch.Key)
388
389 var err error
390 if ch.ID != 0 {
391 _, err = db.db.Exec(`UPDATE Channel
392 SET network = ?, name = ?, key = ?
393 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
394 } else {
395 var res sql.Result
396 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
397 VALUES (?, ?, ?)`, networkID, ch.Name, key)
398 if err != nil {
399 return err
400 }
401 ch.ID, err = res.LastInsertId()
402 }
403 return err
404}
405
406func (db *DB) DeleteChannel(networkID int64, name string) error {
407 db.lock.Lock()
408 defer db.lock.Unlock()
409
410 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
411 return err
412}
Note: See TracBrowser for help on using the repository browser.