source: code/trunk/db.go@ 235

Last change on this file since 235 was 207, checked in by contact, 5 years ago

Fix SQL error logged on JOIN

Closes: https://todo.sr.ht/~emersion/soju/40

File size: 6.9 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[77]6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
[85]13 Password string // hashed
[77]14}
15
[95]16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
[77]25type Network struct {
26 ID int64
[118]27 Name string
[77]28 Addr string
29 Nick string
30 Username string
31 Realname string
[93]32 Pass string
[95]33 SASL SASL
[77]34}
35
[149]36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
[77]43type Channel struct {
44 ID int64
45 Name string
[146]46 Key string
[77]47}
48
[207]49var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
50
[77]51type DB struct {
[81]52 lock sync.RWMutex
[77]53 db *sql.DB
54}
55
56func OpenSQLDB(driver, source string) (*DB, error) {
57 db, err := sql.Open(driver, source)
58 if err != nil {
59 return nil, err
60 }
61 return &DB{db: db}, nil
62}
63
64func (db *DB) Close() error {
65 db.lock.Lock()
66 defer db.lock.Unlock()
67 return db.Close()
68}
69
[95]70func fromStringPtr(ptr *string) string {
71 if ptr == nil {
72 return ""
73 }
74 return *ptr
75}
76
77func toStringPtr(s string) *string {
78 if s == "" {
79 return nil
80 }
81 return &s
82}
83
[77]84func (db *DB) ListUsers() ([]User, error) {
[81]85 db.lock.RLock()
86 defer db.lock.RUnlock()
[77]87
88 rows, err := db.db.Query("SELECT username, password FROM User")
89 if err != nil {
90 return nil, err
91 }
92 defer rows.Close()
93
94 var users []User
95 for rows.Next() {
96 var user User
97 var password *string
98 if err := rows.Scan(&user.Username, &password); err != nil {
99 return nil, err
100 }
[95]101 user.Password = fromStringPtr(password)
[77]102 users = append(users, user)
103 }
104 if err := rows.Err(); err != nil {
105 return nil, err
106 }
107
108 return users, nil
109}
110
[173]111func (db *DB) GetUser(username string) (*User, error) {
112 db.lock.RLock()
113 defer db.lock.RUnlock()
114
115 user := &User{Username: username}
116
117 var password *string
118 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
119 if err := row.Scan(&password); err != nil {
120 return nil, err
121 }
122 user.Password = fromStringPtr(password)
123 return user, nil
124}
125
[84]126func (db *DB) CreateUser(user *User) error {
127 db.lock.Lock()
128 defer db.lock.Unlock()
129
[95]130 password := toStringPtr(user.Password)
[89]131 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
132 return err
[84]133}
134
[77]135func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]136 db.lock.RLock()
137 defer db.lock.RUnlock()
[77]138
[118]139 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[95]140 sasl_mechanism, sasl_plain_username, sasl_plain_password
141 FROM Network
142 WHERE user = ?`,
143 username)
[77]144 if err != nil {
145 return nil, err
146 }
147 defer rows.Close()
148
149 var networks []Network
150 for rows.Next() {
151 var net Network
[118]152 var name, username, realname, pass *string
[95]153 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]154 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[95]155 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
156 if err != nil {
[77]157 return nil, err
158 }
[118]159 net.Name = fromStringPtr(name)
[95]160 net.Username = fromStringPtr(username)
161 net.Realname = fromStringPtr(realname)
162 net.Pass = fromStringPtr(pass)
163 net.SASL.Mechanism = fromStringPtr(saslMechanism)
164 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
165 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]166 networks = append(networks, net)
167 }
168 if err := rows.Err(); err != nil {
169 return nil, err
170 }
171
172 return networks, nil
173}
174
[90]175func (db *DB) StoreNetwork(username string, network *Network) error {
176 db.lock.Lock()
177 defer db.lock.Unlock()
178
[118]179 netName := toStringPtr(network.Name)
[95]180 netUsername := toStringPtr(network.Username)
181 realname := toStringPtr(network.Realname)
182 pass := toStringPtr(network.Pass)
183
184 var saslMechanism, saslPlainUsername, saslPlainPassword *string
185 if network.SASL.Mechanism != "" {
186 saslMechanism = &network.SASL.Mechanism
187 switch network.SASL.Mechanism {
188 case "PLAIN":
189 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
190 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]191 default:
192 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]193 }
[90]194 }
195
196 var err error
197 if network.ID != 0 {
[93]198 _, err = db.db.Exec(`UPDATE Network
[118]199 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
[95]200 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]201 WHERE id = ?`,
[118]202 netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]203 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]204 } else {
205 var res sql.Result
[118]206 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[95]207 realname, pass, sasl_mechanism, sasl_plain_username,
208 sasl_plain_password)
[118]209 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
210 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]211 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]212 if err != nil {
213 return err
214 }
215 network.ID, err = res.LastInsertId()
216 }
217 return err
218}
219
[202]220func (db *DB) DeleteNetwork(id int64) error {
221 db.lock.Lock()
222 defer db.lock.Unlock()
223
224 tx, err := db.db.Begin()
225 if err != nil {
226 return err
227 }
228 defer tx.Rollback()
229
230 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
231 if err != nil {
232 return err
233 }
234
235 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
236 if err != nil {
237 return err
238 }
239
240 return tx.Commit()
241}
242
[77]243func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]244 db.lock.RLock()
245 defer db.lock.RUnlock()
[77]246
[146]247 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]248 if err != nil {
249 return nil, err
250 }
251 defer rows.Close()
252
253 var channels []Channel
254 for rows.Next() {
255 var ch Channel
[146]256 var key *string
257 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]258 return nil, err
259 }
[146]260 ch.Key = fromStringPtr(key)
[77]261 channels = append(channels, ch)
262 }
263 if err := rows.Err(); err != nil {
264 return nil, err
265 }
266
267 return channels, nil
268}
[89]269
[207]270func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
271 db.lock.RLock()
272 defer db.lock.RUnlock()
273
274 ch := &Channel{Name: name}
275
276 var key *string
277 row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
278 if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
279 return nil, ErrNoSuchChannel
280 } else if err != nil {
281 return nil, err
282 }
283 ch.Key = fromStringPtr(key)
284 return ch, nil
285}
286
[89]287func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
288 db.lock.Lock()
289 defer db.lock.Unlock()
290
[146]291 key := toStringPtr(ch.Key)
[149]292
293 var err error
294 if ch.ID != 0 {
295 _, err = db.db.Exec(`UPDATE Channel
296 SET network = ?, name = ?, key = ?
297 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
298 } else {
299 var res sql.Result
300 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
301 VALUES (?, ?, ?)`, networkID, ch.Name, key)
302 if err != nil {
303 return err
304 }
305 ch.ID, err = res.LastInsertId()
306 }
[89]307 return err
308}
309
310func (db *DB) DeleteChannel(networkID int64, name string) error {
311 db.lock.Lock()
312 defer db.lock.Unlock()
313
314 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
315 return err
316}
Note: See TracBrowser for help on using the repository browser.