source: code/trunk/vendor/github.com/lib/pq/conn.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 48.2 KB
Line 
1package pq
2
3import (
4 "bufio"
5 "context"
6 "crypto/md5"
7 "crypto/sha256"
8 "database/sql"
9 "database/sql/driver"
10 "encoding/binary"
11 "errors"
12 "fmt"
13 "io"
14 "net"
15 "os"
16 "os/user"
17 "path"
18 "path/filepath"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23 "unicode"
24
25 "github.com/lib/pq/oid"
26 "github.com/lib/pq/scram"
27)
28
29// Common error types
30var (
31 ErrNotSupported = errors.New("pq: Unsupported command")
32 ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
33 ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
34 ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected")
35 ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less")
36
37 ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
38
39 errUnexpectedReady = errors.New("unexpected ReadyForQuery")
40 errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
41 errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
42)
43
44// Compile time validation that our types implement the expected interfaces
45var (
46 _ driver.Driver = Driver{}
47)
48
49// Driver is the Postgres database driver.
50type Driver struct{}
51
52// Open opens a new connection to the database. name is a connection string.
53// Most users should only use it through database/sql package from the standard
54// library.
55func (d Driver) Open(name string) (driver.Conn, error) {
56 return Open(name)
57}
58
59func init() {
60 sql.Register("postgres", &Driver{})
61}
62
63type parameterStatus struct {
64 // server version in the same format as server_version_num, or 0 if
65 // unavailable
66 serverVersion int
67
68 // the current location based on the TimeZone value of the session, if
69 // available
70 currentLocation *time.Location
71}
72
73type transactionStatus byte
74
75const (
76 txnStatusIdle transactionStatus = 'I'
77 txnStatusIdleInTransaction transactionStatus = 'T'
78 txnStatusInFailedTransaction transactionStatus = 'E'
79)
80
81func (s transactionStatus) String() string {
82 switch s {
83 case txnStatusIdle:
84 return "idle"
85 case txnStatusIdleInTransaction:
86 return "idle in transaction"
87 case txnStatusInFailedTransaction:
88 return "in a failed transaction"
89 default:
90 errorf("unknown transactionStatus %d", s)
91 }
92
93 panic("not reached")
94}
95
96// Dialer is the dialer interface. It can be used to obtain more control over
97// how pq creates network connections.
98type Dialer interface {
99 Dial(network, address string) (net.Conn, error)
100 DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
101}
102
103// DialerContext is the context-aware dialer interface.
104type DialerContext interface {
105 DialContext(ctx context.Context, network, address string) (net.Conn, error)
106}
107
108type defaultDialer struct {
109 d net.Dialer
110}
111
112func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
113 return d.d.Dial(network, address)
114}
115func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
116 ctx, cancel := context.WithTimeout(context.Background(), timeout)
117 defer cancel()
118 return d.DialContext(ctx, network, address)
119}
120func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
121 return d.d.DialContext(ctx, network, address)
122}
123
124type conn struct {
125 c net.Conn
126 buf *bufio.Reader
127 namei int
128 scratch [512]byte
129 txnStatus transactionStatus
130 txnFinish func()
131
132 // Save connection arguments to use during CancelRequest.
133 dialer Dialer
134 opts values
135
136 // Cancellation key data for use with CancelRequest messages.
137 processID int
138 secretKey int
139
140 parameterStatus parameterStatus
141
142 saveMessageType byte
143 saveMessageBuffer []byte
144
145 // If an error is set, this connection is bad and all public-facing
146 // functions should return the appropriate error by calling get()
147 // (ErrBadConn) or getForNext().
148 err syncErr
149
150 // If set, this connection should never use the binary format when
151 // receiving query results from prepared statements. Only provided for
152 // debugging.
153 disablePreparedBinaryResult bool
154
155 // Whether to always send []byte parameters over as binary. Enables single
156 // round-trip mode for non-prepared Query calls.
157 binaryParameters bool
158
159 // If true this connection is in the middle of a COPY
160 inCopy bool
161
162 // If not nil, notices will be synchronously sent here
163 noticeHandler func(*Error)
164
165 // If not nil, notifications will be synchronously sent here
166 notificationHandler func(*Notification)
167
168 // GSSAPI context
169 gss GSS
170}
171
172type syncErr struct {
173 err error
174 sync.Mutex
175}
176
177// Return ErrBadConn if connection is bad.
178func (e *syncErr) get() error {
179 e.Lock()
180 defer e.Unlock()
181 if e.err != nil {
182 return driver.ErrBadConn
183 }
184 return nil
185}
186
187// Return the error set on the connection. Currently only used by rows.Next.
188func (e *syncErr) getForNext() error {
189 e.Lock()
190 defer e.Unlock()
191 return e.err
192}
193
194// Set error, only if it isn't set yet.
195func (e *syncErr) set(err error) {
196 if err == nil {
197 panic("attempt to set nil err")
198 }
199 e.Lock()
200 defer e.Unlock()
201 if e.err == nil {
202 e.err = err
203 }
204}
205
206// Handle driver-side settings in parsed connection string.
207func (cn *conn) handleDriverSettings(o values) (err error) {
208 boolSetting := func(key string, val *bool) error {
209 if value, ok := o[key]; ok {
210 if value == "yes" {
211 *val = true
212 } else if value == "no" {
213 *val = false
214 } else {
215 return fmt.Errorf("unrecognized value %q for %s", value, key)
216 }
217 }
218 return nil
219 }
220
221 err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
222 if err != nil {
223 return err
224 }
225 return boolSetting("binary_parameters", &cn.binaryParameters)
226}
227
228func (cn *conn) handlePgpass(o values) {
229 // if a password was supplied, do not process .pgpass
230 if _, ok := o["password"]; ok {
231 return
232 }
233 filename := os.Getenv("PGPASSFILE")
234 if filename == "" {
235 // XXX this code doesn't work on Windows where the default filename is
236 // XXX %APPDATA%\postgresql\pgpass.conf
237 // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
238 userHome := os.Getenv("HOME")
239 if userHome == "" {
240 user, err := user.Current()
241 if err != nil {
242 return
243 }
244 userHome = user.HomeDir
245 }
246 filename = filepath.Join(userHome, ".pgpass")
247 }
248 fileinfo, err := os.Stat(filename)
249 if err != nil {
250 return
251 }
252 mode := fileinfo.Mode()
253 if mode&(0x77) != 0 {
254 // XXX should warn about incorrect .pgpass permissions as psql does
255 return
256 }
257 file, err := os.Open(filename)
258 if err != nil {
259 return
260 }
261 defer file.Close()
262 scanner := bufio.NewScanner(io.Reader(file))
263 hostname := o["host"]
264 ntw, _ := network(o)
265 port := o["port"]
266 db := o["dbname"]
267 username := o["user"]
268 // From: https://github.com/tg/pgpass/blob/master/reader.go
269 getFields := func(s string) []string {
270 fs := make([]string, 0, 5)
271 f := make([]rune, 0, len(s))
272
273 var esc bool
274 for _, c := range s {
275 switch {
276 case esc:
277 f = append(f, c)
278 esc = false
279 case c == '\\':
280 esc = true
281 case c == ':':
282 fs = append(fs, string(f))
283 f = f[:0]
284 default:
285 f = append(f, c)
286 }
287 }
288 return append(fs, string(f))
289 }
290 for scanner.Scan() {
291 line := scanner.Text()
292 if len(line) == 0 || line[0] == '#' {
293 continue
294 }
295 split := getFields(line)
296 if len(split) != 5 {
297 continue
298 }
299 if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
300 o["password"] = split[4]
301 return
302 }
303 }
304}
305
306func (cn *conn) writeBuf(b byte) *writeBuf {
307 cn.scratch[0] = b
308 return &writeBuf{
309 buf: cn.scratch[:5],
310 pos: 1,
311 }
312}
313
314// Open opens a new connection to the database. dsn is a connection string.
315// Most users should only use it through database/sql package from the standard
316// library.
317func Open(dsn string) (_ driver.Conn, err error) {
318 return DialOpen(defaultDialer{}, dsn)
319}
320
321// DialOpen opens a new connection to the database using a dialer.
322func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
323 c, err := NewConnector(dsn)
324 if err != nil {
325 return nil, err
326 }
327 c.Dialer(d)
328 return c.open(context.Background())
329}
330
331func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
332 // Handle any panics during connection initialization. Note that we
333 // specifically do *not* want to use errRecover(), as that would turn any
334 // connection errors into ErrBadConns, hiding the real error message from
335 // the user.
336 defer errRecoverNoErrBadConn(&err)
337
338 // Create a new values map (copy). This makes it so maps in different
339 // connections do not reference the same underlying data structure, so it
340 // is safe for multiple connections to concurrently write to their opts.
341 o := make(values)
342 for k, v := range c.opts {
343 o[k] = v
344 }
345
346 cn = &conn{
347 opts: o,
348 dialer: c.dialer,
349 }
350 err = cn.handleDriverSettings(o)
351 if err != nil {
352 return nil, err
353 }
354 cn.handlePgpass(o)
355
356 cn.c, err = dial(ctx, c.dialer, o)
357 if err != nil {
358 return nil, err
359 }
360
361 err = cn.ssl(o)
362 if err != nil {
363 if cn.c != nil {
364 cn.c.Close()
365 }
366 return nil, err
367 }
368
369 // cn.startup panics on error. Make sure we don't leak cn.c.
370 panicking := true
371 defer func() {
372 if panicking {
373 cn.c.Close()
374 }
375 }()
376
377 cn.buf = bufio.NewReader(cn.c)
378 cn.startup(o)
379
380 // reset the deadline, in case one was set (see dial)
381 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
382 err = cn.c.SetDeadline(time.Time{})
383 }
384 panicking = false
385 return cn, err
386}
387
388func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
389 network, address := network(o)
390
391 // Zero or not specified means wait indefinitely.
392 if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
393 seconds, err := strconv.ParseInt(timeout, 10, 0)
394 if err != nil {
395 return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
396 }
397 duration := time.Duration(seconds) * time.Second
398
399 // connect_timeout should apply to the entire connection establishment
400 // procedure, so we both use a timeout for the TCP connection
401 // establishment and set a deadline for doing the initial handshake.
402 // The deadline is then reset after startup() is done.
403 deadline := time.Now().Add(duration)
404 var conn net.Conn
405 if dctx, ok := d.(DialerContext); ok {
406 ctx, cancel := context.WithTimeout(ctx, duration)
407 defer cancel()
408 conn, err = dctx.DialContext(ctx, network, address)
409 } else {
410 conn, err = d.DialTimeout(network, address, duration)
411 }
412 if err != nil {
413 return nil, err
414 }
415 err = conn.SetDeadline(deadline)
416 return conn, err
417 }
418 if dctx, ok := d.(DialerContext); ok {
419 return dctx.DialContext(ctx, network, address)
420 }
421 return d.Dial(network, address)
422}
423
424func network(o values) (string, string) {
425 host := o["host"]
426
427 if strings.HasPrefix(host, "/") {
428 sockPath := path.Join(host, ".s.PGSQL."+o["port"])
429 return "unix", sockPath
430 }
431
432 return "tcp", net.JoinHostPort(host, o["port"])
433}
434
435type values map[string]string
436
437// scanner implements a tokenizer for libpq-style option strings.
438type scanner struct {
439 s []rune
440 i int
441}
442
443// newScanner returns a new scanner initialized with the option string s.
444func newScanner(s string) *scanner {
445 return &scanner{[]rune(s), 0}
446}
447
448// Next returns the next rune.
449// It returns 0, false if the end of the text has been reached.
450func (s *scanner) Next() (rune, bool) {
451 if s.i >= len(s.s) {
452 return 0, false
453 }
454 r := s.s[s.i]
455 s.i++
456 return r, true
457}
458
459// SkipSpaces returns the next non-whitespace rune.
460// It returns 0, false if the end of the text has been reached.
461func (s *scanner) SkipSpaces() (rune, bool) {
462 r, ok := s.Next()
463 for unicode.IsSpace(r) && ok {
464 r, ok = s.Next()
465 }
466 return r, ok
467}
468
469// parseOpts parses the options from name and adds them to the values.
470//
471// The parsing code is based on conninfo_parse from libpq's fe-connect.c
472func parseOpts(name string, o values) error {
473 s := newScanner(name)
474
475 for {
476 var (
477 keyRunes, valRunes []rune
478 r rune
479 ok bool
480 )
481
482 if r, ok = s.SkipSpaces(); !ok {
483 break
484 }
485
486 // Scan the key
487 for !unicode.IsSpace(r) && r != '=' {
488 keyRunes = append(keyRunes, r)
489 if r, ok = s.Next(); !ok {
490 break
491 }
492 }
493
494 // Skip any whitespace if we're not at the = yet
495 if r != '=' {
496 r, ok = s.SkipSpaces()
497 }
498
499 // The current character should be =
500 if r != '=' || !ok {
501 return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
502 }
503
504 // Skip any whitespace after the =
505 if r, ok = s.SkipSpaces(); !ok {
506 // If we reach the end here, the last value is just an empty string as per libpq.
507 o[string(keyRunes)] = ""
508 break
509 }
510
511 if r != '\'' {
512 for !unicode.IsSpace(r) {
513 if r == '\\' {
514 if r, ok = s.Next(); !ok {
515 return fmt.Errorf(`missing character after backslash`)
516 }
517 }
518 valRunes = append(valRunes, r)
519
520 if r, ok = s.Next(); !ok {
521 break
522 }
523 }
524 } else {
525 quote:
526 for {
527 if r, ok = s.Next(); !ok {
528 return fmt.Errorf(`unterminated quoted string literal in connection string`)
529 }
530 switch r {
531 case '\'':
532 break quote
533 case '\\':
534 r, _ = s.Next()
535 fallthrough
536 default:
537 valRunes = append(valRunes, r)
538 }
539 }
540 }
541
542 o[string(keyRunes)] = string(valRunes)
543 }
544
545 return nil
546}
547
548func (cn *conn) isInTransaction() bool {
549 return cn.txnStatus == txnStatusIdleInTransaction ||
550 cn.txnStatus == txnStatusInFailedTransaction
551}
552
553func (cn *conn) checkIsInTransaction(intxn bool) {
554 if cn.isInTransaction() != intxn {
555 cn.err.set(driver.ErrBadConn)
556 errorf("unexpected transaction status %v", cn.txnStatus)
557 }
558}
559
560func (cn *conn) Begin() (_ driver.Tx, err error) {
561 return cn.begin("")
562}
563
564func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
565 if err := cn.err.get(); err != nil {
566 return nil, err
567 }
568 defer cn.errRecover(&err)
569
570 cn.checkIsInTransaction(false)
571 _, commandTag, err := cn.simpleExec("BEGIN" + mode)
572 if err != nil {
573 return nil, err
574 }
575 if commandTag != "BEGIN" {
576 cn.err.set(driver.ErrBadConn)
577 return nil, fmt.Errorf("unexpected command tag %s", commandTag)
578 }
579 if cn.txnStatus != txnStatusIdleInTransaction {
580 cn.err.set(driver.ErrBadConn)
581 return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
582 }
583 return cn, nil
584}
585
586func (cn *conn) closeTxn() {
587 if finish := cn.txnFinish; finish != nil {
588 finish()
589 }
590}
591
592func (cn *conn) Commit() (err error) {
593 defer cn.closeTxn()
594 if err := cn.err.get(); err != nil {
595 return err
596 }
597 defer cn.errRecover(&err)
598
599 cn.checkIsInTransaction(true)
600 // We don't want the client to think that everything is okay if it tries
601 // to commit a failed transaction. However, no matter what we return,
602 // database/sql will release this connection back into the free connection
603 // pool so we have to abort the current transaction here. Note that you
604 // would get the same behaviour if you issued a COMMIT in a failed
605 // transaction, so it's also the least surprising thing to do here.
606 if cn.txnStatus == txnStatusInFailedTransaction {
607 if err := cn.rollback(); err != nil {
608 return err
609 }
610 return ErrInFailedTransaction
611 }
612
613 _, commandTag, err := cn.simpleExec("COMMIT")
614 if err != nil {
615 if cn.isInTransaction() {
616 cn.err.set(driver.ErrBadConn)
617 }
618 return err
619 }
620 if commandTag != "COMMIT" {
621 cn.err.set(driver.ErrBadConn)
622 return fmt.Errorf("unexpected command tag %s", commandTag)
623 }
624 cn.checkIsInTransaction(false)
625 return nil
626}
627
628func (cn *conn) Rollback() (err error) {
629 defer cn.closeTxn()
630 if err := cn.err.get(); err != nil {
631 return err
632 }
633 defer cn.errRecover(&err)
634 return cn.rollback()
635}
636
637func (cn *conn) rollback() (err error) {
638 cn.checkIsInTransaction(true)
639 _, commandTag, err := cn.simpleExec("ROLLBACK")
640 if err != nil {
641 if cn.isInTransaction() {
642 cn.err.set(driver.ErrBadConn)
643 }
644 return err
645 }
646 if commandTag != "ROLLBACK" {
647 return fmt.Errorf("unexpected command tag %s", commandTag)
648 }
649 cn.checkIsInTransaction(false)
650 return nil
651}
652
653func (cn *conn) gname() string {
654 cn.namei++
655 return strconv.FormatInt(int64(cn.namei), 10)
656}
657
658func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
659 b := cn.writeBuf('Q')
660 b.string(q)
661 cn.send(b)
662
663 for {
664 t, r := cn.recv1()
665 switch t {
666 case 'C':
667 res, commandTag = cn.parseComplete(r.string())
668 case 'Z':
669 cn.processReadyForQuery(r)
670 if res == nil && err == nil {
671 err = errUnexpectedReady
672 }
673 // done
674 return
675 case 'E':
676 err = parseError(r)
677 case 'I':
678 res = emptyRows
679 case 'T', 'D':
680 // ignore any results
681 default:
682 cn.err.set(driver.ErrBadConn)
683 errorf("unknown response for simple query: %q", t)
684 }
685 }
686}
687
688func (cn *conn) simpleQuery(q string) (res *rows, err error) {
689 defer cn.errRecover(&err)
690
691 b := cn.writeBuf('Q')
692 b.string(q)
693 cn.send(b)
694
695 for {
696 t, r := cn.recv1()
697 switch t {
698 case 'C', 'I':
699 // We allow queries which don't return any results through Query as
700 // well as Exec. We still have to give database/sql a rows object
701 // the user can close, though, to avoid connections from being
702 // leaked. A "rows" with done=true works fine for that purpose.
703 if err != nil {
704 cn.err.set(driver.ErrBadConn)
705 errorf("unexpected message %q in simple query execution", t)
706 }
707 if res == nil {
708 res = &rows{
709 cn: cn,
710 }
711 }
712 // Set the result and tag to the last command complete if there wasn't a
713 // query already run. Although queries usually return from here and cede
714 // control to Next, a query with zero results does not.
715 if t == 'C' {
716 res.result, res.tag = cn.parseComplete(r.string())
717 if res.colNames != nil {
718 return
719 }
720 }
721 res.done = true
722 case 'Z':
723 cn.processReadyForQuery(r)
724 // done
725 return
726 case 'E':
727 res = nil
728 err = parseError(r)
729 case 'D':
730 if res == nil {
731 cn.err.set(driver.ErrBadConn)
732 errorf("unexpected DataRow in simple query execution")
733 }
734 // the query didn't fail; kick off to Next
735 cn.saveMessage(t, r)
736 return
737 case 'T':
738 // res might be non-nil here if we received a previous
739 // CommandComplete, but that's fine; just overwrite it
740 res = &rows{cn: cn}
741 res.rowsHeader = parsePortalRowDescribe(r)
742
743 // To work around a bug in QueryRow in Go 1.2 and earlier, wait
744 // until the first DataRow has been received.
745 default:
746 cn.err.set(driver.ErrBadConn)
747 errorf("unknown response for simple query: %q", t)
748 }
749 }
750}
751
752type noRows struct{}
753
754var emptyRows noRows
755
756var _ driver.Result = noRows{}
757
758func (noRows) LastInsertId() (int64, error) {
759 return 0, errNoLastInsertID
760}
761
762func (noRows) RowsAffected() (int64, error) {
763 return 0, errNoRowsAffected
764}
765
766// Decides which column formats to use for a prepared statement. The input is
767// an array of type oids, one element per result column.
768func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
769 if len(colTyps) == 0 {
770 return nil, colFmtDataAllText
771 }
772
773 colFmts = make([]format, len(colTyps))
774 if forceText {
775 return colFmts, colFmtDataAllText
776 }
777
778 allBinary := true
779 allText := true
780 for i, t := range colTyps {
781 switch t.OID {
782 // This is the list of types to use binary mode for when receiving them
783 // through a prepared statement. If a type appears in this list, it
784 // must also be implemented in binaryDecode in encode.go.
785 case oid.T_bytea:
786 fallthrough
787 case oid.T_int8:
788 fallthrough
789 case oid.T_int4:
790 fallthrough
791 case oid.T_int2:
792 fallthrough
793 case oid.T_uuid:
794 colFmts[i] = formatBinary
795 allText = false
796
797 default:
798 allBinary = false
799 }
800 }
801
802 if allBinary {
803 return colFmts, colFmtDataAllBinary
804 } else if allText {
805 return colFmts, colFmtDataAllText
806 } else {
807 colFmtData = make([]byte, 2+len(colFmts)*2)
808 binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
809 for i, v := range colFmts {
810 binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
811 }
812 return colFmts, colFmtData
813 }
814}
815
816func (cn *conn) prepareTo(q, stmtName string) *stmt {
817 st := &stmt{cn: cn, name: stmtName}
818
819 b := cn.writeBuf('P')
820 b.string(st.name)
821 b.string(q)
822 b.int16(0)
823
824 b.next('D')
825 b.byte('S')
826 b.string(st.name)
827
828 b.next('S')
829 cn.send(b)
830
831 cn.readParseResponse()
832 st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
833 st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
834 cn.readReadyForQuery()
835 return st
836}
837
838func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
839 if err := cn.err.get(); err != nil {
840 return nil, err
841 }
842 defer cn.errRecover(&err)
843
844 if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
845 s, err := cn.prepareCopyIn(q)
846 if err == nil {
847 cn.inCopy = true
848 }
849 return s, err
850 }
851 return cn.prepareTo(q, cn.gname()), nil
852}
853
854func (cn *conn) Close() (err error) {
855 // Skip cn.bad return here because we always want to close a connection.
856 defer cn.errRecover(&err)
857
858 // Ensure that cn.c.Close is always run. Since error handling is done with
859 // panics and cn.errRecover, the Close must be in a defer.
860 defer func() {
861 cerr := cn.c.Close()
862 if err == nil {
863 err = cerr
864 }
865 }()
866
867 // Don't go through send(); ListenerConn relies on us not scribbling on the
868 // scratch buffer of this connection.
869 return cn.sendSimpleMessage('X')
870}
871
872// Implement the "Queryer" interface
873func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
874 return cn.query(query, args)
875}
876
877func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
878 if err := cn.err.get(); err != nil {
879 return nil, err
880 }
881 if cn.inCopy {
882 return nil, errCopyInProgress
883 }
884 defer cn.errRecover(&err)
885
886 // Check to see if we can use the "simpleQuery" interface, which is
887 // *much* faster than going through prepare/exec
888 if len(args) == 0 {
889 return cn.simpleQuery(query)
890 }
891
892 if cn.binaryParameters {
893 cn.sendBinaryModeQuery(query, args)
894
895 cn.readParseResponse()
896 cn.readBindResponse()
897 rows := &rows{cn: cn}
898 rows.rowsHeader = cn.readPortalDescribeResponse()
899 cn.postExecuteWorkaround()
900 return rows, nil
901 }
902 st := cn.prepareTo(query, "")
903 st.exec(args)
904 return &rows{
905 cn: cn,
906 rowsHeader: st.rowsHeader,
907 }, nil
908}
909
910// Implement the optional "Execer" interface for one-shot queries
911func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
912 if err := cn.err.get(); err != nil {
913 return nil, err
914 }
915 defer cn.errRecover(&err)
916
917 // Check to see if we can use the "simpleExec" interface, which is
918 // *much* faster than going through prepare/exec
919 if len(args) == 0 {
920 // ignore commandTag, our caller doesn't care
921 r, _, err := cn.simpleExec(query)
922 return r, err
923 }
924
925 if cn.binaryParameters {
926 cn.sendBinaryModeQuery(query, args)
927
928 cn.readParseResponse()
929 cn.readBindResponse()
930 cn.readPortalDescribeResponse()
931 cn.postExecuteWorkaround()
932 res, _, err = cn.readExecuteResponse("Execute")
933 return res, err
934 }
935 // Use the unnamed statement to defer planning until bind
936 // time, or else value-based selectivity estimates cannot be
937 // used.
938 st := cn.prepareTo(query, "")
939 r, err := st.Exec(args)
940 if err != nil {
941 panic(err)
942 }
943 return r, err
944}
945
946type safeRetryError struct {
947 Err error
948}
949
950func (se *safeRetryError) Error() string {
951 return se.Err.Error()
952}
953
954func (cn *conn) send(m *writeBuf) {
955 n, err := cn.c.Write(m.wrap())
956 if err != nil {
957 if n == 0 {
958 err = &safeRetryError{Err: err}
959 }
960 panic(err)
961 }
962}
963
964func (cn *conn) sendStartupPacket(m *writeBuf) error {
965 _, err := cn.c.Write((m.wrap())[1:])
966 return err
967}
968
969// Send a message of type typ to the server on the other end of cn. The
970// message should have no payload. This method does not use the scratch
971// buffer.
972func (cn *conn) sendSimpleMessage(typ byte) (err error) {
973 _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
974 return err
975}
976
977// saveMessage memorizes a message and its buffer in the conn struct.
978// recvMessage will then return these values on the next call to it. This
979// method is useful in cases where you have to see what the next message is
980// going to be (e.g. to see whether it's an error or not) but you can't handle
981// the message yourself.
982func (cn *conn) saveMessage(typ byte, buf *readBuf) {
983 if cn.saveMessageType != 0 {
984 cn.err.set(driver.ErrBadConn)
985 errorf("unexpected saveMessageType %d", cn.saveMessageType)
986 }
987 cn.saveMessageType = typ
988 cn.saveMessageBuffer = *buf
989}
990
991// recvMessage receives any message from the backend, or returns an error if
992// a problem occurred while reading the message.
993func (cn *conn) recvMessage(r *readBuf) (byte, error) {
994 // workaround for a QueryRow bug, see exec
995 if cn.saveMessageType != 0 {
996 t := cn.saveMessageType
997 *r = cn.saveMessageBuffer
998 cn.saveMessageType = 0
999 cn.saveMessageBuffer = nil
1000 return t, nil
1001 }
1002
1003 x := cn.scratch[:5]
1004 _, err := io.ReadFull(cn.buf, x)
1005 if err != nil {
1006 return 0, err
1007 }
1008
1009 // read the type and length of the message that follows
1010 t := x[0]
1011 n := int(binary.BigEndian.Uint32(x[1:])) - 4
1012 var y []byte
1013 if n <= len(cn.scratch) {
1014 y = cn.scratch[:n]
1015 } else {
1016 y = make([]byte, n)
1017 }
1018 _, err = io.ReadFull(cn.buf, y)
1019 if err != nil {
1020 return 0, err
1021 }
1022 *r = y
1023 return t, nil
1024}
1025
1026// recv receives a message from the backend, but if an error happened while
1027// reading the message or the received message was an ErrorResponse, it panics.
1028// NoticeResponses are ignored. This function should generally be used only
1029// during the startup sequence.
1030func (cn *conn) recv() (t byte, r *readBuf) {
1031 for {
1032 var err error
1033 r = &readBuf{}
1034 t, err = cn.recvMessage(r)
1035 if err != nil {
1036 panic(err)
1037 }
1038 switch t {
1039 case 'E':
1040 panic(parseError(r))
1041 case 'N':
1042 if n := cn.noticeHandler; n != nil {
1043 n(parseError(r))
1044 }
1045 case 'A':
1046 if n := cn.notificationHandler; n != nil {
1047 n(recvNotification(r))
1048 }
1049 default:
1050 return
1051 }
1052 }
1053}
1054
1055// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
1056// the caller to avoid an allocation.
1057func (cn *conn) recv1Buf(r *readBuf) byte {
1058 for {
1059 t, err := cn.recvMessage(r)
1060 if err != nil {
1061 panic(err)
1062 }
1063
1064 switch t {
1065 case 'A':
1066 if n := cn.notificationHandler; n != nil {
1067 n(recvNotification(r))
1068 }
1069 case 'N':
1070 if n := cn.noticeHandler; n != nil {
1071 n(parseError(r))
1072 }
1073 case 'S':
1074 cn.processParameterStatus(r)
1075 default:
1076 return t
1077 }
1078 }
1079}
1080
1081// recv1 receives a message from the backend, panicking if an error occurs
1082// while attempting to read it. All asynchronous messages are ignored, with
1083// the exception of ErrorResponse.
1084func (cn *conn) recv1() (t byte, r *readBuf) {
1085 r = &readBuf{}
1086 t = cn.recv1Buf(r)
1087 return t, r
1088}
1089
1090func (cn *conn) ssl(o values) error {
1091 upgrade, err := ssl(o)
1092 if err != nil {
1093 return err
1094 }
1095
1096 if upgrade == nil {
1097 // Nothing to do
1098 return nil
1099 }
1100
1101 w := cn.writeBuf(0)
1102 w.int32(80877103)
1103 if err = cn.sendStartupPacket(w); err != nil {
1104 return err
1105 }
1106
1107 b := cn.scratch[:1]
1108 _, err = io.ReadFull(cn.c, b)
1109 if err != nil {
1110 return err
1111 }
1112
1113 if b[0] != 'S' {
1114 return ErrSSLNotSupported
1115 }
1116
1117 cn.c, err = upgrade(cn.c)
1118 return err
1119}
1120
1121// isDriverSetting returns true iff a setting is purely for configuring the
1122// driver's options and should not be sent to the server in the connection
1123// startup packet.
1124func isDriverSetting(key string) bool {
1125 switch key {
1126 case "host", "port":
1127 return true
1128 case "password":
1129 return true
1130 case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni":
1131 return true
1132 case "fallback_application_name":
1133 return true
1134 case "connect_timeout":
1135 return true
1136 case "disable_prepared_binary_result":
1137 return true
1138 case "binary_parameters":
1139 return true
1140 case "krbsrvname":
1141 return true
1142 case "krbspn":
1143 return true
1144 default:
1145 return false
1146 }
1147}
1148
1149func (cn *conn) startup(o values) {
1150 w := cn.writeBuf(0)
1151 w.int32(196608)
1152 // Send the backend the name of the database we want to connect to, and the
1153 // user we want to connect as. Additionally, we send over any run-time
1154 // parameters potentially included in the connection string. If the server
1155 // doesn't recognize any of them, it will reply with an error.
1156 for k, v := range o {
1157 if isDriverSetting(k) {
1158 // skip options which can't be run-time parameters
1159 continue
1160 }
1161 // The protocol requires us to supply the database name as "database"
1162 // instead of "dbname".
1163 if k == "dbname" {
1164 k = "database"
1165 }
1166 w.string(k)
1167 w.string(v)
1168 }
1169 w.string("")
1170 if err := cn.sendStartupPacket(w); err != nil {
1171 panic(err)
1172 }
1173
1174 for {
1175 t, r := cn.recv()
1176 switch t {
1177 case 'K':
1178 cn.processBackendKeyData(r)
1179 case 'S':
1180 cn.processParameterStatus(r)
1181 case 'R':
1182 cn.auth(r, o)
1183 case 'Z':
1184 cn.processReadyForQuery(r)
1185 return
1186 default:
1187 errorf("unknown response for startup: %q", t)
1188 }
1189 }
1190}
1191
1192func (cn *conn) auth(r *readBuf, o values) {
1193 switch code := r.int32(); code {
1194 case 0:
1195 // OK
1196 case 3:
1197 w := cn.writeBuf('p')
1198 w.string(o["password"])
1199 cn.send(w)
1200
1201 t, r := cn.recv()
1202 if t != 'R' {
1203 errorf("unexpected password response: %q", t)
1204 }
1205
1206 if r.int32() != 0 {
1207 errorf("unexpected authentication response: %q", t)
1208 }
1209 case 5:
1210 s := string(r.next(4))
1211 w := cn.writeBuf('p')
1212 w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
1213 cn.send(w)
1214
1215 t, r := cn.recv()
1216 if t != 'R' {
1217 errorf("unexpected password response: %q", t)
1218 }
1219
1220 if r.int32() != 0 {
1221 errorf("unexpected authentication response: %q", t)
1222 }
1223 case 7: // GSSAPI, startup
1224 if newGss == nil {
1225 errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
1226 }
1227 cli, err := newGss()
1228 if err != nil {
1229 errorf("kerberos error: %s", err.Error())
1230 }
1231
1232 var token []byte
1233
1234 if spn, ok := o["krbspn"]; ok {
1235 // Use the supplied SPN if provided..
1236 token, err = cli.GetInitTokenFromSpn(spn)
1237 } else {
1238 // Allow the kerberos service name to be overridden
1239 service := "postgres"
1240 if val, ok := o["krbsrvname"]; ok {
1241 service = val
1242 }
1243
1244 token, err = cli.GetInitToken(o["host"], service)
1245 }
1246
1247 if err != nil {
1248 errorf("failed to get Kerberos ticket: %q", err)
1249 }
1250
1251 w := cn.writeBuf('p')
1252 w.bytes(token)
1253 cn.send(w)
1254
1255 // Store for GSSAPI continue message
1256 cn.gss = cli
1257
1258 case 8: // GSSAPI continue
1259
1260 if cn.gss == nil {
1261 errorf("GSSAPI protocol error")
1262 }
1263
1264 b := []byte(*r)
1265
1266 done, tokOut, err := cn.gss.Continue(b)
1267 if err == nil && !done {
1268 w := cn.writeBuf('p')
1269 w.bytes(tokOut)
1270 cn.send(w)
1271 }
1272
1273 // Errors fall through and read the more detailed message
1274 // from the server..
1275
1276 case 10:
1277 sc := scram.NewClient(sha256.New, o["user"], o["password"])
1278 sc.Step(nil)
1279 if sc.Err() != nil {
1280 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1281 }
1282 scOut := sc.Out()
1283
1284 w := cn.writeBuf('p')
1285 w.string("SCRAM-SHA-256")
1286 w.int32(len(scOut))
1287 w.bytes(scOut)
1288 cn.send(w)
1289
1290 t, r := cn.recv()
1291 if t != 'R' {
1292 errorf("unexpected password response: %q", t)
1293 }
1294
1295 if r.int32() != 11 {
1296 errorf("unexpected authentication response: %q", t)
1297 }
1298
1299 nextStep := r.next(len(*r))
1300 sc.Step(nextStep)
1301 if sc.Err() != nil {
1302 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1303 }
1304
1305 scOut = sc.Out()
1306 w = cn.writeBuf('p')
1307 w.bytes(scOut)
1308 cn.send(w)
1309
1310 t, r = cn.recv()
1311 if t != 'R' {
1312 errorf("unexpected password response: %q", t)
1313 }
1314
1315 if r.int32() != 12 {
1316 errorf("unexpected authentication response: %q", t)
1317 }
1318
1319 nextStep = r.next(len(*r))
1320 sc.Step(nextStep)
1321 if sc.Err() != nil {
1322 errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1323 }
1324
1325 default:
1326 errorf("unknown authentication response: %d", code)
1327 }
1328}
1329
1330type format int
1331
1332const formatText format = 0
1333const formatBinary format = 1
1334
1335// One result-column format code with the value 1 (i.e. all binary).
1336var colFmtDataAllBinary = []byte{0, 1, 0, 1}
1337
1338// No result-column format codes (i.e. all text).
1339var colFmtDataAllText = []byte{0, 0}
1340
1341type stmt struct {
1342 cn *conn
1343 name string
1344 rowsHeader
1345 colFmtData []byte
1346 paramTyps []oid.Oid
1347 closed bool
1348}
1349
1350func (st *stmt) Close() (err error) {
1351 if st.closed {
1352 return nil
1353 }
1354 if err := st.cn.err.get(); err != nil {
1355 return err
1356 }
1357 defer st.cn.errRecover(&err)
1358
1359 w := st.cn.writeBuf('C')
1360 w.byte('S')
1361 w.string(st.name)
1362 st.cn.send(w)
1363
1364 st.cn.send(st.cn.writeBuf('S'))
1365
1366 t, _ := st.cn.recv1()
1367 if t != '3' {
1368 st.cn.err.set(driver.ErrBadConn)
1369 errorf("unexpected close response: %q", t)
1370 }
1371 st.closed = true
1372
1373 t, r := st.cn.recv1()
1374 if t != 'Z' {
1375 st.cn.err.set(driver.ErrBadConn)
1376 errorf("expected ready for query, but got: %q", t)
1377 }
1378 st.cn.processReadyForQuery(r)
1379
1380 return nil
1381}
1382
1383func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
1384 return st.query(v)
1385}
1386
1387func (st *stmt) query(v []driver.Value) (r *rows, err error) {
1388 if err := st.cn.err.get(); err != nil {
1389 return nil, err
1390 }
1391 defer st.cn.errRecover(&err)
1392
1393 st.exec(v)
1394 return &rows{
1395 cn: st.cn,
1396 rowsHeader: st.rowsHeader,
1397 }, nil
1398}
1399
1400func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
1401 if err := st.cn.err.get(); err != nil {
1402 return nil, err
1403 }
1404 defer st.cn.errRecover(&err)
1405
1406 st.exec(v)
1407 res, _, err = st.cn.readExecuteResponse("simple query")
1408 return res, err
1409}
1410
1411func (st *stmt) exec(v []driver.Value) {
1412 if len(v) >= 65536 {
1413 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
1414 }
1415 if len(v) != len(st.paramTyps) {
1416 errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
1417 }
1418
1419 cn := st.cn
1420 w := cn.writeBuf('B')
1421 w.byte(0) // unnamed portal
1422 w.string(st.name)
1423
1424 if cn.binaryParameters {
1425 cn.sendBinaryParameters(w, v)
1426 } else {
1427 w.int16(0)
1428 w.int16(len(v))
1429 for i, x := range v {
1430 if x == nil {
1431 w.int32(-1)
1432 } else {
1433 b := encode(&cn.parameterStatus, x, st.paramTyps[i])
1434 w.int32(len(b))
1435 w.bytes(b)
1436 }
1437 }
1438 }
1439 w.bytes(st.colFmtData)
1440
1441 w.next('E')
1442 w.byte(0)
1443 w.int32(0)
1444
1445 w.next('S')
1446 cn.send(w)
1447
1448 cn.readBindResponse()
1449 cn.postExecuteWorkaround()
1450
1451}
1452
1453func (st *stmt) NumInput() int {
1454 return len(st.paramTyps)
1455}
1456
1457// parseComplete parses the "command tag" from a CommandComplete message, and
1458// returns the number of rows affected (if applicable) and a string
1459// identifying only the command that was executed, e.g. "ALTER TABLE". If the
1460// command tag could not be parsed, parseComplete panics.
1461func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
1462 commandsWithAffectedRows := []string{
1463 "SELECT ",
1464 // INSERT is handled below
1465 "UPDATE ",
1466 "DELETE ",
1467 "FETCH ",
1468 "MOVE ",
1469 "COPY ",
1470 }
1471
1472 var affectedRows *string
1473 for _, tag := range commandsWithAffectedRows {
1474 if strings.HasPrefix(commandTag, tag) {
1475 t := commandTag[len(tag):]
1476 affectedRows = &t
1477 commandTag = tag[:len(tag)-1]
1478 break
1479 }
1480 }
1481 // INSERT also includes the oid of the inserted row in its command tag.
1482 // Oids in user tables are deprecated, and the oid is only returned when
1483 // exactly one row is inserted, so it's unlikely to be of value to any
1484 // real-world application and we can ignore it.
1485 if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
1486 parts := strings.Split(commandTag, " ")
1487 if len(parts) != 3 {
1488 cn.err.set(driver.ErrBadConn)
1489 errorf("unexpected INSERT command tag %s", commandTag)
1490 }
1491 affectedRows = &parts[len(parts)-1]
1492 commandTag = "INSERT"
1493 }
1494 // There should be no affected rows attached to the tag, just return it
1495 if affectedRows == nil {
1496 return driver.RowsAffected(0), commandTag
1497 }
1498 n, err := strconv.ParseInt(*affectedRows, 10, 64)
1499 if err != nil {
1500 cn.err.set(driver.ErrBadConn)
1501 errorf("could not parse commandTag: %s", err)
1502 }
1503 return driver.RowsAffected(n), commandTag
1504}
1505
1506type rowsHeader struct {
1507 colNames []string
1508 colTyps []fieldDesc
1509 colFmts []format
1510}
1511
1512type rows struct {
1513 cn *conn
1514 finish func()
1515 rowsHeader
1516 done bool
1517 rb readBuf
1518 result driver.Result
1519 tag string
1520
1521 next *rowsHeader
1522}
1523
1524func (rs *rows) Close() error {
1525 if finish := rs.finish; finish != nil {
1526 defer finish()
1527 }
1528 // no need to look at cn.bad as Next() will
1529 for {
1530 err := rs.Next(nil)
1531 switch err {
1532 case nil:
1533 case io.EOF:
1534 // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
1535 // description, used with HasNextResultSet). We need to fetch messages until
1536 // we hit a 'Z', which is done by waiting for done to be set.
1537 if rs.done {
1538 return nil
1539 }
1540 default:
1541 return err
1542 }
1543 }
1544}
1545
1546func (rs *rows) Columns() []string {
1547 return rs.colNames
1548}
1549
1550func (rs *rows) Result() driver.Result {
1551 if rs.result == nil {
1552 return emptyRows
1553 }
1554 return rs.result
1555}
1556
1557func (rs *rows) Tag() string {
1558 return rs.tag
1559}
1560
1561func (rs *rows) Next(dest []driver.Value) (err error) {
1562 if rs.done {
1563 return io.EOF
1564 }
1565
1566 conn := rs.cn
1567 if err := conn.err.getForNext(); err != nil {
1568 return err
1569 }
1570 defer conn.errRecover(&err)
1571
1572 for {
1573 t := conn.recv1Buf(&rs.rb)
1574 switch t {
1575 case 'E':
1576 err = parseError(&rs.rb)
1577 case 'C', 'I':
1578 if t == 'C' {
1579 rs.result, rs.tag = conn.parseComplete(rs.rb.string())
1580 }
1581 continue
1582 case 'Z':
1583 conn.processReadyForQuery(&rs.rb)
1584 rs.done = true
1585 if err != nil {
1586 return err
1587 }
1588 return io.EOF
1589 case 'D':
1590 n := rs.rb.int16()
1591 if err != nil {
1592 conn.err.set(driver.ErrBadConn)
1593 errorf("unexpected DataRow after error %s", err)
1594 }
1595 if n < len(dest) {
1596 dest = dest[:n]
1597 }
1598 for i := range dest {
1599 l := rs.rb.int32()
1600 if l == -1 {
1601 dest[i] = nil
1602 continue
1603 }
1604 dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
1605 }
1606 return
1607 case 'T':
1608 next := parsePortalRowDescribe(&rs.rb)
1609 rs.next = &next
1610 return io.EOF
1611 default:
1612 errorf("unexpected message after execute: %q", t)
1613 }
1614 }
1615}
1616
1617func (rs *rows) HasNextResultSet() bool {
1618 hasNext := rs.next != nil && !rs.done
1619 return hasNext
1620}
1621
1622func (rs *rows) NextResultSet() error {
1623 if rs.next == nil {
1624 return io.EOF
1625 }
1626 rs.rowsHeader = *rs.next
1627 rs.next = nil
1628 return nil
1629}
1630
1631// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
1632// used as part of an SQL statement. For example:
1633//
1634// tblname := "my_table"
1635// data := "my_data"
1636// quoted := pq.QuoteIdentifier(tblname)
1637// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
1638//
1639// Any double quotes in name will be escaped. The quoted identifier will be
1640// case sensitive when used in a query. If the input string contains a zero
1641// byte, the result will be truncated immediately before it.
1642func QuoteIdentifier(name string) string {
1643 end := strings.IndexRune(name, 0)
1644 if end > -1 {
1645 name = name[:end]
1646 }
1647 return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
1648}
1649
1650// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
1651// to DDL and other statements that do not accept parameters) to be used as part
1652// of an SQL statement. For example:
1653//
1654// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
1655// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
1656//
1657// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
1658// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
1659// that PostgreSQL provides ('E') will be prepended to the string.
1660func QuoteLiteral(literal string) string {
1661 // This follows the PostgreSQL internal algorithm for handling quoted literals
1662 // from libpq, which can be found in the "PQEscapeStringInternal" function,
1663 // which is found in the libpq/fe-exec.c source file:
1664 // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
1665 //
1666 // substitute any single-quotes (') with two single-quotes ('')
1667 literal = strings.Replace(literal, `'`, `''`, -1)
1668 // determine if the string has any backslashes (\) in it.
1669 // if it does, replace any backslashes (\) with two backslashes (\\)
1670 // then, we need to wrap the entire string with a PostgreSQL
1671 // C-style escape. Per how "PQEscapeStringInternal" handles this case, we
1672 // also add a space before the "E"
1673 if strings.Contains(literal, `\`) {
1674 literal = strings.Replace(literal, `\`, `\\`, -1)
1675 literal = ` E'` + literal + `'`
1676 } else {
1677 // otherwise, we can just wrap the literal with a pair of single quotes
1678 literal = `'` + literal + `'`
1679 }
1680 return literal
1681}
1682
1683func md5s(s string) string {
1684 h := md5.New()
1685 h.Write([]byte(s))
1686 return fmt.Sprintf("%x", h.Sum(nil))
1687}
1688
1689func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
1690 // Do one pass over the parameters to see if we're going to send any of
1691 // them over in binary. If we are, create a paramFormats array at the
1692 // same time.
1693 var paramFormats []int
1694 for i, x := range args {
1695 _, ok := x.([]byte)
1696 if ok {
1697 if paramFormats == nil {
1698 paramFormats = make([]int, len(args))
1699 }
1700 paramFormats[i] = 1
1701 }
1702 }
1703 if paramFormats == nil {
1704 b.int16(0)
1705 } else {
1706 b.int16(len(paramFormats))
1707 for _, x := range paramFormats {
1708 b.int16(x)
1709 }
1710 }
1711
1712 b.int16(len(args))
1713 for _, x := range args {
1714 if x == nil {
1715 b.int32(-1)
1716 } else {
1717 datum := binaryEncode(&cn.parameterStatus, x)
1718 b.int32(len(datum))
1719 b.bytes(datum)
1720 }
1721 }
1722}
1723
1724func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
1725 if len(args) >= 65536 {
1726 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
1727 }
1728
1729 b := cn.writeBuf('P')
1730 b.byte(0) // unnamed statement
1731 b.string(query)
1732 b.int16(0)
1733
1734 b.next('B')
1735 b.int16(0) // unnamed portal and statement
1736 cn.sendBinaryParameters(b, args)
1737 b.bytes(colFmtDataAllText)
1738
1739 b.next('D')
1740 b.byte('P')
1741 b.byte(0) // unnamed portal
1742
1743 b.next('E')
1744 b.byte(0)
1745 b.int32(0)
1746
1747 b.next('S')
1748 cn.send(b)
1749}
1750
1751func (cn *conn) processParameterStatus(r *readBuf) {
1752 var err error
1753
1754 param := r.string()
1755 switch param {
1756 case "server_version":
1757 var major1 int
1758 var major2 int
1759 _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2)
1760 if err == nil {
1761 cn.parameterStatus.serverVersion = major1*10000 + major2*100
1762 }
1763
1764 case "TimeZone":
1765 cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
1766 if err != nil {
1767 cn.parameterStatus.currentLocation = nil
1768 }
1769
1770 default:
1771 // ignore
1772 }
1773}
1774
1775func (cn *conn) processReadyForQuery(r *readBuf) {
1776 cn.txnStatus = transactionStatus(r.byte())
1777}
1778
1779func (cn *conn) readReadyForQuery() {
1780 t, r := cn.recv1()
1781 switch t {
1782 case 'Z':
1783 cn.processReadyForQuery(r)
1784 return
1785 default:
1786 cn.err.set(driver.ErrBadConn)
1787 errorf("unexpected message %q; expected ReadyForQuery", t)
1788 }
1789}
1790
1791func (cn *conn) processBackendKeyData(r *readBuf) {
1792 cn.processID = r.int32()
1793 cn.secretKey = r.int32()
1794}
1795
1796func (cn *conn) readParseResponse() {
1797 t, r := cn.recv1()
1798 switch t {
1799 case '1':
1800 return
1801 case 'E':
1802 err := parseError(r)
1803 cn.readReadyForQuery()
1804 panic(err)
1805 default:
1806 cn.err.set(driver.ErrBadConn)
1807 errorf("unexpected Parse response %q", t)
1808 }
1809}
1810
1811func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
1812 for {
1813 t, r := cn.recv1()
1814 switch t {
1815 case 't':
1816 nparams := r.int16()
1817 paramTyps = make([]oid.Oid, nparams)
1818 for i := range paramTyps {
1819 paramTyps[i] = r.oid()
1820 }
1821 case 'n':
1822 return paramTyps, nil, nil
1823 case 'T':
1824 colNames, colTyps = parseStatementRowDescribe(r)
1825 return paramTyps, colNames, colTyps
1826 case 'E':
1827 err := parseError(r)
1828 cn.readReadyForQuery()
1829 panic(err)
1830 default:
1831 cn.err.set(driver.ErrBadConn)
1832 errorf("unexpected Describe statement response %q", t)
1833 }
1834 }
1835}
1836
1837func (cn *conn) readPortalDescribeResponse() rowsHeader {
1838 t, r := cn.recv1()
1839 switch t {
1840 case 'T':
1841 return parsePortalRowDescribe(r)
1842 case 'n':
1843 return rowsHeader{}
1844 case 'E':
1845 err := parseError(r)
1846 cn.readReadyForQuery()
1847 panic(err)
1848 default:
1849 cn.err.set(driver.ErrBadConn)
1850 errorf("unexpected Describe response %q", t)
1851 }
1852 panic("not reached")
1853}
1854
1855func (cn *conn) readBindResponse() {
1856 t, r := cn.recv1()
1857 switch t {
1858 case '2':
1859 return
1860 case 'E':
1861 err := parseError(r)
1862 cn.readReadyForQuery()
1863 panic(err)
1864 default:
1865 cn.err.set(driver.ErrBadConn)
1866 errorf("unexpected Bind response %q", t)
1867 }
1868}
1869
1870func (cn *conn) postExecuteWorkaround() {
1871 // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
1872 // any errors from rows.Next, which masks errors that happened during the
1873 // execution of the query. To avoid the problem in common cases, we wait
1874 // here for one more message from the database. If it's not an error the
1875 // query will likely succeed (or perhaps has already, if it's a
1876 // CommandComplete), so we push the message into the conn struct; recv1
1877 // will return it as the next message for rows.Next or rows.Close.
1878 // However, if it's an error, we wait until ReadyForQuery and then return
1879 // the error to our caller.
1880 for {
1881 t, r := cn.recv1()
1882 switch t {
1883 case 'E':
1884 err := parseError(r)
1885 cn.readReadyForQuery()
1886 panic(err)
1887 case 'C', 'D', 'I':
1888 // the query didn't fail, but we can't process this message
1889 cn.saveMessage(t, r)
1890 return
1891 default:
1892 cn.err.set(driver.ErrBadConn)
1893 errorf("unexpected message during extended query execution: %q", t)
1894 }
1895 }
1896}
1897
1898// Only for Exec(), since we ignore the returned data
1899func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
1900 for {
1901 t, r := cn.recv1()
1902 switch t {
1903 case 'C':
1904 if err != nil {
1905 cn.err.set(driver.ErrBadConn)
1906 errorf("unexpected CommandComplete after error %s", err)
1907 }
1908 res, commandTag = cn.parseComplete(r.string())
1909 case 'Z':
1910 cn.processReadyForQuery(r)
1911 if res == nil && err == nil {
1912 err = errUnexpectedReady
1913 }
1914 return res, commandTag, err
1915 case 'E':
1916 err = parseError(r)
1917 case 'T', 'D', 'I':
1918 if err != nil {
1919 cn.err.set(driver.ErrBadConn)
1920 errorf("unexpected %q after error %s", t, err)
1921 }
1922 if t == 'I' {
1923 res = emptyRows
1924 }
1925 // ignore any results
1926 default:
1927 cn.err.set(driver.ErrBadConn)
1928 errorf("unknown %s response: %q", protocolState, t)
1929 }
1930 }
1931}
1932
1933func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
1934 n := r.int16()
1935 colNames = make([]string, n)
1936 colTyps = make([]fieldDesc, n)
1937 for i := range colNames {
1938 colNames[i] = r.string()
1939 r.next(6)
1940 colTyps[i].OID = r.oid()
1941 colTyps[i].Len = r.int16()
1942 colTyps[i].Mod = r.int32()
1943 // format code not known when describing a statement; always 0
1944 r.next(2)
1945 }
1946 return
1947}
1948
1949func parsePortalRowDescribe(r *readBuf) rowsHeader {
1950 n := r.int16()
1951 colNames := make([]string, n)
1952 colFmts := make([]format, n)
1953 colTyps := make([]fieldDesc, n)
1954 for i := range colNames {
1955 colNames[i] = r.string()
1956 r.next(6)
1957 colTyps[i].OID = r.oid()
1958 colTyps[i].Len = r.int16()
1959 colTyps[i].Mod = r.int32()
1960 colFmts[i] = format(r.int16())
1961 }
1962 return rowsHeader{
1963 colNames: colNames,
1964 colFmts: colFmts,
1965 colTyps: colTyps,
1966 }
1967}
1968
1969// parseEnviron tries to mimic some of libpq's environment handling
1970//
1971// To ease testing, it does not directly reference os.Environ, but is
1972// designed to accept its output.
1973//
1974// Environment-set connection information is intended to have a higher
1975// precedence than a library default but lower than any explicitly
1976// passed information (such as in the URL or connection string).
1977func parseEnviron(env []string) (out map[string]string) {
1978 out = make(map[string]string)
1979
1980 for _, v := range env {
1981 parts := strings.SplitN(v, "=", 2)
1982
1983 accrue := func(keyname string) {
1984 out[keyname] = parts[1]
1985 }
1986 unsupported := func() {
1987 panic(fmt.Sprintf("setting %v not supported", parts[0]))
1988 }
1989
1990 // The order of these is the same as is seen in the
1991 // PostgreSQL 9.1 manual. Unsupported but well-defined
1992 // keys cause a panic; these should be unset prior to
1993 // execution. Options which pq expects to be set to a
1994 // certain value are allowed, but must be set to that
1995 // value if present (they can, of course, be absent).
1996 switch parts[0] {
1997 case "PGHOST":
1998 accrue("host")
1999 case "PGHOSTADDR":
2000 unsupported()
2001 case "PGPORT":
2002 accrue("port")
2003 case "PGDATABASE":
2004 accrue("dbname")
2005 case "PGUSER":
2006 accrue("user")
2007 case "PGPASSWORD":
2008 accrue("password")
2009 case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
2010 unsupported()
2011 case "PGOPTIONS":
2012 accrue("options")
2013 case "PGAPPNAME":
2014 accrue("application_name")
2015 case "PGSSLMODE":
2016 accrue("sslmode")
2017 case "PGSSLCERT":
2018 accrue("sslcert")
2019 case "PGSSLKEY":
2020 accrue("sslkey")
2021 case "PGSSLROOTCERT":
2022 accrue("sslrootcert")
2023 case "PGSSLSNI":
2024 accrue("sslsni")
2025 case "PGREQUIRESSL", "PGSSLCRL":
2026 unsupported()
2027 case "PGREQUIREPEER":
2028 unsupported()
2029 case "PGKRBSRVNAME", "PGGSSLIB":
2030 unsupported()
2031 case "PGCONNECT_TIMEOUT":
2032 accrue("connect_timeout")
2033 case "PGCLIENTENCODING":
2034 accrue("client_encoding")
2035 case "PGDATESTYLE":
2036 accrue("datestyle")
2037 case "PGTZ":
2038 accrue("timezone")
2039 case "PGGEQO":
2040 accrue("geqo")
2041 case "PGSYSCONFDIR", "PGLOCALEDIR":
2042 unsupported()
2043 }
2044 }
2045
2046 return out
2047}
2048
2049// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
2050func isUTF8(name string) bool {
2051 // Recognize all sorts of silly things as "UTF-8", like Postgres does
2052 s := strings.Map(alnumLowerASCII, name)
2053 return s == "utf8" || s == "unicode"
2054}
2055
2056func alnumLowerASCII(ch rune) rune {
2057 if 'A' <= ch && ch <= 'Z' {
2058 return ch + ('a' - 'A')
2059 }
2060 if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
2061 return ch
2062 }
2063 return -1 // discard
2064}
Note: See TracBrowser for help on using the repository browser.