1 | package pq
|
---|
2 |
|
---|
3 | import (
|
---|
4 | "context"
|
---|
5 | "database/sql"
|
---|
6 | "database/sql/driver"
|
---|
7 | "fmt"
|
---|
8 | "io"
|
---|
9 | "io/ioutil"
|
---|
10 | "time"
|
---|
11 | )
|
---|
12 |
|
---|
13 | const (
|
---|
14 | watchCancelDialContextTimeout = time.Second * 10
|
---|
15 | )
|
---|
16 |
|
---|
17 | // Implement the "QueryerContext" interface
|
---|
18 | func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
---|
19 | list := make([]driver.Value, len(args))
|
---|
20 | for i, nv := range args {
|
---|
21 | list[i] = nv.Value
|
---|
22 | }
|
---|
23 | finish := cn.watchCancel(ctx)
|
---|
24 | r, err := cn.query(query, list)
|
---|
25 | if err != nil {
|
---|
26 | if finish != nil {
|
---|
27 | finish()
|
---|
28 | }
|
---|
29 | return nil, err
|
---|
30 | }
|
---|
31 | r.finish = finish
|
---|
32 | return r, nil
|
---|
33 | }
|
---|
34 |
|
---|
35 | // Implement the "ExecerContext" interface
|
---|
36 | func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
---|
37 | list := make([]driver.Value, len(args))
|
---|
38 | for i, nv := range args {
|
---|
39 | list[i] = nv.Value
|
---|
40 | }
|
---|
41 |
|
---|
42 | if finish := cn.watchCancel(ctx); finish != nil {
|
---|
43 | defer finish()
|
---|
44 | }
|
---|
45 |
|
---|
46 | return cn.Exec(query, list)
|
---|
47 | }
|
---|
48 |
|
---|
49 | // Implement the "ConnPrepareContext" interface
|
---|
50 | func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
---|
51 | if finish := cn.watchCancel(ctx); finish != nil {
|
---|
52 | defer finish()
|
---|
53 | }
|
---|
54 | return cn.Prepare(query)
|
---|
55 | }
|
---|
56 |
|
---|
57 | // Implement the "ConnBeginTx" interface
|
---|
58 | func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
---|
59 | var mode string
|
---|
60 |
|
---|
61 | switch sql.IsolationLevel(opts.Isolation) {
|
---|
62 | case sql.LevelDefault:
|
---|
63 | // Don't touch mode: use the server's default
|
---|
64 | case sql.LevelReadUncommitted:
|
---|
65 | mode = " ISOLATION LEVEL READ UNCOMMITTED"
|
---|
66 | case sql.LevelReadCommitted:
|
---|
67 | mode = " ISOLATION LEVEL READ COMMITTED"
|
---|
68 | case sql.LevelRepeatableRead:
|
---|
69 | mode = " ISOLATION LEVEL REPEATABLE READ"
|
---|
70 | case sql.LevelSerializable:
|
---|
71 | mode = " ISOLATION LEVEL SERIALIZABLE"
|
---|
72 | default:
|
---|
73 | return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
|
---|
74 | }
|
---|
75 |
|
---|
76 | if opts.ReadOnly {
|
---|
77 | mode += " READ ONLY"
|
---|
78 | } else {
|
---|
79 | mode += " READ WRITE"
|
---|
80 | }
|
---|
81 |
|
---|
82 | tx, err := cn.begin(mode)
|
---|
83 | if err != nil {
|
---|
84 | return nil, err
|
---|
85 | }
|
---|
86 | cn.txnFinish = cn.watchCancel(ctx)
|
---|
87 | return tx, nil
|
---|
88 | }
|
---|
89 |
|
---|
90 | func (cn *conn) Ping(ctx context.Context) error {
|
---|
91 | if finish := cn.watchCancel(ctx); finish != nil {
|
---|
92 | defer finish()
|
---|
93 | }
|
---|
94 | rows, err := cn.simpleQuery(";")
|
---|
95 | if err != nil {
|
---|
96 | return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
|
---|
97 | }
|
---|
98 | rows.Close()
|
---|
99 | return nil
|
---|
100 | }
|
---|
101 |
|
---|
102 | func (cn *conn) watchCancel(ctx context.Context) func() {
|
---|
103 | if done := ctx.Done(); done != nil {
|
---|
104 | finished := make(chan struct{}, 1)
|
---|
105 | go func() {
|
---|
106 | select {
|
---|
107 | case <-done:
|
---|
108 | select {
|
---|
109 | case finished <- struct{}{}:
|
---|
110 | default:
|
---|
111 | // We raced with the finish func, let the next query handle this with the
|
---|
112 | // context.
|
---|
113 | return
|
---|
114 | }
|
---|
115 |
|
---|
116 | // Set the connection state to bad so it does not get reused.
|
---|
117 | cn.err.set(ctx.Err())
|
---|
118 |
|
---|
119 | // At this point the function level context is canceled,
|
---|
120 | // so it must not be used for the additional network
|
---|
121 | // request to cancel the query.
|
---|
122 | // Create a new context to pass into the dial.
|
---|
123 | ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
|
---|
124 | defer cancel()
|
---|
125 |
|
---|
126 | _ = cn.cancel(ctxCancel)
|
---|
127 | case <-finished:
|
---|
128 | }
|
---|
129 | }()
|
---|
130 | return func() {
|
---|
131 | select {
|
---|
132 | case <-finished:
|
---|
133 | cn.err.set(ctx.Err())
|
---|
134 | cn.Close()
|
---|
135 | case finished <- struct{}{}:
|
---|
136 | }
|
---|
137 | }
|
---|
138 | }
|
---|
139 | return nil
|
---|
140 | }
|
---|
141 |
|
---|
142 | func (cn *conn) cancel(ctx context.Context) error {
|
---|
143 | // Create a new values map (copy). This makes sure the connection created
|
---|
144 | // in this method cannot write to the same underlying data, which could
|
---|
145 | // cause a concurrent map write panic. This is necessary because cancel
|
---|
146 | // is called from a goroutine in watchCancel.
|
---|
147 | o := make(values)
|
---|
148 | for k, v := range cn.opts {
|
---|
149 | o[k] = v
|
---|
150 | }
|
---|
151 |
|
---|
152 | c, err := dial(ctx, cn.dialer, o)
|
---|
153 | if err != nil {
|
---|
154 | return err
|
---|
155 | }
|
---|
156 | defer c.Close()
|
---|
157 |
|
---|
158 | {
|
---|
159 | can := conn{
|
---|
160 | c: c,
|
---|
161 | }
|
---|
162 | err = can.ssl(o)
|
---|
163 | if err != nil {
|
---|
164 | return err
|
---|
165 | }
|
---|
166 |
|
---|
167 | w := can.writeBuf(0)
|
---|
168 | w.int32(80877102) // cancel request code
|
---|
169 | w.int32(cn.processID)
|
---|
170 | w.int32(cn.secretKey)
|
---|
171 |
|
---|
172 | if err := can.sendStartupPacket(w); err != nil {
|
---|
173 | return err
|
---|
174 | }
|
---|
175 | }
|
---|
176 |
|
---|
177 | // Read until EOF to ensure that the server received the cancel.
|
---|
178 | {
|
---|
179 | _, err := io.Copy(ioutil.Discard, c)
|
---|
180 | return err
|
---|
181 | }
|
---|
182 | }
|
---|
183 |
|
---|
184 | // Implement the "StmtQueryContext" interface
|
---|
185 | func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
---|
186 | list := make([]driver.Value, len(args))
|
---|
187 | for i, nv := range args {
|
---|
188 | list[i] = nv.Value
|
---|
189 | }
|
---|
190 | finish := st.watchCancel(ctx)
|
---|
191 | r, err := st.query(list)
|
---|
192 | if err != nil {
|
---|
193 | if finish != nil {
|
---|
194 | finish()
|
---|
195 | }
|
---|
196 | return nil, err
|
---|
197 | }
|
---|
198 | r.finish = finish
|
---|
199 | return r, nil
|
---|
200 | }
|
---|
201 |
|
---|
202 | // Implement the "StmtExecContext" interface
|
---|
203 | func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
---|
204 | list := make([]driver.Value, len(args))
|
---|
205 | for i, nv := range args {
|
---|
206 | list[i] = nv.Value
|
---|
207 | }
|
---|
208 |
|
---|
209 | if finish := st.watchCancel(ctx); finish != nil {
|
---|
210 | defer finish()
|
---|
211 | }
|
---|
212 |
|
---|
213 | return st.Exec(list)
|
---|
214 | }
|
---|
215 |
|
---|
216 | // watchCancel is implemented on stmt in order to not mark the parent conn as bad
|
---|
217 | func (st *stmt) watchCancel(ctx context.Context) func() {
|
---|
218 | if done := ctx.Done(); done != nil {
|
---|
219 | finished := make(chan struct{})
|
---|
220 | go func() {
|
---|
221 | select {
|
---|
222 | case <-done:
|
---|
223 | // At this point the function level context is canceled,
|
---|
224 | // so it must not be used for the additional network
|
---|
225 | // request to cancel the query.
|
---|
226 | // Create a new context to pass into the dial.
|
---|
227 | ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
|
---|
228 | defer cancel()
|
---|
229 |
|
---|
230 | _ = st.cancel(ctxCancel)
|
---|
231 | finished <- struct{}{}
|
---|
232 | case <-finished:
|
---|
233 | }
|
---|
234 | }()
|
---|
235 | return func() {
|
---|
236 | select {
|
---|
237 | case <-finished:
|
---|
238 | case finished <- struct{}{}:
|
---|
239 | }
|
---|
240 | }
|
---|
241 | }
|
---|
242 | return nil
|
---|
243 | }
|
---|
244 |
|
---|
245 | func (st *stmt) cancel(ctx context.Context) error {
|
---|
246 | return st.cn.cancel(ctx)
|
---|
247 | }
|
---|