source: code/trunk/db_postgres.go@ 626

Last change on this file since 626 was 620, checked in by hubert, 4 years ago

PostgreSQL support

File size: 11.5 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "math"
8 "strings"
9 "time"
10
11 _ "github.com/lib/pq"
12)
13
14const postgresConfigSchema = `
15CREATE TABLE IF NOT EXISTS "Config" (
16 id SMALLINT PRIMARY KEY,
17 version INTEGER NOT NULL,
18 CHECK(id = 1)
19);
20`
21
22const postgresSchema = `
23CREATE TABLE "User" (
24 id SERIAL PRIMARY KEY,
25 username VARCHAR(255) NOT NULL UNIQUE,
26 password VARCHAR(255),
27 admin BOOLEAN NOT NULL DEFAULT FALSE,
28 realname VARCHAR(255)
29);
30
31CREATE TABLE "Network" (
32 id SERIAL PRIMARY KEY,
33 name VARCHAR(255),
34 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
35 addr VARCHAR(255) NOT NULL,
36 nick VARCHAR(255) NOT NULL,
37 username VARCHAR(255),
38 realname VARCHAR(255),
39 pass VARCHAR(255),
40 connect_commands VARCHAR(1023),
41 sasl_mechanism VARCHAR(255),
42 sasl_plain_username VARCHAR(255),
43 sasl_plain_password VARCHAR(255),
44 sasl_external_cert BYTEA DEFAULT NULL,
45 sasl_external_key BYTEA DEFAULT NULL,
46 enabled BOOLEAN NOT NULL DEFAULT TRUE,
47 UNIQUE("user", addr, nick),
48 UNIQUE("user", name)
49);
50
51CREATE TABLE "Channel" (
52 id SERIAL PRIMARY KEY,
53 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
54 name VARCHAR(255) NOT NULL,
55 key VARCHAR(255),
56 detached BOOLEAN NOT NULL DEFAULT FALSE,
57 detached_internal_msgid VARCHAR(255),
58 relay_detached INTEGER NOT NULL DEFAULT 0,
59 reattach_on INTEGER NOT NULL DEFAULT 0,
60 detach_after INTEGER NOT NULL DEFAULT 0,
61 detach_on INTEGER NOT NULL DEFAULT 0,
62 UNIQUE(network, name)
63);
64
65CREATE TABLE "DeliveryReceipt" (
66 id SERIAL PRIMARY KEY,
67 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
68 target VARCHAR(255) NOT NULL,
69 client VARCHAR(255) NOT NULL DEFAULT '',
70 internal_msgid VARCHAR(255) NOT NULL,
71 UNIQUE(network, target, client)
72);
73`
74
75var postgresMigrations = []string{
76 "", // migration #0 is reserved for schema initialization
77}
78
79type PostgresDB struct {
80 db *sql.DB
81}
82
83func OpenPostgresDB(source string) (Database, error) {
84 sqlPostgresDB, err := sql.Open("postgres", source)
85 if err != nil {
86 return nil, err
87 }
88
89 db := &PostgresDB{db: sqlPostgresDB}
90 if err := db.upgrade(); err != nil {
91 sqlPostgresDB.Close()
92 return nil, err
93 }
94
95 return db, nil
96}
97
98func (db *PostgresDB) upgrade() error {
99 tx, err := db.db.Begin()
100 if err != nil {
101 return err
102 }
103 defer tx.Rollback()
104
105 if _, err := tx.Exec(postgresConfigSchema); err != nil {
106 return fmt.Errorf("failed to create Config table: %s", err)
107 }
108
109 var version int
110 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
111 if err != nil && !errors.Is(err, sql.ErrNoRows) {
112 return fmt.Errorf("failed to query schema version: %s", err)
113 }
114
115 if version == len(postgresMigrations) {
116 return nil
117 }
118 if version > len(postgresMigrations) {
119 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
120 }
121
122 if version == 0 {
123 if _, err := tx.Exec(postgresSchema); err != nil {
124 return fmt.Errorf("failed to initialize schema: %s", err)
125 }
126 } else {
127 for i := version; i < len(postgresMigrations); i++ {
128 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
129 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
130 }
131 }
132 }
133
134 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
135 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
136 if err != nil {
137 return fmt.Errorf("failed to bump schema version: %v", err)
138 }
139
140 return tx.Commit()
141}
142
143func (db *PostgresDB) Close() error {
144 return db.db.Close()
145}
146
147func (db *PostgresDB) Stats() (*DatabaseStats, error) {
148 var stats DatabaseStats
149 row := db.db.QueryRow(`SELECT
150 (SELECT COUNT(*) FROM "User") AS users,
151 (SELECT COUNT(*) FROM "Network") AS networks,
152 (SELECT COUNT(*) FROM "Channel") AS channels`)
153 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
154 return nil, err
155 }
156
157 return &stats, nil
158}
159
160func (db *PostgresDB) ListUsers() ([]User, error) {
161 rows, err := db.db.Query(`SELECT id, username, password, admin, realname FROM "User"`)
162 if err != nil {
163 return nil, err
164 }
165 defer rows.Close()
166
167 var users []User
168 for rows.Next() {
169 var user User
170 var password, realname sql.NullString
171 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
172 return nil, err
173 }
174 user.Password = password.String
175 user.Realname = realname.String
176 users = append(users, user)
177 }
178 if err := rows.Err(); err != nil {
179 return nil, err
180 }
181
182 return users, nil
183}
184
185func (db *PostgresDB) GetUser(username string) (*User, error) {
186 user := &User{Username: username}
187
188 var password, realname sql.NullString
189 row := db.db.QueryRow(
190 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
191 username)
192 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
193 return nil, err
194 }
195 user.Password = password.String
196 user.Realname = realname.String
197 return user, nil
198}
199
200func (db *PostgresDB) StoreUser(user *User) error {
201 password := toNullString(user.Password)
202 realname := toNullString(user.Realname)
203 err := db.db.QueryRow(`
204 INSERT INTO "User" (username, password, admin, realname)
205 VALUES ($1, $2, $3, $4)
206 ON CONFLICT (username)
207 DO UPDATE SET password = $2, admin = $3, realname = $4
208 RETURNING id`,
209 user.Username, password, user.Admin, realname).Scan(&user.ID)
210 return err
211}
212
213func (db *PostgresDB) DeleteUser(id int64) error {
214 _, err := db.db.Exec(`DELETE FROM "User" WHERE id = $1`, id)
215 return err
216}
217
218func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
219 rows, err := db.db.Query(`
220 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
221 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
222 FROM "Network"
223 WHERE "user" = $1`, userID)
224 if err != nil {
225 return nil, err
226 }
227 defer rows.Close()
228
229 var networks []Network
230 for rows.Next() {
231 var net Network
232 var name, username, realname, pass, connectCommands sql.NullString
233 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
234 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
235 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
236 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
237 if err != nil {
238 return nil, err
239 }
240 net.Name = name.String
241 net.Username = username.String
242 net.Realname = realname.String
243 net.Pass = pass.String
244 if connectCommands.Valid {
245 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
246 }
247 net.SASL.Mechanism = saslMechanism.String
248 net.SASL.Plain.Username = saslPlainUsername.String
249 net.SASL.Plain.Password = saslPlainPassword.String
250 networks = append(networks, net)
251 }
252 if err := rows.Err(); err != nil {
253 return nil, err
254 }
255
256 return networks, nil
257}
258
259func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
260 netName := toNullString(network.Name)
261 netUsername := toNullString(network.Username)
262 realname := toNullString(network.Realname)
263 pass := toNullString(network.Pass)
264 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
265
266 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
267 if network.SASL.Mechanism != "" {
268 saslMechanism = toNullString(network.SASL.Mechanism)
269 switch network.SASL.Mechanism {
270 case "PLAIN":
271 saslPlainUsername = toNullString(network.SASL.Plain.Username)
272 saslPlainPassword = toNullString(network.SASL.Plain.Password)
273 network.SASL.External.CertBlob = nil
274 network.SASL.External.PrivKeyBlob = nil
275 case "EXTERNAL":
276 // keep saslPlain* nil
277 default:
278 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
279 }
280 }
281
282 err := db.db.QueryRow(`
283 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
284 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
285 sasl_external_key, enabled)
286 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
287 ON CONFLICT ("user", name)
288 DO UPDATE SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
289 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
290 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
291 enabled = $14
292 RETURNING id`,
293 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
294 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
295 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
296 return err
297}
298
299func (db *PostgresDB) DeleteNetwork(id int64) error {
300 _, err := db.db.Exec(`DELETE FROM "Network" WHERE id = $1`, id)
301 return err
302}
303
304func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
305 rows, err := db.db.Query(`
306 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
307 detach_on
308 FROM "Channel"
309 WHERE network = $1`, networkID)
310 if err != nil {
311 return nil, err
312 }
313 defer rows.Close()
314
315 var channels []Channel
316 for rows.Next() {
317 var ch Channel
318 var key, detachedInternalMsgID sql.NullString
319 var detachAfter int64
320 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
321 return nil, err
322 }
323 ch.Key = key.String
324 ch.DetachedInternalMsgID = detachedInternalMsgID.String
325 ch.DetachAfter = time.Duration(detachAfter) * time.Second
326 channels = append(channels, ch)
327 }
328 if err := rows.Err(); err != nil {
329 return nil, err
330 }
331
332 return channels, nil
333}
334
335func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
336 key := toNullString(ch.Key)
337 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
338 err := db.db.QueryRow(`
339 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
340 detach_after, detach_on)
341 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
342 ON CONFLICT (network, name)
343 DO UPDATE SET network = $1, name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
344 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
345 RETURNING id`,
346 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
347 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
348 return err
349}
350
351func (db *PostgresDB) DeleteChannel(id int64) error {
352 _, err := db.db.Exec(`DELETE FROM "Channel" WHERE id = $1`, id)
353 return err
354}
355
356func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
357 rows, err := db.db.Query(`
358 SELECT id, target, client, internal_msgid
359 FROM "DeliveryReceipt"
360 WHERE network = $1`, networkID)
361 if err != nil {
362 return nil, err
363 }
364 defer rows.Close()
365
366 var receipts []DeliveryReceipt
367 for rows.Next() {
368 var rcpt DeliveryReceipt
369 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
370 return nil, err
371 }
372 receipts = append(receipts, rcpt)
373 }
374 if err := rows.Err(); err != nil {
375 return nil, err
376 }
377
378 return receipts, nil
379}
380
381func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
382 stmt, err := db.db.Prepare(`
383 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
384 VALUES ($1, $2, $3, $4)
385 ON CONFLICT (network, target, client)
386 DO UPDATE SET internal_msgid = $4
387 RETURNING id`)
388 if err != nil {
389 return err
390 }
391 defer stmt.Close()
392
393 // No need for a transaction since all changes are atomic and don't break data coherence.
394 for i := range receipts {
395 rcpt := &receipts[i]
396 err := stmt.QueryRow(networkID, rcpt.Target, client, rcpt.InternalMsgID).Scan(&rcpt.ID)
397 if err != nil {
398 return err
399 }
400 }
401 return nil
402}
Note: See TracBrowser for help on using the repository browser.