source: code/trunk/db.go@ 99

Last change on this file since 99 was 98, checked in by contact, 5 years ago

Rename project to soju

File size: 5.0 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "sync"
6
7 _ "github.com/mattn/go-sqlite3"
8)
9
10type User struct {
11 Username string
12 Password string // hashed
13}
14
15type SASL struct {
16 Mechanism string
17
18 Plain struct {
19 Username string
20 Password string
21 }
22}
23
24type Network struct {
25 ID int64
26 Addr string
27 Nick string
28 Username string
29 Realname string
30 Pass string
31 SASL SASL
32}
33
34type Channel struct {
35 ID int64
36 Name string
37}
38
39type DB struct {
40 lock sync.RWMutex
41 db *sql.DB
42}
43
44func OpenSQLDB(driver, source string) (*DB, error) {
45 db, err := sql.Open(driver, source)
46 if err != nil {
47 return nil, err
48 }
49 return &DB{db: db}, nil
50}
51
52func (db *DB) Close() error {
53 db.lock.Lock()
54 defer db.lock.Unlock()
55 return db.Close()
56}
57
58func fromStringPtr(ptr *string) string {
59 if ptr == nil {
60 return ""
61 }
62 return *ptr
63}
64
65func toStringPtr(s string) *string {
66 if s == "" {
67 return nil
68 }
69 return &s
70}
71
72func (db *DB) ListUsers() ([]User, error) {
73 db.lock.RLock()
74 defer db.lock.RUnlock()
75
76 rows, err := db.db.Query("SELECT username, password FROM User")
77 if err != nil {
78 return nil, err
79 }
80 defer rows.Close()
81
82 var users []User
83 for rows.Next() {
84 var user User
85 var password *string
86 if err := rows.Scan(&user.Username, &password); err != nil {
87 return nil, err
88 }
89 user.Password = fromStringPtr(password)
90 users = append(users, user)
91 }
92 if err := rows.Err(); err != nil {
93 return nil, err
94 }
95
96 return users, nil
97}
98
99func (db *DB) CreateUser(user *User) error {
100 db.lock.Lock()
101 defer db.lock.Unlock()
102
103 password := toStringPtr(user.Password)
104 _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
105 return err
106}
107
108func (db *DB) ListNetworks(username string) ([]Network, error) {
109 db.lock.RLock()
110 defer db.lock.RUnlock()
111
112 rows, err := db.db.Query(`SELECT id, addr, nick, username, realname, pass,
113 sasl_mechanism, sasl_plain_username, sasl_plain_password
114 FROM Network
115 WHERE user = ?`,
116 username)
117 if err != nil {
118 return nil, err
119 }
120 defer rows.Close()
121
122 var networks []Network
123 for rows.Next() {
124 var net Network
125 var username, realname, pass *string
126 var saslMechanism, saslPlainUsername, saslPlainPassword *string
127 err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname,
128 &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
129 if err != nil {
130 return nil, err
131 }
132 net.Username = fromStringPtr(username)
133 net.Realname = fromStringPtr(realname)
134 net.Pass = fromStringPtr(pass)
135 net.SASL.Mechanism = fromStringPtr(saslMechanism)
136 net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
137 net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
138 networks = append(networks, net)
139 }
140 if err := rows.Err(); err != nil {
141 return nil, err
142 }
143
144 return networks, nil
145}
146
147func (db *DB) StoreNetwork(username string, network *Network) error {
148 db.lock.Lock()
149 defer db.lock.Unlock()
150
151 netUsername := toStringPtr(network.Username)
152 realname := toStringPtr(network.Realname)
153 pass := toStringPtr(network.Pass)
154
155 var saslMechanism, saslPlainUsername, saslPlainPassword *string
156 if network.SASL.Mechanism != "" {
157 saslMechanism = &network.SASL.Mechanism
158 switch network.SASL.Mechanism {
159 case "PLAIN":
160 saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
161 saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
162 }
163 }
164
165 var err error
166 if network.ID != 0 {
167 _, err = db.db.Exec(`UPDATE Network
168 SET addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
169 sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
170 WHERE id = ?`,
171 network.Addr, network.Nick, netUsername, realname, pass,
172 saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
173 } else {
174 var res sql.Result
175 res, err = db.db.Exec(`INSERT INTO Network(user, addr, nick, username,
176 realname, pass, sasl_mechanism, sasl_plain_username,
177 sasl_plain_password)
178 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
179 username, network.Addr, network.Nick, netUsername, realname, pass,
180 saslMechanism, saslPlainUsername, saslPlainPassword)
181 if err != nil {
182 return err
183 }
184 network.ID, err = res.LastInsertId()
185 }
186 return err
187}
188
189func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
190 db.lock.RLock()
191 defer db.lock.RUnlock()
192
193 rows, err := db.db.Query("SELECT id, name FROM Channel WHERE network = ?", networkID)
194 if err != nil {
195 return nil, err
196 }
197 defer rows.Close()
198
199 var channels []Channel
200 for rows.Next() {
201 var ch Channel
202 if err := rows.Scan(&ch.ID, &ch.Name); err != nil {
203 return nil, err
204 }
205 channels = append(channels, ch)
206 }
207 if err := rows.Err(); err != nil {
208 return nil, err
209 }
210
211 return channels, nil
212}
213
214func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
215 db.lock.Lock()
216 defer db.lock.Unlock()
217
218 _, err := db.db.Exec("INSERT OR REPLACE INTO Channel(network, name) VALUES (?, ?)", networkID, ch.Name)
219 return err
220}
221
222func (db *DB) DeleteChannel(networkID int64, name string) error {
223 db.lock.Lock()
224 defer db.lock.Unlock()
225
226 _, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
227 return err
228}
Note: See TracBrowser for help on using the repository browser.