source: code/trunk/db_postgres.go@ 644

Last change on this file since 644 was 640, checked in by contact, 4 years ago

db_postgres: remove unnecessary DEFAULT NULL in schema

File size: 12.2 KB
Line 
1package soju
2
3import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "math"
8 "strings"
9 "time"
10
11 _ "github.com/lib/pq"
12)
13
14const postgresConfigSchema = `
15CREATE TABLE IF NOT EXISTS "Config" (
16 id SMALLINT PRIMARY KEY,
17 version INTEGER NOT NULL,
18 CHECK(id = 1)
19);
20`
21
22const postgresSchema = `
23CREATE TABLE "User" (
24 id SERIAL PRIMARY KEY,
25 username VARCHAR(255) NOT NULL UNIQUE,
26 password VARCHAR(255),
27 admin BOOLEAN NOT NULL DEFAULT FALSE,
28 realname VARCHAR(255)
29);
30
31CREATE TABLE "Network" (
32 id SERIAL PRIMARY KEY,
33 name VARCHAR(255),
34 "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
35 addr VARCHAR(255) NOT NULL,
36 nick VARCHAR(255) NOT NULL,
37 username VARCHAR(255),
38 realname VARCHAR(255),
39 pass VARCHAR(255),
40 connect_commands VARCHAR(1023),
41 sasl_mechanism VARCHAR(255),
42 sasl_plain_username VARCHAR(255),
43 sasl_plain_password VARCHAR(255),
44 sasl_external_cert BYTEA,
45 sasl_external_key BYTEA,
46 enabled BOOLEAN NOT NULL DEFAULT TRUE,
47 UNIQUE("user", addr, nick),
48 UNIQUE("user", name)
49);
50
51CREATE TABLE "Channel" (
52 id SERIAL PRIMARY KEY,
53 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
54 name VARCHAR(255) NOT NULL,
55 key VARCHAR(255),
56 detached BOOLEAN NOT NULL DEFAULT FALSE,
57 detached_internal_msgid VARCHAR(255),
58 relay_detached INTEGER NOT NULL DEFAULT 0,
59 reattach_on INTEGER NOT NULL DEFAULT 0,
60 detach_after INTEGER NOT NULL DEFAULT 0,
61 detach_on INTEGER NOT NULL DEFAULT 0,
62 UNIQUE(network, name)
63);
64
65CREATE TABLE "DeliveryReceipt" (
66 id SERIAL PRIMARY KEY,
67 network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
68 target VARCHAR(255) NOT NULL,
69 client VARCHAR(255) NOT NULL DEFAULT '',
70 internal_msgid VARCHAR(255) NOT NULL,
71 UNIQUE(network, target, client)
72);
73`
74
75var postgresMigrations = []string{
76 "", // migration #0 is reserved for schema initialization
77}
78
79type PostgresDB struct {
80 db *sql.DB
81}
82
83func OpenPostgresDB(source string) (Database, error) {
84 sqlPostgresDB, err := sql.Open("postgres", source)
85 if err != nil {
86 return nil, err
87 }
88
89 db := &PostgresDB{db: sqlPostgresDB}
90 if err := db.upgrade(); err != nil {
91 sqlPostgresDB.Close()
92 return nil, err
93 }
94
95 return db, nil
96}
97
98func (db *PostgresDB) upgrade() error {
99 tx, err := db.db.Begin()
100 if err != nil {
101 return err
102 }
103 defer tx.Rollback()
104
105 if _, err := tx.Exec(postgresConfigSchema); err != nil {
106 return fmt.Errorf("failed to create Config table: %s", err)
107 }
108
109 var version int
110 err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
111 if err != nil && !errors.Is(err, sql.ErrNoRows) {
112 return fmt.Errorf("failed to query schema version: %s", err)
113 }
114
115 if version == len(postgresMigrations) {
116 return nil
117 }
118 if version > len(postgresMigrations) {
119 return fmt.Errorf("soju (version %d) older than schema (version %d)", len(postgresMigrations), version)
120 }
121
122 if version == 0 {
123 if _, err := tx.Exec(postgresSchema); err != nil {
124 return fmt.Errorf("failed to initialize schema: %s", err)
125 }
126 } else {
127 for i := version; i < len(postgresMigrations); i++ {
128 if _, err := tx.Exec(postgresMigrations[i]); err != nil {
129 return fmt.Errorf("failed to execute migration #%v: %v", i, err)
130 }
131 }
132 }
133
134 _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
135 ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
136 if err != nil {
137 return fmt.Errorf("failed to bump schema version: %v", err)
138 }
139
140 return tx.Commit()
141}
142
143func (db *PostgresDB) Close() error {
144 return db.db.Close()
145}
146
147func (db *PostgresDB) Stats() (*DatabaseStats, error) {
148 var stats DatabaseStats
149 row := db.db.QueryRow(`SELECT
150 (SELECT COUNT(*) FROM "User") AS users,
151 (SELECT COUNT(*) FROM "Network") AS networks,
152 (SELECT COUNT(*) FROM "Channel") AS channels`)
153 if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
154 return nil, err
155 }
156
157 return &stats, nil
158}
159
160func (db *PostgresDB) ListUsers() ([]User, error) {
161 rows, err := db.db.Query(`SELECT id, username, password, admin, realname FROM "User"`)
162 if err != nil {
163 return nil, err
164 }
165 defer rows.Close()
166
167 var users []User
168 for rows.Next() {
169 var user User
170 var password, realname sql.NullString
171 if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
172 return nil, err
173 }
174 user.Password = password.String
175 user.Realname = realname.String
176 users = append(users, user)
177 }
178 if err := rows.Err(); err != nil {
179 return nil, err
180 }
181
182 return users, nil
183}
184
185func (db *PostgresDB) GetUser(username string) (*User, error) {
186 user := &User{Username: username}
187
188 var password, realname sql.NullString
189 row := db.db.QueryRow(
190 `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
191 username)
192 if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
193 return nil, err
194 }
195 user.Password = password.String
196 user.Realname = realname.String
197 return user, nil
198}
199
200func (db *PostgresDB) StoreUser(user *User) error {
201 password := toNullString(user.Password)
202 realname := toNullString(user.Realname)
203
204 var err error
205 if user.ID == 0 {
206 err = db.db.QueryRow(`
207 INSERT INTO "User" (username, password, admin, realname)
208 VALUES ($1, $2, $3, $4)
209 RETURNING id`,
210 user.Username, password, user.Admin, realname).Scan(&user.ID)
211 } else {
212 _, err = db.db.Exec(`
213 UPDATE "User"
214 SET password = $1, admin = $2, realname = $3
215 WHERE id = $4`,
216 password, user.Admin, realname, user.ID)
217 }
218 return err
219}
220
221func (db *PostgresDB) DeleteUser(id int64) error {
222 _, err := db.db.Exec(`DELETE FROM "User" WHERE id = $1`, id)
223 return err
224}
225
226func (db *PostgresDB) ListNetworks(userID int64) ([]Network, error) {
227 rows, err := db.db.Query(`
228 SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
229 sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
230 FROM "Network"
231 WHERE "user" = $1`, userID)
232 if err != nil {
233 return nil, err
234 }
235 defer rows.Close()
236
237 var networks []Network
238 for rows.Next() {
239 var net Network
240 var name, username, realname, pass, connectCommands sql.NullString
241 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
242 err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
243 &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
244 &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
245 if err != nil {
246 return nil, err
247 }
248 net.Name = name.String
249 net.Username = username.String
250 net.Realname = realname.String
251 net.Pass = pass.String
252 if connectCommands.Valid {
253 net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
254 }
255 net.SASL.Mechanism = saslMechanism.String
256 net.SASL.Plain.Username = saslPlainUsername.String
257 net.SASL.Plain.Password = saslPlainPassword.String
258 networks = append(networks, net)
259 }
260 if err := rows.Err(); err != nil {
261 return nil, err
262 }
263
264 return networks, nil
265}
266
267func (db *PostgresDB) StoreNetwork(userID int64, network *Network) error {
268 netName := toNullString(network.Name)
269 netUsername := toNullString(network.Username)
270 realname := toNullString(network.Realname)
271 pass := toNullString(network.Pass)
272 connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
273
274 var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
275 if network.SASL.Mechanism != "" {
276 saslMechanism = toNullString(network.SASL.Mechanism)
277 switch network.SASL.Mechanism {
278 case "PLAIN":
279 saslPlainUsername = toNullString(network.SASL.Plain.Username)
280 saslPlainPassword = toNullString(network.SASL.Plain.Password)
281 network.SASL.External.CertBlob = nil
282 network.SASL.External.PrivKeyBlob = nil
283 case "EXTERNAL":
284 // keep saslPlain* nil
285 default:
286 return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
287 }
288 }
289
290 var err error
291 if network.ID == 0 {
292 err = db.db.QueryRow(`
293 INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
294 sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
295 sasl_external_key, enabled)
296 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
297 RETURNING id`,
298 userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
299 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
300 network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
301 } else {
302 _, err = db.db.Exec(`
303 UPDATE "Network"
304 SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
305 connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
306 sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
307 enabled = $14
308 WHERE id = $1`,
309 network.ID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
310 saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
311 network.SASL.External.PrivKeyBlob, network.Enabled)
312 }
313 return err
314}
315
316func (db *PostgresDB) DeleteNetwork(id int64) error {
317 _, err := db.db.Exec(`DELETE FROM "Network" WHERE id = $1`, id)
318 return err
319}
320
321func (db *PostgresDB) ListChannels(networkID int64) ([]Channel, error) {
322 rows, err := db.db.Query(`
323 SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
324 detach_on
325 FROM "Channel"
326 WHERE network = $1`, networkID)
327 if err != nil {
328 return nil, err
329 }
330 defer rows.Close()
331
332 var channels []Channel
333 for rows.Next() {
334 var ch Channel
335 var key, detachedInternalMsgID sql.NullString
336 var detachAfter int64
337 if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
338 return nil, err
339 }
340 ch.Key = key.String
341 ch.DetachedInternalMsgID = detachedInternalMsgID.String
342 ch.DetachAfter = time.Duration(detachAfter) * time.Second
343 channels = append(channels, ch)
344 }
345 if err := rows.Err(); err != nil {
346 return nil, err
347 }
348
349 return channels, nil
350}
351
352func (db *PostgresDB) StoreChannel(networkID int64, ch *Channel) error {
353 key := toNullString(ch.Key)
354 detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
355
356 var err error
357 if ch.ID == 0 {
358 err = db.db.QueryRow(`
359 INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
360 detach_after, detach_on)
361 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
362 RETURNING id`,
363 networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
364 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
365 } else {
366 _, err = db.db.Exec(`
367 UPDATE "Channel"
368 SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
369 relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
370 WHERE id = $1`,
371 ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
372 ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
373 }
374 return err
375}
376
377func (db *PostgresDB) DeleteChannel(id int64) error {
378 _, err := db.db.Exec(`DELETE FROM "Channel" WHERE id = $1`, id)
379 return err
380}
381
382func (db *PostgresDB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
383 rows, err := db.db.Query(`
384 SELECT id, target, client, internal_msgid
385 FROM "DeliveryReceipt"
386 WHERE network = $1`, networkID)
387 if err != nil {
388 return nil, err
389 }
390 defer rows.Close()
391
392 var receipts []DeliveryReceipt
393 for rows.Next() {
394 var rcpt DeliveryReceipt
395 if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
396 return nil, err
397 }
398 receipts = append(receipts, rcpt)
399 }
400 if err := rows.Err(); err != nil {
401 return nil, err
402 }
403
404 return receipts, nil
405}
406
407func (db *PostgresDB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
408 tx, err := db.db.Begin()
409 if err != nil {
410 return err
411 }
412 defer tx.Rollback()
413
414 _, err = tx.Exec(`DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
415 networkID, client)
416 if err != nil {
417 return err
418 }
419
420 stmt, err := tx.Prepare(`
421 INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
422 VALUES ($1, $2, $3, $4)
423 RETURNING id`)
424 if err != nil {
425 return err
426 }
427 defer stmt.Close()
428
429 for i := range receipts {
430 rcpt := &receipts[i]
431 err := stmt.QueryRow(networkID, rcpt.Target, client, rcpt.InternalMsgID).Scan(&rcpt.ID)
432 if err != nil {
433 return err
434 }
435 }
436
437 return tx.Commit()
438}
Note: See TracBrowser for help on using the repository browser.