source: code/trunk/db.go@ 202

Last change on this file since 202 was 202, checked in by contact, 5 years ago

Add "network delete" service command

And add all the infrastructure required to stop and delete networks.

References: https://todo.sr.ht/~emersion/soju/17

File size: 6.4 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[77]6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
[85]13 Password string // hashed
[77]14}
15
[95]16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
[77]25type Network struct {
26 ID int64
[118]27 Name string
[77]28 Addr string
29 Nick string
30 Username string
31 Realname string
[93]32 Pass string
[95]33 SASL SASL
[77]34}
35
[149]36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
[77]43type Channel struct {
44 ID int64
45 Name string
[146]46 Key string
[77]47}
48
49type DB struct {
[81]50 lock sync.RWMutex
[77]51 db *sql.DB
52}
53
54func OpenSQLDB(driver, source string) (*DB, error) {
55 db, err := sql.Open(driver, source)
56 if err != nil {
57 return nil, err
58 }
59 return &DB{db: db}, nil
60}
61
62func (db *DB) Close() error {
63 db.lock.Lock()
64 defer db.lock.Unlock()
65 return db.Close()
66}
67
[95]68func fromStringPtr(ptr *string) string {
69 if ptr == nil {
70 return ""
71 }
72 return *ptr
73}
74
75func toStringPtr(s string) *string {
76 if s == "" {
77 return nil
78 }
79 return &s
80}
81
[77]82func (db *DB) ListUsers() ([]User, error) {
[81]83 db.lock.RLock()
84 defer db.lock.RUnlock()
[77]85
86 rows, err := db.db.Query("SELECT username, password FROM User")
87 if err != nil {
88 return nil, err
89 }
90 defer rows.Close()
91
92 var users []User
93 for rows.Next() {
94 var user User
95 var password *string
96 if err := rows.Scan(&user.Username, &password); err != nil {
97 return nil, err
98 }
[95]99 user.Password = fromStringPtr(password)
[77]100 users = append(users, user)
101 }
102 if err := rows.Err(); err != nil {
103 return nil, err
104 }
105
106 return users, nil
107}
108
[173]109func (db *DB) GetUser(username string) (*User, error) {
110 db.lock.RLock()
111 defer db.lock.RUnlock()
112
113 user := &User{Username: username}
114
115 var password *string
116 row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
117 if err := row.Scan(&password); err != nil {
118 return nil, err
119 }
120 user.Password = fromStringPtr(password)
121 return user, nil
122}
123
[84]124func (db *DB) CreateUser(user *User) error {
125 db.lock.Lock()
126 defer db.lock.Unlock()
127
[95]128 password := toStringPtr(user.Password)
[89]129 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
130 return err
[84]131}
132
[77]133func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]134 db.lock.RLock()
135 defer db.lock.RUnlock()
[77]136
[118]137 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[95]138 sasl_mechanism, sasl_plain_username, sasl_plain_password
139 FROM Network
140 WHERE user = ?`,
141 username)
[77]142 if err != nil {
143 return nil, err
144 }
145 defer rows.Close()
146
147 var networks []Network
148 for rows.Next() {
149 var net Network
[118]150 var name, username, realname, pass *string
[95]151 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]152 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[95]153 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
154 if err != nil {
[77]155 return nil, err
156 }
[118]157 net.Name = fromStringPtr(name)
[95]158 net.Username = fromStringPtr(username)
159 net.Realname = fromStringPtr(realname)
160 net.Pass = fromStringPtr(pass)
161 net.SASL.Mechanism = fromStringPtr(saslMechanism)
162 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
163 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]164 networks = append(networks, net)
165 }
166 if err := rows.Err(); err != nil {
167 return nil, err
168 }
169
170 return networks, nil
171}
172
[90]173func (db *DB) StoreNetwork(username string, network *Network) error {
174 db.lock.Lock()
175 defer db.lock.Unlock()
176
[118]177 netName := toStringPtr(network.Name)
[95]178 netUsername := toStringPtr(network.Username)
179 realname := toStringPtr(network.Realname)
180 pass := toStringPtr(network.Pass)
181
182 var saslMechanism, saslPlainUsername, saslPlainPassword *string
183 if network.SASL.Mechanism != "" {
184 saslMechanism = &network.SASL.Mechanism
185 switch network.SASL.Mechanism {
186 case "PLAIN":
187 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
188 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]189 default:
190 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]191 }
[90]192 }
193
194 var err error
195 if network.ID != 0 {
[93]196 _, err = db.db.Exec(`UPDATE Network
[118]197 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
[95]198 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]199 WHERE id = ?`,
[118]200 netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]201 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]202 } else {
203 var res sql.Result
[118]204 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[95]205 realname, pass, sasl_mechanism, sasl_plain_username,
206 sasl_plain_password)
[118]207 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
208 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]209 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]210 if err != nil {
211 return err
212 }
213 network.ID, err = res.LastInsertId()
214 }
215 return err
216}
217
[202]218func (db *DB) DeleteNetwork(id int64) error {
219 db.lock.Lock()
220 defer db.lock.Unlock()
221
222 tx, err := db.db.Begin()
223 if err != nil {
224 return err
225 }
226 defer tx.Rollback()
227
228 _, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
229 if err != nil {
230 return err
231 }
232
233 _, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
234 if err != nil {
235 return err
236 }
237
238 return tx.Commit()
239}
240
[77]241func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]242 db.lock.RLock()
243 defer db.lock.RUnlock()
[77]244
[146]245 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]246 if err != nil {
247 return nil, err
248 }
249 defer rows.Close()
250
251 var channels []Channel
252 for rows.Next() {
253 var ch Channel
[146]254 var key *string
255 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]256 return nil, err
257 }
[146]258 ch.Key = fromStringPtr(key)
[77]259 channels = append(channels, ch)
260 }
261 if err := rows.Err(); err != nil {
262 return nil, err
263 }
264
265 return channels, nil
266}
[89]267
268func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
269 db.lock.Lock()
270 defer db.lock.Unlock()
271
[146]272 key := toStringPtr(ch.Key)
[149]273
274 var err error
275 if ch.ID != 0 {
276 _, err = db.db.Exec(`UPDATE Channel
277 SET network = ?, name = ?, key = ?
278 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
279 } else {
280 var res sql.Result
281 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
282 VALUES (?, ?, ?)`, networkID, ch.Name, key)
283 if err != nil {
284 return err
285 }
286 ch.ID, err = res.LastInsertId()
287 }
[89]288 return err
289}
290
291func (db *DB) DeleteChannel(networkID int64, name string) error {
292 db.lock.Lock()
293 defer db.lock.Unlock()
294
295 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
296 return err
297}
Note: See TracBrowser for help on using the repository browser.