source: code/trunk/db.go@ 251

Last change on this file since 251 was 251, checked in by admin, 5 years ago

Add ability to change password

File size: 7.2 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
13 Password string // hashed
14}
15
16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
25type Network struct {
26 ID int64
27 Name string
28 Addr string
29 Nick string
30 Username string
31 Realname string
32 Pass string
33 SASL SASL
34}
35
36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
43type Channel struct {
44 ID int64
45 Name string
46 Key string
47}
48
49var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
50
51type DB struct {
52 lock sync.RWMutex
53 db *sql.DB
54}
55
56func OpenSQLDB(driver, source string) (*DB, error) {
57 db, err := sql.Open(driver, source)
58 if err != nil {
59 return nil, err
60 }
61 return &DB{db: db}, nil
62}
63
64func (db *DB) Close() error {
65 db.lock.Lock()
66 defer db.lock.Unlock()
67 return db.Close()
68}
69
70func fromStringPtr(ptr *string) string {
71 if ptr == nil {
72 return ""
73 }
74 return *ptr
75}
76
77func toStringPtr(s string) *string {
78 if s == "" {
79 return nil
80 }
81 return &s
82}
83
84func (db *DB) ListUsers() ([]User, error) {
85 db.lock.RLock()
86 defer db.lock.RUnlock()
87
88 rows, err := db.db.Query("SELECT username, password FROM User")
89 if err != nil {
90 return nil, err
91 }
92 defer rows.Close()
93
94 var users []User
95 for rows.Next() {
96 var user User
97 var password *string
98 if err := rows.Scan(&user.Username, &password); err != nil {
99 return nil, err
100 }
101 user.Password = fromStringPtr(password)
102 users = append(users, user)
103 }
104 if err := rows.Err(); err != nil {
105 return nil, err
106 }
107
108 return users, nil
109}
110
111func (db *DB) GetUser(username string) (*User, error) {
112 db.lock.RLock()
113 defer db.lock.RUnlock()
114
115 user := &User{Username: username}
116
117 var password *string
118 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
119 if err := row.Scan(&password); err != nil {
120 return nil, err
121 }
122 user.Password = fromStringPtr(password)
123 return user, nil
124}
125
126func (db *DB) CreateUser(user *User) error {
127 db.lock.Lock()
128 defer db.lock.Unlock()
129
130 password := toStringPtr(user.Password)
131 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
132 return err
133}
134
135func (db *DB) UpdatePassword(user *User) error {
136 db.lock.Lock()
137 defer db.lock.Unlock()
138
139 password := toStringPtr(user.Password)
140 _, err := db.db.Exec(`UPDATE User
141 SET password = ?
142 WHERE username = ?`,
143 password, user.Username)
144 return err
145}
146
147func (db *DB) ListNetworks(username string) ([]Network, error) {
148 db.lock.RLock()
149 defer db.lock.RUnlock()
150
151 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
152 sasl_mechanism, sasl_plain_username, sasl_plain_password
153 FROM Network
154 WHERE user = ?`,
155 username)
156 if err != nil {
157 return nil, err
158 }
159 defer rows.Close()
160
161 var networks []Network
162 for rows.Next() {
163 var net Network
164 var name, username, realname, pass *string
165 var saslMechanism, saslPlainUsername, saslPlainPassword *string
166 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
167 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
168 if err != nil {
169 return nil, err
170 }
171 net.Name = fromStringPtr(name)
172 net.Username = fromStringPtr(username)
173 net.Realname = fromStringPtr(realname)
174 net.Pass = fromStringPtr(pass)
175 net.SASL.Mechanism = fromStringPtr(saslMechanism)
176 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
177 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
178 networks = append(networks, net)
179 }
180 if err := rows.Err(); err != nil {
181 return nil, err
182 }
183
184 return networks, nil
185}
186
187func (db *DB) StoreNetwork(username string, network *Network) error {
188 db.lock.Lock()
189 defer db.lock.Unlock()
190
191 netName := toStringPtr(network.Name)
192 netUsername := toStringPtr(network.Username)
193 realname := toStringPtr(network.Realname)
194 pass := toStringPtr(network.Pass)
195
196 var saslMechanism, saslPlainUsername, saslPlainPassword *string
197 if network.SASL.Mechanism != "" {
198 saslMechanism = &network.SASL.Mechanism
199 switch network.SASL.Mechanism {
200 case "PLAIN":
201 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
202 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
203 default:
204 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
205 }
206 }
207
208 var err error
209 if network.ID != 0 {
210 _, err = db.db.Exec(`UPDATE Network
211 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
212 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
213 WHERE id = ?`,
214 netName, network.Addr, network.Nick, netUsername, realname, pass,
215 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
216 } else {
217 var res sql.Result
218 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
219 realname, pass, sasl_mechanism, sasl_plain_username,
220 sasl_plain_password)
221 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
222 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
223 saslMechanism, saslPlainUsername, saslPlainPassword)
224 if err != nil {
225 return err
226 }
227 network.ID, err = res.LastInsertId()
228 }
229 return err
230}
231
232func (db *DB) DeleteNetwork(id int64) error {
233 db.lock.Lock()
234 defer db.lock.Unlock()
235
236 tx, err := db.db.Begin()
237 if err != nil {
238 return err
239 }
240 defer tx.Rollback()
241
242 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
243 if err != nil {
244 return err
245 }
246
247 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
248 if err != nil {
249 return err
250 }
251
252 return tx.Commit()
253}
254
255func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
256 db.lock.RLock()
257 defer db.lock.RUnlock()
258
259 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
260 if err != nil {
261 return nil, err
262 }
263 defer rows.Close()
264
265 var channels []Channel
266 for rows.Next() {
267 var ch Channel
268 var key *string
269 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
270 return nil, err
271 }
272 ch.Key = fromStringPtr(key)
273 channels = append(channels, ch)
274 }
275 if err := rows.Err(); err != nil {
276 return nil, err
277 }
278
279 return channels, nil
280}
281
282func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
283 db.lock.RLock()
284 defer db.lock.RUnlock()
285
286 ch := &Channel{Name: name}
287
288 var key *string
289 row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
290 if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
291 return nil, ErrNoSuchChannel
292 } else if err != nil {
293 return nil, err
294 }
295 ch.Key = fromStringPtr(key)
296 return ch, nil
297}
298
299func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
300 db.lock.Lock()
301 defer db.lock.Unlock()
302
303 key := toStringPtr(ch.Key)
304
305 var err error
306 if ch.ID != 0 {
307 _, err = db.db.Exec(`UPDATE Channel
308 SET network = ?, name = ?, key = ?
309 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
310 } else {
311 var res sql.Result
312 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
313 VALUES (?, ?, ?)`, networkID, ch.Name, key)
314 if err != nil {
315 return err
316 }
317 ch.ID, err = res.LastInsertId()
318 }
319 return err
320}
321
322func (db *DB) DeleteChannel(networkID int64, name string) error {
323 db.lock.Lock()
324 defer db.lock.Unlock()
325
326 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
327 return err
328}
Note: See TracBrowser for help on using the repository browser.