source: code/trunk/db.go@ 279

Last change on this file since 279 was 267, checked in by contact, 5 years ago

Add network.channels, remove DB.GetChannel

Store the list of configured channels in the network data structure.
This removes the need for a database lookup and will be useful for
detached channels.

File size: 9.0 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Username string
[85]14 Password string // hashed
[77]15}
16
[95]17type SASL struct {
18 Mechanism string
19
20 Plain struct {
21 Username string
22 Password string
23 }
24}
25
[77]26type Network struct {
[263]27 ID int64
28 Name string
29 Addr string
30 Nick string
31 Username string
32 Realname string
33 Pass string
34 ConnectCommands []string
35 SASL SASL
[77]36}
37
[149]38func (net *Network) GetName() string {
39 if net.Name != "" {
40 return net.Name
41 }
42 return net.Addr
43}
44
[77]45type Channel struct {
46 ID int64
47 Name string
[146]48 Key string
[77]49}
50
[255]51const schema = `
52CREATE TABLE User (
53 username VARCHAR(255) PRIMARY KEY,
54 password VARCHAR(255) NOT NULL
55);
56
57CREATE TABLE Network (
58 id INTEGER PRIMARY KEY,
59 name VARCHAR(255),
60 user VARCHAR(255) NOT NULL,
61 addr VARCHAR(255) NOT NULL,
62 nick VARCHAR(255) NOT NULL,
63 username VARCHAR(255),
64 realname VARCHAR(255),
65 pass VARCHAR(255),
[263]66 connect_commands VARCHAR(1023),
[255]67 sasl_mechanism VARCHAR(255),
68 sasl_plain_username VARCHAR(255),
69 sasl_plain_password VARCHAR(255),
70 FOREIGN KEY(user) REFERENCES User(username),
71 UNIQUE(user, addr, nick)
72);
73
74CREATE TABLE Channel (
75 id INTEGER PRIMARY KEY,
76 network INTEGER NOT NULL,
77 name VARCHAR(255) NOT NULL,
78 key VARCHAR(255),
79 FOREIGN KEY(network) REFERENCES Network(id),
80 UNIQUE(network, name)
81);
82`
83
84var migrations = []string{
85 "", // migration #0 is reserved for schema initialization
[263]86 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[255]87}
88
[77]89type DB struct {
[81]90 lock sync.RWMutex
[77]91 db *sql.DB
92}
93
94func OpenSQLDB(driver, source string) (*DB, error) {
[255]95 sqlDB, err := sql.Open(driver, source)
[77]96 if err != nil {
97 return nil, err
98 }
[255]99
100 db := &DB{db: sqlDB}
101 if err := db.upgrade(); err != nil {
102 return nil, err
103 }
104
105 return db, nil
[77]106}
107
108func (db *DB) Close() error {
109 db.lock.Lock()
110 defer db.lock.Unlock()
111 return db.Close()
112}
113
[255]114func (db *DB) upgrade() error {
115 db.lock.Lock()
116 defer db.lock.Unlock()
117
118 var version int
119 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
120 return fmt.Errorf("failed to query schema version: %v", err)
121 }
122
123 if version == len(migrations) {
124 return nil
125 } else if version > len(migrations) {
126 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
127 }
128
129 tx, err := db.db.Begin()
130 if err != nil {
131 return err
132 }
133 defer tx.Rollback()
134
135 if version == 0 {
136 if _, err := tx.Exec(schema); err != nil {
137 return fmt.Errorf("failed to initialize schema: %v", err)
138 }
139 } else {
140 for i := version; i < len(migrations); i++ {
141 if _, err := tx.Exec(migrations[i]); err != nil {
142 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
143 }
144 }
145 }
146
147 // For some reason prepared statements don't work here
148 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
149 if err != nil {
150 return fmt.Errorf("failed to bump schema version: %v", err)
151 }
152
153 return tx.Commit()
154}
155
[95]156func fromStringPtr(ptr *string) string {
157 if ptr == nil {
158 return ""
159 }
160 return *ptr
161}
162
163func toStringPtr(s string) *string {
164 if s == "" {
165 return nil
166 }
167 return &s
168}
169
[77]170func (db *DB) ListUsers() ([]User, error) {
[81]171 db.lock.RLock()
172 defer db.lock.RUnlock()
[77]173
174 rows, err := db.db.Query("SELECT username, password FROM User")
175 if err != nil {
176 return nil, err
177 }
178 defer rows.Close()
179
180 var users []User
181 for rows.Next() {
182 var user User
183 var password *string
184 if err := rows.Scan(&user.Username, &password); err != nil {
185 return nil, err
186 }
[95]187 user.Password = fromStringPtr(password)
[77]188 users = append(users, user)
189 }
190 if err := rows.Err(); err != nil {
191 return nil, err
192 }
193
194 return users, nil
195}
196
[173]197func (db *DB) GetUser(username string) (*User, error) {
198 db.lock.RLock()
199 defer db.lock.RUnlock()
200
201 user := &User{Username: username}
202
203 var password *string
204 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
205 if err := row.Scan(&password); err != nil {
206 return nil, err
207 }
208 user.Password = fromStringPtr(password)
209 return user, nil
210}
211
[84]212func (db *DB) CreateUser(user *User) error {
213 db.lock.Lock()
214 defer db.lock.Unlock()
215
[95]216 password := toStringPtr(user.Password)
[89]217 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
218 return err
[84]219}
220
[251]221func (db *DB) UpdatePassword(user *User) error {
222 db.lock.Lock()
223 defer db.lock.Unlock()
224
225 password := toStringPtr(user.Password)
226 _, err := db.db.Exec(`UPDATE User
227 SET password = ?
228 WHERE username = ?`,
229 password, user.Username)
230 return err
231}
232
[77]233func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]234 db.lock.RLock()
235 defer db.lock.RUnlock()
[77]236
[118]237 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[263]238 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
[95]239 FROM Network
240 WHERE user = ?`,
241 username)
[77]242 if err != nil {
243 return nil, err
244 }
245 defer rows.Close()
246
247 var networks []Network
248 for rows.Next() {
249 var net Network
[263]250 var name, username, realname, pass, connectCommands *string
[95]251 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]252 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[263]253 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
[95]254 if err != nil {
[77]255 return nil, err
256 }
[118]257 net.Name = fromStringPtr(name)
[95]258 net.Username = fromStringPtr(username)
259 net.Realname = fromStringPtr(realname)
260 net.Pass = fromStringPtr(pass)
[263]261 if connectCommands != nil {
262 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
263 }
[95]264 net.SASL.Mechanism = fromStringPtr(saslMechanism)
265 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
266 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]267 networks = append(networks, net)
268 }
269 if err := rows.Err(); err != nil {
270 return nil, err
271 }
272
273 return networks, nil
274}
275
[90]276func (db *DB) StoreNetwork(username string, network *Network) error {
277 db.lock.Lock()
278 defer db.lock.Unlock()
279
[118]280 netName := toStringPtr(network.Name)
[95]281 netUsername := toStringPtr(network.Username)
282 realname := toStringPtr(network.Realname)
283 pass := toStringPtr(network.Pass)
[263]284 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]285
286 var saslMechanism, saslPlainUsername, saslPlainPassword *string
287 if network.SASL.Mechanism != "" {
288 saslMechanism = &network.SASL.Mechanism
289 switch network.SASL.Mechanism {
290 case "PLAIN":
291 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
292 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]293 default:
294 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]295 }
[90]296 }
297
298 var err error
299 if network.ID != 0 {
[93]300 _, err = db.db.Exec(`UPDATE Network
[263]301 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[95]302 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]303 WHERE id = ?`,
[263]304 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[95]305 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]306 } else {
307 var res sql.Result
[118]308 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]309 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[95]310 sasl_plain_password)
[263]311 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
312 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[95]313 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]314 if err != nil {
315 return err
316 }
317 network.ID, err = res.LastInsertId()
318 }
319 return err
320}
321
[202]322func (db *DB) DeleteNetwork(id int64) error {
323 db.lock.Lock()
324 defer db.lock.Unlock()
325
326 tx, err := db.db.Begin()
327 if err != nil {
328 return err
329 }
330 defer tx.Rollback()
331
332 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
333 if err != nil {
334 return err
335 }
336
337 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
338 if err != nil {
339 return err
340 }
341
342 return tx.Commit()
343}
344
[77]345func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]346 db.lock.RLock()
347 defer db.lock.RUnlock()
[77]348
[146]349 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]350 if err != nil {
351 return nil, err
352 }
353 defer rows.Close()
354
355 var channels []Channel
356 for rows.Next() {
357 var ch Channel
[146]358 var key *string
359 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]360 return nil, err
361 }
[146]362 ch.Key = fromStringPtr(key)
[77]363 channels = append(channels, ch)
364 }
365 if err := rows.Err(); err != nil {
366 return nil, err
367 }
368
369 return channels, nil
370}
[89]371
372func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
373 db.lock.Lock()
374 defer db.lock.Unlock()
375
[146]376 key := toStringPtr(ch.Key)
[149]377
378 var err error
379 if ch.ID != 0 {
380 _, err = db.db.Exec(`UPDATE Channel
381 SET network = ?, name = ?, key = ?
382 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
383 } else {
384 var res sql.Result
385 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
386 VALUES (?, ?, ?)`, networkID, ch.Name, key)
387 if err != nil {
388 return err
389 }
390 ch.ID, err = res.LastInsertId()
391 }
[89]392 return err
393}
394
395func (db *DB) DeleteChannel(networkID int64, name string) error {
396 db.lock.Lock()
397 defer db.lock.Unlock()
398
399 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
400 return err
401}
Note: See TracBrowser for help on using the repository browser.