source: code/trunk/db.go@ 263

Last change on this file since 263 was 263, checked in by delthas, 5 years ago

Add support for custom network on-connect commands

Some servers use custom IRC bots with custom commands for registering to
specific services after connection.

This adds support for setting custom raw IRC messages, that will be
sent after registering to a network.

It also adds support for a custom flag.Value type for string
slice flags (flags taking several string values).

File size: 9.5 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[263]6 "strings"
[77]7 "sync"
8
9 _ "github.com/mattn/go-sqlite3"
10)
11
12type User struct {
13 Username string
[85]14 Password string // hashed
[77]15}
16
[95]17type SASL struct {
18 Mechanism string
19
20 Plain struct {
21 Username string
22 Password string
23 }
24}
25
[77]26type Network struct {
[263]27 ID int64
28 Name string
29 Addr string
30 Nick string
31 Username string
32 Realname string
33 Pass string
34 ConnectCommands []string
35 SASL SASL
[77]36}
37
[149]38func (net *Network) GetName() string {
39 if net.Name != "" {
40 return net.Name
41 }
42 return net.Addr
43}
44
[77]45type Channel struct {
46 ID int64
47 Name string
[146]48 Key string
[77]49}
50
[207]51var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
52
[255]53const schema = `
54CREATE TABLE User (
55 username VARCHAR(255) PRIMARY KEY,
56 password VARCHAR(255) NOT NULL
57);
58
59CREATE TABLE Network (
60 id INTEGER PRIMARY KEY,
61 name VARCHAR(255),
62 user VARCHAR(255) NOT NULL,
63 addr VARCHAR(255) NOT NULL,
64 nick VARCHAR(255) NOT NULL,
65 username VARCHAR(255),
66 realname VARCHAR(255),
67 pass VARCHAR(255),
[263]68 connect_commands VARCHAR(1023),
[255]69 sasl_mechanism VARCHAR(255),
70 sasl_plain_username VARCHAR(255),
71 sasl_plain_password VARCHAR(255),
72 FOREIGN KEY(user) REFERENCES User(username),
73 UNIQUE(user, addr, nick)
74);
75
76CREATE TABLE Channel (
77 id INTEGER PRIMARY KEY,
78 network INTEGER NOT NULL,
79 name VARCHAR(255) NOT NULL,
80 key VARCHAR(255),
81 FOREIGN KEY(network) REFERENCES Network(id),
82 UNIQUE(network, name)
83);
84`
85
86var migrations = []string{
87 "", // migration #0 is reserved for schema initialization
[263]88 "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
[255]89}
90
[77]91type DB struct {
[81]92 lock sync.RWMutex
[77]93 db *sql.DB
94}
95
96func OpenSQLDB(driver, source string) (*DB, error) {
[255]97 sqlDB, err := sql.Open(driver, source)
[77]98 if err != nil {
99 return nil, err
100 }
[255]101
102 db := &DB{db: sqlDB}
103 if err := db.upgrade(); err != nil {
104 return nil, err
105 }
106
107 return db, nil
[77]108}
109
110func (db *DB) Close() error {
111 db.lock.Lock()
112 defer db.lock.Unlock()
113 return db.Close()
114}
115
[255]116func (db *DB) upgrade() error {
117 db.lock.Lock()
118 defer db.lock.Unlock()
119
120 var version int
121 if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
122 return fmt.Errorf("failed to query schema version: %v", err)
123 }
124
125 if version == len(migrations) {
126 return nil
127 } else if version > len(migrations) {
128 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
129 }
130
131 tx, err := db.db.Begin()
132 if err != nil {
133 return err
134 }
135 defer tx.Rollback()
136
137 if version == 0 {
138 if _, err := tx.Exec(schema); err != nil {
139 return fmt.Errorf("failed to initialize schema: %v", err)
140 }
141 } else {
142 for i := version; i < len(migrations); i++ {
143 if _, err := tx.Exec(migrations[i]); err != nil {
144 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
145 }
146 }
147 }
148
149 // For some reason prepared statements don't work here
150 _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
151 if err != nil {
152 return fmt.Errorf("failed to bump schema version: %v", err)
153 }
154
155 return tx.Commit()
156}
157
[95]158func fromStringPtr(ptr *string) string {
159 if ptr == nil {
160 return ""
161 }
162 return *ptr
163}
164
165func toStringPtr(s string) *string {
166 if s == "" {
167 return nil
168 }
169 return &s
170}
171
[77]172func (db *DB) ListUsers() ([]User, error) {
[81]173 db.lock.RLock()
174 defer db.lock.RUnlock()
[77]175
176 rows, err := db.db.Query("SELECT username, password FROM User")
177 if err != nil {
178 return nil, err
179 }
180 defer rows.Close()
181
182 var users []User
183 for rows.Next() {
184 var user User
185 var password *string
186 if err := rows.Scan(&user.Username, &password); err != nil {
187 return nil, err
188 }
[95]189 user.Password = fromStringPtr(password)
[77]190 users = append(users, user)
191 }
192 if err := rows.Err(); err != nil {
193 return nil, err
194 }
195
196 return users, nil
197}
198
[173]199func (db *DB) GetUser(username string) (*User, error) {
200 db.lock.RLock()
201 defer db.lock.RUnlock()
202
203 user := &User{Username: username}
204
205 var password *string
206 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
207 if err := row.Scan(&password); err != nil {
208 return nil, err
209 }
210 user.Password = fromStringPtr(password)
211 return user, nil
212}
213
[84]214func (db *DB) CreateUser(user *User) error {
215 db.lock.Lock()
216 defer db.lock.Unlock()
217
[95]218 password := toStringPtr(user.Password)
[89]219 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
220 return err
[84]221}
222
[251]223func (db *DB) UpdatePassword(user *User) error {
224 db.lock.Lock()
225 defer db.lock.Unlock()
226
227 password := toStringPtr(user.Password)
228 _, err := db.db.Exec(`UPDATE User
229 SET password = ?
230 WHERE username = ?`,
231 password, user.Username)
232 return err
233}
234
[77]235func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]236 db.lock.RLock()
237 defer db.lock.RUnlock()
[77]238
[118]239 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[263]240 connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
[95]241 FROM Network
242 WHERE user = ?`,
243 username)
[77]244 if err != nil {
245 return nil, err
246 }
247 defer rows.Close()
248
249 var networks []Network
250 for rows.Next() {
251 var net Network
[263]252 var name, username, realname, pass, connectCommands *string
[95]253 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]254 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[263]255 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
[95]256 if err != nil {
[77]257 return nil, err
258 }
[118]259 net.Name = fromStringPtr(name)
[95]260 net.Username = fromStringPtr(username)
261 net.Realname = fromStringPtr(realname)
262 net.Pass = fromStringPtr(pass)
[263]263 if connectCommands != nil {
264 net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
265 }
[95]266 net.SASL.Mechanism = fromStringPtr(saslMechanism)
267 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
268 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]269 networks = append(networks, net)
270 }
271 if err := rows.Err(); err != nil {
272 return nil, err
273 }
274
275 return networks, nil
276}
277
[90]278func (db *DB) StoreNetwork(username string, network *Network) error {
279 db.lock.Lock()
280 defer db.lock.Unlock()
281
[118]282 netName := toStringPtr(network.Name)
[95]283 netUsername := toStringPtr(network.Username)
284 realname := toStringPtr(network.Realname)
285 pass := toStringPtr(network.Pass)
[263]286 connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
[95]287
288 var saslMechanism, saslPlainUsername, saslPlainPassword *string
289 if network.SASL.Mechanism != "" {
290 saslMechanism = &network.SASL.Mechanism
291 switch network.SASL.Mechanism {
292 case "PLAIN":
293 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
294 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]295 default:
296 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]297 }
[90]298 }
299
300 var err error
301 if network.ID != 0 {
[93]302 _, err = db.db.Exec(`UPDATE Network
[263]303 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
[95]304 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]305 WHERE id = ?`,
[263]306 netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[95]307 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]308 } else {
309 var res sql.Result
[118]310 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[263]311 realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
[95]312 sasl_plain_password)
[263]313 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
314 username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
[95]315 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]316 if err != nil {
317 return err
318 }
319 network.ID, err = res.LastInsertId()
320 }
321 return err
322}
323
[202]324func (db *DB) DeleteNetwork(id int64) error {
325 db.lock.Lock()
326 defer db.lock.Unlock()
327
328 tx, err := db.db.Begin()
329 if err != nil {
330 return err
331 }
332 defer tx.Rollback()
333
334 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
335 if err != nil {
336 return err
337 }
338
339 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
340 if err != nil {
341 return err
342 }
343
344 return tx.Commit()
345}
346
[77]347func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]348 db.lock.RLock()
349 defer db.lock.RUnlock()
[77]350
[146]351 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]352 if err != nil {
353 return nil, err
354 }
355 defer rows.Close()
356
357 var channels []Channel
358 for rows.Next() {
359 var ch Channel
[146]360 var key *string
361 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]362 return nil, err
363 }
[146]364 ch.Key = fromStringPtr(key)
[77]365 channels = append(channels, ch)
366 }
367 if err := rows.Err(); err != nil {
368 return nil, err
369 }
370
371 return channels, nil
372}
[89]373
[207]374func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
375 db.lock.RLock()
376 defer db.lock.RUnlock()
377
378 ch := &Channel{Name: name}
379
380 var key *string
381 row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
382 if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
383 return nil, ErrNoSuchChannel
384 } else if err != nil {
385 return nil, err
386 }
387 ch.Key = fromStringPtr(key)
388 return ch, nil
389}
390
[89]391func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
392 db.lock.Lock()
393 defer db.lock.Unlock()
394
[146]395 key := toStringPtr(ch.Key)
[149]396
397 var err error
398 if ch.ID != 0 {
399 _, err = db.db.Exec(`UPDATE Channel
400 SET network = ?, name = ?, key = ?
401 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
402 } else {
403 var res sql.Result
404 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
405 VALUES (?, ?, ?)`, networkID, ch.Name, key)
406 if err != nil {
407 return err
408 }
409 ch.ID, err = res.LastInsertId()
410 }
[89]411 return err
412}
413
414func (db *DB) DeleteChannel(networkID int64, name string) error {
415 db.lock.Lock()
416 defer db.lock.Unlock()
417
418 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
419 return err
420}
Note: See TracBrowser for help on using the repository browser.