source: code/trunk/db.go@ 319

Last change on this file since 319 was 307, checked in by fox.cpp, 5 years ago

Implement upstream SASL EXTERNAL support

Closes: https://todo.sr.ht/~emersion/soju/47

File size: 10.1 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Username string
[85]14 Password string // hashed
[77]15}
16
[95]17type SASL struct {
18 Mechanism string
19
20 Plain struct {
21 Username string
22 Password string
23 }
[307]24
25 // TLS client certificate authentication.
26 External struct {
27 // X.509 certificate in DER form.
28 CertBlob []byte
29 // PKCS#8 private key in DER form.
30 PrivKeyBlob []byte
31 }
[95]32}
33
[77]34type Network struct {
[263]35 ID int64
36 Name string
37 Addr string
38 Nick string
39 Username string
40 Realname string
41 Pass string
42 ConnectCommands []string
43 SASL SASL
[77]44}
45
[149]46func (net *Network) GetName() string {
47 if net.Name != "" {
48 return net.Name
49 }
50 return net.Addr
51}
52
[77]53type Channel struct {
[284]54 ID int64
55 Name string
56 Key string
57 Detached bool
[77]58}
59
[255]60const schema = `
61CREATE TABLE User (
62 username VARCHAR(255) PRIMARY KEY,
63 password VARCHAR(255) NOT NULL
64);
65
66CREATE TABLE Network (
67 id INTEGER PRIMARY KEY,
68 name VARCHAR(255),
69 user VARCHAR(255) NOT NULL,
70 addr VARCHAR(255) NOT NULL,
71 nick VARCHAR(255) NOT NULL,
72 username VARCHAR(255),
73 realname VARCHAR(255),
74 pass VARCHAR(255),
[263]75 connect_commands VARCHAR(1023),
[255]76 sasl_mechanism VARCHAR(255),
77 sasl_plain_username VARCHAR(255),
78 sasl_plain_password VARCHAR(255),
[307]79 sasl_external_cert BLOB DEFAULT NULL,
80 sasl_external_key BLOB DEFAULT NULL,
[255]81 FOREIGN KEY(user) REFERENCES User(username),
82 UNIQUE(user, addr, nick)
83);
84
85CREATE TABLE Channel (
86 id INTEGER PRIMARY KEY,
87 network INTEGER NOT NULL,
88 name VARCHAR(255) NOT NULL,
89 key VARCHAR(255),
[284]90 detached INTEGER NOT NULL DEFAULT 0,
[255]91 FOREIGN KEY(network) REFERENCES Network(id),
92 UNIQUE(network, name)
93);
94`
95
96var migrations = []string{
97 "", // migration #0 is reserved for schema initialization
[263]98 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]99 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[307]100 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
101 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
[255]102}
103
[77]104type DB struct {
[81]105 lock sync.RWMutex
[77]106 db *sql.DB
107}
108
109func OpenSQLDB(driver, source string) (*DB, error) {
[255]110 sqlDB, err := sql.Open(driver, source)
[77]111 if err != nil {
112 return nil, err
113 }
[255]114
115 db := &DB{db: sqlDB}
116 if err := db.upgrade(); err != nil {
117 return nil, err
118 }
119
120 return db, nil
[77]121}
122
123func (db *DB) Close() error {
124 db.lock.Lock()
125 defer db.lock.Unlock()
126 return db.Close()
127}
128
[255]129func (db *DB) upgrade() error {
130 db.lock.Lock()
131 defer db.lock.Unlock()
132
133 var version int
134 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
135 return fmt.Errorf("failed to query schema version: %v", err)
136 }
137
138 if version == len(migrations) {
139 return nil
140 } else if version > len(migrations) {
141 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
142 }
143
144 tx, err := db.db.Begin()
145 if err != nil {
146 return err
147 }
148 defer tx.Rollback()
149
150 if version == 0 {
151 if _, err := tx.Exec(schema); err != nil {
152 return fmt.Errorf("failed to initialize schema: %v", err)
153 }
154 } else {
155 for i := version; i < len(migrations); i++ {
156 if _, err := tx.Exec(migrations[i]); err != nil {
157 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
158 }
159 }
160 }
161
162 // For some reason prepared statements don't work here
163 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
164 if err != nil {
165 return fmt.Errorf("failed to bump schema version: %v", err)
166 }
167
168 return tx.Commit()
169}
170
[95]171func fromStringPtr(ptr *string) string {
172 if ptr == nil {
173 return ""
174 }
175 return *ptr
176}
177
178func toStringPtr(s string) *string {
179 if s == "" {
180 return nil
181 }
182 return &s
183}
184
[77]185func (db *DB) ListUsers() ([]User, error) {
[81]186 db.lock.RLock()
187 defer db.lock.RUnlock()
[77]188
189 rows, err := db.db.Query("SELECT username, password FROM User")
190 if err != nil {
191 return nil, err
192 }
193 defer rows.Close()
194
195 var users []User
196 for rows.Next() {
197 var user User
198 var password *string
199 if err := rows.Scan(&user.Username, &password); err != nil {
200 return nil, err
201 }
[95]202 user.Password = fromStringPtr(password)
[77]203 users = append(users, user)
204 }
205 if err := rows.Err(); err != nil {
206 return nil, err
207 }
208
209 return users, nil
210}
211
[173]212func (db *DB) GetUser(username string) (*User, error) {
213 db.lock.RLock()
214 defer db.lock.RUnlock()
215
216 user := &User{Username: username}
217
218 var password *string
219 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
220 if err := row.Scan(&password); err != nil {
221 return nil, err
222 }
223 user.Password = fromStringPtr(password)
224 return user, nil
225}
226
[84]227func (db *DB) CreateUser(user *User) error {
228 db.lock.Lock()
229 defer db.lock.Unlock()
230
[95]231 password := toStringPtr(user.Password)
[89]232 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
233 return err
[84]234}
235
[251]236func (db *DB) UpdatePassword(user *User) error {
237 db.lock.Lock()
238 defer db.lock.Unlock()
239
240 password := toStringPtr(user.Password)
241 _, err := db.db.Exec(`UPDATE User
242 SET password = ?
243 WHERE username = ?`,
244 password, user.Username)
245 return err
246}
247
[77]248func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]249 db.lock.RLock()
250 defer db.lock.RUnlock()
[77]251
[118]252 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[307]253 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
254 sasl_external_cert, sasl_external_key
[95]255 FROM Network
256 WHERE user = ?`,
257 username)
[77]258 if err != nil {
259 return nil, err
260 }
261 defer rows.Close()
262
263 var networks []Network
264 for rows.Next() {
265 var net Network
[263]266 var name, username, realname, pass, connectCommands *string
[95]267 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]268 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[307]269 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
270 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
[95]271 if err != nil {
[77]272 return nil, err
273 }
[118]274 net.Name = fromStringPtr(name)
[95]275 net.Username = fromStringPtr(username)
276 net.Realname = fromStringPtr(realname)
277 net.Pass = fromStringPtr(pass)
[263]278 if connectCommands != nil {
279 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
280 }
[95]281 net.SASL.Mechanism = fromStringPtr(saslMechanism)
282 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
283 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]284 networks = append(networks, net)
285 }
286 if err := rows.Err(); err != nil {
287 return nil, err
288 }
289
290 return networks, nil
291}
292
[90]293func (db *DB) StoreNetwork(username string, network *Network) error {
294 db.lock.Lock()
295 defer db.lock.Unlock()
296
[118]297 netName := toStringPtr(network.Name)
[95]298 netUsername := toStringPtr(network.Username)
299 realname := toStringPtr(network.Realname)
300 pass := toStringPtr(network.Pass)
[263]301 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]302
303 var saslMechanism, saslPlainUsername, saslPlainPassword *string
304 if network.SASL.Mechanism != "" {
305 saslMechanism = &network.SASL.Mechanism
306 switch network.SASL.Mechanism {
307 case "PLAIN":
308 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
309 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[307]310 network.SASL.External.CertBlob = nil
311 network.SASL.External.PrivKeyBlob = nil
312 case "EXTERNAL":
313 // keep saslPlain* nil
[148]314 default:
315 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]316 }
[90]317 }
318
319 var err error
320 if network.ID != 0 {
[93]321 _, err = db.db.Exec(`UPDATE Network
[263]322 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[307]323 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
324 sasl_external_cert = ?, sasl_external_key = ?
[93]325 WHERE id = ?`,
[263]326 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]327 saslMechanism, saslPlainUsername, saslPlainPassword,
328 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
329 network.ID)
[90]330 } else {
331 var res sql.Result
[118]332 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]333 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[307]334 sasl_plain_password, sasl_external_cert, sasl_external_key)
335 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
[263]336 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[307]337 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
338 network.SASL.External.PrivKeyBlob)
[90]339 if err != nil {
340 return err
341 }
342 network.ID, err = res.LastInsertId()
343 }
344 return err
345}
346
[202]347func (db *DB) DeleteNetwork(id int64) error {
348 db.lock.Lock()
349 defer db.lock.Unlock()
350
351 tx, err := db.db.Begin()
352 if err != nil {
353 return err
354 }
355 defer tx.Rollback()
356
357 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
358 if err != nil {
359 return err
360 }
361
362 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
363 if err != nil {
364 return err
365 }
366
367 return tx.Commit()
368}
369
[77]370func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]371 db.lock.RLock()
372 defer db.lock.RUnlock()
[77]373
[284]374 rows, err := db.db.Query(`SELECT id, name, key, detached
375 FROM Channel
376 WHERE network = ?`, networkID)
[77]377 if err != nil {
378 return nil, err
379 }
380 defer rows.Close()
381
382 var channels []Channel
383 for rows.Next() {
384 var ch Channel
[146]385 var key *string
[284]386 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]387 return nil, err
388 }
[146]389 ch.Key = fromStringPtr(key)
[77]390 channels = append(channels, ch)
391 }
392 if err := rows.Err(); err != nil {
393 return nil, err
394 }
395
396 return channels, nil
397}
[89]398
399func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
400 db.lock.Lock()
401 defer db.lock.Unlock()
402
[146]403 key := toStringPtr(ch.Key)
[149]404
405 var err error
406 if ch.ID != 0 {
407 _, err = db.db.Exec(`UPDATE Channel
[284]408 SET network = ?, name = ?, key = ?, detached = ?
409 WHERE id = ?`,
410 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]411 } else {
412 var res sql.Result
[284]413 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
414 VALUES (?, ?, ?, ?)`,
415 networkID, ch.Name, key, ch.Detached)
[149]416 if err != nil {
417 return err
418 }
419 ch.ID, err = res.LastInsertId()
420 }
[89]421 return err
422}
423
424func (db *DB) DeleteChannel(networkID int64, name string) error {
425 db.lock.Lock()
426 defer db.lock.Unlock()
427
428 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
429 return err
430}
Note: See TracBrowser for help on using the repository browser.