source: code/trunk/db.go@ 147

Last change on this file since 147 was 146, checked in by contact, 5 years ago

Add support for channel keys

File size: 5.3 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
5 "sync"
6
7 _ "github.com/mattn/go-sqlite3"
8)
9
10type User struct {
11 Username string
[85]12 Password string // hashed
[77]13}
14
[95]15type SASL struct {
16 Mechanism string
17
18 Plain struct {
19 Username string
20 Password string
21 }
22}
23
[77]24type Network struct {
25 ID int64
[118]26 Name string
[77]27 Addr string
28 Nick string
29 Username string
30 Realname string
[93]31 Pass string
[95]32 SASL SASL
[77]33}
34
35type Channel struct {
36 ID int64
37 Name string
[146]38 Key string
[77]39}
40
41type DB struct {
[81]42 lock sync.RWMutex
[77]43 db *sql.DB
44}
45
46func OpenSQLDB(driver, source string) (*DB, error) {
47 db, err := sql.Open(driver, source)
48 if err != nil {
49 return nil, err
50 }
51 return &DB{db: db}, nil
52}
53
54func (db *DB) Close() error {
55 db.lock.Lock()
56 defer db.lock.Unlock()
57 return db.Close()
58}
59
[95]60func fromStringPtr(ptr *string) string {
61 if ptr == nil {
62 return ""
63 }
64 return *ptr
65}
66
67func toStringPtr(s string) *string {
68 if s == "" {
69 return nil
70 }
71 return &s
72}
73
[77]74func (db *DB) ListUsers() ([]User, error) {
[81]75 db.lock.RLock()
76 defer db.lock.RUnlock()
[77]77
78 rows, err := db.db.Query("SELECT username, password FROM User")
79 if err != nil {
80 return nil, err
81 }
82 defer rows.Close()
83
84 var users []User
85 for rows.Next() {
86 var user User
87 var password *string
88 if err := rows.Scan(&user.Username, &password); err != nil {
89 return nil, err
90 }
[95]91 user.Password = fromStringPtr(password)
[77]92 users = append(users, user)
93 }
94 if err := rows.Err(); err != nil {
95 return nil, err
96 }
97
98 return users, nil
99}
100
[84]101func (db *DB) CreateUser(user *User) error {
102 db.lock.Lock()
103 defer db.lock.Unlock()
104
[95]105 password := toStringPtr(user.Password)
[89]106 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
107 return err
[84]108}
109
[77]110func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]111 db.lock.RLock()
112 defer db.lock.RUnlock()
[77]113
[118]114 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[95]115 sasl_mechanism, sasl_plain_username, sasl_plain_password
116 FROM Network
117 WHERE user = ?`,
118 username)
[77]119 if err != nil {
120 return nil, err
121 }
122 defer rows.Close()
123
124 var networks []Network
125 for rows.Next() {
126 var net Network
[118]127 var name, username, realname, pass *string
[95]128 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]129 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[95]130 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
131 if err != nil {
[77]132 return nil, err
133 }
[118]134 net.Name = fromStringPtr(name)
[95]135 net.Username = fromStringPtr(username)
136 net.Realname = fromStringPtr(realname)
137 net.Pass = fromStringPtr(pass)
138 net.SASL.Mechanism = fromStringPtr(saslMechanism)
139 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
140 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]141 networks = append(networks, net)
142 }
143 if err := rows.Err(); err != nil {
144 return nil, err
145 }
146
147 return networks, nil
148}
149
[90]150func (db *DB) StoreNetwork(username string, network *Network) error {
151 db.lock.Lock()
152 defer db.lock.Unlock()
153
[118]154 netName := toStringPtr(network.Name)
[95]155 netUsername := toStringPtr(network.Username)
156 realname := toStringPtr(network.Realname)
157 pass := toStringPtr(network.Pass)
158
159 var saslMechanism, saslPlainUsername, saslPlainPassword *string
160 if network.SASL.Mechanism != "" {
161 saslMechanism = &network.SASL.Mechanism
162 switch network.SASL.Mechanism {
163 case "PLAIN":
164 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
165 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
166 }
[90]167 }
168
169 var err error
170 if network.ID != 0 {
[93]171 _, err = db.db.Exec(`UPDATE Network
[118]172 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
[95]173 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]174 WHERE id = ?`,
[118]175 netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]176 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]177 } else {
178 var res sql.Result
[118]179 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[95]180 realname, pass, sasl_mechanism, sasl_plain_username,
181 sasl_plain_password)
[118]182 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
183 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]184 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]185 if err != nil {
186 return err
187 }
188 network.ID, err = res.LastInsertId()
189 }
190 return err
191}
192
[77]193func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]194 db.lock.RLock()
195 defer db.lock.RUnlock()
[77]196
[146]197 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]198 if err != nil {
199 return nil, err
200 }
201 defer rows.Close()
202
203 var channels []Channel
204 for rows.Next() {
205 var ch Channel
[146]206 var key *string
207 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]208 return nil, err
209 }
[146]210 ch.Key = fromStringPtr(key)
[77]211 channels = append(channels, ch)
212 }
213 if err := rows.Err(); err != nil {
214 return nil, err
215 }
216
217 return channels, nil
218}
[89]219
220func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
221 db.lock.Lock()
222 defer db.lock.Unlock()
223
[146]224 key := toStringPtr(ch.Key)
225 _, err := db.db.Exec(`INSERT OR REPLACE INTO Channel(network, name, key)
226 VALUES (?, ?, ?)`, networkID, ch.Name, key)
[89]227 return err
228}
229
230func (db *DB) DeleteChannel(networkID int64, name string) error {
231 db.lock.Lock()
232 defer db.lock.Unlock()
233
234 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
235 return err
236}
[118]237
238func (net *Network) GetName() string {
239 if net.Name != "" {
240 return net.Name
241 }
242 return net.Addr
243}
Note: See TracBrowser for help on using the repository browser.