source: code/trunk/db.go@ 326

Last change on this file since 326 was 324, checked in by contact, 5 years ago

Introduce User.Created

For Network and Channel, the database only needed to define one Store
operation to create/update a record. However since User is missing an ID
we couldn't have a single StoreUser function like other types. We had
CreateUser and UpdatePassword. As new User fields get added (e.g. the
upcoming Admin flag) this isn't sustainable.

We could have CreateUser and UpdateUser, but this wouldn't be consistent
with other types. Instead, introduce User.Created which indicates
whether the record is already stored in the DB. This can be used in a
new StoreUser function to decide whether we need to UPDATE or INSERT
without relying on SQL constraints and INSERT OR UPDATE.

The ListUsers and GetUser functions set User.Created to true.

File size: 10.1 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Created bool
14 Username string
15 Password string // hashed
16}
17
18type SASL struct {
19 Mechanism string
20
21 Plain struct {
22 Username string
23 Password string
24 }
25
26 // TLS client certificate authentication.
27 External struct {
28 // X.509 certificate in DER form.
29 CertBlob []byte
30 // PKCS#8 private key in DER form.
31 PrivKeyBlob []byte
32 }
33}
34
35type Network struct {
36 ID int64
37 Name string
38 Addr string
39 Nick string
40 Username string
41 Realname string
42 Pass string
43 ConnectCommands []string
44 SASL SASL
45}
46
47func (net *Network) GetName() string {
48 if net.Name != "" {
49 return net.Name
50 }
51 return net.Addr
52}
53
54type Channel struct {
55 ID int64
56 Name string
57 Key string
58 Detached bool
59}
60
61const schema = `
62CREATE TABLE User (
63 username VARCHAR(255) PRIMARY KEY,
64 password VARCHAR(255) NOT NULL
65);
66
67CREATE TABLE Network (
68 id INTEGER PRIMARY KEY,
69 name VARCHAR(255),
70 user VARCHAR(255) NOT NULL,
71 addr VARCHAR(255) NOT NULL,
72 nick VARCHAR(255) NOT NULL,
73 username VARCHAR(255),
74 realname VARCHAR(255),
75 pass VARCHAR(255),
76 connect_commands VARCHAR(1023),
77 sasl_mechanism VARCHAR(255),
78 sasl_plain_username VARCHAR(255),
79 sasl_plain_password VARCHAR(255),
80 sasl_external_cert BLOB DEFAULT NULL,
81 sasl_external_key BLOB DEFAULT NULL,
82 FOREIGN KEY(user) REFERENCES User(username),
83 UNIQUE(user, addr, nick)
84);
85
86CREATE TABLE Channel (
87 id INTEGER PRIMARY KEY,
88 network INTEGER NOT NULL,
89 name VARCHAR(255) NOT NULL,
90 key VARCHAR(255),
91 detached INTEGER NOT NULL DEFAULT 0,
92 FOREIGN KEY(network) REFERENCES Network(id),
93 UNIQUE(network, name)
94);
95`
96
97var migrations = []string{
98 "", // migration #0 is reserved for schema initialization
99 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
100 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
101 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
102 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
103}
104
105type DB struct {
106 lock sync.RWMutex
107 db *sql.DB
108}
109
110func OpenSQLDB(driver, source string) (*DB, error) {
111 sqlDB, err := sql.Open(driver, source)
112 if err != nil {
113 return nil, err
114 }
115
116 db := &DB{db: sqlDB}
117 if err := db.upgrade(); err != nil {
118 return nil, err
119 }
120
121 return db, nil
122}
123
124func (db *DB) Close() error {
125 db.lock.Lock()
126 defer db.lock.Unlock()
127 return db.Close()
128}
129
130func (db *DB) upgrade() error {
131 db.lock.Lock()
132 defer db.lock.Unlock()
133
134 var version int
135 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
136 return fmt.Errorf("failed to query schema version: %v", err)
137 }
138
139 if version == len(migrations) {
140 return nil
141 } else if version > len(migrations) {
142 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
143 }
144
145 tx, err := db.db.Begin()
146 if err != nil {
147 return err
148 }
149 defer tx.Rollback()
150
151 if version == 0 {
152 if _, err := tx.Exec(schema); err != nil {
153 return fmt.Errorf("failed to initialize schema: %v", err)
154 }
155 } else {
156 for i := version; i < len(migrations); i++ {
157 if _, err := tx.Exec(migrations[i]); err != nil {
158 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
159 }
160 }
161 }
162
163 // For some reason prepared statements don't work here
164 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
165 if err != nil {
166 return fmt.Errorf("failed to bump schema version: %v", err)
167 }
168
169 return tx.Commit()
170}
171
172func fromStringPtr(ptr *string) string {
173 if ptr == nil {
174 return ""
175 }
176 return *ptr
177}
178
179func toStringPtr(s string) *string {
180 if s == "" {
181 return nil
182 }
183 return &s
184}
185
186func (db *DB) ListUsers() ([]User, error) {
187 db.lock.RLock()
188 defer db.lock.RUnlock()
189
190 rows, err := db.db.Query("SELECT username, password FROM User")
191 if err != nil {
192 return nil, err
193 }
194 defer rows.Close()
195
196 var users []User
197 for rows.Next() {
198 var user User
199 var password *string
200 if err := rows.Scan(&user.Username, &password); err != nil {
201 return nil, err
202 }
203 user.Created = true
204 user.Password = fromStringPtr(password)
205 users = append(users, user)
206 }
207 if err := rows.Err(); err != nil {
208 return nil, err
209 }
210
211 return users, nil
212}
213
214func (db *DB) GetUser(username string) (*User, error) {
215 db.lock.RLock()
216 defer db.lock.RUnlock()
217
218 user := &User{Created: true, Username: username}
219
220 var password *string
221 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
222 if err := row.Scan(&password); err != nil {
223 return nil, err
224 }
225 user.Password = fromStringPtr(password)
226 return user, nil
227}
228
229func (db *DB) StoreUser(user *User) error {
230 db.lock.Lock()
231 defer db.lock.Unlock()
232
233 password := toStringPtr(user.Password)
234
235 var err error
236 if user.Created {
237 _, err = db.db.Exec("UPDATE User SET password = ? WHERE username = ?",
238 password, user.Username)
239 } else {
240 _, err = db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)",
241 user.Username, password)
242 if err == nil {
243 user.Created = true
244 }
245 }
246
247 return err
248}
249
250func (db *DB) ListNetworks(username string) ([]Network, error) {
251 db.lock.RLock()
252 defer db.lock.RUnlock()
253
254 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
255 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
256 sasl_external_cert, sasl_external_key
257 FROM Network
258 WHERE user = ?`,
259 username)
260 if err != nil {
261 return nil, err
262 }
263 defer rows.Close()
264
265 var networks []Network
266 for rows.Next() {
267 var net Network
268 var name, username, realname, pass, connectCommands *string
269 var saslMechanism, saslPlainUsername, saslPlainPassword *string
270 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
271 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
272 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
273 if err != nil {
274 return nil, err
275 }
276 net.Name = fromStringPtr(name)
277 net.Username = fromStringPtr(username)
278 net.Realname = fromStringPtr(realname)
279 net.Pass = fromStringPtr(pass)
280 if connectCommands != nil {
281 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
282 }
283 net.SASL.Mechanism = fromStringPtr(saslMechanism)
284 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
285 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
286 networks = append(networks, net)
287 }
288 if err := rows.Err(); err != nil {
289 return nil, err
290 }
291
292 return networks, nil
293}
294
295func (db *DB) StoreNetwork(username string, network *Network) error {
296 db.lock.Lock()
297 defer db.lock.Unlock()
298
299 netName := toStringPtr(network.Name)
300 netUsername := toStringPtr(network.Username)
301 realname := toStringPtr(network.Realname)
302 pass := toStringPtr(network.Pass)
303 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
304
305 var saslMechanism, saslPlainUsername, saslPlainPassword *string
306 if network.SASL.Mechanism != "" {
307 saslMechanism = &network.SASL.Mechanism
308 switch network.SASL.Mechanism {
309 case "PLAIN":
310 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
311 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
312 network.SASL.External.CertBlob = nil
313 network.SASL.External.PrivKeyBlob = nil
314 case "EXTERNAL":
315 // keep saslPlain* nil
316 default:
317 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
318 }
319 }
320
321 var err error
322 if network.ID != 0 {
323 _, err = db.db.Exec(`UPDATE Network
324 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
325 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
326 sasl_external_cert = ?, sasl_external_key = ?
327 WHERE id = ?`,
328 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
329 saslMechanism, saslPlainUsername, saslPlainPassword,
330 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
331 network.ID)
332 } else {
333 var res sql.Result
334 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
335 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
336 sasl_plain_password, sasl_external_cert, sasl_external_key)
337 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
338 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
339 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
340 network.SASL.External.PrivKeyBlob)
341 if err != nil {
342 return err
343 }
344 network.ID, err = res.LastInsertId()
345 }
346 return err
347}
348
349func (db *DB) DeleteNetwork(id int64) error {
350 db.lock.Lock()
351 defer db.lock.Unlock()
352
353 tx, err := db.db.Begin()
354 if err != nil {
355 return err
356 }
357 defer tx.Rollback()
358
359 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
360 if err != nil {
361 return err
362 }
363
364 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
365 if err != nil {
366 return err
367 }
368
369 return tx.Commit()
370}
371
372func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
373 db.lock.RLock()
374 defer db.lock.RUnlock()
375
376 rows, err := db.db.Query(`SELECT id, name, key, detached
377 FROM Channel
378 WHERE network = ?`, networkID)
379 if err != nil {
380 return nil, err
381 }
382 defer rows.Close()
383
384 var channels []Channel
385 for rows.Next() {
386 var ch Channel
387 var key *string
388 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
389 return nil, err
390 }
391 ch.Key = fromStringPtr(key)
392 channels = append(channels, ch)
393 }
394 if err := rows.Err(); err != nil {
395 return nil, err
396 }
397
398 return channels, nil
399}
400
401func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
402 db.lock.Lock()
403 defer db.lock.Unlock()
404
405 key := toStringPtr(ch.Key)
406
407 var err error
408 if ch.ID != 0 {
409 _, err = db.db.Exec(`UPDATE Channel
410 SET network = ?, name = ?, key = ?, detached = ?
411 WHERE id = ?`,
412 networkID, ch.Name, key, ch.Detached, ch.ID)
413 } else {
414 var res sql.Result
415 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
416 VALUES (?, ?, ?, ?)`,
417 networkID, ch.Name, key, ch.Detached)
418 if err != nil {
419 return err
420 }
421 ch.ID, err = res.LastInsertId()
422 }
423 return err
424}
425
426func (db *DB) DeleteChannel(networkID int64, name string) error {
427 db.lock.Lock()
428 defer db.lock.Unlock()
429
430 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
431 return err
432}
Note: See TracBrowser for help on using the repository browser.