source: code/trunk/db_postgres.go@ 662

Last change on this file since 662 was 652, checked in by contact, 4 years ago

Add context args to Database interface

This is a mecanical change, which just lifts up the context.TODO()
calls from inside the DB implementations to the callers.

Future work involves properly wiring up the contexts when it makes
sense.

File size: 13.7 KB
Line 
1package soju
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strings"
10 "time"
11
12 _ "github.com/lib/pq"
13)
14
15const postgresQueryTimeout = 5 * time.Second
16
17const postgresConfigSchema = `
18CREATE TABLE IF NOT EXISTS "Config" (
19 id SMALLINT PRIMARY KEY,
20 version INTEGER NOT NULL,
21 CHECK(id = 1)
22);
23`
24
25const postgresSchema = `
26CREATE TABLE "User" (
27 id SERIAL PRIMARY KEY,
28 username VARCHAR(255) NOT NULL UNIQUE,
29 password VARCHAR(255),
30 admin BOOLEAN NOT NULL DEFAULT FALSE,
31 realname VARCHAR(255)
32);
33
34CREATE TABLE "Network" (
35 id SERIAL PRIMARY KEY,
36 name VARCHAR(255),
37 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
38 addr VARCHAR(255) NOT NULL,
39 nick VARCHAR(255) NOT NULL,
40 username VARCHAR(255),
41 realname VARCHAR(255),
42 pass VARCHAR(255),
43 connect_commands VARCHAR(1023),
44 sasl_mechanism VARCHAR(255),
45 sasl_plain_username VARCHAR(255),
46 sasl_plain_password VARCHAR(255),
47 sasl_external_cert BYTEA,
48 sasl_external_key BYTEA,
49 enabled BOOLEAN NOT NULL DEFAULT TRUE,
50 UNIQUE("user", addr, nick),
51 UNIQUE("user", name)
52);
53
54CREATE TABLE "Channel" (
55 id SERIAL PRIMARY KEY,
56 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
57 name VARCHAR(255) NOT NULL,
58 key VARCHAR(255),
59 detached BOOLEAN NOT NULL DEFAULT FALSE,
60 detached_internal_msgid VARCHAR(255),
61 relay_detached INTEGER NOT NULL DEFAULT 0,
62 reattach_on INTEGER NOT NULL DEFAULT 0,
63 detach_after INTEGER NOT NULL DEFAULT 0,
64 detach_on INTEGER NOT NULL DEFAULT 0,
65 UNIQUE(network, name)
66);
67
68CREATE TABLE "DeliveryReceipt" (
69 id SERIAL PRIMARY KEY,
70 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
71 target VARCHAR(255) NOT NULL,
72 client VARCHAR(255) NOT NULL DEFAULT '',
73 internal_msgid VARCHAR(255) NOT NULL,
74 UNIQUE(network, target, client)
75);
76`
77
78var postgresMigrations = []string{
79 "", // migration #0 is reserved for schema initialization
80}
81
82type PostgresDB struct {
83 db *sql.DB
84}
85
86func OpenPostgresDB(source string) (Database, error) {
87 sqlPostgresDB, err := sql.Open("postgres", source)
88 if err != nil {
89 return nil, err
90 }
91
92 db := &PostgresDB{db: sqlPostgresDB}
93 if err := db.upgrade(); err != nil {
94 sqlPostgresDB.Close()
95 return nil, err
96 }
97
98 return db, nil
99}
100
101func (db *PostgresDB) upgrade() error {
102 tx, err := db.db.Begin()
103 if err != nil {
104 return err
105 }
106 defer tx.Rollback()
107
108 if _, err := tx.Exec(postgresConfigSchema); err != nil {
109 return fmt.Errorf("failed to create Config table: %s", err)
110 }
111
112 var version int
113 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
114 if err != nil && !errors.Is(err, sql.ErrNoRows) {
115 return fmt.Errorf("failed to query schema version: %s", err)
116 }
117
118 if version == len(postgresMigrations) {
119 return nil
120 }
121 if version > len(postgresMigrations) {
122 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
123 }
124
125 if version == 0 {
126 if _, err := tx.Exec(postgresSchema); err != nil {
127 return fmt.Errorf("failed to initialize schema: %s", err)
128 }
129 } else {
130 for i := version; i < len(postgresMigrations); i++ {
131 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
132 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
133 }
134 }
135 }
136
137 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
138 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
139 if err != nil {
140 return fmt.Errorf("failed to bump schema version: %v", err)
141 }
142
143 return tx.Commit()
144}
145
146func (db *PostgresDB) Close() error {
147 return db.db.Close()
148}
149
150func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
151 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
152 defer cancel()
153
154 var stats DatabaseStats
155 row := db.db.QueryRowContext(ctx, `SELECT
156 (SELECT COUNT(*) FROM "User") AS users,
157 (SELECT COUNT(*) FROM "Network") AS networks,
158 (SELECT COUNT(*) FROM "Channel") AS channels`)
159 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
160 return nil, err
161 }
162
163 return &stats, nil
164}
165
166func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
167 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
168 defer cancel()
169
170 rows, err := db.db.QueryContext(ctx,
171 `SELECT id, username, password, admin, realname FROM "User"`)
172 if err != nil {
173 return nil, err
174 }
175 defer rows.Close()
176
177 var users []User
178 for rows.Next() {
179 var user User
180 var password, realname sql.NullString
181 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
182 return nil, err
183 }
184 user.Password = password.String
185 user.Realname = realname.String
186 users = append(users, user)
187 }
188 if err := rows.Err(); err != nil {
189 return nil, err
190 }
191
192 return users, nil
193}
194
195func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
196 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
197 defer cancel()
198
199 user := &User{Username: username}
200
201 var password, realname sql.NullString
202 row := db.db.QueryRowContext(ctx,
203 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
204 username)
205 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
206 return nil, err
207 }
208 user.Password = password.String
209 user.Realname = realname.String
210 return user, nil
211}
212
213func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
214 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
215 defer cancel()
216
217 password := toNullString(user.Password)
218 realname := toNullString(user.Realname)
219
220 var err error
221 if user.ID == 0 {
222 err = db.db.QueryRowContext(ctx, `
223 INSERT INTO "User" (username, password, admin, realname)
224 VALUES ($1, $2, $3, $4)
225 RETURNING id`,
226 user.Username, password, user.Admin, realname).Scan(&user.ID)
227 } else {
228 _, err = db.db.ExecContext(ctx, `
229 UPDATE "User"
230 SET password = $1, admin = $2, realname = $3
231 WHERE id = $4`,
232 password, user.Admin, realname, user.ID)
233 }
234 return err
235}
236
237func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
238 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
239 defer cancel()
240
241 _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
242 return err
243}
244
245func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
246 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
247 defer cancel()
248
249 rows, err := db.db.QueryContext(ctx, `
250 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
251 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
252 FROM "Network"
253 WHERE "user" = $1`, userID)
254 if err != nil {
255 return nil, err
256 }
257 defer rows.Close()
258
259 var networks []Network
260 for rows.Next() {
261 var net Network
262 var name, username, realname, pass, connectCommands sql.NullString
263 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
264 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
265 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
266 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
267 if err != nil {
268 return nil, err
269 }
270 net.Name = name.String
271 net.Username = username.String
272 net.Realname = realname.String
273 net.Pass = pass.String
274 if connectCommands.Valid {
275 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
276 }
277 net.SASL.Mechanism = saslMechanism.String
278 net.SASL.Plain.Username = saslPlainUsername.String
279 net.SASL.Plain.Password = saslPlainPassword.String
280 networks = append(networks, net)
281 }
282 if err := rows.Err(); err != nil {
283 return nil, err
284 }
285
286 return networks, nil
287}
288
289func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
290 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
291 defer cancel()
292
293 netName := toNullString(network.Name)
294 netUsername := toNullString(network.Username)
295 realname := toNullString(network.Realname)
296 pass := toNullString(network.Pass)
297 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
298
299 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
300 if network.SASL.Mechanism != "" {
301 saslMechanism = toNullString(network.SASL.Mechanism)
302 switch network.SASL.Mechanism {
303 case "PLAIN":
304 saslPlainUsername = toNullString(network.SASL.Plain.Username)
305 saslPlainPassword = toNullString(network.SASL.Plain.Password)
306 network.SASL.External.CertBlob = nil
307 network.SASL.External.PrivKeyBlob = nil
308 case "EXTERNAL":
309 // keep saslPlain* nil
310 default:
311 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
312 }
313 }
314
315 var err error
316 if network.ID == 0 {
317 err = db.db.QueryRowContext(ctx, `
318 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
319 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
320 sasl_external_key, enabled)
321 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
322 RETURNING id`,
323 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
324 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
325 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
326 } else {
327 _, err = db.db.ExecContext(ctx, `
328 UPDATE "Network"
329 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
330 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
331 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
332 enabled = $14
333 WHERE id = $1`,
334 network.ID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
335 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
336 network.SASL.External.PrivKeyBlob, network.Enabled)
337 }
338 return err
339}
340
341func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
342 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
343 defer cancel()
344
345 _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
346 return err
347}
348
349func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
350 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
351 defer cancel()
352
353 rows, err := db.db.QueryContext(ctx, `
354 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
355 detach_on
356 FROM "Channel"
357 WHERE network = $1`, networkID)
358 if err != nil {
359 return nil, err
360 }
361 defer rows.Close()
362
363 var channels []Channel
364 for rows.Next() {
365 var ch Channel
366 var key, detachedInternalMsgID sql.NullString
367 var detachAfter int64
368 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
369 return nil, err
370 }
371 ch.Key = key.String
372 ch.DetachedInternalMsgID = detachedInternalMsgID.String
373 ch.DetachAfter = time.Duration(detachAfter) * time.Second
374 channels = append(channels, ch)
375 }
376 if err := rows.Err(); err != nil {
377 return nil, err
378 }
379
380 return channels, nil
381}
382
383func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
384 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
385 defer cancel()
386
387 key := toNullString(ch.Key)
388 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
389
390 var err error
391 if ch.ID == 0 {
392 err = db.db.QueryRowContext(ctx, `
393 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
394 detach_after, detach_on)
395 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
396 RETURNING id`,
397 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
398 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
399 } else {
400 _, err = db.db.ExecContext(ctx, `
401 UPDATE "Channel"
402 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
403 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
404 WHERE id = $1`,
405 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
406 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
407 }
408 return err
409}
410
411func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
412 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
413 defer cancel()
414
415 _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
416 return err
417}
418
419func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
420 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
421 defer cancel()
422
423 rows, err := db.db.QueryContext(ctx, `
424 SELECT id, target, client, internal_msgid
425 FROM "DeliveryReceipt"
426 WHERE network = $1`, networkID)
427 if err != nil {
428 return nil, err
429 }
430 defer rows.Close()
431
432 var receipts []DeliveryReceipt
433 for rows.Next() {
434 var rcpt DeliveryReceipt
435 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
436 return nil, err
437 }
438 receipts = append(receipts, rcpt)
439 }
440 if err := rows.Err(); err != nil {
441 return nil, err
442 }
443
444 return receipts, nil
445}
446
447func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
448 ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
449 defer cancel()
450
451 tx, err := db.db.Begin()
452 if err != nil {
453 return err
454 }
455 defer tx.Rollback()
456
457 _, err = tx.ExecContext(ctx,
458 `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
459 networkID, client)
460 if err != nil {
461 return err
462 }
463
464 stmt, err := tx.PrepareContext(ctx, `
465 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
466 VALUES ($1, $2, $3, $4)
467 RETURNING id`)
468 if err != nil {
469 return err
470 }
471 defer stmt.Close()
472
473 for i := range receipts {
474 rcpt := &receipts[i]
475 err := stmt.
476 QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
477 Scan(&rcpt.ID)
478 if err != nil {
479 return err
480 }
481 }
482
483 return tx.Commit()
484}
Note: See TracBrowser for help on using the repository browser.