source: code/trunk/db_postgres.go@ 779

Last change on this file since 779 was 774, checked in by contact, 3 years ago

db_postgres: use enum for sasl_mechanism

Ensures only supported mechanisms get stored in the DB.

File size: 14.3 KB
Line 
1package soju
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
13 "github.com/prometheus/client_golang/prometheus"
14 promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
15)
16
17const postgresQueryTimeout = 5 * time.Second
18
19const postgresConfigSchema = `
20CREATE TABLE IF NOT EXISTS "Config" (
21 id SMALLINT PRIMARY KEY,
22 version INTEGER NOT NULL,
23 CHECK(id = 1)
24);
25`
26
27const postgresSchema = `
28CREATE TABLE "User" (
29 id SERIAL PRIMARY KEY,
30 username VARCHAR(255) NOT NULL UNIQUE,
31 password VARCHAR(255),
32 admin BOOLEAN NOT NULL DEFAULT FALSE,
33 realname VARCHAR(255)
34);
35
36CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
37
38CREATE TABLE "Network" (
39 id SERIAL PRIMARY KEY,
40 name VARCHAR(255),
41 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
42 addr VARCHAR(255) NOT NULL,
43 nick VARCHAR(255),
44 username VARCHAR(255),
45 realname VARCHAR(255),
46 pass VARCHAR(255),
47 connect_commands VARCHAR(1023),
48 sasl_mechanism sasl_mechanism,
49 sasl_plain_username VARCHAR(255),
50 sasl_plain_password VARCHAR(255),
51 sasl_external_cert BYTEA,
52 sasl_external_key BYTEA,
53 enabled BOOLEAN NOT NULL DEFAULT TRUE,
54 UNIQUE("user", addr, nick),
55 UNIQUE("user", name)
56);
57
58CREATE TABLE "Channel" (
59 id SERIAL PRIMARY KEY,
60 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
61 name VARCHAR(255) NOT NULL,
62 key VARCHAR(255),
63 detached BOOLEAN NOT NULL DEFAULT FALSE,
64 detached_internal_msgid VARCHAR(255),
65 relay_detached INTEGER NOT NULL DEFAULT 0,
66 reattach_on INTEGER NOT NULL DEFAULT 0,
67 detach_after INTEGER NOT NULL DEFAULT 0,
68 detach_on INTEGER NOT NULL DEFAULT 0,
69 UNIQUE(network, name)
70);
71
72CREATE TABLE "DeliveryReceipt" (
73 id SERIAL PRIMARY KEY,
74 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
75 target VARCHAR(255) NOT NULL,
76 client VARCHAR(255) NOT NULL DEFAULT '',
77 internal_msgid VARCHAR(255) NOT NULL,
78 UNIQUE(network, target, client)
79);
80`
81
82var postgresMigrations = []string{
83 "", // migration #0 is reserved for schema initialization
84 `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
85 `
86 CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
87 ALTER TABLE "Network"
88 ALTER COLUMN sasl_mechanism
89 TYPE sasl_mechanism
90 USING sasl_mechanism::sasl_mechanism;
91 `,
92}
93
94type PostgresDB struct {
95 db *sql.DB
96}
97
98func OpenPostgresDB(source string) (Database, error) {
99 sqlPostgresDB, err := sql.Open("postgres", source)
100 if err != nil {
101 return nil, err
102 }
103
104 db := &PostgresDB{db: sqlPostgresDB}
105 if err := db.upgrade(); err != nil {
106 sqlPostgresDB.Close()
107 return nil, err
108 }
109
110 return db, nil
111}
112
113func (db *PostgresDB) upgrade() error {
114 tx, err := db.db.Begin()
115 if err != nil {
116 return err
117 }
118 defer tx.Rollback()
119
120 if _, err := tx.Exec(postgresConfigSchema); err != nil {
121 return fmt.Errorf("failed to create Config table: %s", err)
122 }
123
124 var version int
125 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
126 if err != nil && !errors.Is(err, sql.ErrNoRows) {
127 return fmt.Errorf("failed to query schema version: %s", err)
128 }
129
130 if version == len(postgresMigrations) {
131 return nil
132 }
133 if version > len(postgresMigrations) {
134 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
135 }
136
137 if version == 0 {
138 if _, err := tx.Exec(postgresSchema); err != nil {
139 return fmt.Errorf("failed to initialize schema: %s", err)
140 }
141 } else {
142 for i := version; i < len(postgresMigrations); i++ {
143 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
144 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
145 }
146 }
147 }
148
149 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
150 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
151 if err != nil {
152 return fmt.Errorf("failed to bump schema version: %v", err)
153 }
154
155 return tx.Commit()
156}
157
158func (db *PostgresDB) Close() error {
159 return db.db.Close()
160}
161
162func (db *PostgresDB) MetricsCollector() prometheus.Collector {
163 return promcollectors.NewDBStatsCollector(db.db, "main")
164}
165
166func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
167 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
168 defer cancel()
169
170 var stats DatabaseStats
171 row := db.db.QueryRowContext(ctx, `SELECT
172 (SELECT COUNT(*) FROM "User") AS users,
173 (SELECT COUNT(*) FROM "Network") AS networks,
174 (SELECT COUNT(*) FROM "Channel") AS channels`)
175 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
176 return nil, err
177 }
178
179 return &stats, nil
180}
181
182func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
183 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
184 defer cancel()
185
186 rows, err := db.db.QueryContext(ctx,
187 `SELECT id, username, password, admin, realname FROM "User"`)
188 if err != nil {
189 return nil, err
190 }
191 defer rows.Close()
192
193 var users []User
194 for rows.Next() {
195 var user User
196 var password, realname sql.NullString
197 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
198 return nil, err
199 }
200 user.Password = password.String
201 user.Realname = realname.String
202 users = append(users, user)
203 }
204 if err := rows.Err(); err != nil {
205 return nil, err
206 }
207
208 return users, nil
209}
210
211func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
212 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
213 defer cancel()
214
215 user := &User{Username: username}
216
217 var password, realname sql.NullString
218 row := db.db.QueryRowContext(ctx,
219 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
220 username)
221 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
222 return nil, err
223 }
224 user.Password = password.String
225 user.Realname = realname.String
226 return user, nil
227}
228
229func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
230 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
231 defer cancel()
232
233 password := toNullString(user.Password)
234 realname := toNullString(user.Realname)
235
236 var err error
237 if user.ID == 0 {
238 err = db.db.QueryRowContext(ctx, `
239 INSERT INTO "User" (username, password, admin, realname)
240 VALUES ($1, $2, $3, $4)
241 RETURNING id`,
242 user.Username, password, user.Admin, realname).Scan(&user.ID)
243 } else {
244 _, err = db.db.ExecContext(ctx, `
245 UPDATE "User"
246 SET password = $1, admin = $2, realname = $3
247 WHERE id = $4`,
248 password, user.Admin, realname, user.ID)
249 }
250 return err
251}
252
253func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
254 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
255 defer cancel()
256
257 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
258 return err
259}
260
261func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
262 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
263 defer cancel()
264
265 rows, err := db.db.QueryContext(ctx, `
266 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
267 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
268 FROM "Network"
269 WHERE "user" = $1`, userID)
270 if err != nil {
271 return nil, err
272 }
273 defer rows.Close()
274
275 var networks []Network
276 for rows.Next() {
277 var net Network
278 var name, nick, username, realname, pass, connectCommands sql.NullString
279 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
280 err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
281 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
282 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
283 if err != nil {
284 return nil, err
285 }
286 net.Name = name.String
287 net.Nick = nick.String
288 net.Username = username.String
289 net.Realname = realname.String
290 net.Pass = pass.String
291 if connectCommands.Valid {
292 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
293 }
294 net.SASL.Mechanism = saslMechanism.String
295 net.SASL.Plain.Username = saslPlainUsername.String
296 net.SASL.Plain.Password = saslPlainPassword.String
297 networks = append(networks, net)
298 }
299 if err := rows.Err(); err != nil {
300 return nil, err
301 }
302
303 return networks, nil
304}
305
306func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
307 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
308 defer cancel()
309
310 netName := toNullString(network.Name)
311 nick := toNullString(network.Nick)
312 netUsername := toNullString(network.Username)
313 realname := toNullString(network.Realname)
314 pass := toNullString(network.Pass)
315 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
316
317 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
318 if network.SASL.Mechanism != "" {
319 saslMechanism = toNullString(network.SASL.Mechanism)
320 switch network.SASL.Mechanism {
321 case "PLAIN":
322 saslPlainUsername = toNullString(network.SASL.Plain.Username)
323 saslPlainPassword = toNullString(network.SASL.Plain.Password)
324 network.SASL.External.CertBlob = nil
325 network.SASL.External.PrivKeyBlob = nil
326 case "EXTERNAL":
327 // keep saslPlain* nil
328 default:
329 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
330 }
331 }
332
333 var err error
334 if network.ID == 0 {
335 err = db.db.QueryRowContext(ctx, `
336 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
337 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
338 sasl_external_key, enabled)
339 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
340 RETURNING id`,
341 userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
342 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
343 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
344 } else {
345 _, err = db.db.ExecContext(ctx, `
346 UPDATE "Network"
347 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
348 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
349 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
350 enabled = $14
351 WHERE id = $1`,
352 network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
353 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
354 network.SASL.External.PrivKeyBlob, network.Enabled)
355 }
356 return err
357}
358
359func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
360 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
361 defer cancel()
362
363 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
364 return err
365}
366
367func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
368 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
369 defer cancel()
370
371 rows, err := db.db.QueryContext(ctx, `
372 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
373 detach_on
374 FROM "Channel"
375 WHERE network = $1`, networkID)
376 if err != nil {
377 return nil, err
378 }
379 defer rows.Close()
380
381 var channels []Channel
382 for rows.Next() {
383 var ch Channel
384 var key, detachedInternalMsgID sql.NullString
385 var detachAfter int64
386 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
387 return nil, err
388 }
389 ch.Key = key.String
390 ch.DetachedInternalMsgID = detachedInternalMsgID.String
391 ch.DetachAfter = time.Duration(detachAfter) * time.Second
392 channels = append(channels, ch)
393 }
394 if err := rows.Err(); err != nil {
395 return nil, err
396 }
397
398 return channels, nil
399}
400
401func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
402 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
403 defer cancel()
404
405 key := toNullString(ch.Key)
406 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
407
408 var err error
409 if ch.ID == 0 {
410 err = db.db.QueryRowContext(ctx, `
411 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
412 detach_after, detach_on)
413 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
414 RETURNING id`,
415 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
416 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
417 } else {
418 _, err = db.db.ExecContext(ctx, `
419 UPDATE "Channel"
420 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
421 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
422 WHERE id = $1`,
423 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
424 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
425 }
426 return err
427}
428
429func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
430 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
431 defer cancel()
432
433 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
434 return err
435}
436
437func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
438 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
439 defer cancel()
440
441 rows, err := db.db.QueryContext(ctx, `
442 SELECT id, target, client, internal_msgid
443 FROM "DeliveryReceipt"
444 WHERE network = $1`, networkID)
445 if err != nil {
446 return nil, err
447 }
448 defer rows.Close()
449
450 var receipts []DeliveryReceipt
451 for rows.Next() {
452 var rcpt DeliveryReceipt
453 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
454 return nil, err
455 }
456 receipts = append(receipts, rcpt)
457 }
458 if err := rows.Err(); err != nil {
459 return nil, err
460 }
461
462 return receipts, nil
463}
464
465func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
466 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
467 defer cancel()
468
469 tx, err := db.db.Begin()
470 if err != nil {
471 return err
472 }
473 defer tx.Rollback()
474
475 _, err = tx.ExecContext(ctx,
476 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
477 networkID, client)
478 if err != nil {
479 return err
480 }
481
482 stmt, err := tx.PrepareContext(ctx, `
483 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
484 VALUES ($1, $2, $3, $4)
485 RETURNING id`)
486 if err != nil {
487 return err
488 }
489 defer stmt.Close()
490
491 for i := range receipts {
492 rcpt := &receipts[i]
493 err := stmt.
494 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
495 Scan(&rcpt.ID)
496 if err != nil {
497 return err
498 }
499 }
500
501 return tx.Commit()
502}
Note: See TracBrowser for help on using the repository browser.