source: code/trunk/vendor/github.com/lib/pq/encode.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 16.4 KB
Line 
1package pq
2
3import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/binary"
7 "encoding/hex"
8 "errors"
9 "fmt"
10 "math"
11 "regexp"
12 "strconv"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/lib/pq/oid"
18)
19
20var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`)
21
22func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
23 switch v := x.(type) {
24 case []byte:
25 return v
26 default:
27 return encode(parameterStatus, x, oid.T_unknown)
28 }
29}
30
31func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
32 switch v := x.(type) {
33 case int64:
34 return strconv.AppendInt(nil, v, 10)
35 case float64:
36 return strconv.AppendFloat(nil, v, 'f', -1, 64)
37 case []byte:
38 if pgtypOid == oid.T_bytea {
39 return encodeBytea(parameterStatus.serverVersion, v)
40 }
41
42 return v
43 case string:
44 if pgtypOid == oid.T_bytea {
45 return encodeBytea(parameterStatus.serverVersion, []byte(v))
46 }
47
48 return []byte(v)
49 case bool:
50 return strconv.AppendBool(nil, v)
51 case time.Time:
52 return formatTs(v)
53
54 default:
55 errorf("encode: unknown type for %T", v)
56 }
57
58 panic("not reached")
59}
60
61func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
62 switch f {
63 case formatBinary:
64 return binaryDecode(parameterStatus, s, typ)
65 case formatText:
66 return textDecode(parameterStatus, s, typ)
67 default:
68 panic("not reached")
69 }
70}
71
72func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
73 switch typ {
74 case oid.T_bytea:
75 return s
76 case oid.T_int8:
77 return int64(binary.BigEndian.Uint64(s))
78 case oid.T_int4:
79 return int64(int32(binary.BigEndian.Uint32(s)))
80 case oid.T_int2:
81 return int64(int16(binary.BigEndian.Uint16(s)))
82 case oid.T_uuid:
83 b, err := decodeUUIDBinary(s)
84 if err != nil {
85 panic(err)
86 }
87 return b
88
89 default:
90 errorf("don't know how to decode binary parameter of type %d", uint32(typ))
91 }
92
93 panic("not reached")
94}
95
96func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
97 switch typ {
98 case oid.T_char, oid.T_varchar, oid.T_text:
99 return string(s)
100 case oid.T_bytea:
101 b, err := parseBytea(s)
102 if err != nil {
103 errorf("%s", err)
104 }
105 return b
106 case oid.T_timestamptz:
107 return parseTs(parameterStatus.currentLocation, string(s))
108 case oid.T_timestamp, oid.T_date:
109 return parseTs(nil, string(s))
110 case oid.T_time:
111 return mustParse("15:04:05", typ, s)
112 case oid.T_timetz:
113 return mustParse("15:04:05-07", typ, s)
114 case oid.T_bool:
115 return s[0] == 't'
116 case oid.T_int8, oid.T_int4, oid.T_int2:
117 i, err := strconv.ParseInt(string(s), 10, 64)
118 if err != nil {
119 errorf("%s", err)
120 }
121 return i
122 case oid.T_float4, oid.T_float8:
123 // We always use 64 bit parsing, regardless of whether the input text is for
124 // a float4 or float8, because clients expect float64s for all float datatypes
125 // and returning a 32-bit parsed float64 produces lossy results.
126 f, err := strconv.ParseFloat(string(s), 64)
127 if err != nil {
128 errorf("%s", err)
129 }
130 return f
131 }
132
133 return s
134}
135
136// appendEncodedText encodes item in text format as required by COPY
137// and appends to buf
138func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
139 switch v := x.(type) {
140 case int64:
141 return strconv.AppendInt(buf, v, 10)
142 case float64:
143 return strconv.AppendFloat(buf, v, 'f', -1, 64)
144 case []byte:
145 encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
146 return appendEscapedText(buf, string(encodedBytea))
147 case string:
148 return appendEscapedText(buf, v)
149 case bool:
150 return strconv.AppendBool(buf, v)
151 case time.Time:
152 return append(buf, formatTs(v)...)
153 case nil:
154 return append(buf, "\\N"...)
155 default:
156 errorf("encode: unknown type for %T", v)
157 }
158
159 panic("not reached")
160}
161
162func appendEscapedText(buf []byte, text string) []byte {
163 escapeNeeded := false
164 startPos := 0
165 var c byte
166
167 // check if we need to escape
168 for i := 0; i < len(text); i++ {
169 c = text[i]
170 if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
171 escapeNeeded = true
172 startPos = i
173 break
174 }
175 }
176 if !escapeNeeded {
177 return append(buf, text...)
178 }
179
180 // copy till first char to escape, iterate the rest
181 result := append(buf, text[:startPos]...)
182 for i := startPos; i < len(text); i++ {
183 c = text[i]
184 switch c {
185 case '\\':
186 result = append(result, '\\', '\\')
187 case '\n':
188 result = append(result, '\\', 'n')
189 case '\r':
190 result = append(result, '\\', 'r')
191 case '\t':
192 result = append(result, '\\', 't')
193 default:
194 result = append(result, c)
195 }
196 }
197 return result
198}
199
200func mustParse(f string, typ oid.Oid, s []byte) time.Time {
201 str := string(s)
202
203 // Check for a minute and second offset in the timezone.
204 if typ == oid.T_timestamptz || typ == oid.T_timetz {
205 for i := 3; i <= 6; i += 3 {
206 if str[len(str)-i] == ':' {
207 f += ":00"
208 continue
209 }
210 break
211 }
212 }
213
214 // Special case for 24:00 time.
215 // Unfortunately, golang does not parse 24:00 as a proper time.
216 // In this case, we want to try "round to the next day", to differentiate.
217 // As such, we find if the 24:00 time matches at the beginning; if so,
218 // we default it back to 00:00 but add a day later.
219 var is2400Time bool
220 switch typ {
221 case oid.T_timetz, oid.T_time:
222 if matches := time2400Regex.FindStringSubmatch(str); matches != nil {
223 // Concatenate timezone information at the back.
224 str = "00:00:00" + str[len(matches[1]):]
225 is2400Time = true
226 }
227 }
228 t, err := time.Parse(f, str)
229 if err != nil {
230 errorf("decode: %s", err)
231 }
232 if is2400Time {
233 t = t.Add(24 * time.Hour)
234 }
235 return t
236}
237
238var errInvalidTimestamp = errors.New("invalid timestamp")
239
240type timestampParser struct {
241 err error
242}
243
244func (p *timestampParser) expect(str string, char byte, pos int) {
245 if p.err != nil {
246 return
247 }
248 if pos+1 > len(str) {
249 p.err = errInvalidTimestamp
250 return
251 }
252 if c := str[pos]; c != char && p.err == nil {
253 p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
254 }
255}
256
257func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
258 if p.err != nil {
259 return 0
260 }
261 if begin < 0 || end < 0 || begin > end || end > len(str) {
262 p.err = errInvalidTimestamp
263 return 0
264 }
265 result, err := strconv.Atoi(str[begin:end])
266 if err != nil {
267 if p.err == nil {
268 p.err = fmt.Errorf("expected number; got '%v'", str)
269 }
270 return 0
271 }
272 return result
273}
274
275// The location cache caches the time zones typically used by the client.
276type locationCache struct {
277 cache map[int]*time.Location
278 lock sync.Mutex
279}
280
281// All connections share the same list of timezones. Benchmarking shows that
282// about 5% speed could be gained by putting the cache in the connection and
283// losing the mutex, at the cost of a small amount of memory and a somewhat
284// significant increase in code complexity.
285var globalLocationCache = newLocationCache()
286
287func newLocationCache() *locationCache {
288 return &locationCache{cache: make(map[int]*time.Location)}
289}
290
291// Returns the cached timezone for the specified offset, creating and caching
292// it if necessary.
293func (c *locationCache) getLocation(offset int) *time.Location {
294 c.lock.Lock()
295 defer c.lock.Unlock()
296
297 location, ok := c.cache[offset]
298 if !ok {
299 location = time.FixedZone("", offset)
300 c.cache[offset] = location
301 }
302
303 return location
304}
305
306var infinityTsEnabled = false
307var infinityTsNegative time.Time
308var infinityTsPositive time.Time
309
310const (
311 infinityTsEnabledAlready = "pq: infinity timestamp enabled already"
312 infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
313)
314
315// EnableInfinityTs controls the handling of Postgres' "-infinity" and
316// "infinity" "timestamp"s.
317//
318// If EnableInfinityTs is not called, "-infinity" and "infinity" will return
319// []byte("-infinity") and []byte("infinity") respectively, and potentially
320// cause error "sql: Scan error on column index 0: unsupported driver -> Scan
321// pair: []uint8 -> *time.Time", when scanning into a time.Time value.
322//
323// Once EnableInfinityTs has been called, all connections created using this
324// driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
325// "timestamp with time zone" and "date" types to the predefined minimum and
326// maximum times, respectively. When encoding time.Time values, any time which
327// equals or precedes the predefined minimum time will be encoded to
328// "-infinity". Any values at or past the maximum time will similarly be
329// encoded to "infinity".
330//
331// If EnableInfinityTs is called with negative >= positive, it will panic.
332// Calling EnableInfinityTs after a connection has been established results in
333// undefined behavior. If EnableInfinityTs is called more than once, it will
334// panic.
335func EnableInfinityTs(negative time.Time, positive time.Time) {
336 if infinityTsEnabled {
337 panic(infinityTsEnabledAlready)
338 }
339 if !negative.Before(positive) {
340 panic(infinityTsNegativeMustBeSmaller)
341 }
342 infinityTsEnabled = true
343 infinityTsNegative = negative
344 infinityTsPositive = positive
345}
346
347/*
348 * Testing might want to toggle infinityTsEnabled
349 */
350func disableInfinityTs() {
351 infinityTsEnabled = false
352}
353
354// This is a time function specific to the Postgres default DateStyle
355// setting ("ISO, MDY"), the only one we currently support. This
356// accounts for the discrepancies between the parsing available with
357// time.Parse and the Postgres date formatting quirks.
358func parseTs(currentLocation *time.Location, str string) interface{} {
359 switch str {
360 case "-infinity":
361 if infinityTsEnabled {
362 return infinityTsNegative
363 }
364 return []byte(str)
365 case "infinity":
366 if infinityTsEnabled {
367 return infinityTsPositive
368 }
369 return []byte(str)
370 }
371 t, err := ParseTimestamp(currentLocation, str)
372 if err != nil {
373 panic(err)
374 }
375 return t
376}
377
378// ParseTimestamp parses Postgres' text format. It returns a time.Time in
379// currentLocation iff that time's offset agrees with the offset sent from the
380// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the
381// fixed offset offset provided by the Postgres server.
382func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
383 p := timestampParser{}
384
385 monSep := strings.IndexRune(str, '-')
386 // this is Gregorian year, not ISO Year
387 // In Gregorian system, the year 1 BC is followed by AD 1
388 year := p.mustAtoi(str, 0, monSep)
389 daySep := monSep + 3
390 month := p.mustAtoi(str, monSep+1, daySep)
391 p.expect(str, '-', daySep)
392 timeSep := daySep + 3
393 day := p.mustAtoi(str, daySep+1, timeSep)
394
395 minLen := monSep + len("01-01") + 1
396
397 isBC := strings.HasSuffix(str, " BC")
398 if isBC {
399 minLen += 3
400 }
401
402 var hour, minute, second int
403 if len(str) > minLen {
404 p.expect(str, ' ', timeSep)
405 minSep := timeSep + 3
406 p.expect(str, ':', minSep)
407 hour = p.mustAtoi(str, timeSep+1, minSep)
408 secSep := minSep + 3
409 p.expect(str, ':', secSep)
410 minute = p.mustAtoi(str, minSep+1, secSep)
411 secEnd := secSep + 3
412 second = p.mustAtoi(str, secSep+1, secEnd)
413 }
414 remainderIdx := monSep + len("01-01 00:00:00") + 1
415 // Three optional (but ordered) sections follow: the
416 // fractional seconds, the time zone offset, and the BC
417 // designation. We set them up here and adjust the other
418 // offsets if the preceding sections exist.
419
420 nanoSec := 0
421 tzOff := 0
422
423 if remainderIdx < len(str) && str[remainderIdx] == '.' {
424 fracStart := remainderIdx + 1
425 fracOff := strings.IndexAny(str[fracStart:], "-+Z ")
426 if fracOff < 0 {
427 fracOff = len(str) - fracStart
428 }
429 fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
430 nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
431
432 remainderIdx += fracOff + 1
433 }
434 if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
435 // time zone separator is always '-' or '+' or 'Z' (UTC is +00)
436 var tzSign int
437 switch c := str[tzStart]; c {
438 case '-':
439 tzSign = -1
440 case '+':
441 tzSign = +1
442 default:
443 return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
444 }
445 tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
446 remainderIdx += 3
447 var tzMin, tzSec int
448 if remainderIdx < len(str) && str[remainderIdx] == ':' {
449 tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
450 remainderIdx += 3
451 }
452 if remainderIdx < len(str) && str[remainderIdx] == ':' {
453 tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
454 remainderIdx += 3
455 }
456 tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
457 } else if tzStart < len(str) && str[tzStart] == 'Z' {
458 // time zone Z separator indicates UTC is +00
459 remainderIdx += 1
460 }
461
462 var isoYear int
463
464 if isBC {
465 isoYear = 1 - year
466 remainderIdx += 3
467 } else {
468 isoYear = year
469 }
470 if remainderIdx < len(str) {
471 return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
472 }
473 t := time.Date(isoYear, time.Month(month), day,
474 hour, minute, second, nanoSec,
475 globalLocationCache.getLocation(tzOff))
476
477 if currentLocation != nil {
478 // Set the location of the returned Time based on the session's
479 // TimeZone value, but only if the local time zone database agrees with
480 // the remote database on the offset.
481 lt := t.In(currentLocation)
482 _, newOff := lt.Zone()
483 if newOff == tzOff {
484 t = lt
485 }
486 }
487
488 return t, p.err
489}
490
491// formatTs formats t into a format postgres understands.
492func formatTs(t time.Time) []byte {
493 if infinityTsEnabled {
494 // t <= -infinity : ! (t > -infinity)
495 if !t.After(infinityTsNegative) {
496 return []byte("-infinity")
497 }
498 // t >= infinity : ! (!t < infinity)
499 if !t.Before(infinityTsPositive) {
500 return []byte("infinity")
501 }
502 }
503 return FormatTimestamp(t)
504}
505
506// FormatTimestamp formats t into Postgres' text format for timestamps.
507func FormatTimestamp(t time.Time) []byte {
508 // Need to send dates before 0001 A.D. with " BC" suffix, instead of the
509 // minus sign preferred by Go.
510 // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
511 bc := false
512 if t.Year() <= 0 {
513 // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
514 t = t.AddDate((-t.Year())*2+1, 0, 0)
515 bc = true
516 }
517 b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
518
519 _, offset := t.Zone()
520 offset %= 60
521 if offset != 0 {
522 // RFC3339Nano already printed the minus sign
523 if offset < 0 {
524 offset = -offset
525 }
526
527 b = append(b, ':')
528 if offset < 10 {
529 b = append(b, '0')
530 }
531 b = strconv.AppendInt(b, int64(offset), 10)
532 }
533
534 if bc {
535 b = append(b, " BC"...)
536 }
537 return b
538}
539
540// Parse a bytea value received from the server. Both "hex" and the legacy
541// "escape" format are supported.
542func parseBytea(s []byte) (result []byte, err error) {
543 if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
544 // bytea_output = hex
545 s = s[2:] // trim off leading "\\x"
546 result = make([]byte, hex.DecodedLen(len(s)))
547 _, err := hex.Decode(result, s)
548 if err != nil {
549 return nil, err
550 }
551 } else {
552 // bytea_output = escape
553 for len(s) > 0 {
554 if s[0] == '\\' {
555 // escaped '\\'
556 if len(s) >= 2 && s[1] == '\\' {
557 result = append(result, '\\')
558 s = s[2:]
559 continue
560 }
561
562 // '\\' followed by an octal number
563 if len(s) < 4 {
564 return nil, fmt.Errorf("invalid bytea sequence %v", s)
565 }
566 r, err := strconv.ParseUint(string(s[1:4]), 8, 8)
567 if err != nil {
568 return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
569 }
570 result = append(result, byte(r))
571 s = s[4:]
572 } else {
573 // We hit an unescaped, raw byte. Try to read in as many as
574 // possible in one go.
575 i := bytes.IndexByte(s, '\\')
576 if i == -1 {
577 result = append(result, s...)
578 break
579 }
580 result = append(result, s[:i]...)
581 s = s[i:]
582 }
583 }
584 }
585
586 return result, nil
587}
588
589func encodeBytea(serverVersion int, v []byte) (result []byte) {
590 if serverVersion >= 90000 {
591 // Use the hex format if we know that the server supports it
592 result = make([]byte, 2+hex.EncodedLen(len(v)))
593 result[0] = '\\'
594 result[1] = 'x'
595 hex.Encode(result[2:], v)
596 } else {
597 // .. or resort to "escape"
598 for _, b := range v {
599 if b == '\\' {
600 result = append(result, '\\', '\\')
601 } else if b < 0x20 || b > 0x7e {
602 result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
603 } else {
604 result = append(result, b)
605 }
606 }
607 }
608
609 return result
610}
611
612// NullTime represents a time.Time that may be null. NullTime implements the
613// sql.Scanner interface so it can be used as a scan destination, similar to
614// sql.NullString.
615type NullTime struct {
616 Time time.Time
617 Valid bool // Valid is true if Time is not NULL
618}
619
620// Scan implements the Scanner interface.
621func (nt *NullTime) Scan(value interface{}) error {
622 nt.Time, nt.Valid = value.(time.Time)
623 return nil
624}
625
626// Value implements the driver Valuer interface.
627func (nt NullTime) Value() (driver.Value, error) {
628 if !nt.Valid {
629 return nil, nil
630 }
631 return nt.Time, nil
632}
Note: See TracBrowser for help on using the repository browser.