source: code/trunk/db.go@ 223

Last change on this file since 223 was 207, checked in by contact, 5 years ago

Fix SQL error logged on JOIN

Closes: https://todo.sr.ht/~emersion/soju/40

File size: 6.9 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
13 Password string // hashed
14}
15
16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
25type Network struct {
26 ID int64
27 Name string
28 Addr string
29 Nick string
30 Username string
31 Realname string
32 Pass string
33 SASL SASL
34}
35
36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
43type Channel struct {
44 ID int64
45 Name string
46 Key string
47}
48
49var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
50
51type DB struct {
52 lock sync.RWMutex
53 db *sql.DB
54}
55
56func OpenSQLDB(driver, source string) (*DB, error) {
57 db, err := sql.Open(driver, source)
58 if err != nil {
59 return nil, err
60 }
61 return &DB{db: db}, nil
62}
63
64func (db *DB) Close() error {
65 db.lock.Lock()
66 defer db.lock.Unlock()
67 return db.Close()
68}
69
70func fromStringPtr(ptr *string) string {
71 if ptr == nil {
72 return ""
73 }
74 return *ptr
75}
76
77func toStringPtr(s string) *string {
78 if s == "" {
79 return nil
80 }
81 return &s
82}
83
84func (db *DB) ListUsers() ([]User, error) {
85 db.lock.RLock()
86 defer db.lock.RUnlock()
87
88 rows, err := db.db.Query("SELECT username, password FROM User")
89 if err != nil {
90 return nil, err
91 }
92 defer rows.Close()
93
94 var users []User
95 for rows.Next() {
96 var user User
97 var password *string
98 if err := rows.Scan(&user.Username, &password); err != nil {
99 return nil, err
100 }
101 user.Password = fromStringPtr(password)
102 users = append(users, user)
103 }
104 if err := rows.Err(); err != nil {
105 return nil, err
106 }
107
108 return users, nil
109}
110
111func (db *DB) GetUser(username string) (*User, error) {
112 db.lock.RLock()
113 defer db.lock.RUnlock()
114
115 user := &User{Username: username}
116
117 var password *string
118 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
119 if err := row.Scan(&password); err != nil {
120 return nil, err
121 }
122 user.Password = fromStringPtr(password)
123 return user, nil
124}
125
126func (db *DB) CreateUser(user *User) error {
127 db.lock.Lock()
128 defer db.lock.Unlock()
129
130 password := toStringPtr(user.Password)
131 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
132 return err
133}
134
135func (db *DB) ListNetworks(username string) ([]Network, error) {
136 db.lock.RLock()
137 defer db.lock.RUnlock()
138
139 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
140 sasl_mechanism, sasl_plain_username, sasl_plain_password
141 FROM Network
142 WHERE user = ?`,
143 username)
144 if err != nil {
145 return nil, err
146 }
147 defer rows.Close()
148
149 var networks []Network
150 for rows.Next() {
151 var net Network
152 var name, username, realname, pass *string
153 var saslMechanism, saslPlainUsername, saslPlainPassword *string
154 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
155 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
156 if err != nil {
157 return nil, err
158 }
159 net.Name = fromStringPtr(name)
160 net.Username = fromStringPtr(username)
161 net.Realname = fromStringPtr(realname)
162 net.Pass = fromStringPtr(pass)
163 net.SASL.Mechanism = fromStringPtr(saslMechanism)
164 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
165 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
166 networks = append(networks, net)
167 }
168 if err := rows.Err(); err != nil {
169 return nil, err
170 }
171
172 return networks, nil
173}
174
175func (db *DB) StoreNetwork(username string, network *Network) error {
176 db.lock.Lock()
177 defer db.lock.Unlock()
178
179 netName := toStringPtr(network.Name)
180 netUsername := toStringPtr(network.Username)
181 realname := toStringPtr(network.Realname)
182 pass := toStringPtr(network.Pass)
183
184 var saslMechanism, saslPlainUsername, saslPlainPassword *string
185 if network.SASL.Mechanism != "" {
186 saslMechanism = &network.SASL.Mechanism
187 switch network.SASL.Mechanism {
188 case "PLAIN":
189 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
190 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
191 default:
192 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
193 }
194 }
195
196 var err error
197 if network.ID != 0 {
198 _, err = db.db.Exec(`UPDATE Network
199 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
200 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
201 WHERE id = ?`,
202 netName, network.Addr, network.Nick, netUsername, realname, pass,
203 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
204 } else {
205 var res sql.Result
206 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
207 realname, pass, sasl_mechanism, sasl_plain_username,
208 sasl_plain_password)
209 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
210 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
211 saslMechanism, saslPlainUsername, saslPlainPassword)
212 if err != nil {
213 return err
214 }
215 network.ID, err = res.LastInsertId()
216 }
217 return err
218}
219
220func (db *DB) DeleteNetwork(id int64) error {
221 db.lock.Lock()
222 defer db.lock.Unlock()
223
224 tx, err := db.db.Begin()
225 if err != nil {
226 return err
227 }
228 defer tx.Rollback()
229
230 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
231 if err != nil {
232 return err
233 }
234
235 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
236 if err != nil {
237 return err
238 }
239
240 return tx.Commit()
241}
242
243func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
244 db.lock.RLock()
245 defer db.lock.RUnlock()
246
247 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
248 if err != nil {
249 return nil, err
250 }
251 defer rows.Close()
252
253 var channels []Channel
254 for rows.Next() {
255 var ch Channel
256 var key *string
257 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
258 return nil, err
259 }
260 ch.Key = fromStringPtr(key)
261 channels = append(channels, ch)
262 }
263 if err := rows.Err(); err != nil {
264 return nil, err
265 }
266
267 return channels, nil
268}
269
270func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
271 db.lock.RLock()
272 defer db.lock.RUnlock()
273
274 ch := &Channel{Name: name}
275
276 var key *string
277 row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
278 if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
279 return nil, ErrNoSuchChannel
280 } else if err != nil {
281 return nil, err
282 }
283 ch.Key = fromStringPtr(key)
284 return ch, nil
285}
286
287func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
288 db.lock.Lock()
289 defer db.lock.Unlock()
290
291 key := toStringPtr(ch.Key)
292
293 var err error
294 if ch.ID != 0 {
295 _, err = db.db.Exec(`UPDATE Channel
296 SET network = ?, name = ?, key = ?
297 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
298 } else {
299 var res sql.Result
300 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
301 VALUES (?, ?, ?)`, networkID, ch.Name, key)
302 if err != nil {
303 return err
304 }
305 ch.ID, err = res.LastInsertId()
306 }
307 return err
308}
309
310func (db *DB) DeleteChannel(networkID int64, name string) error {
311 db.lock.Lock()
312 defer db.lock.Unlock()
313
314 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
315 return err
316}
Note: See TracBrowser for help on using the repository browser.