source: code/trunk/db_postgres.go@ 765

Last change on this file since 765 was 712, checked in by contact, 4 years ago

Add Prometheus instrumentation for the database

File size: 14.1 KB
Line 
1package soju
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
13 "github.com/prometheus/client_golang/prometheus"
14 promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
15)
16
17const postgresQueryTimeout = 5 * time.Second
18
19const postgresConfigSchema = `
20CREATE TABLE IF NOT EXISTS "Config" (
21 id SMALLINT PRIMARY KEY,
22 version INTEGER NOT NULL,
23 CHECK(id = 1)
24);
25`
26
27const postgresSchema = `
28CREATE TABLE "User" (
29 id SERIAL PRIMARY KEY,
30 username VARCHAR(255) NOT NULL UNIQUE,
31 password VARCHAR(255),
32 admin BOOLEAN NOT NULL DEFAULT FALSE,
33 realname VARCHAR(255)
34);
35
36CREATE TABLE "Network" (
37 id SERIAL PRIMARY KEY,
38 name VARCHAR(255),
39 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
40 addr VARCHAR(255) NOT NULL,
41 nick VARCHAR(255),
42 username VARCHAR(255),
43 realname VARCHAR(255),
44 pass VARCHAR(255),
45 connect_commands VARCHAR(1023),
46 sasl_mechanism VARCHAR(255),
47 sasl_plain_username VARCHAR(255),
48 sasl_plain_password VARCHAR(255),
49 sasl_external_cert BYTEA,
50 sasl_external_key BYTEA,
51 enabled BOOLEAN NOT NULL DEFAULT TRUE,
52 UNIQUE("user", addr, nick),
53 UNIQUE("user", name)
54);
55
56CREATE TABLE "Channel" (
57 id SERIAL PRIMARY KEY,
58 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
59 name VARCHAR(255) NOT NULL,
60 key VARCHAR(255),
61 detached BOOLEAN NOT NULL DEFAULT FALSE,
62 detached_internal_msgid VARCHAR(255),
63 relay_detached INTEGER NOT NULL DEFAULT 0,
64 reattach_on INTEGER NOT NULL DEFAULT 0,
65 detach_after INTEGER NOT NULL DEFAULT 0,
66 detach_on INTEGER NOT NULL DEFAULT 0,
67 UNIQUE(network, name)
68);
69
70CREATE TABLE "DeliveryReceipt" (
71 id SERIAL PRIMARY KEY,
72 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
73 target VARCHAR(255) NOT NULL,
74 client VARCHAR(255) NOT NULL DEFAULT '',
75 internal_msgid VARCHAR(255) NOT NULL,
76 UNIQUE(network, target, client)
77);
78`
79
80var postgresMigrations = []string{
81 "", // migration #0 is reserved for schema initialization
82 `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
83}
84
85type PostgresDB struct {
86 db *sql.DB
87}
88
89func OpenPostgresDB(source string) (Database, error) {
90 sqlPostgresDB, err := sql.Open("postgres", source)
91 if err != nil {
92 return nil, err
93 }
94
95 db := &PostgresDB{db: sqlPostgresDB}
96 if err := db.upgrade(); err != nil {
97 sqlPostgresDB.Close()
98 return nil, err
99 }
100
101 return db, nil
102}
103
104func (db *PostgresDB) upgrade() error {
105 tx, err := db.db.Begin()
106 if err != nil {
107 return err
108 }
109 defer tx.Rollback()
110
111 if _, err := tx.Exec(postgresConfigSchema); err != nil {
112 return fmt.Errorf("failed to create Config table: %s", err)
113 }
114
115 var version int
116 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
117 if err != nil && !errors.Is(err, sql.ErrNoRows) {
118 return fmt.Errorf("failed to query schema version: %s", err)
119 }
120
121 if version == len(postgresMigrations) {
122 return nil
123 }
124 if version > len(postgresMigrations) {
125 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
126 }
127
128 if version == 0 {
129 if _, err := tx.Exec(postgresSchema); err != nil {
130 return fmt.Errorf("failed to initialize schema: %s", err)
131 }
132 } else {
133 for i := version; i < len(postgresMigrations); i++ {
134 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
135 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
136 }
137 }
138 }
139
140 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
141 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
142 if err != nil {
143 return fmt.Errorf("failed to bump schema version: %v", err)
144 }
145
146 return tx.Commit()
147}
148
149func (db *PostgresDB) Close() error {
150 return db.db.Close()
151}
152
153func (db *PostgresDB) MetricsCollector() prometheus.Collector {
154 return promcollectors.NewDBStatsCollector(db.db, "main")
155}
156
157func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
158 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
159 defer cancel()
160
161 var stats DatabaseStats
162 row := db.db.QueryRowContext(ctx, `SELECT
163 (SELECT COUNT(*) FROM "User") AS users,
164 (SELECT COUNT(*) FROM "Network") AS networks,
165 (SELECT COUNT(*) FROM "Channel") AS channels`)
166 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
167 return nil, err
168 }
169
170 return &stats, nil
171}
172
173func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
174 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
175 defer cancel()
176
177 rows, err := db.db.QueryContext(ctx,
178 `SELECT id, username, password, admin, realname FROM "User"`)
179 if err != nil {
180 return nil, err
181 }
182 defer rows.Close()
183
184 var users []User
185 for rows.Next() {
186 var user User
187 var password, realname sql.NullString
188 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
189 return nil, err
190 }
191 user.Password = password.String
192 user.Realname = realname.String
193 users = append(users, user)
194 }
195 if err := rows.Err(); err != nil {
196 return nil, err
197 }
198
199 return users, nil
200}
201
202func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
203 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
204 defer cancel()
205
206 user := &User{Username: username}
207
208 var password, realname sql.NullString
209 row := db.db.QueryRowContext(ctx,
210 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
211 username)
212 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
213 return nil, err
214 }
215 user.Password = password.String
216 user.Realname = realname.String
217 return user, nil
218}
219
220func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
221 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
222 defer cancel()
223
224 password := toNullString(user.Password)
225 realname := toNullString(user.Realname)
226
227 var err error
228 if user.ID == 0 {
229 err = db.db.QueryRowContext(ctx, `
230 INSERT INTO "User" (username, password, admin, realname)
231 VALUES ($1, $2, $3, $4)
232 RETURNING id`,
233 user.Username, password, user.Admin, realname).Scan(&user.ID)
234 } else {
235 _, err = db.db.ExecContext(ctx, `
236 UPDATE "User"
237 SET password = $1, admin = $2, realname = $3
238 WHERE id = $4`,
239 password, user.Admin, realname, user.ID)
240 }
241 return err
242}
243
244func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
245 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
246 defer cancel()
247
248 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
249 return err
250}
251
252func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
253 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
254 defer cancel()
255
256 rows, err := db.db.QueryContext(ctx, `
257 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
258 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
259 FROM "Network"
260 WHERE "user" = $1`, userID)
261 if err != nil {
262 return nil, err
263 }
264 defer rows.Close()
265
266 var networks []Network
267 for rows.Next() {
268 var net Network
269 var name, nick, username, realname, pass, connectCommands sql.NullString
270 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
271 err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
272 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
273 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
274 if err != nil {
275 return nil, err
276 }
277 net.Name = name.String
278 net.Nick = nick.String
279 net.Username = username.String
280 net.Realname = realname.String
281 net.Pass = pass.String
282 if connectCommands.Valid {
283 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
284 }
285 net.SASL.Mechanism = saslMechanism.String
286 net.SASL.Plain.Username = saslPlainUsername.String
287 net.SASL.Plain.Password = saslPlainPassword.String
288 networks = append(networks, net)
289 }
290 if err := rows.Err(); err != nil {
291 return nil, err
292 }
293
294 return networks, nil
295}
296
297func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
298 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
299 defer cancel()
300
301 netName := toNullString(network.Name)
302 nick := toNullString(network.Nick)
303 netUsername := toNullString(network.Username)
304 realname := toNullString(network.Realname)
305 pass := toNullString(network.Pass)
306 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
307
308 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
309 if network.SASL.Mechanism != "" {
310 saslMechanism = toNullString(network.SASL.Mechanism)
311 switch network.SASL.Mechanism {
312 case "PLAIN":
313 saslPlainUsername = toNullString(network.SASL.Plain.Username)
314 saslPlainPassword = toNullString(network.SASL.Plain.Password)
315 network.SASL.External.CertBlob = nil
316 network.SASL.External.PrivKeyBlob = nil
317 case "EXTERNAL":
318 // keep saslPlain* nil
319 default:
320 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
321 }
322 }
323
324 var err error
325 if network.ID == 0 {
326 err = db.db.QueryRowContext(ctx, `
327 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
328 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
329 sasl_external_key, enabled)
330 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
331 RETURNING id`,
332 userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
333 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
334 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
335 } else {
336 _, err = db.db.ExecContext(ctx, `
337 UPDATE "Network"
338 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
339 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
340 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
341 enabled = $14
342 WHERE id = $1`,
343 network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
344 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
345 network.SASL.External.PrivKeyBlob, network.Enabled)
346 }
347 return err
348}
349
350func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
351 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
352 defer cancel()
353
354 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
355 return err
356}
357
358func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
359 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
360 defer cancel()
361
362 rows, err := db.db.QueryContext(ctx, `
363 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
364 detach_on
365 FROM "Channel"
366 WHERE network = $1`, networkID)
367 if err != nil {
368 return nil, err
369 }
370 defer rows.Close()
371
372 var channels []Channel
373 for rows.Next() {
374 var ch Channel
375 var key, detachedInternalMsgID sql.NullString
376 var detachAfter int64
377 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
378 return nil, err
379 }
380 ch.Key = key.String
381 ch.DetachedInternalMsgID = detachedInternalMsgID.String
382 ch.DetachAfter = time.Duration(detachAfter) * time.Second
383 channels = append(channels, ch)
384 }
385 if err := rows.Err(); err != nil {
386 return nil, err
387 }
388
389 return channels, nil
390}
391
392func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
393 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
394 defer cancel()
395
396 key := toNullString(ch.Key)
397 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
398
399 var err error
400 if ch.ID == 0 {
401 err = db.db.QueryRowContext(ctx, `
402 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
403 detach_after, detach_on)
404 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
405 RETURNING id`,
406 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
407 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
408 } else {
409 _, err = db.db.ExecContext(ctx, `
410 UPDATE "Channel"
411 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
412 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
413 WHERE id = $1`,
414 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
415 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
416 }
417 return err
418}
419
420func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
421 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
422 defer cancel()
423
424 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
425 return err
426}
427
428func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
429 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
430 defer cancel()
431
432 rows, err := db.db.QueryContext(ctx, `
433 SELECT id, target, client, internal_msgid
434 FROM "DeliveryReceipt"
435 WHERE network = $1`, networkID)
436 if err != nil {
437 return nil, err
438 }
439 defer rows.Close()
440
441 var receipts []DeliveryReceipt
442 for rows.Next() {
443 var rcpt DeliveryReceipt
444 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
445 return nil, err
446 }
447 receipts = append(receipts, rcpt)
448 }
449 if err := rows.Err(); err != nil {
450 return nil, err
451 }
452
453 return receipts, nil
454}
455
456func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
457 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
458 defer cancel()
459
460 tx, err := db.db.Begin()
461 if err != nil {
462 return err
463 }
464 defer tx.Rollback()
465
466 _, err = tx.ExecContext(ctx,
467 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
468 networkID, client)
469 if err != nil {
470 return err
471 }
472
473 stmt, err := tx.PrepareContext(ctx, `
474 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
475 VALUES ($1, $2, $3, $4)
476 RETURNING id`)
477 if err != nil {
478 return err
479 }
480 defer stmt.Close()
481
482 for i := range receipts {
483 rcpt := &receipts[i]
484 err := stmt.
485 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
486 Scan(&rcpt.ID)
487 if err != nil {
488 return err
489 }
490 }
491
492 return tx.Commit()
493}
Note: See TracBrowser for help on using the repository browser.