1 | package pq
|
---|
2 |
|
---|
3 | import (
|
---|
4 | "context"
|
---|
5 | "database/sql/driver"
|
---|
6 | "encoding/binary"
|
---|
7 | "errors"
|
---|
8 | "fmt"
|
---|
9 | "sync"
|
---|
10 | )
|
---|
11 |
|
---|
12 | var (
|
---|
13 | errCopyInClosed = errors.New("pq: copyin statement has already been closed")
|
---|
14 | errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
|
---|
15 | errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
|
---|
16 | errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
|
---|
17 | errCopyInProgress = errors.New("pq: COPY in progress")
|
---|
18 | )
|
---|
19 |
|
---|
20 | // CopyIn creates a COPY FROM statement which can be prepared with
|
---|
21 | // Tx.Prepare(). The target table should be visible in search_path.
|
---|
22 | func CopyIn(table string, columns ...string) string {
|
---|
23 | stmt := "COPY " + QuoteIdentifier(table) + " ("
|
---|
24 | for i, col := range columns {
|
---|
25 | if i != 0 {
|
---|
26 | stmt += ", "
|
---|
27 | }
|
---|
28 | stmt += QuoteIdentifier(col)
|
---|
29 | }
|
---|
30 | stmt += ") FROM STDIN"
|
---|
31 | return stmt
|
---|
32 | }
|
---|
33 |
|
---|
34 | // CopyInSchema creates a COPY FROM statement which can be prepared with
|
---|
35 | // Tx.Prepare().
|
---|
36 | func CopyInSchema(schema, table string, columns ...string) string {
|
---|
37 | stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
|
---|
38 | for i, col := range columns {
|
---|
39 | if i != 0 {
|
---|
40 | stmt += ", "
|
---|
41 | }
|
---|
42 | stmt += QuoteIdentifier(col)
|
---|
43 | }
|
---|
44 | stmt += ") FROM STDIN"
|
---|
45 | return stmt
|
---|
46 | }
|
---|
47 |
|
---|
48 | type copyin struct {
|
---|
49 | cn *conn
|
---|
50 | buffer []byte
|
---|
51 | rowData chan []byte
|
---|
52 | done chan bool
|
---|
53 |
|
---|
54 | closed bool
|
---|
55 |
|
---|
56 | mu struct {
|
---|
57 | sync.Mutex
|
---|
58 | err error
|
---|
59 | driver.Result
|
---|
60 | }
|
---|
61 | }
|
---|
62 |
|
---|
63 | const ciBufferSize = 64 * 1024
|
---|
64 |
|
---|
65 | // flush buffer before the buffer is filled up and needs reallocation
|
---|
66 | const ciBufferFlushSize = 63 * 1024
|
---|
67 |
|
---|
68 | func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
|
---|
69 | if !cn.isInTransaction() {
|
---|
70 | return nil, errCopyNotSupportedOutsideTxn
|
---|
71 | }
|
---|
72 |
|
---|
73 | ci := ©in{
|
---|
74 | cn: cn,
|
---|
75 | buffer: make([]byte, 0, ciBufferSize),
|
---|
76 | rowData: make(chan []byte),
|
---|
77 | done: make(chan bool, 1),
|
---|
78 | }
|
---|
79 | // add CopyData identifier + 4 bytes for message length
|
---|
80 | ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
|
---|
81 |
|
---|
82 | b := cn.writeBuf('Q')
|
---|
83 | b.string(q)
|
---|
84 | cn.send(b)
|
---|
85 |
|
---|
86 | awaitCopyInResponse:
|
---|
87 | for {
|
---|
88 | t, r := cn.recv1()
|
---|
89 | switch t {
|
---|
90 | case 'G':
|
---|
91 | if r.byte() != 0 {
|
---|
92 | err = errBinaryCopyNotSupported
|
---|
93 | break awaitCopyInResponse
|
---|
94 | }
|
---|
95 | go ci.resploop()
|
---|
96 | return ci, nil
|
---|
97 | case 'H':
|
---|
98 | err = errCopyToNotSupported
|
---|
99 | break awaitCopyInResponse
|
---|
100 | case 'E':
|
---|
101 | err = parseError(r)
|
---|
102 | case 'Z':
|
---|
103 | if err == nil {
|
---|
104 | ci.setBad(driver.ErrBadConn)
|
---|
105 | errorf("unexpected ReadyForQuery in response to COPY")
|
---|
106 | }
|
---|
107 | cn.processReadyForQuery(r)
|
---|
108 | return nil, err
|
---|
109 | default:
|
---|
110 | ci.setBad(driver.ErrBadConn)
|
---|
111 | errorf("unknown response for copy query: %q", t)
|
---|
112 | }
|
---|
113 | }
|
---|
114 |
|
---|
115 | // something went wrong, abort COPY before we return
|
---|
116 | b = cn.writeBuf('f')
|
---|
117 | b.string(err.Error())
|
---|
118 | cn.send(b)
|
---|
119 |
|
---|
120 | for {
|
---|
121 | t, r := cn.recv1()
|
---|
122 | switch t {
|
---|
123 | case 'c', 'C', 'E':
|
---|
124 | case 'Z':
|
---|
125 | // correctly aborted, we're done
|
---|
126 | cn.processReadyForQuery(r)
|
---|
127 | return nil, err
|
---|
128 | default:
|
---|
129 | ci.setBad(driver.ErrBadConn)
|
---|
130 | errorf("unknown response for CopyFail: %q", t)
|
---|
131 | }
|
---|
132 | }
|
---|
133 | }
|
---|
134 |
|
---|
135 | func (ci *copyin) flush(buf []byte) {
|
---|
136 | // set message length (without message identifier)
|
---|
137 | binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
|
---|
138 |
|
---|
139 | _, err := ci.cn.c.Write(buf)
|
---|
140 | if err != nil {
|
---|
141 | panic(err)
|
---|
142 | }
|
---|
143 | }
|
---|
144 |
|
---|
145 | func (ci *copyin) resploop() {
|
---|
146 | for {
|
---|
147 | var r readBuf
|
---|
148 | t, err := ci.cn.recvMessage(&r)
|
---|
149 | if err != nil {
|
---|
150 | ci.setBad(driver.ErrBadConn)
|
---|
151 | ci.setError(err)
|
---|
152 | ci.done <- true
|
---|
153 | return
|
---|
154 | }
|
---|
155 | switch t {
|
---|
156 | case 'C':
|
---|
157 | // complete
|
---|
158 | res, _ := ci.cn.parseComplete(r.string())
|
---|
159 | ci.setResult(res)
|
---|
160 | case 'N':
|
---|
161 | if n := ci.cn.noticeHandler; n != nil {
|
---|
162 | n(parseError(&r))
|
---|
163 | }
|
---|
164 | case 'Z':
|
---|
165 | ci.cn.processReadyForQuery(&r)
|
---|
166 | ci.done <- true
|
---|
167 | return
|
---|
168 | case 'E':
|
---|
169 | err := parseError(&r)
|
---|
170 | ci.setError(err)
|
---|
171 | default:
|
---|
172 | ci.setBad(driver.ErrBadConn)
|
---|
173 | ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
|
---|
174 | ci.done <- true
|
---|
175 | return
|
---|
176 | }
|
---|
177 | }
|
---|
178 | }
|
---|
179 |
|
---|
180 | func (ci *copyin) setBad(err error) {
|
---|
181 | ci.cn.err.set(err)
|
---|
182 | }
|
---|
183 |
|
---|
184 | func (ci *copyin) getBad() error {
|
---|
185 | return ci.cn.err.get()
|
---|
186 | }
|
---|
187 |
|
---|
188 | func (ci *copyin) err() error {
|
---|
189 | ci.mu.Lock()
|
---|
190 | err := ci.mu.err
|
---|
191 | ci.mu.Unlock()
|
---|
192 | return err
|
---|
193 | }
|
---|
194 |
|
---|
195 | // setError() sets ci.err if one has not been set already. Caller must not be
|
---|
196 | // holding ci.Mutex.
|
---|
197 | func (ci *copyin) setError(err error) {
|
---|
198 | ci.mu.Lock()
|
---|
199 | if ci.mu.err == nil {
|
---|
200 | ci.mu.err = err
|
---|
201 | }
|
---|
202 | ci.mu.Unlock()
|
---|
203 | }
|
---|
204 |
|
---|
205 | func (ci *copyin) setResult(result driver.Result) {
|
---|
206 | ci.mu.Lock()
|
---|
207 | ci.mu.Result = result
|
---|
208 | ci.mu.Unlock()
|
---|
209 | }
|
---|
210 |
|
---|
211 | func (ci *copyin) getResult() driver.Result {
|
---|
212 | ci.mu.Lock()
|
---|
213 | result := ci.mu.Result
|
---|
214 | ci.mu.Unlock()
|
---|
215 | if result == nil {
|
---|
216 | return driver.RowsAffected(0)
|
---|
217 | }
|
---|
218 | return result
|
---|
219 | }
|
---|
220 |
|
---|
221 | func (ci *copyin) NumInput() int {
|
---|
222 | return -1
|
---|
223 | }
|
---|
224 |
|
---|
225 | func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
|
---|
226 | return nil, ErrNotSupported
|
---|
227 | }
|
---|
228 |
|
---|
229 | // Exec inserts values into the COPY stream. The insert is asynchronous
|
---|
230 | // and Exec can return errors from previous Exec calls to the same
|
---|
231 | // COPY stmt.
|
---|
232 | //
|
---|
233 | // You need to call Exec(nil) to sync the COPY stream and to get any
|
---|
234 | // errors from pending data, since Stmt.Close() doesn't return errors
|
---|
235 | // to the user.
|
---|
236 | func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
|
---|
237 | if ci.closed {
|
---|
238 | return nil, errCopyInClosed
|
---|
239 | }
|
---|
240 |
|
---|
241 | if err := ci.getBad(); err != nil {
|
---|
242 | return nil, err
|
---|
243 | }
|
---|
244 | defer ci.cn.errRecover(&err)
|
---|
245 |
|
---|
246 | if err := ci.err(); err != nil {
|
---|
247 | return nil, err
|
---|
248 | }
|
---|
249 |
|
---|
250 | if len(v) == 0 {
|
---|
251 | if err := ci.Close(); err != nil {
|
---|
252 | return driver.RowsAffected(0), err
|
---|
253 | }
|
---|
254 |
|
---|
255 | return ci.getResult(), nil
|
---|
256 | }
|
---|
257 |
|
---|
258 | numValues := len(v)
|
---|
259 | for i, value := range v {
|
---|
260 | ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
|
---|
261 | if i < numValues-1 {
|
---|
262 | ci.buffer = append(ci.buffer, '\t')
|
---|
263 | }
|
---|
264 | }
|
---|
265 |
|
---|
266 | ci.buffer = append(ci.buffer, '\n')
|
---|
267 |
|
---|
268 | if len(ci.buffer) > ciBufferFlushSize {
|
---|
269 | ci.flush(ci.buffer)
|
---|
270 | // reset buffer, keep bytes for message identifier and length
|
---|
271 | ci.buffer = ci.buffer[:5]
|
---|
272 | }
|
---|
273 |
|
---|
274 | return driver.RowsAffected(0), nil
|
---|
275 | }
|
---|
276 |
|
---|
277 | // CopyData inserts a raw string into the COPY stream. The insert is
|
---|
278 | // asynchronous and CopyData can return errors from previous CopyData calls to
|
---|
279 | // the same COPY stmt.
|
---|
280 | //
|
---|
281 | // You need to call Exec(nil) to sync the COPY stream and to get any
|
---|
282 | // errors from pending data, since Stmt.Close() doesn't return errors
|
---|
283 | // to the user.
|
---|
284 | func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) {
|
---|
285 | if ci.closed {
|
---|
286 | return nil, errCopyInClosed
|
---|
287 | }
|
---|
288 |
|
---|
289 | if finish := ci.cn.watchCancel(ctx); finish != nil {
|
---|
290 | defer finish()
|
---|
291 | }
|
---|
292 |
|
---|
293 | if err := ci.getBad(); err != nil {
|
---|
294 | return nil, err
|
---|
295 | }
|
---|
296 | defer ci.cn.errRecover(&err)
|
---|
297 |
|
---|
298 | if err := ci.err(); err != nil {
|
---|
299 | return nil, err
|
---|
300 | }
|
---|
301 |
|
---|
302 | ci.buffer = append(ci.buffer, []byte(line)...)
|
---|
303 | ci.buffer = append(ci.buffer, '\n')
|
---|
304 |
|
---|
305 | if len(ci.buffer) > ciBufferFlushSize {
|
---|
306 | ci.flush(ci.buffer)
|
---|
307 | // reset buffer, keep bytes for message identifier and length
|
---|
308 | ci.buffer = ci.buffer[:5]
|
---|
309 | }
|
---|
310 |
|
---|
311 | return driver.RowsAffected(0), nil
|
---|
312 | }
|
---|
313 |
|
---|
314 | func (ci *copyin) Close() (err error) {
|
---|
315 | if ci.closed { // Don't do anything, we're already closed
|
---|
316 | return nil
|
---|
317 | }
|
---|
318 | ci.closed = true
|
---|
319 |
|
---|
320 | if err := ci.getBad(); err != nil {
|
---|
321 | return err
|
---|
322 | }
|
---|
323 | defer ci.cn.errRecover(&err)
|
---|
324 |
|
---|
325 | if len(ci.buffer) > 0 {
|
---|
326 | ci.flush(ci.buffer)
|
---|
327 | }
|
---|
328 | // Avoid touching the scratch buffer as resploop could be using it.
|
---|
329 | err = ci.cn.sendSimpleMessage('c')
|
---|
330 | if err != nil {
|
---|
331 | return err
|
---|
332 | }
|
---|
333 |
|
---|
334 | <-ci.done
|
---|
335 | ci.cn.inCopy = false
|
---|
336 |
|
---|
337 | if err := ci.err(); err != nil {
|
---|
338 | return err
|
---|
339 | }
|
---|
340 | return nil
|
---|
341 | }
|
---|