source: code/trunk/db.go@ 148

Last change on this file since 148 was 148, checked in by contact, 5 years ago

Error out when storing unsupported SASL mechanism in DB

File size: 5.5 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "fmt"
6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
13 Password string // hashed
14}
15
16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
25type Network struct {
26 ID int64
27 Name string
28 Addr string
29 Nick string
30 Username string
31 Realname string
32 Pass string
33 SASL SASL
34}
35
36type Channel struct {
37 ID int64
38 Name string
39 Key string
40}
41
42type DB struct {
43 lock sync.RWMutex
44 db *sql.DB
45}
46
47func OpenSQLDB(driver, source string) (*DB, error) {
48 db, err := sql.Open(driver, source)
49 if err != nil {
50 return nil, err
51 }
52 return &DB{db: db}, nil
53}
54
55func (db *DB) Close() error {
56 db.lock.Lock()
57 defer db.lock.Unlock()
58 return db.Close()
59}
60
61func fromStringPtr(ptr *string) string {
62 if ptr == nil {
63 return ""
64 }
65 return *ptr
66}
67
68func toStringPtr(s string) *string {
69 if s == "" {
70 return nil
71 }
72 return &s
73}
74
75func (db *DB) ListUsers() ([]User, error) {
76 db.lock.RLock()
77 defer db.lock.RUnlock()
78
79 rows, err := db.db.Query("SELECT username, password FROM User")
80 if err != nil {
81 return nil, err
82 }
83 defer rows.Close()
84
85 var users []User
86 for rows.Next() {
87 var user User
88 var password *string
89 if err := rows.Scan(&user.Username, &password); err != nil {
90 return nil, err
91 }
92 user.Password = fromStringPtr(password)
93 users = append(users, user)
94 }
95 if err := rows.Err(); err != nil {
96 return nil, err
97 }
98
99 return users, nil
100}
101
102func (db *DB) CreateUser(user *User) error {
103 db.lock.Lock()
104 defer db.lock.Unlock()
105
106 password := toStringPtr(user.Password)
107 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
108 return err
109}
110
111func (db *DB) ListNetworks(username string) ([]Network, error) {
112 db.lock.RLock()
113 defer db.lock.RUnlock()
114
115 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
116 sasl_mechanism, sasl_plain_username, sasl_plain_password
117 FROM Network
118 WHERE user = ?`,
119 username)
120 if err != nil {
121 return nil, err
122 }
123 defer rows.Close()
124
125 var networks []Network
126 for rows.Next() {
127 var net Network
128 var name, username, realname, pass *string
129 var saslMechanism, saslPlainUsername, saslPlainPassword *string
130 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
131 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
132 if err != nil {
133 return nil, err
134 }
135 net.Name = fromStringPtr(name)
136 net.Username = fromStringPtr(username)
137 net.Realname = fromStringPtr(realname)
138 net.Pass = fromStringPtr(pass)
139 net.SASL.Mechanism = fromStringPtr(saslMechanism)
140 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
141 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
142 networks = append(networks, net)
143 }
144 if err := rows.Err(); err != nil {
145 return nil, err
146 }
147
148 return networks, nil
149}
150
151func (db *DB) StoreNetwork(username string, network *Network) error {
152 db.lock.Lock()
153 defer db.lock.Unlock()
154
155 netName := toStringPtr(network.Name)
156 netUsername := toStringPtr(network.Username)
157 realname := toStringPtr(network.Realname)
158 pass := toStringPtr(network.Pass)
159
160 var saslMechanism, saslPlainUsername, saslPlainPassword *string
161 if network.SASL.Mechanism != "" {
162 saslMechanism = &network.SASL.Mechanism
163 switch network.SASL.Mechanism {
164 case "PLAIN":
165 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
166 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
167 default:
168 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
169 }
170 }
171
172 var err error
173 if network.ID != 0 {
174 _, err = db.db.Exec(`UPDATE Network
175 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
176 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
177 WHERE id = ?`,
178 netName, network.Addr, network.Nick, netUsername, realname, pass,
179 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
180 } else {
181 var res sql.Result
182 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
183 realname, pass, sasl_mechanism, sasl_plain_username,
184 sasl_plain_password)
185 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
186 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
187 saslMechanism, saslPlainUsername, saslPlainPassword)
188 if err != nil {
189 return err
190 }
191 network.ID, err = res.LastInsertId()
192 }
193 return err
194}
195
196func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
197 db.lock.RLock()
198 defer db.lock.RUnlock()
199
200 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
201 if err != nil {
202 return nil, err
203 }
204 defer rows.Close()
205
206 var channels []Channel
207 for rows.Next() {
208 var ch Channel
209 var key *string
210 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
211 return nil, err
212 }
213 ch.Key = fromStringPtr(key)
214 channels = append(channels, ch)
215 }
216 if err := rows.Err(); err != nil {
217 return nil, err
218 }
219
220 return channels, nil
221}
222
223func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
224 db.lock.Lock()
225 defer db.lock.Unlock()
226
227 key := toStringPtr(ch.Key)
228 _, err := db.db.Exec(`INSERT OR REPLACE INTO Channel(network, name, key)
229 VALUES (?, ?, ?)`, networkID, ch.Name, key)
230 return err
231}
232
233func (db *DB) DeleteChannel(networkID int64, name string) error {
234 db.lock.Lock()
235 defer db.lock.Unlock()
236
237 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
238 return err
239}
240
241func (net *Network) GetName() string {
242 if net.Name != "" {
243 return net.Name
244 }
245 return net.Addr
246}
Note: See TracBrowser for help on using the repository browser.