source: code/trunk/db.go@ 170

Last change on this file since 170 was 149, checked in by contact, 5 years ago

Correctly set Channel.ID in DB.StoreChannel

File size: 5.7 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
[148]5 "fmt"
[77]6 "sync"
7
8 _ "github.com/mattn/go-sqlite3"
9)
10
11type User struct {
12 Username string
[85]13 Password string // hashed
[77]14}
15
[95]16type SASL struct {
17 Mechanism string
18
19 Plain struct {
20 Username string
21 Password string
22 }
23}
24
[77]25type Network struct {
26 ID int64
[118]27 Name string
[77]28 Addr string
29 Nick string
30 Username string
31 Realname string
[93]32 Pass string
[95]33 SASL SASL
[77]34}
35
[149]36func (net *Network) GetName() string {
37 if net.Name != "" {
38 return net.Name
39 }
40 return net.Addr
41}
42
[77]43type Channel struct {
44 ID int64
45 Name string
[146]46 Key string
[77]47}
48
49type DB struct {
[81]50 lock sync.RWMutex
[77]51 db *sql.DB
52}
53
54func OpenSQLDB(driver, source string) (*DB, error) {
55 db, err := sql.Open(driver, source)
56 if err != nil {
57 return nil, err
58 }
59 return &DB{db: db}, nil
60}
61
62func (db *DB) Close() error {
63 db.lock.Lock()
64 defer db.lock.Unlock()
65 return db.Close()
66}
67
[95]68func fromStringPtr(ptr *string) string {
69 if ptr == nil {
70 return ""
71 }
72 return *ptr
73}
74
75func toStringPtr(s string) *string {
76 if s == "" {
77 return nil
78 }
79 return &s
80}
81
[77]82func (db *DB) ListUsers() ([]User, error) {
[81]83 db.lock.RLock()
84 defer db.lock.RUnlock()
[77]85
86 rows, err := db.db.Query("SELECT username, password FROM User")
87 if err != nil {
88 return nil, err
89 }
90 defer rows.Close()
91
92 var users []User
93 for rows.Next() {
94 var user User
95 var password *string
96 if err := rows.Scan(&user.Username, &password); err != nil {
97 return nil, err
98 }
[95]99 user.Password = fromStringPtr(password)
[77]100 users = append(users, user)
101 }
102 if err := rows.Err(); err != nil {
103 return nil, err
104 }
105
106 return users, nil
107}
108
[84]109func (db *DB) CreateUser(user *User) error {
110 db.lock.Lock()
111 defer db.lock.Unlock()
112
[95]113 password := toStringPtr(user.Password)
[89]114 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
115 return err
[84]116}
117
[77]118func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]119 db.lock.RLock()
120 defer db.lock.RUnlock()
[77]121
[118]122 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[95]123 sasl_mechanism, sasl_plain_username, sasl_plain_password
124 FROM Network
125 WHERE user = ?`,
126 username)
[77]127 if err != nil {
128 return nil, err
129 }
130 defer rows.Close()
131
132 var networks []Network
133 for rows.Next() {
134 var net Network
[118]135 var name, username, realname, pass *string
[95]136 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]137 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[95]138 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
139 if err != nil {
[77]140 return nil, err
141 }
[118]142 net.Name = fromStringPtr(name)
[95]143 net.Username = fromStringPtr(username)
144 net.Realname = fromStringPtr(realname)
145 net.Pass = fromStringPtr(pass)
146 net.SASL.Mechanism = fromStringPtr(saslMechanism)
147 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
148 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]149 networks = append(networks, net)
150 }
151 if err := rows.Err(); err != nil {
152 return nil, err
153 }
154
155 return networks, nil
156}
157
[90]158func (db *DB) StoreNetwork(username string, network *Network) error {
159 db.lock.Lock()
160 defer db.lock.Unlock()
161
[118]162 netName := toStringPtr(network.Name)
[95]163 netUsername := toStringPtr(network.Username)
164 realname := toStringPtr(network.Realname)
165 pass := toStringPtr(network.Pass)
166
167 var saslMechanism, saslPlainUsername, saslPlainPassword *string
168 if network.SASL.Mechanism != "" {
169 saslMechanism = &network.SASL.Mechanism
170 switch network.SASL.Mechanism {
171 case "PLAIN":
172 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
173 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
[148]174 default:
175 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
[95]176 }
[90]177 }
178
179 var err error
180 if network.ID != 0 {
[93]181 _, err = db.db.Exec(`UPDATE Network
[118]182 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
[95]183 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]184 WHERE id = ?`,
[118]185 netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]186 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]187 } else {
188 var res sql.Result
[118]189 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[95]190 realname, pass, sasl_mechanism, sasl_plain_username,
191 sasl_plain_password)
[118]192 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
193 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]194 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]195 if err != nil {
196 return err
197 }
198 network.ID, err = res.LastInsertId()
199 }
200 return err
201}
202
[77]203func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]204 db.lock.RLock()
205 defer db.lock.RUnlock()
[77]206
[146]207 rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
[77]208 if err != nil {
209 return nil, err
210 }
211 defer rows.Close()
212
213 var channels []Channel
214 for rows.Next() {
215 var ch Channel
[146]216 var key *string
217 if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
[77]218 return nil, err
219 }
[146]220 ch.Key = fromStringPtr(key)
[77]221 channels = append(channels, ch)
222 }
223 if err := rows.Err(); err != nil {
224 return nil, err
225 }
226
227 return channels, nil
228}
[89]229
230func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
231 db.lock.Lock()
232 defer db.lock.Unlock()
233
[146]234 key := toStringPtr(ch.Key)
[149]235
236 var err error
237 if ch.ID != 0 {
238 _, err = db.db.Exec(`UPDATE Channel
239 SET network = ?, name = ?, key = ?
240 WHERE id = ?`, networkID, ch.Name, key, ch.ID)
241 } else {
242 var res sql.Result
243 res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
244 VALUES (?, ?, ?)`, networkID, ch.Name, key)
245 if err != nil {
246 return err
247 }
248 ch.ID, err = res.LastInsertId()
249 }
[89]250 return err
251}
252
253func (db *DB) DeleteChannel(networkID int64, name string) error {
254 db.lock.Lock()
255 defer db.lock.Unlock()
256
257 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
258 return err
259}
Note: See TracBrowser for help on using the repository browser.