1 | // Copyright 2017 The Sqlite Authors. All rights reserved.
|
---|
2 | // Use of this source code is governed by a BSD-style
|
---|
3 | // license that can be found in the LICENSE file.
|
---|
4 |
|
---|
5 | //go:generate go run generator.go -full-path-comments
|
---|
6 |
|
---|
7 | package sqlite // import "modernc.org/sqlite"
|
---|
8 |
|
---|
9 | import (
|
---|
10 | "context"
|
---|
11 | "database/sql"
|
---|
12 | "database/sql/driver"
|
---|
13 | "fmt"
|
---|
14 | "io"
|
---|
15 | "math"
|
---|
16 | "net/url"
|
---|
17 | "reflect"
|
---|
18 | "strconv"
|
---|
19 | "strings"
|
---|
20 | "sync"
|
---|
21 | "sync/atomic"
|
---|
22 | "time"
|
---|
23 | "unsafe"
|
---|
24 |
|
---|
25 | "modernc.org/libc"
|
---|
26 | "modernc.org/libc/sys/types"
|
---|
27 | sqlite3 "modernc.org/sqlite/lib"
|
---|
28 | )
|
---|
29 |
|
---|
30 | var (
|
---|
31 | _ driver.Conn = (*conn)(nil)
|
---|
32 | _ driver.Driver = (*Driver)(nil)
|
---|
33 | //lint:ignore SA1019 TODO implement ExecerContext
|
---|
34 | _ driver.Execer = (*conn)(nil)
|
---|
35 | //lint:ignore SA1019 TODO implement QueryerContext
|
---|
36 | _ driver.Queryer = (*conn)(nil)
|
---|
37 | _ driver.Result = (*result)(nil)
|
---|
38 | _ driver.Rows = (*rows)(nil)
|
---|
39 | _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
|
---|
40 | _ driver.RowsColumnTypeLength = (*rows)(nil)
|
---|
41 | _ driver.RowsColumnTypeNullable = (*rows)(nil)
|
---|
42 | _ driver.RowsColumnTypePrecisionScale = (*rows)(nil)
|
---|
43 | _ driver.RowsColumnTypeScanType = (*rows)(nil)
|
---|
44 | _ driver.Stmt = (*stmt)(nil)
|
---|
45 | _ driver.Tx = (*tx)(nil)
|
---|
46 | _ error = (*Error)(nil)
|
---|
47 | )
|
---|
48 |
|
---|
49 | const (
|
---|
50 | driverName = "sqlite"
|
---|
51 | ptrSize = unsafe.Sizeof(uintptr(0))
|
---|
52 | sqliteLockedSharedcache = sqlite3.SQLITE_LOCKED | (1 << 8)
|
---|
53 | )
|
---|
54 |
|
---|
55 | // Error represents sqlite library error code.
|
---|
56 | type Error struct {
|
---|
57 | msg string
|
---|
58 | code int
|
---|
59 | }
|
---|
60 |
|
---|
61 | // Error implements error.
|
---|
62 | func (e *Error) Error() string { return e.msg }
|
---|
63 |
|
---|
64 | // Code returns the sqlite result code for this error.
|
---|
65 | func (e *Error) Code() int { return e.code }
|
---|
66 |
|
---|
67 | var (
|
---|
68 | // ErrorCodeString maps Error.Code() to its string representation.
|
---|
69 | ErrorCodeString = map[int]string{
|
---|
70 | sqlite3.SQLITE_ABORT: "Callback routine requested an abort (SQLITE_ABORT)",
|
---|
71 | sqlite3.SQLITE_AUTH: "Authorization denied (SQLITE_AUTH)",
|
---|
72 | sqlite3.SQLITE_BUSY: "The database file is locked (SQLITE_BUSY)",
|
---|
73 | sqlite3.SQLITE_CANTOPEN: "Unable to open the database file (SQLITE_CANTOPEN)",
|
---|
74 | sqlite3.SQLITE_CONSTRAINT: "Abort due to constraint violation (SQLITE_CONSTRAINT)",
|
---|
75 | sqlite3.SQLITE_CORRUPT: "The database disk image is malformed (SQLITE_CORRUPT)",
|
---|
76 | sqlite3.SQLITE_DONE: "sqlite3_step() has finished executing (SQLITE_DONE)",
|
---|
77 | sqlite3.SQLITE_EMPTY: "Internal use only (SQLITE_EMPTY)",
|
---|
78 | sqlite3.SQLITE_ERROR: "Generic error (SQLITE_ERROR)",
|
---|
79 | sqlite3.SQLITE_FORMAT: "Not used (SQLITE_FORMAT)",
|
---|
80 | sqlite3.SQLITE_FULL: "Insertion failed because database is full (SQLITE_FULL)",
|
---|
81 | sqlite3.SQLITE_INTERNAL: "Internal logic error in SQLite (SQLITE_INTERNAL)",
|
---|
82 | sqlite3.SQLITE_INTERRUPT: "Operation terminated by sqlite3_interrupt()(SQLITE_INTERRUPT)",
|
---|
83 | sqlite3.SQLITE_IOERR | (1 << 8): "(SQLITE_IOERR_READ)",
|
---|
84 | sqlite3.SQLITE_IOERR | (10 << 8): "(SQLITE_IOERR_DELETE)",
|
---|
85 | sqlite3.SQLITE_IOERR | (11 << 8): "(SQLITE_IOERR_BLOCKED)",
|
---|
86 | sqlite3.SQLITE_IOERR | (12 << 8): "(SQLITE_IOERR_NOMEM)",
|
---|
87 | sqlite3.SQLITE_IOERR | (13 << 8): "(SQLITE_IOERR_ACCESS)",
|
---|
88 | sqlite3.SQLITE_IOERR | (14 << 8): "(SQLITE_IOERR_CHECKRESERVEDLOCK)",
|
---|
89 | sqlite3.SQLITE_IOERR | (15 << 8): "(SQLITE_IOERR_LOCK)",
|
---|
90 | sqlite3.SQLITE_IOERR | (16 << 8): "(SQLITE_IOERR_CLOSE)",
|
---|
91 | sqlite3.SQLITE_IOERR | (17 << 8): "(SQLITE_IOERR_DIR_CLOSE)",
|
---|
92 | sqlite3.SQLITE_IOERR | (2 << 8): "(SQLITE_IOERR_SHORT_READ)",
|
---|
93 | sqlite3.SQLITE_IOERR | (3 << 8): "(SQLITE_IOERR_WRITE)",
|
---|
94 | sqlite3.SQLITE_IOERR | (4 << 8): "(SQLITE_IOERR_FSYNC)",
|
---|
95 | sqlite3.SQLITE_IOERR | (5 << 8): "(SQLITE_IOERR_DIR_FSYNC)",
|
---|
96 | sqlite3.SQLITE_IOERR | (6 << 8): "(SQLITE_IOERR_TRUNCATE)",
|
---|
97 | sqlite3.SQLITE_IOERR | (7 << 8): "(SQLITE_IOERR_FSTAT)",
|
---|
98 | sqlite3.SQLITE_IOERR | (8 << 8): "(SQLITE_IOERR_UNLOCK)",
|
---|
99 | sqlite3.SQLITE_IOERR | (9 << 8): "(SQLITE_IOERR_RDLOCK)",
|
---|
100 | sqlite3.SQLITE_IOERR: "Some kind of disk I/O error occurred (SQLITE_IOERR)",
|
---|
101 | sqlite3.SQLITE_LOCKED | (1 << 8): "(SQLITE_LOCKED_SHAREDCACHE)",
|
---|
102 | sqlite3.SQLITE_LOCKED: "A table in the database is locked (SQLITE_LOCKED)",
|
---|
103 | sqlite3.SQLITE_MISMATCH: "Data type mismatch (SQLITE_MISMATCH)",
|
---|
104 | sqlite3.SQLITE_MISUSE: "Library used incorrectly (SQLITE_MISUSE)",
|
---|
105 | sqlite3.SQLITE_NOLFS: "Uses OS features not supported on host (SQLITE_NOLFS)",
|
---|
106 | sqlite3.SQLITE_NOMEM: "A malloc() failed (SQLITE_NOMEM)",
|
---|
107 | sqlite3.SQLITE_NOTADB: "File opened that is not a database file (SQLITE_NOTADB)",
|
---|
108 | sqlite3.SQLITE_NOTFOUND: "Unknown opcode in sqlite3_file_control() (SQLITE_NOTFOUND)",
|
---|
109 | sqlite3.SQLITE_NOTICE: "Notifications from sqlite3_log() (SQLITE_NOTICE)",
|
---|
110 | sqlite3.SQLITE_PERM: "Access permission denied (SQLITE_PERM)",
|
---|
111 | sqlite3.SQLITE_PROTOCOL: "Database lock protocol error (SQLITE_PROTOCOL)",
|
---|
112 | sqlite3.SQLITE_RANGE: "2nd parameter to sqlite3_bind out of range (SQLITE_RANGE)",
|
---|
113 | sqlite3.SQLITE_READONLY: "Attempt to write a readonly database (SQLITE_READONLY)",
|
---|
114 | sqlite3.SQLITE_ROW: "sqlite3_step() has another row ready (SQLITE_ROW)",
|
---|
115 | sqlite3.SQLITE_SCHEMA: "The database schema changed (SQLITE_SCHEMA)",
|
---|
116 | sqlite3.SQLITE_TOOBIG: "String or BLOB exceeds size limit (SQLITE_TOOBIG)",
|
---|
117 | sqlite3.SQLITE_WARNING: "Warnings from sqlite3_log() (SQLITE_WARNING)",
|
---|
118 | }
|
---|
119 | )
|
---|
120 |
|
---|
121 | func init() {
|
---|
122 | sql.Register(driverName, newDriver())
|
---|
123 | }
|
---|
124 |
|
---|
125 | type result struct {
|
---|
126 | lastInsertID int64
|
---|
127 | rowsAffected int
|
---|
128 | }
|
---|
129 |
|
---|
130 | func newResult(c *conn) (_ *result, err error) {
|
---|
131 | r := &result{}
|
---|
132 | if r.rowsAffected, err = c.changes(); err != nil {
|
---|
133 | return nil, err
|
---|
134 | }
|
---|
135 |
|
---|
136 | if r.lastInsertID, err = c.lastInsertRowID(); err != nil {
|
---|
137 | return nil, err
|
---|
138 | }
|
---|
139 |
|
---|
140 | return r, nil
|
---|
141 | }
|
---|
142 |
|
---|
143 | // LastInsertId returns the database's auto-generated ID after, for example, an
|
---|
144 | // INSERT into a table with primary key.
|
---|
145 | func (r *result) LastInsertId() (int64, error) {
|
---|
146 | if r == nil {
|
---|
147 | return 0, nil
|
---|
148 | }
|
---|
149 |
|
---|
150 | return r.lastInsertID, nil
|
---|
151 | }
|
---|
152 |
|
---|
153 | // RowsAffected returns the number of rows affected by the query.
|
---|
154 | func (r *result) RowsAffected() (int64, error) {
|
---|
155 | if r == nil {
|
---|
156 | return 0, nil
|
---|
157 | }
|
---|
158 |
|
---|
159 | return int64(r.rowsAffected), nil
|
---|
160 | }
|
---|
161 |
|
---|
162 | type rows struct {
|
---|
163 | allocs []uintptr
|
---|
164 | c *conn
|
---|
165 | columns []string
|
---|
166 | pstmt uintptr
|
---|
167 |
|
---|
168 | doStep bool
|
---|
169 | empty bool
|
---|
170 | }
|
---|
171 |
|
---|
172 | func newRows(c *conn, pstmt uintptr, allocs []uintptr, empty bool) (r *rows, err error) {
|
---|
173 | r = &rows{c: c, pstmt: pstmt, allocs: allocs, empty: empty}
|
---|
174 |
|
---|
175 | defer func() {
|
---|
176 | if err != nil {
|
---|
177 | r.Close()
|
---|
178 | r = nil
|
---|
179 | }
|
---|
180 | }()
|
---|
181 |
|
---|
182 | n, err := c.columnCount(pstmt)
|
---|
183 | if err != nil {
|
---|
184 | return nil, err
|
---|
185 | }
|
---|
186 |
|
---|
187 | r.columns = make([]string, n)
|
---|
188 | for i := range r.columns {
|
---|
189 | if r.columns[i], err = r.c.columnName(pstmt, i); err != nil {
|
---|
190 | return nil, err
|
---|
191 | }
|
---|
192 | }
|
---|
193 |
|
---|
194 | return r, nil
|
---|
195 | }
|
---|
196 |
|
---|
197 | // Close closes the rows iterator.
|
---|
198 | func (r *rows) Close() (err error) {
|
---|
199 | for _, v := range r.allocs {
|
---|
200 | r.c.free(v)
|
---|
201 | }
|
---|
202 | r.allocs = nil
|
---|
203 | return r.c.finalize(r.pstmt)
|
---|
204 | }
|
---|
205 |
|
---|
206 | // Columns returns the names of the columns. The number of columns of the
|
---|
207 | // result is inferred from the length of the slice. If a particular column name
|
---|
208 | // isn't known, an empty string should be returned for that entry.
|
---|
209 | func (r *rows) Columns() (c []string) {
|
---|
210 | return r.columns
|
---|
211 | }
|
---|
212 |
|
---|
213 | // Next is called to populate the next row of data into the provided slice. The
|
---|
214 | // provided slice will be the same size as the Columns() are wide.
|
---|
215 | //
|
---|
216 | // Next should return io.EOF when there are no more rows.
|
---|
217 | func (r *rows) Next(dest []driver.Value) (err error) {
|
---|
218 | if r.empty {
|
---|
219 | return io.EOF
|
---|
220 | }
|
---|
221 |
|
---|
222 | rc := sqlite3.SQLITE_ROW
|
---|
223 | if r.doStep {
|
---|
224 | if rc, err = r.c.step(r.pstmt); err != nil {
|
---|
225 | return err
|
---|
226 | }
|
---|
227 | }
|
---|
228 |
|
---|
229 | r.doStep = true
|
---|
230 | switch rc {
|
---|
231 | case sqlite3.SQLITE_ROW:
|
---|
232 | if g, e := len(dest), len(r.columns); g != e {
|
---|
233 | return fmt.Errorf("sqlite: Next: have %v destination values, expected %v", g, e)
|
---|
234 | }
|
---|
235 |
|
---|
236 | for i := range dest {
|
---|
237 | ct, err := r.c.columnType(r.pstmt, i)
|
---|
238 | if err != nil {
|
---|
239 | return err
|
---|
240 | }
|
---|
241 |
|
---|
242 | switch ct {
|
---|
243 | case sqlite3.SQLITE_INTEGER:
|
---|
244 | v, err := r.c.columnInt64(r.pstmt, i)
|
---|
245 | if err != nil {
|
---|
246 | return err
|
---|
247 | }
|
---|
248 |
|
---|
249 | dest[i] = v
|
---|
250 | case sqlite3.SQLITE_FLOAT:
|
---|
251 | v, err := r.c.columnDouble(r.pstmt, i)
|
---|
252 | if err != nil {
|
---|
253 | return err
|
---|
254 | }
|
---|
255 |
|
---|
256 | dest[i] = v
|
---|
257 | case sqlite3.SQLITE_TEXT:
|
---|
258 | v, err := r.c.columnText(r.pstmt, i)
|
---|
259 | if err != nil {
|
---|
260 | return err
|
---|
261 | }
|
---|
262 |
|
---|
263 | switch r.ColumnTypeDatabaseTypeName(i) {
|
---|
264 | case "DATE", "DATETIME", "TIMESTAMP":
|
---|
265 | dest[i], _ = r.c.parseTime(v)
|
---|
266 | default:
|
---|
267 | dest[i] = v
|
---|
268 | }
|
---|
269 | case sqlite3.SQLITE_BLOB:
|
---|
270 | v, err := r.c.columnBlob(r.pstmt, i)
|
---|
271 | if err != nil {
|
---|
272 | return err
|
---|
273 | }
|
---|
274 |
|
---|
275 | dest[i] = v
|
---|
276 | case sqlite3.SQLITE_NULL:
|
---|
277 | dest[i] = nil
|
---|
278 | default:
|
---|
279 | return fmt.Errorf("internal error: rc %d", rc)
|
---|
280 | }
|
---|
281 | }
|
---|
282 | return nil
|
---|
283 | case sqlite3.SQLITE_DONE:
|
---|
284 | return io.EOF
|
---|
285 | default:
|
---|
286 | return r.c.errstr(int32(rc))
|
---|
287 | }
|
---|
288 | }
|
---|
289 |
|
---|
290 | // Inspired by mattn/go-sqlite3: https://github.com/mattn/go-sqlite3/blob/ab91e934/sqlite3.go#L210-L226
|
---|
291 | //
|
---|
292 | // These time.Parse formats handle formats 1 through 7 listed at https://www.sqlite.org/lang_datefunc.html.
|
---|
293 | var parseTimeFormats = []string{
|
---|
294 | "2006-01-02 15:04:05.999999999-07:00",
|
---|
295 | "2006-01-02T15:04:05.999999999-07:00",
|
---|
296 | "2006-01-02 15:04:05.999999999",
|
---|
297 | "2006-01-02T15:04:05.999999999",
|
---|
298 | "2006-01-02 15:04",
|
---|
299 | "2006-01-02T15:04",
|
---|
300 | "2006-01-02",
|
---|
301 | }
|
---|
302 |
|
---|
303 | // Attempt to parse s as a time. Return (s, false) if s is not
|
---|
304 | // recognized as a valid time encoding.
|
---|
305 | func (c *conn) parseTime(s string) (interface{}, bool) {
|
---|
306 | if v, ok := c.parseTimeString(s, strings.Index(s, "m=")); ok {
|
---|
307 | return v, true
|
---|
308 | }
|
---|
309 |
|
---|
310 | ts := strings.TrimSuffix(s, "Z")
|
---|
311 |
|
---|
312 | for _, f := range parseTimeFormats {
|
---|
313 | t, err := time.Parse(f, ts)
|
---|
314 | if err == nil {
|
---|
315 | return t, true
|
---|
316 | }
|
---|
317 | }
|
---|
318 |
|
---|
319 | return s, false
|
---|
320 | }
|
---|
321 |
|
---|
322 | // Attempt to parse s as a time string produced by t.String(). If x > 0 it's
|
---|
323 | // the index of substring "m=" within s. Return (s, false) if s is
|
---|
324 | // not recognized as a valid time encoding.
|
---|
325 | func (c *conn) parseTimeString(s0 string, x int) (interface{}, bool) {
|
---|
326 | s := s0
|
---|
327 | if x > 0 {
|
---|
328 | s = s[:x] // "2006-01-02 15:04:05.999999999 -0700 MST m=+9999" -> "2006-01-02 15:04:05.999999999 -0700 MST "
|
---|
329 | }
|
---|
330 | s = strings.TrimSpace(s)
|
---|
331 | if t, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", s); err == nil {
|
---|
332 | return t, true
|
---|
333 | }
|
---|
334 |
|
---|
335 | return s0, false
|
---|
336 | }
|
---|
337 |
|
---|
338 | // writeTimeFormats are the names and formats supported
|
---|
339 | // by the `_time_format` DSN query param.
|
---|
340 | var writeTimeFormats = map[string]string{
|
---|
341 | "sqlite": parseTimeFormats[0],
|
---|
342 | }
|
---|
343 |
|
---|
344 | func (c *conn) formatTime(t time.Time) string {
|
---|
345 | // Before configurable write time formats were supported,
|
---|
346 | // time.Time.String was used. Maintain that default to
|
---|
347 | // keep existing driver users formatting times the same.
|
---|
348 | if c.writeTimeFormat == "" {
|
---|
349 | return t.String()
|
---|
350 | }
|
---|
351 | return t.Format(c.writeTimeFormat)
|
---|
352 | }
|
---|
353 |
|
---|
354 | // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return
|
---|
355 | // the database system type name without the length. Type names should be
|
---|
356 | // uppercase. Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2",
|
---|
357 | // "CHAR", "TEXT", "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT",
|
---|
358 | // "JSONB", "XML", "TIMESTAMP".
|
---|
359 | func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
|
---|
360 | return strings.ToUpper(r.c.columnDeclType(r.pstmt, index))
|
---|
361 | }
|
---|
362 |
|
---|
363 | // RowsColumnTypeLength may be implemented by Rows. It should return the length
|
---|
364 | // of the column type if the column is a variable length type. If the column is
|
---|
365 | // not a variable length type ok should return false. If length is not limited
|
---|
366 | // other than system limits, it should return math.MaxInt64. The following are
|
---|
367 | // examples of returned values for various types:
|
---|
368 | //
|
---|
369 | // TEXT (math.MaxInt64, true)
|
---|
370 | // varchar(10) (10, true)
|
---|
371 | // nvarchar(10) (10, true)
|
---|
372 | // decimal (0, false)
|
---|
373 | // int (0, false)
|
---|
374 | // bytea(30) (30, true)
|
---|
375 | func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) {
|
---|
376 | t, err := r.c.columnType(r.pstmt, index)
|
---|
377 | if err != nil {
|
---|
378 | return 0, false
|
---|
379 | }
|
---|
380 |
|
---|
381 | switch t {
|
---|
382 | case sqlite3.SQLITE_INTEGER:
|
---|
383 | return 0, false
|
---|
384 | case sqlite3.SQLITE_FLOAT:
|
---|
385 | return 0, false
|
---|
386 | case sqlite3.SQLITE_TEXT:
|
---|
387 | return math.MaxInt64, true
|
---|
388 | case sqlite3.SQLITE_BLOB:
|
---|
389 | return math.MaxInt64, true
|
---|
390 | case sqlite3.SQLITE_NULL:
|
---|
391 | return 0, false
|
---|
392 | default:
|
---|
393 | return 0, false
|
---|
394 | }
|
---|
395 | }
|
---|
396 |
|
---|
397 | // RowsColumnTypeNullable may be implemented by Rows. The nullable value should
|
---|
398 | // be true if it is known the column may be null, or false if the column is
|
---|
399 | // known to be not nullable. If the column nullability is unknown, ok should be
|
---|
400 | // false.
|
---|
401 | func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
---|
402 | return true, true
|
---|
403 | }
|
---|
404 |
|
---|
405 | // RowsColumnTypePrecisionScale may be implemented by Rows. It should return
|
---|
406 | // the precision and scale for decimal types. If not applicable, ok should be
|
---|
407 | // false. The following are examples of returned values for various types:
|
---|
408 | //
|
---|
409 | // decimal(38, 4) (38, 4, true)
|
---|
410 | // int (0, 0, false)
|
---|
411 | // decimal (math.MaxInt64, math.MaxInt64, true)
|
---|
412 | func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
---|
413 | return 0, 0, false
|
---|
414 | }
|
---|
415 |
|
---|
416 | // RowsColumnTypeScanType may be implemented by Rows. It should return the
|
---|
417 | // value type that can be used to scan types into. For example, the database
|
---|
418 | // column type "bigint" this should return "reflect.TypeOf(int64(0))".
|
---|
419 | func (r *rows) ColumnTypeScanType(index int) reflect.Type {
|
---|
420 | t, err := r.c.columnType(r.pstmt, index)
|
---|
421 | if err != nil {
|
---|
422 | return reflect.TypeOf("")
|
---|
423 | }
|
---|
424 |
|
---|
425 | switch t {
|
---|
426 | case sqlite3.SQLITE_INTEGER:
|
---|
427 | switch strings.ToLower(r.c.columnDeclType(r.pstmt, index)) {
|
---|
428 | case "boolean":
|
---|
429 | return reflect.TypeOf(false)
|
---|
430 | case "date", "datetime", "time", "timestamp":
|
---|
431 | return reflect.TypeOf(time.Time{})
|
---|
432 | default:
|
---|
433 | return reflect.TypeOf(int64(0))
|
---|
434 | }
|
---|
435 | case sqlite3.SQLITE_FLOAT:
|
---|
436 | return reflect.TypeOf(float64(0))
|
---|
437 | case sqlite3.SQLITE_TEXT:
|
---|
438 | return reflect.TypeOf("")
|
---|
439 | case sqlite3.SQLITE_BLOB:
|
---|
440 | return reflect.SliceOf(reflect.TypeOf([]byte{}))
|
---|
441 | case sqlite3.SQLITE_NULL:
|
---|
442 | return reflect.TypeOf(nil)
|
---|
443 | default:
|
---|
444 | return reflect.TypeOf("")
|
---|
445 | }
|
---|
446 | }
|
---|
447 |
|
---|
448 | type stmt struct {
|
---|
449 | c *conn
|
---|
450 | psql uintptr
|
---|
451 | }
|
---|
452 |
|
---|
453 | func newStmt(c *conn, sql string) (*stmt, error) {
|
---|
454 | p, err := libc.CString(sql)
|
---|
455 | if err != nil {
|
---|
456 | return nil, err
|
---|
457 | }
|
---|
458 | stm := stmt{c: c, psql: p}
|
---|
459 |
|
---|
460 | return &stm, nil
|
---|
461 | }
|
---|
462 |
|
---|
463 | // Close closes the statement.
|
---|
464 | //
|
---|
465 | // As of Go 1.1, a Stmt will not be closed if it's in use by any queries.
|
---|
466 | func (s *stmt) Close() (err error) {
|
---|
467 | s.c.free(s.psql)
|
---|
468 | s.psql = 0
|
---|
469 | return nil
|
---|
470 | }
|
---|
471 |
|
---|
472 | // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
|
---|
473 | //
|
---|
474 | //
|
---|
475 | // Deprecated: Drivers should implement StmtExecContext instead (or
|
---|
476 | // additionally).
|
---|
477 | func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { //TODO StmtExecContext
|
---|
478 | return s.exec(context.Background(), toNamedValues(args))
|
---|
479 | }
|
---|
480 |
|
---|
481 | // toNamedValues converts []driver.Value to []driver.NamedValue
|
---|
482 | func toNamedValues(vals []driver.Value) (r []driver.NamedValue) {
|
---|
483 | r = make([]driver.NamedValue, len(vals))
|
---|
484 | for i, val := range vals {
|
---|
485 | r[i] = driver.NamedValue{Value: val, Ordinal: i + 1}
|
---|
486 | }
|
---|
487 | return r
|
---|
488 | }
|
---|
489 |
|
---|
490 | func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
|
---|
491 | var pstmt uintptr
|
---|
492 | var done int32
|
---|
493 | if ctx != nil && ctx.Done() != nil {
|
---|
494 | defer interruptOnDone(ctx, s.c, &done)()
|
---|
495 | }
|
---|
496 |
|
---|
497 | for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; {
|
---|
498 | if pstmt, err = s.c.prepareV2(&psql); err != nil {
|
---|
499 | return nil, err
|
---|
500 | }
|
---|
501 |
|
---|
502 | if pstmt == 0 {
|
---|
503 | continue
|
---|
504 | }
|
---|
505 | err = func() (err error) {
|
---|
506 | n, err := s.c.bindParameterCount(pstmt)
|
---|
507 | if err != nil {
|
---|
508 | return err
|
---|
509 | }
|
---|
510 |
|
---|
511 | if n != 0 {
|
---|
512 | allocs, err := s.c.bind(pstmt, n, args)
|
---|
513 | if err != nil {
|
---|
514 | return err
|
---|
515 | }
|
---|
516 |
|
---|
517 | if len(allocs) != 0 {
|
---|
518 | defer func() {
|
---|
519 | for _, v := range allocs {
|
---|
520 | s.c.free(v)
|
---|
521 | }
|
---|
522 | }()
|
---|
523 | }
|
---|
524 | }
|
---|
525 |
|
---|
526 | rc, err := s.c.step(pstmt)
|
---|
527 | if err != nil {
|
---|
528 | return err
|
---|
529 | }
|
---|
530 |
|
---|
531 | switch rc & 0xff {
|
---|
532 | case sqlite3.SQLITE_DONE, sqlite3.SQLITE_ROW:
|
---|
533 | // nop
|
---|
534 | default:
|
---|
535 | return s.c.errstr(int32(rc))
|
---|
536 | }
|
---|
537 |
|
---|
538 | return nil
|
---|
539 | }()
|
---|
540 |
|
---|
541 | if e := s.c.finalize(pstmt); e != nil && err == nil {
|
---|
542 | err = e
|
---|
543 | }
|
---|
544 |
|
---|
545 | if err != nil {
|
---|
546 | return nil, err
|
---|
547 | }
|
---|
548 | }
|
---|
549 | return newResult(s.c)
|
---|
550 | }
|
---|
551 |
|
---|
552 | // NumInput returns the number of placeholder parameters.
|
---|
553 | //
|
---|
554 | // If NumInput returns >= 0, the sql package will sanity check argument counts
|
---|
555 | // from callers and return errors to the caller before the statement's Exec or
|
---|
556 | // Query methods are called.
|
---|
557 | //
|
---|
558 | // NumInput may also return -1, if the driver doesn't know its number of
|
---|
559 | // placeholders. In that case, the sql package will not sanity check Exec or
|
---|
560 | // Query argument counts.
|
---|
561 | func (s *stmt) NumInput() (n int) {
|
---|
562 | return -1
|
---|
563 | }
|
---|
564 |
|
---|
565 | // Query executes a query that may return rows, such as a
|
---|
566 | // SELECT.
|
---|
567 | //
|
---|
568 | // Deprecated: Drivers should implement StmtQueryContext instead (or
|
---|
569 | // additionally).
|
---|
570 | func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { //TODO StmtQueryContext
|
---|
571 | return s.query(context.Background(), toNamedValues(args))
|
---|
572 | }
|
---|
573 |
|
---|
574 | func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) {
|
---|
575 | var pstmt uintptr
|
---|
576 | var done int32
|
---|
577 | if ctx != nil && ctx.Done() != nil {
|
---|
578 | defer interruptOnDone(ctx, s.c, &done)()
|
---|
579 | }
|
---|
580 |
|
---|
581 | var allocs []uintptr
|
---|
582 | for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; {
|
---|
583 | if pstmt, err = s.c.prepareV2(&psql); err != nil {
|
---|
584 | return nil, err
|
---|
585 | }
|
---|
586 |
|
---|
587 | if pstmt == 0 {
|
---|
588 | continue
|
---|
589 | }
|
---|
590 |
|
---|
591 | err = func() (err error) {
|
---|
592 | n, err := s.c.bindParameterCount(pstmt)
|
---|
593 | if err != nil {
|
---|
594 | return err
|
---|
595 | }
|
---|
596 |
|
---|
597 | if n != 0 {
|
---|
598 | if allocs, err = s.c.bind(pstmt, n, args); err != nil {
|
---|
599 | return err
|
---|
600 | }
|
---|
601 | }
|
---|
602 |
|
---|
603 | rc, err := s.c.step(pstmt)
|
---|
604 | if err != nil {
|
---|
605 | return err
|
---|
606 | }
|
---|
607 |
|
---|
608 | switch rc & 0xff {
|
---|
609 | case sqlite3.SQLITE_ROW:
|
---|
610 | if r != nil {
|
---|
611 | r.Close()
|
---|
612 | }
|
---|
613 | if r, err = newRows(s.c, pstmt, allocs, false); err != nil {
|
---|
614 | return err
|
---|
615 | }
|
---|
616 |
|
---|
617 | pstmt = 0
|
---|
618 | return nil
|
---|
619 | case sqlite3.SQLITE_DONE:
|
---|
620 | if r == nil {
|
---|
621 | if r, err = newRows(s.c, pstmt, allocs, true); err != nil {
|
---|
622 | return err
|
---|
623 | }
|
---|
624 | pstmt = 0
|
---|
625 | return nil
|
---|
626 | }
|
---|
627 |
|
---|
628 | // nop
|
---|
629 | default:
|
---|
630 | return s.c.errstr(int32(rc))
|
---|
631 | }
|
---|
632 |
|
---|
633 | if *(*byte)(unsafe.Pointer(psql)) == 0 {
|
---|
634 | if r != nil {
|
---|
635 | r.Close()
|
---|
636 | }
|
---|
637 | if r, err = newRows(s.c, pstmt, allocs, true); err != nil {
|
---|
638 | return err
|
---|
639 | }
|
---|
640 |
|
---|
641 | pstmt = 0
|
---|
642 | }
|
---|
643 | return nil
|
---|
644 | }()
|
---|
645 | if e := s.c.finalize(pstmt); e != nil && err == nil {
|
---|
646 | err = e
|
---|
647 | }
|
---|
648 |
|
---|
649 | if err != nil {
|
---|
650 | return nil, err
|
---|
651 | }
|
---|
652 | }
|
---|
653 | return r, err
|
---|
654 | }
|
---|
655 |
|
---|
656 | type tx struct {
|
---|
657 | c *conn
|
---|
658 | }
|
---|
659 |
|
---|
660 | func newTx(c *conn, opts driver.TxOptions) (*tx, error) {
|
---|
661 | r := &tx{c: c}
|
---|
662 |
|
---|
663 | sql := "begin"
|
---|
664 | if !opts.ReadOnly && c.beginMode != "" {
|
---|
665 | sql = "begin " + c.beginMode
|
---|
666 | }
|
---|
667 |
|
---|
668 | if err := r.exec(context.Background(), sql); err != nil {
|
---|
669 | return nil, err
|
---|
670 | }
|
---|
671 |
|
---|
672 | return r, nil
|
---|
673 | }
|
---|
674 |
|
---|
675 | // Commit implements driver.Tx.
|
---|
676 | func (t *tx) Commit() (err error) {
|
---|
677 | return t.exec(context.Background(), "commit")
|
---|
678 | }
|
---|
679 |
|
---|
680 | // Rollback implements driver.Tx.
|
---|
681 | func (t *tx) Rollback() (err error) {
|
---|
682 | return t.exec(context.Background(), "rollback")
|
---|
683 | }
|
---|
684 |
|
---|
685 | func (t *tx) exec(ctx context.Context, sql string) (err error) {
|
---|
686 | psql, err := libc.CString(sql)
|
---|
687 | if err != nil {
|
---|
688 | return err
|
---|
689 | }
|
---|
690 |
|
---|
691 | defer t.c.free(psql)
|
---|
692 | //TODO use t.conn.ExecContext() instead
|
---|
693 |
|
---|
694 | if ctx != nil && ctx.Done() != nil {
|
---|
695 | defer interruptOnDone(ctx, t.c, nil)()
|
---|
696 | }
|
---|
697 |
|
---|
698 | if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK {
|
---|
699 | return t.c.errstr(rc)
|
---|
700 | }
|
---|
701 |
|
---|
702 | return nil
|
---|
703 | }
|
---|
704 |
|
---|
705 | // interruptOnDone sets up a goroutine to interrupt the provided db when the
|
---|
706 | // context is canceled, and returns a function the caller must defer so it
|
---|
707 | // doesn't interrupt after the caller finishes.
|
---|
708 | func interruptOnDone(
|
---|
709 | ctx context.Context,
|
---|
710 | c *conn,
|
---|
711 | done *int32,
|
---|
712 | ) func() {
|
---|
713 | if done == nil {
|
---|
714 | var d int32
|
---|
715 | done = &d
|
---|
716 | }
|
---|
717 |
|
---|
718 | donech := make(chan struct{})
|
---|
719 |
|
---|
720 | go func() {
|
---|
721 | select {
|
---|
722 | case <-ctx.Done():
|
---|
723 | // don't call interrupt if we were already done: it indicates that this
|
---|
724 | // call to exec is no longer running and we would be interrupting
|
---|
725 | // nothing, or even possibly an unrelated later call to exec.
|
---|
726 | if atomic.AddInt32(done, 1) == 1 {
|
---|
727 | c.interrupt(c.db)
|
---|
728 | }
|
---|
729 | case <-donech:
|
---|
730 | }
|
---|
731 | }()
|
---|
732 |
|
---|
733 | // the caller is expected to defer this function
|
---|
734 | return func() {
|
---|
735 | // set the done flag so that a context cancellation right after the caller
|
---|
736 | // returns doesn't trigger a call to interrupt for some other statement.
|
---|
737 | atomic.AddInt32(done, 1)
|
---|
738 | close(donech)
|
---|
739 | }
|
---|
740 | }
|
---|
741 |
|
---|
742 | type conn struct {
|
---|
743 | db uintptr // *sqlite3.Xsqlite3
|
---|
744 | tls *libc.TLS
|
---|
745 |
|
---|
746 | // Context handling can cause conn.Close and conn.interrupt to be invoked
|
---|
747 | // concurrently.
|
---|
748 | sync.Mutex
|
---|
749 |
|
---|
750 | writeTimeFormat string
|
---|
751 | beginMode string
|
---|
752 | }
|
---|
753 |
|
---|
754 | func newConn(dsn string) (*conn, error) {
|
---|
755 | var query, vfsName string
|
---|
756 |
|
---|
757 | // Parse the query parameters from the dsn and them from the dsn if not prefixed by file:
|
---|
758 | // https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1046
|
---|
759 | // https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1383
|
---|
760 | pos := strings.IndexRune(dsn, '?')
|
---|
761 | if pos >= 1 {
|
---|
762 | query = dsn[pos+1:]
|
---|
763 | var err error
|
---|
764 | vfsName, err = getVFSName(query)
|
---|
765 | if err != nil {
|
---|
766 | return nil, err
|
---|
767 | }
|
---|
768 |
|
---|
769 | if !strings.HasPrefix(dsn, "file:") {
|
---|
770 | dsn = dsn[:pos]
|
---|
771 | }
|
---|
772 | }
|
---|
773 |
|
---|
774 | c := &conn{tls: libc.NewTLS()}
|
---|
775 | db, err := c.openV2(
|
---|
776 | dsn,
|
---|
777 | vfsName,
|
---|
778 | sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE|
|
---|
779 | sqlite3.SQLITE_OPEN_FULLMUTEX|
|
---|
780 | sqlite3.SQLITE_OPEN_URI,
|
---|
781 | )
|
---|
782 | if err != nil {
|
---|
783 | return nil, err
|
---|
784 | }
|
---|
785 |
|
---|
786 | c.db = db
|
---|
787 | if err = c.extendedResultCodes(true); err != nil {
|
---|
788 | c.Close()
|
---|
789 | return nil, err
|
---|
790 | }
|
---|
791 |
|
---|
792 | if err = applyQueryParams(c, query); err != nil {
|
---|
793 | c.Close()
|
---|
794 | return nil, err
|
---|
795 | }
|
---|
796 |
|
---|
797 | return c, nil
|
---|
798 | }
|
---|
799 |
|
---|
800 | func getVFSName(query string) (r string, err error) {
|
---|
801 | q, err := url.ParseQuery(query)
|
---|
802 | if err != nil {
|
---|
803 | return "", err
|
---|
804 | }
|
---|
805 |
|
---|
806 | for _, v := range q["vfs"] {
|
---|
807 | if r != "" && r != v {
|
---|
808 | return "", fmt.Errorf("conflicting vfs query parameters: %v", q["vfs"])
|
---|
809 | }
|
---|
810 |
|
---|
811 | r = v
|
---|
812 | }
|
---|
813 |
|
---|
814 | return r, nil
|
---|
815 | }
|
---|
816 |
|
---|
817 | func applyQueryParams(c *conn, query string) error {
|
---|
818 | q, err := url.ParseQuery(query)
|
---|
819 | if err != nil {
|
---|
820 | return err
|
---|
821 | }
|
---|
822 |
|
---|
823 | for _, v := range q["_pragma"] {
|
---|
824 | cmd := "pragma " + v
|
---|
825 | _, err := c.exec(context.Background(), cmd, nil)
|
---|
826 | if err != nil {
|
---|
827 | return err
|
---|
828 | }
|
---|
829 | }
|
---|
830 |
|
---|
831 | if v := q.Get("_time_format"); v != "" {
|
---|
832 | f, ok := writeTimeFormats[v]
|
---|
833 | if !ok {
|
---|
834 | return fmt.Errorf("unknown _time_format %q", v)
|
---|
835 | }
|
---|
836 | c.writeTimeFormat = f
|
---|
837 | }
|
---|
838 |
|
---|
839 | if v := q.Get("_txlock"); v != "" {
|
---|
840 | lower := strings.ToLower(v)
|
---|
841 | if lower != "deferred" && lower != "immediate" && lower != "exclusive" {
|
---|
842 | return fmt.Errorf("unknown _txlock %q", v)
|
---|
843 | }
|
---|
844 | c.beginMode = v
|
---|
845 | }
|
---|
846 |
|
---|
847 | return nil
|
---|
848 | }
|
---|
849 |
|
---|
850 | // const void *sqlite3_column_blob(sqlite3_stmt*, int iCol);
|
---|
851 | func (c *conn) columnBlob(pstmt uintptr, iCol int) (v []byte, err error) {
|
---|
852 | p := sqlite3.Xsqlite3_column_blob(c.tls, pstmt, int32(iCol))
|
---|
853 | len, err := c.columnBytes(pstmt, iCol)
|
---|
854 | if err != nil {
|
---|
855 | return nil, err
|
---|
856 | }
|
---|
857 |
|
---|
858 | if p == 0 || len == 0 {
|
---|
859 | return nil, nil
|
---|
860 | }
|
---|
861 |
|
---|
862 | v = make([]byte, len)
|
---|
863 | copy(v, (*libc.RawMem)(unsafe.Pointer(p))[:len:len])
|
---|
864 | return v, nil
|
---|
865 | }
|
---|
866 |
|
---|
867 | // int sqlite3_column_bytes(sqlite3_stmt*, int iCol);
|
---|
868 | func (c *conn) columnBytes(pstmt uintptr, iCol int) (_ int, err error) {
|
---|
869 | v := sqlite3.Xsqlite3_column_bytes(c.tls, pstmt, int32(iCol))
|
---|
870 | return int(v), nil
|
---|
871 | }
|
---|
872 |
|
---|
873 | // const unsigned char *sqlite3_column_text(sqlite3_stmt*, int iCol);
|
---|
874 | func (c *conn) columnText(pstmt uintptr, iCol int) (v string, err error) {
|
---|
875 | p := sqlite3.Xsqlite3_column_text(c.tls, pstmt, int32(iCol))
|
---|
876 | len, err := c.columnBytes(pstmt, iCol)
|
---|
877 | if err != nil {
|
---|
878 | return "", err
|
---|
879 | }
|
---|
880 |
|
---|
881 | if p == 0 || len == 0 {
|
---|
882 | return "", nil
|
---|
883 | }
|
---|
884 |
|
---|
885 | b := make([]byte, len)
|
---|
886 | copy(b, (*libc.RawMem)(unsafe.Pointer(p))[:len:len])
|
---|
887 | return string(b), nil
|
---|
888 | }
|
---|
889 |
|
---|
890 | // double sqlite3_column_double(sqlite3_stmt*, int iCol);
|
---|
891 | func (c *conn) columnDouble(pstmt uintptr, iCol int) (v float64, err error) {
|
---|
892 | v = sqlite3.Xsqlite3_column_double(c.tls, pstmt, int32(iCol))
|
---|
893 | return v, nil
|
---|
894 | }
|
---|
895 |
|
---|
896 | // sqlite3_int64 sqlite3_column_int64(sqlite3_stmt*, int iCol);
|
---|
897 | func (c *conn) columnInt64(pstmt uintptr, iCol int) (v int64, err error) {
|
---|
898 | v = sqlite3.Xsqlite3_column_int64(c.tls, pstmt, int32(iCol))
|
---|
899 | return v, nil
|
---|
900 | }
|
---|
901 |
|
---|
902 | // int sqlite3_column_type(sqlite3_stmt*, int iCol);
|
---|
903 | func (c *conn) columnType(pstmt uintptr, iCol int) (_ int, err error) {
|
---|
904 | v := sqlite3.Xsqlite3_column_type(c.tls, pstmt, int32(iCol))
|
---|
905 | return int(v), nil
|
---|
906 | }
|
---|
907 |
|
---|
908 | // const char *sqlite3_column_decltype(sqlite3_stmt*,int);
|
---|
909 | func (c *conn) columnDeclType(pstmt uintptr, iCol int) string {
|
---|
910 | return libc.GoString(sqlite3.Xsqlite3_column_decltype(c.tls, pstmt, int32(iCol)))
|
---|
911 | }
|
---|
912 |
|
---|
913 | // const char *sqlite3_column_name(sqlite3_stmt*, int N);
|
---|
914 | func (c *conn) columnName(pstmt uintptr, n int) (string, error) {
|
---|
915 | p := sqlite3.Xsqlite3_column_name(c.tls, pstmt, int32(n))
|
---|
916 | return libc.GoString(p), nil
|
---|
917 | }
|
---|
918 |
|
---|
919 | // int sqlite3_column_count(sqlite3_stmt *pStmt);
|
---|
920 | func (c *conn) columnCount(pstmt uintptr) (_ int, err error) {
|
---|
921 | v := sqlite3.Xsqlite3_column_count(c.tls, pstmt)
|
---|
922 | return int(v), nil
|
---|
923 | }
|
---|
924 |
|
---|
925 | // sqlite3_int64 sqlite3_last_insert_rowid(sqlite3*);
|
---|
926 | func (c *conn) lastInsertRowID() (v int64, _ error) {
|
---|
927 | return sqlite3.Xsqlite3_last_insert_rowid(c.tls, c.db), nil
|
---|
928 | }
|
---|
929 |
|
---|
930 | // int sqlite3_changes(sqlite3*);
|
---|
931 | func (c *conn) changes() (int, error) {
|
---|
932 | v := sqlite3.Xsqlite3_changes(c.tls, c.db)
|
---|
933 | return int(v), nil
|
---|
934 | }
|
---|
935 |
|
---|
936 | // int sqlite3_step(sqlite3_stmt*);
|
---|
937 | func (c *conn) step(pstmt uintptr) (int, error) {
|
---|
938 | for {
|
---|
939 | switch rc := sqlite3.Xsqlite3_step(c.tls, pstmt); rc {
|
---|
940 | case sqliteLockedSharedcache:
|
---|
941 | if err := c.retry(pstmt); err != nil {
|
---|
942 | return sqlite3.SQLITE_LOCKED, err
|
---|
943 | }
|
---|
944 | case
|
---|
945 | sqlite3.SQLITE_DONE,
|
---|
946 | sqlite3.SQLITE_ROW:
|
---|
947 |
|
---|
948 | return int(rc), nil
|
---|
949 | default:
|
---|
950 | return int(rc), c.errstr(rc)
|
---|
951 | }
|
---|
952 | }
|
---|
953 | }
|
---|
954 |
|
---|
955 | func (c *conn) retry(pstmt uintptr) error {
|
---|
956 | mu := mutexAlloc(c.tls)
|
---|
957 | (*mutex)(unsafe.Pointer(mu)).Lock()
|
---|
958 | rc := sqlite3.Xsqlite3_unlock_notify(
|
---|
959 | c.tls,
|
---|
960 | c.db,
|
---|
961 | *(*uintptr)(unsafe.Pointer(&struct {
|
---|
962 | f func(*libc.TLS, uintptr, int32)
|
---|
963 | }{unlockNotify})),
|
---|
964 | mu,
|
---|
965 | )
|
---|
966 | if rc == sqlite3.SQLITE_LOCKED { // Deadlock, see https://www.sqlite.org/c3ref/unlock_notify.html
|
---|
967 | (*mutex)(unsafe.Pointer(mu)).Unlock()
|
---|
968 | mutexFree(c.tls, mu)
|
---|
969 | return c.errstr(rc)
|
---|
970 | }
|
---|
971 |
|
---|
972 | (*mutex)(unsafe.Pointer(mu)).Lock()
|
---|
973 | (*mutex)(unsafe.Pointer(mu)).Unlock()
|
---|
974 | mutexFree(c.tls, mu)
|
---|
975 | if pstmt != 0 {
|
---|
976 | sqlite3.Xsqlite3_reset(c.tls, pstmt)
|
---|
977 | }
|
---|
978 | return nil
|
---|
979 | }
|
---|
980 |
|
---|
981 | func unlockNotify(t *libc.TLS, ppArg uintptr, nArg int32) {
|
---|
982 | for i := int32(0); i < nArg; i++ {
|
---|
983 | mu := *(*uintptr)(unsafe.Pointer(ppArg))
|
---|
984 | (*mutex)(unsafe.Pointer(mu)).Unlock()
|
---|
985 | ppArg += ptrSize
|
---|
986 | }
|
---|
987 | }
|
---|
988 |
|
---|
989 | func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []uintptr, err error) {
|
---|
990 | defer func() {
|
---|
991 | if err == nil {
|
---|
992 | return
|
---|
993 | }
|
---|
994 |
|
---|
995 | for _, v := range allocs {
|
---|
996 | c.free(v)
|
---|
997 | }
|
---|
998 | allocs = nil
|
---|
999 | }()
|
---|
1000 |
|
---|
1001 | for i := 1; i <= n; i++ {
|
---|
1002 | name, err := c.bindParameterName(pstmt, i)
|
---|
1003 | if err != nil {
|
---|
1004 | return allocs, err
|
---|
1005 | }
|
---|
1006 |
|
---|
1007 | var found bool
|
---|
1008 | var v driver.NamedValue
|
---|
1009 | for _, v = range args {
|
---|
1010 | if name != "" {
|
---|
1011 | // For ?NNN and $NNN params, match if NNN == v.Ordinal.
|
---|
1012 | //
|
---|
1013 | // Supporting this for $NNN is a special case that makes eg
|
---|
1014 | // `select $1, $2, $3 ...` work without needing to use
|
---|
1015 | // sql.Named.
|
---|
1016 | if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) {
|
---|
1017 | found = true
|
---|
1018 | break
|
---|
1019 | }
|
---|
1020 |
|
---|
1021 | // sqlite supports '$', '@' and ':' prefixes for string
|
---|
1022 | // identifiers and '?' for numeric, so we cannot
|
---|
1023 | // combine different prefixes with the same name
|
---|
1024 | // because `database/sql` requires variable names
|
---|
1025 | // to start with a letter
|
---|
1026 | if name[1:] == v.Name[:] {
|
---|
1027 | found = true
|
---|
1028 | break
|
---|
1029 | }
|
---|
1030 | } else {
|
---|
1031 | if v.Ordinal == i {
|
---|
1032 | found = true
|
---|
1033 | break
|
---|
1034 | }
|
---|
1035 | }
|
---|
1036 | }
|
---|
1037 |
|
---|
1038 | if !found {
|
---|
1039 | if name != "" {
|
---|
1040 | return allocs, fmt.Errorf("missing named argument %q", name[1:])
|
---|
1041 | }
|
---|
1042 |
|
---|
1043 | return allocs, fmt.Errorf("missing argument with index %d", i)
|
---|
1044 | }
|
---|
1045 |
|
---|
1046 | var p uintptr
|
---|
1047 | switch x := v.Value.(type) {
|
---|
1048 | case int64:
|
---|
1049 | if err := c.bindInt64(pstmt, i, x); err != nil {
|
---|
1050 | return allocs, err
|
---|
1051 | }
|
---|
1052 | case float64:
|
---|
1053 | if err := c.bindDouble(pstmt, i, x); err != nil {
|
---|
1054 | return allocs, err
|
---|
1055 | }
|
---|
1056 | case bool:
|
---|
1057 | v := 0
|
---|
1058 | if x {
|
---|
1059 | v = 1
|
---|
1060 | }
|
---|
1061 | if err := c.bindInt(pstmt, i, v); err != nil {
|
---|
1062 | return allocs, err
|
---|
1063 | }
|
---|
1064 | case []byte:
|
---|
1065 | if p, err = c.bindBlob(pstmt, i, x); err != nil {
|
---|
1066 | return allocs, err
|
---|
1067 | }
|
---|
1068 | case string:
|
---|
1069 | if p, err = c.bindText(pstmt, i, x); err != nil {
|
---|
1070 | return allocs, err
|
---|
1071 | }
|
---|
1072 | case time.Time:
|
---|
1073 | if p, err = c.bindText(pstmt, i, c.formatTime(x)); err != nil {
|
---|
1074 | return allocs, err
|
---|
1075 | }
|
---|
1076 | case nil:
|
---|
1077 | if p, err = c.bindNull(pstmt, i); err != nil {
|
---|
1078 | return allocs, err
|
---|
1079 | }
|
---|
1080 | default:
|
---|
1081 | return allocs, fmt.Errorf("sqlite: invalid driver.Value type %T", x)
|
---|
1082 | }
|
---|
1083 | if p != 0 {
|
---|
1084 | allocs = append(allocs, p)
|
---|
1085 | }
|
---|
1086 | }
|
---|
1087 | return allocs, nil
|
---|
1088 | }
|
---|
1089 |
|
---|
1090 | // int sqlite3_bind_null(sqlite3_stmt*, int);
|
---|
1091 | func (c *conn) bindNull(pstmt uintptr, idx1 int) (uintptr, error) {
|
---|
1092 | if rc := sqlite3.Xsqlite3_bind_null(c.tls, pstmt, int32(idx1)); rc != sqlite3.SQLITE_OK {
|
---|
1093 | return 0, c.errstr(rc)
|
---|
1094 | }
|
---|
1095 |
|
---|
1096 | return 0, nil
|
---|
1097 | }
|
---|
1098 |
|
---|
1099 | // int sqlite3_bind_text(sqlite3_stmt*,int,const char*,int,void(*)(void*));
|
---|
1100 | func (c *conn) bindText(pstmt uintptr, idx1 int, value string) (uintptr, error) {
|
---|
1101 | p, err := libc.CString(value)
|
---|
1102 | if err != nil {
|
---|
1103 | return 0, err
|
---|
1104 | }
|
---|
1105 |
|
---|
1106 | if rc := sqlite3.Xsqlite3_bind_text(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK {
|
---|
1107 | c.free(p)
|
---|
1108 | return 0, c.errstr(rc)
|
---|
1109 | }
|
---|
1110 |
|
---|
1111 | return p, nil
|
---|
1112 | }
|
---|
1113 |
|
---|
1114 | // int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*));
|
---|
1115 | func (c *conn) bindBlob(pstmt uintptr, idx1 int, value []byte) (uintptr, error) {
|
---|
1116 | if value != nil && len(value) == 0 {
|
---|
1117 | if rc := sqlite3.Xsqlite3_bind_zeroblob(c.tls, pstmt, int32(idx1), 0); rc != sqlite3.SQLITE_OK {
|
---|
1118 | return 0, c.errstr(rc)
|
---|
1119 | }
|
---|
1120 | return 0, nil
|
---|
1121 | }
|
---|
1122 |
|
---|
1123 | p, err := c.malloc(len(value))
|
---|
1124 | if err != nil {
|
---|
1125 | return 0, err
|
---|
1126 | }
|
---|
1127 | if len(value) != 0 {
|
---|
1128 | copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value)
|
---|
1129 | }
|
---|
1130 | if rc := sqlite3.Xsqlite3_bind_blob(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK {
|
---|
1131 | c.free(p)
|
---|
1132 | return 0, c.errstr(rc)
|
---|
1133 | }
|
---|
1134 |
|
---|
1135 | return p, nil
|
---|
1136 | }
|
---|
1137 |
|
---|
1138 | // int sqlite3_bind_int(sqlite3_stmt*, int, int);
|
---|
1139 | func (c *conn) bindInt(pstmt uintptr, idx1, value int) (err error) {
|
---|
1140 | if rc := sqlite3.Xsqlite3_bind_int(c.tls, pstmt, int32(idx1), int32(value)); rc != sqlite3.SQLITE_OK {
|
---|
1141 | return c.errstr(rc)
|
---|
1142 | }
|
---|
1143 |
|
---|
1144 | return nil
|
---|
1145 | }
|
---|
1146 |
|
---|
1147 | // int sqlite3_bind_double(sqlite3_stmt*, int, double);
|
---|
1148 | func (c *conn) bindDouble(pstmt uintptr, idx1 int, value float64) (err error) {
|
---|
1149 | if rc := sqlite3.Xsqlite3_bind_double(c.tls, pstmt, int32(idx1), value); rc != 0 {
|
---|
1150 | return c.errstr(rc)
|
---|
1151 | }
|
---|
1152 |
|
---|
1153 | return nil
|
---|
1154 | }
|
---|
1155 |
|
---|
1156 | // int sqlite3_bind_int64(sqlite3_stmt*, int, sqlite3_int64);
|
---|
1157 | func (c *conn) bindInt64(pstmt uintptr, idx1 int, value int64) (err error) {
|
---|
1158 | if rc := sqlite3.Xsqlite3_bind_int64(c.tls, pstmt, int32(idx1), value); rc != sqlite3.SQLITE_OK {
|
---|
1159 | return c.errstr(rc)
|
---|
1160 | }
|
---|
1161 |
|
---|
1162 | return nil
|
---|
1163 | }
|
---|
1164 |
|
---|
1165 | // const char *sqlite3_bind_parameter_name(sqlite3_stmt*, int);
|
---|
1166 | func (c *conn) bindParameterName(pstmt uintptr, i int) (string, error) {
|
---|
1167 | p := sqlite3.Xsqlite3_bind_parameter_name(c.tls, pstmt, int32(i))
|
---|
1168 | return libc.GoString(p), nil
|
---|
1169 | }
|
---|
1170 |
|
---|
1171 | // int sqlite3_bind_parameter_count(sqlite3_stmt*);
|
---|
1172 | func (c *conn) bindParameterCount(pstmt uintptr) (_ int, err error) {
|
---|
1173 | r := sqlite3.Xsqlite3_bind_parameter_count(c.tls, pstmt)
|
---|
1174 | return int(r), nil
|
---|
1175 | }
|
---|
1176 |
|
---|
1177 | // int sqlite3_finalize(sqlite3_stmt *pStmt);
|
---|
1178 | func (c *conn) finalize(pstmt uintptr) error {
|
---|
1179 | if rc := sqlite3.Xsqlite3_finalize(c.tls, pstmt); rc != sqlite3.SQLITE_OK {
|
---|
1180 | return c.errstr(rc)
|
---|
1181 | }
|
---|
1182 |
|
---|
1183 | return nil
|
---|
1184 | }
|
---|
1185 |
|
---|
1186 | // int sqlite3_prepare_v2(
|
---|
1187 | // sqlite3 *db, /* Database handle */
|
---|
1188 | // const char *zSql, /* SQL statement, UTF-8 encoded */
|
---|
1189 | // int nByte, /* Maximum length of zSql in bytes. */
|
---|
1190 | // sqlite3_stmt **ppStmt, /* OUT: Statement handle */
|
---|
1191 | // const char **pzTail /* OUT: Pointer to unused portion of zSql */
|
---|
1192 | // );
|
---|
1193 | func (c *conn) prepareV2(zSQL *uintptr) (pstmt uintptr, err error) {
|
---|
1194 | var ppstmt, pptail uintptr
|
---|
1195 |
|
---|
1196 | defer func() {
|
---|
1197 | c.free(ppstmt)
|
---|
1198 | c.free(pptail)
|
---|
1199 | }()
|
---|
1200 |
|
---|
1201 | if ppstmt, err = c.malloc(int(ptrSize)); err != nil {
|
---|
1202 | return 0, err
|
---|
1203 | }
|
---|
1204 |
|
---|
1205 | if pptail, err = c.malloc(int(ptrSize)); err != nil {
|
---|
1206 | return 0, err
|
---|
1207 | }
|
---|
1208 |
|
---|
1209 | for {
|
---|
1210 | switch rc := sqlite3.Xsqlite3_prepare_v2(c.tls, c.db, *zSQL, -1, ppstmt, pptail); rc {
|
---|
1211 | case sqlite3.SQLITE_OK:
|
---|
1212 | *zSQL = *(*uintptr)(unsafe.Pointer(pptail))
|
---|
1213 | return *(*uintptr)(unsafe.Pointer(ppstmt)), nil
|
---|
1214 | case sqliteLockedSharedcache:
|
---|
1215 | if err := c.retry(0); err != nil {
|
---|
1216 | return 0, err
|
---|
1217 | }
|
---|
1218 | default:
|
---|
1219 | return 0, c.errstr(rc)
|
---|
1220 | }
|
---|
1221 | }
|
---|
1222 | }
|
---|
1223 |
|
---|
1224 | // void sqlite3_interrupt(sqlite3*);
|
---|
1225 | func (c *conn) interrupt(pdb uintptr) (err error) {
|
---|
1226 | c.Lock() // Defend against race with .Close invoked by context handling.
|
---|
1227 |
|
---|
1228 | defer c.Unlock()
|
---|
1229 |
|
---|
1230 | if c.tls != nil {
|
---|
1231 | sqlite3.Xsqlite3_interrupt(c.tls, pdb)
|
---|
1232 | }
|
---|
1233 | return nil
|
---|
1234 | }
|
---|
1235 |
|
---|
1236 | // int sqlite3_extended_result_codes(sqlite3*, int onoff);
|
---|
1237 | func (c *conn) extendedResultCodes(on bool) error {
|
---|
1238 | if rc := sqlite3.Xsqlite3_extended_result_codes(c.tls, c.db, libc.Bool32(on)); rc != sqlite3.SQLITE_OK {
|
---|
1239 | return c.errstr(rc)
|
---|
1240 | }
|
---|
1241 |
|
---|
1242 | return nil
|
---|
1243 | }
|
---|
1244 |
|
---|
1245 | // int sqlite3_open_v2(
|
---|
1246 | // const char *filename, /* Database filename (UTF-8) */
|
---|
1247 | // sqlite3 **ppDb, /* OUT: SQLite db handle */
|
---|
1248 | // int flags, /* Flags */
|
---|
1249 | // const char *zVfs /* Name of VFS module to use */
|
---|
1250 | // );
|
---|
1251 | func (c *conn) openV2(name, vfsName string, flags int32) (uintptr, error) {
|
---|
1252 | var p, s, vfs uintptr
|
---|
1253 |
|
---|
1254 | defer func() {
|
---|
1255 | if p != 0 {
|
---|
1256 | c.free(p)
|
---|
1257 | }
|
---|
1258 | if s != 0 {
|
---|
1259 | c.free(s)
|
---|
1260 | }
|
---|
1261 | if vfs != 0 {
|
---|
1262 | c.free(vfs)
|
---|
1263 | }
|
---|
1264 | }()
|
---|
1265 |
|
---|
1266 | p, err := c.malloc(int(ptrSize))
|
---|
1267 | if err != nil {
|
---|
1268 | return 0, err
|
---|
1269 | }
|
---|
1270 |
|
---|
1271 | if s, err = libc.CString(name); err != nil {
|
---|
1272 | return 0, err
|
---|
1273 | }
|
---|
1274 |
|
---|
1275 | if vfsName != "" {
|
---|
1276 | if vfs, err = libc.CString(vfsName); err != nil {
|
---|
1277 | return 0, err
|
---|
1278 | }
|
---|
1279 | }
|
---|
1280 |
|
---|
1281 | if rc := sqlite3.Xsqlite3_open_v2(c.tls, s, p, flags, vfs); rc != sqlite3.SQLITE_OK {
|
---|
1282 | return 0, c.errstr(rc)
|
---|
1283 | }
|
---|
1284 |
|
---|
1285 | return *(*uintptr)(unsafe.Pointer(p)), nil
|
---|
1286 | }
|
---|
1287 |
|
---|
1288 | func (c *conn) malloc(n int) (uintptr, error) {
|
---|
1289 | if p := libc.Xmalloc(c.tls, types.Size_t(n)); p != 0 || n == 0 {
|
---|
1290 | return p, nil
|
---|
1291 | }
|
---|
1292 |
|
---|
1293 | return 0, fmt.Errorf("sqlite: cannot allocate %d bytes of memory", n)
|
---|
1294 | }
|
---|
1295 |
|
---|
1296 | func (c *conn) free(p uintptr) {
|
---|
1297 | if p != 0 {
|
---|
1298 | libc.Xfree(c.tls, p)
|
---|
1299 | }
|
---|
1300 | }
|
---|
1301 |
|
---|
1302 | // const char *sqlite3_errstr(int);
|
---|
1303 | func (c *conn) errstr(rc int32) error {
|
---|
1304 | p := sqlite3.Xsqlite3_errstr(c.tls, rc)
|
---|
1305 | str := libc.GoString(p)
|
---|
1306 | p = sqlite3.Xsqlite3_errmsg(c.tls, c.db)
|
---|
1307 | var s string
|
---|
1308 | if rc == sqlite3.SQLITE_BUSY {
|
---|
1309 | s = " (SQLITE_BUSY)"
|
---|
1310 | }
|
---|
1311 | switch msg := libc.GoString(p); {
|
---|
1312 | case msg == str:
|
---|
1313 | return &Error{msg: fmt.Sprintf("%s (%v)%s", str, rc, s), code: int(rc)}
|
---|
1314 | default:
|
---|
1315 | return &Error{msg: fmt.Sprintf("%s: %s (%v)%s", str, msg, rc, s), code: int(rc)}
|
---|
1316 | }
|
---|
1317 | }
|
---|
1318 |
|
---|
1319 | // Begin starts a transaction.
|
---|
1320 | //
|
---|
1321 | // Deprecated: Drivers should implement ConnBeginTx instead (or additionally).
|
---|
1322 | func (c *conn) Begin() (driver.Tx, error) {
|
---|
1323 | return c.begin(context.Background(), driver.TxOptions{})
|
---|
1324 | }
|
---|
1325 |
|
---|
1326 | func (c *conn) begin(ctx context.Context, opts driver.TxOptions) (t driver.Tx, err error) {
|
---|
1327 | return newTx(c, opts)
|
---|
1328 | }
|
---|
1329 |
|
---|
1330 | // Close invalidates and potentially stops any current prepared statements and
|
---|
1331 | // transactions, marking this connection as no longer in use.
|
---|
1332 | //
|
---|
1333 | // Because the sql package maintains a free pool of connections and only calls
|
---|
1334 | // Close when there's a surplus of idle connections, it shouldn't be necessary
|
---|
1335 | // for drivers to do their own connection caching.
|
---|
1336 | func (c *conn) Close() error {
|
---|
1337 | c.Lock() // Defend against race with .interrupt invoked by context handling.
|
---|
1338 |
|
---|
1339 | defer c.Unlock()
|
---|
1340 |
|
---|
1341 | if c.db != 0 {
|
---|
1342 | if err := c.closeV2(c.db); err != nil {
|
---|
1343 | return err
|
---|
1344 | }
|
---|
1345 |
|
---|
1346 | c.db = 0
|
---|
1347 | }
|
---|
1348 |
|
---|
1349 | if c.tls != nil {
|
---|
1350 | c.tls.Close()
|
---|
1351 | c.tls = nil
|
---|
1352 | }
|
---|
1353 | return nil
|
---|
1354 | }
|
---|
1355 |
|
---|
1356 | // int sqlite3_close_v2(sqlite3*);
|
---|
1357 | func (c *conn) closeV2(db uintptr) error {
|
---|
1358 | if rc := sqlite3.Xsqlite3_close_v2(c.tls, db); rc != sqlite3.SQLITE_OK {
|
---|
1359 | return c.errstr(rc)
|
---|
1360 | }
|
---|
1361 |
|
---|
1362 | return nil
|
---|
1363 | }
|
---|
1364 |
|
---|
1365 | type userDefinedFunction struct {
|
---|
1366 | zFuncName uintptr
|
---|
1367 | nArg int32
|
---|
1368 | eTextRep int32
|
---|
1369 | xFunc func(*libc.TLS, uintptr, int32, uintptr)
|
---|
1370 |
|
---|
1371 | freeOnce sync.Once
|
---|
1372 | }
|
---|
1373 |
|
---|
1374 | func (c *conn) createFunctionInternal(fun *userDefinedFunction) error {
|
---|
1375 | if rc := sqlite3.Xsqlite3_create_function(
|
---|
1376 | c.tls,
|
---|
1377 | c.db,
|
---|
1378 | fun.zFuncName,
|
---|
1379 | fun.nArg,
|
---|
1380 | fun.eTextRep,
|
---|
1381 | 0,
|
---|
1382 | *(*uintptr)(unsafe.Pointer(&fun.xFunc)),
|
---|
1383 | 0,
|
---|
1384 | 0,
|
---|
1385 | ); rc != sqlite3.SQLITE_OK {
|
---|
1386 | return c.errstr(rc)
|
---|
1387 | }
|
---|
1388 | return nil
|
---|
1389 | }
|
---|
1390 |
|
---|
1391 | // Execer is an optional interface that may be implemented by a Conn.
|
---|
1392 | //
|
---|
1393 | // If a Conn does not implement Execer, the sql package's DB.Exec will first
|
---|
1394 | // prepare a query, execute the statement, and then close the statement.
|
---|
1395 | //
|
---|
1396 | // Exec may return ErrSkip.
|
---|
1397 | //
|
---|
1398 | // Deprecated: Drivers should implement ExecerContext instead.
|
---|
1399 | func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
---|
1400 | return c.exec(context.Background(), query, toNamedValues(args))
|
---|
1401 | }
|
---|
1402 |
|
---|
1403 | func (c *conn) exec(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
|
---|
1404 | s, err := c.prepare(ctx, query)
|
---|
1405 | if err != nil {
|
---|
1406 | return nil, err
|
---|
1407 | }
|
---|
1408 |
|
---|
1409 | defer func() {
|
---|
1410 | if err2 := s.Close(); err2 != nil && err == nil {
|
---|
1411 | err = err2
|
---|
1412 | }
|
---|
1413 | }()
|
---|
1414 |
|
---|
1415 | return s.(*stmt).exec(ctx, args)
|
---|
1416 | }
|
---|
1417 |
|
---|
1418 | // Prepare returns a prepared statement, bound to this connection.
|
---|
1419 | func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
---|
1420 | return c.prepare(context.Background(), query)
|
---|
1421 | }
|
---|
1422 |
|
---|
1423 | func (c *conn) prepare(ctx context.Context, query string) (s driver.Stmt, err error) {
|
---|
1424 | //TODO use ctx
|
---|
1425 | return newStmt(c, query)
|
---|
1426 | }
|
---|
1427 |
|
---|
1428 | // Queryer is an optional interface that may be implemented by a Conn.
|
---|
1429 | //
|
---|
1430 | // If a Conn does not implement Queryer, the sql package's DB.Query will first
|
---|
1431 | // prepare a query, execute the statement, and then close the statement.
|
---|
1432 | //
|
---|
1433 | // Query may return ErrSkip.
|
---|
1434 | //
|
---|
1435 | // Deprecated: Drivers should implement QueryerContext instead.
|
---|
1436 | func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
---|
1437 | return c.query(context.Background(), query, toNamedValues(args))
|
---|
1438 | }
|
---|
1439 |
|
---|
1440 | func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) {
|
---|
1441 | s, err := c.prepare(ctx, query)
|
---|
1442 | if err != nil {
|
---|
1443 | return nil, err
|
---|
1444 | }
|
---|
1445 |
|
---|
1446 | defer func() {
|
---|
1447 | if err2 := s.Close(); err2 != nil && err == nil {
|
---|
1448 | err = err2
|
---|
1449 | }
|
---|
1450 | }()
|
---|
1451 |
|
---|
1452 | return s.(*stmt).query(ctx, args)
|
---|
1453 | }
|
---|
1454 |
|
---|
1455 | // Driver implements database/sql/driver.Driver.
|
---|
1456 | type Driver struct {
|
---|
1457 | // user defined functions that are added to every new connection on Open
|
---|
1458 | udfs map[string]*userDefinedFunction
|
---|
1459 | }
|
---|
1460 |
|
---|
1461 | var d = &Driver{udfs: make(map[string]*userDefinedFunction)}
|
---|
1462 |
|
---|
1463 | func newDriver() *Driver { return d }
|
---|
1464 |
|
---|
1465 | // Open returns a new connection to the database. The name is a string in a
|
---|
1466 | // driver-specific format.
|
---|
1467 | //
|
---|
1468 | // Open may return a cached connection (one previously closed), but doing so is
|
---|
1469 | // unnecessary; the sql package maintains a pool of idle connections for
|
---|
1470 | // efficient re-use.
|
---|
1471 | //
|
---|
1472 | // The returned connection is only used by one goroutine at a time.
|
---|
1473 | //
|
---|
1474 | // If name contains a '?', what follows is treated as a query string. This
|
---|
1475 | // driver supports the following query parameters:
|
---|
1476 | //
|
---|
1477 | // _pragma: Each value will be run as a "PRAGMA ..." statement (with the PRAGMA
|
---|
1478 | // keyword added for you). May be specified more than once. Example:
|
---|
1479 | // "_pragma=foreign_keys(1)" will enable foreign key enforcement. More
|
---|
1480 | // information on supported PRAGMAs is available from the SQLite documentation:
|
---|
1481 | // https://www.sqlite.org/pragma.html
|
---|
1482 | //
|
---|
1483 | // _time_format: The name of a format to use when writing time values to the
|
---|
1484 | // database. Currently the only supported value is "sqlite", which corresponds
|
---|
1485 | // to format 7 from https://www.sqlite.org/lang_datefunc.html#time_values,
|
---|
1486 | // including the timezone specifier. If this parameter is not specified, then
|
---|
1487 | // the default String() format will be used.
|
---|
1488 | //
|
---|
1489 | // _txlock: The locking behavior to use when beginning a transaction. May be
|
---|
1490 | // "deferred", "immediate", or "exclusive" (case insensitive). The default is to
|
---|
1491 | // not specify one, which SQLite maps to "deferred". More information is
|
---|
1492 | // available at
|
---|
1493 | // https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
|
---|
1494 | func (d *Driver) Open(name string) (driver.Conn, error) {
|
---|
1495 | c, err := newConn(name)
|
---|
1496 | if err != nil {
|
---|
1497 | return nil, err
|
---|
1498 | }
|
---|
1499 |
|
---|
1500 | for _, udf := range d.udfs {
|
---|
1501 | if err = c.createFunctionInternal(udf); err != nil {
|
---|
1502 | c.Close()
|
---|
1503 | return nil, err
|
---|
1504 | }
|
---|
1505 | }
|
---|
1506 | return c, nil
|
---|
1507 | }
|
---|
1508 |
|
---|
1509 | // FunctionContext represents the context user defined functions execute in.
|
---|
1510 | // Fields and/or methods of this type may get addedd in the future.
|
---|
1511 | type FunctionContext struct{}
|
---|
1512 |
|
---|
1513 | const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{})
|
---|
1514 |
|
---|
1515 | // RegisterScalarFunction registers a scalar function named zFuncName with nArg
|
---|
1516 | // arguments. Passing -1 for nArg indicates the function is variadic.
|
---|
1517 | //
|
---|
1518 | // The new function will be available to all new connections opened after
|
---|
1519 | // executing RegisterScalarFunction.
|
---|
1520 | func RegisterScalarFunction(
|
---|
1521 | zFuncName string,
|
---|
1522 | nArg int32,
|
---|
1523 | xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
|
---|
1524 | ) error {
|
---|
1525 | return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8, xFunc)
|
---|
1526 | }
|
---|
1527 |
|
---|
1528 | // MustRegisterScalarFunction is like RegisterScalarFunction but panics on
|
---|
1529 | // error.
|
---|
1530 | func MustRegisterScalarFunction(
|
---|
1531 | zFuncName string,
|
---|
1532 | nArg int32,
|
---|
1533 | xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
|
---|
1534 | ) {
|
---|
1535 | if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil {
|
---|
1536 | panic(err)
|
---|
1537 | }
|
---|
1538 | }
|
---|
1539 |
|
---|
1540 | // MustRegisterDeterministicScalarFunction is like
|
---|
1541 | // RegisterDeterministicScalarFunction but panics on error.
|
---|
1542 | func MustRegisterDeterministicScalarFunction(
|
---|
1543 | zFuncName string,
|
---|
1544 | nArg int32,
|
---|
1545 | xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
|
---|
1546 | ) {
|
---|
1547 | if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil {
|
---|
1548 | panic(err)
|
---|
1549 | }
|
---|
1550 | }
|
---|
1551 |
|
---|
1552 | // RegisterDeterministicScalarFunction registers a deterministic scalar
|
---|
1553 | // function named zFuncName with nArg arguments. Passing -1 for nArg indicates
|
---|
1554 | // the function is variadic. A deterministic function means that the function
|
---|
1555 | // always gives the same output when the input parameters are the same.
|
---|
1556 | //
|
---|
1557 | // The new function will be available to all new connections opened after
|
---|
1558 | // executing RegisterDeterministicScalarFunction.
|
---|
1559 | func RegisterDeterministicScalarFunction(
|
---|
1560 | zFuncName string,
|
---|
1561 | nArg int32,
|
---|
1562 | xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
|
---|
1563 | ) error {
|
---|
1564 | return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8|sqlite3.SQLITE_DETERMINISTIC, xFunc)
|
---|
1565 | }
|
---|
1566 |
|
---|
1567 | func registerScalarFunction(
|
---|
1568 | zFuncName string,
|
---|
1569 | nArg int32,
|
---|
1570 | eTextRep int32,
|
---|
1571 | xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
|
---|
1572 | ) error {
|
---|
1573 |
|
---|
1574 | if _, ok := d.udfs[zFuncName]; ok {
|
---|
1575 | return fmt.Errorf("a function named %q is already registered", zFuncName)
|
---|
1576 | }
|
---|
1577 |
|
---|
1578 | // dont free, functions registered on the driver live as long as the program
|
---|
1579 | name, err := libc.CString(zFuncName)
|
---|
1580 | if err != nil {
|
---|
1581 | return err
|
---|
1582 | }
|
---|
1583 |
|
---|
1584 | udf := &userDefinedFunction{
|
---|
1585 | zFuncName: name,
|
---|
1586 | nArg: nArg,
|
---|
1587 | eTextRep: eTextRep,
|
---|
1588 | xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
|
---|
1589 | setErrorResult := func(res error) {
|
---|
1590 | errmsg, cerr := libc.CString(res.Error())
|
---|
1591 | if cerr != nil {
|
---|
1592 | panic(cerr)
|
---|
1593 | }
|
---|
1594 | defer libc.Xfree(tls, errmsg)
|
---|
1595 | sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1)
|
---|
1596 | sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR)
|
---|
1597 | }
|
---|
1598 |
|
---|
1599 | args := make([]driver.Value, argc)
|
---|
1600 | for i := int32(0); i < argc; i++ {
|
---|
1601 | valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize))
|
---|
1602 |
|
---|
1603 | switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType {
|
---|
1604 | case sqlite3.SQLITE_TEXT:
|
---|
1605 | args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr))
|
---|
1606 | case sqlite3.SQLITE_INTEGER:
|
---|
1607 | args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr)
|
---|
1608 | case sqlite3.SQLITE_FLOAT:
|
---|
1609 | args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr)
|
---|
1610 | case sqlite3.SQLITE_NULL:
|
---|
1611 | args[i] = nil
|
---|
1612 | case sqlite3.SQLITE_BLOB:
|
---|
1613 | size := sqlite3.Xsqlite3_value_bytes(tls, valPtr)
|
---|
1614 | blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr)
|
---|
1615 | v := make([]byte, size)
|
---|
1616 | copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size])
|
---|
1617 | args[i] = v
|
---|
1618 | default:
|
---|
1619 | panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType))
|
---|
1620 | }
|
---|
1621 | }
|
---|
1622 |
|
---|
1623 | res, err := xFunc(&FunctionContext{}, args)
|
---|
1624 | if err != nil {
|
---|
1625 | setErrorResult(err)
|
---|
1626 | return
|
---|
1627 | }
|
---|
1628 |
|
---|
1629 | switch resTyped := res.(type) {
|
---|
1630 | case nil:
|
---|
1631 | sqlite3.Xsqlite3_result_null(tls, ctx)
|
---|
1632 | case int64:
|
---|
1633 | sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped)
|
---|
1634 | case float64:
|
---|
1635 | sqlite3.Xsqlite3_result_double(tls, ctx, resTyped)
|
---|
1636 | case bool:
|
---|
1637 | sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped))
|
---|
1638 | case time.Time:
|
---|
1639 | sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix())
|
---|
1640 | case string:
|
---|
1641 | size := int32(len(resTyped))
|
---|
1642 | cstr, err := libc.CString(resTyped)
|
---|
1643 | if err != nil {
|
---|
1644 | panic(err)
|
---|
1645 | }
|
---|
1646 | defer libc.Xfree(tls, cstr)
|
---|
1647 | sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT)
|
---|
1648 | case []byte:
|
---|
1649 | size := int32(len(resTyped))
|
---|
1650 | if size == 0 {
|
---|
1651 | sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0)
|
---|
1652 | return
|
---|
1653 | }
|
---|
1654 | p := libc.Xmalloc(tls, types.Size_t(size))
|
---|
1655 | if p == 0 {
|
---|
1656 | panic(fmt.Sprintf("unable to allocate space for blob: %d", size))
|
---|
1657 | }
|
---|
1658 | defer libc.Xfree(tls, p)
|
---|
1659 | copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped)
|
---|
1660 |
|
---|
1661 | sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT)
|
---|
1662 | default:
|
---|
1663 | setErrorResult(fmt.Errorf("function did not return a valid driver.Value: %T", resTyped))
|
---|
1664 | return
|
---|
1665 | }
|
---|
1666 | },
|
---|
1667 | }
|
---|
1668 | d.udfs[zFuncName] = udf
|
---|
1669 |
|
---|
1670 | return nil
|
---|
1671 | }
|
---|