1 | // Copyright 2021 Google Inc. All rights reserved.
|
---|
2 | // Use of this source code is governed by a BSD-style
|
---|
3 | // license that can be found in the LICENSE file.
|
---|
4 |
|
---|
5 | package uuid
|
---|
6 |
|
---|
7 | import (
|
---|
8 | "bytes"
|
---|
9 | "database/sql/driver"
|
---|
10 | "encoding/json"
|
---|
11 | "fmt"
|
---|
12 | )
|
---|
13 |
|
---|
14 | var jsonNull = []byte("null")
|
---|
15 |
|
---|
16 | // NullUUID represents a UUID that may be null.
|
---|
17 | // NullUUID implements the SQL driver.Scanner interface so
|
---|
18 | // it can be used as a scan destination:
|
---|
19 | //
|
---|
20 | // var u uuid.NullUUID
|
---|
21 | // err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&u)
|
---|
22 | // ...
|
---|
23 | // if u.Valid {
|
---|
24 | // // use u.UUID
|
---|
25 | // } else {
|
---|
26 | // // NULL value
|
---|
27 | // }
|
---|
28 | //
|
---|
29 | type NullUUID struct {
|
---|
30 | UUID UUID
|
---|
31 | Valid bool // Valid is true if UUID is not NULL
|
---|
32 | }
|
---|
33 |
|
---|
34 | // Scan implements the SQL driver.Scanner interface.
|
---|
35 | func (nu *NullUUID) Scan(value interface{}) error {
|
---|
36 | if value == nil {
|
---|
37 | nu.UUID, nu.Valid = Nil, false
|
---|
38 | return nil
|
---|
39 | }
|
---|
40 |
|
---|
41 | err := nu.UUID.Scan(value)
|
---|
42 | if err != nil {
|
---|
43 | nu.Valid = false
|
---|
44 | return err
|
---|
45 | }
|
---|
46 |
|
---|
47 | nu.Valid = true
|
---|
48 | return nil
|
---|
49 | }
|
---|
50 |
|
---|
51 | // Value implements the driver Valuer interface.
|
---|
52 | func (nu NullUUID) Value() (driver.Value, error) {
|
---|
53 | if !nu.Valid {
|
---|
54 | return nil, nil
|
---|
55 | }
|
---|
56 | // Delegate to UUID Value function
|
---|
57 | return nu.UUID.Value()
|
---|
58 | }
|
---|
59 |
|
---|
60 | // MarshalBinary implements encoding.BinaryMarshaler.
|
---|
61 | func (nu NullUUID) MarshalBinary() ([]byte, error) {
|
---|
62 | if nu.Valid {
|
---|
63 | return nu.UUID[:], nil
|
---|
64 | }
|
---|
65 |
|
---|
66 | return []byte(nil), nil
|
---|
67 | }
|
---|
68 |
|
---|
69 | // UnmarshalBinary implements encoding.BinaryUnmarshaler.
|
---|
70 | func (nu *NullUUID) UnmarshalBinary(data []byte) error {
|
---|
71 | if len(data) != 16 {
|
---|
72 | return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
|
---|
73 | }
|
---|
74 | copy(nu.UUID[:], data)
|
---|
75 | nu.Valid = true
|
---|
76 | return nil
|
---|
77 | }
|
---|
78 |
|
---|
79 | // MarshalText implements encoding.TextMarshaler.
|
---|
80 | func (nu NullUUID) MarshalText() ([]byte, error) {
|
---|
81 | if nu.Valid {
|
---|
82 | return nu.UUID.MarshalText()
|
---|
83 | }
|
---|
84 |
|
---|
85 | return jsonNull, nil
|
---|
86 | }
|
---|
87 |
|
---|
88 | // UnmarshalText implements encoding.TextUnmarshaler.
|
---|
89 | func (nu *NullUUID) UnmarshalText(data []byte) error {
|
---|
90 | id, err := ParseBytes(data)
|
---|
91 | if err != nil {
|
---|
92 | nu.Valid = false
|
---|
93 | return err
|
---|
94 | }
|
---|
95 | nu.UUID = id
|
---|
96 | nu.Valid = true
|
---|
97 | return nil
|
---|
98 | }
|
---|
99 |
|
---|
100 | // MarshalJSON implements json.Marshaler.
|
---|
101 | func (nu NullUUID) MarshalJSON() ([]byte, error) {
|
---|
102 | if nu.Valid {
|
---|
103 | return json.Marshal(nu.UUID)
|
---|
104 | }
|
---|
105 |
|
---|
106 | return jsonNull, nil
|
---|
107 | }
|
---|
108 |
|
---|
109 | // UnmarshalJSON implements json.Unmarshaler.
|
---|
110 | func (nu *NullUUID) UnmarshalJSON(data []byte) error {
|
---|
111 | if bytes.Equal(data, jsonNull) {
|
---|
112 | *nu = NullUUID{}
|
---|
113 | return nil // valid null UUID
|
---|
114 | }
|
---|
115 | err := json.Unmarshal(data, &nu.UUID)
|
---|
116 | nu.Valid = err == nil
|
---|
117 | return err
|
---|
118 | }
|
---|