source: code/trunk/db_postgres.go@ 811

Last change on this file since 811 was 804, checked in by koizumi.aoi, 2 years ago

Drunk as I like

Signed-off-by: Aoi K <koizumi.aoi@…>

File size: 15.9 KB
Line 
1package suika
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
13 "github.com/prometheus/client_golang/prometheus"
14 promcollectors "github.com/prometheus/client_golang/prometheus/collectors"
15)
16
17const postgresQueryTimeout = 5 * time.Second
18
19const postgresConfigSchema = `
20CREATE TABLE IF NOT EXISTS "Config" (
21 id SMALLINT PRIMARY KEY,
22 version INTEGER NOT NULL,
23 CHECK(id = 1)
24);
25`
26
27const postgresSchema = `
28CREATE TABLE "User" (
29 id SERIAL PRIMARY KEY,
30 username VARCHAR(255) NOT NULL UNIQUE,
31 password VARCHAR(255),
32 admin BOOLEAN NOT NULL DEFAULT FALSE,
33 realname VARCHAR(255)
34);
35
36CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
37
38CREATE TABLE "Network" (
39 id SERIAL PRIMARY KEY,
40 name VARCHAR(255),
41 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
42 addr VARCHAR(255) NOT NULL,
43 nick VARCHAR(255),
44 username VARCHAR(255),
45 realname VARCHAR(255),
46 pass VARCHAR(255),
47 connect_commands VARCHAR(1023),
48 sasl_mechanism sasl_mechanism,
49 sasl_plain_username VARCHAR(255),
50 sasl_plain_password VARCHAR(255),
51 sasl_external_cert BYTEA,
52 sasl_external_key BYTEA,
53 enabled BOOLEAN NOT NULL DEFAULT TRUE,
54 UNIQUE("user", addr, nick),
55 UNIQUE("user", name)
56);
57
58CREATE TABLE "Channel" (
59 id SERIAL PRIMARY KEY,
60 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
61 name VARCHAR(255) NOT NULL,
62 key VARCHAR(255),
63 detached BOOLEAN NOT NULL DEFAULT FALSE,
64 detached_internal_msgid VARCHAR(255),
65 relay_detached INTEGER NOT NULL DEFAULT 0,
66 reattach_on INTEGER NOT NULL DEFAULT 0,
67 detach_after INTEGER NOT NULL DEFAULT 0,
68 detach_on INTEGER NOT NULL DEFAULT 0,
69 UNIQUE(network, name)
70);
71
72CREATE TABLE "DeliveryReceipt" (
73 id SERIAL PRIMARY KEY,
74 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
75 target VARCHAR(255) NOT NULL,
76 client VARCHAR(255) NOT NULL DEFAULT '',
77 internal_msgid VARCHAR(255) NOT NULL,
78 UNIQUE(network, target, client)
79);
80
81CREATE TABLE "ReadReceipt" (
82 id SERIAL PRIMARY KEY,
83 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
84 target VARCHAR(255) NOT NULL,
85 timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
86 UNIQUE(network, target)
87);
88`
89
90var postgresMigrations = []string{
91 "", // migration #0 is reserved for schema initialization
92 `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
93 `
94 CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
95 ALTER TABLE "Network"
96 ALTER COLUMN sasl_mechanism
97 TYPE sasl_mechanism
98 USING sasl_mechanism::sasl_mechanism;
99 `,
100 `
101 CREATE TABLE "ReadReceipt" (
102 id SERIAL PRIMARY KEY,
103 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
104 target VARCHAR(255) NOT NULL,
105 timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
106 UNIQUE(network, target)
107 );
108 `,
109}
110
111type PostgresDB struct {
112 db *sql.DB
113}
114
115func OpenPostgresDB(source string) (Database, error) {
116 sqlPostgresDB, err := sql.Open("postgres", source)
117 if err != nil {
118 return nil, err
119 }
120
121 db := &PostgresDB{db: sqlPostgresDB}
122 if err := db.upgrade(); err != nil {
123 sqlPostgresDB.Close()
124 return nil, err
125 }
126
127 return db, nil
128}
129
130func (db *PostgresDB) upgrade() error {
131 tx, err := db.db.Begin()
132 if err != nil {
133 return err
134 }
135 defer tx.Rollback()
136
137 if _, err := tx.Exec(postgresConfigSchema); err != nil {
138 return fmt.Errorf("failed to create Config table: %s", err)
139 }
140
141 var version int
142 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
143 if err != nil && !errors.Is(err, sql.ErrNoRows) {
144 return fmt.Errorf("failed to query schema version: %s", err)
145 }
146
147 if version == len(postgresMigrations) {
148 return nil
149 }
150 if version > len(postgresMigrations) {
151 return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version)
152 }
153
154 if version == 0 {
155 if _, err := tx.Exec(postgresSchema); err != nil {
156 return fmt.Errorf("failed to initialize schema: %s", err)
157 }
158 } else {
159 for i := version; i < len(postgresMigrations); i++ {
160 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
161 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
162 }
163 }
164 }
165
166 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
167 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
168 if err != nil {
169 return fmt.Errorf("failed to bump schema version: %v", err)
170 }
171
172 return tx.Commit()
173}
174
175func (db *PostgresDB) Close() error {
176 return db.db.Close()
177}
178
179func (db *PostgresDB) MetricsCollector() prometheus.Collector {
180 return promcollectors.NewDBStatsCollector(db.db, "main")
181}
182
183func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
184 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
185 defer cancel()
186
187 var stats DatabaseStats
188 row := db.db.QueryRowContext(ctx, `SELECT
189 (SELECT COUNT(*) FROM "User") AS users,
190 (SELECT COUNT(*) FROM "Network") AS networks,
191 (SELECT COUNT(*) FROM "Channel") AS channels`)
192 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
193 return nil, err
194 }
195
196 return &stats, nil
197}
198
199func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
200 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
201 defer cancel()
202
203 rows, err := db.db.QueryContext(ctx,
204 `SELECT id, username, password, admin, realname FROM "User"`)
205 if err != nil {
206 return nil, err
207 }
208 defer rows.Close()
209
210 var users []User
211 for rows.Next() {
212 var user User
213 var password, realname sql.NullString
214 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
215 return nil, err
216 }
217 user.Password = password.String
218 user.Realname = realname.String
219 users = append(users, user)
220 }
221 if err := rows.Err(); err != nil {
222 return nil, err
223 }
224
225 return users, nil
226}
227
228func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
229 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
230 defer cancel()
231
232 user := &User{Username: username}
233
234 var password, realname sql.NullString
235 row := db.db.QueryRowContext(ctx,
236 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
237 username)
238 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
239 return nil, err
240 }
241 user.Password = password.String
242 user.Realname = realname.String
243 return user, nil
244}
245
246func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
247 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
248 defer cancel()
249
250 password := toNullString(user.Password)
251 realname := toNullString(user.Realname)
252
253 var err error
254 if user.ID == 0 {
255 err = db.db.QueryRowContext(ctx, `
256 INSERT INTO "User" (username, password, admin, realname)
257 VALUES ($1, $2, $3, $4)
258 RETURNING id`,
259 user.Username, password, user.Admin, realname).Scan(&user.ID)
260 } else {
261 _, err = db.db.ExecContext(ctx, `
262 UPDATE "User"
263 SET password = $1, admin = $2, realname = $3
264 WHERE id = $4`,
265 password, user.Admin, realname, user.ID)
266 }
267 return err
268}
269
270func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
271 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
272 defer cancel()
273
274 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
275 return err
276}
277
278func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
279 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
280 defer cancel()
281
282 rows, err := db.db.QueryContext(ctx, `
283 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
284 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
285 FROM "Network"
286 WHERE "user" = $1`, userID)
287 if err != nil {
288 return nil, err
289 }
290 defer rows.Close()
291
292 var networks []Network
293 for rows.Next() {
294 var net Network
295 var name, nick, username, realname, pass, connectCommands sql.NullString
296 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
297 err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
298 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
299 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
300 if err != nil {
301 return nil, err
302 }
303 net.Name = name.String
304 net.Nick = nick.String
305 net.Username = username.String
306 net.Realname = realname.String
307 net.Pass = pass.String
308 if connectCommands.Valid {
309 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
310 }
311 net.SASL.Mechanism = saslMechanism.String
312 net.SASL.Plain.Username = saslPlainUsername.String
313 net.SASL.Plain.Password = saslPlainPassword.String
314 networks = append(networks, net)
315 }
316 if err := rows.Err(); err != nil {
317 return nil, err
318 }
319
320 return networks, nil
321}
322
323func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
324 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
325 defer cancel()
326
327 netName := toNullString(network.Name)
328 nick := toNullString(network.Nick)
329 netUsername := toNullString(network.Username)
330 realname := toNullString(network.Realname)
331 pass := toNullString(network.Pass)
332 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
333
334 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
335 if network.SASL.Mechanism != "" {
336 saslMechanism = toNullString(network.SASL.Mechanism)
337 switch network.SASL.Mechanism {
338 case "PLAIN":
339 saslPlainUsername = toNullString(network.SASL.Plain.Username)
340 saslPlainPassword = toNullString(network.SASL.Plain.Password)
341 network.SASL.External.CertBlob = nil
342 network.SASL.External.PrivKeyBlob = nil
343 case "EXTERNAL":
344 // keep saslPlain* nil
345 default:
346 return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
347 }
348 }
349
350 var err error
351 if network.ID == 0 {
352 err = db.db.QueryRowContext(ctx, `
353 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
354 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
355 sasl_external_key, enabled)
356 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
357 RETURNING id`,
358 userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
359 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
360 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
361 } else {
362 _, err = db.db.ExecContext(ctx, `
363 UPDATE "Network"
364 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
365 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
366 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
367 enabled = $14
368 WHERE id = $1`,
369 network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
370 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
371 network.SASL.External.PrivKeyBlob, network.Enabled)
372 }
373 return err
374}
375
376func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
377 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
378 defer cancel()
379
380 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
381 return err
382}
383
384func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
385 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
386 defer cancel()
387
388 rows, err := db.db.QueryContext(ctx, `
389 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
390 detach_on
391 FROM "Channel"
392 WHERE network = $1`, networkID)
393 if err != nil {
394 return nil, err
395 }
396 defer rows.Close()
397
398 var channels []Channel
399 for rows.Next() {
400 var ch Channel
401 var key, detachedInternalMsgID sql.NullString
402 var detachAfter int64
403 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
404 return nil, err
405 }
406 ch.Key = key.String
407 ch.DetachedInternalMsgID = detachedInternalMsgID.String
408 ch.DetachAfter = time.Duration(detachAfter) * time.Second
409 channels = append(channels, ch)
410 }
411 if err := rows.Err(); err != nil {
412 return nil, err
413 }
414
415 return channels, nil
416}
417
418func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
419 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
420 defer cancel()
421
422 key := toNullString(ch.Key)
423 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
424
425 var err error
426 if ch.ID == 0 {
427 err = db.db.QueryRowContext(ctx, `
428 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
429 detach_after, detach_on)
430 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
431 RETURNING id`,
432 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
433 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
434 } else {
435 _, err = db.db.ExecContext(ctx, `
436 UPDATE "Channel"
437 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
438 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
439 WHERE id = $1`,
440 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
441 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
442 }
443 return err
444}
445
446func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
447 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
448 defer cancel()
449
450 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
451 return err
452}
453
454func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
455 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
456 defer cancel()
457
458 rows, err := db.db.QueryContext(ctx, `
459 SELECT id, target, client, internal_msgid
460 FROM "DeliveryReceipt"
461 WHERE network = $1`, networkID)
462 if err != nil {
463 return nil, err
464 }
465 defer rows.Close()
466
467 var receipts []DeliveryReceipt
468 for rows.Next() {
469 var rcpt DeliveryReceipt
470 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
471 return nil, err
472 }
473 receipts = append(receipts, rcpt)
474 }
475 if err := rows.Err(); err != nil {
476 return nil, err
477 }
478
479 return receipts, nil
480}
481
482func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
483 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
484 defer cancel()
485
486 tx, err := db.db.Begin()
487 if err != nil {
488 return err
489 }
490 defer tx.Rollback()
491
492 _, err = tx.ExecContext(ctx,
493 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
494 networkID, client)
495 if err != nil {
496 return err
497 }
498
499 stmt, err := tx.PrepareContext(ctx, `
500 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
501 VALUES ($1, $2, $3, $4)
502 RETURNING id`)
503 if err != nil {
504 return err
505 }
506 defer stmt.Close()
507
508 for i := range receipts {
509 rcpt := &receipts[i]
510 err := stmt.
511 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
512 Scan(&rcpt.ID)
513 if err != nil {
514 return err
515 }
516 }
517
518 return tx.Commit()
519}
520
521func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
522 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
523 defer cancel()
524
525 receipt := &ReadReceipt{
526 Target: name,
527 }
528
529 row := db.db.QueryRowContext(ctx,
530 `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`,
531 networkID, name)
532 if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil {
533 if err == sql.ErrNoRows {
534 return nil, nil
535 }
536 return nil, err
537 }
538 return receipt, nil
539}
540
541func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
542 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
543 defer cancel()
544
545 var err error
546 if receipt.ID != 0 {
547 _, err = db.db.ExecContext(ctx, `
548 UPDATE "ReadReceipt"
549 SET timestamp = $1
550 WHERE id = $2`,
551 receipt.Timestamp, receipt.ID)
552 } else {
553 err = db.db.QueryRowContext(ctx, `
554 INSERT INTO "ReadReceipt" (network, target, timestamp)
555 VALUES ($1, $2, $3)
556 RETURNING id`,
557 networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID)
558 }
559 return err
560}
Note: See TracBrowser for help on using the repository browser.