source: code/trunk/db.go@ 462

Last change on this file since 462 was 457, checked in by contact, 4 years ago

Add Network.{URL,GetUsername,GetRealname}

Just a bunch of helpers that can be re-used.

File size: 14.1 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[434]6 "math"
[457]7 "net/url"
[263]8 "strings"
[77]9 "sync"
[434]10 "time"
[77]11
12 _ "github.com/mattn/go-sqlite3"
13)
14
15type User struct {
[382]16 ID int64
[77]17 Username string
[85]18 Password string // hashed
[327]19 Admin bool
[77]20}
21
[95]22type SASL struct {
23 Mechanism string
24
25 Plain struct {
26 Username string
27 Password string
28 }
[307]29
30 // TLS client certificate authentication.
31 External struct {
32 // X.509 certificate in DER form.
33 CertBlob []byte
34 // PKCS#8 private key in DER form.
35 PrivKeyBlob []byte
36 }
[95]37}
38
[77]39type Network struct {
[263]40 ID int64
41 Name string
42 Addr string
43 Nick string
44 Username string
45 Realname string
46 Pass string
47 ConnectCommands []string
48 SASL SASL
[77]49}
50
[149]51func (net *Network) GetName() string {
52 if net.Name != "" {
53 return net.Name
54 }
55 return net.Addr
56}
57
[457]58func (net *Network) URL() (*url.URL, error) {
59 s := net.Addr
60 if !strings.Contains(s, "://") {
61 // This is a raw domain name, make it an URL with the default scheme
62 s = "ircs://" + s
63 }
64
65 u, err := url.Parse(s)
66 if err != nil {
67 return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
68 }
69
70 return u, nil
71}
72
73func (net *Network) GetUsername() string {
74 if net.Username != "" {
75 return net.Username
76 }
77 return net.Nick
78}
79
80func (net *Network) GetRealname() string {
81 if net.Realname != "" {
82 return net.Realname
83 }
84 return net.Nick
85}
86
[434]87type MessageFilter int
88
89const (
90 // TODO: use customizable user defaults for FilterDefault
91 FilterDefault MessageFilter = iota
92 FilterNone
93 FilterHighlight
94 FilterMessage
95)
96
97func parseFilter(filter string) (MessageFilter, error) {
98 switch filter {
99 case "default":
100 return FilterDefault, nil
101 case "none":
102 return FilterNone, nil
103 case "highlight":
104 return FilterHighlight, nil
105 case "message":
106 return FilterMessage, nil
107 }
108 return 0, fmt.Errorf("unknown filter: %q", filter)
109}
110
[77]111type Channel struct {
[284]112 ID int64
113 Name string
114 Key string
115 Detached bool
[434]116
117 RelayDetached MessageFilter
118 ReattachOn MessageFilter
119 DetachAfter time.Duration
120 DetachOn MessageFilter
[77]121}
122
[255]123const schema = `
124CREATE TABLE User (
[420]125 id INTEGER PRIMARY KEY,
126 username VARCHAR(255) NOT NULL UNIQUE,
127 password VARCHAR(255),
[327]128 admin INTEGER NOT NULL DEFAULT 0
[255]129);
130
131CREATE TABLE Network (
132 id INTEGER PRIMARY KEY,
133 name VARCHAR(255),
[421]134 user INTEGER NOT NULL,
[255]135 addr VARCHAR(255) NOT NULL,
136 nick VARCHAR(255) NOT NULL,
137 username VARCHAR(255),
138 realname VARCHAR(255),
139 pass VARCHAR(255),
[263]140 connect_commands VARCHAR(1023),
[255]141 sasl_mechanism VARCHAR(255),
142 sasl_plain_username VARCHAR(255),
143 sasl_plain_password VARCHAR(255),
[307]144 sasl_external_cert BLOB DEFAULT NULL,
145 sasl_external_key BLOB DEFAULT NULL,
[421]146 FOREIGN KEY(user) REFERENCES User(id),
147 UNIQUE(user, addr, nick),
148 UNIQUE(user, name)
[255]149);
150
151CREATE TABLE Channel (
152 id INTEGER PRIMARY KEY,
153 network INTEGER NOT NULL,
154 name VARCHAR(255) NOT NULL,
155 key VARCHAR(255),
[284]156 detached INTEGER NOT NULL DEFAULT 0,
[434]157 relay_detached INTEGER NOT NULL DEFAULT 0,
158 reattach_on INTEGER NOT NULL DEFAULT 0,
159 detach_after INTEGER NOT NULL DEFAULT 0,
160 detach_on INTEGER NOT NULL DEFAULT 0,
[255]161 FOREIGN KEY(network) REFERENCES Network(id),
162 UNIQUE(network, name)
163);
164`
165
166var migrations = []string{
167 "", // migration #0 is reserved for schema initialization
[263]168 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]169 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]170 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
171 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[327]172 "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
[420]173 `
174 CREATE TABLE UserNew (
175 id INTEGER PRIMARY KEY,
176 username VARCHAR(255) NOT NULL UNIQUE,
177 password VARCHAR(255),
178 admin INTEGER NOT NULL DEFAULT 0
179 );
180 INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
181 DROP TABLE User;
182 ALTER TABLE UserNew RENAME TO User;
183 `,
[421]184 `
185 CREATE TABLE NetworkNew (
186 id INTEGER PRIMARY KEY,
187 name VARCHAR(255),
188 user INTEGER NOT NULL,
189 addr VARCHAR(255) NOT NULL,
190 nick VARCHAR(255) NOT NULL,
191 username VARCHAR(255),
192 realname VARCHAR(255),
193 pass VARCHAR(255),
194 connect_commands VARCHAR(1023),
195 sasl_mechanism VARCHAR(255),
196 sasl_plain_username VARCHAR(255),
197 sasl_plain_password VARCHAR(255),
198 sasl_external_cert BLOB DEFAULT NULL,
199 sasl_external_key BLOB DEFAULT NULL,
200 FOREIGN KEY(user) REFERENCES User(id),
201 UNIQUE(user, addr, nick),
202 UNIQUE(user, name)
203 );
204 INSERT INTO NetworkNew
205 SELECT Network.id, name, User.id as user, addr, nick,
206 Network.username, realname, pass, connect_commands,
207 sasl_mechanism, sasl_plain_username, sasl_plain_password,
208 sasl_external_cert, sasl_external_key
209 FROM Network
210 JOIN User ON Network.user = User.username;
211 DROP TABLE Network;
212 ALTER TABLE NetworkNew RENAME TO Network;
213 `,
[434]214 `
215 ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
216 ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
217 ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
218 ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
219 `,
[255]220}
221
[77]222type DB struct {
[81]223 lock sync.RWMutex
[77]224 db *sql.DB
225}
226
227func OpenSQLDB(driver, source string) (*DB, error) {
[255]228 sqlDB, err := sql.Open(driver, source)
[77]229 if err != nil {
230 return nil, err
231 }
[255]232
233 db := &DB{db: sqlDB}
234 if err := db.upgrade(); err != nil {
235 return nil, err
236 }
237
238 return db, nil
[77]239}
240
241func (db *DB) Close() error {
242 db.lock.Lock()
243 defer db.lock.Unlock()
[356]244 return db.db.Close()
[77]245}
246
[255]247func (db *DB) upgrade() error {
248 db.lock.Lock()
249 defer db.lock.Unlock()
250
251 var version int
252 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
253 return fmt.Errorf("failed to query schema version: %v", err)
254 }
255
256 if version == len(migrations) {
257 return nil
258 } else if version > len(migrations) {
259 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
260 }
261
262 tx, err := db.db.Begin()
263 if err != nil {
264 return err
265 }
266 defer tx.Rollback()
267
268 if version == 0 {
269 if _, err := tx.Exec(schema); err != nil {
270 return fmt.Errorf("failed to initialize schema: %v", err)
271 }
272 } else {
273 for i := version; i < len(migrations); i++ {
274 if _, err := tx.Exec(migrations[i]); err != nil {
275 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
276 }
277 }
278 }
279
280 // For some reason prepared statements don't work here
281 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
282 if err != nil {
283 return fmt.Errorf("failed to bump schema version: %v", err)
284 }
285
286 return tx.Commit()
287}
288
[422]289func toNullString(s string) sql.NullString {
290 return sql.NullString{
291 String: s,
292 Valid: s != "",
[95]293 }
294}
295
[77]296func (db *DB) ListUsers() ([]User, error) {
[81]297 db.lock.RLock()
298 defer db.lock.RUnlock()
[77]299
[420]300 rows, err := db.db.Query("SELECT id, username, password, admin FROM User")
[77]301 if err != nil {
302 return nil, err
303 }
304 defer rows.Close()
305
306 var users []User
307 for rows.Next() {
308 var user User
[422]309 var password sql.NullString
[382]310 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil {
[77]311 return nil, err
312 }
[422]313 user.Password = password.String
[77]314 users = append(users, user)
315 }
316 if err := rows.Err(); err != nil {
317 return nil, err
318 }
319
320 return users, nil
321}
322
[173]323func (db *DB) GetUser(username string) (*User, error) {
324 db.lock.RLock()
325 defer db.lock.RUnlock()
326
[382]327 user := &User{Username: username}
[173]328
[422]329 var password sql.NullString
[420]330 row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username)
[382]331 if err := row.Scan(&user.ID, &password, &user.Admin); err != nil {
[173]332 return nil, err
333 }
[422]334 user.Password = password.String
[173]335 return user, nil
336}
337
[324]338func (db *DB) StoreUser(user *User) error {
[84]339 db.lock.Lock()
340 defer db.lock.Unlock()
341
[422]342 password := toNullString(user.Password)
[84]343
[324]344 var err error
[382]345 if user.ID != 0 {
[327]346 _, err = db.db.Exec("UPDATE User SET password = ?, admin = ? WHERE username = ?",
347 password, user.Admin, user.Username)
[324]348 } else {
[382]349 var res sql.Result
[393]350 res, err = db.db.Exec("INSERT INTO User(username, password, admin) VALUES (?, ?, ?)",
[327]351 user.Username, password, user.Admin)
[382]352 if err != nil {
353 return err
[324]354 }
[382]355 user.ID, err = res.LastInsertId()
[324]356 }
[251]357
358 return err
359}
360
[421]361func (db *DB) DeleteUser(id int64) error {
[375]362 db.lock.Lock()
363 defer db.lock.Unlock()
364
365 tx, err := db.db.Begin()
366 if err != nil {
367 return err
368 }
369 defer tx.Rollback()
370
371 _, err = tx.Exec(`DELETE FROM Channel
372 WHERE id IN (
373 SELECT Channel.id
374 FROM Channel
375 JOIN Network ON Channel.network = Network.id
376 WHERE Network.user = ?
[421]377 )`, id)
[375]378 if err != nil {
379 return err
380 }
381
[421]382 _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id)
[375]383 if err != nil {
384 return err
385 }
386
[421]387 _, err = tx.Exec("DELETE FROM User WHERE id = ?", id)
[375]388 if err != nil {
389 return err
390 }
391
392 return tx.Commit()
393}
394
[421]395func (db *DB) ListNetworks(userID int64) ([]Network, error) {
[81]396 db.lock.RLock()
397 defer db.lock.RUnlock()
[77]398
[118]399 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]400 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
401 sasl_external_cert, sasl_external_key
[95]402 FROM Network
403 WHERE user = ?`,
[421]404 userID)
[77]405 if err != nil {
406 return nil, err
407 }
408 defer rows.Close()
409
410 var networks []Network
411 for rows.Next() {
412 var net Network
[422]413 var name, username, realname, pass, connectCommands sql.NullString
414 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
[118]415 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]416 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
417 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]418 if err != nil {
[77]419 return nil, err
420 }
[422]421 net.Name = name.String
422 net.Username = username.String
423 net.Realname = realname.String
424 net.Pass = pass.String
425 if connectCommands.Valid {
426 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
[263]427 }
[422]428 net.SASL.Mechanism = saslMechanism.String
429 net.SASL.Plain.Username = saslPlainUsername.String
430 net.SASL.Plain.Password = saslPlainPassword.String
[77]431 networks = append(networks, net)
432 }
433 if err := rows.Err(); err != nil {
434 return nil, err
435 }
436
437 return networks, nil
438}
439
[421]440func (db *DB) StoreNetwork(userID int64, network *Network) error {
[90]441 db.lock.Lock()
442 defer db.lock.Unlock()
443
[422]444 netName := toNullString(network.Name)
445 netUsername := toNullString(network.Username)
446 realname := toNullString(network.Realname)
447 pass := toNullString(network.Pass)
448 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
[95]449
[422]450 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
[95]451 if network.SASL.Mechanism != "" {
[422]452 saslMechanism = toNullString(network.SASL.Mechanism)
[95]453 switch network.SASL.Mechanism {
454 case "PLAIN":
[422]455 saslPlainUsername = toNullString(network.SASL.Plain.Username)
456 saslPlainPassword = toNullString(network.SASL.Plain.Password)
[307]457 network.SASL.External.CertBlob = nil
458 network.SASL.External.PrivKeyBlob = nil
459 case "EXTERNAL":
460 // keep saslPlain* nil
[148]461 default:
462 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]463 }
[90]464 }
465
466 var err error
467 if network.ID != 0 {
[93]468 _, err = db.db.Exec(`UPDATE Network
[263]469 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]470 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
471 sasl_external_cert = ?, sasl_external_key = ?
[93]472 WHERE id = ?`,
[263]473 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]474 saslMechanism, saslPlainUsername, saslPlainPassword,
475 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
476 network.ID)
[90]477 } else {
478 var res sql.Result
[118]479 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]480 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]481 sasl_plain_password, sasl_external_cert, sasl_external_key)
482 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[421]483 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]484 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
485 network.SASL.External.PrivKeyBlob)
[90]486 if err != nil {
487 return err
488 }
489 network.ID, err = res.LastInsertId()
490 }
491 return err
492}
493
[202]494func (db *DB) DeleteNetwork(id int64) error {
495 db.lock.Lock()
496 defer db.lock.Unlock()
497
498 tx, err := db.db.Begin()
499 if err != nil {
500 return err
501 }
502 defer tx.Rollback()
503
[375]504 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
[202]505 if err != nil {
506 return err
507 }
508
[375]509 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
[202]510 if err != nil {
511 return err
512 }
513
514 return tx.Commit()
515}
516
[77]517func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]518 db.lock.RLock()
519 defer db.lock.RUnlock()
[77]520
[434]521 rows, err := db.db.Query(`SELECT id, name, key, detached, relay_detached, reattach_on, detach_after, detach_on
[284]522 FROM Channel
523 WHERE network = ?`, networkID)
[77]524 if err != nil {
525 return nil, err
526 }
527 defer rows.Close()
528
529 var channels []Channel
530 for rows.Next() {
531 var ch Channel
[422]532 var key sql.NullString
[434]533 var detachAfter int64
534 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
[77]535 return nil, err
536 }
[422]537 ch.Key = key.String
[434]538 ch.DetachAfter = time.Duration(detachAfter) * time.Second
[77]539 channels = append(channels, ch)
540 }
541 if err := rows.Err(); err != nil {
542 return nil, err
543 }
544
545 return channels, nil
546}
[89]547
548func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
549 db.lock.Lock()
550 defer db.lock.Unlock()
551
[422]552 key := toNullString(ch.Key)
[434]553 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
[149]554
555 var err error
556 if ch.ID != 0 {
557 _, err = db.db.Exec(`UPDATE Channel
[434]558 SET network = ?, name = ?, key = ?, detached = ?, relay_detached = ?, reattach_on = ?, detach_after = ?, detach_on = ?
[284]559 WHERE id = ?`,
[434]560 networkID, ch.Name, key, ch.Detached, ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn, ch.ID)
[149]561 } else {
562 var res sql.Result
[434]563 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached, relay_detached, reattach_on, detach_after, detach_on)
564 VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
565 networkID, ch.Name, key, ch.Detached, ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
[149]566 if err != nil {
567 return err
568 }
569 ch.ID, err = res.LastInsertId()
570 }
[89]571 return err
572}
573
[416]574func (db *DB) DeleteChannel(id int64) error {
[89]575 db.lock.Lock()
576 defer db.lock.Unlock()
577
[416]578 _, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
[89]579 return err
580}
Note: See TracBrowser for help on using the repository browser.