source: code/trunk/db.go@ 141

Last change on this file since 141 was 118, checked in by delthas, 5 years ago

schema: add Network.name

File size: 5.2 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "sync"
6
7 _ "github.com/mattn/go-sqlite3"
8)
9
10type User struct {
11 Username string
12 Password string // hashed
13}
14
15type SASL struct {
16 Mechanism string
17
18 Plain struct {
19 Username string
20 Password string
21 }
22}
23
24type Network struct {
25 ID int64
26 Name string
27 Addr string
28 Nick string
29 Username string
30 Realname string
31 Pass string
32 SASL SASL
33}
34
35type Channel struct {
36 ID int64
37 Name string
38}
39
40type DB struct {
41 lock sync.RWMutex
42 db *sql.DB
43}
44
45func OpenSQLDB(driver, source string) (*DB, error) {
46 db, err := sql.Open(driver, source)
47 if err != nil {
48 return nil, err
49 }
50 return &DB{db: db}, nil
51}
52
53func (db *DB) Close() error {
54 db.lock.Lock()
55 defer db.lock.Unlock()
56 return db.Close()
57}
58
59func fromStringPtr(ptr *string) string {
60 if ptr == nil {
61 return ""
62 }
63 return *ptr
64}
65
66func toStringPtr(s string) *string {
67 if s == "" {
68 return nil
69 }
70 return &s
71}
72
73func (db *DB) ListUsers() ([]User, error) {
74 db.lock.RLock()
75 defer db.lock.RUnlock()
76
77 rows, err := db.db.Query("SELECT username, password FROM User")
78 if err != nil {
79 return nil, err
80 }
81 defer rows.Close()
82
83 var users []User
84 for rows.Next() {
85 var user User
86 var password *string
87 if err := rows.Scan(&user.Username, &password); err != nil {
88 return nil, err
89 }
90 user.Password = fromStringPtr(password)
91 users = append(users, user)
92 }
93 if err := rows.Err(); err != nil {
94 return nil, err
95 }
96
97 return users, nil
98}
99
100func (db *DB) CreateUser(user *User) error {
101 db.lock.Lock()
102 defer db.lock.Unlock()
103
104 password := toStringPtr(user.Password)
105 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
106 return err
107}
108
109func (db *DB) ListNetworks(username string) ([]Network, error) {
110 db.lock.RLock()
111 defer db.lock.RUnlock()
112
113 rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
114 sasl_mechanism, sasl_plain_username, sasl_plain_password
115 FROM Network
116 WHERE user = ?`,
117 username)
118 if err != nil {
119 return nil, err
120 }
121 defer rows.Close()
122
123 var networks []Network
124 for rows.Next() {
125 var net Network
126 var name, username, realname, pass *string
127 var saslMechanism, saslPlainUsername, saslPlainPassword *string
128 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
129 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
130 if err != nil {
131 return nil, err
132 }
133 net.Name = fromStringPtr(name)
134 net.Username = fromStringPtr(username)
135 net.Realname = fromStringPtr(realname)
136 net.Pass = fromStringPtr(pass)
137 net.SASL.Mechanism = fromStringPtr(saslMechanism)
138 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
139 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
140 networks = append(networks, net)
141 }
142 if err := rows.Err(); err != nil {
143 return nil, err
144 }
145
146 return networks, nil
147}
148
149func (db *DB) StoreNetwork(username string, network *Network) error {
150 db.lock.Lock()
151 defer db.lock.Unlock()
152
153 netName := toStringPtr(network.Name)
154 netUsername := toStringPtr(network.Username)
155 realname := toStringPtr(network.Realname)
156 pass := toStringPtr(network.Pass)
157
158 var saslMechanism, saslPlainUsername, saslPlainPassword *string
159 if network.SASL.Mechanism != "" {
160 saslMechanism = &network.SASL.Mechanism
161 switch network.SASL.Mechanism {
162 case "PLAIN":
163 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
164 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
165 }
166 }
167
168 var err error
169 if network.ID != 0 {
170 _, err = db.db.Exec(`UPDATE Network
171 SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
172 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
173 WHERE id = ?`,
174 netName, network.Addr, network.Nick, netUsername, realname, pass,
175 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
176 } else {
177 var res sql.Result
178 res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
179 realname, pass, sasl_mechanism, sasl_plain_username,
180 sasl_plain_password)
181 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
182 username, netName, network.Addr, network.Nick, netUsername, realname, pass,
183 saslMechanism, saslPlainUsername, saslPlainPassword)
184 if err != nil {
185 return err
186 }
187 network.ID, err = res.LastInsertId()
188 }
189 return err
190}
191
192func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
193 db.lock.RLock()
194 defer db.lock.RUnlock()
195
196 rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID)
197 if err != nil {
198 return nil, err
199 }
200 defer rows.Close()
201
202 var channels []Channel
203 for rows.Next() {
204 var ch Channel
205 if err := rows.Scan(&ch.ID, &ch.Name); err != nil {
206 return nil, err
207 }
208 channels = append(channels, ch)
209 }
210 if err := rows.Err(); err != nil {
211 return nil, err
212 }
213
214 return channels, nil
215}
216
217func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
218 db.lock.Lock()
219 defer db.lock.Unlock()
220
221 _, err := db.db.Exec("INSERT OR REPLACE INTO Channel(network, name) VALUES (?, ?)", networkID, ch.Name)
222 return err
223}
224
225func (db *DB) DeleteChannel(networkID int64, name string) error {
226 db.lock.Lock()
227 defer db.lock.Unlock()
228
229 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
230 return err
231}
232
233func (net *Network) GetName() string {
234 if net.Name != "" {
235 return net.Name
236 }
237 return net.Addr
238}
Note: See TracBrowser for help on using the repository browser.