aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/query.go')
-rw-r--r--vendor/github.com/jackc/pgx/query.go494
1 files changed, 494 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/query.go b/vendor/github.com/jackc/pgx/query.go
new file mode 100644
index 0000000..19b867e
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/query.go
@@ -0,0 +1,494 @@
+package pgx
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "time"
+)
+
+// Row is a convenience wrapper over Rows that is returned by QueryRow.
+type Row Rows
+
+// Scan works the same as (*Rows Scan) with the following exceptions. If no
+// rows were found it returns ErrNoRows. If multiple rows are returned it
+// ignores all but the first.
+func (r *Row) Scan(dest ...interface{}) (err error) {
+ rows := (*Rows)(r)
+
+ if rows.Err() != nil {
+ return rows.Err()
+ }
+
+ if !rows.Next() {
+ if rows.Err() == nil {
+ return ErrNoRows
+ }
+ return rows.Err()
+ }
+
+ rows.Scan(dest...)
+ rows.Close()
+ return rows.Err()
+}
+
+// Rows is the result set returned from *Conn.Query. Rows must be closed before
+// the *Conn can be used again. Rows are closed by explicitly calling Close(),
+// calling Next() until it returns false, or when a fatal error occurs.
+type Rows struct {
+ conn *Conn
+ mr *msgReader
+ fields []FieldDescription
+ vr ValueReader
+ rowCount int
+ columnIdx int
+ err error
+ startTime time.Time
+ sql string
+ args []interface{}
+ afterClose func(*Rows)
+ unlockConn bool
+ closed bool
+}
+
+func (rows *Rows) FieldDescriptions() []FieldDescription {
+ return rows.fields
+}
+
+func (rows *Rows) close() {
+ if rows.closed {
+ return
+ }
+
+ if rows.unlockConn {
+ rows.conn.unlock()
+ rows.unlockConn = false
+ }
+
+ rows.closed = true
+
+ if rows.err == nil {
+ if rows.conn.shouldLog(LogLevelInfo) {
+ endTime := time.Now()
+ rows.conn.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount)
+ }
+ } else if rows.conn.shouldLog(LogLevelError) {
+ rows.conn.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args))
+ }
+
+ if rows.afterClose != nil {
+ rows.afterClose(rows)
+ }
+}
+
+func (rows *Rows) readUntilReadyForQuery() {
+ for {
+ t, r, err := rows.conn.rxMsg()
+ if err != nil {
+ rows.close()
+ return
+ }
+
+ switch t {
+ case readyForQuery:
+ rows.conn.rxReadyForQuery(r)
+ rows.close()
+ return
+ case rowDescription:
+ case dataRow:
+ case commandComplete:
+ case bindComplete:
+ case errorResponse:
+ err = rows.conn.rxErrorResponse(r)
+ if rows.err == nil {
+ rows.err = err
+ }
+ default:
+ err = rows.conn.processContextFreeMsg(t, r)
+ if err != nil {
+ rows.close()
+ return
+ }
+ }
+ }
+}
+
+// Close closes the rows, making the connection ready for use again. It is safe
+// to call Close after rows is already closed.
+func (rows *Rows) Close() {
+ if rows.closed {
+ return
+ }
+ rows.readUntilReadyForQuery()
+ rows.close()
+}
+
+func (rows *Rows) Err() error {
+ return rows.err
+}
+
+// abort signals that the query was not successfully sent to the server.
+// This differs from Fatal in that it is not necessary to readUntilReadyForQuery
+func (rows *Rows) abort(err error) {
+ if rows.err != nil {
+ return
+ }
+
+ rows.err = err
+ rows.close()
+}
+
+// Fatal signals an error occurred after the query was sent to the server. It
+// closes the rows automatically.
+func (rows *Rows) Fatal(err error) {
+ if rows.err != nil {
+ return
+ }
+
+ rows.err = err
+ rows.Close()
+}
+
+// Next prepares the next row for reading. It returns true if there is another
+// row and false if no more rows are available. It automatically closes rows
+// when all rows are read.
+func (rows *Rows) Next() bool {
+ if rows.closed {
+ return false
+ }
+
+ rows.rowCount++
+ rows.columnIdx = 0
+ rows.vr = ValueReader{}
+
+ for {
+ t, r, err := rows.conn.rxMsg()
+ if err != nil {
+ rows.Fatal(err)
+ return false
+ }
+
+ switch t {
+ case readyForQuery:
+ rows.conn.rxReadyForQuery(r)
+ rows.close()
+ return false
+ case dataRow:
+ fieldCount := r.readInt16()
+ if int(fieldCount) != len(rows.fields) {
+ rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
+ return false
+ }
+
+ rows.mr = r
+ return true
+ case commandComplete:
+ case bindComplete:
+ default:
+ err = rows.conn.processContextFreeMsg(t, r)
+ if err != nil {
+ rows.Fatal(err)
+ return false
+ }
+ }
+ }
+}
+
+// Conn returns the *Conn this *Rows is using.
+func (rows *Rows) Conn() *Conn {
+ return rows.conn
+}
+
+func (rows *Rows) nextColumn() (*ValueReader, bool) {
+ if rows.closed {
+ return nil, false
+ }
+ if len(rows.fields) <= rows.columnIdx {
+ rows.Fatal(ProtocolError("No next column available"))
+ return nil, false
+ }
+
+ if rows.vr.Len() > 0 {
+ rows.mr.readBytes(rows.vr.Len())
+ }
+
+ fd := &rows.fields[rows.columnIdx]
+ rows.columnIdx++
+ size := rows.mr.readInt32()
+ rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size}
+ return &rows.vr, true
+}
+
+type scanArgError struct {
+ col int
+ err error
+}
+
+func (e scanArgError) Error() string {
+ return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
+}
+
+// Scan reads the values from the current row into dest values positionally.
+// dest can include pointers to core types, values implementing the Scanner
+// interface, []byte, and nil. []byte will skip the decoding process and directly
+// copy the raw bytes received from PostgreSQL. nil will skip the value entirely.
+func (rows *Rows) Scan(dest ...interface{}) (err error) {
+ if len(rows.fields) != len(dest) {
+ err = fmt.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
+ rows.Fatal(err)
+ return err
+ }
+
+ for i, d := range dest {
+ vr, _ := rows.nextColumn()
+
+ if d == nil {
+ continue
+ }
+
+ // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes
+ if b, ok := d.(*[]byte); ok {
+ // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format)
+ // Otherwise read the bytes directly regardless of what the actual type is.
+ if vr.Type().DataType == ByteaOid {
+ *b = decodeBytea(vr)
+ } else {
+ if vr.Len() != -1 {
+ *b = vr.ReadBytes(vr.Len())
+ } else {
+ *b = nil
+ }
+ }
+ } else if s, ok := d.(Scanner); ok {
+ err = s.Scan(vr)
+ if err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ } else if s, ok := d.(PgxScanner); ok {
+ err = s.ScanPgx(vr)
+ if err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ } else if s, ok := d.(sql.Scanner); ok {
+ var val interface{}
+ if 0 <= vr.Len() {
+ switch vr.Type().DataType {
+ case BoolOid:
+ val = decodeBool(vr)
+ case Int8Oid:
+ val = int64(decodeInt8(vr))
+ case Int2Oid:
+ val = int64(decodeInt2(vr))
+ case Int4Oid:
+ val = int64(decodeInt4(vr))
+ case TextOid, VarcharOid:
+ val = decodeText(vr)
+ case OidOid:
+ val = int64(decodeOid(vr))
+ case Float4Oid:
+ val = float64(decodeFloat4(vr))
+ case Float8Oid:
+ val = decodeFloat8(vr)
+ case DateOid:
+ val = decodeDate(vr)
+ case TimestampOid:
+ val = decodeTimestamp(vr)
+ case TimestampTzOid:
+ val = decodeTimestampTz(vr)
+ default:
+ val = vr.ReadBytes(vr.Len())
+ }
+ }
+ err = s.Scan(val)
+ if err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ } else if vr.Type().DataType == JsonOid {
+ // Because the argument passed to decodeJSON will escape the heap.
+ // This allows d to be stack allocated and only copied to the heap when
+ // we actually are decoding JSON. This saves one memory allocation per
+ // row.
+ d2 := d
+ decodeJSON(vr, &d2)
+ } else if vr.Type().DataType == JsonbOid {
+ // Same trick as above for getting stack allocation
+ d2 := d
+ decodeJSONB(vr, &d2)
+ } else {
+ if err := Decode(vr, d); err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ }
+ if vr.Err() != nil {
+ rows.Fatal(scanArgError{col: i, err: vr.Err()})
+ }
+
+ if rows.Err() != nil {
+ return rows.Err()
+ }
+ }
+
+ return nil
+}
+
+// Values returns an array of the row values
+func (rows *Rows) Values() ([]interface{}, error) {
+ if rows.closed {
+ return nil, errors.New("rows is closed")
+ }
+
+ values := make([]interface{}, 0, len(rows.fields))
+
+ for range rows.fields {
+ vr, _ := rows.nextColumn()
+
+ if vr.Len() == -1 {
+ values = append(values, nil)
+ continue
+ }
+
+ switch vr.Type().FormatCode {
+ // All intrinsic types (except string) are encoded with binary
+ // encoding so anything else should be treated as a string
+ case TextFormatCode:
+ values = append(values, vr.ReadString(vr.Len()))
+ case BinaryFormatCode:
+ switch vr.Type().DataType {
+ case TextOid, VarcharOid:
+ values = append(values, decodeText(vr))
+ case BoolOid:
+ values = append(values, decodeBool(vr))
+ case ByteaOid:
+ values = append(values, decodeBytea(vr))
+ case Int8Oid:
+ values = append(values, decodeInt8(vr))
+ case Int2Oid:
+ values = append(values, decodeInt2(vr))
+ case Int4Oid:
+ values = append(values, decodeInt4(vr))
+ case OidOid:
+ values = append(values, decodeOid(vr))
+ case Float4Oid:
+ values = append(values, decodeFloat4(vr))
+ case Float8Oid:
+ values = append(values, decodeFloat8(vr))
+ case BoolArrayOid:
+ values = append(values, decodeBoolArray(vr))
+ case Int2ArrayOid:
+ values = append(values, decodeInt2Array(vr))
+ case Int4ArrayOid:
+ values = append(values, decodeInt4Array(vr))
+ case Int8ArrayOid:
+ values = append(values, decodeInt8Array(vr))
+ case Float4ArrayOid:
+ values = append(values, decodeFloat4Array(vr))
+ case Float8ArrayOid:
+ values = append(values, decodeFloat8Array(vr))
+ case TextArrayOid, VarcharArrayOid:
+ values = append(values, decodeTextArray(vr))
+ case TimestampArrayOid, TimestampTzArrayOid:
+ values = append(values, decodeTimestampArray(vr))
+ case DateOid:
+ values = append(values, decodeDate(vr))
+ case TimestampTzOid:
+ values = append(values, decodeTimestampTz(vr))
+ case TimestampOid:
+ values = append(values, decodeTimestamp(vr))
+ case InetOid, CidrOid:
+ values = append(values, decodeInet(vr))
+ case JsonOid:
+ var d interface{}
+ decodeJSON(vr, &d)
+ values = append(values, d)
+ case JsonbOid:
+ var d interface{}
+ decodeJSONB(vr, &d)
+ values = append(values, d)
+ default:
+ rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
+ }
+ default:
+ rows.Fatal(errors.New("Unknown format code"))
+ }
+
+ if vr.Err() != nil {
+ rows.Fatal(vr.Err())
+ }
+
+ if rows.Err() != nil {
+ return nil, rows.Err()
+ }
+ }
+
+ return values, rows.Err()
+}
+
+// AfterClose adds f to a LILO queue of functions that will be called when
+// rows is closed.
+func (rows *Rows) AfterClose(f func(*Rows)) {
+ if rows.afterClose == nil {
+ rows.afterClose = f
+ } else {
+ prevFn := rows.afterClose
+ rows.afterClose = func(rows *Rows) {
+ f(rows)
+ prevFn(rows)
+ }
+ }
+}
+
+// Query executes sql with args. If there is an error the returned *Rows will
+// be returned in an error state. So it is allowed to ignore the error returned
+// from Query and handle it in *Rows.
+func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
+ c.lastActivityTime = time.Now()
+
+ rows := c.getRows(sql, args)
+
+ if err := c.lock(); err != nil {
+ rows.abort(err)
+ return rows, err
+ }
+ rows.unlockConn = true
+
+ ps, ok := c.preparedStatements[sql]
+ if !ok {
+ var err error
+ ps, err = c.Prepare("", sql)
+ if err != nil {
+ rows.abort(err)
+ return rows, rows.err
+ }
+ }
+ rows.sql = ps.SQL
+ rows.fields = ps.FieldDescriptions
+ err := c.sendPreparedQuery(ps, args...)
+ if err != nil {
+ rows.abort(err)
+ }
+ return rows, rows.err
+}
+
+func (c *Conn) getRows(sql string, args []interface{}) *Rows {
+ if len(c.preallocatedRows) == 0 {
+ c.preallocatedRows = make([]Rows, 64)
+ }
+
+ r := &c.preallocatedRows[len(c.preallocatedRows)-1]
+ c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
+
+ r.conn = c
+ r.startTime = c.lastActivityTime
+ r.sql = sql
+ r.args = args
+
+ return r
+}
+
+// QueryRow is a convenience wrapper over Query. Any error that occurs while
+// querying is deferred until calling Scan on the returned *Row. That *Row will
+// error with ErrNoRows if no rows are returned.
+func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
+ rows, _ := c.Query(sql, args...)
+ return (*Row)(rows)
+}