source: code/trunk/db.go@ 298

Last change on this file since 298 was 284, checked in by contact, 5 years ago

Add support for detached channels

Channels can now be detached by leaving them with the reason "detach",
and re-attached by joining them again. Upon detaching the channel is
no longer forwarded to downstream connections. Upon re-attaching the
history buffer is sent.

File size: 9.3 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Username string
14 Password string // hashed
15}
16
17type SASL struct {
18 Mechanism string
19
20 Plain struct {
21 Username string
22 Password string
23 }
24}
25
26type Network struct {
27 ID int64
28 Name string
29 Addr string
30 Nick string
31 Username string
32 Realname string
33 Pass string
34 ConnectCommands []string
35 SASL SASL
36}
37
38func (net *Network) GetName() string {
39 if net.Name != "" {
40 return net.Name
41 }
42 return net.Addr
43}
44
45type Channel struct {
46 ID int64
47 Name string
48 Key string
49 Detached bool
50}
51
52const schema = `
53CREATE TABLE User (
54 username VARCHAR(255) PRIMARY KEY,
55 password VARCHAR(255) NOT NULL
56);
57
58CREATE TABLE Network (
59 id INTEGER PRIMARY KEY,
60 name VARCHAR(255),
61 user VARCHAR(255) NOT NULL,
62 addr VARCHAR(255) NOT NULL,
63 nick VARCHAR(255) NOT NULL,
64 username VARCHAR(255),
65 realname VARCHAR(255),
66 pass VARCHAR(255),
67 connect_commands VARCHAR(1023),
68 sasl_mechanism VARCHAR(255),
69 sasl_plain_username VARCHAR(255),
70 sasl_plain_password VARCHAR(255),
71 FOREIGN KEY(user) REFERENCES User(username),
72 UNIQUE(user, addr, nick)
73);
74
75CREATE TABLE Channel (
76 id INTEGER PRIMARY KEY,
77 network INTEGER NOT NULL,
78 name VARCHAR(255) NOT NULL,
79 key VARCHAR(255),
80 detached INTEGER NOT NULL DEFAULT 0,
81 FOREIGN KEY(network) REFERENCES Network(id),
82 UNIQUE(network, name)
83);
84`
85
86var migrations = []string{
87 "", // migration #0 is reserved for schema initialization
88 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
89 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
90}
91
92type DB struct {
93 lock sync.RWMutex
94 db *sql.DB
95}
96
97func OpenSQLDB(driver, source string) (*DB, error) {
98 sqlDB, err := sql.Open(driver, source)
99 if err != nil {
100 return nil, err
101 }
102
103 db := &DB{db: sqlDB}
104 if err := db.upgrade(); err != nil {
105 return nil, err
106 }
107
108 return db, nil
109}
110
111func (db *DB) Close() error {
112 db.lock.Lock()
113 defer db.lock.Unlock()
114 return db.Close()
115}
116
117func (db *DB) upgrade() error {
118 db.lock.Lock()
119 defer db.lock.Unlock()
120
121 var version int
122 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
123 return fmt.Errorf("failed to query schema version: %v", err)
124 }
125
126 if version == len(migrations) {
127 return nil
128 } else if version > len(migrations) {
129 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
130 }
131
132 tx, err := db.db.Begin()
133 if err != nil {
134 return err
135 }
136 defer tx.Rollback()
137
138 if version == 0 {
139 if _, err := tx.Exec(schema); err != nil {
140 return fmt.Errorf("failed to initialize schema: %v", err)
141 }
142 } else {
143 for i := version; i < len(migrations); i++ {
144 if _, err := tx.Exec(migrations[i]); err != nil {
145 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
146 }
147 }
148 }
149
150 // For some reason prepared statements don't work here
151 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
152 if err != nil {
153 return fmt.Errorf("failed to bump schema version: %v", err)
154 }
155
156 return tx.Commit()
157}
158
159func fromStringPtr(ptr *string) string {
160 if ptr == nil {
161 return ""
162 }
163 return *ptr
164}
165
166func toStringPtr(s string) *string {
167 if s == "" {
168 return nil
169 }
170 return &s
171}
172
173func (db *DB) ListUsers() ([]User, error) {
174 db.lock.RLock()
175 defer db.lock.RUnlock()
176
177 rows, err := db.db.Query("SELECT username, password FROM User")
178 if err != nil {
179 return nil, err
180 }
181 defer rows.Close()
182
183 var users []User
184 for rows.Next() {
185 var user User
186 var password *string
187 if err := rows.Scan(&user.Username, &password); err != nil {
188 return nil, err
189 }
190 user.Password = fromStringPtr(password)
191 users = append(users, user)
192 }
193 if err := rows.Err(); err != nil {
194 return nil, err
195 }
196
197 return users, nil
198}
199
200func (db *DB) GetUser(username string) (*User, error) {
201 db.lock.RLock()
202 defer db.lock.RUnlock()
203
204 user := &User{Username: username}
205
206 var password *string
207 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
208 if err := row.Scan(&password); err != nil {
209 return nil, err
210 }
211 user.Password = fromStringPtr(password)
212 return user, nil
213}
214
215func (db *DB) CreateUser(user *User) error {
216 db.lock.Lock()
217 defer db.lock.Unlock()
218
219 password := toStringPtr(user.Password)
220 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
221 return err
222}
223
224func (db *DB) UpdatePassword(user *User) error {
225 db.lock.Lock()
226 defer db.lock.Unlock()
227
228 password := toStringPtr(user.Password)
229 _, err := db.db.Exec(`UPDATE User
230 SET password = ?
231 WHERE username = ?`,
232 password, user.Username)
233 return err
234}
235
236func (db *DB) ListNetworks(username string) ([]Network, error) {
237 db.lock.RLock()
238 defer db.lock.RUnlock()
239
240 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
241 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
242 FROM Network
243 WHERE user = ?`,
244 username)
245 if err != nil {
246 return nil, err
247 }
248 defer rows.Close()
249
250 var networks []Network
251 for rows.Next() {
252 var net Network
253 var name, username, realname, pass, connectCommands *string
254 var saslMechanism, saslPlainUsername, saslPlainPassword *string
255 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
256 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
257 if err != nil {
258 return nil, err
259 }
260 net.Name = fromStringPtr(name)
261 net.Username = fromStringPtr(username)
262 net.Realname = fromStringPtr(realname)
263 net.Pass = fromStringPtr(pass)
264 if connectCommands != nil {
265 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
266 }
267 net.SASL.Mechanism = fromStringPtr(saslMechanism)
268 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
269 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
270 networks = append(networks, net)
271 }
272 if err := rows.Err(); err != nil {
273 return nil, err
274 }
275
276 return networks, nil
277}
278
279func (db *DB) StoreNetwork(username string, network *Network) error {
280 db.lock.Lock()
281 defer db.lock.Unlock()
282
283 netName := toStringPtr(network.Name)
284 netUsername := toStringPtr(network.Username)
285 realname := toStringPtr(network.Realname)
286 pass := toStringPtr(network.Pass)
287 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
288
289 var saslMechanism, saslPlainUsername, saslPlainPassword *string
290 if network.SASL.Mechanism != "" {
291 saslMechanism = &network.SASL.Mechanism
292 switch network.SASL.Mechanism {
293 case "PLAIN":
294 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
295 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
296 default:
297 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
298 }
299 }
300
301 var err error
302 if network.ID != 0 {
303 _, err = db.db.Exec(`UPDATE Network
304 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
305 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
306 WHERE id = ?`,
307 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
308 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
309 } else {
310 var res sql.Result
311 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
312 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
313 sasl_plain_password)
314 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
315 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
316 saslMechanism, saslPlainUsername, saslPlainPassword)
317 if err != nil {
318 return err
319 }
320 network.ID, err = res.LastInsertId()
321 }
322 return err
323}
324
325func (db *DB) DeleteNetwork(id int64) error {
326 db.lock.Lock()
327 defer db.lock.Unlock()
328
329 tx, err := db.db.Begin()
330 if err != nil {
331 return err
332 }
333 defer tx.Rollback()
334
335 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
336 if err != nil {
337 return err
338 }
339
340 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
341 if err != nil {
342 return err
343 }
344
345 return tx.Commit()
346}
347
348func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
349 db.lock.RLock()
350 defer db.lock.RUnlock()
351
352 rows, err := db.db.Query(`SELECT id, name, key, detached
353 FROM Channel
354 WHERE network = ?`, networkID)
355 if err != nil {
356 return nil, err
357 }
358 defer rows.Close()
359
360 var channels []Channel
361 for rows.Next() {
362 var ch Channel
363 var key *string
364 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
365 return nil, err
366 }
367 ch.Key = fromStringPtr(key)
368 channels = append(channels, ch)
369 }
370 if err := rows.Err(); err != nil {
371 return nil, err
372 }
373
374 return channels, nil
375}
376
377func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
378 db.lock.Lock()
379 defer db.lock.Unlock()
380
381 key := toStringPtr(ch.Key)
382
383 var err error
384 if ch.ID != 0 {
385 _, err = db.db.Exec(`UPDATE Channel
386 SET network = ?, name = ?, key = ?, detached = ?
387 WHERE id = ?`,
388 networkID, ch.Name, key, ch.Detached, ch.ID)
389 } else {
390 var res sql.Result
391 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
392 VALUES (?, ?, ?, ?)`,
393 networkID, ch.Name, key, ch.Detached)
394 if err != nil {
395 return err
396 }
397 ch.ID, err = res.LastInsertId()
398 }
399 return err
400}
401
402func (db *DB) DeleteChannel(networkID int64, name string) error {
403 db.lock.Lock()
404 defer db.lock.Unlock()
405
406 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
407 return err
408}
Note: See TracBrowser for help on using the repository browser.