source: code/trunk/db.go@ 146

Last change on this file since 146 was 146, checked in by contact, 5 years ago

Add support for channel keys

File size: 5.3 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "sync"
6
7 _ "github.com/mattn/go-sqlite3"
8)
9
10type User struct {
11 Username string
12 Password string // hashed
13}
14
15type SASL struct {
16 Mechanism string
17
18 Plain struct {
19 Username string
20 Password string
21 }
22}
23
24type Network struct {
25 ID int64
26 Name string
27 Addr string
28 Nick string
29 Username string
30 Realname string
31 Pass string
32 SASL SASL
33}
34
35type Channel struct {
36 ID int64
37 Name string
38 Key string
39}
40
41type DB struct {
42 lock sync.RWMutex
43 db *sql.DB
44}
45
46func OpenSQLDB(driver, source string) (*DB, error) {
47 db, err := sql.Open(driver, source)
48 if err != nil {
49 return nil, err
50 }
51 return &DB{db: db}, nil
52}
53
54func (db *DB) Close() error {
55 db.lock.Lock()
56 defer db.lock.Unlock()
57 return db.Close()
58}
59
60func fromStringPtr(ptr *string) string {
61 if ptr == nil {
62 return ""
63 }
64 return *ptr
65}
66
67func toStringPtr(s string) *string {
68 if s == "" {
69 return nil
70 }
71 return &s
72}
73
74func (db *DB) ListUsers() ([]User, error) {
75 db.lock.RLock()
76 defer db.lock.RUnlock()
77
78 rows, err := db.db.Query("SELECT username, password FROM User")
79 if err != nil {
80 return nil, err
81 }
82 defer rows.Close()
83
84 var users []User
85 for rows.Next() {
86 var user User
87 var password *string
88 if err := rows.Scan(&user.Username, &password); err != nil {
89 return nil, err
90 }
91 user.Password = fromStringPtr(password)
92 users = append(users, user)
93 }
94 if err := rows.Err(); err != nil {
95 return nil, err
96 }
97
98 return users, nil
99}
100
101func (db *DB) CreateUser(user *User) error {
102 db.lock.Lock()
103 defer db.lock.Unlock()
104
105 password := toStringPtr(user.Password)
106 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
107 return err
108}
109
110func (db *DB) ListNetworks(username string) ([]Network, error) {
111 db.lock.RLock()
112 defer db.lock.RUnlock()
113
114 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
115 sasl_mechanism, sasl_plain_username, sasl_plain_password
116 FROM Network
117 WHERE user = ?`,
118 username)
119 if err != nil {
120 return nil, err
121 }
122 defer rows.Close()
123
124 var networks []Network
125 for rows.Next() {
126 var net Network
127 var name, username, realname, pass *string
128 var saslMechanism, saslPlainUsername, saslPlainPassword *string
129 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
130 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
131 if err != nil {
132 return nil, err
133 }
134 net.Name = fromStringPtr(name)
135 net.Username = fromStringPtr(username)
136 net.Realname = fromStringPtr(realname)
137 net.Pass = fromStringPtr(pass)
138 net.SASL.Mechanism = fromStringPtr(saslMechanism)
139 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
140 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
141 networks = append(networks, net)
142 }
143 if err := rows.Err(); err != nil {
144 return nil, err
145 }
146
147 return networks, nil
148}
149
150func (db *DB) StoreNetwork(username string, network *Network) error {
151 db.lock.Lock()
152 defer db.lock.Unlock()
153
154 netName := toStringPtr(network.Name)
155 netUsername := toStringPtr(network.Username)
156 realname := toStringPtr(network.Realname)
157 pass := toStringPtr(network.Pass)
158
159 var saslMechanism, saslPlainUsername, saslPlainPassword *string
160 if network.SASL.Mechanism != "" {
161 saslMechanism = &network.SASL.Mechanism
162 switch network.SASL.Mechanism {
163 case "PLAIN":
164 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
165 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
166 }
167 }
168
169 var err error
170 if network.ID != 0 {
171 _, err = db.db.Exec(`UPDATE Network
172 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
173 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
174 WHERE id = ?`,
175 netName, network.Addr, network.Nick, netUsername, realname, pass,
176 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
177 } else {
178 var res sql.Result
179 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
180 realname, pass, sasl_mechanism, sasl_plain_username,
181 sasl_plain_password)
182 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
183 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
184 saslMechanism, saslPlainUsername, saslPlainPassword)
185 if err != nil {
186 return err
187 }
188 network.ID, err = res.LastInsertId()
189 }
190 return err
191}
192
193func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
194 db.lock.RLock()
195 defer db.lock.RUnlock()
196
197 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
198 if err != nil {
199 return nil, err
200 }
201 defer rows.Close()
202
203 var channels []Channel
204 for rows.Next() {
205 var ch Channel
206 var key *string
207 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
208 return nil, err
209 }
210 ch.Key = fromStringPtr(key)
211 channels = append(channels, ch)
212 }
213 if err := rows.Err(); err != nil {
214 return nil, err
215 }
216
217 return channels, nil
218}
219
220func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
221 db.lock.Lock()
222 defer db.lock.Unlock()
223
224 key := toStringPtr(ch.Key)
225 _, err := db.db.Exec(`INSERT OR REPLACE INTO Channel(network, name, key)
226 VALUES (?, ?, ?)`, networkID, ch.Name, key)
227 return err
228}
229
230func (db *DB) DeleteChannel(networkID int64, name string) error {
231 db.lock.Lock()
232 defer db.lock.Unlock()
233
234 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
235 return err
236}
237
238func (net *Network) GetName() string {
239 if net.Name != "" {
240 return net.Name
241 }
242 return net.Addr
243}
Note: See TracBrowser for help on using the repository browser.