source: code/trunk/db.go@ 152

Last change on this file since 152 was 149, checked in by contact, 5 years ago

Correctly set Channel.ID in DB.StoreChannel

File size: 5.7 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
13 Password string // hashed
14}
15
16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
25type Network struct {
26 ID int64
27 Name string
28 Addr string
29 Nick string
30 Username string
31 Realname string
32 Pass string
33 SASL SASL
34}
35
36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
43type Channel struct {
44 ID int64
45 Name string
46 Key string
47}
48
49type DB struct {
50 lock sync.RWMutex
51 db *sql.DB
52}
53
54func OpenSQLDB(driver, source string) (*DB, error) {
55 db, err := sql.Open(driver, source)
56 if err != nil {
57 return nil, err
58 }
59 return &DB{db: db}, nil
60}
61
62func (db *DB) Close() error {
63 db.lock.Lock()
64 defer db.lock.Unlock()
65 return db.Close()
66}
67
68func fromStringPtr(ptr *string) string {
69 if ptr == nil {
70 return ""
71 }
72 return *ptr
73}
74
75func toStringPtr(s string) *string {
76 if s == "" {
77 return nil
78 }
79 return &s
80}
81
82func (db *DB) ListUsers() ([]User, error) {
83 db.lock.RLock()
84 defer db.lock.RUnlock()
85
86 rows, err := db.db.Query("SELECT username, password FROM User")
87 if err != nil {
88 return nil, err
89 }
90 defer rows.Close()
91
92 var users []User
93 for rows.Next() {
94 var user User
95 var password *string
96 if err := rows.Scan(&user.Username, &password); err != nil {
97 return nil, err
98 }
99 user.Password = fromStringPtr(password)
100 users = append(users, user)
101 }
102 if err := rows.Err(); err != nil {
103 return nil, err
104 }
105
106 return users, nil
107}
108
109func (db *DB) CreateUser(user *User) error {
110 db.lock.Lock()
111 defer db.lock.Unlock()
112
113 password := toStringPtr(user.Password)
114 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
115 return err
116}
117
118func (db *DB) ListNetworks(username string) ([]Network, error) {
119 db.lock.RLock()
120 defer db.lock.RUnlock()
121
122 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
123 sasl_mechanism, sasl_plain_username, sasl_plain_password
124 FROM Network
125 WHERE user = ?`,
126 username)
127 if err != nil {
128 return nil, err
129 }
130 defer rows.Close()
131
132 var networks []Network
133 for rows.Next() {
134 var net Network
135 var name, username, realname, pass *string
136 var saslMechanism, saslPlainUsername, saslPlainPassword *string
137 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
138 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
139 if err != nil {
140 return nil, err
141 }
142 net.Name = fromStringPtr(name)
143 net.Username = fromStringPtr(username)
144 net.Realname = fromStringPtr(realname)
145 net.Pass = fromStringPtr(pass)
146 net.SASL.Mechanism = fromStringPtr(saslMechanism)
147 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
148 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
149 networks = append(networks, net)
150 }
151 if err := rows.Err(); err != nil {
152 return nil, err
153 }
154
155 return networks, nil
156}
157
158func (db *DB) StoreNetwork(username string, network *Network) error {
159 db.lock.Lock()
160 defer db.lock.Unlock()
161
162 netName := toStringPtr(network.Name)
163 netUsername := toStringPtr(network.Username)
164 realname := toStringPtr(network.Realname)
165 pass := toStringPtr(network.Pass)
166
167 var saslMechanism, saslPlainUsername, saslPlainPassword *string
168 if network.SASL.Mechanism != "" {
169 saslMechanism = &network.SASL.Mechanism
170 switch network.SASL.Mechanism {
171 case "PLAIN":
172 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
173 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
174 default:
175 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
176 }
177 }
178
179 var err error
180 if network.ID != 0 {
181 _, err = db.db.Exec(`UPDATE Network
182 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
183 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
184 WHERE id = ?`,
185 netName, network.Addr, network.Nick, netUsername, realname, pass,
186 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
187 } else {
188 var res sql.Result
189 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
190 realname, pass, sasl_mechanism, sasl_plain_username,
191 sasl_plain_password)
192 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
193 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
194 saslMechanism, saslPlainUsername, saslPlainPassword)
195 if err != nil {
196 return err
197 }
198 network.ID, err = res.LastInsertId()
199 }
200 return err
201}
202
203func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
204 db.lock.RLock()
205 defer db.lock.RUnlock()
206
207 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
208 if err != nil {
209 return nil, err
210 }
211 defer rows.Close()
212
213 var channels []Channel
214 for rows.Next() {
215 var ch Channel
216 var key *string
217 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
218 return nil, err
219 }
220 ch.Key = fromStringPtr(key)
221 channels = append(channels, ch)
222 }
223 if err := rows.Err(); err != nil {
224 return nil, err
225 }
226
227 return channels, nil
228}
229
230func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
231 db.lock.Lock()
232 defer db.lock.Unlock()
233
234 key := toStringPtr(ch.Key)
235
236 var err error
237 if ch.ID != 0 {
238 _, err = db.db.Exec(`UPDATE Channel
239 SET network = ?, name = ?, key = ?
240 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
241 } else {
242 var res sql.Result
243 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
244 VALUES (?, ?, ?)`, networkID, ch.Name, key)
245 if err != nil {
246 return err
247 }
248 ch.ID, err = res.LastInsertId()
249 }
250 return err
251}
252
253func (db *DB) DeleteChannel(networkID int64, name string) error {
254 db.lock.Lock()
255 defer db.lock.Unlock()
256
257 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
258 return err
259}
Note: See TracBrowser for help on using the repository browser.