aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/values.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/values.go')
-rw-r--r--vendor/github.com/jackc/pgx/values.go3439
1 files changed, 3439 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/values.go b/vendor/github.com/jackc/pgx/values.go
new file mode 100644
index 0000000..a189e18
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/values.go
@@ -0,0 +1,3439 @@
+package pgx
+
+import (
+ "bytes"
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+ "io"
+ "math"
+ "net"
+ "reflect"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// PostgreSQL oids for common types
+const (
+ BoolOid = 16
+ ByteaOid = 17
+ CharOid = 18
+ NameOid = 19
+ Int8Oid = 20
+ Int2Oid = 21
+ Int4Oid = 23
+ TextOid = 25
+ OidOid = 26
+ TidOid = 27
+ XidOid = 28
+ CidOid = 29
+ JsonOid = 114
+ CidrOid = 650
+ CidrArrayOid = 651
+ Float4Oid = 700
+ Float8Oid = 701
+ UnknownOid = 705
+ InetOid = 869
+ BoolArrayOid = 1000
+ Int2ArrayOid = 1005
+ Int4ArrayOid = 1007
+ TextArrayOid = 1009
+ ByteaArrayOid = 1001
+ VarcharArrayOid = 1015
+ Int8ArrayOid = 1016
+ Float4ArrayOid = 1021
+ Float8ArrayOid = 1022
+ AclItemOid = 1033
+ AclItemArrayOid = 1034
+ InetArrayOid = 1041
+ VarcharOid = 1043
+ DateOid = 1082
+ TimestampOid = 1114
+ TimestampArrayOid = 1115
+ TimestampTzOid = 1184
+ TimestampTzArrayOid = 1185
+ RecordOid = 2249
+ UuidOid = 2950
+ JsonbOid = 3802
+)
+
+// PostgreSQL format codes
+const (
+ TextFormatCode = 0
+ BinaryFormatCode = 1
+)
+
+const maxUint = ^uint(0)
+const maxInt = int(maxUint >> 1)
+const minInt = -maxInt - 1
+
+// DefaultTypeFormats maps type names to their default requested format (text
+// or binary). In theory the Scanner interface should be the one to determine
+// the format of the returned values. However, the query has already been
+// executed by the time Scan is called so it has no chance to set the format.
+// So for types that should always be returned in binary the format should be
+// set here.
+var DefaultTypeFormats map[string]int16
+
+func init() {
+ DefaultTypeFormats = map[string]int16{
+ "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
+ "_bool": BinaryFormatCode,
+ "_bytea": BinaryFormatCode,
+ "_cidr": BinaryFormatCode,
+ "_float4": BinaryFormatCode,
+ "_float8": BinaryFormatCode,
+ "_inet": BinaryFormatCode,
+ "_int2": BinaryFormatCode,
+ "_int4": BinaryFormatCode,
+ "_int8": BinaryFormatCode,
+ "_text": BinaryFormatCode,
+ "_timestamp": BinaryFormatCode,
+ "_timestamptz": BinaryFormatCode,
+ "_varchar": BinaryFormatCode,
+ "aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
+ "bool": BinaryFormatCode,
+ "bytea": BinaryFormatCode,
+ "char": BinaryFormatCode,
+ "cid": BinaryFormatCode,
+ "cidr": BinaryFormatCode,
+ "date": BinaryFormatCode,
+ "float4": BinaryFormatCode,
+ "float8": BinaryFormatCode,
+ "json": BinaryFormatCode,
+ "jsonb": BinaryFormatCode,
+ "inet": BinaryFormatCode,
+ "int2": BinaryFormatCode,
+ "int4": BinaryFormatCode,
+ "int8": BinaryFormatCode,
+ "name": BinaryFormatCode,
+ "oid": BinaryFormatCode,
+ "record": BinaryFormatCode,
+ "text": BinaryFormatCode,
+ "tid": BinaryFormatCode,
+ "timestamp": BinaryFormatCode,
+ "timestamptz": BinaryFormatCode,
+ "varchar": BinaryFormatCode,
+ "xid": BinaryFormatCode,
+ }
+}
+
+// SerializationError occurs on failure to encode or decode a value
+type SerializationError string
+
+func (e SerializationError) Error() string {
+ return string(e)
+}
+
+// Deprecated: Scanner is an interface used to decode values from the PostgreSQL
+// server. To allow types to support pgx and database/sql.Scan this interface
+// has been deprecated in favor of PgxScanner.
+type Scanner interface {
+ // Scan MUST check r.Type().DataType (to check by OID) or
+ // r.Type().DataTypeName (to check by name) to ensure that it is scanning an
+ // expected column type. It also MUST check r.Type().FormatCode before
+ // decoding. It should not assume that it was called on a data type or format
+ // that it understands.
+ Scan(r *ValueReader) error
+}
+
+// PgxScanner is an interface used to decode values from the PostgreSQL server.
+// It is used exactly the same as the Scanner interface. It simply has renamed
+// the method.
+type PgxScanner interface {
+ // ScanPgx MUST check r.Type().DataType (to check by OID) or
+ // r.Type().DataTypeName (to check by name) to ensure that it is scanning an
+ // expected column type. It also MUST check r.Type().FormatCode before
+ // decoding. It should not assume that it was called on a data type or format
+ // that it understands.
+ ScanPgx(r *ValueReader) error
+}
+
+// Encoder is an interface used to encode values for transmission to the
+// PostgreSQL server.
+type Encoder interface {
+ // Encode writes the value to w.
+ //
+ // If the value is NULL an int32(-1) should be written.
+ //
+ // Encode MUST check oid to see if the parameter data type is compatible. If
+ // this is not done, the PostgreSQL server may detect the error if the
+ // expected data size or format of the encoded data does not match. But if
+ // the encoded data is a valid representation of the data type PostgreSQL
+ // expects such as date and int4, incorrect data may be stored.
+ Encode(w *WriteBuf, oid Oid) error
+
+ // FormatCode returns the format that the encoder writes the value. It must be
+ // either pgx.TextFormatCode or pgx.BinaryFormatCode.
+ FormatCode() int16
+}
+
+// NullFloat32 represents an float4 that may be null. NullFloat32 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullFloat32 struct {
+ Float32 float32
+ Valid bool // Valid is true if Float32 is not NULL
+}
+
+func (n *NullFloat32) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Float4Oid {
+ return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Float32, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Float32 = decodeFloat4(vr)
+ return vr.Err()
+}
+
+func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Float4Oid {
+ return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeFloat32(w, oid, n.Float32)
+}
+
+// NullFloat64 represents an float8 that may be null. NullFloat64 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullFloat64 struct {
+ Float64 float64
+ Valid bool // Valid is true if Float64 is not NULL
+}
+
+func (n *NullFloat64) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Float8Oid {
+ return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Float64, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Float64 = decodeFloat8(vr)
+ return vr.Err()
+}
+
+func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Float8Oid {
+ return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeFloat64(w, oid, n.Float64)
+}
+
+// NullString represents an string that may be null. NullString implements the
+// Scanner Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullString struct {
+ String string
+ Valid bool // Valid is true if String is not NULL
+}
+
+func (n *NullString) Scan(vr *ValueReader) error {
+ // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later
+
+ if vr.Len() == -1 {
+ n.String, n.Valid = "", false
+ return nil
+ }
+
+ n.Valid = true
+ n.String = decodeText(vr)
+ return vr.Err()
+}
+
+func (n NullString) FormatCode() int16 { return TextFormatCode }
+
+func (n NullString) Encode(w *WriteBuf, oid Oid) error {
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeString(w, oid, n.String)
+}
+
+// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem
+// might look like this:
+//
+// postgres=arwdDxt/postgres
+//
+// Note, however, that because the user/role name part of an aclitem is
+// an identifier, it follows all the usual formatting rules for SQL
+// identifiers: if it contains spaces and other special characters,
+// it should appear in double-quotes:
+//
+// postgres=arwdDxt/"role with spaces"
+//
+type AclItem string
+
+// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullAclItem struct {
+ AclItem AclItem
+ Valid bool // Valid is true if AclItem is not NULL
+}
+
+func (n *NullAclItem) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != AclItemOid {
+ return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.AclItem, n.Valid = "", false
+ return nil
+ }
+
+ n.Valid = true
+ n.AclItem = AclItem(decodeText(vr))
+ return vr.Err()
+}
+
+// Particularly important to return TextFormatCode, seeing as Postgres
+// only ever sends aclitem as text, not binary.
+func (n NullAclItem) FormatCode() int16 { return TextFormatCode }
+
+func (n NullAclItem) Encode(w *WriteBuf, oid Oid) error {
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeString(w, oid, string(n.AclItem))
+}
+
+// Name is a type used for PostgreSQL's special 63-byte
+// name data type, used for identifiers like table names.
+// The pg_class.relname column is a good example of where the
+// name data type is used.
+//
+// Note that the underlying Go data type of pgx.Name is string,
+// so there is no way to enforce the 63-byte length. Inputting
+// a longer name into PostgreSQL will result in silent truncation
+// to 63 bytes.
+//
+// Also, if you have custom-compiled PostgreSQL and set
+// NAMEDATALEN to a different value, obviously that number of
+// bytes applies, rather than the default 63.
+type Name string
+
+// NullName represents a pgx.Name that may be null. NullName implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullName struct {
+ Name Name
+ Valid bool // Valid is true if Name is not NULL
+}
+
+func (n *NullName) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != NameOid {
+ return SerializationError(fmt.Sprintf("NullName.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Name, n.Valid = "", false
+ return nil
+ }
+
+ n.Valid = true
+ n.Name = Name(decodeText(vr))
+ return vr.Err()
+}
+
+func (n NullName) FormatCode() int16 { return TextFormatCode }
+
+func (n NullName) Encode(w *WriteBuf, oid Oid) error {
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeString(w, oid, string(n.Name))
+}
+
+// The pgx.Char type is for PostgreSQL's special 8-bit-only
+// "char" type more akin to the C language's char type, or Go's byte type.
+// (Note that the name in PostgreSQL itself is "char", in double-quotes,
+// and not char.) It gets used a lot in PostgreSQL's system tables to hold
+// a single ASCII character value (eg pg_class.relkind).
+type Char byte
+
+// NullChar represents a pgx.Char that may be null. NullChar implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullChar struct {
+ Char Char
+ Valid bool // Valid is true if Char is not NULL
+}
+
+func (n *NullChar) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != CharOid {
+ return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Char, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Char = decodeChar(vr)
+ return vr.Err()
+}
+
+func (n NullChar) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullChar) Encode(w *WriteBuf, oid Oid) error {
+ if oid != CharOid {
+ return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeChar(w, oid, n.Char)
+}
+
+// NullInt16 represents a smallint that may be null. NullInt16 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullInt16 struct {
+ Int16 int16
+ Valid bool // Valid is true if Int16 is not NULL
+}
+
+func (n *NullInt16) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Int2Oid {
+ return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Int16, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Int16 = decodeInt2(vr)
+ return vr.Err()
+}
+
+func (n NullInt16) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullInt16) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Int2Oid {
+ return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeInt16(w, oid, n.Int16)
+}
+
+// NullInt32 represents an integer that may be null. NullInt32 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullInt32 struct {
+ Int32 int32
+ Valid bool // Valid is true if Int32 is not NULL
+}
+
+func (n *NullInt32) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Int4Oid {
+ return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Int32, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Int32 = decodeInt4(vr)
+ return vr.Err()
+}
+
+func (n NullInt32) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullInt32) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Int4Oid {
+ return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeInt32(w, oid, n.Int32)
+}
+
+// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html,
+// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented
+// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h
+// in the PostgreSQL sources.
+type Oid uint32
+
+// NullOid represents a Command Identifier (Oid) that may be null. NullOid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullOid struct {
+ Oid Oid
+ Valid bool // Valid is true if Oid is not NULL
+}
+
+func (n *NullOid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != OidOid {
+ return SerializationError(fmt.Sprintf("NullOid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Oid, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Oid = decodeOid(vr)
+ return vr.Err()
+}
+
+func (n NullOid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullOid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != OidOid {
+ return SerializationError(fmt.Sprintf("NullOid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeOid(w, oid, n.Oid)
+}
+
+// Xid is PostgreSQL's Transaction ID type.
+//
+// In later versions of PostgreSQL, it is the type used for the backend_xid
+// and backend_xmin columns of the pg_stat_activity system view.
+//
+// Also, when one does
+//
+// select xmin, xmax, * from some_table;
+//
+// it is the data type of the xmin and xmax hidden system columns.
+//
+// It is currently implemented as an unsigned four byte integer.
+// Its definition can be found in src/include/postgres_ext.h as TransactionId
+// in the PostgreSQL sources.
+type Xid uint32
+
+// NullXid represents a Transaction ID (Xid) that may be null. NullXid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullXid struct {
+ Xid Xid
+ Valid bool // Valid is true if Xid is not NULL
+}
+
+func (n *NullXid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != XidOid {
+ return SerializationError(fmt.Sprintf("NullXid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Xid, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Xid = decodeXid(vr)
+ return vr.Err()
+}
+
+func (n NullXid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullXid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != XidOid {
+ return SerializationError(fmt.Sprintf("NullXid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeXid(w, oid, n.Xid)
+}
+
+// Cid is PostgreSQL's Command Identifier type.
+//
+// When one does
+//
+// select cmin, cmax, * from some_table;
+//
+// it is the data type of the cmin and cmax hidden system columns.
+//
+// It is currently implemented as an unsigned four byte integer.
+// Its definition can be found in src/include/c.h as CommandId
+// in the PostgreSQL sources.
+type Cid uint32
+
+// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullCid struct {
+ Cid Cid
+ Valid bool // Valid is true if Cid is not NULL
+}
+
+func (n *NullCid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != CidOid {
+ return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Cid, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Cid = decodeCid(vr)
+ return vr.Err()
+}
+
+func (n NullCid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullCid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != CidOid {
+ return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeCid(w, oid, n.Cid)
+}
+
+// Tid is PostgreSQL's Tuple Identifier type.
+//
+// When one does
+//
+// select ctid, * from some_table;
+//
+// it is the data type of the ctid hidden system column.
+//
+// It is currently implemented as a pair unsigned two byte integers.
+// Its conversion functions can be found in src/backend/utils/adt/tid.c
+// in the PostgreSQL sources.
+type Tid struct {
+ BlockNumber uint32
+ OffsetNumber uint16
+}
+
+// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullTid struct {
+ Tid Tid
+ Valid bool // Valid is true if Tid is not NULL
+}
+
+func (n *NullTid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != TidOid {
+ return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false
+ return nil
+ }
+ n.Valid = true
+ n.Tid = decodeTid(vr)
+ return vr.Err()
+}
+
+func (n NullTid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullTid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != TidOid {
+ return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeTid(w, oid, n.Tid)
+}
+
+// NullInt64 represents an bigint that may be null. NullInt64 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullInt64 struct {
+ Int64 int64
+ Valid bool // Valid is true if Int64 is not NULL
+}
+
+func (n *NullInt64) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Int8Oid {
+ return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Int64, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Int64 = decodeInt8(vr)
+ return vr.Err()
+}
+
+func (n NullInt64) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullInt64) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Int8Oid {
+ return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeInt64(w, oid, n.Int64)
+}
+
+// NullBool represents an bool that may be null. NullBool implements the Scanner
+// and Encoder interfaces so it may be used both as an argument to Query[Row]
+// and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullBool struct {
+ Bool bool
+ Valid bool // Valid is true if Bool is not NULL
+}
+
+func (n *NullBool) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != BoolOid {
+ return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Bool, n.Valid = false, false
+ return nil
+ }
+ n.Valid = true
+ n.Bool = decodeBool(vr)
+ return vr.Err()
+}
+
+func (n NullBool) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullBool) Encode(w *WriteBuf, oid Oid) error {
+ if oid != BoolOid {
+ return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeBool(w, oid, n.Bool)
+}
+
+// NullTime represents an time.Time that may be null. NullTime implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL
+// types timestamptz, timestamp, and date.
+//
+// If Valid is false then the value is NULL.
+type NullTime struct {
+ Time time.Time
+ Valid bool // Valid is true if Time is not NULL
+}
+
+func (n *NullTime) Scan(vr *ValueReader) error {
+ oid := vr.Type().DataType
+ if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid {
+ return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Time, n.Valid = time.Time{}, false
+ return nil
+ }
+
+ n.Valid = true
+ switch oid {
+ case TimestampTzOid:
+ n.Time = decodeTimestampTz(vr)
+ case TimestampOid:
+ n.Time = decodeTimestamp(vr)
+ case DateOid:
+ n.Time = decodeDate(vr)
+ }
+
+ return vr.Err()
+}
+
+func (n NullTime) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
+ if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid {
+ return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeTime(w, oid, n.Time)
+}
+
+// Hstore represents an hstore column. It does not support a null column or null
+// key values (use NullHstore for this). Hstore implements the Scanner and
+// Encoder interfaces so it may be used both as an argument to Query[Row] and a
+// destination for Scan.
+type Hstore map[string]string
+
+func (h *Hstore) Scan(vr *ValueReader) error {
+ //oid for hstore not standardized, so we check its type name
+ if vr.Type().DataTypeName != "hstore" {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName)))
+ return nil
+ }
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null column into Hstore"))
+ return nil
+ }
+
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ m, err := parseHstoreToMap(vr.ReadString(vr.Len()))
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
+ return nil
+ }
+ hm := Hstore(m)
+ *h = hm
+ return nil
+ case BinaryFormatCode:
+ vr.Fatal(ProtocolError("Can't decode binary hstore"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+}
+
+func (h Hstore) FormatCode() int16 { return TextFormatCode }
+
+func (h Hstore) Encode(w *WriteBuf, oid Oid) error {
+ var buf bytes.Buffer
+
+ i := 0
+ for k, v := range h {
+ i++
+ ks := strings.Replace(k, `\`, `\\`, -1)
+ ks = strings.Replace(ks, `"`, `\"`, -1)
+ vs := strings.Replace(v, `\`, `\\`, -1)
+ vs = strings.Replace(vs, `"`, `\"`, -1)
+ buf.WriteString(`"`)
+ buf.WriteString(ks)
+ buf.WriteString(`"=>"`)
+ buf.WriteString(vs)
+ buf.WriteString(`"`)
+ if i < len(h) {
+ buf.WriteString(", ")
+ }
+ }
+ w.WriteInt32(int32(buf.Len()))
+ w.WriteBytes(buf.Bytes())
+ return nil
+}
+
+// NullHstore represents an hstore column that can be null or have null values
+// associated with its keys. NullHstore implements the Scanner and Encoder
+// interfaces so it may be used both as an argument to Query[Row] and a
+// destination for Scan.
+//
+// If Valid is false, then the value of the entire hstore column is NULL
+// If any of the NullString values in Store has Valid set to false, the key
+// appears in the hstore column, but its value is explicitly set to NULL.
+type NullHstore struct {
+ Hstore map[string]NullString
+ Valid bool
+}
+
+func (h *NullHstore) Scan(vr *ValueReader) error {
+ //oid for hstore not standardized, so we check its type name
+ if vr.Type().DataTypeName != "hstore" {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName)))
+ return nil
+ }
+
+ if vr.Len() == -1 {
+ h.Valid = false
+ return nil
+ }
+
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len()))
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
+ return nil
+ }
+ h.Valid = true
+ h.Hstore = store
+ return nil
+ case BinaryFormatCode:
+ vr.Fatal(ProtocolError("Can't decode binary hstore"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+}
+
+func (h NullHstore) FormatCode() int16 { return TextFormatCode }
+
+func (h NullHstore) Encode(w *WriteBuf, oid Oid) error {
+ var buf bytes.Buffer
+
+ if !h.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ i := 0
+ for k, v := range h.Hstore {
+ i++
+ ks := strings.Replace(k, `\`, `\\`, -1)
+ ks = strings.Replace(ks, `"`, `\"`, -1)
+ if v.Valid {
+ vs := strings.Replace(v.String, `\`, `\\`, -1)
+ vs = strings.Replace(vs, `"`, `\"`, -1)
+ buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
+ } else {
+ buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks))
+ }
+ if i < len(h.Hstore) {
+ buf.WriteString(", ")
+ }
+ }
+ w.WriteInt32(int32(buf.Len()))
+ w.WriteBytes(buf.Bytes())
+ return nil
+}
+
+// Encode encodes arg into wbuf as the type oid. This allows implementations
+// of the Encoder interface to delegate the actual work of encoding to the
+// built-in functionality.
+func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
+ if arg == nil {
+ wbuf.WriteInt32(-1)
+ return nil
+ }
+
+ switch arg := arg.(type) {
+ case Encoder:
+ return arg.Encode(wbuf, oid)
+ case driver.Valuer:
+ v, err := arg.Value()
+ if err != nil {
+ return err
+ }
+ return Encode(wbuf, oid, v)
+ case string:
+ return encodeString(wbuf, oid, arg)
+ case []AclItem:
+ return encodeAclItemSlice(wbuf, oid, arg)
+ case []byte:
+ return encodeByteSlice(wbuf, oid, arg)
+ case [][]byte:
+ return encodeByteSliceSlice(wbuf, oid, arg)
+ }
+
+ refVal := reflect.ValueOf(arg)
+
+ if refVal.Kind() == reflect.Ptr {
+ if refVal.IsNil() {
+ wbuf.WriteInt32(-1)
+ return nil
+ }
+ arg = refVal.Elem().Interface()
+ return Encode(wbuf, oid, arg)
+ }
+
+ if oid == JsonOid {
+ return encodeJSON(wbuf, oid, arg)
+ }
+ if oid == JsonbOid {
+ return encodeJSONB(wbuf, oid, arg)
+ }
+
+ switch arg := arg.(type) {
+ case []string:
+ return encodeStringSlice(wbuf, oid, arg)
+ case bool:
+ return encodeBool(wbuf, oid, arg)
+ case []bool:
+ return encodeBoolSlice(wbuf, oid, arg)
+ case int:
+ return encodeInt(wbuf, oid, arg)
+ case uint:
+ return encodeUInt(wbuf, oid, arg)
+ case Char:
+ return encodeChar(wbuf, oid, arg)
+ case AclItem:
+ // The aclitem data type goes over the wire using the same format as string,
+ // so just cast to string and use encodeString
+ return encodeString(wbuf, oid, string(arg))
+ case Name:
+ // The name data type goes over the wire using the same format as string,
+ // so just cast to string and use encodeString
+ return encodeString(wbuf, oid, string(arg))
+ case int8:
+ return encodeInt8(wbuf, oid, arg)
+ case uint8:
+ return encodeUInt8(wbuf, oid, arg)
+ case int16:
+ return encodeInt16(wbuf, oid, arg)
+ case []int16:
+ return encodeInt16Slice(wbuf, oid, arg)
+ case uint16:
+ return encodeUInt16(wbuf, oid, arg)
+ case []uint16:
+ return encodeUInt16Slice(wbuf, oid, arg)
+ case int32:
+ return encodeInt32(wbuf, oid, arg)
+ case []int32:
+ return encodeInt32Slice(wbuf, oid, arg)
+ case uint32:
+ return encodeUInt32(wbuf, oid, arg)
+ case []uint32:
+ return encodeUInt32Slice(wbuf, oid, arg)
+ case int64:
+ return encodeInt64(wbuf, oid, arg)
+ case []int64:
+ return encodeInt64Slice(wbuf, oid, arg)
+ case uint64:
+ return encodeUInt64(wbuf, oid, arg)
+ case []uint64:
+ return encodeUInt64Slice(wbuf, oid, arg)
+ case float32:
+ return encodeFloat32(wbuf, oid, arg)
+ case []float32:
+ return encodeFloat32Slice(wbuf, oid, arg)
+ case float64:
+ return encodeFloat64(wbuf, oid, arg)
+ case []float64:
+ return encodeFloat64Slice(wbuf, oid, arg)
+ case time.Time:
+ return encodeTime(wbuf, oid, arg)
+ case []time.Time:
+ return encodeTimeSlice(wbuf, oid, arg)
+ case net.IP:
+ return encodeIP(wbuf, oid, arg)
+ case []net.IP:
+ return encodeIPSlice(wbuf, oid, arg)
+ case net.IPNet:
+ return encodeIPNet(wbuf, oid, arg)
+ case []net.IPNet:
+ return encodeIPNetSlice(wbuf, oid, arg)
+ case Oid:
+ return encodeOid(wbuf, oid, arg)
+ case Xid:
+ return encodeXid(wbuf, oid, arg)
+ case Cid:
+ return encodeCid(wbuf, oid, arg)
+ default:
+ if strippedArg, ok := stripNamedType(&refVal); ok {
+ return Encode(wbuf, oid, strippedArg)
+ }
+ return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
+ }
+}
+
+func stripNamedType(val *reflect.Value) (interface{}, bool) {
+ switch val.Kind() {
+ case reflect.Int:
+ return int(val.Int()), true
+ case reflect.Int8:
+ return int8(val.Int()), true
+ case reflect.Int16:
+ return int16(val.Int()), true
+ case reflect.Int32:
+ return int32(val.Int()), true
+ case reflect.Int64:
+ return int64(val.Int()), true
+ case reflect.Uint:
+ return uint(val.Uint()), true
+ case reflect.Uint8:
+ return uint8(val.Uint()), true
+ case reflect.Uint16:
+ return uint16(val.Uint()), true
+ case reflect.Uint32:
+ return uint32(val.Uint()), true
+ case reflect.Uint64:
+ return uint64(val.Uint()), true
+ case reflect.String:
+ return val.String(), true
+ }
+
+ return nil, false
+}
+
+// Decode decodes from vr into d. d must be a pointer. This allows
+// implementations of the Decoder interface to delegate the actual work of
+// decoding to the built-in functionality.
+func Decode(vr *ValueReader, d interface{}) error {
+ switch v := d.(type) {
+ case *bool:
+ *v = decodeBool(vr)
+ case *int:
+ n := decodeInt(vr)
+ if n < int64(minInt) {
+ return fmt.Errorf("%d is less than minimum value for int", n)
+ } else if n > int64(maxInt) {
+ return fmt.Errorf("%d is greater than maximum value for int", n)
+ }
+ *v = int(n)
+ case *int8:
+ n := decodeInt(vr)
+ if n < math.MinInt8 {
+ return fmt.Errorf("%d is less than minimum value for int8", n)
+ } else if n > math.MaxInt8 {
+ return fmt.Errorf("%d is greater than maximum value for int8", n)
+ }
+ *v = int8(n)
+ case *int16:
+ n := decodeInt(vr)
+ if n < math.MinInt16 {
+ return fmt.Errorf("%d is less than minimum value for int16", n)
+ } else if n > math.MaxInt16 {
+ return fmt.Errorf("%d is greater than maximum value for int16", n)
+ }
+ *v = int16(n)
+ case *int32:
+ n := decodeInt(vr)
+ if n < math.MinInt32 {
+ return fmt.Errorf("%d is less than minimum value for int32", n)
+ } else if n > math.MaxInt32 {
+ return fmt.Errorf("%d is greater than maximum value for int32", n)
+ }
+ *v = int32(n)
+ case *int64:
+ n := decodeInt(vr)
+ if n < math.MinInt64 {
+ return fmt.Errorf("%d is less than minimum value for int64", n)
+ } else if n > math.MaxInt64 {
+ return fmt.Errorf("%d is greater than maximum value for int64", n)
+ }
+ *v = int64(n)
+ case *uint:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint8", n)
+ } else if maxInt == math.MaxInt32 && n > math.MaxUint32 {
+ return fmt.Errorf("%d is greater than maximum value for uint", n)
+ }
+ *v = uint(n)
+ case *uint8:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint8", n)
+ } else if n > math.MaxUint8 {
+ return fmt.Errorf("%d is greater than maximum value for uint8", n)
+ }
+ *v = uint8(n)
+ case *uint16:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint16", n)
+ } else if n > math.MaxUint16 {
+ return fmt.Errorf("%d is greater than maximum value for uint16", n)
+ }
+ *v = uint16(n)
+ case *uint32:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint32", n)
+ } else if n > math.MaxUint32 {
+ return fmt.Errorf("%d is greater than maximum value for uint32", n)
+ }
+ *v = uint32(n)
+ case *uint64:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint64", n)
+ }
+ *v = uint64(n)
+ case *Char:
+ *v = decodeChar(vr)
+ case *AclItem:
+ // aclitem goes over the wire just like text
+ *v = AclItem(decodeText(vr))
+ case *Name:
+ // name goes over the wire just like text
+ *v = Name(decodeText(vr))
+ case *Oid:
+ *v = decodeOid(vr)
+ case *Xid:
+ *v = decodeXid(vr)
+ case *Tid:
+ *v = decodeTid(vr)
+ case *Cid:
+ *v = decodeCid(vr)
+ case *string:
+ *v = decodeText(vr)
+ case *float32:
+ *v = decodeFloat4(vr)
+ case *float64:
+ *v = decodeFloat8(vr)
+ case *[]AclItem:
+ *v = decodeAclItemArray(vr)
+ case *[]bool:
+ *v = decodeBoolArray(vr)
+ case *[]int16:
+ *v = decodeInt2Array(vr)
+ case *[]uint16:
+ *v = decodeInt2ArrayToUInt(vr)
+ case *[]int32:
+ *v = decodeInt4Array(vr)
+ case *[]uint32:
+ *v = decodeInt4ArrayToUInt(vr)
+ case *[]int64:
+ *v = decodeInt8Array(vr)
+ case *[]uint64:
+ *v = decodeInt8ArrayToUInt(vr)
+ case *[]float32:
+ *v = decodeFloat4Array(vr)
+ case *[]float64:
+ *v = decodeFloat8Array(vr)
+ case *[]string:
+ *v = decodeTextArray(vr)
+ case *[]time.Time:
+ *v = decodeTimestampArray(vr)
+ case *[][]byte:
+ *v = decodeByteaArray(vr)
+ case *[]interface{}:
+ *v = decodeRecord(vr)
+ case *time.Time:
+ switch vr.Type().DataType {
+ case DateOid:
+ *v = decodeDate(vr)
+ case TimestampTzOid:
+ *v = decodeTimestampTz(vr)
+ case TimestampOid:
+ *v = decodeTimestamp(vr)
+ default:
+ return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)
+ }
+ case *net.IP:
+ ipnet := decodeInet(vr)
+ if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
+ return fmt.Errorf("Cannot decode netmask into *net.IP")
+ }
+ *v = ipnet.IP
+ case *[]net.IP:
+ ipnets := decodeInetArray(vr)
+ ips := make([]net.IP, len(ipnets))
+ for i, ipnet := range ipnets {
+ if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
+ return fmt.Errorf("Cannot decode netmask into *net.IP")
+ }
+ ips[i] = ipnet.IP
+ }
+ *v = ips
+ case *net.IPNet:
+ *v = decodeInet(vr)
+ case *[]net.IPNet:
+ *v = decodeInetArray(vr)
+ default:
+ if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
+ el := v.Elem()
+ switch el.Kind() {
+ // if d is a pointer to pointer, strip the pointer and try again
+ case reflect.Ptr:
+ // -1 is a null value
+ if vr.Len() == -1 {
+ if !el.IsNil() {
+ // if the destination pointer is not nil, nil it out
+ el.Set(reflect.Zero(el.Type()))
+ }
+ return nil
+ }
+ if el.IsNil() {
+ // allocate destination
+ el.Set(reflect.New(el.Type().Elem()))
+ }
+ d = el.Interface()
+ return Decode(vr, d)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ n := decodeInt(vr)
+ if el.OverflowInt(n) {
+ return fmt.Errorf("Scan cannot decode %d into %T", n, d)
+ }
+ el.SetInt(n)
+ return nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for %T", n, d)
+ }
+ if el.OverflowUint(uint64(n)) {
+ return fmt.Errorf("Scan cannot decode %d into %T", n, d)
+ }
+ el.SetUint(uint64(n))
+ return nil
+ case reflect.String:
+ el.SetString(decodeText(vr))
+ return nil
+ }
+ }
+ return fmt.Errorf("Scan cannot decode into %T", d)
+ }
+
+ return nil
+}
+
+func decodeBool(vr *ValueReader) bool {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into bool"))
+ return false
+ }
+
+ if vr.Type().DataType != BoolOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
+ return false
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return false
+ }
+
+ if vr.Len() != 1 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len())))
+ return false
+ }
+
+ b := vr.ReadByte()
+ return b != 0
+}
+
+func encodeBool(w *WriteBuf, oid Oid, value bool) error {
+ if oid != BoolOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
+ }
+
+ w.WriteInt32(1)
+
+ var n byte
+ if value {
+ n = 1
+ }
+
+ w.WriteByte(n)
+
+ return nil
+}
+
+func decodeInt(vr *ValueReader) int64 {
+ switch vr.Type().DataType {
+ case Int2Oid:
+ return int64(decodeInt2(vr))
+ case Int4Oid:
+ return int64(decodeInt4(vr))
+ case Int8Oid:
+ return int64(decodeInt8(vr))
+ }
+
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType)))
+ return 0
+}
+
+func decodeInt8(vr *ValueReader) int64 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into int64"))
+ return 0
+ }
+
+ if vr.Type().DataType != Int8Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len())))
+ return 0
+ }
+
+ return vr.ReadInt64()
+}
+
+func decodeChar(vr *ValueReader) Char {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into char"))
+ return Char(0)
+ }
+
+ if vr.Type().DataType != CharOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType)))
+ return Char(0)
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Char(0)
+ }
+
+ if vr.Len() != 1 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len())))
+ return Char(0)
+ }
+
+ return Char(vr.ReadByte())
+}
+
+func decodeInt2(vr *ValueReader) int16 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into int16"))
+ return 0
+ }
+
+ if vr.Type().DataType != Int2Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 2 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len())))
+ return 0
+ }
+
+ return vr.ReadInt16()
+}
+
+func encodeInt(w *WriteBuf, oid Oid, value int) error {
+ switch oid {
+ case Int2Oid:
+ if value < math.MinInt16 {
+ return fmt.Errorf("%d is less than min pg:int2", value)
+ } else if value > math.MaxInt16 {
+ return fmt.Errorf("%d is greater than max pg:int2", value)
+ }
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ if value < math.MinInt32 {
+ return fmt.Errorf("%d is less than min pg:int4", value)
+ } else if value > math.MaxInt32 {
+ return fmt.Errorf("%d is greater than max pg:int4", value)
+ }
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ if int64(value) <= int64(math.MaxInt64) {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ } else {
+ return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64))
+ }
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int8", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt(w *WriteBuf, oid Oid, value uint) error {
+ switch oid {
+ case Int2Oid:
+ if value > math.MaxInt16 {
+ return fmt.Errorf("%d is greater than max pg:int2", value)
+ }
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ if value > math.MaxInt32 {
+ return fmt.Errorf("%d is greater than max pg:int4", value)
+ }
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ //****** Changed value to int64(value) and math.MaxInt64 to int64(math.MaxInt64)
+ if int64(value) > int64(math.MaxInt64) {
+ return fmt.Errorf("%d is greater than max pg:int8", value)
+ }
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid)
+ }
+
+ return nil
+}
+
+func encodeChar(w *WriteBuf, oid Oid, value Char) error {
+ w.WriteInt32(1)
+ w.WriteByte(byte(value))
+ return nil
+}
+
+func encodeInt8(w *WriteBuf, oid Oid, value int8) error {
+ switch oid {
+ case Int2Oid:
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int8", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error {
+ switch oid {
+ case Int2Oid:
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid)
+ }
+
+ return nil
+}
+
+func encodeInt16(w *WriteBuf, oid Oid, value int16) error {
+ switch oid {
+ case Int2Oid:
+ w.WriteInt32(2)
+ w.WriteInt16(value)
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int16", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int16", oid)
+ }
+
+ return nil
+}
+
+func encodeInt32(w *WriteBuf, oid Oid, value int32) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(value)
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int32", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ if value <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
+ }
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid)
+ }
+
+ return nil
+}
+
+func encodeInt64(w *WriteBuf, oid Oid, value int64) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ if value <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
+ }
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(value)
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int64", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ if value <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
+ }
+ case Int8Oid:
+
+ if value <= math.MaxInt64 {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int64 %d", value, int64(math.MaxInt64))
+ }
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid)
+ }
+
+ return nil
+}
+
+func decodeInt4(vr *ValueReader) int32 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into int32"))
+ return 0
+ }
+
+ if vr.Type().DataType != Int4Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len())))
+ return 0
+ }
+
+ return vr.ReadInt32()
+}
+
+func decodeOid(vr *ValueReader) Oid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Oid"))
+ return Oid(0)
+ }
+
+ if vr.Type().DataType != OidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType)))
+ return Oid(0)
+ }
+
+ // Oid needs to decode text format because it is used in loadPgTypes
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+ n, err := strconv.ParseUint(s, 10, 32)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ }
+ return Oid(n)
+ case BinaryFormatCode:
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Oid(0)
+ }
+ return Oid(vr.ReadInt32())
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Oid(0)
+ }
+}
+
+func encodeOid(w *WriteBuf, oid Oid, value Oid) error {
+ if oid != OidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Oid", oid)
+ }
+
+ w.WriteInt32(4)
+ w.WriteUint32(uint32(value))
+
+ return nil
+}
+
+func decodeXid(vr *ValueReader) Xid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Xid"))
+ return Xid(0)
+ }
+
+ if vr.Type().DataType != XidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Xid", vr.Type().DataType)))
+ return Xid(0)
+ }
+
+ // Unlikely Xid will ever go over the wire as text format, but who knows?
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+ n, err := strconv.ParseUint(s, 10, 32)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ }
+ return Xid(n)
+ case BinaryFormatCode:
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Xid(0)
+ }
+ return Xid(vr.ReadUint32())
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Xid(0)
+ }
+}
+
+func encodeXid(w *WriteBuf, oid Oid, value Xid) error {
+ if oid != XidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Xid", oid)
+ }
+
+ w.WriteInt32(4)
+ w.WriteUint32(uint32(value))
+
+ return nil
+}
+
+func decodeCid(vr *ValueReader) Cid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Cid"))
+ return Cid(0)
+ }
+
+ if vr.Type().DataType != CidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType)))
+ return Cid(0)
+ }
+
+ // Unlikely Cid will ever go over the wire as text format, but who knows?
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+ n, err := strconv.ParseUint(s, 10, 32)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ }
+ return Cid(n)
+ case BinaryFormatCode:
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Cid(0)
+ }
+ return Cid(vr.ReadUint32())
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Cid(0)
+ }
+}
+
+func encodeCid(w *WriteBuf, oid Oid, value Cid) error {
+ if oid != CidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid)
+ }
+
+ w.WriteInt32(4)
+ w.WriteUint32(uint32(value))
+
+ return nil
+}
+
+// Note that we do not match negative numbers, because neither the
+// BlockNumber nor OffsetNumber of a Tid can be negative.
+var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`)
+
+func decodeTid(vr *ValueReader) Tid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Tid"))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+
+ if vr.Type().DataType != TidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType)))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+
+ // Unlikely Tid will ever go over the wire as text format, but who knows?
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+
+ match := tidRegexp.FindStringSubmatch(s)
+ if match == nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+
+ blockNumber, err := strconv.ParseUint(s, 10, 16)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s)))
+ }
+
+ offsetNumber, err := strconv.ParseUint(s, 10, 16)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s)))
+ }
+ return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)}
+ case BinaryFormatCode:
+ if vr.Len() != 6 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+ return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()}
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+}
+
+func encodeTid(w *WriteBuf, oid Oid, value Tid) error {
+ if oid != TidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid)
+ }
+
+ w.WriteInt32(6)
+ w.WriteUint32(value.BlockNumber)
+ w.WriteUint16(value.OffsetNumber)
+
+ return nil
+}
+
+func decodeFloat4(vr *ValueReader) float32 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into float32"))
+ return 0
+ }
+
+ if vr.Type().DataType != Float4Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
+ return 0
+ }
+
+ i := vr.ReadInt32()
+ return math.Float32frombits(uint32(i))
+}
+
+func encodeFloat32(w *WriteBuf, oid Oid, value float32) error {
+ switch oid {
+ case Float4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(math.Float32bits(value)))
+ case Float8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(math.Float64bits(float64(value))))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "float32", oid)
+ }
+
+ return nil
+}
+
+func decodeFloat8(vr *ValueReader) float64 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into float64"))
+ return 0
+ }
+
+ if vr.Type().DataType != Float8Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len())))
+ return 0
+ }
+
+ i := vr.ReadInt64()
+ return math.Float64frombits(uint64(i))
+}
+
+func encodeFloat64(w *WriteBuf, oid Oid, value float64) error {
+ switch oid {
+ case Float8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(math.Float64bits(value)))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "float64", oid)
+ }
+
+ return nil
+}
+
+func decodeText(vr *ValueReader) string {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into string"))
+ return ""
+ }
+
+ return vr.ReadString(vr.Len())
+}
+
+func encodeString(w *WriteBuf, oid Oid, value string) error {
+ w.WriteInt32(int32(len(value)))
+ w.WriteBytes([]byte(value))
+ return nil
+}
+
+func decodeBytea(vr *ValueReader) []byte {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != ByteaOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ return vr.ReadBytes(vr.Len())
+}
+
+func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error {
+ w.WriteInt32(int32(len(value)))
+ w.WriteBytes(value)
+
+ return nil
+}
+
+func decodeJSON(vr *ValueReader, d interface{}) error {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != JsonOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType)))
+ }
+
+ bytes := vr.ReadBytes(vr.Len())
+ err := json.Unmarshal(bytes, d)
+ if err != nil {
+ vr.Fatal(err)
+ }
+ return err
+}
+
+func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error {
+ if oid != JsonOid {
+ return fmt.Errorf("cannot encode JSON into oid %v", oid)
+ }
+
+ s, err := json.Marshal(value)
+ if err != nil {
+ return fmt.Errorf("Failed to encode json from type: %T", value)
+ }
+
+ w.WriteInt32(int32(len(s)))
+ w.WriteBytes(s)
+
+ return nil
+}
+
+func decodeJSONB(vr *ValueReader, d interface{}) error {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != JsonbOid {
+ err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType))
+ vr.Fatal(err)
+ return err
+ }
+
+ bytes := vr.ReadBytes(vr.Len())
+ if vr.Type().FormatCode == BinaryFormatCode {
+ if bytes[0] != 1 {
+ err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0]))
+ vr.Fatal(err)
+ return err
+ }
+ bytes = bytes[1:]
+ }
+
+ err := json.Unmarshal(bytes, d)
+ if err != nil {
+ vr.Fatal(err)
+ }
+ return err
+}
+
+func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error {
+ if oid != JsonbOid {
+ return fmt.Errorf("cannot encode JSON into oid %v", oid)
+ }
+
+ s, err := json.Marshal(value)
+ if err != nil {
+ return fmt.Errorf("Failed to encode json from type: %T", value)
+ }
+
+ w.WriteInt32(int32(len(s) + 1))
+ w.WriteByte(1) // JSONB format header
+ w.WriteBytes(s)
+
+ return nil
+}
+
+func decodeDate(vr *ValueReader) time.Time {
+ var zeroTime time.Time
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
+ return zeroTime
+ }
+
+ if vr.Type().DataType != DateOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
+ return zeroTime
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zeroTime
+ }
+
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len())))
+ }
+ dayOffset := vr.ReadInt32()
+ return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local)
+}
+
+func encodeTime(w *WriteBuf, oid Oid, value time.Time) error {
+ switch oid {
+ case DateOid:
+ tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix()
+ dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix()
+
+ secSinceDateEpoch := tUnix - dateEpoch
+ daysSinceDateEpoch := secSinceDateEpoch / 86400
+
+ w.WriteInt32(4)
+ w.WriteInt32(int32(daysSinceDateEpoch))
+
+ return nil
+ case TimestampTzOid, TimestampOid:
+ microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000
+ microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K
+
+ w.WriteInt32(8)
+ w.WriteInt64(microsecSinceY2K)
+
+ return nil
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid)
+ }
+}
+
+const microsecFromUnixEpochToY2K = 946684800 * 1000000
+
+func decodeTimestampTz(vr *ValueReader) time.Time {
+ var zeroTime time.Time
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
+ return zeroTime
+ }
+
+ if vr.Type().DataType != TimestampTzOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
+ return zeroTime
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zeroTime
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len())))
+ return zeroTime
+ }
+
+ microsecSinceY2K := vr.ReadInt64()
+ microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
+ return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
+}
+
+func decodeTimestamp(vr *ValueReader) time.Time {
+ var zeroTime time.Time
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into timestamp"))
+ return zeroTime
+ }
+
+ if vr.Type().DataType != TimestampOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
+ return zeroTime
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zeroTime
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len())))
+ return zeroTime
+ }
+
+ microsecSinceY2K := vr.ReadInt64()
+ microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
+ return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
+}
+
+func decodeInet(vr *ValueReader) net.IPNet {
+ var zero net.IPNet
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into net.IPNet"))
+ return zero
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zero
+ }
+
+ pgType := vr.Type()
+ if pgType.DataType != InetOid && pgType.DataType != CidrOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
+ return zero
+ }
+ if vr.Len() != 8 && vr.Len() != 20 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
+ return zero
+ }
+
+ vr.ReadByte() // ignore family
+ bits := vr.ReadByte()
+ vr.ReadByte() // ignore is_cidr
+ addressLength := vr.ReadByte()
+
+ var ipnet net.IPNet
+ ipnet.IP = vr.ReadBytes(int32(addressLength))
+ ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
+
+ return ipnet
+}
+
+func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error {
+ if oid != InetOid && oid != CidrOid {
+ return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid)
+ }
+
+ var size int32
+ var family byte
+ switch len(value.IP) {
+ case net.IPv4len:
+ size = 8
+ family = *w.conn.pgsqlAfInet
+ case net.IPv6len:
+ size = 20
+ family = *w.conn.pgsqlAfInet6
+ default:
+ return fmt.Errorf("Unexpected IP length: %v", len(value.IP))
+ }
+
+ w.WriteInt32(size)
+ w.WriteByte(family)
+ ones, _ := value.Mask.Size()
+ w.WriteByte(byte(ones))
+ w.WriteByte(0) // is_cidr is ignored on server
+ w.WriteByte(byte(len(value.IP)))
+ w.WriteBytes(value.IP)
+
+ return nil
+}
+
+func encodeIP(w *WriteBuf, oid Oid, value net.IP) error {
+ if oid != InetOid && oid != CidrOid {
+ return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid)
+ }
+
+ var ipnet net.IPNet
+ ipnet.IP = value
+ bitCount := len(value) * 8
+ ipnet.Mask = net.CIDRMask(bitCount, bitCount)
+ return encodeIPNet(w, oid, ipnet)
+}
+
+func decodeRecord(vr *ValueReader) []interface{} {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ if vr.Type().DataType != RecordOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType)))
+ return nil
+ }
+
+ valueCount := vr.ReadInt32()
+ record := make([]interface{}, 0, int(valueCount))
+
+ for i := int32(0); i < valueCount; i++ {
+ fd := FieldDescription{FormatCode: BinaryFormatCode}
+ fieldVR := ValueReader{mr: vr.mr, fd: &fd}
+ fd.DataType = vr.ReadOid()
+ fieldVR.valueBytesRemaining = vr.ReadInt32()
+ vr.valueBytesRemaining -= fieldVR.valueBytesRemaining
+
+ switch fd.DataType {
+ case BoolOid:
+ record = append(record, decodeBool(&fieldVR))
+ case ByteaOid:
+ record = append(record, decodeBytea(&fieldVR))
+ case Int8Oid:
+ record = append(record, decodeInt8(&fieldVR))
+ case Int2Oid:
+ record = append(record, decodeInt2(&fieldVR))
+ case Int4Oid:
+ record = append(record, decodeInt4(&fieldVR))
+ case OidOid:
+ record = append(record, decodeOid(&fieldVR))
+ case Float4Oid:
+ record = append(record, decodeFloat4(&fieldVR))
+ case Float8Oid:
+ record = append(record, decodeFloat8(&fieldVR))
+ case DateOid:
+ record = append(record, decodeDate(&fieldVR))
+ case TimestampTzOid:
+ record = append(record, decodeTimestampTz(&fieldVR))
+ case TimestampOid:
+ record = append(record, decodeTimestamp(&fieldVR))
+ case InetOid, CidrOid:
+ record = append(record, decodeInet(&fieldVR))
+ case TextOid, VarcharOid, UnknownOid:
+ record = append(record, decodeText(&fieldVR))
+ default:
+ vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType))
+ return nil
+ }
+
+ // Consume any remaining data
+ if fieldVR.Len() > 0 {
+ fieldVR.ReadBytes(fieldVR.Len())
+ }
+
+ if fieldVR.Err() != nil {
+ vr.Fatal(fieldVR.Err())
+ return nil
+ }
+ }
+
+ return record
+}
+
+func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
+ numDims := vr.ReadInt32()
+ if numDims > 1 {
+ return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims))
+ }
+
+ vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care
+ vr.ReadInt32() // element oid
+
+ if numDims == 0 {
+ return 0, nil
+ }
+
+ length = vr.ReadInt32()
+
+ idxFirstElem := vr.ReadInt32()
+ if idxFirstElem != 1 {
+ return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem))
+ }
+
+ return length, nil
+}
+
+func decodeBoolArray(vr *ValueReader) []bool {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != BoolArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]bool, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 1:
+ if vr.ReadByte() == 1 {
+ a[i] = true
+ }
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error {
+ if oid != BoolArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid)
+ }
+
+ encodeArrayHeader(w, BoolOid, len(slice), 5)
+ for _, v := range slice {
+ w.WriteInt32(1)
+ var b byte
+ if v {
+ b = 1
+ }
+ w.WriteByte(b)
+ }
+
+ return nil
+}
+
+func decodeByteaArray(vr *ValueReader) [][]byte {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != ByteaArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([][]byte, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ a[i] = vr.ReadBytes(elSize)
+ }
+ }
+
+ return a
+}
+
+func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error {
+ if oid != ByteaArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid)
+ }
+
+ size := 20 // array header size
+ for _, el := range value {
+ size += 4 + len(el)
+ }
+
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(ByteaOid)) // type of elements
+ w.WriteInt32(int32(len(value))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, el := range value {
+ encodeByteSlice(w, ByteaOid, el)
+ }
+
+ return nil
+}
+
+func decodeInt2Array(vr *ValueReader) []int16 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int2ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]int16, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 2:
+ a[i] = vr.ReadInt16()
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int2ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]uint16, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 2:
+ tmp := vr.ReadInt16()
+ if tmp < 0 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint16", tmp)))
+ return nil
+ }
+ a[i] = uint16(tmp)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error {
+ if oid != Int2ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid)
+ }
+
+ encodeArrayHeader(w, Int2Oid, len(slice), 6)
+ for _, v := range slice {
+ w.WriteInt32(2)
+ w.WriteInt16(v)
+ }
+
+ return nil
+}
+
+func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error {
+ if oid != Int2ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid)
+ }
+
+ encodeArrayHeader(w, Int2Oid, len(slice), 6)
+ for _, v := range slice {
+ if v <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(v))
+ } else {
+ return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16)
+ }
+ }
+
+ return nil
+}
+
+func decodeInt4Array(vr *ValueReader) []int32 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int4ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]int32, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 4:
+ a[i] = vr.ReadInt32()
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int4ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]uint32, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 4:
+ tmp := vr.ReadInt32()
+ if tmp < 0 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint32", tmp)))
+ return nil
+ }
+ a[i] = uint32(tmp)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error {
+ if oid != Int4ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid)
+ }
+
+ encodeArrayHeader(w, Int4Oid, len(slice), 8)
+ for _, v := range slice {
+ w.WriteInt32(4)
+ w.WriteInt32(v)
+ }
+
+ return nil
+}
+
+func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error {
+ if oid != Int4ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid)
+ }
+
+ encodeArrayHeader(w, Int4Oid, len(slice), 8)
+ for _, v := range slice {
+ if v <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(v))
+ } else {
+ return fmt.Errorf("%d is greater than max integer %d", v, math.MaxInt32)
+ }
+ }
+
+ return nil
+}
+
+func decodeInt8Array(vr *ValueReader) []int64 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int8ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]int64, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ a[i] = vr.ReadInt64()
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int8ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]uint64, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ tmp := vr.ReadInt64()
+ if tmp < 0 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint64", tmp)))
+ return nil
+ }
+ a[i] = uint64(tmp)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error {
+ if oid != Int8ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid)
+ }
+
+ encodeArrayHeader(w, Int8Oid, len(slice), 12)
+ for _, v := range slice {
+ w.WriteInt32(8)
+ w.WriteInt64(v)
+ }
+
+ return nil
+}
+
+func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error {
+ if oid != Int8ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid)
+ }
+
+ encodeArrayHeader(w, Int8Oid, len(slice), 12)
+ for _, v := range slice {
+ if v <= math.MaxInt64 {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(v))
+ } else {
+ return fmt.Errorf("%d is greater than max bigint %d", v, int64(math.MaxInt64))
+ }
+ }
+
+ return nil
+}
+
+func decodeFloat4Array(vr *ValueReader) []float32 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Float4ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]float32, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 4:
+ n := vr.ReadInt32()
+ a[i] = math.Float32frombits(uint32(n))
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeFloat32Slice(w *WriteBuf, oid Oid, slice []float32) error {
+ if oid != Float4ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid)
+ }
+
+ encodeArrayHeader(w, Float4Oid, len(slice), 8)
+ for _, v := range slice {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(math.Float32bits(v)))
+ }
+
+ return nil
+}
+
+func decodeFloat8Array(vr *ValueReader) []float64 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Float8ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]float64, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ n := vr.ReadInt64()
+ a[i] = math.Float64frombits(uint64(n))
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeFloat64Slice(w *WriteBuf, oid Oid, slice []float64) error {
+ if oid != Float8ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid)
+ }
+
+ encodeArrayHeader(w, Float8Oid, len(slice), 12)
+ for _, v := range slice {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(math.Float64bits(v)))
+ }
+
+ return nil
+}
+
+func decodeTextArray(vr *ValueReader) []string {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]string, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ if elSize == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ }
+
+ a[i] = vr.ReadString(elSize)
+ }
+
+ return a
+}
+
+// escapeAclItem escapes an AclItem before it is added to
+// its aclitem[] string representation. The PostgreSQL aclitem
+// datatype itself can need escapes because it follows the
+// formatting rules of SQL identifiers. Think of this function
+// as escaping the escapes, so that PostgreSQL's array parser
+// will do the right thing.
+func escapeAclItem(acl string) (string, error) {
+ var escapedAclItem bytes.Buffer
+ reader := strings.NewReader(acl)
+ for {
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ // Here, EOF is an expected end state, not an error.
+ return escapedAclItem.String(), nil
+ }
+ // This error was not expected
+ return "", err
+ }
+ if needsEscape(rn) {
+ escapedAclItem.WriteRune('\\')
+ }
+ escapedAclItem.WriteRune(rn)
+ }
+}
+
+// needsEscape determines whether or not a rune needs escaping
+// before being placed in the textual representation of an
+// aclitem[] array.
+func needsEscape(rn rune) bool {
+ return rn == '\\' || rn == ',' || rn == '"' || rn == '}'
+}
+
+// encodeAclItemSlice encodes a slice of AclItems in
+// their textual represention for PostgreSQL.
+func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error {
+ strs := make([]string, len(aclitems))
+ var escapedAclItem string
+ var err error
+ for i := range strs {
+ escapedAclItem, err = escapeAclItem(string(aclitems[i]))
+ if err != nil {
+ return err
+ }
+ strs[i] = string(escapedAclItem)
+ }
+
+ var buf bytes.Buffer
+ buf.WriteRune('{')
+ buf.WriteString(strings.Join(strs, ","))
+ buf.WriteRune('}')
+ str := buf.String()
+ w.WriteInt32(int32(len(str)))
+ w.WriteBytes([]byte(str))
+ return nil
+}
+
+// parseAclItemArray parses the textual representation
+// of the aclitem[] type. The textual representation is chosen because
+// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin).
+// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
+// for formatting notes.
+func parseAclItemArray(arr string) ([]AclItem, error) {
+ reader := strings.NewReader(arr)
+ // Difficult to guess a performant initial capacity for a slice of
+ // aclitems, but let's go with 5.
+ aclItems := make([]AclItem, 0, 5)
+ // A single value
+ aclItem := AclItem("")
+ for {
+ // Grab the first/next/last rune to see if we are dealing with a
+ // quoted value, an unquoted value, or the end of the string.
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ // Here, EOF is an expected end state, not an error.
+ return aclItems, nil
+ }
+ // This error was not expected
+ return nil, err
+ }
+
+ if rn == '"' {
+ // Discard the opening quote of the quoted value.
+ aclItem, err = parseQuotedAclItem(reader)
+ } else {
+ // We have just read the first rune of an unquoted (bare) value;
+ // put it back so that ParseBareValue can read it.
+ err := reader.UnreadRune()
+ if err != nil {
+ return nil, err
+ }
+ aclItem, err = parseBareAclItem(reader)
+ }
+
+ if err != nil {
+ if err == io.EOF {
+ // Here, EOF is an expected end state, not an error..
+ aclItems = append(aclItems, aclItem)
+ return aclItems, nil
+ }
+ // This error was not expected.
+ return nil, err
+ }
+ aclItems = append(aclItems, aclItem)
+ }
+}
+
+// parseBareAclItem parses a bare (unquoted) aclitem from reader
+func parseBareAclItem(reader *strings.Reader) (AclItem, error) {
+ var aclItem bytes.Buffer
+ for {
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ // Return the read value in case the error is a harmless io.EOF.
+ // (io.EOF marks the end of a bare aclitem at the end of a string)
+ return AclItem(aclItem.String()), err
+ }
+ if rn == ',' {
+ // A comma marks the end of a bare aclitem.
+ return AclItem(aclItem.String()), nil
+ } else {
+ aclItem.WriteRune(rn)
+ }
+ }
+}
+
+// parseQuotedAclItem parses an aclitem which is in double quotes from reader
+func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) {
+ var aclItem bytes.Buffer
+ for {
+ rn, escaped, err := readPossiblyEscapedRune(reader)
+ if err != nil {
+ if err == io.EOF {
+ // Even when it is the last value, the final rune of
+ // a quoted aclitem should be the final closing quote, not io.EOF.
+ return AclItem(""), fmt.Errorf("unexpected end of quoted value")
+ }
+ // Return the read aclitem in case the error is a harmless io.EOF,
+ // which will be determined by the caller.
+ return AclItem(aclItem.String()), err
+ }
+ if !escaped && rn == '"' {
+ // An unescaped double quote marks the end of a quoted value.
+ // The next rune should either be a comma or the end of the string.
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ // Return the read value in case the error is a harmless io.EOF,
+ // which will be determined by the caller.
+ return AclItem(aclItem.String()), err
+ }
+ if rn != ',' {
+ return AclItem(""), fmt.Errorf("unexpected rune after quoted value")
+ }
+ return AclItem(aclItem.String()), nil
+ }
+ aclItem.WriteRune(rn)
+ }
+}
+
+// Returns the next rune from r, unless it is a backslash;
+// in that case, it returns the rune after the backslash. The second
+// return value tells us whether or not the rune was
+// preceeded by a backslash (escaped).
+func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) {
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ return 0, false, err
+ }
+ if rn == '\\' {
+ // Discard the backslash and read the next rune.
+ rn, _, err = reader.ReadRune()
+ if err != nil {
+ return 0, false, err
+ }
+ return rn, true, nil
+ }
+ return rn, false, nil
+}
+
+func decodeAclItemArray(vr *ValueReader) []AclItem {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into []AclItem"))
+ return nil
+ }
+
+ str := vr.ReadString(vr.Len())
+
+ // Short-circuit empty array.
+ if str == "{}" {
+ return []AclItem{}
+ }
+
+ // Remove the '{' at the front and the '}' at the end,
+ // so that parseAclItemArray doesn't have to deal with them.
+ str = str[1 : len(str)-1]
+ aclItems, err := parseAclItemArray(str)
+ if err != nil {
+ vr.Fatal(ProtocolError(err.Error()))
+ return nil
+ }
+ return aclItems
+}
+
+func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error {
+ var elOid Oid
+ switch oid {
+ case VarcharArrayOid:
+ elOid = VarcharOid
+ case TextArrayOid:
+ elOid = TextOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid)
+ }
+
+ var totalStringSize int
+ for _, v := range slice {
+ totalStringSize += len(v)
+ }
+
+ size := 20 + len(slice)*4 + totalStringSize
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(elOid)) // type of elements
+ w.WriteInt32(int32(len(slice))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, v := range slice {
+ w.WriteInt32(int32(len(v)))
+ w.WriteBytes([]byte(v))
+ }
+
+ return nil
+}
+
+func decodeTimestampArray(vr *ValueReader) []time.Time {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]time.Time, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ microsecSinceY2K := vr.ReadInt64()
+ microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
+ a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeTimeSlice(w *WriteBuf, oid Oid, slice []time.Time) error {
+ var elOid Oid
+ switch oid {
+ case TimestampArrayOid:
+ elOid = TimestampOid
+ case TimestampTzArrayOid:
+ elOid = TimestampTzOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid)
+ }
+
+ encodeArrayHeader(w, int(elOid), len(slice), 12)
+ for _, t := range slice {
+ w.WriteInt32(8)
+ microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000
+ microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K
+ w.WriteInt64(microsecSinceY2K)
+ }
+
+ return nil
+}
+
+func decodeInetArray(vr *ValueReader) []net.IPNet {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]net.IPNet, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ if elSize == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ }
+
+ vr.ReadByte() // ignore family
+ bits := vr.ReadByte()
+ vr.ReadByte() // ignore is_cidr
+ addressLength := vr.ReadByte()
+
+ var ipnet net.IPNet
+ ipnet.IP = vr.ReadBytes(int32(addressLength))
+ ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
+
+ a[i] = ipnet
+ }
+
+ return a
+}
+
+func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error {
+ var elOid Oid
+ switch oid {
+ case InetArrayOid:
+ elOid = InetOid
+ case CidrArrayOid:
+ elOid = CidrOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid)
+ }
+
+ size := int32(20) // array header size
+ for _, ipnet := range slice {
+ size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes
+ }
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(elOid)) // type of elements
+ w.WriteInt32(int32(len(slice))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, ipnet := range slice {
+ encodeIPNet(w, elOid, ipnet)
+ }
+
+ return nil
+}
+
+func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error {
+ var elOid Oid
+ switch oid {
+ case InetArrayOid:
+ elOid = InetOid
+ case CidrArrayOid:
+ elOid = CidrOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid)
+ }
+
+ size := int32(20) // array header size
+ for _, ip := range slice {
+ size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes
+ }
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(elOid)) // type of elements
+ w.WriteInt32(int32(len(slice))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, ip := range slice {
+ encodeIP(w, elOid, ip)
+ }
+
+ return nil
+}
+
+func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) {
+ w.WriteInt32(int32(20 + length*sizePerItem))
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(oid)) // type of elements
+ w.WriteInt32(int32(length)) // number of elements
+ w.WriteInt32(1) // index of first element
+}