source: code/trunk/db.go@ 139

Last change on this file since 139 was 118, checked in by delthas, 5 years ago

schema: add Network.name

File size: 5.2 KB
RevLine 
[98]1package soju
[77]2
3import (
4 "database/sql"
5 "sync"
6
7 _ "github.com/mattn/go-sqlite3"
8)
9
10type User struct {
11 Username string
[85]12 Password string // hashed
[77]13}
14
[95]15type SASL struct {
16 Mechanism string
17
18 Plain struct {
19 Username string
20 Password string
21 }
22}
23
[77]24type Network struct {
25 ID int64
[118]26 Name string
[77]27 Addr string
28 Nick string
29 Username string
30 Realname string
[93]31 Pass string
[95]32 SASL SASL
[77]33}
34
35type Channel struct {
36 ID int64
37 Name string
38}
39
40type DB struct {
[81]41 lock sync.RWMutex
[77]42 db *sql.DB
43}
44
45func OpenSQLDB(driver, source string) (*DB, error) {
46 db, err := sql.Open(driver, source)
47 if err != nil {
48 return nil, err
49 }
50 return &DB{db: db}, nil
51}
52
53func (db *DB) Close() error {
54 db.lock.Lock()
55 defer db.lock.Unlock()
56 return db.Close()
57}
58
[95]59func fromStringPtr(ptr *string) string {
60 if ptr == nil {
61 return ""
62 }
63 return *ptr
64}
65
66func toStringPtr(s string) *string {
67 if s == "" {
68 return nil
69 }
70 return &s
71}
72
[77]73func (db *DB) ListUsers() ([]User, error) {
[81]74 db.lock.RLock()
75 defer db.lock.RUnlock()
[77]76
77 rows, err := db.db.Query("SELECT username, password FROM User")
78 if err != nil {
79 return nil, err
80 }
81 defer rows.Close()
82
83 var users []User
84 for rows.Next() {
85 var user User
86 var password *string
87 if err := rows.Scan(&user.Username, &password); err != nil {
88 return nil, err
89 }
[95]90 user.Password = fromStringPtr(password)
[77]91 users = append(users, user)
92 }
93 if err := rows.Err(); err != nil {
94 return nil, err
95 }
96
97 return users, nil
98}
99
[84]100func (db *DB) CreateUser(user *User) error {
101 db.lock.Lock()
102 defer db.lock.Unlock()
103
[95]104 password := toStringPtr(user.Password)
[89]105 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
106 return err
[84]107}
108
[77]109func (db *DB) ListNetworks(username string) ([]Network, error) {
[81]110 db.lock.RLock()
111 defer db.lock.RUnlock()
[77]112
[118]113 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
[95]114 sasl_mechanism, sasl_plain_username, sasl_plain_password
115 FROM Network
116 WHERE user = ?`,
117 username)
[77]118 if err != nil {
119 return nil, err
120 }
121 defer rows.Close()
122
123 var networks []Network
124 for rows.Next() {
125 var net Network
[118]126 var name, username, realname, pass *string
[95]127 var saslMechanism, saslPlainUsername, saslPlainPassword *string
[118]128 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
[95]129 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
130 if err != nil {
[77]131 return nil, err
132 }
[118]133 net.Name = fromStringPtr(name)
[95]134 net.Username = fromStringPtr(username)
135 net.Realname = fromStringPtr(realname)
136 net.Pass = fromStringPtr(pass)
137 net.SASL.Mechanism = fromStringPtr(saslMechanism)
138 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
139 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
[77]140 networks = append(networks, net)
141 }
142 if err := rows.Err(); err != nil {
143 return nil, err
144 }
145
146 return networks, nil
147}
148
[90]149func (db *DB) StoreNetwork(username string, network *Network) error {
150 db.lock.Lock()
151 defer db.lock.Unlock()
152
[118]153 netName := toStringPtr(network.Name)
[95]154 netUsername := toStringPtr(network.Username)
155 realname := toStringPtr(network.Realname)
156 pass := toStringPtr(network.Pass)
157
158 var saslMechanism, saslPlainUsername, saslPlainPassword *string
159 if network.SASL.Mechanism != "" {
160 saslMechanism = &network.SASL.Mechanism
161 switch network.SASL.Mechanism {
162 case "PLAIN":
163 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
164 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
165 }
[90]166 }
167
168 var err error
169 if network.ID != 0 {
[93]170 _, err = db.db.Exec(`UPDATE Network
[118]171 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
[95]172 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
[93]173 WHERE id = ?`,
[118]174 netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]175 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
[90]176 } else {
177 var res sql.Result
[118]178 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
[95]179 realname, pass, sasl_mechanism, sasl_plain_username,
180 sasl_plain_password)
[118]181 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
182 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
[95]183 saslMechanism, saslPlainUsername, saslPlainPassword)
[90]184 if err != nil {
185 return err
186 }
187 network.ID, err = res.LastInsertId()
188 }
189 return err
190}
191
[77]192func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
[81]193 db.lock.RLock()
194 defer db.lock.RUnlock()
[77]195
196 rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID)
197 if err != nil {
198 return nil, err
199 }
200 defer rows.Close()
201
202 var channels []Channel
203 for rows.Next() {
204 var ch Channel
205 if err := rows.Scan(&ch.ID, &ch.Name); err != nil {
206 return nil, err
207 }
208 channels = append(channels, ch)
209 }
210 if err := rows.Err(); err != nil {
211 return nil, err
212 }
213
214 return channels, nil
215}
[89]216
217func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
218 db.lock.Lock()
219 defer db.lock.Unlock()
220
221 _, err := db.db.Exec("INSERT OR REPLACE INTO Channel(network, name) VALUES (?, ?)", networkID, ch.Name)
222 return err
223}
224
225func (db *DB) DeleteChannel(networkID int64, name string) error {
226 db.lock.Lock()
227 defer db.lock.Unlock()
228
229 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
230 return err
231}
[118]232
233func (net *Network) GetName() string {
234 if net.Name != "" {
235 return net.Name
236 }
237 return net.Addr
238}
Note: See TracBrowser for help on using the repository browser.