source: code/trunk/db.go@ 279

Last change on this file since 279 was 267, checked in by contact, 5 years ago

Add network.channels, remove DB.GetChannel

Store the list of configured channels in the network data structure.
This removes the need for a database lookup and will be useful for
detached channels.

File size: 9.0 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Username string
14 Password string // hashed
15}
16
17type SASL struct {
18 Mechanism string
19
20 Plain struct {
21 Username string
22 Password string
23 }
24}
25
26type Network struct {
27 ID int64
28 Name string
29 Addr string
30 Nick string
31 Username string
32 Realname string
33 Pass string
34 ConnectCommands []string
35 SASL SASL
36}
37
38func (net *Network) GetName() string {
39 if net.Name != "" {
40 return net.Name
41 }
42 return net.Addr
43}
44
45type Channel struct {
46 ID int64
47 Name string
48 Key string
49}
50
51const schema = `
52CREATE TABLE User (
53 username VARCHAR(255) PRIMARY KEY,
54 password VARCHAR(255) NOT NULL
55);
56
57CREATE TABLE Network (
58 id INTEGER PRIMARY KEY,
59 name VARCHAR(255),
60 user VARCHAR(255) NOT NULL,
61 addr VARCHAR(255) NOT NULL,
62 nick VARCHAR(255) NOT NULL,
63 username VARCHAR(255),
64 realname VARCHAR(255),
65 pass VARCHAR(255),
66 connect_commands VARCHAR(1023),
67 sasl_mechanism VARCHAR(255),
68 sasl_plain_username VARCHAR(255),
69 sasl_plain_password VARCHAR(255),
70 FOREIGN KEY(user) REFERENCES User(username),
71 UNIQUE(user, addr, nick)
72);
73
74CREATE TABLE Channel (
75 id INTEGER PRIMARY KEY,
76 network INTEGER NOT NULL,
77 name VARCHAR(255) NOT NULL,
78 key VARCHAR(255),
79 FOREIGN KEY(network) REFERENCES Network(id),
80 UNIQUE(network, name)
81);
82`
83
84var migrations = []string{
85 "", // migration #0 is reserved for schema initialization
86 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
87}
88
89type DB struct {
90 lock sync.RWMutex
91 db *sql.DB
92}
93
94func OpenSQLDB(driver, source string) (*DB, error) {
95 sqlDB, err := sql.Open(driver, source)
96 if err != nil {
97 return nil, err
98 }
99
100 db := &DB{db: sqlDB}
101 if err := db.upgrade(); err != nil {
102 return nil, err
103 }
104
105 return db, nil
106}
107
108func (db *DB) Close() error {
109 db.lock.Lock()
110 defer db.lock.Unlock()
111 return db.Close()
112}
113
114func (db *DB) upgrade() error {
115 db.lock.Lock()
116 defer db.lock.Unlock()
117
118 var version int
119 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
120 return fmt.Errorf("failed to query schema version: %v", err)
121 }
122
123 if version == len(migrations) {
124 return nil
125 } else if version > len(migrations) {
126 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
127 }
128
129 tx, err := db.db.Begin()
130 if err != nil {
131 return err
132 }
133 defer tx.Rollback()
134
135 if version == 0 {
136 if _, err := tx.Exec(schema); err != nil {
137 return fmt.Errorf("failed to initialize schema: %v", err)
138 }
139 } else {
140 for i := version; i < len(migrations); i++ {
141 if _, err := tx.Exec(migrations[i]); err != nil {
142 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
143 }
144 }
145 }
146
147 // For some reason prepared statements don't work here
148 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
149 if err != nil {
150 return fmt.Errorf("failed to bump schema version: %v", err)
151 }
152
153 return tx.Commit()
154}
155
156func fromStringPtr(ptr *string) string {
157 if ptr == nil {
158 return ""
159 }
160 return *ptr
161}
162
163func toStringPtr(s string) *string {
164 if s == "" {
165 return nil
166 }
167 return &s
168}
169
170func (db *DB) ListUsers() ([]User, error) {
171 db.lock.RLock()
172 defer db.lock.RUnlock()
173
174 rows, err := db.db.Query("SELECT username, password FROM User")
175 if err != nil {
176 return nil, err
177 }
178 defer rows.Close()
179
180 var users []User
181 for rows.Next() {
182 var user User
183 var password *string
184 if err := rows.Scan(&user.Username, &password); err != nil {
185 return nil, err
186 }
187 user.Password = fromStringPtr(password)
188 users = append(users, user)
189 }
190 if err := rows.Err(); err != nil {
191 return nil, err
192 }
193
194 return users, nil
195}
196
197func (db *DB) GetUser(username string) (*User, error) {
198 db.lock.RLock()
199 defer db.lock.RUnlock()
200
201 user := &User{Username: username}
202
203 var password *string
204 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
205 if err := row.Scan(&password); err != nil {
206 return nil, err
207 }
208 user.Password = fromStringPtr(password)
209 return user, nil
210}
211
212func (db *DB) CreateUser(user *User) error {
213 db.lock.Lock()
214 defer db.lock.Unlock()
215
216 password := toStringPtr(user.Password)
217 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
218 return err
219}
220
221func (db *DB) UpdatePassword(user *User) error {
222 db.lock.Lock()
223 defer db.lock.Unlock()
224
225 password := toStringPtr(user.Password)
226 _, err := db.db.Exec(`UPDATE User
227 SET password = ?
228 WHERE username = ?`,
229 password, user.Username)
230 return err
231}
232
233func (db *DB) ListNetworks(username string) ([]Network, error) {
234 db.lock.RLock()
235 defer db.lock.RUnlock()
236
237 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
238 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
239 FROM Network
240 WHERE user = ?`,
241 username)
242 if err != nil {
243 return nil, err
244 }
245 defer rows.Close()
246
247 var networks []Network
248 for rows.Next() {
249 var net Network
250 var name, username, realname, pass, connectCommands *string
251 var saslMechanism, saslPlainUsername, saslPlainPassword *string
252 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
253 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
254 if err != nil {
255 return nil, err
256 }
257 net.Name = fromStringPtr(name)
258 net.Username = fromStringPtr(username)
259 net.Realname = fromStringPtr(realname)
260 net.Pass = fromStringPtr(pass)
261 if connectCommands != nil {
262 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
263 }
264 net.SASL.Mechanism = fromStringPtr(saslMechanism)
265 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
266 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
267 networks = append(networks, net)
268 }
269 if err := rows.Err(); err != nil {
270 return nil, err
271 }
272
273 return networks, nil
274}
275
276func (db *DB) StoreNetwork(username string, network *Network) error {
277 db.lock.Lock()
278 defer db.lock.Unlock()
279
280 netName := toStringPtr(network.Name)
281 netUsername := toStringPtr(network.Username)
282 realname := toStringPtr(network.Realname)
283 pass := toStringPtr(network.Pass)
284 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
285
286 var saslMechanism, saslPlainUsername, saslPlainPassword *string
287 if network.SASL.Mechanism != "" {
288 saslMechanism = &network.SASL.Mechanism
289 switch network.SASL.Mechanism {
290 case "PLAIN":
291 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
292 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
293 default:
294 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
295 }
296 }
297
298 var err error
299 if network.ID != 0 {
300 _, err = db.db.Exec(`UPDATE Network
301 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
302 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
303 WHERE id = ?`,
304 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
305 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
306 } else {
307 var res sql.Result
308 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
309 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
310 sasl_plain_password)
311 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
312 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
313 saslMechanism, saslPlainUsername, saslPlainPassword)
314 if err != nil {
315 return err
316 }
317 network.ID, err = res.LastInsertId()
318 }
319 return err
320}
321
322func (db *DB) DeleteNetwork(id int64) error {
323 db.lock.Lock()
324 defer db.lock.Unlock()
325
326 tx, err := db.db.Begin()
327 if err != nil {
328 return err
329 }
330 defer tx.Rollback()
331
332 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
333 if err != nil {
334 return err
335 }
336
337 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
338 if err != nil {
339 return err
340 }
341
342 return tx.Commit()
343}
344
345func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
346 db.lock.RLock()
347 defer db.lock.RUnlock()
348
349 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
350 if err != nil {
351 return nil, err
352 }
353 defer rows.Close()
354
355 var channels []Channel
356 for rows.Next() {
357 var ch Channel
358 var key *string
359 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
360 return nil, err
361 }
362 ch.Key = fromStringPtr(key)
363 channels = append(channels, ch)
364 }
365 if err := rows.Err(); err != nil {
366 return nil, err
367 }
368
369 return channels, nil
370}
371
372func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
373 db.lock.Lock()
374 defer db.lock.Unlock()
375
376 key := toStringPtr(ch.Key)
377
378 var err error
379 if ch.ID != 0 {
380 _, err = db.db.Exec(`UPDATE Channel
381 SET network = ?, name = ?, key = ?
382 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
383 } else {
384 var res sql.Result
385 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
386 VALUES (?, ?, ?)`, networkID, ch.Name, key)
387 if err != nil {
388 return err
389 }
390 ch.ID, err = res.LastInsertId()
391 }
392 return err
393}
394
395func (db *DB) DeleteChannel(networkID int64, name string) error {
396 db.lock.Lock()
397 defer db.lock.Unlock()
398
399 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
400 return err
401}
Note: See TracBrowser for help on using the repository browser.