source: code/trunk/db.go@ 320

Last change on this file since 320 was 307, checked in by fox.cpp, 5 years ago

Implement upstream SASL EXTERNAL support

Closes: https://todo.sr.ht/~emersion/soju/47

File size: 10.1 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Username string
14 Password string // hashed
15}
16
17type SASL struct {
18 Mechanism string
19
20 Plain struct {
21 Username string
22 Password string
23 }
24
25 // TLS client certificate authentication.
26 External struct {
27 // X.509 certificate in DER form.
28 CertBlob []byte
29 // PKCS#8 private key in DER form.
30 PrivKeyBlob []byte
31 }
32}
33
34type Network struct {
35 ID int64
36 Name string
37 Addr string
38 Nick string
39 Username string
40 Realname string
41 Pass string
42 ConnectCommands []string
43 SASL SASL
44}
45
46func (net *Network) GetName() string {
47 if net.Name != "" {
48 return net.Name
49 }
50 return net.Addr
51}
52
53type Channel struct {
54 ID int64
55 Name string
56 Key string
57 Detached bool
58}
59
60const schema = `
61CREATE TABLE User (
62 username VARCHAR(255) PRIMARY KEY,
63 password VARCHAR(255) NOT NULL
64);
65
66CREATE TABLE Network (
67 id INTEGER PRIMARY KEY,
68 name VARCHAR(255),
69 user VARCHAR(255) NOT NULL,
70 addr VARCHAR(255) NOT NULL,
71 nick VARCHAR(255) NOT NULL,
72 username VARCHAR(255),
73 realname VARCHAR(255),
74 pass VARCHAR(255),
75 connect_commands VARCHAR(1023),
76 sasl_mechanism VARCHAR(255),
77 sasl_plain_username VARCHAR(255),
78 sasl_plain_password VARCHAR(255),
79 sasl_external_cert BLOB DEFAULT NULL,
80 sasl_external_key BLOB DEFAULT NULL,
81 FOREIGN KEY(user) REFERENCES User(username),
82 UNIQUE(user, addr, nick)
83);
84
85CREATE TABLE Channel (
86 id INTEGER PRIMARY KEY,
87 network INTEGER NOT NULL,
88 name VARCHAR(255) NOT NULL,
89 key VARCHAR(255),
90 detached INTEGER NOT NULL DEFAULT 0,
91 FOREIGN KEY(network) REFERENCES Network(id),
92 UNIQUE(network, name)
93);
94`
95
96var migrations = []string{
97 "", // migration #0 is reserved for schema initialization
98 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
99 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
100 "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
101 "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
102}
103
104type DB struct {
105 lock sync.RWMutex
106 db *sql.DB
107}
108
109func OpenSQLDB(driver, source string) (*DB, error) {
110 sqlDB, err := sql.Open(driver, source)
111 if err != nil {
112 return nil, err
113 }
114
115 db := &DB{db: sqlDB}
116 if err := db.upgrade(); err != nil {
117 return nil, err
118 }
119
120 return db, nil
121}
122
123func (db *DB) Close() error {
124 db.lock.Lock()
125 defer db.lock.Unlock()
126 return db.Close()
127}
128
129func (db *DB) upgrade() error {
130 db.lock.Lock()
131 defer db.lock.Unlock()
132
133 var version int
134 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
135 return fmt.Errorf("failed to query schema version: %v", err)
136 }
137
138 if version == len(migrations) {
139 return nil
140 } else if version > len(migrations) {
141 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
142 }
143
144 tx, err := db.db.Begin()
145 if err != nil {
146 return err
147 }
148 defer tx.Rollback()
149
150 if version == 0 {
151 if _, err := tx.Exec(schema); err != nil {
152 return fmt.Errorf("failed to initialize schema: %v", err)
153 }
154 } else {
155 for i := version; i < len(migrations); i++ {
156 if _, err := tx.Exec(migrations[i]); err != nil {
157 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
158 }
159 }
160 }
161
162 // For some reason prepared statements don't work here
163 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
164 if err != nil {
165 return fmt.Errorf("failed to bump schema version: %v", err)
166 }
167
168 return tx.Commit()
169}
170
171func fromStringPtr(ptr *string) string {
172 if ptr == nil {
173 return ""
174 }
175 return *ptr
176}
177
178func toStringPtr(s string) *string {
179 if s == "" {
180 return nil
181 }
182 return &s
183}
184
185func (db *DB) ListUsers() ([]User, error) {
186 db.lock.RLock()
187 defer db.lock.RUnlock()
188
189 rows, err := db.db.Query("SELECT username, password FROM User")
190 if err != nil {
191 return nil, err
192 }
193 defer rows.Close()
194
195 var users []User
196 for rows.Next() {
197 var user User
198 var password *string
199 if err := rows.Scan(&user.Username, &password); err != nil {
200 return nil, err
201 }
202 user.Password = fromStringPtr(password)
203 users = append(users, user)
204 }
205 if err := rows.Err(); err != nil {
206 return nil, err
207 }
208
209 return users, nil
210}
211
212func (db *DB) GetUser(username string) (*User, error) {
213 db.lock.RLock()
214 defer db.lock.RUnlock()
215
216 user := &User{Username: username}
217
218 var password *string
219 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
220 if err := row.Scan(&password); err != nil {
221 return nil, err
222 }
223 user.Password = fromStringPtr(password)
224 return user, nil
225}
226
227func (db *DB) CreateUser(user *User) error {
228 db.lock.Lock()
229 defer db.lock.Unlock()
230
231 password := toStringPtr(user.Password)
232 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
233 return err
234}
235
236func (db *DB) UpdatePassword(user *User) error {
237 db.lock.Lock()
238 defer db.lock.Unlock()
239
240 password := toStringPtr(user.Password)
241 _, err := db.db.Exec(`UPDATE User
242 SET password = ?
243 WHERE username = ?`,
244 password, user.Username)
245 return err
246}
247
248func (db *DB) ListNetworks(username string) ([]Network, error) {
249 db.lock.RLock()
250 defer db.lock.RUnlock()
251
252 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
253 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
254 sasl_external_cert, sasl_external_key
255 FROM Network
256 WHERE user = ?`,
257 username)
258 if err != nil {
259 return nil, err
260 }
261 defer rows.Close()
262
263 var networks []Network
264 for rows.Next() {
265 var net Network
266 var name, username, realname, pass, connectCommands *string
267 var saslMechanism, saslPlainUsername, saslPlainPassword *string
268 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
269 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
270 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
271 if err != nil {
272 return nil, err
273 }
274 net.Name = fromStringPtr(name)
275 net.Username = fromStringPtr(username)
276 net.Realname = fromStringPtr(realname)
277 net.Pass = fromStringPtr(pass)
278 if connectCommands != nil {
279 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
280 }
281 net.SASL.Mechanism = fromStringPtr(saslMechanism)
282 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
283 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
284 networks = append(networks, net)
285 }
286 if err := rows.Err(); err != nil {
287 return nil, err
288 }
289
290 return networks, nil
291}
292
293func (db *DB) StoreNetwork(username string, network *Network) error {
294 db.lock.Lock()
295 defer db.lock.Unlock()
296
297 netName := toStringPtr(network.Name)
298 netUsername := toStringPtr(network.Username)
299 realname := toStringPtr(network.Realname)
300 pass := toStringPtr(network.Pass)
301 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
302
303 var saslMechanism, saslPlainUsername, saslPlainPassword *string
304 if network.SASL.Mechanism != "" {
305 saslMechanism = &network.SASL.Mechanism
306 switch network.SASL.Mechanism {
307 case "PLAIN":
308 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
309 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
310 network.SASL.External.CertBlob = nil
311 network.SASL.External.PrivKeyBlob = nil
312 case "EXTERNAL":
313 // keep saslPlain* nil
314 default:
315 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
316 }
317 }
318
319 var err error
320 if network.ID != 0 {
321 _, err = db.db.Exec(`UPDATE Network
322 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
323 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
324 sasl_external_cert = ?, sasl_external_key = ?
325 WHERE id = ?`,
326 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
327 saslMechanism, saslPlainUsername, saslPlainPassword,
328 network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
329 network.ID)
330 } else {
331 var res sql.Result
332 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
333 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
334 sasl_plain_password, sasl_external_cert, sasl_external_key)
335 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
336 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
337 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
338 network.SASL.External.PrivKeyBlob)
339 if err != nil {
340 return err
341 }
342 network.ID, err = res.LastInsertId()
343 }
344 return err
345}
346
347func (db *DB) DeleteNetwork(id int64) error {
348 db.lock.Lock()
349 defer db.lock.Unlock()
350
351 tx, err := db.db.Begin()
352 if err != nil {
353 return err
354 }
355 defer tx.Rollback()
356
357 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
358 if err != nil {
359 return err
360 }
361
362 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
363 if err != nil {
364 return err
365 }
366
367 return tx.Commit()
368}
369
370func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
371 db.lock.RLock()
372 defer db.lock.RUnlock()
373
374 rows, err := db.db.Query(`SELECT id, name, key, detached
375 FROM Channel
376 WHERE network = ?`, networkID)
377 if err != nil {
378 return nil, err
379 }
380 defer rows.Close()
381
382 var channels []Channel
383 for rows.Next() {
384 var ch Channel
385 var key *string
386 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
387 return nil, err
388 }
389 ch.Key = fromStringPtr(key)
390 channels = append(channels, ch)
391 }
392 if err := rows.Err(); err != nil {
393 return nil, err
394 }
395
396 return channels, nil
397}
398
399func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
400 db.lock.Lock()
401 defer db.lock.Unlock()
402
403 key := toStringPtr(ch.Key)
404
405 var err error
406 if ch.ID != 0 {
407 _, err = db.db.Exec(`UPDATE Channel
408 SET network = ?, name = ?, key = ?, detached = ?
409 WHERE id = ?`,
410 networkID, ch.Name, key, ch.Detached, ch.ID)
411 } else {
412 var res sql.Result
413 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
414 VALUES (?, ?, ?, ?)`,
415 networkID, ch.Name, key, ch.Detached)
416 if err != nil {
417 return err
418 }
419 ch.ID, err = res.LastInsertId()
420 }
421 return err
422}
423
424func (db *DB) DeleteChannel(networkID int64, name string) error {
425 db.lock.Lock()
426 defer db.lock.Unlock()
427
428 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
429 return err
430}
Note: See TracBrowser for help on using the repository browser.