source: code/trunk/db_postgres.go@ 708

Last change on this file since 708 was 664, checked in by contact, 4 years ago

Make Network.Nick optional

Make Network.Nick optional, default to the user's username. This
will allow adding a global setting to set the nickname in the
future, just like we have for the real name.

References: https://todo.sr.ht/~emersion/soju/110

File size: 13.8 KB
Line 
1package soju
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
13)
14
15const postgresQueryTimeout = 5 * time.Second
16
17const postgresConfigSchema = `
18CREATE TABLE IF NOT EXISTS "Config" (
19 id SMALLINT PRIMARY KEY,
20 version INTEGER NOT NULL,
21 CHECK(id = 1)
22);
23`
24
25const postgresSchema = `
26CREATE TABLE "User" (
27 id SERIAL PRIMARY KEY,
28 username VARCHAR(255) NOT NULL UNIQUE,
29 password VARCHAR(255),
30 admin BOOLEAN NOT NULL DEFAULT FALSE,
31 realname VARCHAR(255)
32);
33
34CREATE TABLE "Network" (
35 id SERIAL PRIMARY KEY,
36 name VARCHAR(255),
37 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
38 addr VARCHAR(255) NOT NULL,
39 nick VARCHAR(255),
40 username VARCHAR(255),
41 realname VARCHAR(255),
42 pass VARCHAR(255),
43 connect_commands VARCHAR(1023),
44 sasl_mechanism VARCHAR(255),
45 sasl_plain_username VARCHAR(255),
46 sasl_plain_password VARCHAR(255),
47 sasl_external_cert BYTEA,
48 sasl_external_key BYTEA,
49 enabled BOOLEAN NOT NULL DEFAULT TRUE,
50 UNIQUE("user", addr, nick),
51 UNIQUE("user", name)
52);
53
54CREATE TABLE "Channel" (
55 id SERIAL PRIMARY KEY,
56 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
57 name VARCHAR(255) NOT NULL,
58 key VARCHAR(255),
59 detached BOOLEAN NOT NULL DEFAULT FALSE,
60 detached_internal_msgid VARCHAR(255),
61 relay_detached INTEGER NOT NULL DEFAULT 0,
62 reattach_on INTEGER NOT NULL DEFAULT 0,
63 detach_after INTEGER NOT NULL DEFAULT 0,
64 detach_on INTEGER NOT NULL DEFAULT 0,
65 UNIQUE(network, name)
66);
67
68CREATE TABLE "DeliveryReceipt" (
69 id SERIAL PRIMARY KEY,
70 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
71 target VARCHAR(255) NOT NULL,
72 client VARCHAR(255) NOT NULL DEFAULT '',
73 internal_msgid VARCHAR(255) NOT NULL,
74 UNIQUE(network, target, client)
75);
76`
77
78var postgresMigrations = []string{
79 "", // migration #0 is reserved for schema initialization
80 `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
81}
82
83type PostgresDB struct {
84 db *sql.DB
85}
86
87func OpenPostgresDB(source string) (Database, error) {
88 sqlPostgresDB, err := sql.Open("postgres", source)
89 if err != nil {
90 return nil, err
91 }
92
93 db := &PostgresDB{db: sqlPostgresDB}
94 if err := db.upgrade(); err != nil {
95 sqlPostgresDB.Close()
96 return nil, err
97 }
98
99 return db, nil
100}
101
102func (db *PostgresDB) upgrade() error {
103 tx, err := db.db.Begin()
104 if err != nil {
105 return err
106 }
107 defer tx.Rollback()
108
109 if _, err := tx.Exec(postgresConfigSchema); err != nil {
110 return fmt.Errorf("failed to create Config table: %s", err)
111 }
112
113 var version int
114 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
115 if err != nil && !errors.Is(err, sql.ErrNoRows) {
116 return fmt.Errorf("failed to query schema version: %s", err)
117 }
118
119 if version == len(postgresMigrations) {
120 return nil
121 }
122 if version > len(postgresMigrations) {
123 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
124 }
125
126 if version == 0 {
127 if _, err := tx.Exec(postgresSchema); err != nil {
128 return fmt.Errorf("failed to initialize schema: %s", err)
129 }
130 } else {
131 for i := version; i < len(postgresMigrations); i++ {
132 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
133 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
134 }
135 }
136 }
137
138 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
139 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
140 if err != nil {
141 return fmt.Errorf("failed to bump schema version: %v", err)
142 }
143
144 return tx.Commit()
145}
146
147func (db *PostgresDB) Close() error {
148 return db.db.Close()
149}
150
151func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
152 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
153 defer cancel()
154
155 var stats DatabaseStats
156 row := db.db.QueryRowContext(ctx, `SELECT
157 (SELECT COUNT(*) FROM "User") AS users,
158 (SELECT COUNT(*) FROM "Network") AS networks,
159 (SELECT COUNT(*) FROM "Channel") AS channels`)
160 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
161 return nil, err
162 }
163
164 return &stats, nil
165}
166
167func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
168 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
169 defer cancel()
170
171 rows, err := db.db.QueryContext(ctx,
172 `SELECT id, username, password, admin, realname FROM "User"`)
173 if err != nil {
174 return nil, err
175 }
176 defer rows.Close()
177
178 var users []User
179 for rows.Next() {
180 var user User
181 var password, realname sql.NullString
182 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
183 return nil, err
184 }
185 user.Password = password.String
186 user.Realname = realname.String
187 users = append(users, user)
188 }
189 if err := rows.Err(); err != nil {
190 return nil, err
191 }
192
193 return users, nil
194}
195
196func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
197 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
198 defer cancel()
199
200 user := &User{Username: username}
201
202 var password, realname sql.NullString
203 row := db.db.QueryRowContext(ctx,
204 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
205 username)
206 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
207 return nil, err
208 }
209 user.Password = password.String
210 user.Realname = realname.String
211 return user, nil
212}
213
214func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
215 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
216 defer cancel()
217
218 password := toNullString(user.Password)
219 realname := toNullString(user.Realname)
220
221 var err error
222 if user.ID == 0 {
223 err = db.db.QueryRowContext(ctx, `
224 INSERT INTO "User" (username, password, admin, realname)
225 VALUES ($1, $2, $3, $4)
226 RETURNING id`,
227 user.Username, password, user.Admin, realname).Scan(&user.ID)
228 } else {
229 _, err = db.db.ExecContext(ctx, `
230 UPDATE "User"
231 SET password = $1, admin = $2, realname = $3
232 WHERE id = $4`,
233 password, user.Admin, realname, user.ID)
234 }
235 return err
236}
237
238func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
239 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
240 defer cancel()
241
242 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
243 return err
244}
245
246func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
247 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
248 defer cancel()
249
250 rows, err := db.db.QueryContext(ctx, `
251 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
252 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
253 FROM "Network"
254 WHERE "user" = $1`, userID)
255 if err != nil {
256 return nil, err
257 }
258 defer rows.Close()
259
260 var networks []Network
261 for rows.Next() {
262 var net Network
263 var name, nick, username, realname, pass, connectCommands sql.NullString
264 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
265 err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
266 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
267 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
268 if err != nil {
269 return nil, err
270 }
271 net.Name = name.String
272 net.Nick = nick.String
273 net.Username = username.String
274 net.Realname = realname.String
275 net.Pass = pass.String
276 if connectCommands.Valid {
277 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
278 }
279 net.SASL.Mechanism = saslMechanism.String
280 net.SASL.Plain.Username = saslPlainUsername.String
281 net.SASL.Plain.Password = saslPlainPassword.String
282 networks = append(networks, net)
283 }
284 if err := rows.Err(); err != nil {
285 return nil, err
286 }
287
288 return networks, nil
289}
290
291func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
292 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
293 defer cancel()
294
295 netName := toNullString(network.Name)
296 nick := toNullString(network.Nick)
297 netUsername := toNullString(network.Username)
298 realname := toNullString(network.Realname)
299 pass := toNullString(network.Pass)
300 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
301
302 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
303 if network.SASL.Mechanism != "" {
304 saslMechanism = toNullString(network.SASL.Mechanism)
305 switch network.SASL.Mechanism {
306 case "PLAIN":
307 saslPlainUsername = toNullString(network.SASL.Plain.Username)
308 saslPlainPassword = toNullString(network.SASL.Plain.Password)
309 network.SASL.External.CertBlob = nil
310 network.SASL.External.PrivKeyBlob = nil
311 case "EXTERNAL":
312 // keep saslPlain* nil
313 default:
314 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
315 }
316 }
317
318 var err error
319 if network.ID == 0 {
320 err = db.db.QueryRowContext(ctx, `
321 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
322 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
323 sasl_external_key, enabled)
324 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
325 RETURNING id`,
326 userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
327 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
328 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
329 } else {
330 _, err = db.db.ExecContext(ctx, `
331 UPDATE "Network"
332 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
333 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
334 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
335 enabled = $14
336 WHERE id = $1`,
337 network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
338 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
339 network.SASL.External.PrivKeyBlob, network.Enabled)
340 }
341 return err
342}
343
344func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
345 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
346 defer cancel()
347
348 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
349 return err
350}
351
352func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
353 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
354 defer cancel()
355
356 rows, err := db.db.QueryContext(ctx, `
357 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
358 detach_on
359 FROM "Channel"
360 WHERE network = $1`, networkID)
361 if err != nil {
362 return nil, err
363 }
364 defer rows.Close()
365
366 var channels []Channel
367 for rows.Next() {
368 var ch Channel
369 var key, detachedInternalMsgID sql.NullString
370 var detachAfter int64
371 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
372 return nil, err
373 }
374 ch.Key = key.String
375 ch.DetachedInternalMsgID = detachedInternalMsgID.String
376 ch.DetachAfter = time.Duration(detachAfter) * time.Second
377 channels = append(channels, ch)
378 }
379 if err := rows.Err(); err != nil {
380 return nil, err
381 }
382
383 return channels, nil
384}
385
386func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
387 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
388 defer cancel()
389
390 key := toNullString(ch.Key)
391 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
392
393 var err error
394 if ch.ID == 0 {
395 err = db.db.QueryRowContext(ctx, `
396 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
397 detach_after, detach_on)
398 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
399 RETURNING id`,
400 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
401 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
402 } else {
403 _, err = db.db.ExecContext(ctx, `
404 UPDATE "Channel"
405 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
406 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
407 WHERE id = $1`,
408 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
409 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
410 }
411 return err
412}
413
414func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
415 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
416 defer cancel()
417
418 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
419 return err
420}
421
422func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
423 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
424 defer cancel()
425
426 rows, err := db.db.QueryContext(ctx, `
427 SELECT id, target, client, internal_msgid
428 FROM "DeliveryReceipt"
429 WHERE network = $1`, networkID)
430 if err != nil {
431 return nil, err
432 }
433 defer rows.Close()
434
435 var receipts []DeliveryReceipt
436 for rows.Next() {
437 var rcpt DeliveryReceipt
438 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
439 return nil, err
440 }
441 receipts = append(receipts, rcpt)
442 }
443 if err := rows.Err(); err != nil {
444 return nil, err
445 }
446
447 return receipts, nil
448}
449
450func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
451 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
452 defer cancel()
453
454 tx, err := db.db.Begin()
455 if err != nil {
456 return err
457 }
458 defer tx.Rollback()
459
460 _, err = tx.ExecContext(ctx,
461 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
462 networkID, client)
463 if err != nil {
464 return err
465 }
466
467 stmt, err := tx.PrepareContext(ctx, `
468 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
469 VALUES ($1, $2, $3, $4)
470 RETURNING id`)
471 if err != nil {
472 return err
473 }
474 defer stmt.Close()
475
476 for i := range receipts {
477 rcpt := &receipts[i]
478 err := stmt.
479 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
480 Scan(&rcpt.ID)
481 if err != nil {
482 return err
483 }
484 }
485
486 return tx.Commit()
487}
Note: See TracBrowser for help on using the repository browser.