source: code/trunk/db_postgres.go@ 664

Last change on this file since 664 was 664, checked in by contact, 4 years ago

Make Network.Nick optional

Make Network.Nick optional, default to the user's username. This
will allow adding a global setting to set the nickname in the
future, just like we have for the real name.

References: https://todo.sr.ht/~emersion/soju/110

File size: 13.8 KB
RevLine 
[620]1package soju
2
3import (
[645]4 "context"
[620]5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
13)
14
[645]15const postgresQueryTimeout = 5 * time.Second
16
[620]17const postgresConfigSchema = `
18CREATE TABLE IF NOT EXISTS "Config" (
19 id SMALLINT PRIMARY KEY,
20 version INTEGER NOT NULL,
21 CHECK(id = 1)
22);
23`
24
25const postgresSchema = `
26CREATE TABLE "User" (
27 id SERIAL PRIMARY KEY,
28 username VARCHAR(255) NOT NULL UNIQUE,
29 password VARCHAR(255),
30 admin BOOLEAN NOT NULL DEFAULT FALSE,
31 realname VARCHAR(255)
32);
33
34CREATE TABLE "Network" (
35 id SERIAL PRIMARY KEY,
36 name VARCHAR(255),
37 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
38 addr VARCHAR(255) NOT NULL,
[664]39 nick VARCHAR(255),
[620]40 username VARCHAR(255),
41 realname VARCHAR(255),
42 pass VARCHAR(255),
43 connect_commands VARCHAR(1023),
44 sasl_mechanism VARCHAR(255),
45 sasl_plain_username VARCHAR(255),
46 sasl_plain_password VARCHAR(255),
[640]47 sasl_external_cert BYTEA,
48 sasl_external_key BYTEA,
[620]49 enabled BOOLEAN NOT NULL DEFAULT TRUE,
50 UNIQUE("user", addr, nick),
51 UNIQUE("user", name)
52);
53
54CREATE TABLE "Channel" (
55 id SERIAL PRIMARY KEY,
56 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
57 name VARCHAR(255) NOT NULL,
58 key VARCHAR(255),
59 detached BOOLEAN NOT NULL DEFAULT FALSE,
60 detached_internal_msgid VARCHAR(255),
61 relay_detached INTEGER NOT NULL DEFAULT 0,
62 reattach_on INTEGER NOT NULL DEFAULT 0,
63 detach_after INTEGER NOT NULL DEFAULT 0,
64 detach_on INTEGER NOT NULL DEFAULT 0,
65 UNIQUE(network, name)
66);
67
68CREATE TABLE "DeliveryReceipt" (
69 id SERIAL PRIMARY KEY,
70 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
71 target VARCHAR(255) NOT NULL,
72 client VARCHAR(255) NOT NULL DEFAULT '',
73 internal_msgid VARCHAR(255) NOT NULL,
74 UNIQUE(network, target, client)
75);
76`
77
78var postgresMigrations = []string{
79 "", // migration #0 is reserved for schema initialization
[664]80 `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
[620]81}
82
83type PostgresDB struct {
84 db *sql.DB
85}
86
87func OpenPostgresDB(source string) (Database, error) {
88 sqlPostgresDB, err := sql.Open("postgres", source)
89 if err != nil {
90 return nil, err
91 }
92
93 db := &PostgresDB{db: sqlPostgresDB}
94 if err := db.upgrade(); err != nil {
95 sqlPostgresDB.Close()
96 return nil, err
97 }
98
99 return db, nil
100}
101
102func (db *PostgresDB) upgrade() error {
103 tx, err := db.db.Begin()
104 if err != nil {
105 return err
106 }
107 defer tx.Rollback()
108
109 if _, err := tx.Exec(postgresConfigSchema); err != nil {
110 return fmt.Errorf("failed to create Config table: %s", err)
111 }
112
113 var version int
114 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
115 if err != nil && !errors.Is(err, sql.ErrNoRows) {
116 return fmt.Errorf("failed to query schema version: %s", err)
117 }
118
119 if version == len(postgresMigrations) {
120 return nil
121 }
122 if version > len(postgresMigrations) {
123 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
124 }
125
126 if version == 0 {
127 if _, err := tx.Exec(postgresSchema); err != nil {
128 return fmt.Errorf("failed to initialize schema: %s", err)
129 }
130 } else {
131 for i := version; i < len(postgresMigrations); i++ {
132 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
133 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
134 }
135 }
136 }
137
138 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
139 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
140 if err != nil {
141 return fmt.Errorf("failed to bump schema version: %v", err)
142 }
143
144 return tx.Commit()
145}
146
147func (db *PostgresDB) Close() error {
148 return db.db.Close()
149}
150
[652]151func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
152 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]153 defer cancel()
154
[620]155 var stats DatabaseStats
[645]156 row := db.db.QueryRowContext(ctx, `SELECT
[620]157 (SELECT COUNT(*) FROM "User") AS users,
158 (SELECT COUNT(*) FROM "Network") AS networks,
159 (SELECT COUNT(*) FROM "Channel") AS channels`)
160 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
161 return nil, err
162 }
163
164 return &stats, nil
165}
166
[652]167func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
168 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]169 defer cancel()
170
171 rows, err := db.db.QueryContext(ctx,
172 `SELECT id, username, password, admin, realname FROM "User"`)
[620]173 if err != nil {
174 return nil, err
175 }
176 defer rows.Close()
177
178 var users []User
179 for rows.Next() {
180 var user User
181 var password, realname sql.NullString
182 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
183 return nil, err
184 }
185 user.Password = password.String
186 user.Realname = realname.String
187 users = append(users, user)
188 }
189 if err := rows.Err(); err != nil {
190 return nil, err
191 }
192
193 return users, nil
194}
195
[652]196func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
197 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]198 defer cancel()
199
[620]200 user := &User{Username: username}
201
202 var password, realname sql.NullString
[645]203 row := db.db.QueryRowContext(ctx,
[620]204 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
205 username)
206 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
207 return nil, err
208 }
209 user.Password = password.String
210 user.Realname = realname.String
211 return user, nil
212}
213
[652]214func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
215 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]216 defer cancel()
217
[620]218 password := toNullString(user.Password)
219 realname := toNullString(user.Realname)
[635]220
221 var err error
222 if user.ID == 0 {
[645]223 err = db.db.QueryRowContext(ctx, `
[635]224 INSERT INTO "User" (username, password, admin, realname)
225 VALUES ($1, $2, $3, $4)
226 RETURNING id`,
227 user.Username, password, user.Admin, realname).Scan(&user.ID)
228 } else {
[645]229 _, err = db.db.ExecContext(ctx, `
[635]230 UPDATE "User"
231 SET password = $1, admin = $2, realname = $3
232 WHERE id = $4`,
233 password, user.Admin, realname, user.ID)
234 }
[620]235 return err
236}
237
[652]238func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
239 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]240 defer cancel()
241
242 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
[620]243 return err
244}
245
[652]246func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
247 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]248 defer cancel()
249
250 rows, err := db.db.QueryContext(ctx, `
[620]251 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
252 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
253 FROM "Network"
254 WHERE "user" = $1`, userID)
255 if err != nil {
256 return nil, err
257 }
258 defer rows.Close()
259
260 var networks []Network
261 for rows.Next() {
262 var net Network
[664]263 var name, nick, username, realname, pass, connectCommands sql.NullString
[620]264 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
[664]265 err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
[620]266 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
267 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
268 if err != nil {
269 return nil, err
270 }
271 net.Name = name.String
[664]272 net.Nick = nick.String
[620]273 net.Username = username.String
274 net.Realname = realname.String
275 net.Pass = pass.String
276 if connectCommands.Valid {
277 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
278 }
279 net.SASL.Mechanism = saslMechanism.String
280 net.SASL.Plain.Username = saslPlainUsername.String
281 net.SASL.Plain.Password = saslPlainPassword.String
282 networks = append(networks, net)
283 }
284 if err := rows.Err(); err != nil {
285 return nil, err
286 }
287
288 return networks, nil
289}
290
[652]291func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
292 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]293 defer cancel()
294
[620]295 netName := toNullString(network.Name)
[664]296 nick := toNullString(network.Nick)
[620]297 netUsername := toNullString(network.Username)
298 realname := toNullString(network.Realname)
299 pass := toNullString(network.Pass)
300 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
301
302 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
303 if network.SASL.Mechanism != "" {
304 saslMechanism = toNullString(network.SASL.Mechanism)
305 switch network.SASL.Mechanism {
306 case "PLAIN":
307 saslPlainUsername = toNullString(network.SASL.Plain.Username)
308 saslPlainPassword = toNullString(network.SASL.Plain.Password)
309 network.SASL.External.CertBlob = nil
310 network.SASL.External.PrivKeyBlob = nil
311 case "EXTERNAL":
312 // keep saslPlain* nil
313 default:
314 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
315 }
316 }
317
[635]318 var err error
319 if network.ID == 0 {
[645]320 err = db.db.QueryRowContext(ctx, `
[635]321 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
322 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
323 sasl_external_key, enabled)
324 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
325 RETURNING id`,
[664]326 userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
[635]327 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
328 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
329 } else {
[645]330 _, err = db.db.ExecContext(ctx, `
[635]331 UPDATE "Network"
332 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
333 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
334 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
335 enabled = $14
336 WHERE id = $1`,
[664]337 network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
[635]338 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
339 network.SASL.External.PrivKeyBlob, network.Enabled)
340 }
[620]341 return err
342}
343
[652]344func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
345 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]346 defer cancel()
347
348 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
[620]349 return err
350}
351
[652]352func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
353 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]354 defer cancel()
355
356 rows, err := db.db.QueryContext(ctx, `
[620]357 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
358 detach_on
359 FROM "Channel"
360 WHERE network = $1`, networkID)
361 if err != nil {
362 return nil, err
363 }
364 defer rows.Close()
365
366 var channels []Channel
367 for rows.Next() {
368 var ch Channel
369 var key, detachedInternalMsgID sql.NullString
370 var detachAfter int64
371 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
372 return nil, err
373 }
374 ch.Key = key.String
375 ch.DetachedInternalMsgID = detachedInternalMsgID.String
376 ch.DetachAfter = time.Duration(detachAfter) * time.Second
377 channels = append(channels, ch)
378 }
379 if err := rows.Err(); err != nil {
380 return nil, err
381 }
382
383 return channels, nil
384}
385
[652]386func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
387 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]388 defer cancel()
389
[620]390 key := toNullString(ch.Key)
391 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
[635]392
393 var err error
394 if ch.ID == 0 {
[645]395 err = db.db.QueryRowContext(ctx, `
[635]396 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
397 detach_after, detach_on)
398 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
399 RETURNING id`,
400 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
401 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
402 } else {
[645]403 _, err = db.db.ExecContext(ctx, `
[635]404 UPDATE "Channel"
405 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
406 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
407 WHERE id = $1`,
408 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
409 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
410 }
[620]411 return err
412}
413
[652]414func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
415 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]416 defer cancel()
417
418 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
[620]419 return err
420}
421
[652]422func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
423 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]424 defer cancel()
425
426 rows, err := db.db.QueryContext(ctx, `
[620]427 SELECT id, target, client, internal_msgid
428 FROM "DeliveryReceipt"
429 WHERE network = $1`, networkID)
430 if err != nil {
431 return nil, err
432 }
433 defer rows.Close()
434
435 var receipts []DeliveryReceipt
436 for rows.Next() {
437 var rcpt DeliveryReceipt
438 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
439 return nil, err
440 }
441 receipts = append(receipts, rcpt)
442 }
443 if err := rows.Err(); err != nil {
444 return nil, err
445 }
446
447 return receipts, nil
448}
449
[652]450func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
451 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
[645]452 defer cancel()
453
[635]454 tx, err := db.db.Begin()
455 if err != nil {
456 return err
457 }
458 defer tx.Rollback()
459
[645]460 _, err = tx.ExecContext(ctx,
461 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
[635]462 networkID, client)
463 if err != nil {
464 return err
465 }
466
[645]467 stmt, err := tx.PrepareContext(ctx, `
[620]468 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
469 VALUES ($1, $2, $3, $4)
470 RETURNING id`)
471 if err != nil {
472 return err
473 }
474 defer stmt.Close()
475
476 for i := range receipts {
477 rcpt := &receipts[i]
[645]478 err := stmt.
479 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
480 Scan(&rcpt.ID)
[620]481 if err != nil {
482 return err
483 }
484 }
[635]485
486 return tx.Commit()
[620]487}
Note: See TracBrowser for help on using the repository browser.