source: code/trunk/db.go@ 290

Last change on this file since 290 was 284, checked in by contact, 5 years ago

Add support for detached channels

Channels can now be detached by leaving them with the reason "detach",
and re-attached by joining them again. Upon detaching the channel is
no longer forwarded to downstream connections. Upon re-attaching the
history buffer is sent.

File size: 9.3 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Username string
[85]14 Password string // hashed
[77]15}
16
[95]17type SASL struct {
18 Mechanism string
19
20 Plain struct {
21 Username string
22 Password string
23 }
24}
25
[77]26type Network struct {
[263]27 ID int64
28 Name string
29 Addr string
30 Nick string
31 Username string
32 Realname string
33 Pass string
34 ConnectCommands []string
35 SASL SASL
[77]36}
37
[149]38func (net *Network) GetName() string {
39 if net.Name != "" {
40 return net.Name
41 }
42 return net.Addr
43}
44
[77]45type Channel struct {
[284]46 ID int64
47 Name string
48 Key string
49 Detached bool
[77]50}
51
[255]52const schema = `
53CREATE TABLE User (
54 username VARCHAR(255) PRIMARY KEY,
55 password VARCHAR(255) NOT NULL
56);
57
58CREATE TABLE Network (
59 id INTEGER PRIMARY KEY,
60 name VARCHAR(255),
61 user VARCHAR(255) NOT NULL,
62 addr VARCHAR(255) NOT NULL,
63 nick VARCHAR(255) NOT NULL,
64 username VARCHAR(255),
65 realname VARCHAR(255),
66 pass VARCHAR(255),
[263]67 connect_commands VARCHAR(1023),
[255]68 sasl_mechanism VARCHAR(255),
69 sasl_plain_username VARCHAR(255),
70 sasl_plain_password VARCHAR(255),
71 FOREIGN KEY(user) REFERENCES User(username),
72 UNIQUE(user, addr, nick)
73);
74
75CREATE TABLE Channel (
76 id INTEGER PRIMARY KEY,
77 network INTEGER NOT NULL,
78 name VARCHAR(255) NOT NULL,
79 key VARCHAR(255),
[284]80 detached INTEGER NOT NULL DEFAULT 0,
[255]81 FOREIGN KEY(network) REFERENCES Network(id),
82 UNIQUE(network, name)
83);
84`
85
86var migrations = []string{
87 "", // migration #0 is reserved for schema initialization
[263]88 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[284]89 "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
[255]90}
91
[77]92type DB struct {
[81]93 lock sync.RWMutex
[77]94 db *sql.DB
95}
96
97func OpenSQLDB(driver, source string) (*DB, error) {
[255]98 sqlDB, err := sql.Open(driver, source)
[77]99 if err != nil {
100 return nil, err
101 }
[255]102
103 db := &DB{db: sqlDB}
104 if err := db.upgrade(); err != nil {
105 return nil, err
106 }
107
108 return db, nil
[77]109}
110
111func (db *DB) Close() error {
112 db.lock.Lock()
113 defer db.lock.Unlock()
114 return db.Close()
115}
116
[255]117func (db *DB) upgrade() error {
118 db.lock.Lock()
119 defer db.lock.Unlock()
120
121 var version int
122 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
123 return fmt.Errorf("failed to query schema version: %v", err)
124 }
125
126 if version == len(migrations) {
127 return nil
128 } else if version > len(migrations) {
129 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
130 }
131
132 tx, err := db.db.Begin()
133 if err != nil {
134 return err
135 }
136 defer tx.Rollback()
137
138 if version == 0 {
139 if _, err := tx.Exec(schema); err != nil {
140 return fmt.Errorf("failed to initialize schema: %v", err)
141 }
142 } else {
143 for i := version; i < len(migrations); i++ {
144 if _, err := tx.Exec(migrations[i]); err != nil {
145 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
146 }
147 }
148 }
149
150 // For some reason prepared statements don't work here
151 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
152 if err != nil {
153 return fmt.Errorf("failed to bump schema version: %v", err)
154 }
155
156 return tx.Commit()
157}
158
[95]159func fromStringPtr(ptr *string) string {
160 if ptr == nil {
161 return ""
162 }
163 return *ptr
164}
165
166func toStringPtr(s string) *string {
167 if s == "" {
168 return nil
169 }
170 return &s
171}
172
[77]173func (db *DB) ListUsers() ([]User, error) {
[81]174 db.lock.RLock()
175 defer db.lock.RUnlock()
[77]176
177 rows, err := db.db.Query("SELECT username, password FROM User")
178 if err != nil {
179 return nil, err
180 }
181 defer rows.Close()
182
183 var users []User
184 for rows.Next() {
185 var user User
186 var password *string
187 if err := rows.Scan(&user.Username, &password); err != nil {
188 return nil, err
189 }
[95]190 user.Password = fromStringPtr(password)
[77]191 users = append(users, user)
192 }
193 if err := rows.Err(); err != nil {
194 return nil, err
195 }
196
197 return users, nil
198}
199
[173]200func (db *DB) GetUser(username string) (*User, error) {
201 db.lock.RLock()
202 defer db.lock.RUnlock()
203
204 user := &User{Username: username}
205
206 var password *string
207 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
208 if err := row.Scan(&password); err != nil {
209 return nil, err
210 }
211 user.Password = fromStringPtr(password)
212 return user, nil
213}
214
[84]215func (db *DB) CreateUser(user *User) error {
216 db.lock.Lock()
217 defer db.lock.Unlock()
218
[95]219 password := toStringPtr(user.Password)
[89]220 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
221 return err
[84]222}
223
[251]224func (db *DB) UpdatePassword(user *User) error {
225 db.lock.Lock()
226 defer db.lock.Unlock()
227
228 password := toStringPtr(user.Password)
229 _, err := db.db.Exec(`UPDATE User
230 SET password = ?
231 WHERE username = ?`,
232 password, user.Username)
233 return err
234}
235
[77]236func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]237 db.lock.RLock()
238 defer db.lock.RUnlock()
[77]239
[118]240 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[263]241 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
[95]242 FROM Network
243 WHERE user = ?`,
244 username)
[77]245 if err != nil {
246 return nil, err
247 }
248 defer rows.Close()
249
250 var networks []Network
251 for rows.Next() {
252 var net Network
[263]253 var name, username, realname, pass, connectCommands *string
[95]254 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]255 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[263]256 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
[95]257 if err != nil {
[77]258 return nil, err
259 }
[118]260 net.Name = fromStringPtr(name)
[95]261 net.Username = fromStringPtr(username)
262 net.Realname = fromStringPtr(realname)
263 net.Pass = fromStringPtr(pass)
[263]264 if connectCommands != nil {
265 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
266 }
[95]267 net.SASL.Mechanism = fromStringPtr(saslMechanism)
268 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
269 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]270 networks = append(networks, net)
271 }
272 if err := rows.Err(); err != nil {
273 return nil, err
274 }
275
276 return networks, nil
277}
278
[90]279func (db *DB) StoreNetwork(username string, network *Network) error {
280 db.lock.Lock()
281 defer db.lock.Unlock()
282
[118]283 netName := toStringPtr(network.Name)
[95]284 netUsername := toStringPtr(network.Username)
285 realname := toStringPtr(network.Realname)
286 pass := toStringPtr(network.Pass)
[263]287 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]288
289 var saslMechanism, saslPlainUsername, saslPlainPassword *string
290 if network.SASL.Mechanism != "" {
291 saslMechanism = &network.SASL.Mechanism
292 switch network.SASL.Mechanism {
293 case "PLAIN":
294 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
295 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]296 default:
297 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]298 }
[90]299 }
300
301 var err error
302 if network.ID != 0 {
[93]303 _, err = db.db.Exec(`UPDATE Network
[263]304 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[95]305 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]306 WHERE id = ?`,
[263]307 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[95]308 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]309 } else {
310 var res sql.Result
[118]311 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]312 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[95]313 sasl_plain_password)
[263]314 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
315 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[95]316 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]317 if err != nil {
318 return err
319 }
320 network.ID, err = res.LastInsertId()
321 }
322 return err
323}
324
[202]325func (db *DB) DeleteNetwork(id int64) error {
326 db.lock.Lock()
327 defer db.lock.Unlock()
328
329 tx, err := db.db.Begin()
330 if err != nil {
331 return err
332 }
333 defer tx.Rollback()
334
335 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
336 if err != nil {
337 return err
338 }
339
340 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
341 if err != nil {
342 return err
343 }
344
345 return tx.Commit()
346}
347
[77]348func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]349 db.lock.RLock()
350 defer db.lock.RUnlock()
[77]351
[284]352 rows, err := db.db.Query(`SELECT id, name, key, detached
353 FROM Channel
354 WHERE network = ?`, networkID)
[77]355 if err != nil {
356 return nil, err
357 }
358 defer rows.Close()
359
360 var channels []Channel
361 for rows.Next() {
362 var ch Channel
[146]363 var key *string
[284]364 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil {
[77]365 return nil, err
366 }
[146]367 ch.Key = fromStringPtr(key)
[77]368 channels = append(channels, ch)
369 }
370 if err := rows.Err(); err != nil {
371 return nil, err
372 }
373
374 return channels, nil
375}
[89]376
377func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
378 db.lock.Lock()
379 defer db.lock.Unlock()
380
[146]381 key := toStringPtr(ch.Key)
[149]382
383 var err error
384 if ch.ID != 0 {
385 _, err = db.db.Exec(`UPDATE Channel
[284]386 SET network = ?, name = ?, key = ?, detached = ?
387 WHERE id = ?`,
388 networkID, ch.Name, key, ch.Detached, ch.ID)
[149]389 } else {
390 var res sql.Result
[284]391 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key, detached)
392 VALUES (?, ?, ?, ?)`,
393 networkID, ch.Name, key, ch.Detached)
[149]394 if err != nil {
395 return err
396 }
397 ch.ID, err = res.LastInsertId()
398 }
[89]399 return err
400}
401
402func (db *DB) DeleteChannel(networkID int64, name string) error {
403 db.lock.Lock()
404 defer db.lock.Unlock()
405
406 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
407 return err
408}
Note: See TracBrowser for help on using the repository browser.