source: code/trunk/db.go@ 148

Last change on this file since 148 was 148, checked in by contact, 5 years ago

Error out when storing unsupported SASL mechanism in DB

File size: 5.5 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[77]6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
[85]13 Password string // hashed
[77]14}
15
[95]16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
[77]25type Network struct {
26 ID int64
[118]27 Name string
[77]28 Addr string
29 Nick string
30 Username string
31 Realname string
[93]32 Pass string
[95]33 SASL SASL
[77]34}
35
36type Channel struct {
37 ID int64
38 Name string
[146]39 Key string
[77]40}
41
42type DB struct {
[81]43 lock sync.RWMutex
[77]44 db *sql.DB
45}
46
47func OpenSQLDB(driver, source string) (*DB, error) {
48 db, err := sql.Open(driver, source)
49 if err != nil {
50 return nil, err
51 }
52 return &DB{db: db}, nil
53}
54
55func (db *DB) Close() error {
56 db.lock.Lock()
57 defer db.lock.Unlock()
58 return db.Close()
59}
60
[95]61func fromStringPtr(ptr *string) string {
62 if ptr == nil {
63 return ""
64 }
65 return *ptr
66}
67
68func toStringPtr(s string) *string {
69 if s == "" {
70 return nil
71 }
72 return &s
73}
74
[77]75func (db *DB) ListUsers() ([]User, error) {
[81]76 db.lock.RLock()
77 defer db.lock.RUnlock()
[77]78
79 rows, err := db.db.Query("SELECT username, password FROM User")
80 if err != nil {
81 return nil, err
82 }
83 defer rows.Close()
84
85 var users []User
86 for rows.Next() {
87 var user User
88 var password *string
89 if err := rows.Scan(&user.Username, &password); err != nil {
90 return nil, err
91 }
[95]92 user.Password = fromStringPtr(password)
[77]93 users = append(users, user)
94 }
95 if err := rows.Err(); err != nil {
96 return nil, err
97 }
98
99 return users, nil
100}
101
[84]102func (db *DB) CreateUser(user *User) error {
103 db.lock.Lock()
104 defer db.lock.Unlock()
105
[95]106 password := toStringPtr(user.Password)
[89]107 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
108 return err
[84]109}
110
[77]111func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]112 db.lock.RLock()
113 defer db.lock.RUnlock()
[77]114
[118]115 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[95]116 sasl_mechanism, sasl_plain_username, sasl_plain_password
117 FROM Network
118 WHERE user = ?`,
119 username)
[77]120 if err != nil {
121 return nil, err
122 }
123 defer rows.Close()
124
125 var networks []Network
126 for rows.Next() {
127 var net Network
[118]128 var name, username, realname, pass *string
[95]129 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]130 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[95]131 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
132 if err != nil {
[77]133 return nil, err
134 }
[118]135 net.Name = fromStringPtr(name)
[95]136 net.Username = fromStringPtr(username)
137 net.Realname = fromStringPtr(realname)
138 net.Pass = fromStringPtr(pass)
139 net.SASL.Mechanism = fromStringPtr(saslMechanism)
140 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
141 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]142 networks = append(networks, net)
143 }
144 if err := rows.Err(); err != nil {
145 return nil, err
146 }
147
148 return networks, nil
149}
150
[90]151func (db *DB) StoreNetwork(username string, network *Network) error {
152 db.lock.Lock()
153 defer db.lock.Unlock()
154
[118]155 netName := toStringPtr(network.Name)
[95]156 netUsername := toStringPtr(network.Username)
157 realname := toStringPtr(network.Realname)
158 pass := toStringPtr(network.Pass)
159
160 var saslMechanism, saslPlainUsername, saslPlainPassword *string
161 if network.SASL.Mechanism != "" {
162 saslMechanism = &network.SASL.Mechanism
163 switch network.SASL.Mechanism {
164 case "PLAIN":
165 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
166 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]167 default:
168 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]169 }
[90]170 }
171
172 var err error
173 if network.ID != 0 {
[93]174 _, err = db.db.Exec(`UPDATE Network
[118]175 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
[95]176 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]177 WHERE id = ?`,
[118]178 netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]179 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]180 } else {
181 var res sql.Result
[118]182 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[95]183 realname, pass, sasl_mechanism, sasl_plain_username,
184 sasl_plain_password)
[118]185 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
186 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]187 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]188 if err != nil {
189 return err
190 }
191 network.ID, err = res.LastInsertId()
192 }
193 return err
194}
195
[77]196func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]197 db.lock.RLock()
198 defer db.lock.RUnlock()
[77]199
[146]200 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]201 if err != nil {
202 return nil, err
203 }
204 defer rows.Close()
205
206 var channels []Channel
207 for rows.Next() {
208 var ch Channel
[146]209 var key *string
210 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]211 return nil, err
212 }
[146]213 ch.Key = fromStringPtr(key)
[77]214 channels = append(channels, ch)
215 }
216 if err := rows.Err(); err != nil {
217 return nil, err
218 }
219
220 return channels, nil
221}
[89]222
223func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
224 db.lock.Lock()
225 defer db.lock.Unlock()
226
[146]227 key := toStringPtr(ch.Key)
228 _, err := db.db.Exec(`INSERT OR REPLACE INTO Channel(network, name, key)
229 VALUES (?, ?, ?)`, networkID, ch.Name, key)
[89]230 return err
231}
232
233func (db *DB) DeleteChannel(networkID int64, name string) error {
234 db.lock.Lock()
235 defer db.lock.Unlock()
236
237 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
238 return err
239}
[118]240
241func (net *Network) GetName() string {
242 if net.Name != "" {
243 return net.Name
244 }
245 return net.Addr
246}
Note: See TracBrowser for help on using the repository browser.