source: code/trunk/db_postgres.go@ 736

Last change on this file since 736 was 712, checked in by contact, 4 years ago

Add Prometheus instrumentation for the database

File size: 14.1 KB
RevLine 
[620]1package soju
2
3import (
[645]4 "context"
[620]5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
[712]13 "github.com/prometheus/client_golang/prometheus"
14 promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
[620]15)
16
[645]17const postgresQueryTimeout = 5 * time.Second
18
[620]19const postgresConfigSchema = `
20CREATE TABLE IF NOT EXISTS "Config" (
21 id SMALLINT PRIMARY KEY,
22 version INTEGER NOT NULL,
23 CHECK(id = 1)
24);
25`
26
27const postgresSchema = `
28CREATE TABLE "User" (
29 id SERIAL PRIMARY KEY,
30 username VARCHAR(255) NOT NULL UNIQUE,
31 password VARCHAR(255),
32 admin BOOLEAN NOT NULL DEFAULT FALSE,
33 realname VARCHAR(255)
34);
35
36CREATE TABLE "Network" (
37 id SERIAL PRIMARY KEY,
38 name VARCHAR(255),
39 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
40 addr VARCHAR(255) NOT NULL,
[664]41 nick VARCHAR(255),
[620]42 username VARCHAR(255),
43 realname VARCHAR(255),
44 pass VARCHAR(255),
45 connect_commands VARCHAR(1023),
46 sasl_mechanism VARCHAR(255),
47 sasl_plain_username VARCHAR(255),
48 sasl_plain_password VARCHAR(255),
[640]49 sasl_external_cert BYTEA,
50 sasl_external_key BYTEA,
[620]51 enabled BOOLEAN NOT NULL DEFAULT TRUE,
52 UNIQUE("user", addr, nick),
53 UNIQUE("user", name)
54);
55
56CREATE TABLE "Channel" (
57 id SERIAL PRIMARY KEY,
58 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
59 name VARCHAR(255) NOT NULL,
60 key VARCHAR(255),
61 detached BOOLEAN NOT NULL DEFAULT FALSE,
62 detached_internal_msgid VARCHAR(255),
63 relay_detached INTEGER NOT NULL DEFAULT 0,
64 reattach_on INTEGER NOT NULL DEFAULT 0,
65 detach_after INTEGER NOT NULL DEFAULT 0,
66 detach_on INTEGER NOT NULL DEFAULT 0,
67 UNIQUE(network, name)
68);
69
70CREATE TABLE "DeliveryReceipt" (
71 id SERIAL PRIMARY KEY,
72 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
73 target VARCHAR(255) NOT NULL,
74 client VARCHAR(255) NOT NULL DEFAULT '',
75 internal_msgid VARCHAR(255) NOT NULL,
76 UNIQUE(network, target, client)
77);
78`
79
80var postgresMigrations = []string{
81 "", // migration #0 is reserved for schema initialization
[664]82 `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
[620]83}
84
85type PostgresDB struct {
86 db *sql.DB
87}
88
89func OpenPostgresDB(source string) (Database, error) {
90 sqlPostgresDB, err := sql.Open("postgres", source)
91 if err != nil {
92 return nil, err
93 }
94
95 db := &PostgresDB{db: sqlPostgresDB}
96 if err := db.upgrade(); err != nil {
97 sqlPostgresDB.Close()
98 return nil, err
99 }
100
101 return db, nil
102}
103
104func (db *PostgresDB) upgrade() error {
105 tx, err := db.db.Begin()
106 if err != nil {
107 return err
108 }
109 defer tx.Rollback()
110
111 if _, err := tx.Exec(postgresConfigSchema); err != nil {
112 return fmt.Errorf("failed to create Config table: %s", err)
113 }
114
115 var version int
116 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
117 if err != nil && !errors.Is(err, sql.ErrNoRows) {
118 return fmt.Errorf("failed to query schema version: %s", err)
119 }
120
121 if version == len(postgresMigrations) {
122 return nil
123 }
124 if version > len(postgresMigrations) {
125 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
126 }
127
128 if version == 0 {
129 if _, err := tx.Exec(postgresSchema); err != nil {
130 return fmt.Errorf("failed to initialize schema: %s", err)
131 }
132 } else {
133 for i := version; i < len(postgresMigrations); i++ {
134 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
135 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
136 }
137 }
138 }
139
140 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
141 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
142 if err != nil {
143 return fmt.Errorf("failed to bump schema version: %v", err)
144 }
145
146 return tx.Commit()
147}
148
149func (db *PostgresDB) Close() error {
150 return db.db.Close()
151}
152
[712]153func (db *PostgresDB) MetricsCollector() prometheus.Collector {
154 return promcollectors.NewDBStatsCollector(db.db, "main")
155}
156
[652]157func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
158 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]159 defer cancel()
160
[620]161 var stats DatabaseStats
[645]162 row := db.db.QueryRowContext(ctx, `SELECT
[620]163 (SELECT COUNT(*) FROM "User") AS users,
164 (SELECT COUNT(*) FROM "Network") AS networks,
165 (SELECT COUNT(*) FROM "Channel") AS channels`)
166 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
167 return nil, err
168 }
169
170 return &stats, nil
171}
172
[652]173func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
174 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]175 defer cancel()
176
177 rows, err := db.db.QueryContext(ctx,
178 `SELECT id, username, password, admin, realname FROM "User"`)
[620]179 if err != nil {
180 return nil, err
181 }
182 defer rows.Close()
183
184 var users []User
185 for rows.Next() {
186 var user User
187 var password, realname sql.NullString
188 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
189 return nil, err
190 }
191 user.Password = password.String
192 user.Realname = realname.String
193 users = append(users, user)
194 }
195 if err := rows.Err(); err != nil {
196 return nil, err
197 }
198
199 return users, nil
200}
201
[652]202func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
203 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]204 defer cancel()
205
[620]206 user := &User{Username: username}
207
208 var password, realname sql.NullString
[645]209 row := db.db.QueryRowContext(ctx,
[620]210 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
211 username)
212 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
213 return nil, err
214 }
215 user.Password = password.String
216 user.Realname = realname.String
217 return user, nil
218}
219
[652]220func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
221 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]222 defer cancel()
223
[620]224 password := toNullString(user.Password)
225 realname := toNullString(user.Realname)
[635]226
227 var err error
228 if user.ID == 0 {
[645]229 err = db.db.QueryRowContext(ctx, `
[635]230 INSERT INTO "User" (username, password, admin, realname)
231 VALUES ($1, $2, $3, $4)
232 RETURNING id`,
233 user.Username, password, user.Admin, realname).Scan(&user.ID)
234 } else {
[645]235 _, err = db.db.ExecContext(ctx, `
[635]236 UPDATE "User"
237 SET password = $1, admin = $2, realname = $3
238 WHERE id = $4`,
239 password, user.Admin, realname, user.ID)
240 }
[620]241 return err
242}
243
[652]244func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
245 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]246 defer cancel()
247
248 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
[620]249 return err
250}
251
[652]252func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
253 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]254 defer cancel()
255
256 rows, err := db.db.QueryContext(ctx, `
[620]257 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
258 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
259 FROM "Network"
260 WHERE "user" = $1`, userID)
261 if err != nil {
262 return nil, err
263 }
264 defer rows.Close()
265
266 var networks []Network
267 for rows.Next() {
268 var net Network
[664]269 var name, nick, username, realname, pass, connectCommands sql.NullString
[620]270 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
[664]271 err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
[620]272 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
273 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
274 if err != nil {
275 return nil, err
276 }
277 net.Name = name.String
[664]278 net.Nick = nick.String
[620]279 net.Username = username.String
280 net.Realname = realname.String
281 net.Pass = pass.String
282 if connectCommands.Valid {
283 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
284 }
285 net.SASL.Mechanism = saslMechanism.String
286 net.SASL.Plain.Username = saslPlainUsername.String
287 net.SASL.Plain.Password = saslPlainPassword.String
288 networks = append(networks, net)
289 }
290 if err := rows.Err(); err != nil {
291 return nil, err
292 }
293
294 return networks, nil
295}
296
[652]297func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
298 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]299 defer cancel()
300
[620]301 netName := toNullString(network.Name)
[664]302 nick := toNullString(network.Nick)
[620]303 netUsername := toNullString(network.Username)
304 realname := toNullString(network.Realname)
305 pass := toNullString(network.Pass)
306 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
307
308 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
309 if network.SASL.Mechanism != "" {
310 saslMechanism = toNullString(network.SASL.Mechanism)
311 switch network.SASL.Mechanism {
312 case "PLAIN":
313 saslPlainUsername = toNullString(network.SASL.Plain.Username)
314 saslPlainPassword = toNullString(network.SASL.Plain.Password)
315 network.SASL.External.CertBlob = nil
316 network.SASL.External.PrivKeyBlob = nil
317 case "EXTERNAL":
318 // keep saslPlain* nil
319 default:
320 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
321 }
322 }
323
[635]324 var err error
325 if network.ID == 0 {
[645]326 err = db.db.QueryRowContext(ctx, `
[635]327 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
328 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
329 sasl_external_key, enabled)
330 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
331 RETURNING id`,
[664]332 userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
[635]333 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
334 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
335 } else {
[645]336 _, err = db.db.ExecContext(ctx, `
[635]337 UPDATE "Network"
338 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
339 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
340 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
341 enabled = $14
342 WHERE id = $1`,
[664]343 network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
[635]344 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
345 network.SASL.External.PrivKeyBlob, network.Enabled)
346 }
[620]347 return err
348}
349
[652]350func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
351 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]352 defer cancel()
353
354 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
[620]355 return err
356}
357
[652]358func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
359 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]360 defer cancel()
361
362 rows, err := db.db.QueryContext(ctx, `
[620]363 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
364 detach_on
365 FROM "Channel"
366 WHERE network = $1`, networkID)
367 if err != nil {
368 return nil, err
369 }
370 defer rows.Close()
371
372 var channels []Channel
373 for rows.Next() {
374 var ch Channel
375 var key, detachedInternalMsgID sql.NullString
376 var detachAfter int64
377 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
378 return nil, err
379 }
380 ch.Key = key.String
381 ch.DetachedInternalMsgID = detachedInternalMsgID.String
382 ch.DetachAfter = time.Duration(detachAfter) * time.Second
383 channels = append(channels, ch)
384 }
385 if err := rows.Err(); err != nil {
386 return nil, err
387 }
388
389 return channels, nil
390}
391
[652]392func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
393 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]394 defer cancel()
395
[620]396 key := toNullString(ch.Key)
397 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
[635]398
399 var err error
400 if ch.ID == 0 {
[645]401 err = db.db.QueryRowContext(ctx, `
[635]402 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
403 detach_after, detach_on)
404 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
405 RETURNING id`,
406 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
407 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
408 } else {
[645]409 _, err = db.db.ExecContext(ctx, `
[635]410 UPDATE "Channel"
411 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
412 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
413 WHERE id = $1`,
414 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
415 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
416 }
[620]417 return err
418}
419
[652]420func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
421 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]422 defer cancel()
423
424 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
[620]425 return err
426}
427
[652]428func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
429 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]430 defer cancel()
431
432 rows, err := db.db.QueryContext(ctx, `
[620]433 SELECT id, target, client, internal_msgid
434 FROM "DeliveryReceipt"
435 WHERE network = $1`, networkID)
436 if err != nil {
437 return nil, err
438 }
439 defer rows.Close()
440
441 var receipts []DeliveryReceipt
442 for rows.Next() {
443 var rcpt DeliveryReceipt
444 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
445 return nil, err
446 }
447 receipts = append(receipts, rcpt)
448 }
449 if err := rows.Err(); err != nil {
450 return nil, err
451 }
452
453 return receipts, nil
454}
455
[652]456func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
457 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]458 defer cancel()
459
[635]460 tx, err := db.db.Begin()
461 if err != nil {
462 return err
463 }
464 defer tx.Rollback()
465
[645]466 _, err = tx.ExecContext(ctx,
467 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
[635]468 networkID, client)
469 if err != nil {
470 return err
471 }
472
[645]473 stmt, err := tx.PrepareContext(ctx, `
[620]474 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
475 VALUES ($1, $2, $3, $4)
476 RETURNING id`)
477 if err != nil {
478 return err
479 }
480 defer stmt.Close()
481
482 for i := range receipts {
483 rcpt := &receipts[i]
[645]484 err := stmt.
485 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
486 Scan(&rcpt.ID)
[620]487 if err != nil {
488 return err
489 }
490 }
[635]491
492 return tx.Commit()
[620]493}
Note: See TracBrowser for help on using the repository browser.