source: code/trunk/db.go@ 205

Last change on this file since 205 was 202, checked in by contact, 5 years ago

Add "network delete" service command

And add all the infrastructure required to stop and delete networks.

References: https://todo.sr.ht/~emersion/soju/17

File size: 6.4 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
13 Password string // hashed
14}
15
16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
25type Network struct {
26 ID int64
27 Name string
28 Addr string
29 Nick string
30 Username string
31 Realname string
32 Pass string
33 SASL SASL
34}
35
36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
43type Channel struct {
44 ID int64
45 Name string
46 Key string
47}
48
49type DB struct {
50 lock sync.RWMutex
51 db *sql.DB
52}
53
54func OpenSQLDB(driver, source string) (*DB, error) {
55 db, err := sql.Open(driver, source)
56 if err != nil {
57 return nil, err
58 }
59 return &DB{db: db}, nil
60}
61
62func (db *DB) Close() error {
63 db.lock.Lock()
64 defer db.lock.Unlock()
65 return db.Close()
66}
67
68func fromStringPtr(ptr *string) string {
69 if ptr == nil {
70 return ""
71 }
72 return *ptr
73}
74
75func toStringPtr(s string) *string {
76 if s == "" {
77 return nil
78 }
79 return &s
80}
81
82func (db *DB) ListUsers() ([]User, error) {
83 db.lock.RLock()
84 defer db.lock.RUnlock()
85
86 rows, err := db.db.Query("SELECT username, password FROM User")
87 if err != nil {
88 return nil, err
89 }
90 defer rows.Close()
91
92 var users []User
93 for rows.Next() {
94 var user User
95 var password *string
96 if err := rows.Scan(&user.Username, &password); err != nil {
97 return nil, err
98 }
99 user.Password = fromStringPtr(password)
100 users = append(users, user)
101 }
102 if err := rows.Err(); err != nil {
103 return nil, err
104 }
105
106 return users, nil
107}
108
109func (db *DB) GetUser(username string) (*User, error) {
110 db.lock.RLock()
111 defer db.lock.RUnlock()
112
113 user := &User{Username: username}
114
115 var password *string
116 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
117 if err := row.Scan(&password); err != nil {
118 return nil, err
119 }
120 user.Password = fromStringPtr(password)
121 return user, nil
122}
123
124func (db *DB) CreateUser(user *User) error {
125 db.lock.Lock()
126 defer db.lock.Unlock()
127
128 password := toStringPtr(user.Password)
129 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
130 return err
131}
132
133func (db *DB) ListNetworks(username string) ([]Network, error) {
134 db.lock.RLock()
135 defer db.lock.RUnlock()
136
137 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
138 sasl_mechanism, sasl_plain_username, sasl_plain_password
139 FROM Network
140 WHERE user = ?`,
141 username)
142 if err != nil {
143 return nil, err
144 }
145 defer rows.Close()
146
147 var networks []Network
148 for rows.Next() {
149 var net Network
150 var name, username, realname, pass *string
151 var saslMechanism, saslPlainUsername, saslPlainPassword *string
152 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
153 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
154 if err != nil {
155 return nil, err
156 }
157 net.Name = fromStringPtr(name)
158 net.Username = fromStringPtr(username)
159 net.Realname = fromStringPtr(realname)
160 net.Pass = fromStringPtr(pass)
161 net.SASL.Mechanism = fromStringPtr(saslMechanism)
162 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
163 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
164 networks = append(networks, net)
165 }
166 if err := rows.Err(); err != nil {
167 return nil, err
168 }
169
170 return networks, nil
171}
172
173func (db *DB) StoreNetwork(username string, network *Network) error {
174 db.lock.Lock()
175 defer db.lock.Unlock()
176
177 netName := toStringPtr(network.Name)
178 netUsername := toStringPtr(network.Username)
179 realname := toStringPtr(network.Realname)
180 pass := toStringPtr(network.Pass)
181
182 var saslMechanism, saslPlainUsername, saslPlainPassword *string
183 if network.SASL.Mechanism != "" {
184 saslMechanism = &network.SASL.Mechanism
185 switch network.SASL.Mechanism {
186 case "PLAIN":
187 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
188 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
189 default:
190 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
191 }
192 }
193
194 var err error
195 if network.ID != 0 {
196 _, err = db.db.Exec(`UPDATE Network
197 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
198 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
199 WHERE id = ?`,
200 netName, network.Addr, network.Nick, netUsername, realname, pass,
201 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
202 } else {
203 var res sql.Result
204 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
205 realname, pass, sasl_mechanism, sasl_plain_username,
206 sasl_plain_password)
207 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
208 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
209 saslMechanism, saslPlainUsername, saslPlainPassword)
210 if err != nil {
211 return err
212 }
213 network.ID, err = res.LastInsertId()
214 }
215 return err
216}
217
218func (db *DB) DeleteNetwork(id int64) error {
219 db.lock.Lock()
220 defer db.lock.Unlock()
221
222 tx, err := db.db.Begin()
223 if err != nil {
224 return err
225 }
226 defer tx.Rollback()
227
228 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
229 if err != nil {
230 return err
231 }
232
233 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
234 if err != nil {
235 return err
236 }
237
238 return tx.Commit()
239}
240
241func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
242 db.lock.RLock()
243 defer db.lock.RUnlock()
244
245 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
246 if err != nil {
247 return nil, err
248 }
249 defer rows.Close()
250
251 var channels []Channel
252 for rows.Next() {
253 var ch Channel
254 var key *string
255 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
256 return nil, err
257 }
258 ch.Key = fromStringPtr(key)
259 channels = append(channels, ch)
260 }
261 if err := rows.Err(); err != nil {
262 return nil, err
263 }
264
265 return channels, nil
266}
267
268func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
269 db.lock.Lock()
270 defer db.lock.Unlock()
271
272 key := toStringPtr(ch.Key)
273
274 var err error
275 if ch.ID != 0 {
276 _, err = db.db.Exec(`UPDATE Channel
277 SET network = ?, name = ?, key = ?
278 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
279 } else {
280 var res sql.Result
281 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
282 VALUES (?, ?, ?)`, networkID, ch.Name, key)
283 if err != nil {
284 return err
285 }
286 ch.ID, err = res.LastInsertId()
287 }
288 return err
289}
290
291func (db *DB) DeleteChannel(networkID int64, name string) error {
292 db.lock.Lock()
293 defer db.lock.Unlock()
294
295 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
296 return err
297}
Note: See TracBrowser for help on using the repository browser.