source: code/trunk/db_postgres.go@ 777

Last change on this file since 777 was 774, checked in by contact, 3 years ago

db_postgres: use enum for sasl_mechanism

Ensures only supported mechanisms get stored in the DB.

File size: 14.3 KB
RevLine 
[620]1package soju
2
3import (
[645]4 "context"
[620]5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
[712]13 "github.com/prometheus/client_golang/prometheus"
14 promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
[620]15)
16
[645]17const postgresQueryTimeout = 5 * time.Second
18
[620]19const postgresConfigSchema = `
20CREATE TABLE IF NOT EXISTS "Config" (
21 id SMALLINT PRIMARY KEY,
22 version INTEGER NOT NULL,
23 CHECK(id = 1)
24);
25`
26
27const postgresSchema = `
28CREATE TABLE "User" (
29 id SERIAL PRIMARY KEY,
30 username VARCHAR(255) NOT NULL UNIQUE,
31 password VARCHAR(255),
32 admin BOOLEAN NOT NULL DEFAULT FALSE,
33 realname VARCHAR(255)
34);
35
[774]36CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
37
[620]38CREATE TABLE "Network" (
39 id SERIAL PRIMARY KEY,
40 name VARCHAR(255),
41 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
42 addr VARCHAR(255) NOT NULL,
[664]43 nick VARCHAR(255),
[620]44 username VARCHAR(255),
45 realname VARCHAR(255),
46 pass VARCHAR(255),
47 connect_commands VARCHAR(1023),
[774]48 sasl_mechanism sasl_mechanism,
[620]49 sasl_plain_username VARCHAR(255),
50 sasl_plain_password VARCHAR(255),
[640]51 sasl_external_cert BYTEA,
52 sasl_external_key BYTEA,
[620]53 enabled BOOLEAN NOT NULL DEFAULT TRUE,
54 UNIQUE("user", addr, nick),
55 UNIQUE("user", name)
56);
57
58CREATE TABLE "Channel" (
59 id SERIAL PRIMARY KEY,
60 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
61 name VARCHAR(255) NOT NULL,
62 key VARCHAR(255),
63 detached BOOLEAN NOT NULL DEFAULT FALSE,
64 detached_internal_msgid VARCHAR(255),
65 relay_detached INTEGER NOT NULL DEFAULT 0,
66 reattach_on INTEGER NOT NULL DEFAULT 0,
67 detach_after INTEGER NOT NULL DEFAULT 0,
68 detach_on INTEGER NOT NULL DEFAULT 0,
69 UNIQUE(network, name)
70);
71
72CREATE TABLE "DeliveryReceipt" (
73 id SERIAL PRIMARY KEY,
74 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
75 target VARCHAR(255) NOT NULL,
76 client VARCHAR(255) NOT NULL DEFAULT '',
77 internal_msgid VARCHAR(255) NOT NULL,
78 UNIQUE(network, target, client)
79);
80`
81
82var postgresMigrations = []string{
83 "", // migration #0 is reserved for schema initialization
[664]84 `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
[774]85 `
86 CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
87 ALTER TABLE "Network"
88 ALTER COLUMN sasl_mechanism
89 TYPE sasl_mechanism
90 USING sasl_mechanism::sasl_mechanism;
91 `,
[620]92}
93
94type PostgresDB struct {
95 db *sql.DB
96}
97
98func OpenPostgresDB(source string) (Database, error) {
99 sqlPostgresDB, err := sql.Open("postgres", source)
100 if err != nil {
101 return nil, err
102 }
103
104 db := &PostgresDB{db: sqlPostgresDB}
105 if err := db.upgrade(); err != nil {
106 sqlPostgresDB.Close()
107 return nil, err
108 }
109
110 return db, nil
111}
112
113func (db *PostgresDB) upgrade() error {
114 tx, err := db.db.Begin()
115 if err != nil {
116 return err
117 }
118 defer tx.Rollback()
119
120 if _, err := tx.Exec(postgresConfigSchema); err != nil {
121 return fmt.Errorf("failed to create Config table: %s", err)
122 }
123
124 var version int
125 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
126 if err != nil && !errors.Is(err, sql.ErrNoRows) {
127 return fmt.Errorf("failed to query schema version: %s", err)
128 }
129
130 if version == len(postgresMigrations) {
131 return nil
132 }
133 if version > len(postgresMigrations) {
134 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
135 }
136
137 if version == 0 {
138 if _, err := tx.Exec(postgresSchema); err != nil {
139 return fmt.Errorf("failed to initialize schema: %s", err)
140 }
141 } else {
142 for i := version; i < len(postgresMigrations); i++ {
143 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
144 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
145 }
146 }
147 }
148
149 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
150 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
151 if err != nil {
152 return fmt.Errorf("failed to bump schema version: %v", err)
153 }
154
155 return tx.Commit()
156}
157
158func (db *PostgresDB) Close() error {
159 return db.db.Close()
160}
161
[712]162func (db *PostgresDB) MetricsCollector() prometheus.Collector {
163 return promcollectors.NewDBStatsCollector(db.db, "main")
164}
165
[652]166func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
167 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]168 defer cancel()
169
[620]170 var stats DatabaseStats
[645]171 row := db.db.QueryRowContext(ctx, `SELECT
[620]172 (SELECT COUNT(*) FROM "User") AS users,
173 (SELECT COUNT(*) FROM "Network") AS networks,
174 (SELECT COUNT(*) FROM "Channel") AS channels`)
175 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
176 return nil, err
177 }
178
179 return &stats, nil
180}
181
[652]182func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
183 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]184 defer cancel()
185
186 rows, err := db.db.QueryContext(ctx,
187 `SELECT id, username, password, admin, realname FROM "User"`)
[620]188 if err != nil {
189 return nil, err
190 }
191 defer rows.Close()
192
193 var users []User
194 for rows.Next() {
195 var user User
196 var password, realname sql.NullString
197 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
198 return nil, err
199 }
200 user.Password = password.String
201 user.Realname = realname.String
202 users = append(users, user)
203 }
204 if err := rows.Err(); err != nil {
205 return nil, err
206 }
207
208 return users, nil
209}
210
[652]211func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
212 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]213 defer cancel()
214
[620]215 user := &User{Username: username}
216
217 var password, realname sql.NullString
[645]218 row := db.db.QueryRowContext(ctx,
[620]219 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
220 username)
221 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
222 return nil, err
223 }
224 user.Password = password.String
225 user.Realname = realname.String
226 return user, nil
227}
228
[652]229func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
230 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]231 defer cancel()
232
[620]233 password := toNullString(user.Password)
234 realname := toNullString(user.Realname)
[635]235
236 var err error
237 if user.ID == 0 {
[645]238 err = db.db.QueryRowContext(ctx, `
[635]239 INSERT INTO "User" (username, password, admin, realname)
240 VALUES ($1, $2, $3, $4)
241 RETURNING id`,
242 user.Username, password, user.Admin, realname).Scan(&user.ID)
243 } else {
[645]244 _, err = db.db.ExecContext(ctx, `
[635]245 UPDATE "User"
246 SET password = $1, admin = $2, realname = $3
247 WHERE id = $4`,
248 password, user.Admin, realname, user.ID)
249 }
[620]250 return err
251}
252
[652]253func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
254 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]255 defer cancel()
256
257 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
[620]258 return err
259}
260
[652]261func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
262 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]263 defer cancel()
264
265 rows, err := db.db.QueryContext(ctx, `
[620]266 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
267 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
268 FROM "Network"
269 WHERE "user" = $1`, userID)
270 if err != nil {
271 return nil, err
272 }
273 defer rows.Close()
274
275 var networks []Network
276 for rows.Next() {
277 var net Network
[664]278 var name, nick, username, realname, pass, connectCommands sql.NullString
[620]279 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
[664]280 err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
[620]281 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
282 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
283 if err != nil {
284 return nil, err
285 }
286 net.Name = name.String
[664]287 net.Nick = nick.String
[620]288 net.Username = username.String
289 net.Realname = realname.String
290 net.Pass = pass.String
291 if connectCommands.Valid {
292 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
293 }
294 net.SASL.Mechanism = saslMechanism.String
295 net.SASL.Plain.Username = saslPlainUsername.String
296 net.SASL.Plain.Password = saslPlainPassword.String
297 networks = append(networks, net)
298 }
299 if err := rows.Err(); err != nil {
300 return nil, err
301 }
302
303 return networks, nil
304}
305
[652]306func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
307 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]308 defer cancel()
309
[620]310 netName := toNullString(network.Name)
[664]311 nick := toNullString(network.Nick)
[620]312 netUsername := toNullString(network.Username)
313 realname := toNullString(network.Realname)
314 pass := toNullString(network.Pass)
315 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
316
317 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
318 if network.SASL.Mechanism != "" {
319 saslMechanism = toNullString(network.SASL.Mechanism)
320 switch network.SASL.Mechanism {
321 case "PLAIN":
322 saslPlainUsername = toNullString(network.SASL.Plain.Username)
323 saslPlainPassword = toNullString(network.SASL.Plain.Password)
324 network.SASL.External.CertBlob = nil
325 network.SASL.External.PrivKeyBlob = nil
326 case "EXTERNAL":
327 // keep saslPlain* nil
328 default:
329 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
330 }
331 }
332
[635]333 var err error
334 if network.ID == 0 {
[645]335 err = db.db.QueryRowContext(ctx, `
[635]336 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
337 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
338 sasl_external_key, enabled)
339 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
340 RETURNING id`,
[664]341 userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
[635]342 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
343 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
344 } else {
[645]345 _, err = db.db.ExecContext(ctx, `
[635]346 UPDATE "Network"
347 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
348 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
349 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
350 enabled = $14
351 WHERE id = $1`,
[664]352 network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
[635]353 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
354 network.SASL.External.PrivKeyBlob, network.Enabled)
355 }
[620]356 return err
357}
358
[652]359func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
360 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]361 defer cancel()
362
363 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
[620]364 return err
365}
366
[652]367func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
368 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]369 defer cancel()
370
371 rows, err := db.db.QueryContext(ctx, `
[620]372 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
373 detach_on
374 FROM "Channel"
375 WHERE network = $1`, networkID)
376 if err != nil {
377 return nil, err
378 }
379 defer rows.Close()
380
381 var channels []Channel
382 for rows.Next() {
383 var ch Channel
384 var key, detachedInternalMsgID sql.NullString
385 var detachAfter int64
386 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
387 return nil, err
388 }
389 ch.Key = key.String
390 ch.DetachedInternalMsgID = detachedInternalMsgID.String
391 ch.DetachAfter = time.Duration(detachAfter) * time.Second
392 channels = append(channels, ch)
393 }
394 if err := rows.Err(); err != nil {
395 return nil, err
396 }
397
398 return channels, nil
399}
400
[652]401func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
402 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]403 defer cancel()
404
[620]405 key := toNullString(ch.Key)
406 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
[635]407
408 var err error
409 if ch.ID == 0 {
[645]410 err = db.db.QueryRowContext(ctx, `
[635]411 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
412 detach_after, detach_on)
413 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
414 RETURNING id`,
415 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
416 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
417 } else {
[645]418 _, err = db.db.ExecContext(ctx, `
[635]419 UPDATE "Channel"
420 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
421 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
422 WHERE id = $1`,
423 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
424 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
425 }
[620]426 return err
427}
428
[652]429func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
430 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]431 defer cancel()
432
433 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
[620]434 return err
435}
436
[652]437func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
438 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]439 defer cancel()
440
441 rows, err := db.db.QueryContext(ctx, `
[620]442 SELECT id, target, client, internal_msgid
443 FROM "DeliveryReceipt"
444 WHERE network = $1`, networkID)
445 if err != nil {
446 return nil, err
447 }
448 defer rows.Close()
449
450 var receipts []DeliveryReceipt
451 for rows.Next() {
452 var rcpt DeliveryReceipt
453 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
454 return nil, err
455 }
456 receipts = append(receipts, rcpt)
457 }
458 if err := rows.Err(); err != nil {
459 return nil, err
460 }
461
462 return receipts, nil
463}
464
[652]465func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
466 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]467 defer cancel()
468
[635]469 tx, err := db.db.Begin()
470 if err != nil {
471 return err
472 }
473 defer tx.Rollback()
474
[645]475 _, err = tx.ExecContext(ctx,
476 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
[635]477 networkID, client)
478 if err != nil {
479 return err
480 }
481
[645]482 stmt, err := tx.PrepareContext(ctx, `
[620]483 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
484 VALUES ($1, $2, $3, $4)
485 RETURNING id`)
486 if err != nil {
487 return err
488 }
489 defer stmt.Close()
490
491 for i := range receipts {
492 rcpt := &receipts[i]
[645]493 err := stmt.
494 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
495 Scan(&rcpt.ID)
[620]496 if err != nil {
497 return err
498 }
499 }
[635]500
501 return tx.Commit()
[620]502}
Note: See TracBrowser for help on using the repository browser.