aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc')
-rw-r--r--vendor/github.com/jackc/pgx/.gitignore24
-rw-r--r--vendor/github.com/jackc/pgx/.travis.yml60
-rw-r--r--vendor/github.com/jackc/pgx/CHANGELOG.md190
-rw-r--r--vendor/github.com/jackc/pgx/LICENSE22
-rw-r--r--vendor/github.com/jackc/pgx/README.md137
-rw-r--r--vendor/github.com/jackc/pgx/aclitem_parse_test.go126
-rw-r--r--vendor/github.com/jackc/pgx/bench_test.go765
-rw-r--r--vendor/github.com/jackc/pgx/conn.go1322
-rw-r--r--vendor/github.com/jackc/pgx/conn_config_test.go.example25
-rw-r--r--vendor/github.com/jackc/pgx/conn_config_test.go.travis30
-rw-r--r--vendor/github.com/jackc/pgx/conn_pool.go529
-rw-r--r--vendor/github.com/jackc/pgx/conn_pool_private_test.go44
-rw-r--r--vendor/github.com/jackc/pgx/conn_pool_test.go982
-rw-r--r--vendor/github.com/jackc/pgx/conn_test.go1744
-rw-r--r--vendor/github.com/jackc/pgx/copy_from.go241
-rw-r--r--vendor/github.com/jackc/pgx/copy_from_test.go428
-rw-r--r--vendor/github.com/jackc/pgx/copy_to.go222
-rw-r--r--vendor/github.com/jackc/pgx/copy_to_test.go367
-rw-r--r--vendor/github.com/jackc/pgx/doc.go265
-rw-r--r--vendor/github.com/jackc/pgx/example_custom_type_test.go104
-rw-r--r--vendor/github.com/jackc/pgx/example_json_test.go43
-rw-r--r--vendor/github.com/jackc/pgx/examples/README.md7
-rw-r--r--vendor/github.com/jackc/pgx/examples/chat/README.md25
-rw-r--r--vendor/github.com/jackc/pgx/examples/chat/main.go95
-rw-r--r--vendor/github.com/jackc/pgx/examples/todo/README.md72
-rw-r--r--vendor/github.com/jackc/pgx/examples/todo/main.go140
-rw-r--r--vendor/github.com/jackc/pgx/examples/todo/structure.sql4
-rw-r--r--vendor/github.com/jackc/pgx/examples/url_shortener/README.md25
-rw-r--r--vendor/github.com/jackc/pgx/examples/url_shortener/main.go128
-rw-r--r--vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql4
-rw-r--r--vendor/github.com/jackc/pgx/fastpath.go108
-rw-r--r--vendor/github.com/jackc/pgx/helper_test.go74
-rw-r--r--vendor/github.com/jackc/pgx/hstore.go222
-rw-r--r--vendor/github.com/jackc/pgx/hstore_test.go181
-rw-r--r--vendor/github.com/jackc/pgx/large_objects.go147
-rw-r--r--vendor/github.com/jackc/pgx/large_objects_test.go121
-rw-r--r--vendor/github.com/jackc/pgx/logger.go81
-rw-r--r--vendor/github.com/jackc/pgx/messages.go159
-rw-r--r--vendor/github.com/jackc/pgx/msg_reader.go316
-rw-r--r--vendor/github.com/jackc/pgx/pgpass.go85
-rw-r--r--vendor/github.com/jackc/pgx/pgpass_test.go57
-rw-r--r--vendor/github.com/jackc/pgx/query.go494
-rw-r--r--vendor/github.com/jackc/pgx/query_test.go1414
-rw-r--r--vendor/github.com/jackc/pgx/replication.go429
-rw-r--r--vendor/github.com/jackc/pgx/replication_test.go329
-rw-r--r--vendor/github.com/jackc/pgx/sql.go29
-rw-r--r--vendor/github.com/jackc/pgx/sql_test.go36
-rw-r--r--vendor/github.com/jackc/pgx/stdlib/sql.go348
-rw-r--r--vendor/github.com/jackc/pgx/stdlib/sql_test.go691
-rw-r--r--vendor/github.com/jackc/pgx/stress_test.go346
-rw-r--r--vendor/github.com/jackc/pgx/tx.go207
-rw-r--r--vendor/github.com/jackc/pgx/tx_test.go297
-rw-r--r--vendor/github.com/jackc/pgx/value_reader.go156
-rw-r--r--vendor/github.com/jackc/pgx/values.go3439
-rw-r--r--vendor/github.com/jackc/pgx/values_test.go1183
55 files changed, 19119 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/.gitignore b/vendor/github.com/jackc/pgx/.gitignore
new file mode 100644
index 0000000..cb0cd90
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/.gitignore
@@ -0,0 +1,24 @@
+# Compiled Object files, Static and Dynamic libs (Shared Objects)
+*.o
+*.a
+*.so
+
+# Folders
+_obj
+_test
+
+# Architecture specific extensions/prefixes
+*.[568vq]
+[568vq].out
+
+*.cgo1.go
+*.cgo2.c
+_cgo_defun.c
+_cgo_gotypes.go
+_cgo_export.*
+
+_testmain.go
+
+*.exe
+
+conn_config_test.go
diff --git a/vendor/github.com/jackc/pgx/.travis.yml b/vendor/github.com/jackc/pgx/.travis.yml
new file mode 100644
index 0000000..d9ea43b
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/.travis.yml
@@ -0,0 +1,60 @@
+language: go
+
+go:
+ - 1.7.4
+ - 1.6.4
+ - tip
+
+# Derived from https://github.com/lib/pq/blob/master/.travis.yml
+before_install:
+ - sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common
+ - sudo rm -rf /var/lib/postgresql
+ - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
+ - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
+ - sudo apt-get update -qq
+ - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
+ - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
+ - sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
+ - "[[ $PGVERSION < 9.6 ]] || echo \"wal_level='logical'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
+ - "[[ $PGVERSION < 9.6 ]] || echo \"max_wal_senders=5\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
+ - "[[ $PGVERSION < 9.6 ]] || echo \"max_replication_slots=5\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
+ - sudo /etc/init.d/postgresql restart
+
+env:
+ matrix:
+ - PGVERSION=9.6
+ - PGVERSION=9.5
+ - PGVERSION=9.4
+ - PGVERSION=9.3
+ - PGVERSION=9.2
+
+# The tricky test user, below, has to actually exist so that it can be used in a test
+# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
+before_script:
+ - mv conn_config_test.go.travis conn_config_test.go
+ - psql -U postgres -c 'create database pgx_test'
+ - "[[ \"${PGVERSION}\" = '9.0' ]] && psql -U postgres -f /usr/share/postgresql/9.0/contrib/hstore.sql pgx_test || psql -U postgres pgx_test -c 'create extension hstore'"
+ - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
+ - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
+ - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
+ - psql -U postgres -c "create user pgx_replication with replication password 'secret'"
+ - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
+
+install:
+ - go get -u github.com/shopspring/decimal
+ - go get -u gopkg.in/inconshreveable/log15.v2
+ - go get -u github.com/jackc/fake
+
+script:
+ - go test -v -race -short ./...
+
+matrix:
+ allow_failures:
+ - go: tip
diff --git a/vendor/github.com/jackc/pgx/CHANGELOG.md b/vendor/github.com/jackc/pgx/CHANGELOG.md
new file mode 100644
index 0000000..88c663b
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/CHANGELOG.md
@@ -0,0 +1,190 @@
+# Unreleased
+
+## Fixes
+
+* database/sql driver created through stdlib.OpenFromConnPool closes connections when requested by database/sql rather than release to underlying connection pool.
+
+# 2.11.0 (June 5, 2017)
+
+## Fixes
+
+* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock)
+
+## Features
+
+* .pgpass support (j7b)
+* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen)
+* Add ParseConnectionString (James Lawrence)
+
+## Performance
+
+* Optimize HStore encoding (René Kroon)
+
+# 2.10.0 (March 17, 2017)
+
+## Fixes
+
+* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood)
+* Explicitly close checked-in connections on ConnPool.Reset, previously they were closed by GC
+
+## Features
+
+* Add xid type support (Manni Wood)
+* Add cid type support (Manni Wood)
+* Add tid type support (Manni Wood)
+* Add "char" type support (Manni Wood)
+* Add NullOid type (Manni Wood)
+* Add json/jsonb binary support to allow use with CopyTo
+* Add named error ErrAcquireTimeout (Alexander Staubo)
+* Add logical replication decoding (Kris Wehner)
+* Add PgxScanner interface to allow types to simultaneously support database/sql and pgx (Jack Christensen)
+* Add CopyFrom with schema support (Jack Christensen)
+
+## Compatibility
+
+* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work.
+* CopyTo is now deprecated but will continue to work.
+
+# 2.9.0 (August 26, 2016)
+
+## Fixes
+
+* Fix *ConnPool.Deallocate() not deleting prepared statement from map
+* Fix stdlib not logging unprepared query SQL (Krzysztof Dryś)
+* Fix Rows.Values() with varchar binary format
+* Concurrent ConnPool.Acquire calls with Dialer timeouts now timeout in the expected amount of time (Konstantin Dzreev)
+
+## Features
+
+* Add CopyTo
+* Add PrepareEx
+* Add basic record to []interface{} decoding
+* Encode and decode between all Go and PostgreSQL integer types with bounds checking
+* Decode inet/cidr to net.IP
+* Encode/decode [][]byte to/from bytea[]
+* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
+
+## Performance
+
+* Substantial reduction in memory allocations
+
+# 2.8.1 (March 24, 2016)
+
+## Features
+
+* Scan accepts nil argument to ignore a column
+
+## Fixes
+
+* Fix compilation on 32-bit architecture
+* Fix Tx.status not being set on error on Commit
+* Fix Listen/Unlisten with special characters
+
+# 2.8.0 (March 18, 2016)
+
+## Fixes
+
+* Fix unrecognized commit failure
+* Fix msgReader.rxMsg bug when msgReader already has error
+* Go float64 can no longer be encoded to a PostgreSQL float4
+* Fix connection corruption when query with error is closed early
+
+## Features
+
+This release adds multiple extension points helpful when wrapping pgx with
+custom application behavior. pgx can now use custom types designed for the
+standard database/sql package such as
+[github.com/shopspring/decimal](https://github.com/shopspring/decimal).
+
+* Add *Tx.AfterClose() hook
+* Add *Tx.Conn()
+* Add *Tx.Status()
+* Add *Tx.Err()
+* Add *Rows.AfterClose() hook
+* Add *Rows.Conn()
+* Add *Conn.SetLogger() to allow changing logger
+* Add *Conn.SetLogLevel() to allow changing log level
+* Add ConnPool.Reset method
+* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
+* Rows.Scan errors now include which argument caused error
+* Add Encode() to allow custom Encoders to reuse internal encoding functionality
+* Add Decode() to allow customer Decoders to reuse internal decoding functionality
+* Add ConnPool.Prepare method
+* Add ConnPool.Deallocate method
+* Add Scan to uint32 and uint64 (utrack)
+* Add encode and decode to []uint16, []uint32, and []uint64 (Max Musatov)
+
+## Performance
+
+* []byte skips encoding/decoding
+
+# 2.7.1 (October 26, 2015)
+
+* Disable SSL renegotiation
+
+# 2.7.0 (October 16, 2015)
+
+* Add RuntimeParams to ConnConfig
+* ParseURI extracts RuntimeParams
+* ParseDSN extracts RuntimeParams
+* ParseEnvLibpq extracts PGAPPNAME
+* Prepare is now idempotent
+* Rows.Values now supports oid type
+* ConnPool.Release automatically unlistens connections (Joseph Glanville)
+* Add trace log level
+* Add more efficient log leveling
+* Retry automatically on ConnPool.Begin (Joseph Glanville)
+* Encode from net.IP to inet and cidr
+* Generalize encoding pointer to string to any PostgreSQL type
+* Add UUID encoding from pointer to string (Joseph Glanville)
+* Add null mapping to pointer to pointer (Jonathan Rudenberg)
+* Add JSON and JSONB type support (Joseph Glanville)
+
+# 2.6.0 (September 3, 2015)
+
+* Add inet and cidr type support
+* Add binary decoding to TimestampOid in stdlib driver (Samuel Stauffer)
+* Add support for specifying sslmode in connection strings (Rick Snyder)
+* Allow ConnPool to have MaxConnections of 1
+* Add basic PGSSLMODE to support to ParseEnvLibpq
+* Add fallback TLS config
+* Expose specific error for TSL refused
+* More error details exposed in PgError
+* Support custom dialer (Lewis Marshall)
+
+# 2.5.0 (April 15, 2015)
+
+* Fix stdlib nil support (Blaž Hrastnik)
+* Support custom Scanner not reading entire value
+* Fix empty array scanning (Laurent Debacker)
+* Add ParseDSN (deoxxa)
+* Add timestamp support to NullTime
+* Remove unused text format scanners
+* Return error when too many parameters on Prepare
+* Add Travis CI integration (Jonathan Rudenberg)
+* Large object support (Jonathan Rudenberg)
+* Fix reading null byte arrays (Karl Seguin)
+* Add timestamptz[] support
+* Add timestamp[] support (Karl Seguin)
+* Add bool[] support (Karl Seguin)
+* Allow writing []byte into text and varchar columns without type conversion (Hari Bhaskaran)
+* Fix ConnPool Close panic
+* Add Listen / notify example
+* Reduce memory allocations (Karl Seguin)
+
+# 2.4.0 (October 3, 2014)
+
+* Add per connection oid to name map
+* Add Hstore support (Andy Walker)
+* Move introductory docs to godoc from readme
+* Fix documentation references to TextEncoder and BinaryEncoder
+* Add keep-alive to TCP connections (Andy Walker)
+* Add support for EmptyQueryResponse / Allow no-op Exec (Andy Walker)
+* Allow reading any type into []byte
+* WaitForNotification detects lost connections quicker
+
+# 2.3.0 (September 16, 2014)
+
+* Truncate logged strings and byte slices
+* Extract more error information from PostgreSQL
+* Fix data race with Rows and ConnPool
diff --git a/vendor/github.com/jackc/pgx/LICENSE b/vendor/github.com/jackc/pgx/LICENSE
new file mode 100644
index 0000000..7dee3da
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/LICENSE
@@ -0,0 +1,22 @@
+Copyright (c) 2013 Jack Christensen
+
+MIT License
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file
diff --git a/vendor/github.com/jackc/pgx/README.md b/vendor/github.com/jackc/pgx/README.md
new file mode 100644
index 0000000..5550f6b
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/README.md
@@ -0,0 +1,137 @@
+[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx)
+
+# Pgx
+
+## Master Branch
+
+This is the `master` branch which tracks the stable release of the current
+version. At the moment this is `v2`. The `v3` branch which is currently in beta.
+General release is planned for July. `v3` is considered to be stable in the
+sense of lack of known bugs, but the API is not considered stable until general
+release. No further changes are planned, but the beta process may surface
+desirable changes. If possible API changes are acceptable, then `v3` is the
+recommented branch for new development. Regardless, please lock to the `v2` or
+`v3` branch as when `v3` is released breaking changes will be applied to the
+master branch.
+
+Pgx is a pure Go database connection library designed specifically for
+PostgreSQL. Pgx is different from other drivers such as
+[pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a
+database/sql compatible driver, pgx is primarily intended to be used directly.
+It offers a native interface similar to database/sql that offers better
+performance and more features.
+
+## Features
+
+Pgx supports many additional features beyond what is available through database/sql.
+
+* Listen / notify
+* Transaction isolation level control
+* Full TLS connection control
+* Binary format support for custom types (can be much faster)
+* Copy from protocol support for faster bulk data loads
+* Logging support
+* Configurable connection pool with after connect hooks to do arbitrary connection setup
+* PostgreSQL array to Go slice mapping for integers, floats, and strings
+* Hstore support
+* JSON and JSONB support
+* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP
+* Large object support
+* Null mapping to Null* struct or pointer to pointer.
+* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
+* Logical replication connections, including receiving WAL and sending standby status updates
+
+## Performance
+
+Pgx performs roughly equivalent to [pq](http://godoc.org/github.com/lib/pq) and
+[go-pg](https://github.com/go-pg/pg) for selecting a single column from a single
+row, but it is substantially faster when selecting multiple entire rows (6893
+queries/sec for pgx vs. 3968 queries/sec for pq -- 73% faster).
+
+See this [gist](https://gist.github.com/jackc/d282f39e088b495fba3e) for the
+underlying benchmark results or checkout
+[go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself.
+
+## database/sql
+
+Import the ```github.com/jackc/pgx/stdlib``` package to use pgx as a driver for
+database/sql. It is possible to retrieve a pgx connection from database/sql on
+demand. This allows using the database/sql interface in most places, but using
+pgx directly when more performance or PostgreSQL specific features are needed.
+
+## Documentation
+
+pgx includes extensive documentation in the godoc format. It is viewable online at [godoc.org](https://godoc.org/github.com/jackc/pgx).
+
+## Testing
+
+pgx supports multiple connection and authentication types. Setting up a test
+environment that can test all of them can be cumbersome. In particular,
+Windows cannot test Unix domain socket connections. Because of this pgx will
+skip tests for connection types that are not configured.
+
+### Normal Test Environment
+
+To setup the normal test environment, first install these dependencies:
+
+ go get github.com/jackc/fake
+ go get github.com/shopspring/decimal
+ go get gopkg.in/inconshreveable/log15.v2
+
+Then run the following SQL:
+
+ create user pgx_md5 password 'secret';
+ create user " tricky, ' } "" \ test user " password 'secret';
+ create database pgx_test;
+ create user pgx_replication with replication password 'secret';
+
+Connect to database pgx_test and run:
+
+ create extension hstore;
+
+Next open conn_config_test.go.example and make a copy without the
+.example. If your PostgreSQL server is accepting connections on 127.0.0.1,
+then you are done.
+
+### Connection and Authentication Test Environment
+
+Complete the normal test environment setup and also do the following.
+
+Run the following SQL:
+
+ create user pgx_none;
+ create user pgx_pw password 'secret';
+
+Add the following to your pg_hba.conf:
+
+If you are developing on Unix with domain socket connections:
+
+ local pgx_test pgx_none trust
+ local pgx_test pgx_pw password
+ local pgx_test pgx_md5 md5
+
+If you are developing on Windows with TCP connections:
+
+ host pgx_test pgx_none 127.0.0.1/32 trust
+ host pgx_test pgx_pw 127.0.0.1/32 password
+ host pgx_test pgx_md5 127.0.0.1/32 md5
+
+### Replication Test Environment
+
+Add a replication user:
+
+ create user pgx_replication with replication password 'secret';
+
+Add a replication line to your pg_hba.conf:
+
+ host replication pgx_replication 127.0.0.1/32 md5
+
+Change the following settings in your postgresql.conf:
+
+ wal_level=logical
+ max_wal_senders=5
+ max_replication_slots=5
+
+## Version Policy
+
+pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely).
diff --git a/vendor/github.com/jackc/pgx/aclitem_parse_test.go b/vendor/github.com/jackc/pgx/aclitem_parse_test.go
new file mode 100644
index 0000000..5c7c748
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/aclitem_parse_test.go
@@ -0,0 +1,126 @@
+package pgx
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestEscapeAclItem(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {
+ "foo",
+ "foo",
+ },
+ {
+ `foo, "\}`,
+ `foo\, \"\\\}`,
+ },
+ }
+
+ for i, tt := range tests {
+ actual, err := escapeAclItem(tt.input)
+
+ if err != nil {
+ t.Errorf("%d. Unexpected error %v", i, err)
+ }
+
+ if actual != tt.expected {
+ t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual)
+ }
+ }
+}
+
+func TestParseAclItemArray(t *testing.T) {
+ tests := []struct {
+ input string
+ expected []AclItem
+ errMsg string
+ }{
+ {
+ "",
+ []AclItem{},
+ "",
+ },
+ {
+ "one",
+ []AclItem{"one"},
+ "",
+ },
+ {
+ `"one"`,
+ []AclItem{"one"},
+ "",
+ },
+ {
+ "one,two,three",
+ []AclItem{"one", "two", "three"},
+ "",
+ },
+ {
+ `"one","two","three"`,
+ []AclItem{"one", "two", "three"},
+ "",
+ },
+ {
+ `"one",two,"three"`,
+ []AclItem{"one", "two", "three"},
+ "",
+ },
+ {
+ `one,two,"three"`,
+ []AclItem{"one", "two", "three"},
+ "",
+ },
+ {
+ `"one","two",three`,
+ []AclItem{"one", "two", "three"},
+ "",
+ },
+ {
+ `"one","t w o",three`,
+ []AclItem{"one", "t w o", "three"},
+ "",
+ },
+ {
+ `"one","t, w o\"\}\\",three`,
+ []AclItem{"one", `t, w o"}\`, "three"},
+ "",
+ },
+ {
+ `"one","two",three"`,
+ []AclItem{"one", "two", `three"`},
+ "",
+ },
+ {
+ `"one","two,"three"`,
+ nil,
+ "unexpected rune after quoted value",
+ },
+ {
+ `"one","two","three`,
+ nil,
+ "unexpected end of quoted value",
+ },
+ }
+
+ for i, tt := range tests {
+ actual, err := parseAclItemArray(tt.input)
+
+ if err != nil {
+ if tt.errMsg == "" {
+ t.Errorf("%d. Unexpected error %v", i, err)
+ } else if err.Error() != tt.errMsg {
+ t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error())
+ }
+ } else if tt.errMsg != "" {
+ t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg)
+ }
+
+ if !reflect.DeepEqual(actual, tt.expected) {
+ t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual)
+ }
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/bench_test.go b/vendor/github.com/jackc/pgx/bench_test.go
new file mode 100644
index 0000000..30e31e2
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/bench_test.go
@@ -0,0 +1,765 @@
+package pgx_test
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx"
+ log "gopkg.in/inconshreveable/log15.v2"
+)
+
+func BenchmarkConnPool(b *testing.B) {
+ config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ b.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer pool.Close()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var conn *pgx.Conn
+ if conn, err = pool.Acquire(); err != nil {
+ b.Fatalf("Unable to acquire connection: %v", err)
+ }
+ pool.Release(conn)
+ }
+}
+
+func BenchmarkConnPoolQueryRow(b *testing.B) {
+ config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ b.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer pool.Close()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ num := float64(-1)
+ if err := pool.QueryRow("select random()").Scan(&num); err != nil {
+ b.Fatal(err)
+ }
+
+ if num < 0 {
+ b.Fatalf("expected `select random()` to return between 0 and 1 but it was: %v", num)
+ }
+ }
+}
+
+func BenchmarkNullXWithNullValues(b *testing.B) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz")
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var record struct {
+ id int32
+ userName string
+ email pgx.NullString
+ name pgx.NullString
+ sex pgx.NullString
+ birthDate pgx.NullTime
+ lastLoginTime pgx.NullTime
+ }
+
+ err = conn.QueryRow("selectNulls").Scan(
+ &record.id,
+ &record.userName,
+ &record.email,
+ &record.name,
+ &record.sex,
+ &record.birthDate,
+ &record.lastLoginTime,
+ )
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ // These checks both ensure that the correct data was returned
+ // and provide a benchmark of accessing the returned values.
+ if record.id != 1 {
+ b.Fatalf("bad value for id: %v", record.id)
+ }
+ if record.userName != "johnsmith" {
+ b.Fatalf("bad value for userName: %v", record.userName)
+ }
+ if record.email.Valid {
+ b.Fatalf("bad value for email: %v", record.email)
+ }
+ if record.name.Valid {
+ b.Fatalf("bad value for name: %v", record.name)
+ }
+ if record.sex.Valid {
+ b.Fatalf("bad value for sex: %v", record.sex)
+ }
+ if record.birthDate.Valid {
+ b.Fatalf("bad value for birthDate: %v", record.birthDate)
+ }
+ if record.lastLoginTime.Valid {
+ b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
+ }
+ }
+}
+
+func BenchmarkNullXWithPresentValues(b *testing.B) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var record struct {
+ id int32
+ userName string
+ email pgx.NullString
+ name pgx.NullString
+ sex pgx.NullString
+ birthDate pgx.NullTime
+ lastLoginTime pgx.NullTime
+ }
+
+ err = conn.QueryRow("selectNulls").Scan(
+ &record.id,
+ &record.userName,
+ &record.email,
+ &record.name,
+ &record.sex,
+ &record.birthDate,
+ &record.lastLoginTime,
+ )
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ // These checks both ensure that the correct data was returned
+ // and provide a benchmark of accessing the returned values.
+ if record.id != 1 {
+ b.Fatalf("bad value for id: %v", record.id)
+ }
+ if record.userName != "johnsmith" {
+ b.Fatalf("bad value for userName: %v", record.userName)
+ }
+ if !record.email.Valid || record.email.String != "johnsmith@example.com" {
+ b.Fatalf("bad value for email: %v", record.email)
+ }
+ if !record.name.Valid || record.name.String != "John Smith" {
+ b.Fatalf("bad value for name: %v", record.name)
+ }
+ if !record.sex.Valid || record.sex.String != "male" {
+ b.Fatalf("bad value for sex: %v", record.sex)
+ }
+ if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) {
+ b.Fatalf("bad value for birthDate: %v", record.birthDate)
+ }
+ if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
+ b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
+ }
+ }
+}
+
+func BenchmarkPointerPointerWithNullValues(b *testing.B) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz")
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var record struct {
+ id int32
+ userName string
+ email *string
+ name *string
+ sex *string
+ birthDate *time.Time
+ lastLoginTime *time.Time
+ }
+
+ err = conn.QueryRow("selectNulls").Scan(
+ &record.id,
+ &record.userName,
+ &record.email,
+ &record.name,
+ &record.sex,
+ &record.birthDate,
+ &record.lastLoginTime,
+ )
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ // These checks both ensure that the correct data was returned
+ // and provide a benchmark of accessing the returned values.
+ if record.id != 1 {
+ b.Fatalf("bad value for id: %v", record.id)
+ }
+ if record.userName != "johnsmith" {
+ b.Fatalf("bad value for userName: %v", record.userName)
+ }
+ if record.email != nil {
+ b.Fatalf("bad value for email: %v", record.email)
+ }
+ if record.name != nil {
+ b.Fatalf("bad value for name: %v", record.name)
+ }
+ if record.sex != nil {
+ b.Fatalf("bad value for sex: %v", record.sex)
+ }
+ if record.birthDate != nil {
+ b.Fatalf("bad value for birthDate: %v", record.birthDate)
+ }
+ if record.lastLoginTime != nil {
+ b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
+ }
+ }
+}
+
+func BenchmarkPointerPointerWithPresentValues(b *testing.B) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var record struct {
+ id int32
+ userName string
+ email *string
+ name *string
+ sex *string
+ birthDate *time.Time
+ lastLoginTime *time.Time
+ }
+
+ err = conn.QueryRow("selectNulls").Scan(
+ &record.id,
+ &record.userName,
+ &record.email,
+ &record.name,
+ &record.sex,
+ &record.birthDate,
+ &record.lastLoginTime,
+ )
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ // These checks both ensure that the correct data was returned
+ // and provide a benchmark of accessing the returned values.
+ if record.id != 1 {
+ b.Fatalf("bad value for id: %v", record.id)
+ }
+ if record.userName != "johnsmith" {
+ b.Fatalf("bad value for userName: %v", record.userName)
+ }
+ if record.email == nil || *record.email != "johnsmith@example.com" {
+ b.Fatalf("bad value for email: %v", record.email)
+ }
+ if record.name == nil || *record.name != "John Smith" {
+ b.Fatalf("bad value for name: %v", record.name)
+ }
+ if record.sex == nil || *record.sex != "male" {
+ b.Fatalf("bad value for sex: %v", record.sex)
+ }
+ if record.birthDate == nil || *record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) {
+ b.Fatalf("bad value for birthDate: %v", record.birthDate)
+ }
+ if record.lastLoginTime == nil || *record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
+ b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
+ }
+ }
+}
+
+func BenchmarkSelectWithoutLogging(b *testing.B) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ benchmarkSelectWithLog(b, conn)
+}
+
+func BenchmarkSelectWithLoggingTraceWithLog15(b *testing.B) {
+ connConfig := *defaultConnConfig
+
+ logger := log.New()
+ lvl, err := log.LvlFromString("debug")
+ if err != nil {
+ b.Fatal(err)
+ }
+ logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
+ connConfig.Logger = logger
+ connConfig.LogLevel = pgx.LogLevelTrace
+ conn := mustConnect(b, connConfig)
+ defer closeConn(b, conn)
+
+ benchmarkSelectWithLog(b, conn)
+}
+
+func BenchmarkSelectWithLoggingDebugWithLog15(b *testing.B) {
+ connConfig := *defaultConnConfig
+
+ logger := log.New()
+ lvl, err := log.LvlFromString("debug")
+ if err != nil {
+ b.Fatal(err)
+ }
+ logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
+ connConfig.Logger = logger
+ connConfig.LogLevel = pgx.LogLevelDebug
+ conn := mustConnect(b, connConfig)
+ defer closeConn(b, conn)
+
+ benchmarkSelectWithLog(b, conn)
+}
+
+func BenchmarkSelectWithLoggingInfoWithLog15(b *testing.B) {
+ connConfig := *defaultConnConfig
+
+ logger := log.New()
+ lvl, err := log.LvlFromString("info")
+ if err != nil {
+ b.Fatal(err)
+ }
+ logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
+ connConfig.Logger = logger
+ connConfig.LogLevel = pgx.LogLevelInfo
+ conn := mustConnect(b, connConfig)
+ defer closeConn(b, conn)
+
+ benchmarkSelectWithLog(b, conn)
+}
+
+func BenchmarkSelectWithLoggingErrorWithLog15(b *testing.B) {
+ connConfig := *defaultConnConfig
+
+ logger := log.New()
+ lvl, err := log.LvlFromString("error")
+ if err != nil {
+ b.Fatal(err)
+ }
+ logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
+ connConfig.Logger = logger
+ connConfig.LogLevel = pgx.LogLevelError
+ conn := mustConnect(b, connConfig)
+ defer closeConn(b, conn)
+
+ benchmarkSelectWithLog(b, conn)
+}
+
+func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
+ _, err := conn.Prepare("test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var record struct {
+ id int32
+ userName string
+ email string
+ name string
+ sex string
+ birthDate time.Time
+ lastLoginTime time.Time
+ }
+
+ err = conn.QueryRow("test").Scan(
+ &record.id,
+ &record.userName,
+ &record.email,
+ &record.name,
+ &record.sex,
+ &record.birthDate,
+ &record.lastLoginTime,
+ )
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ // These checks both ensure that the correct data was returned
+ // and provide a benchmark of accessing the returned values.
+ if record.id != 1 {
+ b.Fatalf("bad value for id: %v", record.id)
+ }
+ if record.userName != "johnsmith" {
+ b.Fatalf("bad value for userName: %v", record.userName)
+ }
+ if record.email != "johnsmith@example.com" {
+ b.Fatalf("bad value for email: %v", record.email)
+ }
+ if record.name != "John Smith" {
+ b.Fatalf("bad value for name: %v", record.name)
+ }
+ if record.sex != "male" {
+ b.Fatalf("bad value for sex: %v", record.sex)
+ }
+ if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) {
+ b.Fatalf("bad value for birthDate: %v", record.birthDate)
+ }
+ if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
+ b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
+ }
+ }
+}
+
+func BenchmarkLog15Discard(b *testing.B) {
+ logger := log.New()
+ lvl, err := log.LvlFromString("error")
+ if err != nil {
+ b.Fatal(err)
+ }
+ logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler()))
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ logger.Debug("benchmark", "i", i, "b.N", b.N)
+ }
+}
+
+const benchmarkWriteTableCreateSQL = `drop table if exists t;
+
+create table t(
+ varchar_1 varchar not null,
+ varchar_2 varchar not null,
+ varchar_null_1 varchar,
+ date_1 date not null,
+ date_null_1 date,
+ int4_1 int4 not null,
+ int4_2 int4 not null,
+ int4_null_1 int4,
+ tstz_1 timestamptz not null,
+ tstz_2 timestamptz,
+ bool_1 bool not null,
+ bool_2 bool not null,
+ bool_3 bool not null
+);
+`
+
+const benchmarkWriteTableInsertSQL = `insert into t(
+ varchar_1,
+ varchar_2,
+ varchar_null_1,
+ date_1,
+ date_null_1,
+ int4_1,
+ int4_2,
+ int4_null_1,
+ tstz_1,
+ tstz_2,
+ bool_1,
+ bool_2,
+ bool_3
+) values (
+ $1::varchar,
+ $2::varchar,
+ $3::varchar,
+ $4::date,
+ $5::date,
+ $6::int4,
+ $7::int4,
+ $8::int4,
+ $9::timestamptz,
+ $10::timestamptz,
+ $11::bool,
+ $12::bool,
+ $13::bool
+)`
+
+type benchmarkWriteTableCopyFromSrc struct {
+ count int
+ idx int
+ row []interface{}
+}
+
+func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
+ s.idx++
+ return s.idx < s.count
+}
+
+func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) {
+ return s.row, nil
+}
+
+func (s *benchmarkWriteTableCopyFromSrc) Err() error {
+ return nil
+}
+
+func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource {
+ return &benchmarkWriteTableCopyFromSrc{
+ count: count,
+ row: []interface{}{
+ "varchar_1",
+ "varchar_2",
+ pgx.NullString{},
+ time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
+ pgx.NullTime{},
+ 1,
+ 2,
+ pgx.NullInt32{},
+ time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
+ true,
+ false,
+ true,
+ },
+ }
+}
+
+func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ mustExec(b, conn, benchmarkWriteTableCreateSQL)
+ _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ src := newBenchmarkWriteTableCopyFromSrc(n)
+
+ tx, err := conn.Begin()
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ for src.Next() {
+ values, _ := src.Values()
+ if _, err = tx.Exec("insert_t", values...); err != nil {
+ b.Fatalf("Exec unexpectedly failed with: %v", err)
+ }
+ }
+
+ err = tx.Commit()
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+// note this function is only used for benchmarks -- it doesn't escape tableName
+// or columnNames
+func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyFromSource) (int, error) {
+ maxRowsPerInsert := 65535 / len(columnNames)
+ rowsThisInsert := 0
+ rowCount := 0
+
+ sqlBuf := &bytes.Buffer{}
+ args := make(pgx.QueryArgs, 0)
+
+ resetQuery := func() {
+ sqlBuf.Reset()
+ fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", "))
+
+ args = args[0:0]
+
+ rowsThisInsert = 0
+ }
+ resetQuery()
+
+ tx, err := conn.Begin()
+ if err != nil {
+ return 0, err
+ }
+ defer tx.Rollback()
+
+ for rowSrc.Next() {
+ if rowsThisInsert > 0 {
+ sqlBuf.WriteByte(',')
+ }
+
+ sqlBuf.WriteByte('(')
+
+ values, err := rowSrc.Values()
+ if err != nil {
+ return 0, err
+ }
+
+ for i, val := range values {
+ if i > 0 {
+ sqlBuf.WriteByte(',')
+ }
+ sqlBuf.WriteString(args.Append(val))
+ }
+
+ sqlBuf.WriteByte(')')
+
+ rowsThisInsert++
+
+ if rowsThisInsert == maxRowsPerInsert {
+ _, err := tx.Exec(sqlBuf.String(), args...)
+ if err != nil {
+ return 0, err
+ }
+
+ rowCount += rowsThisInsert
+ resetQuery()
+ }
+ }
+
+ if rowsThisInsert > 0 {
+ _, err := tx.Exec(sqlBuf.String(), args...)
+ if err != nil {
+ return 0, err
+ }
+
+ rowCount += rowsThisInsert
+ }
+
+ if err := tx.Commit(); err != nil {
+ return 0, nil
+ }
+
+ return rowCount, nil
+
+}
+
+func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ mustExec(b, conn, benchmarkWriteTableCreateSQL)
+ _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ src := newBenchmarkWriteTableCopyFromSrc(n)
+
+ _, err := multiInsert(conn, "t",
+ []string{"varchar_1",
+ "varchar_2",
+ "varchar_null_1",
+ "date_1",
+ "date_null_1",
+ "int4_1",
+ "int4_2",
+ "int4_null_1",
+ "tstz_1",
+ "tstz_2",
+ "bool_1",
+ "bool_2",
+ "bool_3"},
+ src)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
+ conn := mustConnect(b, *defaultConnConfig)
+ defer closeConn(b, conn)
+
+ mustExec(b, conn, benchmarkWriteTableCreateSQL)
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ src := newBenchmarkWriteTableCopyFromSrc(n)
+
+ _, err := conn.CopyFrom(pgx.Identifier{"t"},
+ []string{"varchar_1",
+ "varchar_2",
+ "varchar_null_1",
+ "date_1",
+ "date_null_1",
+ "int4_1",
+ "int4_2",
+ "int4_null_1",
+ "tstz_1",
+ "tstz_2",
+ "bool_1",
+ "bool_2",
+ "bool_3"},
+ src)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkWrite5RowsViaInsert(b *testing.B) {
+ benchmarkWriteNRowsViaInsert(b, 5)
+}
+
+func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
+ benchmarkWriteNRowsViaMultiInsert(b, 5)
+}
+
+func BenchmarkWrite5RowsViaCopy(b *testing.B) {
+ benchmarkWriteNRowsViaCopy(b, 5)
+}
+
+func BenchmarkWrite10RowsViaInsert(b *testing.B) {
+ benchmarkWriteNRowsViaInsert(b, 10)
+}
+
+func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
+ benchmarkWriteNRowsViaMultiInsert(b, 10)
+}
+
+func BenchmarkWrite10RowsViaCopy(b *testing.B) {
+ benchmarkWriteNRowsViaCopy(b, 10)
+}
+
+func BenchmarkWrite100RowsViaInsert(b *testing.B) {
+ benchmarkWriteNRowsViaInsert(b, 100)
+}
+
+func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
+ benchmarkWriteNRowsViaMultiInsert(b, 100)
+}
+
+func BenchmarkWrite100RowsViaCopy(b *testing.B) {
+ benchmarkWriteNRowsViaCopy(b, 100)
+}
+
+func BenchmarkWrite1000RowsViaInsert(b *testing.B) {
+ benchmarkWriteNRowsViaInsert(b, 1000)
+}
+
+func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) {
+ benchmarkWriteNRowsViaMultiInsert(b, 1000)
+}
+
+func BenchmarkWrite1000RowsViaCopy(b *testing.B) {
+ benchmarkWriteNRowsViaCopy(b, 1000)
+}
+
+func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
+ benchmarkWriteNRowsViaInsert(b, 10000)
+}
+
+func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
+ benchmarkWriteNRowsViaMultiInsert(b, 10000)
+}
+
+func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
+ benchmarkWriteNRowsViaCopy(b, 10000)
+}
diff --git a/vendor/github.com/jackc/pgx/conn.go b/vendor/github.com/jackc/pgx/conn.go
new file mode 100644
index 0000000..a2d60e7
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/conn.go
@@ -0,0 +1,1322 @@
+package pgx
+
+import (
+ "bufio"
+ "crypto/md5"
+ "crypto/tls"
+ "encoding/binary"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/url"
+ "os"
+ "os/user"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// DialFunc is a function that can be used to connect to a PostgreSQL server
+type DialFunc func(network, addr string) (net.Conn, error)
+
+// ConnConfig contains all the options used to establish a connection.
+type ConnConfig struct {
+ Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
+ Port uint16 // default: 5432
+ Database string
+ User string // default: OS user name
+ Password string
+ TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
+ UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
+ FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
+ Logger Logger
+ LogLevel int
+ Dial DialFunc
+ RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
+}
+
+// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
+// Use ConnPool to manage access to multiple database connections from multiple
+// goroutines.
+type Conn struct {
+ conn net.Conn // the underlying TCP or unix domain socket connection
+ lastActivityTime time.Time // the last time the connection was used
+ reader *bufio.Reader // buffered reader to improve read performance
+ wbuf [1024]byte
+ writeBuf WriteBuf
+ Pid int32 // backend pid
+ SecretKey int32 // key to use to send a cancel query message to the server
+ RuntimeParams map[string]string // parameters that have been reported by the server
+ PgTypes map[Oid]PgType // oids to PgTypes
+ config ConnConfig // config used when establishing this connection
+ TxStatus byte
+ preparedStatements map[string]*PreparedStatement
+ channels map[string]struct{}
+ notifications []*Notification
+ alive bool
+ causeOfDeath error
+ logger Logger
+ logLevel int
+ mr msgReader
+ fp *fastpath
+ pgsqlAfInet *byte
+ pgsqlAfInet6 *byte
+ busy bool
+ poolResetCount int
+ preallocatedRows []Rows
+}
+
+// PreparedStatement is a description of a prepared statement
+type PreparedStatement struct {
+ Name string
+ SQL string
+ FieldDescriptions []FieldDescription
+ ParameterOids []Oid
+}
+
+// PrepareExOptions is an option struct that can be passed to PrepareEx
+type PrepareExOptions struct {
+ ParameterOids []Oid
+}
+
+// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
+type Notification struct {
+ Pid int32 // backend pid that sent the notification
+ Channel string // channel from which notification was received
+ Payload string
+}
+
+// PgType is information about PostgreSQL type and how to encode and decode it
+type PgType struct {
+ Name string // name of type e.g. int4, text, date
+ DefaultFormat int16 // default format (text or binary) this type will be requested in
+}
+
+// CommandTag is the result of an Exec function
+type CommandTag string
+
+// RowsAffected returns the number of rows affected. If the CommandTag was not
+// for a row affecting command (such as "CREATE TABLE") then it returns 0
+func (ct CommandTag) RowsAffected() int64 {
+ s := string(ct)
+ index := strings.LastIndex(s, " ")
+ if index == -1 {
+ return 0
+ }
+ n, _ := strconv.ParseInt(s[index+1:], 10, 64)
+ return n
+}
+
+// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
+// multiple parts such as ["schema", "table"] or ["table", "column"].
+type Identifier []string
+
+// Sanitize returns a sanitized string safe for SQL interpolation.
+func (ident Identifier) Sanitize() string {
+ parts := make([]string, len(ident))
+ for i := range ident {
+ parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"`
+ }
+ return strings.Join(parts, ".")
+}
+
+// ErrNoRows occurs when rows are expected but none are returned.
+var ErrNoRows = errors.New("no rows in result set")
+
+// ErrNotificationTimeout occurs when WaitForNotification times out.
+var ErrNotificationTimeout = errors.New("notification timeout")
+
+// ErrDeadConn occurs on an attempt to use a dead connection
+var ErrDeadConn = errors.New("conn is dead")
+
+// ErrTLSRefused occurs when the connection attempt requires TLS and the
+// PostgreSQL server refuses to use TLS
+var ErrTLSRefused = errors.New("server refused TLS connection")
+
+// ErrConnBusy occurs when the connection is busy (for example, in the middle of
+// reading query results) and another action is attempts.
+var ErrConnBusy = errors.New("conn is busy")
+
+// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
+var ErrInvalidLogLevel = errors.New("invalid log level")
+
+// ProtocolError occurs when unexpected data is received from PostgreSQL
+type ProtocolError string
+
+func (e ProtocolError) Error() string {
+ return string(e)
+}
+
+// Connect establishes a connection with a PostgreSQL server using config.
+// config.Host must be specified. config.User will default to the OS user name.
+// Other config fields are optional.
+func Connect(config ConnConfig) (c *Conn, err error) {
+ return connect(config, nil, nil, nil)
+}
+
+func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) {
+ c = new(Conn)
+
+ c.config = config
+
+ if pgTypes != nil {
+ c.PgTypes = make(map[Oid]PgType, len(pgTypes))
+ for k, v := range pgTypes {
+ c.PgTypes[k] = v
+ }
+ }
+
+ if pgsqlAfInet != nil {
+ c.pgsqlAfInet = new(byte)
+ *c.pgsqlAfInet = *pgsqlAfInet
+ }
+ if pgsqlAfInet6 != nil {
+ c.pgsqlAfInet6 = new(byte)
+ *c.pgsqlAfInet6 = *pgsqlAfInet6
+ }
+
+ if c.config.LogLevel != 0 {
+ c.logLevel = c.config.LogLevel
+ } else {
+ // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
+ c.logLevel = LogLevelDebug
+ }
+ c.logger = c.config.Logger
+ c.mr.log = c.log
+ c.mr.shouldLog = c.shouldLog
+
+ if c.config.User == "" {
+ user, err := user.Current()
+ if err != nil {
+ return nil, err
+ }
+ c.config.User = user.Username
+ if c.shouldLog(LogLevelDebug) {
+ c.log(LogLevelDebug, "Using default connection config", "User", c.config.User)
+ }
+ }
+
+ if c.config.Port == 0 {
+ c.config.Port = 5432
+ if c.shouldLog(LogLevelDebug) {
+ c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port)
+ }
+ }
+
+ network := "tcp"
+ address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
+ // See if host is a valid path, if yes connect with a socket
+ if _, err := os.Stat(c.config.Host); err == nil {
+ // For backward compatibility accept socket file paths -- but directories are now preferred
+ network = "unix"
+ address = c.config.Host
+ if !strings.Contains(address, "/.s.PGSQL.") {
+ address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
+ }
+ }
+ if c.config.Dial == nil {
+ c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
+ }
+
+ if c.shouldLog(LogLevelInfo) {
+ c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
+ }
+ err = c.connect(config, network, address, config.TLSConfig)
+ if err != nil && config.UseFallbackTLS {
+ if c.shouldLog(LogLevelInfo) {
+ c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err))
+ }
+ err = c.connect(config, network, address, config.FallbackTLSConfig)
+ }
+
+ if err != nil {
+ if c.shouldLog(LogLevelError) {
+ c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err))
+ }
+ return nil, err
+ }
+
+ return c, nil
+}
+
+func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
+ c.conn, err = c.config.Dial(network, address)
+ if err != nil {
+ return err
+ }
+ defer func() {
+ if c != nil && err != nil {
+ c.conn.Close()
+ c.alive = false
+ }
+ }()
+
+ c.RuntimeParams = make(map[string]string)
+ c.preparedStatements = make(map[string]*PreparedStatement)
+ c.channels = make(map[string]struct{})
+ c.alive = true
+ c.lastActivityTime = time.Now()
+
+ if tlsConfig != nil {
+ if c.shouldLog(LogLevelDebug) {
+ c.log(LogLevelDebug, "Starting TLS handshake")
+ }
+ if err := c.startTLS(tlsConfig); err != nil {
+ return err
+ }
+ }
+
+ c.reader = bufio.NewReader(c.conn)
+ c.mr.reader = c.reader
+
+ msg := newStartupMessage()
+
+ // Default to disabling TLS renegotiation.
+ //
+ // Go does not support (https://github.com/golang/go/issues/5742)
+ // PostgreSQL recommends disabling (http://www.postgresql.org/docs/9.4/static/runtime-config-connection.html#GUC-SSL-RENEGOTIATION-LIMIT)
+ if tlsConfig != nil {
+ msg.options["ssl_renegotiation_limit"] = "0"
+ }
+
+ // Copy default run-time params
+ for k, v := range config.RuntimeParams {
+ msg.options[k] = v
+ }
+
+ msg.options["user"] = c.config.User
+ if c.config.Database != "" {
+ msg.options["database"] = c.config.Database
+ }
+
+ if err = c.txStartupMessage(msg); err != nil {
+ return err
+ }
+
+ for {
+ var t byte
+ var r *msgReader
+ t, r, err = c.rxMsg()
+ if err != nil {
+ return err
+ }
+
+ switch t {
+ case backendKeyData:
+ c.rxBackendKeyData(r)
+ case authenticationX:
+ if err = c.rxAuthenticationX(r); err != nil {
+ return err
+ }
+ case readyForQuery:
+ c.rxReadyForQuery(r)
+ if c.shouldLog(LogLevelInfo) {
+ c.log(LogLevelInfo, "Connection established")
+ }
+
+ // Replication connections can't execute the queries to
+ // populate the c.PgTypes and c.pgsqlAfInet
+ if _, ok := msg.options["replication"]; ok {
+ return nil
+ }
+
+ if c.PgTypes == nil {
+ err = c.loadPgTypes()
+ if err != nil {
+ return err
+ }
+ }
+
+ if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil {
+ err = c.loadInetConstants()
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ default:
+ if err = c.processContextFreeMsg(t, r); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+func (c *Conn) loadPgTypes() error {
+ rows, err := c.Query(`select t.oid, t.typname
+from pg_type t
+left join pg_type base_type on t.typelem=base_type.oid
+where (
+ t.typtype='b'
+ and (base_type.oid is null or base_type.typtype='b')
+ )
+ or t.typname in('record');`)
+ if err != nil {
+ return err
+ }
+
+ c.PgTypes = make(map[Oid]PgType, 128)
+
+ for rows.Next() {
+ var oid Oid
+ var t PgType
+
+ rows.Scan(&oid, &t.Name)
+
+ // The zero value is text format so we ignore any types without a default type format
+ t.DefaultFormat, _ = DefaultTypeFormats[t.Name]
+
+ c.PgTypes[oid] = t
+ }
+
+ return rows.Err()
+}
+
+// Family is needed for binary encoding of inet/cidr. The constant is based on
+// the server's definition of AF_INET. In theory, this could differ between
+// platforms, so request an IPv4 and an IPv6 inet and get the family from that.
+func (c *Conn) loadInetConstants() error {
+ var ipv4, ipv6 []byte
+
+ err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6)
+ if err != nil {
+ return err
+ }
+
+ c.pgsqlAfInet = &ipv4[0]
+ c.pgsqlAfInet6 = &ipv6[0]
+
+ return nil
+}
+
+// Close closes a connection. It is safe to call Close on a already closed
+// connection.
+func (c *Conn) Close() (err error) {
+ if !c.IsAlive() {
+ return nil
+ }
+
+ wbuf := newWriteBuf(c, 'X')
+ wbuf.closeMsg()
+
+ _, err = c.conn.Write(wbuf.buf)
+
+ c.die(errors.New("Closed"))
+ if c.shouldLog(LogLevelInfo) {
+ c.log(LogLevelInfo, "Closed connection")
+ }
+ return err
+}
+
+// ParseURI parses a database URI into ConnConfig
+//
+// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams.
+func ParseURI(uri string) (ConnConfig, error) {
+ var cp ConnConfig
+
+ url, err := url.Parse(uri)
+ if err != nil {
+ return cp, err
+ }
+
+ if url.User != nil {
+ cp.User = url.User.Username()
+ cp.Password, _ = url.User.Password()
+ }
+
+ parts := strings.SplitN(url.Host, ":", 2)
+ cp.Host = parts[0]
+ if len(parts) == 2 {
+ p, err := strconv.ParseUint(parts[1], 10, 16)
+ if err != nil {
+ return cp, err
+ }
+ cp.Port = uint16(p)
+ }
+ cp.Database = strings.TrimLeft(url.Path, "/")
+
+ err = configSSL(url.Query().Get("sslmode"), &cp)
+ if err != nil {
+ return cp, err
+ }
+
+ ignoreKeys := map[string]struct{}{
+ "sslmode": {},
+ }
+
+ cp.RuntimeParams = make(map[string]string)
+
+ for k, v := range url.Query() {
+ if _, ok := ignoreKeys[k]; ok {
+ continue
+ }
+
+ cp.RuntimeParams[k] = v[0]
+ }
+ if cp.Password == "" {
+ pgpass(&cp)
+ }
+ return cp, nil
+}
+
+var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
+
+// ParseDSN parses a database DSN (data source name) into a ConnConfig
+//
+// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable")
+//
+// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams.
+//
+// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb")
+//
+// ParseDSN tries to match libpq behavior with regard to sslmode. See comments
+// for ParseEnvLibpq for more information on the security implications of
+// sslmode options.
+func ParseDSN(s string) (ConnConfig, error) {
+ var cp ConnConfig
+
+ m := dsnRegexp.FindAllStringSubmatch(s, -1)
+
+ var sslmode string
+
+ cp.RuntimeParams = make(map[string]string)
+
+ for _, b := range m {
+ switch b[1] {
+ case "user":
+ cp.User = b[2]
+ case "password":
+ cp.Password = b[2]
+ case "host":
+ cp.Host = b[2]
+ case "port":
+ p, err := strconv.ParseUint(b[2], 10, 16)
+ if err != nil {
+ return cp, err
+ }
+ cp.Port = uint16(p)
+ case "dbname":
+ cp.Database = b[2]
+ case "sslmode":
+ sslmode = b[2]
+ default:
+ cp.RuntimeParams[b[1]] = b[2]
+ }
+ }
+
+ err := configSSL(sslmode, &cp)
+ if err != nil {
+ return cp, err
+ }
+ if cp.Password == "" {
+ pgpass(&cp)
+ }
+ return cp, nil
+}
+
+// ParseConnectionString parses either a URI or a DSN connection string.
+// see ParseURI and ParseDSN for details.
+func ParseConnectionString(s string) (ConnConfig, error) {
+ if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") {
+ return ParseURI(s)
+ }
+ return ParseDSN(s)
+}
+
+// ParseEnvLibpq parses the environment like libpq does into a ConnConfig
+//
+// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details
+// on the meaning of environment variables.
+//
+// ParseEnvLibpq currently recognizes the following environment variables:
+// PGHOST
+// PGPORT
+// PGDATABASE
+// PGUSER
+// PGPASSWORD
+// PGSSLMODE
+// PGAPPNAME
+//
+// Important TLS Security Notes:
+// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
+// includes defaulting to "prefer" behavior if no environment variable is set.
+//
+// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION
+// for details on what level of security each sslmode provides.
+//
+// "require" and "verify-ca" modes currently are treated as "verify-full". e.g.
+// They have stronger security guarantees than they would with libpq. Do not
+// rely on this behavior as it may be possible to match libpq in the future. If
+// you need full security use "verify-full".
+//
+// Several of the PGSSLMODE options (including the default behavior of "prefer")
+// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or
+// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is
+// later set from a different source that UseFallbackTLS MUST be set false to
+// avoid the possibility of falling back to weaker or disabled security.
+func ParseEnvLibpq() (ConnConfig, error) {
+ var cc ConnConfig
+
+ cc.Host = os.Getenv("PGHOST")
+
+ if pgport := os.Getenv("PGPORT"); pgport != "" {
+ if port, err := strconv.ParseUint(pgport, 10, 16); err == nil {
+ cc.Port = uint16(port)
+ } else {
+ return cc, err
+ }
+ }
+
+ cc.Database = os.Getenv("PGDATABASE")
+ cc.User = os.Getenv("PGUSER")
+ cc.Password = os.Getenv("PGPASSWORD")
+
+ sslmode := os.Getenv("PGSSLMODE")
+
+ err := configSSL(sslmode, &cc)
+ if err != nil {
+ return cc, err
+ }
+
+ cc.RuntimeParams = make(map[string]string)
+ if appname := os.Getenv("PGAPPNAME"); appname != "" {
+ cc.RuntimeParams["application_name"] = appname
+ }
+ if cc.Password == "" {
+ pgpass(&cc)
+ }
+ return cc, nil
+}
+
+func configSSL(sslmode string, cc *ConnConfig) error {
+ // Match libpq default behavior
+ if sslmode == "" {
+ sslmode = "prefer"
+ }
+
+ switch sslmode {
+ case "disable":
+ case "allow":
+ cc.UseFallbackTLS = true
+ cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
+ case "prefer":
+ cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
+ cc.UseFallbackTLS = true
+ cc.FallbackTLSConfig = nil
+ case "require", "verify-ca", "verify-full":
+ cc.TLSConfig = &tls.Config{
+ ServerName: cc.Host,
+ }
+ default:
+ return errors.New("sslmode is invalid")
+ }
+
+ return nil
+}
+
+// Prepare creates a prepared statement with name and sql. sql can contain placeholders
+// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
+//
+// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
+// name and sql arguments. This allows a code path to Prepare and Query/Exec without
+// concern for if the statement has already been prepared.
+func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
+ return c.PrepareEx(name, sql, nil)
+}
+
+// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
+// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
+// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
+//
+// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
+// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
+// concern for if the statement has already been prepared.
+func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
+ if name != "" {
+ if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
+ return ps, nil
+ }
+ }
+
+ if c.shouldLog(LogLevelError) {
+ defer func() {
+ if err != nil {
+ c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err))
+ }
+ }()
+ }
+
+ // parse
+ wbuf := newWriteBuf(c, 'P')
+ wbuf.WriteCString(name)
+ wbuf.WriteCString(sql)
+
+ if opts != nil {
+ if len(opts.ParameterOids) > 65535 {
+ return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))
+ }
+ wbuf.WriteInt16(int16(len(opts.ParameterOids)))
+ for _, oid := range opts.ParameterOids {
+ wbuf.WriteInt32(int32(oid))
+ }
+ } else {
+ wbuf.WriteInt16(0)
+ }
+
+ // describe
+ wbuf.startMsg('D')
+ wbuf.WriteByte('S')
+ wbuf.WriteCString(name)
+
+ // sync
+ wbuf.startMsg('S')
+ wbuf.closeMsg()
+
+ _, err = c.conn.Write(wbuf.buf)
+ if err != nil {
+ c.die(err)
+ return nil, err
+ }
+
+ ps = &PreparedStatement{Name: name, SQL: sql}
+
+ var softErr error
+
+ for {
+ var t byte
+ var r *msgReader
+ t, r, err := c.rxMsg()
+ if err != nil {
+ return nil, err
+ }
+
+ switch t {
+ case parseComplete:
+ case parameterDescription:
+ ps.ParameterOids = c.rxParameterDescription(r)
+
+ if len(ps.ParameterOids) > 65535 && softErr == nil {
+ softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids))
+ }
+ case rowDescription:
+ ps.FieldDescriptions = c.rxRowDescription(r)
+ for i := range ps.FieldDescriptions {
+ t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType]
+ ps.FieldDescriptions[i].DataTypeName = t.Name
+ ps.FieldDescriptions[i].FormatCode = t.DefaultFormat
+ }
+ case noData:
+ case readyForQuery:
+ c.rxReadyForQuery(r)
+
+ if softErr == nil {
+ c.preparedStatements[name] = ps
+ }
+
+ return ps, softErr
+ default:
+ if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
+ softErr = e
+ }
+ }
+ }
+}
+
+// Deallocate released a prepared statement
+func (c *Conn) Deallocate(name string) (err error) {
+ delete(c.preparedStatements, name)
+
+ // close
+ wbuf := newWriteBuf(c, 'C')
+ wbuf.WriteByte('S')
+ wbuf.WriteCString(name)
+
+ // flush
+ wbuf.startMsg('H')
+ wbuf.closeMsg()
+
+ _, err = c.conn.Write(wbuf.buf)
+ if err != nil {
+ c.die(err)
+ return err
+ }
+
+ for {
+ var t byte
+ var r *msgReader
+ t, r, err := c.rxMsg()
+ if err != nil {
+ return err
+ }
+
+ switch t {
+ case closeComplete:
+ return nil
+ default:
+ err = c.processContextFreeMsg(t, r)
+ if err != nil {
+ return err
+ }
+ }
+ }
+}
+
+// Listen establishes a PostgreSQL listen/notify to channel
+func (c *Conn) Listen(channel string) error {
+ _, err := c.Exec("listen " + quoteIdentifier(channel))
+ if err != nil {
+ return err
+ }
+
+ c.channels[channel] = struct{}{}
+
+ return nil
+}
+
+// Unlisten unsubscribes from a listen channel
+func (c *Conn) Unlisten(channel string) error {
+ _, err := c.Exec("unlisten " + quoteIdentifier(channel))
+ if err != nil {
+ return err
+ }
+
+ delete(c.channels, channel)
+ return nil
+}
+
+// WaitForNotification waits for a PostgreSQL notification for up to timeout.
+// If the timeout occurs it returns pgx.ErrNotificationTimeout
+func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) {
+ // Return already received notification immediately
+ if len(c.notifications) > 0 {
+ notification := c.notifications[0]
+ c.notifications = c.notifications[1:]
+ return notification, nil
+ }
+
+ stopTime := time.Now().Add(timeout)
+
+ for {
+ now := time.Now()
+
+ if now.After(stopTime) {
+ return nil, ErrNotificationTimeout
+ }
+
+ // If there has been no activity on this connection for a while send a nop message just to ensure
+ // the connection is alive
+ nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second)
+ if nextEnsureAliveTime.Before(now) {
+ // If the server can't respond to a nop in 15 seconds, assume it's dead
+ err := c.conn.SetReadDeadline(now.Add(15 * time.Second))
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = c.Exec("--;")
+ if err != nil {
+ return nil, err
+ }
+
+ c.lastActivityTime = now
+ }
+
+ var deadline time.Time
+ if stopTime.Before(nextEnsureAliveTime) {
+ deadline = stopTime
+ } else {
+ deadline = nextEnsureAliveTime
+ }
+
+ notification, err := c.waitForNotification(deadline)
+ if err != ErrNotificationTimeout {
+ return notification, err
+ }
+ }
+}
+
+func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
+ var zeroTime time.Time
+
+ for {
+ // Use SetReadDeadline to implement the timeout. SetReadDeadline will
+ // cause operations to fail with a *net.OpError that has a Timeout()
+ // of true. Because the normal pgx rxMsg path considers any error to
+ // have potentially corrupted the state of the connection, it dies
+ // on any errors. So to avoid timeout errors in rxMsg we set the
+ // deadline and peek into the reader. If a timeout error occurs there
+ // we don't break the pgx connection. If the Peek returns that data
+ // is available then we turn off the read deadline before the rxMsg.
+ err := c.conn.SetReadDeadline(deadline)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wait until there is a byte available before continuing onto the normal msg reading path
+ _, err = c.reader.Peek(1)
+ if err != nil {
+ c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
+ if err, ok := err.(*net.OpError); ok && err.Timeout() {
+ return nil, ErrNotificationTimeout
+ }
+ return nil, err
+ }
+
+ err = c.conn.SetReadDeadline(zeroTime)
+ if err != nil {
+ return nil, err
+ }
+
+ var t byte
+ var r *msgReader
+ if t, r, err = c.rxMsg(); err == nil {
+ if err = c.processContextFreeMsg(t, r); err != nil {
+ return nil, err
+ }
+ } else {
+ return nil, err
+ }
+
+ if len(c.notifications) > 0 {
+ notification := c.notifications[0]
+ c.notifications = c.notifications[1:]
+ return notification, nil
+ }
+ }
+}
+
+func (c *Conn) IsAlive() bool {
+ return c.alive
+}
+
+func (c *Conn) CauseOfDeath() error {
+ return c.causeOfDeath
+}
+
+func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
+ if ps, present := c.preparedStatements[sql]; present {
+ return c.sendPreparedQuery(ps, arguments...)
+ }
+ return c.sendSimpleQuery(sql, arguments...)
+}
+
+func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
+
+ if len(args) == 0 {
+ wbuf := newWriteBuf(c, 'Q')
+ wbuf.WriteCString(sql)
+ wbuf.closeMsg()
+
+ _, err := c.conn.Write(wbuf.buf)
+ if err != nil {
+ c.die(err)
+ return err
+ }
+
+ return nil
+ }
+
+ ps, err := c.Prepare("", sql)
+ if err != nil {
+ return err
+ }
+
+ return c.sendPreparedQuery(ps, args...)
+}
+
+func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
+ if len(ps.ParameterOids) != len(arguments) {
+ return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments))
+ }
+
+ // bind
+ wbuf := newWriteBuf(c, 'B')
+ wbuf.WriteByte(0)
+ wbuf.WriteCString(ps.Name)
+
+ wbuf.WriteInt16(int16(len(ps.ParameterOids)))
+ for i, oid := range ps.ParameterOids {
+ switch arg := arguments[i].(type) {
+ case Encoder:
+ wbuf.WriteInt16(arg.FormatCode())
+ case string, *string:
+ wbuf.WriteInt16(TextFormatCode)
+ default:
+ switch oid {
+ case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid, JsonOid, JsonbOid:
+ wbuf.WriteInt16(BinaryFormatCode)
+ default:
+ wbuf.WriteInt16(TextFormatCode)
+ }
+ }
+ }
+
+ wbuf.WriteInt16(int16(len(arguments)))
+ for i, oid := range ps.ParameterOids {
+ if err := Encode(wbuf, oid, arguments[i]); err != nil {
+ return err
+ }
+ }
+
+ wbuf.WriteInt16(int16(len(ps.FieldDescriptions)))
+ for _, fd := range ps.FieldDescriptions {
+ wbuf.WriteInt16(fd.FormatCode)
+ }
+
+ // execute
+ wbuf.startMsg('E')
+ wbuf.WriteByte(0)
+ wbuf.WriteInt32(0)
+
+ // sync
+ wbuf.startMsg('S')
+ wbuf.closeMsg()
+
+ _, err = c.conn.Write(wbuf.buf)
+ if err != nil {
+ c.die(err)
+ }
+
+ return err
+}
+
+// Exec executes sql. sql can be either a prepared statement name or an SQL string.
+// arguments should be referenced positionally from the sql string as $1, $2, etc.
+func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
+ if err = c.lock(); err != nil {
+ return commandTag, err
+ }
+
+ startTime := time.Now()
+ c.lastActivityTime = startTime
+
+ defer func() {
+ if err == nil {
+ if c.shouldLog(LogLevelInfo) {
+ endTime := time.Now()
+ c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
+ }
+ } else {
+ if c.shouldLog(LogLevelError) {
+ c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
+ }
+ }
+
+ if unlockErr := c.unlock(); unlockErr != nil && err == nil {
+ err = unlockErr
+ }
+ }()
+
+ if err = c.sendQuery(sql, arguments...); err != nil {
+ return
+ }
+
+ var softErr error
+
+ for {
+ var t byte
+ var r *msgReader
+ t, r, err = c.rxMsg()
+ if err != nil {
+ return commandTag, err
+ }
+
+ switch t {
+ case readyForQuery:
+ c.rxReadyForQuery(r)
+ return commandTag, softErr
+ case rowDescription:
+ case dataRow:
+ case bindComplete:
+ case commandComplete:
+ commandTag = CommandTag(r.readCString())
+ default:
+ if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
+ softErr = e
+ }
+ }
+ }
+}
+
+// Processes messages that are not exclusive to one context such as
+// authentication or query response. The response to these messages
+// is the same regardless of when they occur.
+func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
+ switch t {
+ case 'S':
+ c.rxParameterStatus(r)
+ return nil
+ case errorResponse:
+ return c.rxErrorResponse(r)
+ case noticeResponse:
+ return nil
+ case emptyQueryResponse:
+ return nil
+ case notificationResponse:
+ c.rxNotificationResponse(r)
+ return nil
+ default:
+ return fmt.Errorf("Received unknown message type: %c", t)
+ }
+}
+
+func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
+ if !c.alive {
+ return 0, nil, ErrDeadConn
+ }
+
+ t, err = c.mr.rxMsg()
+ if err != nil {
+ c.die(err)
+ }
+
+ c.lastActivityTime = time.Now()
+
+ if c.shouldLog(LogLevelTrace) {
+ c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining)
+ }
+
+ return t, &c.mr, err
+}
+
+func (c *Conn) rxAuthenticationX(r *msgReader) (err error) {
+ switch r.readInt32() {
+ case 0: // AuthenticationOk
+ case 3: // AuthenticationCleartextPassword
+ err = c.txPasswordMessage(c.config.Password)
+ case 5: // AuthenticationMD5Password
+ salt := r.readString(4)
+ digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt)
+ err = c.txPasswordMessage(digestedPassword)
+ default:
+ err = errors.New("Received unknown authentication message")
+ }
+
+ return
+}
+
+func hexMD5(s string) string {
+ hash := md5.New()
+ io.WriteString(hash, s)
+ return hex.EncodeToString(hash.Sum(nil))
+}
+
+func (c *Conn) rxParameterStatus(r *msgReader) {
+ key := r.readCString()
+ value := r.readCString()
+ c.RuntimeParams[key] = value
+}
+
+func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) {
+ for {
+ switch r.readByte() {
+ case 'S':
+ err.Severity = r.readCString()
+ case 'C':
+ err.Code = r.readCString()
+ case 'M':
+ err.Message = r.readCString()
+ case 'D':
+ err.Detail = r.readCString()
+ case 'H':
+ err.Hint = r.readCString()
+ case 'P':
+ s := r.readCString()
+ n, _ := strconv.ParseInt(s, 10, 32)
+ err.Position = int32(n)
+ case 'p':
+ s := r.readCString()
+ n, _ := strconv.ParseInt(s, 10, 32)
+ err.InternalPosition = int32(n)
+ case 'q':
+ err.InternalQuery = r.readCString()
+ case 'W':
+ err.Where = r.readCString()
+ case 's':
+ err.SchemaName = r.readCString()
+ case 't':
+ err.TableName = r.readCString()
+ case 'c':
+ err.ColumnName = r.readCString()
+ case 'd':
+ err.DataTypeName = r.readCString()
+ case 'n':
+ err.ConstraintName = r.readCString()
+ case 'F':
+ err.File = r.readCString()
+ case 'L':
+ s := r.readCString()
+ n, _ := strconv.ParseInt(s, 10, 32)
+ err.Line = int32(n)
+ case 'R':
+ err.Routine = r.readCString()
+
+ case 0: // End of error message
+ if err.Severity == "FATAL" {
+ c.die(err)
+ }
+ return
+ default: // Ignore other error fields
+ r.readCString()
+ }
+ }
+}
+
+func (c *Conn) rxBackendKeyData(r *msgReader) {
+ c.Pid = r.readInt32()
+ c.SecretKey = r.readInt32()
+}
+
+func (c *Conn) rxReadyForQuery(r *msgReader) {
+ c.TxStatus = r.readByte()
+}
+
+func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) {
+ fieldCount := r.readInt16()
+ fields = make([]FieldDescription, fieldCount)
+ for i := int16(0); i < fieldCount; i++ {
+ f := &fields[i]
+ f.Name = r.readCString()
+ f.Table = r.readOid()
+ f.AttributeNumber = r.readInt16()
+ f.DataType = r.readOid()
+ f.DataTypeSize = r.readInt16()
+ f.Modifier = r.readInt32()
+ f.FormatCode = r.readInt16()
+ }
+ return
+}
+
+func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) {
+ // Internally, PostgreSQL supports greater than 64k parameters to a prepared
+ // statement. But the parameter description uses a 16-bit integer for the
+ // count of parameters. If there are more than 64K parameters, this count is
+ // wrong. So read the count, ignore it, and compute the proper value from
+ // the size of the message.
+ r.readInt16()
+ parameterCount := r.msgBytesRemaining / 4
+
+ parameters = make([]Oid, 0, parameterCount)
+
+ for i := int32(0); i < parameterCount; i++ {
+ parameters = append(parameters, r.readOid())
+ }
+ return
+}
+
+func (c *Conn) rxNotificationResponse(r *msgReader) {
+ n := new(Notification)
+ n.Pid = r.readInt32()
+ n.Channel = r.readCString()
+ n.Payload = r.readCString()
+ c.notifications = append(c.notifications, n)
+}
+
+func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) {
+ err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
+ if err != nil {
+ return
+ }
+
+ response := make([]byte, 1)
+ if _, err = io.ReadFull(c.conn, response); err != nil {
+ return
+ }
+
+ if response[0] != 'S' {
+ return ErrTLSRefused
+ }
+
+ c.conn = tls.Client(c.conn, tlsConfig)
+
+ return nil
+}
+
+func (c *Conn) txStartupMessage(msg *startupMessage) error {
+ _, err := c.conn.Write(msg.Bytes())
+ return err
+}
+
+func (c *Conn) txPasswordMessage(password string) (err error) {
+ wbuf := newWriteBuf(c, 'p')
+ wbuf.WriteCString(password)
+ wbuf.closeMsg()
+
+ _, err = c.conn.Write(wbuf.buf)
+
+ return err
+}
+
+func (c *Conn) die(err error) {
+ c.alive = false
+ c.causeOfDeath = err
+ c.conn.Close()
+}
+
+func (c *Conn) lock() error {
+ if c.busy {
+ return ErrConnBusy
+ }
+ c.busy = true
+ return nil
+}
+
+func (c *Conn) unlock() error {
+ if !c.busy {
+ return errors.New("unlock conn that is not busy")
+ }
+ c.busy = false
+ return nil
+}
+
+func (c *Conn) shouldLog(lvl int) bool {
+ return c.logger != nil && c.logLevel >= lvl
+}
+
+func (c *Conn) log(lvl int, msg string, ctx ...interface{}) {
+ if c.Pid != 0 {
+ ctx = append(ctx, "pid", c.Pid)
+ }
+
+ switch lvl {
+ case LogLevelTrace:
+ c.logger.Debug(msg, ctx...)
+ case LogLevelDebug:
+ c.logger.Debug(msg, ctx...)
+ case LogLevelInfo:
+ c.logger.Info(msg, ctx...)
+ case LogLevelWarn:
+ c.logger.Warn(msg, ctx...)
+ case LogLevelError:
+ c.logger.Error(msg, ctx...)
+ }
+}
+
+// SetLogger replaces the current logger and returns the previous logger.
+func (c *Conn) SetLogger(logger Logger) Logger {
+ oldLogger := c.logger
+ c.logger = logger
+ return oldLogger
+}
+
+// SetLogLevel replaces the current log level and returns the previous log
+// level.
+func (c *Conn) SetLogLevel(lvl int) (int, error) {
+ oldLvl := c.logLevel
+
+ if lvl < LogLevelNone || lvl > LogLevelTrace {
+ return oldLvl, ErrInvalidLogLevel
+ }
+
+ c.logLevel = lvl
+ return lvl, nil
+}
+
+func quoteIdentifier(s string) string {
+ return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
+}
diff --git a/vendor/github.com/jackc/pgx/conn_config_test.go.example b/vendor/github.com/jackc/pgx/conn_config_test.go.example
new file mode 100644
index 0000000..cac798b
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/conn_config_test.go.example
@@ -0,0 +1,25 @@
+package pgx_test
+
+import (
+ "github.com/jackc/pgx"
+)
+
+var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
+
+// To skip tests for specific connection / authentication types set that connection param to nil
+var tcpConnConfig *pgx.ConnConfig = nil
+var unixSocketConnConfig *pgx.ConnConfig = nil
+var md5ConnConfig *pgx.ConnConfig = nil
+var plainPasswordConnConfig *pgx.ConnConfig = nil
+var invalidUserConnConfig *pgx.ConnConfig = nil
+var tlsConnConfig *pgx.ConnConfig = nil
+var customDialerConnConfig *pgx.ConnConfig = nil
+var replicationConnConfig *pgx.ConnConfig = nil
+
+// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
+// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
+// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
+// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
+// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
+// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
+// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
diff --git a/vendor/github.com/jackc/pgx/conn_config_test.go.travis b/vendor/github.com/jackc/pgx/conn_config_test.go.travis
new file mode 100644
index 0000000..75714bf
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/conn_config_test.go.travis
@@ -0,0 +1,30 @@
+package pgx_test
+
+import (
+ "crypto/tls"
+ "github.com/jackc/pgx"
+ "os"
+ "strconv"
+)
+
+var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
+var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
+var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"}
+var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
+var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
+var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
+var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
+var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
+var replicationConnConfig *pgx.ConnConfig = nil
+
+func init() {
+ version := os.Getenv("PGVERSION")
+
+ if len(version) > 0 {
+ v, err := strconv.ParseFloat(version,64)
+ if err == nil && v >= 9.6 {
+ replicationConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"}
+ }
+ }
+}
+
diff --git a/vendor/github.com/jackc/pgx/conn_pool.go b/vendor/github.com/jackc/pgx/conn_pool.go
new file mode 100644
index 0000000..1913699
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/conn_pool.go
@@ -0,0 +1,529 @@
+package pgx
+
+import (
+ "errors"
+ "sync"
+ "time"
+)
+
+type ConnPoolConfig struct {
+ ConnConfig
+ MaxConnections int // max simultaneous connections to use, default 5, must be at least 2
+ AfterConnect func(*Conn) error // function to call on every new connection
+ AcquireTimeout time.Duration // max wait time when all connections are busy (0 means no timeout)
+}
+
+type ConnPool struct {
+ allConnections []*Conn
+ availableConnections []*Conn
+ cond *sync.Cond
+ config ConnConfig // config used when establishing connection
+ inProgressConnects int
+ maxConnections int
+ resetCount int
+ afterConnect func(*Conn) error
+ logger Logger
+ logLevel int
+ closed bool
+ preparedStatements map[string]*PreparedStatement
+ acquireTimeout time.Duration
+ pgTypes map[Oid]PgType
+ pgsqlAfInet *byte
+ pgsqlAfInet6 *byte
+ txAfterClose func(tx *Tx)
+ rowsAfterClose func(rows *Rows)
+}
+
+type ConnPoolStat struct {
+ MaxConnections int // max simultaneous connections to use
+ CurrentConnections int // current live connections
+ AvailableConnections int // unused live connections
+}
+
+// ErrAcquireTimeout occurs when an attempt to acquire a connection times out.
+var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool")
+
+// NewConnPool creates a new ConnPool. config.ConnConfig is passed through to
+// Connect directly.
+func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) {
+ p = new(ConnPool)
+ p.config = config.ConnConfig
+ p.maxConnections = config.MaxConnections
+ if p.maxConnections == 0 {
+ p.maxConnections = 5
+ }
+ if p.maxConnections < 1 {
+ return nil, errors.New("MaxConnections must be at least 1")
+ }
+ p.acquireTimeout = config.AcquireTimeout
+ if p.acquireTimeout < 0 {
+ return nil, errors.New("AcquireTimeout must be equal to or greater than 0")
+ }
+
+ p.afterConnect = config.AfterConnect
+
+ if config.LogLevel != 0 {
+ p.logLevel = config.LogLevel
+ } else {
+ // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
+ p.logLevel = LogLevelDebug
+ }
+ p.logger = config.Logger
+ if p.logger == nil {
+ p.logLevel = LogLevelNone
+ }
+
+ p.txAfterClose = func(tx *Tx) {
+ p.Release(tx.Conn())
+ }
+
+ p.rowsAfterClose = func(rows *Rows) {
+ p.Release(rows.Conn())
+ }
+
+ p.allConnections = make([]*Conn, 0, p.maxConnections)
+ p.availableConnections = make([]*Conn, 0, p.maxConnections)
+ p.preparedStatements = make(map[string]*PreparedStatement)
+ p.cond = sync.NewCond(new(sync.Mutex))
+
+ // Initially establish one connection
+ var c *Conn
+ c, err = p.createConnection()
+ if err != nil {
+ return
+ }
+ p.allConnections = append(p.allConnections, c)
+ p.availableConnections = append(p.availableConnections, c)
+
+ return
+}
+
+// Acquire takes exclusive use of a connection until it is released.
+func (p *ConnPool) Acquire() (*Conn, error) {
+ p.cond.L.Lock()
+ c, err := p.acquire(nil)
+ p.cond.L.Unlock()
+ return c, err
+}
+
+// deadlinePassed returns true if the given deadline has passed.
+func (p *ConnPool) deadlinePassed(deadline *time.Time) bool {
+ return deadline != nil && time.Now().After(*deadline)
+}
+
+// acquire performs acquision assuming pool is already locked
+func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
+ if p.closed {
+ return nil, errors.New("cannot acquire from closed pool")
+ }
+
+ // A connection is available
+ if len(p.availableConnections) > 0 {
+ c := p.availableConnections[len(p.availableConnections)-1]
+ c.poolResetCount = p.resetCount
+ p.availableConnections = p.availableConnections[:len(p.availableConnections)-1]
+ return c, nil
+ }
+
+ // Set initial timeout/deadline value. If the method (acquire) happens to
+ // recursively call itself the deadline should retain its value.
+ if deadline == nil && p.acquireTimeout > 0 {
+ tmp := time.Now().Add(p.acquireTimeout)
+ deadline = &tmp
+ }
+
+ // Make sure the deadline (if it is) has not passed yet
+ if p.deadlinePassed(deadline) {
+ return nil, ErrAcquireTimeout
+ }
+
+ // If there is a deadline then start a timeout timer
+ var timer *time.Timer
+ if deadline != nil {
+ timer = time.AfterFunc(deadline.Sub(time.Now()), func() {
+ p.cond.Broadcast()
+ })
+ defer timer.Stop()
+ }
+
+ // No connections are available, but we can create more
+ if len(p.allConnections)+p.inProgressConnects < p.maxConnections {
+ // Create a new connection.
+ // Careful here: createConnectionUnlocked() removes the current lock,
+ // creates a connection and then locks it back.
+ c, err := p.createConnectionUnlocked()
+ if err != nil {
+ return nil, err
+ }
+ c.poolResetCount = p.resetCount
+ p.allConnections = append(p.allConnections, c)
+ return c, nil
+ }
+ // All connections are in use and we cannot create more
+ if p.logLevel >= LogLevelWarn {
+ p.logger.Warn("All connections in pool are busy - waiting...")
+ }
+
+ // Wait until there is an available connection OR room to create a new connection
+ for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
+ if p.deadlinePassed(deadline) {
+ return nil, ErrAcquireTimeout
+ }
+ p.cond.Wait()
+ }
+
+ // Stop the timer so that we do not spawn it on every acquire call.
+ if timer != nil {
+ timer.Stop()
+ }
+ return p.acquire(deadline)
+}
+
+// Release gives up use of a connection.
+func (p *ConnPool) Release(conn *Conn) {
+ if conn.TxStatus != 'I' {
+ conn.Exec("rollback")
+ }
+
+ if len(conn.channels) > 0 {
+ if err := conn.Unlisten("*"); err != nil {
+ conn.die(err)
+ }
+ conn.channels = make(map[string]struct{})
+ }
+ conn.notifications = nil
+
+ p.cond.L.Lock()
+
+ if conn.poolResetCount != p.resetCount {
+ conn.Close()
+ p.cond.L.Unlock()
+ p.cond.Signal()
+ return
+ }
+
+ if conn.IsAlive() {
+ p.availableConnections = append(p.availableConnections, conn)
+ } else {
+ p.removeFromAllConnections(conn)
+ }
+ p.cond.L.Unlock()
+ p.cond.Signal()
+}
+
+// removeFromAllConnections Removes the given connection from the list.
+// It returns true if the connection was found and removed or false otherwise.
+func (p *ConnPool) removeFromAllConnections(conn *Conn) bool {
+ for i, c := range p.allConnections {
+ if conn == c {
+ p.allConnections = append(p.allConnections[:i], p.allConnections[i+1:]...)
+ return true
+ }
+ }
+ return false
+}
+
+// Close ends the use of a connection pool. It prevents any new connections
+// from being acquired, waits until all acquired connections are released,
+// then closes all underlying connections.
+func (p *ConnPool) Close() {
+ p.cond.L.Lock()
+ defer p.cond.L.Unlock()
+
+ p.closed = true
+
+ // Wait until all connections are released
+ if len(p.availableConnections) != len(p.allConnections) {
+ for len(p.availableConnections) != len(p.allConnections) {
+ p.cond.Wait()
+ }
+ }
+
+ for _, c := range p.allConnections {
+ _ = c.Close()
+ }
+}
+
+// Reset closes all open connections, but leaves the pool open. It is intended
+// for use when an error is detected that would disrupt all connections (such as
+// a network interruption or a server state change).
+//
+// It is safe to reset a pool while connections are checked out. Those
+// connections will be closed when they are returned to the pool.
+func (p *ConnPool) Reset() {
+ p.cond.L.Lock()
+ defer p.cond.L.Unlock()
+
+ p.resetCount++
+ p.allConnections = p.allConnections[0:0]
+
+ for _, conn := range p.availableConnections {
+ conn.Close()
+ }
+
+ p.availableConnections = p.availableConnections[0:0]
+}
+
+// invalidateAcquired causes all acquired connections to be closed when released.
+// The pool must already be locked.
+func (p *ConnPool) invalidateAcquired() {
+ p.resetCount++
+
+ for _, c := range p.availableConnections {
+ c.poolResetCount = p.resetCount
+ }
+
+ p.allConnections = p.allConnections[:len(p.availableConnections)]
+ copy(p.allConnections, p.availableConnections)
+}
+
+// Stat returns connection pool statistics
+func (p *ConnPool) Stat() (s ConnPoolStat) {
+ p.cond.L.Lock()
+ defer p.cond.L.Unlock()
+
+ s.MaxConnections = p.maxConnections
+ s.CurrentConnections = len(p.allConnections)
+ s.AvailableConnections = len(p.availableConnections)
+ return
+}
+
+func (p *ConnPool) createConnection() (*Conn, error) {
+ c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6)
+ if err != nil {
+ return nil, err
+ }
+ return p.afterConnectionCreated(c)
+}
+
+// createConnectionUnlocked Removes the current lock, creates a new connection, and
+// then locks it back.
+// Here is the point: lets say our pool dialer's OpenTimeout is set to 3 seconds.
+// And we have a pool with 20 connections in it, and we try to acquire them all at
+// startup.
+// If it happens that the remote server is not accessible, then the first connection
+// in the pool blocks all the others for 3 secs, before it gets the timeout. Then
+// connection #2 holds the lock and locks everything for the next 3 secs until it
+// gets OpenTimeout err, etc. And the very last 20th connection will fail only after
+// 3 * 20 = 60 secs.
+// To avoid this we put Connect(p.config) outside of the lock (it is thread safe)
+// what would allow us to make all the 20 connection in parallel (more or less).
+func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
+ p.inProgressConnects++
+ p.cond.L.Unlock()
+ c, err := Connect(p.config)
+ p.cond.L.Lock()
+ p.inProgressConnects--
+
+ if err != nil {
+ return nil, err
+ }
+ return p.afterConnectionCreated(c)
+}
+
+// afterConnectionCreated executes (if it is) afterConnect() callback and prepares
+// all the known statements for the new connection.
+func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
+ p.pgTypes = c.PgTypes
+ p.pgsqlAfInet = c.pgsqlAfInet
+ p.pgsqlAfInet6 = c.pgsqlAfInet6
+
+ if p.afterConnect != nil {
+ err := p.afterConnect(c)
+ if err != nil {
+ c.die(err)
+ return nil, err
+ }
+ }
+
+ for _, ps := range p.preparedStatements {
+ if _, err := c.Prepare(ps.Name, ps.SQL); err != nil {
+ c.die(err)
+ return nil, err
+ }
+ }
+
+ return c, nil
+}
+
+// Exec acquires a connection, delegates the call to that connection, and releases the connection
+func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
+ var c *Conn
+ if c, err = p.Acquire(); err != nil {
+ return
+ }
+ defer p.Release(c)
+
+ return c.Exec(sql, arguments...)
+}
+
+// Query acquires a connection and delegates the call to that connection. When
+// *Rows are closed, the connection is released automatically.
+func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
+ c, err := p.Acquire()
+ if err != nil {
+ // Because checking for errors can be deferred to the *Rows, build one with the error
+ return &Rows{closed: true, err: err}, err
+ }
+
+ rows, err := c.Query(sql, args...)
+ if err != nil {
+ p.Release(c)
+ return rows, err
+ }
+
+ rows.AfterClose(p.rowsAfterClose)
+
+ return rows, nil
+}
+
+// QueryRow acquires a connection and delegates the call to that connection. The
+// connection is released automatically after Scan is called on the returned
+// *Row.
+func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
+ rows, _ := p.Query(sql, args...)
+ return (*Row)(rows)
+}
+
+// Begin acquires a connection and begins a transaction on it. When the
+// transaction is closed the connection will be automatically released.
+func (p *ConnPool) Begin() (*Tx, error) {
+ return p.BeginIso("")
+}
+
+// Prepare creates a prepared statement on a connection in the pool to test the
+// statement is valid. If it succeeds all connections accessed through the pool
+// will have the statement available.
+//
+// Prepare creates a prepared statement with name and sql. sql can contain
+// placeholders for bound parameters. These placeholders are referenced
+// positional as $1, $2, etc.
+//
+// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with
+// the same name and sql arguments. This allows a code path to Prepare and
+// Query/Exec/PrepareEx without concern for if the statement has already been prepared.
+func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
+ return p.PrepareEx(name, sql, nil)
+}
+
+// PrepareEx creates a prepared statement on a connection in the pool to test the
+// statement is valid. If it succeeds all connections accessed through the pool
+// will have the statement available.
+//
+// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
+// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
+// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
+//
+// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
+// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without
+// concern for if the statement has already been prepared.
+func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
+ p.cond.L.Lock()
+ defer p.cond.L.Unlock()
+
+ if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql {
+ return ps, nil
+ }
+
+ c, err := p.acquire(nil)
+ if err != nil {
+ return nil, err
+ }
+
+ p.availableConnections = append(p.availableConnections, c)
+
+ // Double check that the statement was not prepared by someone else
+ // while we were acquiring the connection (since acquire is not fully
+ // blocking now, see createConnectionUnlocked())
+ if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql {
+ return ps, nil
+ }
+
+ ps, err := c.PrepareEx(name, sql, opts)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, c := range p.availableConnections {
+ _, err := c.PrepareEx(name, sql, opts)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ p.invalidateAcquired()
+ p.preparedStatements[name] = ps
+
+ return ps, err
+}
+
+// Deallocate releases a prepared statement from all connections in the pool.
+func (p *ConnPool) Deallocate(name string) (err error) {
+ p.cond.L.Lock()
+ defer p.cond.L.Unlock()
+
+ for _, c := range p.availableConnections {
+ if err := c.Deallocate(name); err != nil {
+ return err
+ }
+ }
+
+ p.invalidateAcquired()
+ delete(p.preparedStatements, name)
+
+ return nil
+}
+
+// BeginIso acquires a connection and begins a transaction in isolation mode iso
+// on it. When the transaction is closed the connection will be automatically
+// released.
+func (p *ConnPool) BeginIso(iso string) (*Tx, error) {
+ for {
+ c, err := p.Acquire()
+ if err != nil {
+ return nil, err
+ }
+
+ tx, err := c.BeginIso(iso)
+ if err != nil {
+ alive := c.IsAlive()
+ p.Release(c)
+
+ // If connection is still alive then the error is not something trying
+ // again on a new connection would fix, so just return the error. But
+ // if the connection is dead try to acquire a new connection and try
+ // again.
+ if alive {
+ return nil, err
+ }
+ continue
+ }
+
+ tx.AfterClose(p.txAfterClose)
+ return tx, nil
+ }
+}
+
+// Deprecated. Use CopyFrom instead. CopyTo acquires a connection, delegates the
+// call to that connection, and releases the connection.
+func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
+ c, err := p.Acquire()
+ if err != nil {
+ return 0, err
+ }
+ defer p.Release(c)
+
+ return c.CopyTo(tableName, columnNames, rowSrc)
+}
+
+// CopyFrom acquires a connection, delegates the call to that connection, and
+// releases the connection.
+func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
+ c, err := p.Acquire()
+ if err != nil {
+ return 0, err
+ }
+ defer p.Release(c)
+
+ return c.CopyFrom(tableName, columnNames, rowSrc)
+}
diff --git a/vendor/github.com/jackc/pgx/conn_pool_private_test.go b/vendor/github.com/jackc/pgx/conn_pool_private_test.go
new file mode 100644
index 0000000..ef0ec1d
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/conn_pool_private_test.go
@@ -0,0 +1,44 @@
+package pgx
+
+import (
+ "testing"
+)
+
+func compareConnSlices(slice1, slice2 []*Conn) bool {
+ if len(slice1) != len(slice2) {
+ return false
+ }
+ for i, c := range slice1 {
+ if c != slice2[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func TestConnPoolRemoveFromAllConnections(t *testing.T) {
+ t.Parallel()
+ pool := ConnPool{}
+ conn1 := &Conn{}
+ conn2 := &Conn{}
+ conn3 := &Conn{}
+
+ // First element
+ pool.allConnections = []*Conn{conn1, conn2, conn3}
+ pool.removeFromAllConnections(conn1)
+ if !compareConnSlices(pool.allConnections, []*Conn{conn2, conn3}) {
+ t.Fatal("First element test failed")
+ }
+ // Element somewhere in the middle
+ pool.allConnections = []*Conn{conn1, conn2, conn3}
+ pool.removeFromAllConnections(conn2)
+ if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn3}) {
+ t.Fatal("Middle element test failed")
+ }
+ // Last element
+ pool.allConnections = []*Conn{conn1, conn2, conn3}
+ pool.removeFromAllConnections(conn3)
+ if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn2}) {
+ t.Fatal("Last element test failed")
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/conn_pool_test.go b/vendor/github.com/jackc/pgx/conn_pool_test.go
new file mode 100644
index 0000000..ab76bfb
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/conn_pool_test.go
@@ -0,0 +1,982 @@
+package pgx_test
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx"
+)
+
+func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool {
+ config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: maxConnections}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ return pool
+}
+
+func acquireAllConnections(t *testing.T, pool *pgx.ConnPool, maxConnections int) []*pgx.Conn {
+ connections := make([]*pgx.Conn, maxConnections)
+ for i := 0; i < maxConnections; i++ {
+ var err error
+ if connections[i], err = pool.Acquire(); err != nil {
+ t.Fatalf("Unable to acquire connection: %v", err)
+ }
+ }
+ return connections
+}
+
+func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) {
+ for _, c := range connections {
+ pool.Release(c)
+ }
+}
+
+func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) {
+ startTime := time.Now()
+ c, err := pool.Acquire()
+ return c, time.Since(startTime), err
+}
+
+func TestNewConnPool(t *testing.T) {
+ t.Parallel()
+
+ var numCallbacks int
+ afterConnect := func(c *pgx.Conn) error {
+ numCallbacks++
+ return nil
+ }
+
+ config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 2, AfterConnect: afterConnect}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatal("Unable to establish connection pool")
+ }
+ defer pool.Close()
+
+ // It initially connects once
+ stat := pool.Stat()
+ if stat.CurrentConnections != 1 {
+ t.Errorf("Expected 1 connection to be established immediately, but %v were", numCallbacks)
+ }
+
+ // Pool creation returns an error if any AfterConnect callback does
+ errAfterConnect := errors.New("Some error")
+ afterConnect = func(c *pgx.Conn) error {
+ return errAfterConnect
+ }
+
+ config = pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 2, AfterConnect: afterConnect}
+ pool, err = pgx.NewConnPool(config)
+ if err != errAfterConnect {
+ t.Errorf("Expected errAfterConnect but received unexpected: %v", err)
+ }
+}
+
+func TestNewConnPoolDefaultsTo5MaxConnections(t *testing.T) {
+ t.Parallel()
+
+ config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatal("Unable to establish connection pool")
+ }
+ defer pool.Close()
+
+ if n := pool.Stat().MaxConnections; n != 5 {
+ t.Fatalf("Expected pool to default to 5 max connections, but it was %d", n)
+ }
+}
+
+func TestPoolAcquireAndReleaseCycle(t *testing.T) {
+ t.Parallel()
+
+ maxConnections := 2
+ incrementCount := int32(100)
+ completeSync := make(chan int)
+ pool := createConnPool(t, maxConnections)
+ defer pool.Close()
+
+ allConnections := acquireAllConnections(t, pool, maxConnections)
+
+ for _, c := range allConnections {
+ mustExec(t, c, "create temporary table t(counter integer not null)")
+ mustExec(t, c, "insert into t(counter) values(0);")
+ }
+
+ releaseAllConnections(pool, allConnections)
+
+ f := func() {
+ conn, err := pool.Acquire()
+ if err != nil {
+ t.Fatal("Unable to acquire connection")
+ }
+ defer pool.Release(conn)
+
+ // Increment counter...
+ mustExec(t, conn, "update t set counter = counter + 1")
+ completeSync <- 0
+ }
+
+ for i := int32(0); i < incrementCount; i++ {
+ go f()
+ }
+
+ // Wait for all f() to complete
+ for i := int32(0); i < incrementCount; i++ {
+ <-completeSync
+ }
+
+ // Check that temp table in each connection has been incremented some number of times
+ actualCount := int32(0)
+ allConnections = acquireAllConnections(t, pool, maxConnections)
+
+ for _, c := range allConnections {
+ var n int32
+ c.QueryRow("select counter from t").Scan(&n)
+ if n == 0 {
+ t.Error("A connection was never used")
+ }
+
+ actualCount += n
+ }
+
+ if actualCount != incrementCount {
+ fmt.Println(actualCount)
+ t.Error("Wrong number of increments")
+ }
+
+ releaseAllConnections(pool, allConnections)
+}
+
+func TestPoolNonBlockingConnections(t *testing.T) {
+ t.Parallel()
+
+ var dialCountLock sync.Mutex
+ dialCount := 0
+ openTimeout := 1 * time.Second
+ testDialer := func(network, address string) (net.Conn, error) {
+ var firstDial bool
+ dialCountLock.Lock()
+ dialCount++
+ firstDial = dialCount == 1
+ dialCountLock.Unlock()
+
+ if firstDial {
+ return net.Dial(network, address)
+ } else {
+ time.Sleep(openTimeout)
+ return nil, &net.OpError{Op: "dial", Net: "tcp"}
+ }
+ }
+
+ maxConnections := 3
+ config := pgx.ConnPoolConfig{
+ ConnConfig: *defaultConnConfig,
+ MaxConnections: maxConnections,
+ }
+ config.ConnConfig.Dial = testDialer
+
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Expected NewConnPool not to fail, instead it failed with: %v", err)
+ }
+ defer pool.Close()
+
+ // NewConnPool establishes an initial connection
+ // so we need to close that for the rest of the test
+ if conn, err := pool.Acquire(); err == nil {
+ conn.Close()
+ pool.Release(conn)
+ } else {
+ t.Fatalf("pool.Acquire unexpectedly failed: %v", err)
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(maxConnections)
+
+ startedAt := time.Now()
+ for i := 0; i < maxConnections; i++ {
+ go func() {
+ _, err := pool.Acquire()
+ wg.Done()
+ if err == nil {
+ t.Fatal("Acquire() expected to fail but it did not")
+ }
+ }()
+ }
+ wg.Wait()
+
+ // Prior to createConnectionUnlocked() use the test took
+ // maxConnections * openTimeout seconds to complete.
+ // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds.
+ timeTaken := time.Since(startedAt)
+ if timeTaken > openTimeout+1*time.Second {
+ t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken)
+ }
+
+}
+
+func TestAcquireTimeoutSanity(t *testing.T) {
+ t.Parallel()
+
+ config := pgx.ConnPoolConfig{
+ ConnConfig: *defaultConnConfig,
+ MaxConnections: 1,
+ }
+
+ // case 1: default 0 value
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Expected NewConnPool with default config.AcquireTimeout not to fail, instead it failed with '%v'", err)
+ }
+ pool.Close()
+
+ // case 2: negative value
+ config.AcquireTimeout = -1 * time.Second
+ _, err = pgx.NewConnPool(config)
+ if err == nil {
+ t.Fatal("Expected NewConnPool with negative config.AcquireTimeout to fail, instead it did not")
+ }
+
+ // case 3: positive value
+ config.AcquireTimeout = 1 * time.Second
+ pool, err = pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Expected NewConnPool with positive config.AcquireTimeout not to fail, instead it failed with '%v'", err)
+ }
+ defer pool.Close()
+}
+
+func TestPoolWithAcquireTimeoutSet(t *testing.T) {
+ t.Parallel()
+
+ connAllocTimeout := 2 * time.Second
+ config := pgx.ConnPoolConfig{
+ ConnConfig: *defaultConnConfig,
+ MaxConnections: 1,
+ AcquireTimeout: connAllocTimeout,
+ }
+
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer pool.Close()
+
+ // Consume all connections ...
+ allConnections := acquireAllConnections(t, pool, config.MaxConnections)
+ defer releaseAllConnections(pool, allConnections)
+
+ // ... then try to consume 1 more. It should fail after a short timeout.
+ _, timeTaken, err := acquireWithTimeTaken(pool)
+
+ if err == nil || err != pgx.ErrAcquireTimeout {
+ t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err)
+ }
+ if timeTaken < connAllocTimeout {
+ t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken)
+ }
+}
+
+func TestPoolWithoutAcquireTimeoutSet(t *testing.T) {
+ t.Parallel()
+
+ maxConnections := 1
+ pool := createConnPool(t, maxConnections)
+ defer pool.Close()
+
+ // Consume all connections ...
+ allConnections := acquireAllConnections(t, pool, maxConnections)
+
+ // ... then try to consume 1 more. It should hang forever.
+ // To unblock it we release the previously taken connection in a goroutine.
+ stopDeadWaitTimeout := 5 * time.Second
+ timer := time.AfterFunc(stopDeadWaitTimeout, func() {
+ releaseAllConnections(pool, allConnections)
+ })
+ defer timer.Stop()
+
+ conn, timeTaken, err := acquireWithTimeTaken(pool)
+ if err == nil {
+ pool.Release(conn)
+ } else {
+ t.Fatalf("Expected error to be nil, instead it was '%v'", err)
+ }
+ if timeTaken < stopDeadWaitTimeout {
+ t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken)
+ }
+}
+
+func TestPoolReleaseWithTransactions(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ conn, err := pool.Acquire()
+ if err != nil {
+ t.Fatalf("Unable to acquire connection: %v", err)
+ }
+ mustExec(t, conn, "begin")
+ if _, err = conn.Exec("selct"); err == nil {
+ t.Fatal("Did not receive expected error")
+ }
+
+ if conn.TxStatus != 'E' {
+ t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus)
+ }
+
+ pool.Release(conn)
+
+ if conn.TxStatus != 'I' {
+ t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus)
+ }
+
+ conn, err = pool.Acquire()
+ if err != nil {
+ t.Fatalf("Unable to acquire connection: %v", err)
+ }
+ mustExec(t, conn, "begin")
+ if conn.TxStatus != 'T' {
+ t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus)
+ }
+
+ pool.Release(conn)
+
+ if conn.TxStatus != 'I' {
+ t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus)
+ }
+}
+
+func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) {
+ t.Parallel()
+
+ maxConnections := 3
+ pool := createConnPool(t, maxConnections)
+ defer pool.Close()
+
+ doSomething := func() {
+ c, err := pool.Acquire()
+ if err != nil {
+ t.Fatalf("Unable to Acquire: %v", err)
+ }
+ rows, _ := c.Query("select 1, pg_sleep(0.02)")
+ rows.Close()
+ pool.Release(c)
+ }
+
+ for i := 0; i < 10; i++ {
+ doSomething()
+ }
+
+ stat := pool.Stat()
+ if stat.CurrentConnections != 1 {
+ t.Fatalf("Pool shouldn't have established more connections when no contention: %v", stat.CurrentConnections)
+ }
+
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doSomething()
+ }()
+ }
+ wg.Wait()
+
+ stat = pool.Stat()
+ if stat.CurrentConnections != stat.MaxConnections {
+ t.Fatalf("Pool should have used all possible connections: %v", stat.CurrentConnections)
+ }
+}
+
+func TestPoolReleaseDiscardsDeadConnections(t *testing.T) {
+ t.Parallel()
+
+ // Run timing sensitive test many times
+ for i := 0; i < 50; i++ {
+ func() {
+ maxConnections := 3
+ pool := createConnPool(t, maxConnections)
+ defer pool.Close()
+
+ var c1, c2 *pgx.Conn
+ var err error
+ var stat pgx.ConnPoolStat
+
+ if c1, err = pool.Acquire(); err != nil {
+ t.Fatalf("Unexpected error acquiring connection: %v", err)
+ }
+ defer func() {
+ if c1 != nil {
+ pool.Release(c1)
+ }
+ }()
+
+ if c2, err = pool.Acquire(); err != nil {
+ t.Fatalf("Unexpected error acquiring connection: %v", err)
+ }
+ defer func() {
+ if c2 != nil {
+ pool.Release(c2)
+ }
+ }()
+
+ if _, err = c2.Exec("select pg_terminate_backend($1)", c1.Pid); err != nil {
+ t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
+ }
+
+ // do something with the connection so it knows it's dead
+ rows, _ := c1.Query("select 1")
+ rows.Close()
+ if rows.Err() == nil {
+ t.Fatal("Expected error but none occurred")
+ }
+
+ if c1.IsAlive() {
+ t.Fatal("Expected connection to be dead but it wasn't")
+ }
+
+ stat = pool.Stat()
+ if stat.CurrentConnections != 2 {
+ t.Fatalf("Unexpected CurrentConnections: %v", stat.CurrentConnections)
+ }
+ if stat.AvailableConnections != 0 {
+ t.Fatalf("Unexpected AvailableConnections: %v", stat.CurrentConnections)
+ }
+
+ pool.Release(c1)
+ c1 = nil // so it doesn't get released again by the defer
+
+ stat = pool.Stat()
+ if stat.CurrentConnections != 1 {
+ t.Fatalf("Unexpected CurrentConnections: %v", stat.CurrentConnections)
+ }
+ if stat.AvailableConnections != 0 {
+ t.Fatalf("Unexpected AvailableConnections: %v", stat.CurrentConnections)
+ }
+ }()
+ }
+}
+
+func TestConnPoolResetClosesCheckedOutConnectionsOnRelease(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 5)
+ defer pool.Close()
+
+ inProgressRows := []*pgx.Rows{}
+ var inProgressPIDs []int32
+
+ // Start some queries and reset pool while they are in progress
+ for i := 0; i < 10; i++ {
+ rows, err := pool.Query("select pg_backend_pid() union all select 1 union all select 2")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rows.Next()
+ var pid int32
+ rows.Scan(&pid)
+ inProgressPIDs = append(inProgressPIDs, pid)
+
+ inProgressRows = append(inProgressRows, rows)
+ pool.Reset()
+ }
+
+ // Check that the queries are completed
+ for _, rows := range inProgressRows {
+ var expectedN int32
+
+ for rows.Next() {
+ expectedN++
+ var n int32
+ err := rows.Scan(&n)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if expectedN != n {
+ t.Fatalf("Expected n to be %d, but it was %d", expectedN, n)
+ }
+ }
+
+ if err := rows.Err(); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // pool should be in fresh state due to previous reset
+ stats := pool.Stat()
+ if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+
+ var connCount int
+ err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if connCount != 0 {
+ t.Fatalf("%d connections not closed", connCount)
+ }
+}
+
+func TestConnPoolResetClosesCheckedInConnections(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 5)
+ defer pool.Close()
+
+ inProgressRows := []*pgx.Rows{}
+ var inProgressPIDs []int32
+
+ // Start some queries and reset pool while they are in progress
+ for i := 0; i < 5; i++ {
+ rows, err := pool.Query("select pg_backend_pid()")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ inProgressRows = append(inProgressRows, rows)
+ }
+
+ // Check that the queries are completed
+ for _, rows := range inProgressRows {
+ for rows.Next() {
+ var pid int32
+ err := rows.Scan(&pid)
+ if err != nil {
+ t.Fatal(err)
+ }
+ inProgressPIDs = append(inProgressPIDs, pid)
+
+ }
+
+ if err := rows.Err(); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // Ensure pool is fully connected and available
+ stats := pool.Stat()
+ if stats.CurrentConnections != 5 || stats.AvailableConnections != 5 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+
+ pool.Reset()
+
+ // Pool should be empty after reset
+ stats = pool.Stat()
+ if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+
+ var connCount int
+ err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if connCount != 0 {
+ t.Fatalf("%d connections not closed", connCount)
+ }
+}
+
+func TestConnPoolTransaction(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ stats := pool.Stat()
+ if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+
+ tx, err := pool.Begin()
+ if err != nil {
+ t.Fatalf("pool.Begin failed: %v", err)
+ }
+ defer tx.Rollback()
+
+ var n int32
+ err = tx.QueryRow("select 40+$1", 2).Scan(&n)
+ if err != nil {
+ t.Fatalf("tx.QueryRow Scan failed: %v", err)
+ }
+ if n != 42 {
+ t.Errorf("Expected 42, got %d", n)
+ }
+
+ stats = pool.Stat()
+ if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback failed: %v", err)
+ }
+
+ stats = pool.Stat()
+ if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+}
+
+func TestConnPoolTransactionIso(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ tx, err := pool.BeginIso(pgx.Serializable)
+ if err != nil {
+ t.Fatalf("pool.Begin failed: %v", err)
+ }
+ defer tx.Rollback()
+
+ var level string
+ err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&level)
+ if err != nil {
+ t.Fatalf("tx.QueryRow failed: %v", level)
+ }
+
+ if level != "serializable" {
+ t.Errorf("Expected to be in isolation level %v but was %v", "serializable", level)
+ }
+}
+
+func TestConnPoolBeginRetry(t *testing.T) {
+ t.Parallel()
+
+ // Run timing sensitive test many times
+ for i := 0; i < 50; i++ {
+ func() {
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ killerConn, err := pool.Acquire()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer pool.Release(killerConn)
+
+ victimConn, err := pool.Acquire()
+ if err != nil {
+ t.Fatal(err)
+ }
+ pool.Release(victimConn)
+
+ // Terminate connection that was released to pool
+ if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.Pid); err != nil {
+ t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
+ }
+
+ // Since victimConn is the only available connection in the pool, pool.Begin should
+ // try to use it, fail, and allocate another connection
+ tx, err := pool.Begin()
+ if err != nil {
+ t.Fatalf("pool.Begin failed: %v", err)
+ }
+ defer tx.Rollback()
+
+ var txPid int32
+ err = tx.QueryRow("select pg_backend_pid()").Scan(&txPid)
+ if err != nil {
+ t.Fatalf("tx.QueryRow Scan failed: %v", err)
+ }
+ if txPid == victimConn.Pid {
+ t.Error("Expected txPid to defer from killed conn pid, but it didn't")
+ }
+ }()
+ }
+}
+
+func TestConnPoolQuery(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ var sum, rowCount int32
+
+ rows, err := pool.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("pool.Query failed: %v", err)
+ }
+
+ stats := pool.Stat()
+ if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+
+ for rows.Next() {
+ var n int32
+ rows.Scan(&n)
+ sum += n
+ rowCount++
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ if rowCount != 10 {
+ t.Error("Select called onDataRow wrong number of times")
+ }
+ if sum != 55 {
+ t.Error("Wrong values returned")
+ }
+
+ stats = pool.Stat()
+ if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+}
+
+func TestConnPoolQueryConcurrentLoad(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 10)
+ defer pool.Close()
+
+ n := 100
+ done := make(chan bool)
+
+ for i := 0; i < n; i++ {
+ go func() {
+ defer func() { done <- true }()
+ var rowCount int32
+
+ rows, err := pool.Query("select generate_series(1,$1)", 1000)
+ if err != nil {
+ t.Fatalf("pool.Query failed: %v", err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var n int32
+ err = rows.Scan(&n)
+ if err != nil {
+ t.Fatalf("rows.Scan failed: %v", err)
+ }
+ if n != rowCount+1 {
+ t.Fatalf("Expected n to be %d, but it was %d", rowCount+1, n)
+ }
+ rowCount++
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("conn.Query failed: %v", rows.Err())
+ }
+
+ if rowCount != 1000 {
+ t.Error("Select called onDataRow wrong number of times")
+ }
+
+ _, err = pool.Exec("--;")
+ if err != nil {
+ t.Fatalf("pool.Exec failed: %v", err)
+ }
+ }()
+ }
+
+ for i := 0; i < n; i++ {
+ <-done
+ }
+}
+
+func TestConnPoolQueryRow(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ var n int32
+ err := pool.QueryRow("select 40+$1", 2).Scan(&n)
+ if err != nil {
+ t.Fatalf("pool.QueryRow Scan failed: %v", err)
+ }
+
+ if n != 42 {
+ t.Errorf("Expected 42, got %d", n)
+ }
+
+ stats := pool.Stat()
+ if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 {
+ t.Fatalf("Unexpected connection pool stats: %v", stats)
+ }
+}
+
+func TestConnPoolExec(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ results, err := pool.Exec("create temporary table foo(id integer primary key);")
+ if err != nil {
+ t.Fatalf("Unexpected error from pool.Exec: %v", err)
+ }
+ if results != "CREATE TABLE" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+
+ results, err = pool.Exec("insert into foo(id) values($1)", 1)
+ if err != nil {
+ t.Fatalf("Unexpected error from pool.Exec: %v", err)
+ }
+ if results != "INSERT 0 1" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+
+ results, err = pool.Exec("drop table foo;")
+ if err != nil {
+ t.Fatalf("Unexpected error from pool.Exec: %v", err)
+ }
+ if results != "DROP TABLE" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+}
+
+func TestConnPoolPrepare(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ _, err := pool.Prepare("test", "select $1::varchar")
+ if err != nil {
+ t.Fatalf("Unable to prepare statement: %v", err)
+ }
+
+ var s string
+ err = pool.QueryRow("test", "hello").Scan(&s)
+ if err != nil {
+ t.Errorf("Executing prepared statement failed: %v", err)
+ }
+
+ if s != "hello" {
+ t.Errorf("Prepared statement did not return expected value: %v", s)
+ }
+
+ err = pool.Deallocate("test")
+ if err != nil {
+ t.Errorf("Deallocate failed: %v", err)
+ }
+
+ err = pool.QueryRow("test", "hello").Scan(&s)
+ if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") {
+ t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
+ }
+}
+
+func TestConnPoolPrepareDeallocatePrepare(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ _, err := pool.Prepare("test", "select $1::varchar")
+ if err != nil {
+ t.Fatalf("Unable to prepare statement: %v", err)
+ }
+ err = pool.Deallocate("test")
+ if err != nil {
+ t.Fatalf("Unable to deallocate statement: %v", err)
+ }
+ _, err = pool.Prepare("test", "select $1::varchar")
+ if err != nil {
+ t.Fatalf("Unable to prepare statement: %v", err)
+ }
+
+ var s string
+ err = pool.QueryRow("test", "hello").Scan(&s)
+ if err != nil {
+ t.Fatalf("Executing prepared statement failed: %v", err)
+ }
+
+ if s != "hello" {
+ t.Errorf("Prepared statement did not return expected value: %v", s)
+ }
+}
+
+func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 2)
+ defer pool.Close()
+
+ testPreparedStatement := func(db queryRower, desc string) {
+ var s string
+ err := db.QueryRow("test", "hello").Scan(&s)
+ if err != nil {
+ t.Fatalf("%s. Executing prepared statement failed: %v", desc, err)
+ }
+
+ if s != "hello" {
+ t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s)
+ }
+ }
+
+ newReleaseOnce := func(c *pgx.Conn) func() {
+ var once sync.Once
+ return func() {
+ once.Do(func() { pool.Release(c) })
+ }
+ }
+
+ c1, err := pool.Acquire()
+ if err != nil {
+ t.Fatalf("Unable to acquire connection: %v", err)
+ }
+ c1Release := newReleaseOnce(c1)
+ defer c1Release()
+
+ _, err = pool.Prepare("test", "select $1::varchar")
+ if err != nil {
+ t.Fatalf("Unable to prepare statement: %v", err)
+ }
+
+ testPreparedStatement(pool, "pool")
+
+ c1Release()
+
+ c2, err := pool.Acquire()
+ if err != nil {
+ t.Fatalf("Unable to acquire connection: %v", err)
+ }
+ c2Release := newReleaseOnce(c2)
+ defer c2Release()
+
+ // This conn will not be available and will be connection at this point
+ c3, err := pool.Acquire()
+ if err != nil {
+ t.Fatalf("Unable to acquire connection: %v", err)
+ }
+ c3Release := newReleaseOnce(c3)
+ defer c3Release()
+
+ testPreparedStatement(c2, "c2")
+ testPreparedStatement(c3, "c3")
+
+ c2Release()
+ c3Release()
+
+ err = pool.Deallocate("test")
+ if err != nil {
+ t.Errorf("Deallocate failed: %v", err)
+ }
+
+ var s string
+ err = pool.QueryRow("test", "hello").Scan(&s)
+ if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") {
+ t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err)
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/conn_test.go b/vendor/github.com/jackc/pgx/conn_test.go
new file mode 100644
index 0000000..cfb9956
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/conn_test.go
@@ -0,0 +1,1744 @@
+package pgx_test
+
+import (
+ "crypto/tls"
+ "fmt"
+ "net"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx"
+)
+
+func TestConnect(t *testing.T) {
+ t.Parallel()
+
+ conn, err := pgx.Connect(*defaultConnConfig)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+
+ if _, present := conn.RuntimeParams["server_version"]; !present {
+ t.Error("Runtime parameters not stored")
+ }
+
+ if conn.Pid == 0 {
+ t.Error("Backend PID not stored")
+ }
+
+ if conn.SecretKey == 0 {
+ t.Error("Backend secret key not stored")
+ }
+
+ var currentDB string
+ err = conn.QueryRow("select current_database()").Scan(&currentDB)
+ if err != nil {
+ t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
+ }
+ if currentDB != defaultConnConfig.Database {
+ t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database)
+ }
+
+ var user string
+ err = conn.QueryRow("select current_user").Scan(&user)
+ if err != nil {
+ t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
+ }
+ if user != defaultConnConfig.User {
+ t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User)
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithUnixSocketDirectory(t *testing.T) {
+ t.Parallel()
+
+ // /.s.PGSQL.5432
+ if unixSocketConnConfig == nil {
+ t.Skip("Skipping due to undefined unixSocketConnConfig")
+ }
+
+ conn, err := pgx.Connect(*unixSocketConnConfig)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithUnixSocketFile(t *testing.T) {
+ t.Parallel()
+
+ if unixSocketConnConfig == nil {
+ t.Skip("Skipping due to undefined unixSocketConnConfig")
+ }
+
+ connParams := *unixSocketConnConfig
+ connParams.Host = connParams.Host + "/.s.PGSQL.5432"
+ conn, err := pgx.Connect(connParams)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithTcp(t *testing.T) {
+ t.Parallel()
+
+ if tcpConnConfig == nil {
+ t.Skip("Skipping due to undefined tcpConnConfig")
+ }
+
+ conn, err := pgx.Connect(*tcpConnConfig)
+ if err != nil {
+ t.Fatal("Unable to establish connection: " + err.Error())
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithTLS(t *testing.T) {
+ t.Parallel()
+
+ if tlsConnConfig == nil {
+ t.Skip("Skipping due to undefined tlsConnConfig")
+ }
+
+ conn, err := pgx.Connect(*tlsConnConfig)
+ if err != nil {
+ t.Fatal("Unable to establish connection: " + err.Error())
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithInvalidUser(t *testing.T) {
+ t.Parallel()
+
+ if invalidUserConnConfig == nil {
+ t.Skip("Skipping due to undefined invalidUserConnConfig")
+ }
+
+ _, err := pgx.Connect(*invalidUserConnConfig)
+ pgErr, ok := err.(pgx.PgError)
+ if !ok {
+ t.Fatalf("Expected to receive a PgError with code 28000, instead received: %v", err)
+ }
+ if pgErr.Code != "28000" && pgErr.Code != "28P01" {
+ t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
+ }
+}
+
+func TestConnectWithPlainTextPassword(t *testing.T) {
+ t.Parallel()
+
+ if plainPasswordConnConfig == nil {
+ t.Skip("Skipping due to undefined plainPasswordConnConfig")
+ }
+
+ conn, err := pgx.Connect(*plainPasswordConnConfig)
+ if err != nil {
+ t.Fatal("Unable to establish connection: " + err.Error())
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithMD5Password(t *testing.T) {
+ t.Parallel()
+
+ if md5ConnConfig == nil {
+ t.Skip("Skipping due to undefined md5ConnConfig")
+ }
+
+ conn, err := pgx.Connect(*md5ConnConfig)
+ if err != nil {
+ t.Fatal("Unable to establish connection: " + err.Error())
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithTLSFallback(t *testing.T) {
+ t.Parallel()
+
+ if tlsConnConfig == nil {
+ t.Skip("Skipping due to undefined tlsConnConfig")
+ }
+
+ connConfig := *tlsConnConfig
+ connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure
+
+ conn, err := pgx.Connect(connConfig)
+ if err == nil {
+ t.Fatal("Expected failed connection, but succeeded")
+ }
+
+ connConfig.UseFallbackTLS = true
+ connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
+
+ conn, err = pgx.Connect(connConfig)
+ if err != nil {
+ t.Fatal("Unable to establish connection: " + err.Error())
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithConnectionRefused(t *testing.T) {
+ t.Parallel()
+
+ // Presumably nothing is listening on 127.0.0.1:1
+ bad := *defaultConnConfig
+ bad.Host = "127.0.0.1"
+ bad.Port = 1
+
+ _, err := pgx.Connect(bad)
+ if err == nil {
+ t.Fatal("Expected error establishing connection to bad port")
+ }
+}
+
+func TestConnectCustomDialer(t *testing.T) {
+ t.Parallel()
+
+ if customDialerConnConfig == nil {
+ t.Skip("Skipping due to undefined customDialerConnConfig")
+ }
+
+ dialled := false
+ conf := *customDialerConnConfig
+ conf.Dial = func(network, address string) (net.Conn, error) {
+ dialled = true
+ return net.Dial(network, address)
+ }
+
+ conn, err := pgx.Connect(conf)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %s", err)
+ }
+ if !dialled {
+ t.Fatal("Connect did not use custom dialer")
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatal("Unable to close connection")
+ }
+}
+
+func TestConnectWithRuntimeParams(t *testing.T) {
+ t.Parallel()
+
+ connConfig := *defaultConnConfig
+ connConfig.RuntimeParams = map[string]string{
+ "application_name": "pgxtest",
+ "search_path": "myschema",
+ }
+
+ conn, err := pgx.Connect(connConfig)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+ defer conn.Close()
+
+ var s string
+ err = conn.QueryRow("show application_name").Scan(&s)
+ if err != nil {
+ t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
+ }
+ if s != "pgxtest" {
+ t.Errorf("Expected application_name to be %s, but it was %s", "pgxtest", s)
+ }
+
+ err = conn.QueryRow("show search_path").Scan(&s)
+ if err != nil {
+ t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
+ }
+ if s != "myschema" {
+ t.Errorf("Expected search_path to be %s, but it was %s", "myschema", s)
+ }
+}
+
+func TestParseURI(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ url string
+ connParams pgx.ConnConfig
+ }{
+ {
+ url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: nil,
+ UseFallbackTLS: false,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack:secret@localhost:5432/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgresql://jack:secret@localhost:5432/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack@localhost:5432/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack@localhost/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{
+ "application_name": "pgxtest",
+ "search_path": "myschema",
+ },
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ connParams, err := pgx.ParseURI(tt.url)
+ if err != nil {
+ t.Errorf("%d. Unexpected error from pgx.ParseURL(%q) => %v", i, tt.url, err)
+ continue
+ }
+
+ if !reflect.DeepEqual(connParams, tt.connParams) {
+ t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
+ }
+ }
+}
+
+func TestParseDSN(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ url string
+ connParams pgx.ConnConfig
+ }{
+ {
+ url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack password=secret host=localhost port=5432 dbname=mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack host=localhost port=5432 dbname=mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack host=localhost dbname=mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{
+ "application_name": "pgxtest",
+ "search_path": "myschema",
+ },
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ connParams, err := pgx.ParseDSN(tt.url)
+ if err != nil {
+ t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
+ continue
+ }
+
+ if !reflect.DeepEqual(connParams, tt.connParams) {
+ t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
+ }
+ }
+}
+
+func TestParseConnectionString(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ url string
+ connParams pgx.ConnConfig
+ }{
+ {
+ url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: nil,
+ UseFallbackTLS: false,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack:secret@localhost:5432/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgresql://jack:secret@localhost:5432/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack@localhost:5432/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack@localhost/mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{
+ "application_name": "pgxtest",
+ "search_path": "myschema",
+ },
+ },
+ },
+ {
+ url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack password=secret host=localhost port=5432 dbname=mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Password: "secret",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack host=localhost port=5432 dbname=mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Port: 5432,
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack host=localhost dbname=mydb",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema",
+ connParams: pgx.ConnConfig{
+ User: "jack",
+ Host: "localhost",
+ Database: "mydb",
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{
+ "application_name": "pgxtest",
+ "search_path": "myschema",
+ },
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ connParams, err := pgx.ParseConnectionString(tt.url)
+ if err != nil {
+ t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
+ continue
+ }
+
+ if !reflect.DeepEqual(connParams, tt.connParams) {
+ t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
+ }
+ }
+}
+
+func TestParseEnvLibpq(t *testing.T) {
+ pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME"}
+
+ savedEnv := make(map[string]string)
+ for _, n := range pgEnvvars {
+ savedEnv[n] = os.Getenv(n)
+ }
+ defer func() {
+ for k, v := range savedEnv {
+ err := os.Setenv(k, v)
+ if err != nil {
+ t.Fatalf("Unable to restore environment: %v", err)
+ }
+ }
+ }()
+
+ tests := []struct {
+ name string
+ envvars map[string]string
+ config pgx.ConnConfig
+ }{
+ {
+ name: "No environment",
+ envvars: map[string]string{},
+ config: pgx.ConnConfig{
+ TLSConfig: &tls.Config{InsecureSkipVerify: true},
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "Normal PG vars",
+ envvars: map[string]string{
+ "PGHOST": "123.123.123.123",
+ "PGPORT": "7777",
+ "PGDATABASE": "foo",
+ "PGUSER": "bar",
+ "PGPASSWORD": "baz",
+ },
+ config: pgx.ConnConfig{
+ Host: "123.123.123.123",
+ Port: 7777,
+ Database: "foo",
+ User: "bar",
+ Password: "baz",
+ TLSConfig: &tls.Config{InsecureSkipVerify: true},
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "application_name",
+ envvars: map[string]string{
+ "PGAPPNAME": "pgxtest",
+ },
+ config: pgx.ConnConfig{
+ TLSConfig: &tls.Config{InsecureSkipVerify: true},
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{"application_name": "pgxtest"},
+ },
+ },
+ {
+ name: "sslmode=disable",
+ envvars: map[string]string{
+ "PGSSLMODE": "disable",
+ },
+ config: pgx.ConnConfig{
+ TLSConfig: nil,
+ UseFallbackTLS: false,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "sslmode=allow",
+ envvars: map[string]string{
+ "PGSSLMODE": "allow",
+ },
+ config: pgx.ConnConfig{
+ TLSConfig: nil,
+ UseFallbackTLS: true,
+ FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true},
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "sslmode=prefer",
+ envvars: map[string]string{
+ "PGSSLMODE": "prefer",
+ },
+ config: pgx.ConnConfig{
+ TLSConfig: &tls.Config{InsecureSkipVerify: true},
+ UseFallbackTLS: true,
+ FallbackTLSConfig: nil,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "sslmode=require",
+ envvars: map[string]string{
+ "PGSSLMODE": "require",
+ },
+ config: pgx.ConnConfig{
+ TLSConfig: &tls.Config{},
+ UseFallbackTLS: false,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "sslmode=verify-ca",
+ envvars: map[string]string{
+ "PGSSLMODE": "verify-ca",
+ },
+ config: pgx.ConnConfig{
+ TLSConfig: &tls.Config{},
+ UseFallbackTLS: false,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "sslmode=verify-full",
+ envvars: map[string]string{
+ "PGSSLMODE": "verify-full",
+ },
+ config: pgx.ConnConfig{
+ TLSConfig: &tls.Config{},
+ UseFallbackTLS: false,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ {
+ name: "sslmode=verify-full with host",
+ envvars: map[string]string{
+ "PGHOST": "pgx.example",
+ "PGSSLMODE": "verify-full",
+ },
+ config: pgx.ConnConfig{
+ Host: "pgx.example",
+ TLSConfig: &tls.Config{
+ ServerName: "pgx.example",
+ },
+ UseFallbackTLS: false,
+ RuntimeParams: map[string]string{},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ for _, n := range pgEnvvars {
+ err := os.Unsetenv(n)
+ if err != nil {
+ t.Fatalf("%s: Unable to clear environment: %v", tt.name, err)
+ }
+ }
+
+ for k, v := range tt.envvars {
+ err := os.Setenv(k, v)
+ if err != nil {
+ t.Fatalf("%s: Unable to set environment: %v", tt.name, err)
+ }
+ }
+
+ config, err := pgx.ParseEnvLibpq()
+ if err != nil {
+ t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
+ continue
+ }
+
+ if config.Host != tt.config.Host {
+ t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host)
+ }
+ if config.Port != tt.config.Port {
+ t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
+ }
+ if config.Port != tt.config.Port {
+ t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
+ }
+ if config.User != tt.config.User {
+ t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User)
+ }
+ if config.Password != tt.config.Password {
+ t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password)
+ }
+
+ if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) {
+ t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams)
+ }
+
+ tlsTests := []struct {
+ name string
+ expected *tls.Config
+ actual *tls.Config
+ }{
+ {
+ name: "TLSConfig",
+ expected: tt.config.TLSConfig,
+ actual: config.TLSConfig,
+ },
+ {
+ name: "FallbackTLSConfig",
+ expected: tt.config.FallbackTLSConfig,
+ actual: config.FallbackTLSConfig,
+ },
+ }
+ for _, tlsTest := range tlsTests {
+ name := tlsTest.name
+ expected := tlsTest.expected
+ actual := tlsTest.actual
+
+ if expected == nil && actual != nil {
+ t.Errorf("%s / %s: expected nil, but it was set", tt.name, name)
+ } else if expected != nil && actual == nil {
+ t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name)
+ } else if expected != nil && actual != nil {
+ if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
+ t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
+ }
+
+ if actual.ServerName != expected.ServerName {
+ t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName)
+ }
+ }
+ }
+
+ if config.UseFallbackTLS != tt.config.UseFallbackTLS {
+ t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS)
+ }
+ }
+}
+
+func TestExec(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" {
+ t.Error("Unexpected results from Exec")
+ }
+
+ // Accept parameters
+ if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results != "INSERT 0 1" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+
+ if results := mustExec(t, conn, "drop table foo;"); results != "DROP TABLE" {
+ t.Error("Unexpected results from Exec")
+ }
+
+ // Multiple statements can be executed -- last command tag is returned
+ if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results != "DROP TABLE" {
+ t.Error("Unexpected results from Exec")
+ }
+
+ // Can execute longer SQL strings than sharedBufferSize
+ if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results != "SELECT 1" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+
+ // Exec no-op which does not return a command tag
+ if results := mustExec(t, conn, "--;"); results != "" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+}
+
+func TestExecFailure(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ if _, err := conn.Exec("selct;"); err == nil {
+ t.Fatal("Expected SQL syntax error")
+ }
+
+ rows, _ := conn.Query("select 1")
+ rows.Close()
+ if rows.Err() != nil {
+ t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err())
+ }
+}
+
+func TestPrepare(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ _, err := conn.Prepare("test", "select $1::varchar")
+ if err != nil {
+ t.Errorf("Unable to prepare statement: %v", err)
+ return
+ }
+
+ var s string
+ err = conn.QueryRow("test", "hello").Scan(&s)
+ if err != nil {
+ t.Errorf("Executing prepared statement failed: %v", err)
+ }
+
+ if s != "hello" {
+ t.Errorf("Prepared statement did not return expected value: %v", s)
+ }
+
+ err = conn.Deallocate("test")
+ if err != nil {
+ t.Errorf("conn.Deallocate failed: %v", err)
+ }
+
+ // Create another prepared statement to ensure Deallocate left the connection
+ // in a working state and that we can reuse the prepared statement name.
+
+ _, err = conn.Prepare("test", "select $1::integer")
+ if err != nil {
+ t.Errorf("Unable to prepare statement: %v", err)
+ return
+ }
+
+ var n int32
+ err = conn.QueryRow("test", int32(1)).Scan(&n)
+ if err != nil {
+ t.Errorf("Executing prepared statement failed: %v", err)
+ }
+
+ if n != 1 {
+ t.Errorf("Prepared statement did not return expected value: %v", s)
+ }
+
+ err = conn.Deallocate("test")
+ if err != nil {
+ t.Errorf("conn.Deallocate failed: %v", err)
+ }
+}
+
+func TestPrepareBadSQLFailure(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ if _, err := conn.Prepare("badSQL", "select foo"); err == nil {
+ t.Fatal("Prepare should have failed with syntax error")
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestPrepareQueryManyParameters(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ count int
+ succeed bool
+ }{
+ {
+ count: 65534,
+ succeed: true,
+ },
+ {
+ count: 65535,
+ succeed: true,
+ },
+ {
+ count: 65536,
+ succeed: false,
+ },
+ {
+ count: 65537,
+ succeed: false,
+ },
+ }
+
+ for i, tt := range tests {
+ params := make([]string, 0, tt.count)
+ args := make([]interface{}, 0, tt.count)
+ for j := 0; j < tt.count; j++ {
+ params = append(params, fmt.Sprintf("($%d::text)", j+1))
+ args = append(args, strconv.Itoa(j))
+ }
+
+ sql := "values" + strings.Join(params, ", ")
+
+ psName := fmt.Sprintf("manyParams%d", i)
+ _, err := conn.Prepare(psName, sql)
+ if err != nil {
+ if tt.succeed {
+ t.Errorf("%d. %v", i, err)
+ }
+ continue
+ }
+ if !tt.succeed {
+ t.Errorf("%d. Expected error but succeeded", i)
+ continue
+ }
+
+ rows, err := conn.Query(psName, args...)
+ if err != nil {
+ t.Errorf("conn.Query failed: %v", err)
+ continue
+ }
+
+ for rows.Next() {
+ var s string
+ rows.Scan(&s)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Reading query result failed: %v", err)
+ }
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestPrepareIdempotency(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ for i := 0; i < 2; i++ {
+ _, err := conn.Prepare("test", "select 42::integer")
+ if err != nil {
+ t.Fatalf("%d. Unable to prepare statement: %v", i, err)
+ }
+
+ var n int32
+ err = conn.QueryRow("test").Scan(&n)
+ if err != nil {
+ t.Errorf("%d. Executing prepared statement failed: %v", i, err)
+ }
+
+ if n != int32(42) {
+ t.Errorf("%d. Prepared statement did not return expected value: %v", i, n)
+ }
+ }
+
+ _, err := conn.Prepare("test", "select 'fail'::varchar")
+ if err == nil {
+ t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't")
+ return
+ }
+}
+
+func TestPrepareEx(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}})
+ if err != nil {
+ t.Errorf("Unable to prepare statement: %v", err)
+ return
+ }
+
+ var s string
+ err = conn.QueryRow("test", "hello").Scan(&s)
+ if err != nil {
+ t.Errorf("Executing prepared statement failed: %v", err)
+ }
+
+ if s != "hello" {
+ t.Errorf("Prepared statement did not return expected value: %v", s)
+ }
+
+ err = conn.Deallocate("test")
+ if err != nil {
+ t.Errorf("conn.Deallocate failed: %v", err)
+ }
+}
+
+func TestListenNotify(t *testing.T) {
+ t.Parallel()
+
+ listener := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, listener)
+
+ if err := listener.Listen("chat"); err != nil {
+ t.Fatalf("Unable to start listening: %v", err)
+ }
+
+ notifier := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, notifier)
+
+ mustExec(t, notifier, "notify chat")
+
+ // when notification is waiting on the socket to be read
+ notification, err := listener.WaitForNotification(time.Second)
+ if err != nil {
+ t.Fatalf("Unexpected error on WaitForNotification: %v", err)
+ }
+ if notification.Channel != "chat" {
+ t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
+ }
+
+ // when notification has already been read during previous query
+ mustExec(t, notifier, "notify chat")
+ rows, _ := listener.Query("select 1")
+ rows.Close()
+ if rows.Err() != nil {
+ t.Fatalf("Unexpected error on Query: %v", rows.Err())
+ }
+ notification, err = listener.WaitForNotification(0)
+ if err != nil {
+ t.Fatalf("Unexpected error on WaitForNotification: %v", err)
+ }
+ if notification.Channel != "chat" {
+ t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
+ }
+
+ // when timeout occurs
+ notification, err = listener.WaitForNotification(time.Millisecond)
+ if err != pgx.ErrNotificationTimeout {
+ t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
+ }
+ if notification != nil {
+ t.Errorf("WaitForNotification returned an unexpected notification: %v", notification)
+ }
+
+ // listener can listen again after a timeout
+ mustExec(t, notifier, "notify chat")
+ notification, err = listener.WaitForNotification(time.Second)
+ if err != nil {
+ t.Fatalf("Unexpected error on WaitForNotification: %v", err)
+ }
+ if notification.Channel != "chat" {
+ t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
+ }
+}
+
+func TestUnlistenSpecificChannel(t *testing.T) {
+ t.Parallel()
+
+ listener := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, listener)
+
+ if err := listener.Listen("unlisten_test"); err != nil {
+ t.Fatalf("Unable to start listening: %v", err)
+ }
+
+ notifier := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, notifier)
+
+ mustExec(t, notifier, "notify unlisten_test")
+
+ // when notification is waiting on the socket to be read
+ notification, err := listener.WaitForNotification(time.Second)
+ if err != nil {
+ t.Fatalf("Unexpected error on WaitForNotification: %v", err)
+ }
+ if notification.Channel != "unlisten_test" {
+ t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
+ }
+
+ err = listener.Unlisten("unlisten_test")
+ if err != nil {
+ t.Fatalf("Unexpected error on Unlisten: %v", err)
+ }
+
+ // when notification has already been read during previous query
+ mustExec(t, notifier, "notify unlisten_test")
+ rows, _ := listener.Query("select 1")
+ rows.Close()
+ if rows.Err() != nil {
+ t.Fatalf("Unexpected error on Query: %v", rows.Err())
+ }
+ notification, err = listener.WaitForNotification(100 * time.Millisecond)
+ if err != pgx.ErrNotificationTimeout {
+ t.Errorf("WaitForNotification returned the wrong kind of error: %v", err)
+ }
+}
+
+func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
+ t.Parallel()
+
+ listenerDone := make(chan bool)
+ go func() {
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+ defer func() {
+ listenerDone <- true
+ }()
+
+ if err := conn.Listen("busysafe"); err != nil {
+ t.Fatalf("Unable to start listening: %v", err)
+ }
+
+ for i := 0; i < 5000; i++ {
+ var sum int32
+ var rowCount int32
+
+ rows, err := conn.Query("select generate_series(1,$1)", 100)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ for rows.Next() {
+ var n int32
+ rows.Scan(&n)
+ sum += n
+ rowCount++
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ if sum != 5050 {
+ t.Fatalf("Wrong rows sum: %v", sum)
+ }
+
+ if rowCount != 100 {
+ t.Fatalf("Wrong number of rows: %v", rowCount)
+ }
+
+ time.Sleep(1 * time.Microsecond)
+ }
+ }()
+
+ notifierDone := make(chan bool)
+ go func() {
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+ defer func() {
+ notifierDone <- true
+ }()
+
+ for i := 0; i < 100000; i++ {
+ mustExec(t, conn, "notify busysafe, 'hello'")
+ time.Sleep(1 * time.Microsecond)
+ }
+ }()
+
+ <-listenerDone
+}
+
+func TestListenNotifySelfNotification(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ if err := conn.Listen("self"); err != nil {
+ t.Fatalf("Unable to start listening: %v", err)
+ }
+
+ // Notify self and WaitForNotification immediately
+ mustExec(t, conn, "notify self")
+
+ notification, err := conn.WaitForNotification(time.Second)
+ if err != nil {
+ t.Fatalf("Unexpected error on WaitForNotification: %v", err)
+ }
+ if notification.Channel != "self" {
+ t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
+ }
+
+ // Notify self and do something else before WaitForNotification
+ mustExec(t, conn, "notify self")
+
+ rows, _ := conn.Query("select 1")
+ rows.Close()
+ if rows.Err() != nil {
+ t.Fatalf("Unexpected error on Query: %v", rows.Err())
+ }
+
+ notification, err = conn.WaitForNotification(time.Second)
+ if err != nil {
+ t.Fatalf("Unexpected error on WaitForNotification: %v", err)
+ }
+ if notification.Channel != "self" {
+ t.Errorf("Did not receive notification on expected channel: %v", notification.Channel)
+ }
+}
+
+func TestListenUnlistenSpecialCharacters(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ chanName := "special characters !@#{$%^&*()}"
+ if err := conn.Listen(chanName); err != nil {
+ t.Fatalf("Unable to start listening: %v", err)
+ }
+
+ if err := conn.Unlisten(chanName); err != nil {
+ t.Fatalf("Unable to stop listening: %v", err)
+ }
+}
+
+func TestFatalRxError(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var n int32
+ var s string
+ err := conn.QueryRow("select 1::int4, pg_sleep(10)::varchar").Scan(&n, &s)
+ if err == pgx.ErrDeadConn {
+ } else if pgErr, ok := err.(pgx.PgError); ok && pgErr.Severity == "FATAL" {
+ } else {
+ t.Fatalf("Expected QueryRow Scan to return fatal PgError or ErrDeadConn, but instead received %v", err)
+ }
+ }()
+
+ otherConn, err := pgx.Connect(*defaultConnConfig)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+ defer otherConn.Close()
+
+ if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.Pid); err != nil {
+ t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
+ }
+
+ wg.Wait()
+
+ if conn.IsAlive() {
+ t.Fatal("Connection should not be live but was")
+ }
+}
+
+func TestFatalTxError(t *testing.T) {
+ t.Parallel()
+
+ // Run timing sensitive test many times
+ for i := 0; i < 50; i++ {
+ func() {
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ otherConn, err := pgx.Connect(*defaultConnConfig)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+ defer otherConn.Close()
+
+ _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.Pid)
+ if err != nil {
+ t.Fatalf("Unable to kill backend PostgreSQL process: %v", err)
+ }
+
+ _, err = conn.Query("select 1")
+ if err == nil {
+ t.Fatal("Expected error but none occurred")
+ }
+
+ if conn.IsAlive() {
+ t.Fatalf("Connection should not be live but was. Previous Query err: %v", err)
+ }
+ }()
+ }
+}
+
+func TestCommandTag(t *testing.T) {
+ t.Parallel()
+
+ var tests = []struct {
+ commandTag pgx.CommandTag
+ rowsAffected int64
+ }{
+ {commandTag: "INSERT 0 5", rowsAffected: 5},
+ {commandTag: "UPDATE 0", rowsAffected: 0},
+ {commandTag: "UPDATE 1", rowsAffected: 1},
+ {commandTag: "DELETE 0", rowsAffected: 0},
+ {commandTag: "DELETE 1", rowsAffected: 1},
+ {commandTag: "CREATE TABLE", rowsAffected: 0},
+ {commandTag: "ALTER TABLE", rowsAffected: 0},
+ {commandTag: "DROP TABLE", rowsAffected: 0},
+ }
+
+ for i, tt := range tests {
+ actual := tt.commandTag.RowsAffected()
+ if tt.rowsAffected != actual {
+ t.Errorf(`%d. "%s" should have affected %d rows but it was %d`, i, tt.commandTag, tt.rowsAffected, actual)
+ }
+ }
+}
+
+func TestInsertBoolArray(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" {
+ t.Error("Unexpected results from Exec")
+ }
+
+ // Accept parameters
+ if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+}
+
+func TestInsertTimestampArray(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" {
+ t.Error("Unexpected results from Exec")
+ }
+
+ // Accept parameters
+ if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" {
+ t.Errorf("Unexpected results from Exec: %v", results)
+ }
+}
+
+func TestCatchSimultaneousConnectionQueries(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows1, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ defer rows1.Close()
+
+ _, err = conn.Query("select generate_series(1,$1)", 10)
+ if err != pgx.ErrConnBusy {
+ t.Fatalf("conn.Query should have failed with pgx.ErrConnBusy, but it was %v", err)
+ }
+}
+
+func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ defer rows.Close()
+
+ _, err = conn.Exec("create temporary table foo(spice timestamp[])")
+ if err != pgx.ErrConnBusy {
+ t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err)
+ }
+}
+
+type testLog struct {
+ lvl int
+ msg string
+ ctx []interface{}
+}
+
+type testLogger struct {
+ logs []testLog
+}
+
+func (l *testLogger) Debug(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx})
+}
+func (l *testLogger) Info(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx})
+}
+func (l *testLogger) Warn(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx})
+}
+func (l *testLogger) Error(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx})
+}
+
+func TestSetLogger(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ l1 := &testLogger{}
+ oldLogger := conn.SetLogger(l1)
+ if oldLogger != nil {
+ t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger)
+ }
+
+ if err := conn.Listen("foo"); err != nil {
+ t.Fatal(err)
+ }
+
+ if len(l1.logs) == 0 {
+ t.Fatal("Expected new logger l1 to be called, but it wasn't")
+ }
+
+ l2 := &testLogger{}
+ oldLogger = conn.SetLogger(l2)
+ if oldLogger != l1 {
+ t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger)
+ }
+
+ if err := conn.Listen("bar"); err != nil {
+ t.Fatal(err)
+ }
+
+ if len(l2.logs) == 0 {
+ t.Fatal("Expected new logger l2 to be called, but it wasn't")
+ }
+}
+
+func TestSetLogLevel(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ logger := &testLogger{}
+ conn.SetLogger(logger)
+
+ if _, err := conn.SetLogLevel(0); err != pgx.ErrInvalidLogLevel {
+ t.Fatal("SetLogLevel with invalid level did not return error")
+ }
+
+ if _, err := conn.SetLogLevel(pgx.LogLevelNone); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := conn.Listen("foo"); err != nil {
+ t.Fatal(err)
+ }
+
+ if len(logger.logs) != 0 {
+ t.Fatalf("Expected logger not to be called, but it was: %v", logger.logs)
+ }
+
+ if _, err := conn.SetLogLevel(pgx.LogLevelTrace); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := conn.Listen("bar"); err != nil {
+ t.Fatal(err)
+ }
+
+ if len(logger.logs) == 0 {
+ t.Fatal("Expected logger to be called, but it wasn't")
+ }
+}
+
+func TestIdentifierSanitize(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ ident pgx.Identifier
+ expected string
+ }{
+ {
+ ident: pgx.Identifier{`foo`},
+ expected: `"foo"`,
+ },
+ {
+ ident: pgx.Identifier{`select`},
+ expected: `"select"`,
+ },
+ {
+ ident: pgx.Identifier{`foo`, `bar`},
+ expected: `"foo"."bar"`,
+ },
+ {
+ ident: pgx.Identifier{`you should " not do this`},
+ expected: `"you should "" not do this"`,
+ },
+ {
+ ident: pgx.Identifier{`you should " not do this`, `please don't`},
+ expected: `"you should "" not do this"."please don't"`,
+ },
+ }
+
+ for i, tt := range tests {
+ qval := tt.ident.Sanitize()
+ if qval != tt.expected {
+ t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval)
+ }
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/copy_from.go b/vendor/github.com/jackc/pgx/copy_from.go
new file mode 100644
index 0000000..1f8a230
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/copy_from.go
@@ -0,0 +1,241 @@
+package pgx
+
+import (
+ "bytes"
+ "fmt"
+)
+
+// CopyFromRows returns a CopyFromSource interface over the provided rows slice
+// making it usable by *Conn.CopyFrom.
+func CopyFromRows(rows [][]interface{}) CopyFromSource {
+ return &copyFromRows{rows: rows, idx: -1}
+}
+
+type copyFromRows struct {
+ rows [][]interface{}
+ idx int
+}
+
+func (ctr *copyFromRows) Next() bool {
+ ctr.idx++
+ return ctr.idx < len(ctr.rows)
+}
+
+func (ctr *copyFromRows) Values() ([]interface{}, error) {
+ return ctr.rows[ctr.idx], nil
+}
+
+func (ctr *copyFromRows) Err() error {
+ return nil
+}
+
+// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
+type CopyFromSource interface {
+ // Next returns true if there is another row and makes the next row data
+ // available to Values(). When there are no more rows available or an error
+ // has occurred it returns false.
+ Next() bool
+
+ // Values returns the values for the current row.
+ Values() ([]interface{}, error)
+
+ // Err returns any error that has been encountered by the CopyFromSource. If
+ // this is not nil *Conn.CopyFrom will abort the copy.
+ Err() error
+}
+
+type copyFrom struct {
+ conn *Conn
+ tableName Identifier
+ columnNames []string
+ rowSrc CopyFromSource
+ readerErrChan chan error
+}
+
+func (ct *copyFrom) readUntilReadyForQuery() {
+ for {
+ t, r, err := ct.conn.rxMsg()
+ if err != nil {
+ ct.readerErrChan <- err
+ close(ct.readerErrChan)
+ return
+ }
+
+ switch t {
+ case readyForQuery:
+ ct.conn.rxReadyForQuery(r)
+ close(ct.readerErrChan)
+ return
+ case commandComplete:
+ case errorResponse:
+ ct.readerErrChan <- ct.conn.rxErrorResponse(r)
+ default:
+ err = ct.conn.processContextFreeMsg(t, r)
+ if err != nil {
+ ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
+ }
+ }
+ }
+}
+
+func (ct *copyFrom) waitForReaderDone() error {
+ var err error
+ for err = range ct.readerErrChan {
+ }
+ return err
+}
+
+func (ct *copyFrom) run() (int, error) {
+ quotedTableName := ct.tableName.Sanitize()
+ buf := &bytes.Buffer{}
+ for i, cn := range ct.columnNames {
+ if i != 0 {
+ buf.WriteString(", ")
+ }
+ buf.WriteString(quoteIdentifier(cn))
+ }
+ quotedColumnNames := buf.String()
+
+ ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
+ if err != nil {
+ return 0, err
+ }
+
+ err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
+ if err != nil {
+ return 0, err
+ }
+
+ err = ct.conn.readUntilCopyInResponse()
+ if err != nil {
+ return 0, err
+ }
+
+ go ct.readUntilReadyForQuery()
+ defer ct.waitForReaderDone()
+
+ wbuf := newWriteBuf(ct.conn, copyData)
+
+ wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
+ wbuf.WriteInt32(0)
+ wbuf.WriteInt32(0)
+
+ var sentCount int
+
+ for ct.rowSrc.Next() {
+ select {
+ case err = <-ct.readerErrChan:
+ return 0, err
+ default:
+ }
+
+ if len(wbuf.buf) > 65536 {
+ wbuf.closeMsg()
+ _, err = ct.conn.conn.Write(wbuf.buf)
+ if err != nil {
+ ct.conn.die(err)
+ return 0, err
+ }
+
+ // Directly manipulate wbuf to reset to reuse the same buffer
+ wbuf.buf = wbuf.buf[0:5]
+ wbuf.buf[0] = copyData
+ wbuf.sizeIdx = 1
+ }
+
+ sentCount++
+
+ values, err := ct.rowSrc.Values()
+ if err != nil {
+ ct.cancelCopyIn()
+ return 0, err
+ }
+ if len(values) != len(ct.columnNames) {
+ ct.cancelCopyIn()
+ return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
+ }
+
+ wbuf.WriteInt16(int16(len(ct.columnNames)))
+ for i, val := range values {
+ err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
+ if err != nil {
+ ct.cancelCopyIn()
+ return 0, err
+ }
+
+ }
+ }
+
+ if ct.rowSrc.Err() != nil {
+ ct.cancelCopyIn()
+ return 0, ct.rowSrc.Err()
+ }
+
+ wbuf.WriteInt16(-1) // terminate the copy stream
+
+ wbuf.startMsg(copyDone)
+ wbuf.closeMsg()
+ _, err = ct.conn.conn.Write(wbuf.buf)
+ if err != nil {
+ ct.conn.die(err)
+ return 0, err
+ }
+
+ err = ct.waitForReaderDone()
+ if err != nil {
+ return 0, err
+ }
+ return sentCount, nil
+}
+
+func (c *Conn) readUntilCopyInResponse() error {
+ for {
+ var t byte
+ var r *msgReader
+ t, r, err := c.rxMsg()
+ if err != nil {
+ return err
+ }
+
+ switch t {
+ case copyInResponse:
+ return nil
+ default:
+ err = c.processContextFreeMsg(t, r)
+ if err != nil {
+ return err
+ }
+ }
+ }
+}
+
+func (ct *copyFrom) cancelCopyIn() error {
+ wbuf := newWriteBuf(ct.conn, copyFail)
+ wbuf.WriteCString("client error: abort")
+ wbuf.closeMsg()
+ _, err := ct.conn.conn.Write(wbuf.buf)
+ if err != nil {
+ ct.conn.die(err)
+ return err
+ }
+
+ return nil
+}
+
+// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
+// It returns the number of rows copied and an error.
+//
+// CopyFrom requires all values use the binary format. Almost all types
+// implemented by pgx use the binary format by default. Types implementing
+// Encoder can only be used if they encode to the binary format.
+func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
+ ct := &copyFrom{
+ conn: c,
+ tableName: tableName,
+ columnNames: columnNames,
+ rowSrc: rowSrc,
+ readerErrChan: make(chan error),
+ }
+
+ return ct.run()
+}
diff --git a/vendor/github.com/jackc/pgx/copy_from_test.go b/vendor/github.com/jackc/pgx/copy_from_test.go
new file mode 100644
index 0000000..54da698
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/copy_from_test.go
@@ -0,0 +1,428 @@
+package pgx_test
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx"
+)
+
+func TestConnCopyFromSmall(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a int2,
+ b int4,
+ c int8,
+ d varchar,
+ e text,
+ f date,
+ g timestamptz
+ )`)
+
+ inputRows := [][]interface{}{
+ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
+ {nil, nil, nil, nil, nil, nil, nil},
+ }
+
+ copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
+ if err != nil {
+ t.Errorf("Unexpected error for CopyFrom: %v", err)
+ }
+ if copyCount != len(inputRows) {
+ t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if !reflect.DeepEqual(inputRows, outputRows) {
+ t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyFromLarge(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a int2,
+ b int4,
+ c int8,
+ d varchar,
+ e text,
+ f date,
+ g timestamptz,
+ h bytea
+ )`)
+
+ inputRows := [][]interface{}{}
+
+ for i := 0; i < 10000; i++ {
+ inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
+ }
+
+ copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
+ if err != nil {
+ t.Errorf("Unexpected error for CopyFrom: %v", err)
+ }
+ if copyCount != len(inputRows) {
+ t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if !reflect.DeepEqual(inputRows, outputRows) {
+ t.Errorf("Input rows and output rows do not equal")
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyFromJSON(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
+ if _, ok := conn.PgTypes[oid]; !ok {
+ return // No JSON/JSONB type -- must be running against old PostgreSQL
+ }
+ }
+
+ mustExec(t, conn, `create temporary table foo(
+ a json,
+ b jsonb
+ )`)
+
+ inputRows := [][]interface{}{
+ {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
+ {nil, nil},
+ }
+
+ copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
+ if err != nil {
+ t.Errorf("Unexpected error for CopyFrom: %v", err)
+ }
+ if copyCount != len(inputRows) {
+ t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if !reflect.DeepEqual(inputRows, outputRows) {
+ t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyFromFailServerSideMidway(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a int4,
+ b varchar not null
+ )`)
+
+ inputRows := [][]interface{}{
+ {int32(1), "abc"},
+ {int32(2), nil}, // this row should trigger a failure
+ {int32(3), "def"},
+ }
+
+ copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
+ if err == nil {
+ t.Errorf("Expected CopyFrom return error, but it did not")
+ }
+ if _, ok := err.(pgx.PgError); !ok {
+ t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+type failSource struct {
+ count int
+}
+
+func (fs *failSource) Next() bool {
+ time.Sleep(time.Millisecond * 100)
+ fs.count++
+ return fs.count < 100
+}
+
+func (fs *failSource) Values() ([]interface{}, error) {
+ if fs.count == 3 {
+ return []interface{}{nil}, nil
+ }
+ return []interface{}{make([]byte, 100000)}, nil
+}
+
+func (fs *failSource) Err() error {
+ return nil
+}
+
+func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a bytea not null
+ )`)
+
+ startTime := time.Now()
+
+ copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
+ if err == nil {
+ t.Errorf("Expected CopyFrom return error, but it did not")
+ }
+ if _, ok := err.(pgx.PgError); !ok {
+ t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
+ }
+
+ endTime := time.Now()
+ copyTime := endTime.Sub(startTime)
+ if copyTime > time.Second {
+ t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+type clientFailSource struct {
+ count int
+ err error
+}
+
+func (cfs *clientFailSource) Next() bool {
+ cfs.count++
+ return cfs.count < 100
+}
+
+func (cfs *clientFailSource) Values() ([]interface{}, error) {
+ if cfs.count == 3 {
+ cfs.err = fmt.Errorf("client error")
+ return nil, cfs.err
+ }
+ return []interface{}{make([]byte, 100000)}, nil
+}
+
+func (cfs *clientFailSource) Err() error {
+ return cfs.err
+}
+
+func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a bytea not null
+ )`)
+
+ copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
+ if err == nil {
+ t.Errorf("Expected CopyFrom return error, but it did not")
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+type clientFinalErrSource struct {
+ count int
+}
+
+func (cfs *clientFinalErrSource) Next() bool {
+ cfs.count++
+ return cfs.count < 5
+}
+
+func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
+ return []interface{}{make([]byte, 100000)}, nil
+}
+
+func (cfs *clientFinalErrSource) Err() error {
+ return fmt.Errorf("final error")
+}
+
+func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a bytea not null
+ )`)
+
+ copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
+ if err == nil {
+ t.Errorf("Expected CopyFrom return error, but it did not")
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
diff --git a/vendor/github.com/jackc/pgx/copy_to.go b/vendor/github.com/jackc/pgx/copy_to.go
new file mode 100644
index 0000000..229e9a4
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/copy_to.go
@@ -0,0 +1,222 @@
+package pgx
+
+import (
+ "bytes"
+ "fmt"
+)
+
+// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource
+// interface over the provided rows slice making it usable by *Conn.CopyTo.
+func CopyToRows(rows [][]interface{}) CopyToSource {
+ return &copyToRows{rows: rows, idx: -1}
+}
+
+type copyToRows struct {
+ rows [][]interface{}
+ idx int
+}
+
+func (ctr *copyToRows) Next() bool {
+ ctr.idx++
+ return ctr.idx < len(ctr.rows)
+}
+
+func (ctr *copyToRows) Values() ([]interface{}, error) {
+ return ctr.rows[ctr.idx], nil
+}
+
+func (ctr *copyToRows) Err() error {
+ return nil
+}
+
+// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by
+// *Conn.CopyTo as the source for copy data.
+type CopyToSource interface {
+ // Next returns true if there is another row and makes the next row data
+ // available to Values(). When there are no more rows available or an error
+ // has occurred it returns false.
+ Next() bool
+
+ // Values returns the values for the current row.
+ Values() ([]interface{}, error)
+
+ // Err returns any error that has been encountered by the CopyToSource. If
+ // this is not nil *Conn.CopyTo will abort the copy.
+ Err() error
+}
+
+type copyTo struct {
+ conn *Conn
+ tableName string
+ columnNames []string
+ rowSrc CopyToSource
+ readerErrChan chan error
+}
+
+func (ct *copyTo) readUntilReadyForQuery() {
+ for {
+ t, r, err := ct.conn.rxMsg()
+ if err != nil {
+ ct.readerErrChan <- err
+ close(ct.readerErrChan)
+ return
+ }
+
+ switch t {
+ case readyForQuery:
+ ct.conn.rxReadyForQuery(r)
+ close(ct.readerErrChan)
+ return
+ case commandComplete:
+ case errorResponse:
+ ct.readerErrChan <- ct.conn.rxErrorResponse(r)
+ default:
+ err = ct.conn.processContextFreeMsg(t, r)
+ if err != nil {
+ ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
+ }
+ }
+ }
+}
+
+func (ct *copyTo) waitForReaderDone() error {
+ var err error
+ for err = range ct.readerErrChan {
+ }
+ return err
+}
+
+func (ct *copyTo) run() (int, error) {
+ quotedTableName := quoteIdentifier(ct.tableName)
+ buf := &bytes.Buffer{}
+ for i, cn := range ct.columnNames {
+ if i != 0 {
+ buf.WriteString(", ")
+ }
+ buf.WriteString(quoteIdentifier(cn))
+ }
+ quotedColumnNames := buf.String()
+
+ ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
+ if err != nil {
+ return 0, err
+ }
+
+ err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
+ if err != nil {
+ return 0, err
+ }
+
+ err = ct.conn.readUntilCopyInResponse()
+ if err != nil {
+ return 0, err
+ }
+
+ go ct.readUntilReadyForQuery()
+ defer ct.waitForReaderDone()
+
+ wbuf := newWriteBuf(ct.conn, copyData)
+
+ wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
+ wbuf.WriteInt32(0)
+ wbuf.WriteInt32(0)
+
+ var sentCount int
+
+ for ct.rowSrc.Next() {
+ select {
+ case err = <-ct.readerErrChan:
+ return 0, err
+ default:
+ }
+
+ if len(wbuf.buf) > 65536 {
+ wbuf.closeMsg()
+ _, err = ct.conn.conn.Write(wbuf.buf)
+ if err != nil {
+ ct.conn.die(err)
+ return 0, err
+ }
+
+ // Directly manipulate wbuf to reset to reuse the same buffer
+ wbuf.buf = wbuf.buf[0:5]
+ wbuf.buf[0] = copyData
+ wbuf.sizeIdx = 1
+ }
+
+ sentCount++
+
+ values, err := ct.rowSrc.Values()
+ if err != nil {
+ ct.cancelCopyIn()
+ return 0, err
+ }
+ if len(values) != len(ct.columnNames) {
+ ct.cancelCopyIn()
+ return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
+ }
+
+ wbuf.WriteInt16(int16(len(ct.columnNames)))
+ for i, val := range values {
+ err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
+ if err != nil {
+ ct.cancelCopyIn()
+ return 0, err
+ }
+
+ }
+ }
+
+ if ct.rowSrc.Err() != nil {
+ ct.cancelCopyIn()
+ return 0, ct.rowSrc.Err()
+ }
+
+ wbuf.WriteInt16(-1) // terminate the copy stream
+
+ wbuf.startMsg(copyDone)
+ wbuf.closeMsg()
+ _, err = ct.conn.conn.Write(wbuf.buf)
+ if err != nil {
+ ct.conn.die(err)
+ return 0, err
+ }
+
+ err = ct.waitForReaderDone()
+ if err != nil {
+ return 0, err
+ }
+ return sentCount, nil
+}
+
+func (ct *copyTo) cancelCopyIn() error {
+ wbuf := newWriteBuf(ct.conn, copyFail)
+ wbuf.WriteCString("client error: abort")
+ wbuf.closeMsg()
+ _, err := ct.conn.conn.Write(wbuf.buf)
+ if err != nil {
+ ct.conn.die(err)
+ return err
+ }
+
+ return nil
+}
+
+// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to
+// perform bulk data insertion. It returns the number of rows copied and an
+// error.
+//
+// CopyTo requires all values use the binary format. Almost all types
+// implemented by pgx use the binary format by default. Types implementing
+// Encoder can only be used if they encode to the binary format.
+func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
+ ct := &copyTo{
+ conn: c,
+ tableName: tableName,
+ columnNames: columnNames,
+ rowSrc: rowSrc,
+ readerErrChan: make(chan error),
+ }
+
+ return ct.run()
+}
diff --git a/vendor/github.com/jackc/pgx/copy_to_test.go b/vendor/github.com/jackc/pgx/copy_to_test.go
new file mode 100644
index 0000000..ac27042
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/copy_to_test.go
@@ -0,0 +1,367 @@
+package pgx_test
+
+import (
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx"
+)
+
+func TestConnCopyToSmall(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a int2,
+ b int4,
+ c int8,
+ d varchar,
+ e text,
+ f date,
+ g timestamptz
+ )`)
+
+ inputRows := [][]interface{}{
+ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
+ {nil, nil, nil, nil, nil, nil, nil},
+ }
+
+ copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows))
+ if err != nil {
+ t.Errorf("Unexpected error for CopyTo: %v", err)
+ }
+ if copyCount != len(inputRows) {
+ t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if !reflect.DeepEqual(inputRows, outputRows) {
+ t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyToLarge(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a int2,
+ b int4,
+ c int8,
+ d varchar,
+ e text,
+ f date,
+ g timestamptz,
+ h bytea
+ )`)
+
+ inputRows := [][]interface{}{}
+
+ for i := 0; i < 10000; i++ {
+ inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
+ }
+
+ copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows))
+ if err != nil {
+ t.Errorf("Unexpected error for CopyTo: %v", err)
+ }
+ if copyCount != len(inputRows) {
+ t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if !reflect.DeepEqual(inputRows, outputRows) {
+ t.Errorf("Input rows and output rows do not equal")
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyToJSON(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
+ if _, ok := conn.PgTypes[oid]; !ok {
+ return // No JSON/JSONB type -- must be running against old PostgreSQL
+ }
+ }
+
+ mustExec(t, conn, `create temporary table foo(
+ a json,
+ b jsonb
+ )`)
+
+ inputRows := [][]interface{}{
+ {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
+ {nil, nil},
+ }
+
+ copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
+ if err != nil {
+ t.Errorf("Unexpected error for CopyTo: %v", err)
+ }
+ if copyCount != len(inputRows) {
+ t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if !reflect.DeepEqual(inputRows, outputRows) {
+ t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyToFailServerSideMidway(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a int4,
+ b varchar not null
+ )`)
+
+ inputRows := [][]interface{}{
+ {int32(1), "abc"},
+ {int32(2), nil}, // this row should trigger a failure
+ {int32(3), "def"},
+ }
+
+ copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
+ if err == nil {
+ t.Errorf("Expected CopyTo return error, but it did not")
+ }
+ if _, ok := err.(pgx.PgError); !ok {
+ t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a bytea not null
+ )`)
+
+ startTime := time.Now()
+
+ copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{})
+ if err == nil {
+ t.Errorf("Expected CopyTo return error, but it did not")
+ }
+ if _, ok := err.(pgx.PgError); !ok {
+ t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err)
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
+ }
+
+ endTime := time.Now()
+ copyTime := endTime.Sub(startTime)
+ if copyTime > time.Second {
+ t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a bytea not null
+ )`)
+
+ copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{})
+ if err == nil {
+ t.Errorf("Expected CopyTo return error, but it did not")
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ mustExec(t, conn, `create temporary table foo(
+ a bytea not null
+ )`)
+
+ copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{})
+ if err == nil {
+ t.Errorf("Expected CopyTo return error, but it did not")
+ }
+ if copyCount != 0 {
+ t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount)
+ }
+
+ rows, err := conn.Query("select * from foo")
+ if err != nil {
+ t.Errorf("Unexpected error for Query: %v", err)
+ }
+
+ var outputRows [][]interface{}
+ for rows.Next() {
+ row, err := rows.Values()
+ if err != nil {
+ t.Errorf("Unexpected error for rows.Values(): %v", err)
+ }
+ outputRows = append(outputRows, row)
+ }
+
+ if rows.Err() != nil {
+ t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
+ }
+
+ if len(outputRows) != 0 {
+ t.Errorf("Expected 0 rows, but got %v", outputRows)
+ }
+
+ ensureConnValid(t, conn)
+}
diff --git a/vendor/github.com/jackc/pgx/doc.go b/vendor/github.com/jackc/pgx/doc.go
new file mode 100644
index 0000000..f527d11
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/doc.go
@@ -0,0 +1,265 @@
+// Package pgx is a PostgreSQL database driver.
+/*
+pgx provides lower level access to PostgreSQL than the standard database/sql
+It remains as similar to the database/sql interface as possible while
+providing better speed and access to PostgreSQL specific features. Import
+github.com/jack/pgx/stdlib to use pgx as a database/sql compatible driver.
+
+Query Interface
+
+pgx implements Query and Scan in the familiar database/sql style.
+
+ var sum int32
+
+ // Send the query to the server. The returned rows MUST be closed
+ // before conn can be used again.
+ rows, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ return err
+ }
+
+ // rows.Close is called by rows.Next when all rows are read
+ // or an error occurs in Next or Scan. So it may optionally be
+ // omitted if nothing in the rows.Next loop can panic. It is
+ // safe to close rows multiple times.
+ defer rows.Close()
+
+ // Iterate through the result set
+ for rows.Next() {
+ var n int32
+ err = rows.Scan(&n)
+ if err != nil {
+ return err
+ }
+ sum += n
+ }
+
+ // Any errors encountered by rows.Next or rows.Scan will be returned here
+ if rows.Err() != nil {
+ return err
+ }
+
+ // No errors found - do something with sum
+
+pgx also implements QueryRow in the same style as database/sql.
+
+ var name string
+ var weight int64
+ err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
+ if err != nil {
+ return err
+ }
+
+Use Exec to execute a query that does not return a result set.
+
+ commandTag, err := conn.Exec("delete from widgets where id=$1", 42)
+ if err != nil {
+ return err
+ }
+ if commandTag.RowsAffected() != 1 {
+ return errors.New("No row found to delete")
+ }
+
+Connection Pool
+
+Connection pool usage is explicit and configurable. In pgx, a connection can
+be created and managed directly, or a connection pool with a configurable
+maximum connections can be used. Also, the connection pool offers an after
+connect hook that allows every connection to be automatically setup before
+being made available in the connection pool. This is especially useful to
+ensure all connections have the same prepared statements available or to
+change any other connection settings.
+
+It delegates Query, QueryRow, Exec, and Begin functions to an automatically
+checked out and released connection so you can avoid manually acquiring and
+releasing connections when you do not need that level of control.
+
+ var name string
+ var weight int64
+ err := pool.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight)
+ if err != nil {
+ return err
+ }
+
+Base Type Mapping
+
+pgx maps between all common base types directly between Go and PostgreSQL. In
+particular:
+
+ Go PostgreSQL
+ -----------------------
+ string varchar
+ text
+
+ // Integers are automatically be converted to any other integer type if
+ // it can be done without overflow or underflow.
+ int8
+ int16 smallint
+ int32 int
+ int64 bigint
+ int
+ uint8
+ uint16
+ uint32
+ uint64
+ uint
+
+ // Floats are strict and do not automatically convert like integers.
+ float32 float4
+ float64 float8
+
+ time.Time date
+ timestamp
+ timestamptz
+
+ []byte bytea
+
+
+Null Mapping
+
+pgx can map nulls in two ways. The first is Null* types that have a data field
+and a valid field. They work in a similar fashion to database/sql. The second
+is to use a pointer to a pointer.
+
+ var foo pgx.NullString
+ var bar *string
+ err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b)
+ if err != nil {
+ return err
+ }
+
+Array Mapping
+
+pgx maps between int16, int32, int64, float32, float64, and string Go slices
+and the equivalent PostgreSQL array type. Go slices of native types do not
+support nulls, so if a PostgreSQL array that contains a null is read into a
+native Go slice an error will occur.
+
+Hstore Mapping
+
+pgx includes an Hstore type and a NullHstore type. Hstore is simply a
+map[string]string and is preferred when the hstore contains no nulls. NullHstore
+follows the Null* pattern and supports null values.
+
+JSON and JSONB Mapping
+
+pgx includes built-in support to marshal and unmarshal between Go types and
+the PostgreSQL JSON and JSONB.
+
+Inet and Cidr Mapping
+
+pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In
+addition, as a convenience pgx will encode from a net.IP; it will assume a /32
+netmask for IPv4 and a /128 for IPv6.
+
+Custom Type Support
+
+pgx includes support for the common data types like integers, floats, strings,
+dates, and times that have direct mappings between Go and SQL. Support can be
+added for additional types like point, hstore, numeric, etc. that do not have
+direct mappings in Go by the types implementing ScannerPgx and Encoder.
+
+Custom types can support text or binary formats. Binary format can provide a
+large performance increase. The natural place for deciding the format for a
+value would be in ScannerPgx as it is responsible for decoding the returned
+data. However, that is impossible as the query has already been sent by the time
+the ScannerPgx is invoked. The solution to this is the global
+DefaultTypeFormats. If a custom type prefers binary format it should register it
+there.
+
+ pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode
+
+Note that the type is referred to by name, not by OID. This is because custom
+PostgreSQL types like hstore will have different OIDs on different servers. When
+pgx establishes a connection it queries the pg_type table for all types. It then
+matches the names in DefaultTypeFormats with the returned OIDs and stores it in
+Conn.PgTypes.
+
+See example_custom_type_test.go for an example of a custom type for the
+PostgreSQL point type.
+
+pgx also includes support for custom types implementing the database/sql.Scanner
+and database/sql/driver.Valuer interfaces.
+
+Raw Bytes Mapping
+
+[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified
+to PostgreSQL. In like manner, a *[]byte passed to Scan will be filled with
+the raw bytes returned by PostgreSQL. This can be especially useful for reading
+varchar, text, json, and jsonb values directly into a []byte and avoiding the
+type conversion from string.
+
+Transactions
+
+Transactions are started by calling Begin or BeginIso. The BeginIso variant
+creates a transaction with a specified isolation level.
+
+ tx, err := conn.Begin()
+ if err != nil {
+ return err
+ }
+ // Rollback is safe to call even if the tx is already closed, so if
+ // the tx commits successfully, this is a no-op
+ defer tx.Rollback()
+
+ _, err = tx.Exec("insert into foo(id) values (1)")
+ if err != nil {
+ return err
+ }
+
+ err = tx.Commit()
+ if err != nil {
+ return err
+ }
+
+Copy Protocol
+
+Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL
+copy protocol. CopyFrom accepts a CopyFromSource interface. If the data is already
+in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource interface. Or
+implement CopyFromSource to avoid buffering the entire data set in memory.
+
+ rows := [][]interface{}{
+ {"John", "Smith", int32(36)},
+ {"Jane", "Doe", int32(29)},
+ }
+
+ copyCount, err := conn.CopyFrom(
+ "people",
+ []string{"first_name", "last_name", "age"},
+ pgx.CopyFromRows(rows),
+ )
+
+CopyFrom can be faster than an insert with as few as 5 rows.
+
+Listen and Notify
+
+pgx can listen to the PostgreSQL notification system with the
+WaitForNotification function. It takes a maximum time to wait for a
+notification.
+
+ err := conn.Listen("channelname")
+ if err != nil {
+ return nil
+ }
+
+ if notification, err := conn.WaitForNotification(time.Second); err != nil {
+ // do something with notification
+ }
+
+TLS
+
+The pgx ConnConfig struct has a TLSConfig field. If this field is
+nil, then TLS will be disabled. If it is present, then it will be used to
+configure the TLS connection. This allows total configuration of the TLS
+connection.
+
+Logging
+
+pgx defines a simple logger interface. Connections optionally accept a logger
+that satisfies this interface. The log15 package
+(http://gopkg.in/inconshreveable/log15.v2) satisfies this interface and it is
+simple to define adapters for other loggers. Set LogLevel to control logging
+verbosity.
+*/
+package pgx
diff --git a/vendor/github.com/jackc/pgx/example_custom_type_test.go b/vendor/github.com/jackc/pgx/example_custom_type_test.go
new file mode 100644
index 0000000..34cc316
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/example_custom_type_test.go
@@ -0,0 +1,104 @@
+package pgx_test
+
+import (
+ "errors"
+ "fmt"
+ "github.com/jackc/pgx"
+ "regexp"
+ "strconv"
+)
+
+var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
+
+// NullPoint represents a point that may be null.
+//
+// If Valid is false then the value is NULL.
+type NullPoint struct {
+ X, Y float64 // Coordinates of point
+ Valid bool // Valid is true if not NULL
+}
+
+func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error {
+ if vr.Type().DataTypeName != "point" {
+ return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (OID %d)", vr.Type().DataTypeName, vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ p.X, p.Y, p.Valid = 0, 0, false
+ return nil
+ }
+
+ switch vr.Type().FormatCode {
+ case pgx.TextFormatCode:
+ s := vr.ReadString(vr.Len())
+ match := pointRegexp.FindStringSubmatch(s)
+ if match == nil {
+ return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
+ }
+
+ var err error
+ p.X, err = strconv.ParseFloat(match[1], 64)
+ if err != nil {
+ return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
+ }
+ p.Y, err = strconv.ParseFloat(match[2], 64)
+ if err != nil {
+ return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
+ }
+ case pgx.BinaryFormatCode:
+ return errors.New("binary format not implemented")
+ default:
+ return fmt.Errorf("unknown format %v", vr.Type().FormatCode)
+ }
+
+ p.Valid = true
+ return vr.Err()
+}
+
+func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode }
+
+func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
+ if !p.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ s := fmt.Sprintf("point(%v,%v)", p.X, p.Y)
+ w.WriteInt32(int32(len(s)))
+ w.WriteBytes([]byte(s))
+
+ return nil
+}
+
+func (p NullPoint) String() string {
+ if p.Valid {
+ return fmt.Sprintf("%v, %v", p.X, p.Y)
+ }
+ return "null point"
+}
+
+func Example_CustomType() {
+ conn, err := pgx.Connect(*defaultConnConfig)
+ if err != nil {
+ fmt.Printf("Unable to establish connection: %v", err)
+ return
+ }
+
+ var p NullPoint
+ err = conn.QueryRow("select null::point").Scan(&p)
+ if err != nil {
+ fmt.Println(err)
+ return
+ }
+ fmt.Println(p)
+
+ err = conn.QueryRow("select point(1.5,2.5)").Scan(&p)
+ if err != nil {
+ fmt.Println(err)
+ return
+ }
+ fmt.Println(p)
+ // Output:
+ // null point
+ // 1.5, 2.5
+}
diff --git a/vendor/github.com/jackc/pgx/example_json_test.go b/vendor/github.com/jackc/pgx/example_json_test.go
new file mode 100644
index 0000000..c153415
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/example_json_test.go
@@ -0,0 +1,43 @@
+package pgx_test
+
+import (
+ "fmt"
+ "github.com/jackc/pgx"
+)
+
+func Example_JSON() {
+ conn, err := pgx.Connect(*defaultConnConfig)
+ if err != nil {
+ fmt.Printf("Unable to establish connection: %v", err)
+ return
+ }
+
+ if _, ok := conn.PgTypes[pgx.JsonOid]; !ok {
+ // No JSON type -- must be running against very old PostgreSQL
+ // Pretend it works
+ fmt.Println("John", 42)
+ return
+ }
+
+ type person struct {
+ Name string `json:"name"`
+ Age int `json:"age"`
+ }
+
+ input := person{
+ Name: "John",
+ Age: 42,
+ }
+
+ var output person
+
+ err = conn.QueryRow("select $1::json", input).Scan(&output)
+ if err != nil {
+ fmt.Println(err)
+ return
+ }
+
+ fmt.Println(output.Name, output.Age)
+ // Output:
+ // John 42
+}
diff --git a/vendor/github.com/jackc/pgx/examples/README.md b/vendor/github.com/jackc/pgx/examples/README.md
new file mode 100644
index 0000000..6a97bc0
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/README.md
@@ -0,0 +1,7 @@
+# Examples
+
+* chat is a command line chat program using listen/notify.
+* todo is a command line todo list that demonstrates basic CRUD actions.
+* url_shortener contains a simple example of using pgx in a web context.
+* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx (uses v1 of pgx).
+* [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx.
diff --git a/vendor/github.com/jackc/pgx/examples/chat/README.md b/vendor/github.com/jackc/pgx/examples/chat/README.md
new file mode 100644
index 0000000..a093525
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/chat/README.md
@@ -0,0 +1,25 @@
+# Description
+
+This is a sample chat program implemented using PostgreSQL's listen/notify
+functionality with pgx.
+
+Start multiple instances of this program connected to the same database to chat
+between them.
+
+## Connection configuration
+
+The database connection is configured via enviroment variables.
+
+* CHAT_DB_HOST - defaults to localhost
+* CHAT_DB_USER - defaults to current OS user
+* CHAT_DB_PASSWORD - defaults to empty string
+* CHAT_DB_DATABASE - defaults to postgres
+
+You can either export them then run chat:
+
+ export CHAT_DB_HOST=/private/tmp
+ ./chat
+
+Or you can prefix the chat execution with the environment variables:
+
+ CHAT_DB_HOST=/private/tmp ./chat
diff --git a/vendor/github.com/jackc/pgx/examples/chat/main.go b/vendor/github.com/jackc/pgx/examples/chat/main.go
new file mode 100644
index 0000000..517508c
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/chat/main.go
@@ -0,0 +1,95 @@
+package main
+
+import (
+ "bufio"
+ "fmt"
+ "github.com/jackc/pgx"
+ "os"
+ "time"
+)
+
+var pool *pgx.ConnPool
+
+func main() {
+ var err error
+ pool, err = pgx.NewConnPool(extractConfig())
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Unable to connect to database:", err)
+ os.Exit(1)
+ }
+
+ go listen()
+
+ fmt.Println(`Type a message and press enter.
+
+This message should appear in any other chat instances connected to the same
+database.
+
+Type "exit" to quit.
+`)
+
+ scanner := bufio.NewScanner(os.Stdin)
+ for scanner.Scan() {
+ msg := scanner.Text()
+ if msg == "exit" {
+ os.Exit(0)
+ }
+
+ _, err = pool.Exec("select pg_notify('chat', $1)", msg)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error sending notification:", err)
+ os.Exit(1)
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err)
+ os.Exit(1)
+ }
+}
+
+func listen() {
+ conn, err := pool.Acquire()
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error acquiring connection:", err)
+ os.Exit(1)
+ }
+ defer pool.Release(conn)
+
+ conn.Listen("chat")
+
+ for {
+ notification, err := conn.WaitForNotification(time.Second)
+ if err == pgx.ErrNotificationTimeout {
+ continue
+ }
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error waiting for notification:", err)
+ os.Exit(1)
+ }
+
+ fmt.Println("PID:", notification.Pid, "Channel:", notification.Channel, "Payload:", notification.Payload)
+ }
+}
+
+func extractConfig() pgx.ConnPoolConfig {
+ var config pgx.ConnPoolConfig
+
+ config.Host = os.Getenv("CHAT_DB_HOST")
+ if config.Host == "" {
+ config.Host = "localhost"
+ }
+
+ config.User = os.Getenv("CHAT_DB_USER")
+ if config.User == "" {
+ config.User = os.Getenv("USER")
+ }
+
+ config.Password = os.Getenv("CHAT_DB_PASSWORD")
+
+ config.Database = os.Getenv("CHAT_DB_DATABASE")
+ if config.Database == "" {
+ config.Database = "postgres"
+ }
+
+ return config
+}
diff --git a/vendor/github.com/jackc/pgx/examples/todo/README.md b/vendor/github.com/jackc/pgx/examples/todo/README.md
new file mode 100644
index 0000000..eb3d95b
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/todo/README.md
@@ -0,0 +1,72 @@
+# Description
+
+This is a sample todo list implemented using pgx as the connector to a
+PostgreSQL data store.
+
+# Usage
+
+Create a PostgreSQL database and run structure.sql into it to create the
+necessary data schema.
+
+Example:
+
+ createdb todo
+ psql todo < structure.sql
+
+Build todo:
+
+ go build
+
+## Connection configuration
+
+The database connection is configured via enviroment variables.
+
+* TODO_DB_HOST - defaults to localhost
+* TODO_DB_USER - defaults to current OS user
+* TODO_DB_PASSWORD - defaults to empty string
+* TODO_DB_DATABASE - defaults to todo
+
+You can either export them then run todo:
+
+ export TODO_DB_HOST=/private/tmp
+ ./todo list
+
+Or you can prefix the todo execution with the environment variables:
+
+ TODO_DB_HOST=/private/tmp ./todo list
+
+## Add a todo item
+
+ ./todo add 'Learn go'
+
+## List tasks
+
+ ./todo list
+
+## Update a task
+
+ ./todo add 1 'Learn more go'
+
+## Delete a task
+
+ ./todo remove 1
+
+# Example Setup and Execution
+
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ createdb todo
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ psql todo < structure.sql
+ Expanded display is used automatically.
+ Timing is on.
+ CREATE TABLE
+ Time: 6.363 ms
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ go build
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ export TODO_DB_HOST=/private/tmp
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo add 'Learn Go'
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list
+ 1. Learn Go
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo update 1 'Learn more Go'
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list
+ 1. Learn more Go
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo remove 1
+ jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list
diff --git a/vendor/github.com/jackc/pgx/examples/todo/main.go b/vendor/github.com/jackc/pgx/examples/todo/main.go
new file mode 100644
index 0000000..bacdf24
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/todo/main.go
@@ -0,0 +1,140 @@
+package main
+
+import (
+ "fmt"
+ "github.com/jackc/pgx"
+ "os"
+ "strconv"
+)
+
+var conn *pgx.Conn
+
+func main() {
+ var err error
+ conn, err = pgx.Connect(extractConfig())
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Unable to connection to database: %v\n", err)
+ os.Exit(1)
+ }
+
+ if len(os.Args) == 1 {
+ printHelp()
+ os.Exit(0)
+ }
+
+ switch os.Args[1] {
+ case "list":
+ err = listTasks()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Unable to list tasks: %v\n", err)
+ os.Exit(1)
+ }
+
+ case "add":
+ err = addTask(os.Args[2])
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Unable to add task: %v\n", err)
+ os.Exit(1)
+ }
+
+ case "update":
+ n, err := strconv.ParseInt(os.Args[2], 10, 32)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err)
+ os.Exit(1)
+ }
+ err = updateTask(int32(n), os.Args[3])
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Unable to update task: %v\n", err)
+ os.Exit(1)
+ }
+
+ case "remove":
+ n, err := strconv.ParseInt(os.Args[2], 10, 32)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err)
+ os.Exit(1)
+ }
+ err = removeTask(int32(n))
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Unable to remove task: %v\n", err)
+ os.Exit(1)
+ }
+
+ default:
+ fmt.Fprintln(os.Stderr, "Invalid command")
+ printHelp()
+ os.Exit(1)
+ }
+}
+
+func listTasks() error {
+ rows, _ := conn.Query("select * from tasks")
+
+ for rows.Next() {
+ var id int32
+ var description string
+ err := rows.Scan(&id, &description)
+ if err != nil {
+ return err
+ }
+ fmt.Printf("%d. %s\n", id, description)
+ }
+
+ return rows.Err()
+}
+
+func addTask(description string) error {
+ _, err := conn.Exec("insert into tasks(description) values($1)", description)
+ return err
+}
+
+func updateTask(itemNum int32, description string) error {
+ _, err := conn.Exec("update tasks set description=$1 where id=$2", description, itemNum)
+ return err
+}
+
+func removeTask(itemNum int32) error {
+ _, err := conn.Exec("delete from tasks where id=$1", itemNum)
+ return err
+}
+
+func printHelp() {
+ fmt.Print(`Todo pgx demo
+
+Usage:
+
+ todo list
+ todo add task
+ todo update task_num item
+ todo remove task_num
+
+Example:
+
+ todo add 'Learn Go'
+ todo list
+`)
+}
+
+func extractConfig() pgx.ConnConfig {
+ var config pgx.ConnConfig
+
+ config.Host = os.Getenv("TODO_DB_HOST")
+ if config.Host == "" {
+ config.Host = "localhost"
+ }
+
+ config.User = os.Getenv("TODO_DB_USER")
+ if config.User == "" {
+ config.User = os.Getenv("USER")
+ }
+
+ config.Password = os.Getenv("TODO_DB_PASSWORD")
+
+ config.Database = os.Getenv("TODO_DB_DATABASE")
+ if config.Database == "" {
+ config.Database = "todo"
+ }
+
+ return config
+}
diff --git a/vendor/github.com/jackc/pgx/examples/todo/structure.sql b/vendor/github.com/jackc/pgx/examples/todo/structure.sql
new file mode 100644
index 0000000..567685d
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/todo/structure.sql
@@ -0,0 +1,4 @@
+create table tasks (
+ id serial primary key,
+ description text not null
+);
diff --git a/vendor/github.com/jackc/pgx/examples/url_shortener/README.md b/vendor/github.com/jackc/pgx/examples/url_shortener/README.md
new file mode 100644
index 0000000..cc04d60
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/url_shortener/README.md
@@ -0,0 +1,25 @@
+# Description
+
+This is a sample REST URL shortener service implemented using pgx as the connector to a PostgreSQL data store.
+
+# Usage
+
+Create a PostgreSQL database and run structure.sql into it to create the necessary data schema.
+
+Edit connectionOptions in main.go with the location and credentials for your database.
+
+Run main.go:
+
+ go run main.go
+
+## Create or Update a Shortened URL
+
+ curl -X PUT -d 'http://www.google.com' http://localhost:8080/google
+
+## Get a Shortened URL
+
+ curl http://localhost:8080/google
+
+## Delete a Shortened URL
+
+ curl -X DELETE http://localhost:8080/google
diff --git a/vendor/github.com/jackc/pgx/examples/url_shortener/main.go b/vendor/github.com/jackc/pgx/examples/url_shortener/main.go
new file mode 100644
index 0000000..f6a22c3
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/url_shortener/main.go
@@ -0,0 +1,128 @@
+package main
+
+import (
+ "github.com/jackc/pgx"
+ log "gopkg.in/inconshreveable/log15.v2"
+ "io/ioutil"
+ "net/http"
+ "os"
+)
+
+var pool *pgx.ConnPool
+
+// afterConnect creates the prepared statements that this application uses
+func afterConnect(conn *pgx.Conn) (err error) {
+ _, err = conn.Prepare("getUrl", `
+ select url from shortened_urls where id=$1
+ `)
+ if err != nil {
+ return
+ }
+
+ _, err = conn.Prepare("deleteUrl", `
+ delete from shortened_urls where id=$1
+ `)
+ if err != nil {
+ return
+ }
+
+ // There technically is a small race condition in doing an upsert with a CTE
+ // where one of two simultaneous requests to the shortened URL would fail
+ // with a unique index violation. As the point of this demo is pgx usage and
+ // not how to perfectly upsert in PostgreSQL it is deemed acceptable.
+ _, err = conn.Prepare("putUrl", `
+ with upsert as (
+ update shortened_urls
+ set url=$2
+ where id=$1
+ returning *
+ )
+ insert into shortened_urls(id, url)
+ select $1, $2 where not exists(select 1 from upsert)
+ `)
+ return
+}
+
+func getUrlHandler(w http.ResponseWriter, req *http.Request) {
+ var url string
+ err := pool.QueryRow("getUrl", req.URL.Path).Scan(&url)
+ switch err {
+ case nil:
+ http.Redirect(w, req, url, http.StatusSeeOther)
+ case pgx.ErrNoRows:
+ http.NotFound(w, req)
+ default:
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+func putUrlHandler(w http.ResponseWriter, req *http.Request) {
+ id := req.URL.Path
+ var url string
+ if body, err := ioutil.ReadAll(req.Body); err == nil {
+ url = string(body)
+ } else {
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ if _, err := pool.Exec("putUrl", id, url); err == nil {
+ w.WriteHeader(http.StatusOK)
+ } else {
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+func deleteUrlHandler(w http.ResponseWriter, req *http.Request) {
+ if _, err := pool.Exec("deleteUrl", req.URL.Path); err == nil {
+ w.WriteHeader(http.StatusOK)
+ } else {
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+func urlHandler(w http.ResponseWriter, req *http.Request) {
+ switch req.Method {
+ case "GET":
+ getUrlHandler(w, req)
+
+ case "PUT":
+ putUrlHandler(w, req)
+
+ case "DELETE":
+ deleteUrlHandler(w, req)
+
+ default:
+ w.Header().Add("Allow", "GET, PUT, DELETE")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ }
+}
+
+func main() {
+ var err error
+ connPoolConfig := pgx.ConnPoolConfig{
+ ConnConfig: pgx.ConnConfig{
+ Host: "127.0.0.1",
+ User: "jack",
+ Password: "jack",
+ Database: "url_shortener",
+ Logger: log.New("module", "pgx"),
+ },
+ MaxConnections: 5,
+ AfterConnect: afterConnect,
+ }
+ pool, err = pgx.NewConnPool(connPoolConfig)
+ if err != nil {
+ log.Crit("Unable to create connection pool", "error", err)
+ os.Exit(1)
+ }
+
+ http.HandleFunc("/", urlHandler)
+
+ log.Info("Starting URL shortener on localhost:8080")
+ err = http.ListenAndServe("localhost:8080", nil)
+ if err != nil {
+ log.Crit("Unable to start web server", "error", err)
+ os.Exit(1)
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql b/vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql
new file mode 100644
index 0000000..0b9de25
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql
@@ -0,0 +1,4 @@
+create table shortened_urls (
+ id text primary key,
+ url text not null
+); \ No newline at end of file
diff --git a/vendor/github.com/jackc/pgx/fastpath.go b/vendor/github.com/jackc/pgx/fastpath.go
new file mode 100644
index 0000000..19b9878
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/fastpath.go
@@ -0,0 +1,108 @@
+package pgx
+
+import (
+ "encoding/binary"
+)
+
+func newFastpath(cn *Conn) *fastpath {
+ return &fastpath{cn: cn, fns: make(map[string]Oid)}
+}
+
+type fastpath struct {
+ cn *Conn
+ fns map[string]Oid
+}
+
+func (f *fastpath) functionOID(name string) Oid {
+ return f.fns[name]
+}
+
+func (f *fastpath) addFunction(name string, oid Oid) {
+ f.fns[name] = oid
+}
+
+func (f *fastpath) addFunctions(rows *Rows) error {
+ for rows.Next() {
+ var name string
+ var oid Oid
+ if err := rows.Scan(&name, &oid); err != nil {
+ return err
+ }
+ f.addFunction(name, oid)
+ }
+ return rows.Err()
+}
+
+type fpArg []byte
+
+func fpIntArg(n int32) fpArg {
+ res := make([]byte, 4)
+ binary.BigEndian.PutUint32(res, uint32(n))
+ return res
+}
+
+func fpInt64Arg(n int64) fpArg {
+ res := make([]byte, 8)
+ binary.BigEndian.PutUint64(res, uint64(n))
+ return res
+}
+
+func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
+ wbuf := newWriteBuf(f.cn, 'F') // function call
+ wbuf.WriteInt32(int32(oid)) // function object id
+ wbuf.WriteInt16(1) // # of argument format codes
+ wbuf.WriteInt16(1) // format code: binary
+ wbuf.WriteInt16(int16(len(args))) // # of arguments
+ for _, arg := range args {
+ wbuf.WriteInt32(int32(len(arg))) // length of argument
+ wbuf.WriteBytes(arg) // argument value
+ }
+ wbuf.WriteInt16(1) // response format code (binary)
+ wbuf.closeMsg()
+
+ if _, err := f.cn.conn.Write(wbuf.buf); err != nil {
+ return nil, err
+ }
+
+ for {
+ var t byte
+ var r *msgReader
+ t, r, err = f.cn.rxMsg()
+ if err != nil {
+ return nil, err
+ }
+ switch t {
+ case 'V': // FunctionCallResponse
+ data := r.readBytes(r.readInt32())
+ res = make([]byte, len(data))
+ copy(res, data)
+ case 'Z': // Ready for query
+ f.cn.rxReadyForQuery(r)
+ // done
+ return
+ default:
+ if err := f.cn.processContextFreeMsg(t, r); err != nil {
+ return nil, err
+ }
+ }
+ }
+}
+
+func (f *fastpath) CallFn(fn string, args []fpArg) ([]byte, error) {
+ return f.Call(f.functionOID(fn), args)
+}
+
+func fpInt32(data []byte, err error) (int32, error) {
+ if err != nil {
+ return 0, err
+ }
+ n := int32(binary.BigEndian.Uint32(data))
+ return n, nil
+}
+
+func fpInt64(data []byte, err error) (int64, error) {
+ if err != nil {
+ return 0, err
+ }
+ return int64(binary.BigEndian.Uint64(data)), nil
+}
diff --git a/vendor/github.com/jackc/pgx/helper_test.go b/vendor/github.com/jackc/pgx/helper_test.go
new file mode 100644
index 0000000..eff731e
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/helper_test.go
@@ -0,0 +1,74 @@
+package pgx_test
+
+import (
+ "github.com/jackc/pgx"
+ "testing"
+)
+
+func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
+ conn, err := pgx.Connect(config)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+ return conn
+}
+
+func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn {
+ conn, err := pgx.ReplicationConnect(config)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+ return conn
+}
+
+
+func closeConn(t testing.TB, conn *pgx.Conn) {
+ err := conn.Close()
+ if err != nil {
+ t.Fatalf("conn.Close unexpectedly failed: %v", err)
+ }
+}
+
+func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) {
+ err := conn.Close()
+ if err != nil {
+ t.Fatalf("conn.Close unexpectedly failed: %v", err)
+ }
+}
+
+func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) {
+ var err error
+ if commandTag, err = conn.Exec(sql, arguments...); err != nil {
+ t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
+ }
+ return
+}
+
+// Do a simple query to ensure the connection is still usable
+func ensureConnValid(t *testing.T, conn *pgx.Conn) {
+ var sum, rowCount int32
+
+ rows, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var n int32
+ rows.Scan(&n)
+ sum += n
+ rowCount++
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ if rowCount != 10 {
+ t.Error("Select called onDataRow wrong number of times")
+ }
+ if sum != 55 {
+ t.Error("Wrong values returned")
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/hstore.go b/vendor/github.com/jackc/pgx/hstore.go
new file mode 100644
index 0000000..0ab9f77
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/hstore.go
@@ -0,0 +1,222 @@
+package pgx
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "unicode"
+ "unicode/utf8"
+)
+
+const (
+ hsPre = iota
+ hsKey
+ hsSep
+ hsVal
+ hsNul
+ hsNext
+)
+
+type hstoreParser struct {
+ str string
+ pos int
+}
+
+func newHSP(in string) *hstoreParser {
+ return &hstoreParser{
+ pos: 0,
+ str: in,
+ }
+}
+
+func (p *hstoreParser) Consume() (r rune, end bool) {
+ if p.pos >= len(p.str) {
+ end = true
+ return
+ }
+ r, w := utf8.DecodeRuneInString(p.str[p.pos:])
+ p.pos += w
+ return
+}
+
+func (p *hstoreParser) Peek() (r rune, end bool) {
+ if p.pos >= len(p.str) {
+ end = true
+ return
+ }
+ r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
+ return
+}
+
+func parseHstoreToMap(s string) (m map[string]string, err error) {
+ keys, values, err := ParseHstore(s)
+ if err != nil {
+ return
+ }
+ m = make(map[string]string, len(keys))
+ for i, key := range keys {
+ if !values[i].Valid {
+ err = fmt.Errorf("key '%s' has NULL value", key)
+ m = nil
+ return
+ }
+ m[key] = values[i].String
+ }
+ return
+}
+
+func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) {
+ keys, values, err := ParseHstore(s)
+ if err != nil {
+ return
+ }
+
+ store = make(map[string]NullString, len(keys))
+
+ for i, key := range keys {
+ store[key] = values[i]
+ }
+ return
+}
+
+// ParseHstore parses the string representation of an hstore column (the same
+// you would get from an ordinary SELECT) into two slices of keys and values. it
+// is used internally in the default parsing of hstores, but is exported for use
+// in handling custom data structures backed by an hstore column without the
+// overhead of creating a map[string]string
+func ParseHstore(s string) (k []string, v []NullString, err error) {
+ if s == "" {
+ return
+ }
+
+ buf := bytes.Buffer{}
+ keys := []string{}
+ values := []NullString{}
+ p := newHSP(s)
+
+ r, end := p.Consume()
+ state := hsPre
+
+ for !end {
+ switch state {
+ case hsPre:
+ if r == '"' {
+ state = hsKey
+ } else {
+ err = errors.New("String does not begin with \"")
+ }
+ case hsKey:
+ switch r {
+ case '"': //End of the key
+ if buf.Len() == 0 {
+ err = errors.New("Empty Key is invalid")
+ } else {
+ keys = append(keys, buf.String())
+ buf = bytes.Buffer{}
+ state = hsSep
+ }
+ case '\\': //Potential escaped character
+ n, end := p.Consume()
+ switch {
+ case end:
+ err = errors.New("Found EOS in key, expecting character or \"")
+ case n == '"', n == '\\':
+ buf.WriteRune(n)
+ default:
+ buf.WriteRune(r)
+ buf.WriteRune(n)
+ }
+ default: //Any other character
+ buf.WriteRune(r)
+ }
+ case hsSep:
+ if r == '=' {
+ r, end = p.Consume()
+ switch {
+ case end:
+ err = errors.New("Found EOS after '=', expecting '>'")
+ case r == '>':
+ r, end = p.Consume()
+ switch {
+ case end:
+ err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
+ case r == '"':
+ state = hsVal
+ case r == 'N':
+ state = hsNul
+ default:
+ err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
+ }
+ default:
+ err = fmt.Errorf("Invalid character after '=', expecting '>'")
+ }
+ } else {
+ err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
+ }
+ case hsVal:
+ switch r {
+ case '"': //End of the value
+ values = append(values, NullString{String: buf.String(), Valid: true})
+ buf = bytes.Buffer{}
+ state = hsNext
+ case '\\': //Potential escaped character
+ n, end := p.Consume()
+ switch {
+ case end:
+ err = errors.New("Found EOS in key, expecting character or \"")
+ case n == '"', n == '\\':
+ buf.WriteRune(n)
+ default:
+ buf.WriteRune(r)
+ buf.WriteRune(n)
+ }
+ default: //Any other character
+ buf.WriteRune(r)
+ }
+ case hsNul:
+ nulBuf := make([]rune, 3)
+ nulBuf[0] = r
+ for i := 1; i < 3; i++ {
+ r, end = p.Consume()
+ if end {
+ err = errors.New("Found EOS in NULL value")
+ return
+ }
+ nulBuf[i] = r
+ }
+ if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
+ values = append(values, NullString{String: "", Valid: false})
+ state = hsNext
+ } else {
+ err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
+ }
+ case hsNext:
+ if r == ',' {
+ r, end = p.Consume()
+ switch {
+ case end:
+ err = errors.New("Found EOS after ',', expcting space")
+ case (unicode.IsSpace(r)):
+ r, end = p.Consume()
+ state = hsKey
+ default:
+ err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
+ }
+ } else {
+ err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
+ }
+ }
+
+ if err != nil {
+ return
+ }
+ r, end = p.Consume()
+ }
+ if state != hsNext {
+ err = errors.New("Improperly formatted hstore")
+ return
+ }
+ k = keys
+ v = values
+ return
+}
diff --git a/vendor/github.com/jackc/pgx/hstore_test.go b/vendor/github.com/jackc/pgx/hstore_test.go
new file mode 100644
index 0000000..c948f0c
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/hstore_test.go
@@ -0,0 +1,181 @@
+package pgx_test
+
+import (
+ "github.com/jackc/pgx"
+ "testing"
+)
+
+func TestHstoreTranscode(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type test struct {
+ hstore pgx.Hstore
+ description string
+ }
+
+ tests := []test{
+ {pgx.Hstore{}, "empty"},
+ {pgx.Hstore{"foo": "bar"}, "single key/value"},
+ {pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"},
+ {pgx.Hstore{"NULL": "bar"}, `string "NULL" key`},
+ {pgx.Hstore{"foo": "NULL"}, `string "NULL" value`},
+ }
+
+ specialStringTests := []struct {
+ input string
+ description string
+ }{
+ {`"`, `double quote (")`},
+ {`'`, `single quote (')`},
+ {`\`, `backslash (\)`},
+ {`\\`, `multiple backslashes (\\)`},
+ {`=>`, `separator (=>)`},
+ {` `, `space`},
+ {`\ / / \\ => " ' " '`, `multiple special characters`},
+ }
+ for _, sst := range specialStringTests {
+ tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"})
+ tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"})
+ tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"})
+ tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description})
+
+ tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"})
+ tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"})
+ tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"})
+ tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description})
+ }
+
+ for _, tt := range tests {
+ var result pgx.Hstore
+ err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result)
+ if err != nil {
+ t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
+ }
+
+ for key, inValue := range tt.hstore {
+ outValue, ok := result[key]
+ if ok {
+ if inValue != outValue {
+ t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue)
+ }
+ } else {
+ t.Errorf(`%s: Missing key %s`, tt.description, key)
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestNullHstoreTranscode(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type test struct {
+ nullHstore pgx.NullHstore
+ description string
+ }
+
+ tests := []test{
+ {pgx.NullHstore{}, "null"},
+ {pgx.NullHstore{Valid: true}, "empty"},
+ {pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}},
+ Valid: true},
+ "single key/value"},
+ {pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}},
+ Valid: true},
+ "multiple key/values"},
+ {pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}},
+ Valid: true},
+ `string "NULL" key`},
+ {pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}},
+ Valid: true},
+ `string "NULL" value`},
+ {pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}},
+ Valid: true},
+ `NULL value`},
+ }
+
+ specialStringTests := []struct {
+ input string
+ description string
+ }{
+ {`"`, `double quote (")`},
+ {`'`, `single quote (')`},
+ {`\`, `backslash (\)`},
+ {`\\`, `multiple backslashes (\\)`},
+ {`=>`, `separator (=>)`},
+ {` `, `space`},
+ {`\ / / \\ => " ' " '`, `multiple special characters`},
+ }
+ for _, sst := range specialStringTests {
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}},
+ Valid: true},
+ "key with " + sst.description + " at beginning"})
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}},
+ Valid: true},
+ "key with " + sst.description + " in middle"})
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}},
+ Valid: true},
+ "key with " + sst.description + " at end"})
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}},
+ Valid: true},
+ "key is " + sst.description})
+
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}},
+ Valid: true},
+ "value with " + sst.description + " at beginning"})
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}},
+ Valid: true},
+ "value with " + sst.description + " in middle"})
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}},
+ Valid: true},
+ "value with " + sst.description + " at end"})
+ tests = append(tests, test{pgx.NullHstore{
+ Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}},
+ Valid: true},
+ "value is " + sst.description})
+ }
+
+ for _, tt := range tests {
+ var result pgx.NullHstore
+ err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result)
+ if err != nil {
+ t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err)
+ }
+
+ if result.Valid != tt.nullHstore.Valid {
+ t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid)
+ }
+
+ for key, inValue := range tt.nullHstore.Hstore {
+ outValue, ok := result.Hstore[key]
+ if ok {
+ if inValue != outValue {
+ t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue)
+ }
+ } else {
+ t.Errorf(`%s: Missing key %s`, tt.description, key)
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/large_objects.go b/vendor/github.com/jackc/pgx/large_objects.go
new file mode 100644
index 0000000..a4922ef
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/large_objects.go
@@ -0,0 +1,147 @@
+package pgx
+
+import (
+ "io"
+)
+
+// LargeObjects is a structure used to access the large objects API. It is only
+// valid within the transaction where it was created.
+//
+// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html
+type LargeObjects struct {
+ // Has64 is true if the server is capable of working with 64-bit numbers
+ Has64 bool
+ fp *fastpath
+}
+
+const largeObjectFns = `select proname, oid from pg_catalog.pg_proc
+where proname in (
+'lo_open',
+'lo_close',
+'lo_create',
+'lo_unlink',
+'lo_lseek',
+'lo_lseek64',
+'lo_tell',
+'lo_tell64',
+'lo_truncate',
+'lo_truncate64',
+'loread',
+'lowrite')
+and pronamespace = (select oid from pg_catalog.pg_namespace where nspname = 'pg_catalog')`
+
+// LargeObjects returns a LargeObjects instance for the transaction.
+func (tx *Tx) LargeObjects() (*LargeObjects, error) {
+ if tx.conn.fp == nil {
+ tx.conn.fp = newFastpath(tx.conn)
+ }
+ if _, exists := tx.conn.fp.fns["lo_open"]; !exists {
+ res, err := tx.Query(largeObjectFns)
+ if err != nil {
+ return nil, err
+ }
+ if err := tx.conn.fp.addFunctions(res); err != nil {
+ return nil, err
+ }
+ }
+
+ lo := &LargeObjects{fp: tx.conn.fp}
+ _, lo.Has64 = lo.fp.fns["lo_lseek64"]
+
+ return lo, nil
+}
+
+type LargeObjectMode int32
+
+const (
+ LargeObjectModeWrite LargeObjectMode = 0x20000
+ LargeObjectModeRead LargeObjectMode = 0x40000
+)
+
+// Create creates a new large object. If id is zero, the server assigns an
+// unused OID.
+func (o *LargeObjects) Create(id Oid) (Oid, error) {
+ newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))}))
+ return Oid(newOid), err
+}
+
+// Open opens an existing large object with the given mode.
+func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) {
+ fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))}))
+ return &LargeObject{fd: fd, lo: o}, err
+}
+
+// Unlink removes a large object from the database.
+func (o *LargeObjects) Unlink(oid Oid) error {
+ _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))})
+ return err
+}
+
+// A LargeObject is a large object stored on the server. It is only valid within
+// the transaction that it was initialized in. It implements these interfaces:
+//
+// io.Writer
+// io.Reader
+// io.Seeker
+// io.Closer
+type LargeObject struct {
+ fd int32
+ lo *LargeObjects
+}
+
+// Write writes p to the large object and returns the number of bytes written
+// and an error if not all of p was written.
+func (o *LargeObject) Write(p []byte) (int, error) {
+ n, err := fpInt32(o.lo.fp.CallFn("lowrite", []fpArg{fpIntArg(o.fd), p}))
+ return int(n), err
+}
+
+// Read reads up to len(p) bytes into p returning the number of bytes read.
+func (o *LargeObject) Read(p []byte) (int, error) {
+ res, err := o.lo.fp.CallFn("loread", []fpArg{fpIntArg(o.fd), fpIntArg(int32(len(p)))})
+ if len(res) < len(p) {
+ err = io.EOF
+ }
+ return copy(p, res), err
+}
+
+// Seek moves the current location pointer to the new location specified by offset.
+func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) {
+ if o.lo.Has64 {
+ n, err = fpInt64(o.lo.fp.CallFn("lo_lseek64", []fpArg{fpIntArg(o.fd), fpInt64Arg(offset), fpIntArg(int32(whence))}))
+ } else {
+ var n32 int32
+ n32, err = fpInt32(o.lo.fp.CallFn("lo_lseek", []fpArg{fpIntArg(o.fd), fpIntArg(int32(offset)), fpIntArg(int32(whence))}))
+ n = int64(n32)
+ }
+ return
+}
+
+// Tell returns the current read or write location of the large object
+// descriptor.
+func (o *LargeObject) Tell() (n int64, err error) {
+ if o.lo.Has64 {
+ n, err = fpInt64(o.lo.fp.CallFn("lo_tell64", []fpArg{fpIntArg(o.fd)}))
+ } else {
+ var n32 int32
+ n32, err = fpInt32(o.lo.fp.CallFn("lo_tell", []fpArg{fpIntArg(o.fd)}))
+ n = int64(n32)
+ }
+ return
+}
+
+// Trunctes the large object to size.
+func (o *LargeObject) Truncate(size int64) (err error) {
+ if o.lo.Has64 {
+ _, err = o.lo.fp.CallFn("lo_truncate64", []fpArg{fpIntArg(o.fd), fpInt64Arg(size)})
+ } else {
+ _, err = o.lo.fp.CallFn("lo_truncate", []fpArg{fpIntArg(o.fd), fpIntArg(int32(size))})
+ }
+ return
+}
+
+// Close closees the large object descriptor.
+func (o *LargeObject) Close() error {
+ _, err := o.lo.fp.CallFn("lo_close", []fpArg{fpIntArg(o.fd)})
+ return err
+}
diff --git a/vendor/github.com/jackc/pgx/large_objects_test.go b/vendor/github.com/jackc/pgx/large_objects_test.go
new file mode 100644
index 0000000..a19c851
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/large_objects_test.go
@@ -0,0 +1,121 @@
+package pgx_test
+
+import (
+ "io"
+ "testing"
+
+ "github.com/jackc/pgx"
+)
+
+func TestLargeObjects(t *testing.T) {
+ t.Parallel()
+
+ conn, err := pgx.Connect(*defaultConnConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tx, err := conn.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ lo, err := tx.LargeObjects()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ id, err := lo.Create(0)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ obj, err := lo.Open(id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ n, err := obj.Write([]byte("testing"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 7 {
+ t.Errorf("Expected n to be 7, got %d", n)
+ }
+
+ pos, err := obj.Seek(1, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if pos != 1 {
+ t.Errorf("Expected pos to be 1, got %d", pos)
+ }
+
+ res := make([]byte, 6)
+ n, err = obj.Read(res)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(res) != "esting" {
+ t.Errorf(`Expected res to be "esting", got %q`, res)
+ }
+ if n != 6 {
+ t.Errorf("Expected n to be 6, got %d", n)
+ }
+
+ n, err = obj.Read(res)
+ if err != io.EOF {
+ t.Error("Expected io.EOF, go nil")
+ }
+ if n != 0 {
+ t.Errorf("Expected n to be 0, got %d", n)
+ }
+
+ pos, err = obj.Tell()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if pos != 7 {
+ t.Errorf("Expected pos to be 7, got %d", pos)
+ }
+
+ err = obj.Truncate(1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ pos, err = obj.Seek(-1, 2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if pos != 0 {
+ t.Errorf("Expected pos to be 0, got %d", pos)
+ }
+
+ res = make([]byte, 2)
+ n, err = obj.Read(res)
+ if err != io.EOF {
+ t.Errorf("Expected err to be io.EOF, got %v", err)
+ }
+ if n != 1 {
+ t.Errorf("Expected n to be 1, got %d", n)
+ }
+ if res[0] != 't' {
+ t.Errorf("Expected res[0] to be 't', got %v", res[0])
+ }
+
+ err = obj.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = lo.Unlink(id)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = lo.Open(id, pgx.LargeObjectModeRead)
+ if e, ok := err.(pgx.PgError); !ok || e.Code != "42704" {
+ t.Errorf("Expected undefined_object error (42704), got %#v", err)
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/logger.go b/vendor/github.com/jackc/pgx/logger.go
new file mode 100644
index 0000000..4423325
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/logger.go
@@ -0,0 +1,81 @@
+package pgx
+
+import (
+ "encoding/hex"
+ "errors"
+ "fmt"
+)
+
+// The values for log levels are chosen such that the zero value means that no
+// log level was specified and we can default to LogLevelDebug to preserve
+// the behavior that existed prior to log level introduction.
+const (
+ LogLevelTrace = 6
+ LogLevelDebug = 5
+ LogLevelInfo = 4
+ LogLevelWarn = 3
+ LogLevelError = 2
+ LogLevelNone = 1
+)
+
+// Logger is the interface used to get logging from pgx internals.
+// https://github.com/inconshreveable/log15 is the recommended logging package.
+// This logging interface was extracted from there. However, it should be simple
+// to adapt any logger to this interface.
+type Logger interface {
+ // Log a message at the given level with context key/value pairs
+ Debug(msg string, ctx ...interface{})
+ Info(msg string, ctx ...interface{})
+ Warn(msg string, ctx ...interface{})
+ Error(msg string, ctx ...interface{})
+}
+
+// LogLevelFromString converts log level string to constant
+//
+// Valid levels:
+// trace
+// debug
+// info
+// warn
+// error
+// none
+func LogLevelFromString(s string) (int, error) {
+ switch s {
+ case "trace":
+ return LogLevelTrace, nil
+ case "debug":
+ return LogLevelDebug, nil
+ case "info":
+ return LogLevelInfo, nil
+ case "warn":
+ return LogLevelWarn, nil
+ case "error":
+ return LogLevelError, nil
+ case "none":
+ return LogLevelNone, nil
+ default:
+ return 0, errors.New("invalid log level")
+ }
+}
+
+func logQueryArgs(args []interface{}) []interface{} {
+ logArgs := make([]interface{}, 0, len(args))
+
+ for _, a := range args {
+ switch v := a.(type) {
+ case []byte:
+ if len(v) < 64 {
+ a = hex.EncodeToString(v)
+ } else {
+ a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
+ }
+ case string:
+ if len(v) > 64 {
+ a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
+ }
+ }
+ logArgs = append(logArgs, a)
+ }
+
+ return logArgs
+}
diff --git a/vendor/github.com/jackc/pgx/messages.go b/vendor/github.com/jackc/pgx/messages.go
new file mode 100644
index 0000000..317ba27
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/messages.go
@@ -0,0 +1,159 @@
+package pgx
+
+import (
+ "encoding/binary"
+)
+
+const (
+ protocolVersionNumber = 196608 // 3.0
+)
+
+const (
+ backendKeyData = 'K'
+ authenticationX = 'R'
+ readyForQuery = 'Z'
+ rowDescription = 'T'
+ dataRow = 'D'
+ commandComplete = 'C'
+ errorResponse = 'E'
+ noticeResponse = 'N'
+ parseComplete = '1'
+ parameterDescription = 't'
+ bindComplete = '2'
+ notificationResponse = 'A'
+ emptyQueryResponse = 'I'
+ noData = 'n'
+ closeComplete = '3'
+ flush = 'H'
+ copyInResponse = 'G'
+ copyData = 'd'
+ copyFail = 'f'
+ copyDone = 'c'
+)
+
+type startupMessage struct {
+ options map[string]string
+}
+
+func newStartupMessage() *startupMessage {
+ return &startupMessage{map[string]string{}}
+}
+
+func (s *startupMessage) Bytes() (buf []byte) {
+ buf = make([]byte, 8, 128)
+ binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber))
+ for key, value := range s.options {
+ buf = append(buf, key...)
+ buf = append(buf, 0)
+ buf = append(buf, value...)
+ buf = append(buf, 0)
+ }
+ buf = append(buf, ("\000")...)
+ binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf)))
+ return buf
+}
+
+type FieldDescription struct {
+ Name string
+ Table Oid
+ AttributeNumber int16
+ DataType Oid
+ DataTypeSize int16
+ DataTypeName string
+ Modifier int32
+ FormatCode int16
+}
+
+// PgError represents an error reported by the PostgreSQL server. See
+// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
+// detailed field description.
+type PgError struct {
+ Severity string
+ Code string
+ Message string
+ Detail string
+ Hint string
+ Position int32
+ InternalPosition int32
+ InternalQuery string
+ Where string
+ SchemaName string
+ TableName string
+ ColumnName string
+ DataTypeName string
+ ConstraintName string
+ File string
+ Line int32
+ Routine string
+}
+
+func (pe PgError) Error() string {
+ return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
+}
+
+func newWriteBuf(c *Conn, t byte) *WriteBuf {
+ buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
+ c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c}
+ return &c.writeBuf
+}
+
+// WriteBuf is used build messages to send to the PostgreSQL server. It is used
+// by the Encoder interface when implementing custom encoders.
+type WriteBuf struct {
+ buf []byte
+ sizeIdx int
+ conn *Conn
+}
+
+func (wb *WriteBuf) startMsg(t byte) {
+ wb.closeMsg()
+ wb.buf = append(wb.buf, t, 0, 0, 0, 0)
+ wb.sizeIdx = len(wb.buf) - 4
+}
+
+func (wb *WriteBuf) closeMsg() {
+ binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx))
+}
+
+func (wb *WriteBuf) WriteByte(b byte) {
+ wb.buf = append(wb.buf, b)
+}
+
+func (wb *WriteBuf) WriteCString(s string) {
+ wb.buf = append(wb.buf, []byte(s)...)
+ wb.buf = append(wb.buf, 0)
+}
+
+func (wb *WriteBuf) WriteInt16(n int16) {
+ b := make([]byte, 2)
+ binary.BigEndian.PutUint16(b, uint16(n))
+ wb.buf = append(wb.buf, b...)
+}
+
+func (wb *WriteBuf) WriteUint16(n uint16) {
+ b := make([]byte, 2)
+ binary.BigEndian.PutUint16(b, n)
+ wb.buf = append(wb.buf, b...)
+}
+
+func (wb *WriteBuf) WriteInt32(n int32) {
+ b := make([]byte, 4)
+ binary.BigEndian.PutUint32(b, uint32(n))
+ wb.buf = append(wb.buf, b...)
+}
+
+func (wb *WriteBuf) WriteUint32(n uint32) {
+ b := make([]byte, 4)
+ binary.BigEndian.PutUint32(b, n)
+ wb.buf = append(wb.buf, b...)
+}
+
+func (wb *WriteBuf) WriteInt64(n int64) {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint64(b, uint64(n))
+ wb.buf = append(wb.buf, b...)
+}
+
+func (wb *WriteBuf) WriteBytes(b []byte) {
+ wb.buf = append(wb.buf, b...)
+}
diff --git a/vendor/github.com/jackc/pgx/msg_reader.go b/vendor/github.com/jackc/pgx/msg_reader.go
new file mode 100644
index 0000000..21db5d2
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/msg_reader.go
@@ -0,0 +1,316 @@
+package pgx
+
+import (
+ "bufio"
+ "encoding/binary"
+ "errors"
+ "io"
+)
+
+// msgReader is a helper that reads values from a PostgreSQL message.
+type msgReader struct {
+ reader *bufio.Reader
+ msgBytesRemaining int32
+ err error
+ log func(lvl int, msg string, ctx ...interface{})
+ shouldLog func(lvl int) bool
+}
+
+// Err returns any error that the msgReader has experienced
+func (r *msgReader) Err() error {
+ return r.err
+}
+
+// fatal tells rc that a Fatal error has occurred
+func (r *msgReader) fatal(err error) {
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+ r.err = err
+}
+
+// rxMsg reads the type and size of the next message.
+func (r *msgReader) rxMsg() (byte, error) {
+ if r.err != nil {
+ return 0, r.err
+ }
+
+ if r.msgBytesRemaining > 0 {
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ _, err := r.reader.Discard(int(r.msgBytesRemaining))
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ b, err := r.reader.Peek(5)
+ if err != nil {
+ r.fatal(err)
+ return 0, err
+ }
+ msgType := b[0]
+ r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
+ r.reader.Discard(5)
+ return msgType, nil
+}
+
+func (r *msgReader) readByte() byte {
+ if r.err != nil {
+ return 0
+ }
+
+ r.msgBytesRemaining--
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return 0
+ }
+
+ b, err := r.reader.ReadByte()
+ if err != nil {
+ r.fatal(err)
+ return 0
+ }
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return b
+}
+
+func (r *msgReader) readInt16() int16 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.msgBytesRemaining -= 2
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return 0
+ }
+
+ b, err := r.reader.Peek(2)
+ if err != nil {
+ r.fatal(err)
+ return 0
+ }
+
+ n := int16(binary.BigEndian.Uint16(b))
+
+ r.reader.Discard(2)
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return n
+}
+
+func (r *msgReader) readInt32() int32 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.msgBytesRemaining -= 4
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return 0
+ }
+
+ b, err := r.reader.Peek(4)
+ if err != nil {
+ r.fatal(err)
+ return 0
+ }
+
+ n := int32(binary.BigEndian.Uint32(b))
+
+ r.reader.Discard(4)
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return n
+}
+
+func (r *msgReader) readUint16() uint16 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.msgBytesRemaining -= 2
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return 0
+ }
+
+ b, err := r.reader.Peek(2)
+ if err != nil {
+ r.fatal(err)
+ return 0
+ }
+
+ n := uint16(binary.BigEndian.Uint16(b))
+
+ r.reader.Discard(2)
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return n
+}
+
+func (r *msgReader) readUint32() uint32 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.msgBytesRemaining -= 4
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return 0
+ }
+
+ b, err := r.reader.Peek(4)
+ if err != nil {
+ r.fatal(err)
+ return 0
+ }
+
+ n := uint32(binary.BigEndian.Uint32(b))
+
+ r.reader.Discard(4)
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return n
+}
+
+func (r *msgReader) readInt64() int64 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.msgBytesRemaining -= 8
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return 0
+ }
+
+ b, err := r.reader.Peek(8)
+ if err != nil {
+ r.fatal(err)
+ return 0
+ }
+
+ n := int64(binary.BigEndian.Uint64(b))
+
+ r.reader.Discard(8)
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return n
+}
+
+func (r *msgReader) readOid() Oid {
+ return Oid(r.readInt32())
+}
+
+// readCString reads a null terminated string
+func (r *msgReader) readCString() string {
+ if r.err != nil {
+ return ""
+ }
+
+ b, err := r.reader.ReadBytes(0)
+ if err != nil {
+ r.fatal(err)
+ return ""
+ }
+
+ r.msgBytesRemaining -= int32(len(b))
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return ""
+ }
+
+ s := string(b[0 : len(b)-1])
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return s
+}
+
+// readString reads count bytes and returns as string
+func (r *msgReader) readString(countI32 int32) string {
+ if r.err != nil {
+ return ""
+ }
+
+ r.msgBytesRemaining -= countI32
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return ""
+ }
+
+ count := int(countI32)
+ var s string
+
+ if r.reader.Buffered() >= count {
+ buf, _ := r.reader.Peek(count)
+ s = string(buf)
+ r.reader.Discard(count)
+ } else {
+ buf := make([]byte, count)
+ _, err := io.ReadFull(r.reader, buf)
+ if err != nil {
+ r.fatal(err)
+ return ""
+ }
+ s = string(buf)
+ }
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return s
+}
+
+// readBytes reads count bytes and returns as []byte
+func (r *msgReader) readBytes(count int32) []byte {
+ if r.err != nil {
+ return nil
+ }
+
+ r.msgBytesRemaining -= count
+ if r.msgBytesRemaining < 0 {
+ r.fatal(errors.New("read past end of message"))
+ return nil
+ }
+
+ b := make([]byte, int(count))
+
+ _, err := io.ReadFull(r.reader, b)
+ if err != nil {
+ r.fatal(err)
+ return nil
+ }
+
+ if r.shouldLog(LogLevelTrace) {
+ r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
+ }
+
+ return b
+}
diff --git a/vendor/github.com/jackc/pgx/pgpass.go b/vendor/github.com/jackc/pgx/pgpass.go
new file mode 100644
index 0000000..b6f028d
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/pgpass.go
@@ -0,0 +1,85 @@
+package pgx
+
+import (
+ "bufio"
+ "fmt"
+ "os"
+ "os/user"
+ "path/filepath"
+ "strings"
+)
+
+func parsepgpass(cfg *ConnConfig, line string) *string {
+ const (
+ backslash = "\r"
+ colon = "\n"
+ )
+ const (
+ host int = iota
+ port
+ database
+ username
+ pw
+ )
+ line = strings.Replace(line, `\:`, colon, -1)
+ line = strings.Replace(line, `\\`, backslash, -1)
+ parts := strings.Split(line, `:`)
+ if len(parts) != 5 {
+ return nil
+ }
+ for i := range parts {
+ if parts[i] == `*` {
+ continue
+ }
+ parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1)
+ switch i {
+ case host:
+ if parts[i] != cfg.Host {
+ return nil
+ }
+ case port:
+ portstr := fmt.Sprintf(`%v`, cfg.Port)
+ if portstr == "0" {
+ portstr = "5432"
+ }
+ if parts[i] != portstr {
+ return nil
+ }
+ case database:
+ if parts[i] != cfg.Database {
+ return nil
+ }
+ case username:
+ if parts[i] != cfg.User {
+ return nil
+ }
+ }
+ }
+ return &parts[4]
+}
+
+func pgpass(cfg *ConnConfig) (found bool) {
+ passfile := os.Getenv("PGPASSFILE")
+ if passfile == "" {
+ u, err := user.Current()
+ if err != nil {
+ return
+ }
+ passfile = filepath.Join(u.HomeDir, ".pgpass")
+ }
+ f, err := os.Open(passfile)
+ if err != nil {
+ return
+ }
+ defer f.Close()
+ scanner := bufio.NewScanner(f)
+ var pw *string
+ for scanner.Scan() {
+ pw = parsepgpass(cfg, scanner.Text())
+ if pw != nil {
+ cfg.Password = *pw
+ return true
+ }
+ }
+ return false
+}
diff --git a/vendor/github.com/jackc/pgx/pgpass_test.go b/vendor/github.com/jackc/pgx/pgpass_test.go
new file mode 100644
index 0000000..f6094c8
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/pgpass_test.go
@@ -0,0 +1,57 @@
+package pgx
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+ "testing"
+)
+
+func unescape(s string) string {
+ s = strings.Replace(s, `\:`, `:`, -1)
+ s = strings.Replace(s, `\\`, `\`, -1)
+ return s
+}
+
+var passfile = [][]string{
+ []string{"test1", "5432", "larrydb", "larry", "whatstheidea"},
+ []string{"test1", "5432", "moedb", "moe", "imbecile"},
+ []string{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"},
+ []string{"test2", "5432", "*", "shemp", "heymoe"},
+ []string{"test2", "5432", "*", "*", `test\\ing\:`},
+}
+
+func TestPGPass(t *testing.T) {
+ tf, err := ioutil.TempFile("", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tf.Close()
+ defer os.Remove(tf.Name())
+ os.Setenv("PGPASSFILE", tf.Name())
+ for _, l := range passfile {
+ _, err := fmt.Fprintln(tf, strings.Join(l, `:`))
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ if err = tf.Close(); err != nil {
+ t.Fatal(err)
+ }
+ for i, l := range passfile {
+ cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]}
+ found := pgpass(&cfg)
+ if !found {
+ t.Fatalf("Entry %v not found", i)
+ }
+ if cfg.Password != unescape(l[4]) {
+ t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password)
+ }
+ }
+ cfg := ConnConfig{Host: "derp", Database: "herp", User: "joe"}
+ found := pgpass(&cfg)
+ if found {
+ t.Fatal("bad found")
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/query.go b/vendor/github.com/jackc/pgx/query.go
new file mode 100644
index 0000000..19b867e
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/query.go
@@ -0,0 +1,494 @@
+package pgx
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "time"
+)
+
+// Row is a convenience wrapper over Rows that is returned by QueryRow.
+type Row Rows
+
+// Scan works the same as (*Rows Scan) with the following exceptions. If no
+// rows were found it returns ErrNoRows. If multiple rows are returned it
+// ignores all but the first.
+func (r *Row) Scan(dest ...interface{}) (err error) {
+ rows := (*Rows)(r)
+
+ if rows.Err() != nil {
+ return rows.Err()
+ }
+
+ if !rows.Next() {
+ if rows.Err() == nil {
+ return ErrNoRows
+ }
+ return rows.Err()
+ }
+
+ rows.Scan(dest...)
+ rows.Close()
+ return rows.Err()
+}
+
+// Rows is the result set returned from *Conn.Query. Rows must be closed before
+// the *Conn can be used again. Rows are closed by explicitly calling Close(),
+// calling Next() until it returns false, or when a fatal error occurs.
+type Rows struct {
+ conn *Conn
+ mr *msgReader
+ fields []FieldDescription
+ vr ValueReader
+ rowCount int
+ columnIdx int
+ err error
+ startTime time.Time
+ sql string
+ args []interface{}
+ afterClose func(*Rows)
+ unlockConn bool
+ closed bool
+}
+
+func (rows *Rows) FieldDescriptions() []FieldDescription {
+ return rows.fields
+}
+
+func (rows *Rows) close() {
+ if rows.closed {
+ return
+ }
+
+ if rows.unlockConn {
+ rows.conn.unlock()
+ rows.unlockConn = false
+ }
+
+ rows.closed = true
+
+ if rows.err == nil {
+ if rows.conn.shouldLog(LogLevelInfo) {
+ endTime := time.Now()
+ rows.conn.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount)
+ }
+ } else if rows.conn.shouldLog(LogLevelError) {
+ rows.conn.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args))
+ }
+
+ if rows.afterClose != nil {
+ rows.afterClose(rows)
+ }
+}
+
+func (rows *Rows) readUntilReadyForQuery() {
+ for {
+ t, r, err := rows.conn.rxMsg()
+ if err != nil {
+ rows.close()
+ return
+ }
+
+ switch t {
+ case readyForQuery:
+ rows.conn.rxReadyForQuery(r)
+ rows.close()
+ return
+ case rowDescription:
+ case dataRow:
+ case commandComplete:
+ case bindComplete:
+ case errorResponse:
+ err = rows.conn.rxErrorResponse(r)
+ if rows.err == nil {
+ rows.err = err
+ }
+ default:
+ err = rows.conn.processContextFreeMsg(t, r)
+ if err != nil {
+ rows.close()
+ return
+ }
+ }
+ }
+}
+
+// Close closes the rows, making the connection ready for use again. It is safe
+// to call Close after rows is already closed.
+func (rows *Rows) Close() {
+ if rows.closed {
+ return
+ }
+ rows.readUntilReadyForQuery()
+ rows.close()
+}
+
+func (rows *Rows) Err() error {
+ return rows.err
+}
+
+// abort signals that the query was not successfully sent to the server.
+// This differs from Fatal in that it is not necessary to readUntilReadyForQuery
+func (rows *Rows) abort(err error) {
+ if rows.err != nil {
+ return
+ }
+
+ rows.err = err
+ rows.close()
+}
+
+// Fatal signals an error occurred after the query was sent to the server. It
+// closes the rows automatically.
+func (rows *Rows) Fatal(err error) {
+ if rows.err != nil {
+ return
+ }
+
+ rows.err = err
+ rows.Close()
+}
+
+// Next prepares the next row for reading. It returns true if there is another
+// row and false if no more rows are available. It automatically closes rows
+// when all rows are read.
+func (rows *Rows) Next() bool {
+ if rows.closed {
+ return false
+ }
+
+ rows.rowCount++
+ rows.columnIdx = 0
+ rows.vr = ValueReader{}
+
+ for {
+ t, r, err := rows.conn.rxMsg()
+ if err != nil {
+ rows.Fatal(err)
+ return false
+ }
+
+ switch t {
+ case readyForQuery:
+ rows.conn.rxReadyForQuery(r)
+ rows.close()
+ return false
+ case dataRow:
+ fieldCount := r.readInt16()
+ if int(fieldCount) != len(rows.fields) {
+ rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
+ return false
+ }
+
+ rows.mr = r
+ return true
+ case commandComplete:
+ case bindComplete:
+ default:
+ err = rows.conn.processContextFreeMsg(t, r)
+ if err != nil {
+ rows.Fatal(err)
+ return false
+ }
+ }
+ }
+}
+
+// Conn returns the *Conn this *Rows is using.
+func (rows *Rows) Conn() *Conn {
+ return rows.conn
+}
+
+func (rows *Rows) nextColumn() (*ValueReader, bool) {
+ if rows.closed {
+ return nil, false
+ }
+ if len(rows.fields) <= rows.columnIdx {
+ rows.Fatal(ProtocolError("No next column available"))
+ return nil, false
+ }
+
+ if rows.vr.Len() > 0 {
+ rows.mr.readBytes(rows.vr.Len())
+ }
+
+ fd := &rows.fields[rows.columnIdx]
+ rows.columnIdx++
+ size := rows.mr.readInt32()
+ rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size}
+ return &rows.vr, true
+}
+
+type scanArgError struct {
+ col int
+ err error
+}
+
+func (e scanArgError) Error() string {
+ return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
+}
+
+// Scan reads the values from the current row into dest values positionally.
+// dest can include pointers to core types, values implementing the Scanner
+// interface, []byte, and nil. []byte will skip the decoding process and directly
+// copy the raw bytes received from PostgreSQL. nil will skip the value entirely.
+func (rows *Rows) Scan(dest ...interface{}) (err error) {
+ if len(rows.fields) != len(dest) {
+ err = fmt.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
+ rows.Fatal(err)
+ return err
+ }
+
+ for i, d := range dest {
+ vr, _ := rows.nextColumn()
+
+ if d == nil {
+ continue
+ }
+
+ // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes
+ if b, ok := d.(*[]byte); ok {
+ // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format)
+ // Otherwise read the bytes directly regardless of what the actual type is.
+ if vr.Type().DataType == ByteaOid {
+ *b = decodeBytea(vr)
+ } else {
+ if vr.Len() != -1 {
+ *b = vr.ReadBytes(vr.Len())
+ } else {
+ *b = nil
+ }
+ }
+ } else if s, ok := d.(Scanner); ok {
+ err = s.Scan(vr)
+ if err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ } else if s, ok := d.(PgxScanner); ok {
+ err = s.ScanPgx(vr)
+ if err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ } else if s, ok := d.(sql.Scanner); ok {
+ var val interface{}
+ if 0 <= vr.Len() {
+ switch vr.Type().DataType {
+ case BoolOid:
+ val = decodeBool(vr)
+ case Int8Oid:
+ val = int64(decodeInt8(vr))
+ case Int2Oid:
+ val = int64(decodeInt2(vr))
+ case Int4Oid:
+ val = int64(decodeInt4(vr))
+ case TextOid, VarcharOid:
+ val = decodeText(vr)
+ case OidOid:
+ val = int64(decodeOid(vr))
+ case Float4Oid:
+ val = float64(decodeFloat4(vr))
+ case Float8Oid:
+ val = decodeFloat8(vr)
+ case DateOid:
+ val = decodeDate(vr)
+ case TimestampOid:
+ val = decodeTimestamp(vr)
+ case TimestampTzOid:
+ val = decodeTimestampTz(vr)
+ default:
+ val = vr.ReadBytes(vr.Len())
+ }
+ }
+ err = s.Scan(val)
+ if err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ } else if vr.Type().DataType == JsonOid {
+ // Because the argument passed to decodeJSON will escape the heap.
+ // This allows d to be stack allocated and only copied to the heap when
+ // we actually are decoding JSON. This saves one memory allocation per
+ // row.
+ d2 := d
+ decodeJSON(vr, &d2)
+ } else if vr.Type().DataType == JsonbOid {
+ // Same trick as above for getting stack allocation
+ d2 := d
+ decodeJSONB(vr, &d2)
+ } else {
+ if err := Decode(vr, d); err != nil {
+ rows.Fatal(scanArgError{col: i, err: err})
+ }
+ }
+ if vr.Err() != nil {
+ rows.Fatal(scanArgError{col: i, err: vr.Err()})
+ }
+
+ if rows.Err() != nil {
+ return rows.Err()
+ }
+ }
+
+ return nil
+}
+
+// Values returns an array of the row values
+func (rows *Rows) Values() ([]interface{}, error) {
+ if rows.closed {
+ return nil, errors.New("rows is closed")
+ }
+
+ values := make([]interface{}, 0, len(rows.fields))
+
+ for range rows.fields {
+ vr, _ := rows.nextColumn()
+
+ if vr.Len() == -1 {
+ values = append(values, nil)
+ continue
+ }
+
+ switch vr.Type().FormatCode {
+ // All intrinsic types (except string) are encoded with binary
+ // encoding so anything else should be treated as a string
+ case TextFormatCode:
+ values = append(values, vr.ReadString(vr.Len()))
+ case BinaryFormatCode:
+ switch vr.Type().DataType {
+ case TextOid, VarcharOid:
+ values = append(values, decodeText(vr))
+ case BoolOid:
+ values = append(values, decodeBool(vr))
+ case ByteaOid:
+ values = append(values, decodeBytea(vr))
+ case Int8Oid:
+ values = append(values, decodeInt8(vr))
+ case Int2Oid:
+ values = append(values, decodeInt2(vr))
+ case Int4Oid:
+ values = append(values, decodeInt4(vr))
+ case OidOid:
+ values = append(values, decodeOid(vr))
+ case Float4Oid:
+ values = append(values, decodeFloat4(vr))
+ case Float8Oid:
+ values = append(values, decodeFloat8(vr))
+ case BoolArrayOid:
+ values = append(values, decodeBoolArray(vr))
+ case Int2ArrayOid:
+ values = append(values, decodeInt2Array(vr))
+ case Int4ArrayOid:
+ values = append(values, decodeInt4Array(vr))
+ case Int8ArrayOid:
+ values = append(values, decodeInt8Array(vr))
+ case Float4ArrayOid:
+ values = append(values, decodeFloat4Array(vr))
+ case Float8ArrayOid:
+ values = append(values, decodeFloat8Array(vr))
+ case TextArrayOid, VarcharArrayOid:
+ values = append(values, decodeTextArray(vr))
+ case TimestampArrayOid, TimestampTzArrayOid:
+ values = append(values, decodeTimestampArray(vr))
+ case DateOid:
+ values = append(values, decodeDate(vr))
+ case TimestampTzOid:
+ values = append(values, decodeTimestampTz(vr))
+ case TimestampOid:
+ values = append(values, decodeTimestamp(vr))
+ case InetOid, CidrOid:
+ values = append(values, decodeInet(vr))
+ case JsonOid:
+ var d interface{}
+ decodeJSON(vr, &d)
+ values = append(values, d)
+ case JsonbOid:
+ var d interface{}
+ decodeJSONB(vr, &d)
+ values = append(values, d)
+ default:
+ rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
+ }
+ default:
+ rows.Fatal(errors.New("Unknown format code"))
+ }
+
+ if vr.Err() != nil {
+ rows.Fatal(vr.Err())
+ }
+
+ if rows.Err() != nil {
+ return nil, rows.Err()
+ }
+ }
+
+ return values, rows.Err()
+}
+
+// AfterClose adds f to a LILO queue of functions that will be called when
+// rows is closed.
+func (rows *Rows) AfterClose(f func(*Rows)) {
+ if rows.afterClose == nil {
+ rows.afterClose = f
+ } else {
+ prevFn := rows.afterClose
+ rows.afterClose = func(rows *Rows) {
+ f(rows)
+ prevFn(rows)
+ }
+ }
+}
+
+// Query executes sql with args. If there is an error the returned *Rows will
+// be returned in an error state. So it is allowed to ignore the error returned
+// from Query and handle it in *Rows.
+func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
+ c.lastActivityTime = time.Now()
+
+ rows := c.getRows(sql, args)
+
+ if err := c.lock(); err != nil {
+ rows.abort(err)
+ return rows, err
+ }
+ rows.unlockConn = true
+
+ ps, ok := c.preparedStatements[sql]
+ if !ok {
+ var err error
+ ps, err = c.Prepare("", sql)
+ if err != nil {
+ rows.abort(err)
+ return rows, rows.err
+ }
+ }
+ rows.sql = ps.SQL
+ rows.fields = ps.FieldDescriptions
+ err := c.sendPreparedQuery(ps, args...)
+ if err != nil {
+ rows.abort(err)
+ }
+ return rows, rows.err
+}
+
+func (c *Conn) getRows(sql string, args []interface{}) *Rows {
+ if len(c.preallocatedRows) == 0 {
+ c.preallocatedRows = make([]Rows, 64)
+ }
+
+ r := &c.preallocatedRows[len(c.preallocatedRows)-1]
+ c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
+
+ r.conn = c
+ r.startTime = c.lastActivityTime
+ r.sql = sql
+ r.args = args
+
+ return r
+}
+
+// QueryRow is a convenience wrapper over Query. Any error that occurs while
+// querying is deferred until calling Scan on the returned *Row. That *Row will
+// error with ErrNoRows if no rows are returned.
+func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
+ rows, _ := c.Query(sql, args...)
+ return (*Row)(rows)
+}
diff --git a/vendor/github.com/jackc/pgx/query_test.go b/vendor/github.com/jackc/pgx/query_test.go
new file mode 100644
index 0000000..f08887b
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/query_test.go
@@ -0,0 +1,1414 @@
+package pgx_test
+
+import (
+ "bytes"
+ "database/sql"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx"
+
+ "github.com/shopspring/decimal"
+)
+
+func TestConnQueryScan(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var sum, rowCount int32
+
+ rows, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var n int32
+ rows.Scan(&n)
+ sum += n
+ rowCount++
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ if rowCount != 10 {
+ t.Error("Select called onDataRow wrong number of times")
+ }
+ if sum != 55 {
+ t.Error("Wrong values returned")
+ }
+}
+
+func TestConnQueryValues(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var rowCount int32
+
+ rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n::oid from generate_series(1,$1) n", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ rowCount++
+
+ values, err := rows.Values()
+ if err != nil {
+ t.Fatalf("rows.Values failed: %v", err)
+ }
+ if len(values) != 5 {
+ t.Errorf("Expected rows.Values to return 5 values, but it returned %d", len(values))
+ }
+ if values[0] != "foo" {
+ t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0])
+ }
+ if values[1] != "bar" {
+ t.Errorf(`Expected values[1] to be "bar", but it was %v`, values[1])
+ }
+
+ if values[2] != rowCount {
+ t.Errorf(`Expected values[2] to be %d, but it was %d`, rowCount, values[2])
+ }
+
+ if values[3] != nil {
+ t.Errorf(`Expected values[3] to be %v, but it was %d`, nil, values[3])
+ }
+
+ if values[4] != pgx.Oid(rowCount) {
+ t.Errorf(`Expected values[4] to be %d, but it was %d`, rowCount, values[4])
+ }
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ if rowCount != 10 {
+ t.Error("Select called onDataRow wrong number of times")
+ }
+}
+
+// Test that a connection stays valid when query results are closed early
+func TestConnQueryCloseEarly(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ // Immediately close query without reading any rows
+ rows, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ rows.Close()
+
+ ensureConnValid(t, conn)
+
+ // Read partial response then close
+ rows, err = conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ ok := rows.Next()
+ if !ok {
+ t.Fatal("rows.Next terminated early")
+ }
+
+ var n int32
+ rows.Scan(&n)
+ if n != 1 {
+ t.Fatalf("Expected 1 from first row, but got %v", n)
+ }
+
+ rows.Close()
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select 1/(10-n) from generate_series(1,10) n")
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ rows.Close()
+
+ ensureConnValid(t, conn)
+}
+
+// Test that a connection stays valid when query results read incorrectly
+func TestConnQueryReadWrongTypeError(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ // Read a single value incorrectly
+ rows, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ rowsRead := 0
+
+ for rows.Next() {
+ var t time.Time
+ rows.Scan(&t)
+ rowsRead++
+ }
+
+ if rowsRead != 1 {
+ t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
+ }
+
+ if rows.Err() == nil {
+ t.Fatal("Expected Rows to have an error after an improper read but it didn't")
+ }
+
+ if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" {
+ t.Fatalf("Expected different Rows.Err(): %v", rows.Err())
+ }
+
+ ensureConnValid(t, conn)
+}
+
+// Test that a connection stays valid when query results read incorrectly
+func TestConnQueryReadTooManyValues(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ // Read too many values
+ rows, err := conn.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ rowsRead := 0
+
+ for rows.Next() {
+ var n, m int32
+ rows.Scan(&n, &m)
+ rowsRead++
+ }
+
+ if rowsRead != 1 {
+ t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
+ }
+
+ if rows.Err() == nil {
+ t.Fatal("Expected Rows to have an error after an improper read but it didn't")
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnQueryScanIgnoreColumn(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select 1::int8, 2::int8, 3::int8")
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ ok := rows.Next()
+ if !ok {
+ t.Fatal("rows.Next terminated early")
+ }
+
+ var n, m int64
+ err = rows.Scan(&n, nil, &m)
+ if err != nil {
+ t.Fatalf("rows.Scan failed: %v", err)
+ }
+ rows.Close()
+
+ if n != 1 {
+ t.Errorf("Expected n to equal 1, but it was %d", n)
+ }
+
+ if m != 3 {
+ t.Errorf("Expected n to equal 3, but it was %d", m)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnQueryScanner(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select null::int8, 1::int8")
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ ok := rows.Next()
+ if !ok {
+ t.Fatal("rows.Next terminated early")
+ }
+
+ var n, m pgx.NullInt64
+ err = rows.Scan(&n, &m)
+ if err != nil {
+ t.Fatalf("rows.Scan failed: %v", err)
+ }
+ rows.Close()
+
+ if n.Valid {
+ t.Error("Null should not be valid, but it was")
+ }
+
+ if !m.Valid {
+ t.Error("1 should be valid, but it wasn't")
+ }
+
+ if m.Int64 != 1 {
+ t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+type pgxNullInt64 struct {
+ Int64 int64
+ Valid bool // Valid is true if Int64 is not NULL
+}
+
+func (n *pgxNullInt64) ScanPgx(vr *pgx.ValueReader) error {
+ if vr.Type().DataType != pgx.Int8Oid {
+ return pgx.SerializationError(fmt.Sprintf("pgxNullInt64.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Int64, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+
+ err := pgx.Decode(vr, &n.Int64)
+ if err != nil {
+ return err
+ }
+ return vr.Err()
+}
+
+func TestConnQueryPgxScanner(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select null::int8, 1::int8")
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ ok := rows.Next()
+ if !ok {
+ t.Fatal("rows.Next terminated early")
+ }
+
+ var n, m pgxNullInt64
+ err = rows.Scan(&n, &m)
+ if err != nil {
+ t.Fatalf("rows.Scan failed: %v", err)
+ }
+ rows.Close()
+
+ if n.Valid {
+ t.Error("Null should not be valid, but it was")
+ }
+
+ if !m.Valid {
+ t.Error("1 should be valid, but it wasn't")
+ }
+
+ if m.Int64 != 1 {
+ t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnQueryErrorWhileReturningRows(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ for i := 0; i < 100; i++ {
+ func() {
+ sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
+
+ rows, err := conn.Query(sql)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var n int32
+ rows.Scan(&n)
+ }
+
+ if err, ok := rows.Err().(pgx.PgError); !ok {
+ t.Fatalf("Expected pgx.PgError, got %v", err)
+ }
+
+ ensureConnValid(t, conn)
+ }()
+ }
+
+}
+
+func TestConnQueryEncoder(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ n := pgx.NullInt64{Int64: 1, Valid: true}
+
+ rows, err := conn.Query("select $1::int8", &n)
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ ok := rows.Next()
+ if !ok {
+ t.Fatal("rows.Next terminated early")
+ }
+
+ var m pgx.NullInt64
+ err = rows.Scan(&m)
+ if err != nil {
+ t.Fatalf("rows.Scan failed: %v", err)
+ }
+ rows.Close()
+
+ if !m.Valid {
+ t.Error("m should be valid, but it wasn't")
+ }
+
+ if m.Int64 != 1 {
+ t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryEncodeError(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select $1::integer", "wrong")
+ if err != nil {
+ t.Errorf("conn.Query failure: %v", err)
+ }
+ defer rows.Close()
+
+ rows.Next()
+
+ if rows.Err() == nil {
+ t.Error("Expected rows.Err() to return error, but it didn't")
+ }
+ if rows.Err().Error() != `ERROR: invalid input syntax for integer: "wrong" (SQLSTATE 22P02)` {
+ t.Error("Expected rows.Err() to return different error:", rows.Err())
+ }
+}
+
+// Ensure that an argument that implements Encoder works when the parameter type
+// is a core type.
+type coreEncoder struct{}
+
+func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode }
+
+func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
+ w.WriteInt32(int32(2))
+ w.WriteBytes([]byte("42"))
+ return nil
+}
+
+func TestQueryEncodeCoreTextFormatError(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var n int32
+ err := conn.QueryRow("select $1::integer", &coreEncoder{}).Scan(&n)
+ if err != nil {
+ t.Fatalf("Unexpected conn.QueryRow error: %v", err)
+ }
+
+ if n != 42 {
+ t.Errorf("Expected 42, got %v", n)
+ }
+}
+
+func TestQueryRowCoreTypes(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type allTypes struct {
+ s string
+ f32 float32
+ f64 float64
+ b bool
+ t time.Time
+ oid pgx.Oid
+ }
+
+ var actual, zero allTypes
+
+ tests := []struct {
+ sql string
+ queryArgs []interface{}
+ scanArgs []interface{}
+ expected allTypes
+ }{
+ {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}},
+ {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}},
+ {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}},
+ {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}},
+ {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}},
+ {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}},
+ {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}},
+ {"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}},
+ }
+
+ for i, tt := range tests {
+ actual = zero
+
+ err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
+ }
+
+ if actual != tt.expected {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs)
+ }
+
+ ensureConnValid(t, conn)
+
+ // Check that Scan errors when a core type is null
+ err = conn.QueryRow(tt.sql, nil).Scan(tt.scanArgs...)
+ if err == nil {
+ t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql)
+ }
+ if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
+ t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestQueryRowCoreIntegerEncoding(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type allTypes struct {
+ ui uint
+ ui8 uint8
+ ui16 uint16
+ ui32 uint32
+ ui64 uint64
+ i int
+ i8 int8
+ i16 int16
+ i32 int32
+ i64 int64
+ }
+
+ var actual, zero allTypes
+
+ successfulEncodeTests := []struct {
+ sql string
+ queryArg interface{}
+ scanArg interface{}
+ expected allTypes
+ }{
+ // Check any integer type where value is within int2 range can be encoded
+ {"select $1::int2", int(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", int8(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", int16(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", int32(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", int64(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", uint(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", uint8(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", uint16(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", uint32(42), &actual.i16, allTypes{i16: 42}},
+ {"select $1::int2", uint64(42), &actual.i16, allTypes{i16: 42}},
+
+ // Check any integer type where value is within int4 range can be encoded
+ {"select $1::int4", int(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", int8(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", int16(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", int32(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", int64(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", uint(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", uint8(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", uint16(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", uint32(42), &actual.i32, allTypes{i32: 42}},
+ {"select $1::int4", uint64(42), &actual.i32, allTypes{i32: 42}},
+
+ // Check any integer type where value is within int8 range can be encoded
+ {"select $1::int8", int(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", int8(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", int16(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", int32(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", int64(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", uint(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", uint8(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", uint16(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", uint32(42), &actual.i64, allTypes{i64: 42}},
+ {"select $1::int8", uint64(42), &actual.i64, allTypes{i64: 42}},
+ }
+
+ for i, tt := range successfulEncodeTests {
+ actual = zero
+
+ err := conn.QueryRow(tt.sql, tt.queryArg).Scan(tt.scanArg)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
+ continue
+ }
+
+ if actual != tt.expected {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArg -> %v)", i, tt.expected, actual, tt.sql, tt.queryArg)
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ failedEncodeTests := []struct {
+ sql string
+ queryArg interface{}
+ }{
+ // Check any integer type where value is outside pg:int2 range cannot be encoded
+ {"select $1::int2", int(32769)},
+ {"select $1::int2", int32(32769)},
+ {"select $1::int2", int32(32769)},
+ {"select $1::int2", int64(32769)},
+ {"select $1::int2", uint(32769)},
+ {"select $1::int2", uint16(32769)},
+ {"select $1::int2", uint32(32769)},
+ {"select $1::int2", uint64(32769)},
+
+ // Check any integer type where value is outside pg:int4 range cannot be encoded
+ {"select $1::int4", int64(2147483649)},
+ {"select $1::int4", uint32(2147483649)},
+ {"select $1::int4", uint64(2147483649)},
+
+ // Check any integer type where value is outside pg:int8 range cannot be encoded
+ {"select $1::int8", uint64(9223372036854775809)},
+ }
+
+ for i, tt := range failedEncodeTests {
+ err := conn.QueryRow(tt.sql, tt.queryArg).Scan(nil)
+ if err == nil {
+ t.Errorf("%d. Expected failure to encode, but unexpectedly succeeded: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
+ } else if !strings.Contains(err.Error(), "is greater than") {
+ t.Errorf("%d. Expected failure to encode, but got: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestQueryRowCoreIntegerDecoding(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type allTypes struct {
+ ui uint
+ ui8 uint8
+ ui16 uint16
+ ui32 uint32
+ ui64 uint64
+ i int
+ i8 int8
+ i16 int16
+ i32 int32
+ i64 int64
+ }
+
+ var actual, zero allTypes
+
+ successfulDecodeTests := []struct {
+ sql string
+ scanArg interface{}
+ expected allTypes
+ }{
+ // Check any integer type where value is within Go:int range can be decoded
+ {"select 42::int2", &actual.i, allTypes{i: 42}},
+ {"select 42::int4", &actual.i, allTypes{i: 42}},
+ {"select 42::int8", &actual.i, allTypes{i: 42}},
+ {"select -42::int2", &actual.i, allTypes{i: -42}},
+ {"select -42::int4", &actual.i, allTypes{i: -42}},
+ {"select -42::int8", &actual.i, allTypes{i: -42}},
+
+ // Check any integer type where value is within Go:int8 range can be decoded
+ {"select 42::int2", &actual.i8, allTypes{i8: 42}},
+ {"select 42::int4", &actual.i8, allTypes{i8: 42}},
+ {"select 42::int8", &actual.i8, allTypes{i8: 42}},
+ {"select -42::int2", &actual.i8, allTypes{i8: -42}},
+ {"select -42::int4", &actual.i8, allTypes{i8: -42}},
+ {"select -42::int8", &actual.i8, allTypes{i8: -42}},
+
+ // Check any integer type where value is within Go:int16 range can be decoded
+ {"select 42::int2", &actual.i16, allTypes{i16: 42}},
+ {"select 42::int4", &actual.i16, allTypes{i16: 42}},
+ {"select 42::int8", &actual.i16, allTypes{i16: 42}},
+ {"select -42::int2", &actual.i16, allTypes{i16: -42}},
+ {"select -42::int4", &actual.i16, allTypes{i16: -42}},
+ {"select -42::int8", &actual.i16, allTypes{i16: -42}},
+
+ // Check any integer type where value is within Go:int32 range can be decoded
+ {"select 42::int2", &actual.i32, allTypes{i32: 42}},
+ {"select 42::int4", &actual.i32, allTypes{i32: 42}},
+ {"select 42::int8", &actual.i32, allTypes{i32: 42}},
+ {"select -42::int2", &actual.i32, allTypes{i32: -42}},
+ {"select -42::int4", &actual.i32, allTypes{i32: -42}},
+ {"select -42::int8", &actual.i32, allTypes{i32: -42}},
+
+ // Check any integer type where value is within Go:int64 range can be decoded
+ {"select 42::int2", &actual.i64, allTypes{i64: 42}},
+ {"select 42::int4", &actual.i64, allTypes{i64: 42}},
+ {"select 42::int8", &actual.i64, allTypes{i64: 42}},
+ {"select -42::int2", &actual.i64, allTypes{i64: -42}},
+ {"select -42::int4", &actual.i64, allTypes{i64: -42}},
+ {"select -42::int8", &actual.i64, allTypes{i64: -42}},
+
+ // Check any integer type where value is within Go:uint range can be decoded
+ {"select 128::int2", &actual.ui, allTypes{ui: 128}},
+ {"select 128::int4", &actual.ui, allTypes{ui: 128}},
+ {"select 128::int8", &actual.ui, allTypes{ui: 128}},
+
+ // Check any integer type where value is within Go:uint8 range can be decoded
+ {"select 128::int2", &actual.ui8, allTypes{ui8: 128}},
+ {"select 128::int4", &actual.ui8, allTypes{ui8: 128}},
+ {"select 128::int8", &actual.ui8, allTypes{ui8: 128}},
+
+ // Check any integer type where value is within Go:uint16 range can be decoded
+ {"select 42::int2", &actual.ui16, allTypes{ui16: 42}},
+ {"select 32768::int4", &actual.ui16, allTypes{ui16: 32768}},
+ {"select 32768::int8", &actual.ui16, allTypes{ui16: 32768}},
+
+ // Check any integer type where value is within Go:uint32 range can be decoded
+ {"select 42::int2", &actual.ui32, allTypes{ui32: 42}},
+ {"select 42::int4", &actual.ui32, allTypes{ui32: 42}},
+ {"select 2147483648::int8", &actual.ui32, allTypes{ui32: 2147483648}},
+
+ // Check any integer type where value is within Go:uint64 range can be decoded
+ {"select 42::int2", &actual.ui64, allTypes{ui64: 42}},
+ {"select 42::int4", &actual.ui64, allTypes{ui64: 42}},
+ {"select 42::int8", &actual.ui64, allTypes{ui64: 42}},
+ }
+
+ for i, tt := range successfulDecodeTests {
+ actual = zero
+
+ err := conn.QueryRow(tt.sql).Scan(tt.scanArg)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
+ continue
+ }
+
+ if actual != tt.expected {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ failedDecodeTests := []struct {
+ sql string
+ scanArg interface{}
+ expectedErr string
+ }{
+ // Check any integer type where value is outside Go:int8 range cannot be decoded
+ {"select 128::int2", &actual.i8, "is greater than"},
+ {"select 128::int4", &actual.i8, "is greater than"},
+ {"select 128::int8", &actual.i8, "is greater than"},
+ {"select -129::int2", &actual.i8, "is less than"},
+ {"select -129::int4", &actual.i8, "is less than"},
+ {"select -129::int8", &actual.i8, "is less than"},
+
+ // Check any integer type where value is outside Go:int16 range cannot be decoded
+ {"select 32768::int4", &actual.i16, "is greater than"},
+ {"select 32768::int8", &actual.i16, "is greater than"},
+ {"select -32769::int4", &actual.i16, "is less than"},
+ {"select -32769::int8", &actual.i16, "is less than"},
+
+ // Check any integer type where value is outside Go:int32 range cannot be decoded
+ {"select 2147483648::int8", &actual.i32, "is greater than"},
+ {"select -2147483649::int8", &actual.i32, "is less than"},
+
+ // Check any integer type where value is outside Go:uint range cannot be decoded
+ {"select -1::int2", &actual.ui, "is less than"},
+ {"select -1::int4", &actual.ui, "is less than"},
+ {"select -1::int8", &actual.ui, "is less than"},
+
+ // Check any integer type where value is outside Go:uint8 range cannot be decoded
+ {"select 256::int2", &actual.ui8, "is greater than"},
+ {"select 256::int4", &actual.ui8, "is greater than"},
+ {"select 256::int8", &actual.ui8, "is greater than"},
+ {"select -1::int2", &actual.ui8, "is less than"},
+ {"select -1::int4", &actual.ui8, "is less than"},
+ {"select -1::int8", &actual.ui8, "is less than"},
+
+ // Check any integer type where value is outside Go:uint16 cannot be decoded
+ {"select 65536::int4", &actual.ui16, "is greater than"},
+ {"select 65536::int8", &actual.ui16, "is greater than"},
+ {"select -1::int2", &actual.ui16, "is less than"},
+ {"select -1::int4", &actual.ui16, "is less than"},
+ {"select -1::int8", &actual.ui16, "is less than"},
+
+ // Check any integer type where value is outside Go:uint32 range cannot be decoded
+ {"select 4294967296::int8", &actual.ui32, "is greater than"},
+ {"select -1::int2", &actual.ui32, "is less than"},
+ {"select -1::int4", &actual.ui32, "is less than"},
+ {"select -1::int8", &actual.ui32, "is less than"},
+
+ // Check any integer type where value is outside Go:uint64 range cannot be decoded
+ {"select -1::int2", &actual.ui64, "is less than"},
+ {"select -1::int4", &actual.ui64, "is less than"},
+ {"select -1::int8", &actual.ui64, "is less than"},
+ }
+
+ for i, tt := range failedDecodeTests {
+ err := conn.QueryRow(tt.sql).Scan(tt.scanArg)
+ if err == nil {
+ t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql)
+ } else if !strings.Contains(err.Error(), tt.expectedErr) {
+ t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestQueryRowCoreByteSlice(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ queryArg interface{}
+ expected []byte
+ }{
+ {"select $1::text", "Jack", []byte("Jack")},
+ {"select $1::text", []byte("Jack"), []byte("Jack")},
+ {"select $1::int4", int32(239023409), []byte{14, 63, 53, 49}},
+ {"select $1::varchar", []byte("Jack"), []byte("Jack")},
+ {"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}},
+ }
+
+ for i, tt := range tests {
+ var actual []byte
+
+ err := conn.QueryRow(tt.sql, tt.queryArg).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
+ }
+
+ if !bytes.Equal(actual, tt.expected) {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestQueryRowByteSliceArgument(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ sql := "select $1::int4"
+ queryArg := []byte{14, 63, 53, 49}
+ expected := int32(239023409)
+
+ var actual int32
+
+ err := conn.QueryRow(sql, queryArg).Scan(&actual)
+ if err != nil {
+ t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
+ }
+
+ if expected != actual {
+ t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowUnknownType(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ sql := "select $1::point"
+ expected := "(1,0)"
+ var actual string
+
+ err := conn.QueryRow(sql, expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
+ }
+
+ if actual != expected {
+ t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
+
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowErrors(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type allTypes struct {
+ i16 int16
+ i int
+ s string
+ }
+
+ var actual, zero allTypes
+
+ tests := []struct {
+ sql string
+ queryArgs []interface{}
+ scanArgs []interface{}
+ err string
+ }{
+ {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "could not determine data type of parameter $1 (SQLSTATE 42P18)"},
+ {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
+ {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
+ {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into any integer type"},
+ {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot encode int8 into oid 600"},
+ }
+
+ for i, tt := range tests {
+ actual = zero
+
+ err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
+ if err == nil {
+ t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs)
+ }
+ if err != nil && !strings.Contains(err.Error(), tt.err) {
+ t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestQueryRowNoResults(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var n int32
+ err := conn.QueryRow("select 1 where 1=0").Scan(&n)
+ if err != pgx.ErrNoRows {
+ t.Errorf("Expected pgx.ErrNoRows, got %v", err)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowCoreInt16Slice(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var actual []int16
+
+ tests := []struct {
+ sql string
+ expected []int16
+ }{
+ {"select $1::int2[]", []int16{1, 2, 3, 4, 5}},
+ {"select $1::int2[]", []int16{}},
+ }
+
+ for i, tt := range tests {
+ err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v", i, err)
+ }
+
+ if len(actual) != len(tt.expected) {
+ t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
+ }
+
+ for j := 0; j < len(actual); j++ {
+ if actual[j] != tt.expected[j] {
+ t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ // Check that Scan errors when an array with a null is scanned into a core slice type
+ err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual)
+ if err == nil {
+ t.Error("Expected null to cause error when scanned into slice, but it didn't")
+ }
+ if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
+ t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowCoreInt32Slice(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var actual []int32
+
+ tests := []struct {
+ sql string
+ expected []int32
+ }{
+ {"select $1::int4[]", []int32{1, 2, 3, 4, 5}},
+ {"select $1::int4[]", []int32{}},
+ }
+
+ for i, tt := range tests {
+ err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v", i, err)
+ }
+
+ if len(actual) != len(tt.expected) {
+ t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
+ }
+
+ for j := 0; j < len(actual); j++ {
+ if actual[j] != tt.expected[j] {
+ t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ // Check that Scan errors when an array with a null is scanned into a core slice type
+ err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual)
+ if err == nil {
+ t.Error("Expected null to cause error when scanned into slice, but it didn't")
+ }
+ if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
+ t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowCoreInt64Slice(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var actual []int64
+
+ tests := []struct {
+ sql string
+ expected []int64
+ }{
+ {"select $1::int8[]", []int64{1, 2, 3, 4, 5}},
+ {"select $1::int8[]", []int64{}},
+ }
+
+ for i, tt := range tests {
+ err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v", i, err)
+ }
+
+ if len(actual) != len(tt.expected) {
+ t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
+ }
+
+ for j := 0; j < len(actual); j++ {
+ if actual[j] != tt.expected[j] {
+ t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ // Check that Scan errors when an array with a null is scanned into a core slice type
+ err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual)
+ if err == nil {
+ t.Error("Expected null to cause error when scanned into slice, but it didn't")
+ }
+ if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
+ t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowCoreFloat32Slice(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var actual []float32
+
+ tests := []struct {
+ sql string
+ expected []float32
+ }{
+ {"select $1::float4[]", []float32{1.5, 2.0, 3.5}},
+ {"select $1::float4[]", []float32{}},
+ }
+
+ for i, tt := range tests {
+ err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v", i, err)
+ }
+
+ if len(actual) != len(tt.expected) {
+ t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
+ }
+
+ for j := 0; j < len(actual); j++ {
+ if actual[j] != tt.expected[j] {
+ t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ // Check that Scan errors when an array with a null is scanned into a core slice type
+ err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual)
+ if err == nil {
+ t.Error("Expected null to cause error when scanned into slice, but it didn't")
+ }
+ if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
+ t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowCoreFloat64Slice(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var actual []float64
+
+ tests := []struct {
+ sql string
+ expected []float64
+ }{
+ {"select $1::float8[]", []float64{1.5, 2.0, 3.5}},
+ {"select $1::float8[]", []float64{}},
+ }
+
+ for i, tt := range tests {
+ err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v", i, err)
+ }
+
+ if len(actual) != len(tt.expected) {
+ t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
+ }
+
+ for j := 0; j < len(actual); j++ {
+ if actual[j] != tt.expected[j] {
+ t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ // Check that Scan errors when an array with a null is scanned into a core slice type
+ err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual)
+ if err == nil {
+ t.Error("Expected null to cause error when scanned into slice, but it didn't")
+ }
+ if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
+ t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestQueryRowCoreStringSlice(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var actual []string
+
+ tests := []struct {
+ sql string
+ expected []string
+ }{
+ {"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
+ {"select $1::text[]", []string{}},
+ {"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
+ {"select $1::varchar[]", []string{}},
+ }
+
+ for i, tt := range tests {
+ err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v", i, err)
+ }
+
+ if len(actual) != len(tt.expected) {
+ t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
+ }
+
+ for j := 0; j < len(actual); j++ {
+ if actual[j] != tt.expected[j] {
+ t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
+ }
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ // Check that Scan errors when an array with a null is scanned into a core slice type
+ err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual)
+ if err == nil {
+ t.Error("Expected null to cause error when scanned into slice, but it didn't")
+ }
+ if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
+ t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestReadingValueAfterEmptyArray(t *testing.T) {
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var a []string
+ var b int32
+ err := conn.QueryRow("select '{}'::text[], 42::integer").Scan(&a, &b)
+ if err != nil {
+ t.Fatalf("conn.QueryRow failed: %v", err)
+ }
+
+ if len(a) != 0 {
+ t.Errorf("Expected 'a' to have length 0, but it was: %d", len(a))
+ }
+
+ if b != 42 {
+ t.Errorf("Expected 'b' to 42, but it was: %d", b)
+ }
+}
+
+func TestReadingNullByteArray(t *testing.T) {
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var a []byte
+ err := conn.QueryRow("select null::text").Scan(&a)
+ if err != nil {
+ t.Fatalf("conn.QueryRow failed: %v", err)
+ }
+
+ if a != nil {
+ t.Errorf("Expected 'a' to be nil, but it was: %v", a)
+ }
+}
+
+func TestReadingNullByteArrays(t *testing.T) {
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select null::text union all select null::text")
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+
+ count := 0
+ for rows.Next() {
+ count++
+ var a []byte
+ if err := rows.Scan(&a); err != nil {
+ t.Fatalf("failed to scan row: %v", err)
+ }
+ if a != nil {
+ t.Errorf("Expected 'a' to be nil, but it was: %v", a)
+ }
+ }
+ if count != 2 {
+ t.Errorf("Expected to read 2 rows, read: %d", count)
+ }
+}
+
+// Use github.com/shopspring/decimal as real-world database/sql custom type
+// to test against.
+func TestConnQueryDatabaseSQLScanner(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var num decimal.Decimal
+
+ err := conn.QueryRow("select '1234.567'::decimal").Scan(&num)
+ if err != nil {
+ t.Fatalf("Scan failed: %v", err)
+ }
+
+ expected, err := decimal.NewFromString("1234.567")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !num.Equals(expected) {
+ t.Errorf("Expected num to be %v, but it was %v", expected, num)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+// Use github.com/shopspring/decimal as real-world database/sql custom type
+// to test against.
+func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ expected, err := decimal.NewFromString("1234.567")
+ if err != nil {
+ t.Fatal(err)
+ }
+ var num decimal.Decimal
+
+ err = conn.QueryRow("select $1::decimal", &expected).Scan(&num)
+ if err != nil {
+ t.Fatalf("Scan failed: %v", err)
+ }
+
+ if !num.Equals(expected) {
+ t.Errorf("Expected num to be %v, but it was %v", expected, num)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestConnQueryDatabaseSQLNullX(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type row struct {
+ boolValid sql.NullBool
+ boolNull sql.NullBool
+ int64Valid sql.NullInt64
+ int64Null sql.NullInt64
+ float64Valid sql.NullFloat64
+ float64Null sql.NullFloat64
+ stringValid sql.NullString
+ stringNull sql.NullString
+ }
+
+ expected := row{
+ boolValid: sql.NullBool{Bool: true, Valid: true},
+ int64Valid: sql.NullInt64{Int64: 123, Valid: true},
+ float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true},
+ stringValid: sql.NullString{String: "pgx", Valid: true},
+ }
+
+ var actual row
+
+ err := conn.QueryRow(
+ "select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
+ expected.boolValid,
+ expected.boolNull,
+ expected.int64Valid,
+ expected.int64Null,
+ expected.float64Valid,
+ expected.float64Null,
+ expected.stringValid,
+ expected.stringNull,
+ ).Scan(
+ &actual.boolValid,
+ &actual.boolNull,
+ &actual.int64Valid,
+ &actual.int64Null,
+ &actual.float64Valid,
+ &actual.float64Null,
+ &actual.stringValid,
+ &actual.stringNull,
+ )
+ if err != nil {
+ t.Fatalf("Scan failed: %v", err)
+ }
+
+ if expected != actual {
+ t.Errorf("Expected %v, but got %v", expected, actual)
+ }
+
+ ensureConnValid(t, conn)
+}
diff --git a/vendor/github.com/jackc/pgx/replication.go b/vendor/github.com/jackc/pgx/replication.go
new file mode 100644
index 0000000..7b28d6b
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/replication.go
@@ -0,0 +1,429 @@
+package pgx
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "time"
+)
+
+const (
+ copyBothResponse = 'W'
+ walData = 'w'
+ senderKeepalive = 'k'
+ standbyStatusUpdate = 'r'
+ initialReplicationResponseTimeout = 5 * time.Second
+)
+
+var epochNano int64
+
+func init() {
+ epochNano = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixNano()
+}
+
+// Format the given 64bit LSN value into the XXX/XXX format,
+// which is the format reported by postgres.
+func FormatLSN(lsn uint64) string {
+ return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn))
+}
+
+// Parse the given XXX/XXX format LSN as reported by postgres,
+// into a 64 bit integer as used internally by the wire procotols
+func ParseLSN(lsn string) (outputLsn uint64, err error) {
+ var upperHalf uint64
+ var lowerHalf uint64
+ var nparsed int
+ nparsed, err = fmt.Sscanf(lsn, "%X/%X", &upperHalf, &lowerHalf)
+ if err != nil {
+ return
+ }
+
+ if nparsed != 2 {
+ err = errors.New(fmt.Sprintf("Failed to parsed LSN: %s", lsn))
+ return
+ }
+
+ outputLsn = (upperHalf << 32) + lowerHalf
+ return
+}
+
+// The WAL message contains WAL payload entry data
+type WalMessage struct {
+ // The WAL start position of this data. This
+ // is the WAL position we need to track.
+ WalStart uint64
+ // The server wal end and server time are
+ // documented to track the end position and current
+ // time of the server, both of which appear to be
+ // unimplemented in pg 9.5.
+ ServerWalEnd uint64
+ ServerTime uint64
+ // The WAL data is the raw unparsed binary WAL entry.
+ // The contents of this are determined by the output
+ // logical encoding plugin.
+ WalData []byte
+}
+
+func (w *WalMessage) Time() time.Time {
+ return time.Unix(0, (int64(w.ServerTime)*1000)+epochNano)
+}
+
+func (w *WalMessage) ByteLag() uint64 {
+ return (w.ServerWalEnd - w.WalStart)
+}
+
+func (w *WalMessage) String() string {
+ return fmt.Sprintf("Wal: %s Time: %s Lag: %d", FormatLSN(w.WalStart), w.Time(), w.ByteLag())
+}
+
+// The server heartbeat is sent periodically from the server,
+// including server status, and a reply request field
+type ServerHeartbeat struct {
+ // The current max wal position on the server,
+ // used for lag tracking
+ ServerWalEnd uint64
+ // The server time, in microseconds since jan 1 2000
+ ServerTime uint64
+ // If 1, the server is requesting a standby status message
+ // to be sent immediately.
+ ReplyRequested byte
+}
+
+func (s *ServerHeartbeat) Time() time.Time {
+ return time.Unix(0, (int64(s.ServerTime)*1000)+epochNano)
+}
+
+func (s *ServerHeartbeat) String() string {
+ return fmt.Sprintf("WalEnd: %s ReplyRequested: %d T: %s", FormatLSN(s.ServerWalEnd), s.ReplyRequested, s.Time())
+}
+
+// The replication message wraps all possible messages from the
+// server received during replication. At most one of the wal message
+// or server heartbeat will be non-nil
+type ReplicationMessage struct {
+ WalMessage *WalMessage
+ ServerHeartbeat *ServerHeartbeat
+}
+
+// The standby status is the client side heartbeat sent to the postgresql
+// server to track the client wal positions. For practical purposes,
+// all wal positions are typically set to the same value.
+type StandbyStatus struct {
+ // The WAL position that's been locally written
+ WalWritePosition uint64
+ // The WAL position that's been locally flushed
+ WalFlushPosition uint64
+ // The WAL position that's been locally applied
+ WalApplyPosition uint64
+ // The client time in microseconds since jan 1 2000
+ ClientTime uint64
+ // If 1, requests the server to immediately send a
+ // server heartbeat
+ ReplyRequested byte
+}
+
+// Create a standby status struct, which sets all the WAL positions
+// to the given wal position, and the client time to the current time.
+// The wal positions are, in order:
+// WalFlushPosition
+// WalApplyPosition
+// WalWritePosition
+//
+// If only one position is provided, it will be used as the value for all 3
+// status fields. Note you must provide either 1 wal position, or all 3
+// in order to initialize the standby status.
+func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error) {
+ if len(walPositions) == 1 {
+ status = new(StandbyStatus)
+ status.WalFlushPosition = walPositions[0]
+ status.WalApplyPosition = walPositions[0]
+ status.WalWritePosition = walPositions[0]
+ } else if len(walPositions) == 3 {
+ status = new(StandbyStatus)
+ status.WalFlushPosition = walPositions[0]
+ status.WalApplyPosition = walPositions[1]
+ status.WalWritePosition = walPositions[2]
+ } else {
+ err = errors.New(fmt.Sprintf("Invalid number of wal positions provided, need 1 or 3, got %d", len(walPositions)))
+ return
+ }
+ status.ClientTime = uint64((time.Now().UnixNano() - epochNano) / 1000)
+ return
+}
+
+func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) {
+ if config.RuntimeParams == nil {
+ config.RuntimeParams = make(map[string]string)
+ }
+ config.RuntimeParams["replication"] = "database"
+
+ c, err := Connect(config)
+ if err != nil {
+ return
+ }
+ return &ReplicationConn{c: c}, nil
+}
+
+type ReplicationConn struct {
+ c *Conn
+}
+
+// Send standby status to the server, which both acts as a keepalive
+// message to the server, as well as carries the WAL position of the
+// client, which then updates the server's replication slot position.
+func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
+ writeBuf := newWriteBuf(rc.c, copyData)
+ writeBuf.WriteByte(standbyStatusUpdate)
+ writeBuf.WriteInt64(int64(k.WalWritePosition))
+ writeBuf.WriteInt64(int64(k.WalFlushPosition))
+ writeBuf.WriteInt64(int64(k.WalApplyPosition))
+ writeBuf.WriteInt64(int64(k.ClientTime))
+ writeBuf.WriteByte(k.ReplyRequested)
+
+ writeBuf.closeMsg()
+
+ _, err = rc.c.conn.Write(writeBuf.buf)
+ if err != nil {
+ rc.c.die(err)
+ }
+
+ return
+}
+
+func (rc *ReplicationConn) Close() error {
+ return rc.c.Close()
+}
+
+func (rc *ReplicationConn) IsAlive() bool {
+ return rc.c.IsAlive()
+}
+
+func (rc *ReplicationConn) CauseOfDeath() error {
+ return rc.c.CauseOfDeath()
+}
+
+func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
+ var t byte
+ var reader *msgReader
+ t, reader, err = rc.c.rxMsg()
+ if err != nil {
+ return
+ }
+
+ switch t {
+ case noticeResponse:
+ pgError := rc.c.rxErrorResponse(reader)
+ if rc.c.shouldLog(LogLevelInfo) {
+ rc.c.log(LogLevelInfo, pgError.Error())
+ }
+ case errorResponse:
+ err = rc.c.rxErrorResponse(reader)
+ if rc.c.shouldLog(LogLevelError) {
+ rc.c.log(LogLevelError, err.Error())
+ }
+ return
+ case copyBothResponse:
+ // This is the tail end of the replication process start,
+ // and can be safely ignored
+ return
+ case copyData:
+ var msgType byte
+ msgType = reader.readByte()
+ switch msgType {
+ case walData:
+ walStart := reader.readInt64()
+ serverWalEnd := reader.readInt64()
+ serverTime := reader.readInt64()
+ walData := reader.readBytes(reader.msgBytesRemaining)
+ walMessage := WalMessage{WalStart: uint64(walStart),
+ ServerWalEnd: uint64(serverWalEnd),
+ ServerTime: uint64(serverTime),
+ WalData: walData,
+ }
+
+ return &ReplicationMessage{WalMessage: &walMessage}, nil
+ case senderKeepalive:
+ serverWalEnd := reader.readInt64()
+ serverTime := reader.readInt64()
+ replyNow := reader.readByte()
+ h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow}
+ return &ReplicationMessage{ServerHeartbeat: h}, nil
+ default:
+ if rc.c.shouldLog(LogLevelError) {
+ rc.c.log(LogLevelError, "Unexpected data playload message type %v", t)
+ }
+ }
+ default:
+ if rc.c.shouldLog(LogLevelError) {
+ rc.c.log(LogLevelError, "Unexpected replication message type %v", t)
+ }
+ }
+ return
+}
+
+// Wait for a single replication message up to timeout time.
+//
+// Properly using this requires some knowledge of the postgres replication mechanisms,
+// as the client can receive both WAL data (the ultimate payload) and server heartbeat
+// updates. The caller also must send standby status updates in order to keep the connection
+// alive and working.
+//
+// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified
+// duration.
+func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) {
+ var zeroTime time.Time
+
+ deadline := time.Now().Add(timeout)
+
+ // Use SetReadDeadline to implement the timeout. SetReadDeadline will
+ // cause operations to fail with a *net.OpError that has a Timeout()
+ // of true. Because the normal pgx rxMsg path considers any error to
+ // have potentially corrupted the state of the connection, it dies
+ // on any errors. So to avoid timeout errors in rxMsg we set the
+ // deadline and peek into the reader. If a timeout error occurs there
+ // we don't break the pgx connection. If the Peek returns that data
+ // is available then we turn off the read deadline before the rxMsg.
+ err = rc.c.conn.SetReadDeadline(deadline)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wait until there is a byte available before continuing onto the normal msg reading path
+ _, err = rc.c.reader.Peek(1)
+ if err != nil {
+ rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
+ if err, ok := err.(*net.OpError); ok && err.Timeout() {
+ return nil, ErrNotificationTimeout
+ }
+ return nil, err
+ }
+
+ err = rc.c.conn.SetReadDeadline(zeroTime)
+ if err != nil {
+ return nil, err
+ }
+
+ return rc.readReplicationMessage()
+}
+
+func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
+ rc.c.lastActivityTime = time.Now()
+
+ rows := rc.c.getRows(sql, nil)
+
+ if err := rc.c.lock(); err != nil {
+ rows.abort(err)
+ return rows, err
+ }
+ rows.unlockConn = true
+
+ err := rc.c.sendSimpleQuery(sql)
+ if err != nil {
+ rows.abort(err)
+ }
+
+ var t byte
+ var r *msgReader
+ t, r, err = rc.c.rxMsg()
+ if err != nil {
+ return nil, err
+ }
+
+ switch t {
+ case rowDescription:
+ rows.fields = rc.c.rxRowDescription(r)
+ // We don't have c.PgTypes here because we're a replication
+ // connection. This means the field descriptions will have
+ // only Oids. Not much we can do about this.
+ default:
+ if e := rc.c.processContextFreeMsg(t, r); e != nil {
+ rows.abort(e)
+ return rows, e
+ }
+ }
+
+ return rows, rows.err
+}
+
+// Execute the "IDENTIFY_SYSTEM" command as documented here:
+// https://www.postgresql.org/docs/9.5/static/protocol-replication.html
+//
+// This will return (if successful) a result set that has a single row
+// that contains the systemid, current timeline, xlogpos and database
+// name.
+//
+// NOTE: Because this is a replication mode connection, we don't have
+// type names, so the field descriptions in the result will have only
+// Oids and no DataTypeName values
+func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) {
+ return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM")
+}
+
+// Execute the "TIMELINE_HISTORY" command as documented here:
+// https://www.postgresql.org/docs/9.5/static/protocol-replication.html
+//
+// This will return (if successful) a result set that has a single row
+// that contains the filename of the history file and the content
+// of the history file. If called for timeline 1, typically this will
+// generate an error that the timeline history file does not exist.
+//
+// NOTE: Because this is a replication mode connection, we don't have
+// type names, so the field descriptions in the result will have only
+// Oids and no DataTypeName values
+func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) {
+ return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline))
+}
+
+// Start a replication connection, sending WAL data to the given replication
+// receiver. This function wraps a START_REPLICATION command as documented
+// here:
+// https://www.postgresql.org/docs/9.5/static/protocol-replication.html
+//
+// Once started, the client needs to invoke WaitForReplicationMessage() in order
+// to fetch the WAL and standby status. Also, it is the responsibility of the caller
+// to periodically send StandbyStatus messages to update the replication slot position.
+//
+// This function assumes that slotName has already been created. In order to omit the timeline argument
+// pass a -1 for the timeline to get the server default behavior.
+func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, timeline int64, pluginArguments ...string) (err error) {
+ var queryString string
+ if timeline >= 0 {
+ queryString = fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s TIMELINE %d", slotName, FormatLSN(startLsn), timeline)
+ } else {
+ queryString = fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s", slotName, FormatLSN(startLsn))
+ }
+
+ for _, arg := range pluginArguments {
+ queryString += fmt.Sprintf(" %s", arg)
+ }
+
+ if err = rc.c.sendQuery(queryString); err != nil {
+ return
+ }
+
+ // The first replication message that comes back here will be (in a success case)
+ // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has
+ // started. This call will either return nil, nil or if it returns an error
+ // that indicates the start replication command failed
+ var r *ReplicationMessage
+ r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout)
+ if err != nil && r != nil {
+ if rc.c.shouldLog(LogLevelError) {
+ rc.c.log(LogLevelError, "Unxpected replication message %v", r)
+ }
+ }
+
+ return
+}
+
+// Create the replication slot, using the given name and output plugin.
+func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
+ _, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin))
+ return
+}
+
+// Drop the replication slot for the given name
+func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
+ _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
+ return
+}
diff --git a/vendor/github.com/jackc/pgx/replication_test.go b/vendor/github.com/jackc/pgx/replication_test.go
new file mode 100644
index 0000000..4f810c7
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/replication_test.go
@@ -0,0 +1,329 @@
+package pgx_test
+
+import (
+ "fmt"
+ "github.com/jackc/pgx"
+ "reflect"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// This function uses a postgresql 9.6 specific column
+func getConfirmedFlushLsnFor(t *testing.T, conn *pgx.Conn, slot string) string {
+ // Fetch the restart LSN of the slot, to establish a starting point
+ rows, err := conn.Query(fmt.Sprintf("select confirmed_flush_lsn from pg_replication_slots where slot_name='%s'", slot))
+ if err != nil {
+ t.Fatalf("conn.Query failed: %v", err)
+ }
+ defer rows.Close()
+
+ var restartLsn string
+ for rows.Next() {
+ rows.Scan(&restartLsn)
+ }
+ return restartLsn
+}
+
+// This battleship test (at least somewhat by necessity) does
+// several things all at once in a single run. It:
+// - Establishes a replication connection & slot
+// - Does a series of operations to create some known WAL entries
+// - Replicates the entries down, and checks that the rows it
+// created come down in order
+// - Sends a standby status message to update the server with the
+// wal position of the slot
+// - Checks the wal position of the slot on the server to make sure
+// the update succeeded
+func TestSimpleReplicationConnection(t *testing.T) {
+ t.Parallel()
+
+ var err error
+
+ if replicationConnConfig == nil {
+ t.Skip("Skipping due to undefined replicationConnConfig")
+ }
+
+ conn := mustConnect(t, *replicationConnConfig)
+ defer closeConn(t, conn)
+
+ replicationConn := mustReplicationConnect(t, *replicationConnConfig)
+ defer closeReplicationConn(t, replicationConn)
+
+ err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding")
+ if err != nil {
+ t.Logf("replication slot create failed: %v", err)
+ }
+
+ // Do a simple change so we can get some wal data
+ _, err = conn.Exec("create table if not exists replication_test (a integer)")
+ if err != nil {
+ t.Fatalf("Failed to create table: %v", err)
+ }
+
+ err = replicationConn.StartReplication("pgx_test", 0, -1)
+ if err != nil {
+ t.Fatalf("Failed to start replication: %v", err)
+ }
+
+ var i int32
+ var insertedTimes []int64
+ for i < 5 {
+ var ct pgx.CommandTag
+ currentTime := time.Now().Unix()
+ insertedTimes = append(insertedTimes, currentTime)
+ ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime)
+ if err != nil {
+ t.Fatalf("Insert failed: %v", err)
+ }
+ t.Logf("Inserted %d rows", ct.RowsAffected())
+ i++
+ }
+
+ i = 0
+ var foundTimes []int64
+ var foundCount int
+ var maxWal uint64
+ for {
+ var message *pgx.ReplicationMessage
+
+ message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second))
+ if err != nil {
+ if err != pgx.ErrNotificationTimeout {
+ t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err))
+ }
+ }
+ if message != nil {
+ if message.WalMessage != nil {
+ // The waldata payload with the test_decoding plugin looks like:
+ // public.replication_test: INSERT: a[integer]:2
+ // What we wanna do here is check that once we find one of our inserted times,
+ // that they occur in the wal stream in the order we executed them.
+ walString := string(message.WalMessage.WalData)
+ if strings.Contains(walString, "public.replication_test: INSERT") {
+ stringParts := strings.Split(walString, ":")
+ offset, err := strconv.ParseInt(stringParts[len(stringParts)-1], 10, 64)
+ if err != nil {
+ t.Fatalf("Failed to parse walString %s", walString)
+ }
+ if foundCount > 0 || offset == insertedTimes[0] {
+ foundTimes = append(foundTimes, offset)
+ foundCount++
+ }
+ }
+ if message.WalMessage.WalStart > maxWal {
+ maxWal = message.WalMessage.WalStart
+ }
+
+ }
+ if message.ServerHeartbeat != nil {
+ t.Logf("Got heartbeat: %s", message.ServerHeartbeat)
+ }
+ } else {
+ t.Log("Timed out waiting for wal message")
+ i++
+ }
+ if i > 3 {
+ t.Log("Actual timeout")
+ break
+ }
+ }
+
+ if foundCount != len(insertedTimes) {
+ t.Fatalf("Failed to find all inserted time values in WAL stream (found %d expected %d)", foundCount, len(insertedTimes))
+ }
+
+ for i := range insertedTimes {
+ if foundTimes[i] != insertedTimes[i] {
+ t.Fatalf("Found %d expected %d", foundTimes[i], insertedTimes[i])
+ }
+ }
+
+ t.Logf("Found %d times, as expected", len(foundTimes))
+
+ // Before closing our connection, let's send a standby status to update our wal
+ // position, which should then be reflected if we fetch out our current wal position
+ // for the slot
+ status, err := pgx.NewStandbyStatus(maxWal)
+ if err != nil {
+ t.Errorf("Failed to create standby status %v", err)
+ }
+ replicationConn.SendStandbyStatus(status)
+
+ restartLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test")
+ integerRestartLsn, _ := pgx.ParseLSN(restartLsn)
+ if integerRestartLsn != maxWal {
+ t.Fatalf("Wal offset update failed, expected %s found %s", pgx.FormatLSN(maxWal), restartLsn)
+ }
+
+ closeReplicationConn(t, replicationConn)
+
+ replicationConn2 := mustReplicationConnect(t, *replicationConnConfig)
+ defer closeReplicationConn(t, replicationConn2)
+
+ err = replicationConn2.DropReplicationSlot("pgx_test")
+ if err != nil {
+ t.Fatalf("Failed to drop replication slot: %v", err)
+ }
+
+ droppedLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test")
+ if droppedLsn != "" {
+ t.Errorf("Got odd flush lsn %s for supposedly dropped slot", droppedLsn)
+ }
+}
+
+func TestReplicationConn_DropReplicationSlot(t *testing.T) {
+ if replicationConnConfig == nil {
+ t.Skip("Skipping due to undefined replicationConnConfig")
+ }
+
+ replicationConn := mustReplicationConnect(t, *replicationConnConfig)
+ defer closeReplicationConn(t, replicationConn)
+
+ err := replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding")
+ if err != nil {
+ t.Logf("replication slot create failed: %v", err)
+ }
+ err = replicationConn.DropReplicationSlot("pgx_slot_test")
+ if err != nil {
+ t.Fatalf("Failed to drop replication slot: %v", err)
+ }
+
+ // We re-create to ensure the drop worked.
+ err = replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding")
+ if err != nil {
+ t.Logf("replication slot create failed: %v", err)
+ }
+
+ // And finally we drop to ensure we don't leave dirty state
+ err = replicationConn.DropReplicationSlot("pgx_slot_test")
+ if err != nil {
+ t.Fatalf("Failed to drop replication slot: %v", err)
+ }
+}
+
+func TestIdentifySystem(t *testing.T) {
+ if replicationConnConfig == nil {
+ t.Skip("Skipping due to undefined replicationConnConfig")
+ }
+
+ replicationConn2 := mustReplicationConnect(t, *replicationConnConfig)
+ defer closeReplicationConn(t, replicationConn2)
+
+ r, err := replicationConn2.IdentifySystem()
+ if err != nil {
+ t.Error(err)
+ }
+ defer r.Close()
+ for _, fd := range r.FieldDescriptions() {
+ t.Logf("Field: %s of type %v", fd.Name, fd.DataType)
+ }
+
+ var rowCount int
+ for r.Next() {
+ rowCount++
+ values, err := r.Values()
+ if err != nil {
+ t.Error(err)
+ }
+ t.Logf("Row values: %v", values)
+ }
+ if r.Err() != nil {
+ t.Error(r.Err())
+ }
+
+ if rowCount == 0 {
+ t.Errorf("Failed to find any rows: %d", rowCount)
+ }
+}
+
+func getCurrentTimeline(t *testing.T, rc *pgx.ReplicationConn) int {
+ r, err := rc.IdentifySystem()
+ if err != nil {
+ t.Error(err)
+ }
+ defer r.Close()
+ for r.Next() {
+ values, e := r.Values()
+ if e != nil {
+ t.Error(e)
+ }
+ timeline, e := strconv.Atoi(values[1].(string))
+ if e != nil {
+ t.Error(e)
+ }
+ return timeline
+ }
+ t.Fatal("Failed to read timeline")
+ return -1
+}
+
+func TestGetTimelineHistory(t *testing.T) {
+ if replicationConnConfig == nil {
+ t.Skip("Skipping due to undefined replicationConnConfig")
+ }
+
+ replicationConn := mustReplicationConnect(t, *replicationConnConfig)
+ defer closeReplicationConn(t, replicationConn)
+
+ timeline := getCurrentTimeline(t, replicationConn)
+
+ r, err := replicationConn.TimelineHistory(timeline)
+ if err != nil {
+ t.Errorf("%#v", err)
+ }
+ defer r.Close()
+
+ for _, fd := range r.FieldDescriptions() {
+ t.Logf("Field: %s of type %v", fd.Name, fd.DataType)
+ }
+
+ var rowCount int
+ for r.Next() {
+ rowCount++
+ values, err := r.Values()
+ if err != nil {
+ t.Error(err)
+ }
+ t.Logf("Row values: %v", values)
+ }
+ if r.Err() != nil {
+ if strings.Contains(r.Err().Error(), "No such file or directory") {
+ // This is normal, this means the timeline we're on has no
+ // history, which is the common case in a test db that
+ // has only one timeline
+ return
+ }
+ t.Error(r.Err())
+ }
+
+ // If we have a timeline history (see above) there should have been
+ // rows emitted
+ if rowCount == 0 {
+ t.Errorf("Failed to find any rows: %d", rowCount)
+ }
+}
+
+func TestStandbyStatusParsing(t *testing.T) {
+ // Let's push the boundary conditions of the standby status and ensure it errors correctly
+ status, err := pgx.NewStandbyStatus(0, 1, 2, 3, 4)
+ if err == nil {
+ t.Errorf("Expected error from new standby status, got %v", status)
+ }
+
+ // And if you provide 3 args, ensure the right fields are set
+ status, err = pgx.NewStandbyStatus(1, 2, 3)
+ if err != nil {
+ t.Errorf("Failed to create test status: %v", err)
+ }
+ if status.WalFlushPosition != 1 {
+ t.Errorf("Unexpected flush position %d", status.WalFlushPosition)
+ }
+ if status.WalApplyPosition != 2 {
+ t.Errorf("Unexpected apply position %d", status.WalApplyPosition)
+ }
+ if status.WalWritePosition != 3 {
+ t.Errorf("Unexpected write position %d", status.WalWritePosition)
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/sql.go b/vendor/github.com/jackc/pgx/sql.go
new file mode 100644
index 0000000..7ee0f2a
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/sql.go
@@ -0,0 +1,29 @@
+package pgx
+
+import (
+ "strconv"
+)
+
+// QueryArgs is a container for arguments to an SQL query. It is helpful when
+// building SQL statements where the number of arguments is variable.
+type QueryArgs []interface{}
+
+var placeholders []string
+
+func init() {
+ placeholders = make([]string, 64)
+
+ for i := 1; i < 64; i++ {
+ placeholders[i] = "$" + strconv.Itoa(i)
+ }
+}
+
+// Append adds a value to qa and returns the placeholder value for the
+// argument. e.g. $1, $2, etc.
+func (qa *QueryArgs) Append(v interface{}) string {
+ *qa = append(*qa, v)
+ if len(*qa) < len(placeholders) {
+ return placeholders[len(*qa)]
+ }
+ return "$" + strconv.Itoa(len(*qa))
+}
diff --git a/vendor/github.com/jackc/pgx/sql_test.go b/vendor/github.com/jackc/pgx/sql_test.go
new file mode 100644
index 0000000..dd03603
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/sql_test.go
@@ -0,0 +1,36 @@
+package pgx_test
+
+import (
+ "strconv"
+ "testing"
+
+ "github.com/jackc/pgx"
+)
+
+func TestQueryArgs(t *testing.T) {
+ var qa pgx.QueryArgs
+
+ for i := 1; i < 512; i++ {
+ expectedPlaceholder := "$" + strconv.Itoa(i)
+ placeholder := qa.Append(i)
+ if placeholder != expectedPlaceholder {
+ t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder)
+ }
+ }
+}
+
+func BenchmarkQueryArgs(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ qa := pgx.QueryArgs(make([]interface{}, 0, 16))
+ qa.Append("foo1")
+ qa.Append("foo2")
+ qa.Append("foo3")
+ qa.Append("foo4")
+ qa.Append("foo5")
+ qa.Append("foo6")
+ qa.Append("foo7")
+ qa.Append("foo8")
+ qa.Append("foo9")
+ qa.Append("foo10")
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/stdlib/sql.go b/vendor/github.com/jackc/pgx/stdlib/sql.go
new file mode 100644
index 0000000..8c78cd3
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/stdlib/sql.go
@@ -0,0 +1,348 @@
+// Package stdlib is the compatibility layer from pgx to database/sql.
+//
+// A database/sql connection can be established through sql.Open.
+//
+// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
+// if err != nil {
+// return err
+// }
+//
+// Or from a DSN string.
+//
+// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
+// if err != nil {
+// return err
+// }
+//
+// Or a normal pgx connection pool can be established and the database/sql
+// connection can be created through stdlib.OpenFromConnPool(). This allows
+// more control over the connection process (such as TLS), more control
+// over the connection pool, setting an AfterConnect hook, and using both
+// database/sql and pgx interfaces as needed.
+//
+// connConfig := pgx.ConnConfig{
+// Host: "localhost",
+// User: "pgx_md5",
+// Password: "secret",
+// Database: "pgx_test",
+// }
+//
+// config := pgx.ConnPoolConfig{ConnConfig: connConfig}
+// pool, err := pgx.NewConnPool(config)
+// if err != nil {
+// return err
+// }
+//
+// db, err := stdlib.OpenFromConnPool(pool)
+// if err != nil {
+// t.Fatalf("Unable to create connection pool: %v", err)
+// }
+//
+// If the database/sql connection is established through
+// stdlib.OpenFromConnPool then access to a pgx *ConnPool can be regained
+// through db.Driver(). This allows writing a fast path for pgx while
+// preserving compatibility with other drivers and database
+//
+// if driver, ok := db.Driver().(*stdlib.Driver); ok && driver.Pool != nil {
+// // fast path with pgx
+// } else {
+// // normal path for other drivers and databases
+// }
+package stdlib
+
+import (
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "sync"
+
+ "github.com/jackc/pgx"
+)
+
+var (
+ openFromConnPoolCountMu sync.Mutex
+ openFromConnPoolCount int
+)
+
+// oids that map to intrinsic database/sql types. These will be allowed to be
+// binary, anything else will be forced to text format
+var databaseSqlOids map[pgx.Oid]bool
+
+func init() {
+ d := &Driver{}
+ sql.Register("pgx", d)
+
+ databaseSqlOids = make(map[pgx.Oid]bool)
+ databaseSqlOids[pgx.BoolOid] = true
+ databaseSqlOids[pgx.ByteaOid] = true
+ databaseSqlOids[pgx.Int2Oid] = true
+ databaseSqlOids[pgx.Int4Oid] = true
+ databaseSqlOids[pgx.Int8Oid] = true
+ databaseSqlOids[pgx.Float4Oid] = true
+ databaseSqlOids[pgx.Float8Oid] = true
+ databaseSqlOids[pgx.DateOid] = true
+ databaseSqlOids[pgx.TimestampTzOid] = true
+ databaseSqlOids[pgx.TimestampOid] = true
+}
+
+type Driver struct {
+ Pool *pgx.ConnPool
+}
+
+func (d *Driver) Open(name string) (driver.Conn, error) {
+ if d.Pool != nil {
+ conn, err := d.Pool.Acquire()
+ if err != nil {
+ return nil, err
+ }
+
+ return &Conn{conn: conn, pool: d.Pool}, nil
+ }
+
+ connConfig, err := pgx.ParseConnectionString(name)
+ if err != nil {
+ return nil, err
+ }
+
+ conn, err := pgx.Connect(connConfig)
+ if err != nil {
+ return nil, err
+ }
+
+ c := &Conn{conn: conn}
+ return c, nil
+}
+
+// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB
+// with pool as the backend. This enables full control over the connection
+// process and configuration while maintaining compatibility with the
+// database/sql interface. In addition, by calling Driver() on the returned
+// *sql.DB and typecasting to *stdlib.Driver a reference to the pgx.ConnPool can
+// be reaquired later. This allows fast paths targeting pgx to be used while
+// still maintaining compatibility with other databases and drivers.
+//
+// pool connection size must be at least 2.
+func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) {
+ d := &Driver{Pool: pool}
+
+ openFromConnPoolCountMu.Lock()
+ name := fmt.Sprintf("pgx-%d", openFromConnPoolCount)
+ openFromConnPoolCount++
+ openFromConnPoolCountMu.Unlock()
+
+ sql.Register(name, d)
+ db, err := sql.Open(name, "")
+ if err != nil {
+ return nil, err
+ }
+
+ // Presumably OpenFromConnPool is being used because the user wants to use
+ // database/sql most of the time, but fast path with pgx some of the time.
+ // Allow database/sql to use all the connections, but release 2 idle ones.
+ // Don't have database/sql immediately release all idle connections because
+ // that would mean that prepared statements would be lost (which kills
+ // performance if the prepared statements constantly have to be reprepared)
+ stat := pool.Stat()
+
+ if stat.MaxConnections <= 2 {
+ return nil, errors.New("pool connection size must be at least 3")
+ }
+ db.SetMaxIdleConns(stat.MaxConnections - 2)
+ db.SetMaxOpenConns(stat.MaxConnections)
+
+ return db, nil
+}
+
+type Conn struct {
+ conn *pgx.Conn
+ pool *pgx.ConnPool
+ psCount int64 // Counter used for creating unique prepared statement names
+}
+
+func (c *Conn) Prepare(query string) (driver.Stmt, error) {
+ if !c.conn.IsAlive() {
+ return nil, driver.ErrBadConn
+ }
+
+ name := fmt.Sprintf("pgx_%d", c.psCount)
+ c.psCount++
+
+ ps, err := c.conn.Prepare(name, query)
+ if err != nil {
+ return nil, err
+ }
+
+ restrictBinaryToDatabaseSqlTypes(ps)
+
+ return &Stmt{ps: ps, conn: c}, nil
+}
+
+func (c *Conn) Close() error {
+ err := c.conn.Close()
+ if c.pool != nil {
+ c.pool.Release(c.conn)
+ }
+
+ return err
+}
+
+func (c *Conn) Begin() (driver.Tx, error) {
+ if !c.conn.IsAlive() {
+ return nil, driver.ErrBadConn
+ }
+
+ _, err := c.conn.Exec("begin")
+ if err != nil {
+ return nil, err
+ }
+
+ return &Tx{conn: c.conn}, nil
+}
+
+func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
+ if !c.conn.IsAlive() {
+ return nil, driver.ErrBadConn
+ }
+
+ args := valueToInterface(argsV)
+ commandTag, err := c.conn.Exec(query, args...)
+ return driver.RowsAffected(commandTag.RowsAffected()), err
+}
+
+func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
+ if !c.conn.IsAlive() {
+ return nil, driver.ErrBadConn
+ }
+
+ ps, err := c.conn.Prepare("", query)
+ if err != nil {
+ return nil, err
+ }
+
+ restrictBinaryToDatabaseSqlTypes(ps)
+
+ return c.queryPrepared("", argsV)
+}
+
+func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
+ if !c.conn.IsAlive() {
+ return nil, driver.ErrBadConn
+ }
+
+ args := valueToInterface(argsV)
+
+ rows, err := c.conn.Query(name, args...)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Rows{rows: rows}, nil
+}
+
+// Anything that isn't a database/sql compatible type needs to be forced to
+// text format so that pgx.Rows.Values doesn't decode it into a native type
+// (e.g. []int32)
+func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
+ for i := range ps.FieldDescriptions {
+ intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType]
+ if !intrinsic {
+ ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
+ }
+ }
+}
+
+type Stmt struct {
+ ps *pgx.PreparedStatement
+ conn *Conn
+}
+
+func (s *Stmt) Close() error {
+ return s.conn.conn.Deallocate(s.ps.Name)
+}
+
+func (s *Stmt) NumInput() int {
+ return len(s.ps.ParameterOids)
+}
+
+func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
+ return s.conn.Exec(s.ps.Name, argsV)
+}
+
+func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
+ return s.conn.queryPrepared(s.ps.Name, argsV)
+}
+
+// TODO - rename to avoid alloc
+type Rows struct {
+ rows *pgx.Rows
+}
+
+func (r *Rows) Columns() []string {
+ fieldDescriptions := r.rows.FieldDescriptions()
+ names := make([]string, 0, len(fieldDescriptions))
+ for _, fd := range fieldDescriptions {
+ names = append(names, fd.Name)
+ }
+ return names
+}
+
+func (r *Rows) Close() error {
+ r.rows.Close()
+ return nil
+}
+
+func (r *Rows) Next(dest []driver.Value) error {
+ more := r.rows.Next()
+ if !more {
+ if r.rows.Err() == nil {
+ return io.EOF
+ } else {
+ return r.rows.Err()
+ }
+ }
+
+ values, err := r.rows.Values()
+ if err != nil {
+ return err
+ }
+
+ if len(dest) < len(values) {
+ fmt.Printf("%d: %#v\n", len(dest), dest)
+ fmt.Printf("%d: %#v\n", len(values), values)
+ return errors.New("expected more values than were received")
+ }
+
+ for i, v := range values {
+ dest[i] = driver.Value(v)
+ }
+
+ return nil
+}
+
+func valueToInterface(argsV []driver.Value) []interface{} {
+ args := make([]interface{}, 0, len(argsV))
+ for _, v := range argsV {
+ if v != nil {
+ args = append(args, v.(interface{}))
+ } else {
+ args = append(args, nil)
+ }
+ }
+ return args
+}
+
+type Tx struct {
+ conn *pgx.Conn
+}
+
+func (t *Tx) Commit() error {
+ _, err := t.conn.Exec("commit")
+ return err
+}
+
+func (t *Tx) Rollback() error {
+ _, err := t.conn.Exec("rollback")
+ return err
+}
diff --git a/vendor/github.com/jackc/pgx/stdlib/sql_test.go b/vendor/github.com/jackc/pgx/stdlib/sql_test.go
new file mode 100644
index 0000000..1455ca1
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/stdlib/sql_test.go
@@ -0,0 +1,691 @@
+package stdlib_test
+
+import (
+ "bytes"
+ "database/sql"
+ "github.com/jackc/pgx"
+ "github.com/jackc/pgx/stdlib"
+ "sync"
+ "testing"
+)
+
+func openDB(t *testing.T) *sql.DB {
+ db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
+ if err != nil {
+ t.Fatalf("sql.Open failed: %v", err)
+ }
+
+ return db
+}
+
+func closeDB(t *testing.T, db *sql.DB) {
+ err := db.Close()
+ if err != nil {
+ t.Fatalf("db.Close unexpectedly failed: %v", err)
+ }
+}
+
+// Do a simple query to ensure the connection is still usable
+func ensureConnValid(t *testing.T, db *sql.DB) {
+ var sum, rowCount int32
+
+ rows, err := db.Query("select generate_series(1,$1)", 10)
+ if err != nil {
+ t.Fatalf("db.Query failed: %v", err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var n int32
+ rows.Scan(&n)
+ sum += n
+ rowCount++
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("db.Query failed: %v", err)
+ }
+
+ if rowCount != 10 {
+ t.Error("Select called onDataRow wrong number of times")
+ }
+ if sum != 55 {
+ t.Error("Wrong values returned")
+ }
+}
+
+type preparer interface {
+ Prepare(query string) (*sql.Stmt, error)
+}
+
+func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
+ stmt, err := p.Prepare(sql)
+ if err != nil {
+ t.Fatalf("%v Prepare unexpectedly failed: %v", p, err)
+ }
+
+ return stmt
+}
+
+func closeStmt(t *testing.T, stmt *sql.Stmt) {
+ err := stmt.Close()
+ if err != nil {
+ t.Fatalf("stmt.Close unexpectedly failed: %v", err)
+ }
+}
+
+func TestNormalLifeCycle(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
+ defer closeStmt(t, stmt)
+
+ rows, err := stmt.Query(int32(1), int32(10))
+ if err != nil {
+ t.Fatalf("stmt.Query unexpectedly failed: %v", err)
+ }
+
+ rowCount := int64(0)
+
+ for rows.Next() {
+ rowCount++
+
+ var s string
+ var n int64
+ if err := rows.Scan(&s, &n); err != nil {
+ t.Fatalf("rows.Scan unexpectedly failed: %v", err)
+ }
+ if s != "foo" {
+ t.Errorf(`Expected "foo", received "%v"`, s)
+ }
+ if n != rowCount {
+ t.Errorf("Expected %d, received %d", rowCount, n)
+ }
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("rows.Err unexpectedly is: %v", err)
+ }
+ if rowCount != 10 {
+ t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
+ }
+
+ err = rows.Close()
+ if err != nil {
+ t.Fatalf("rows.Close unexpectedly failed: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestSqlOpenDoesNotHavePool(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ driver := db.Driver().(*stdlib.Driver)
+ if driver.Pool != nil {
+ t.Fatal("Did not expect driver opened through database/sql to have Pool, but it did")
+ }
+}
+
+func TestOpenFromConnPool(t *testing.T) {
+ connConfig := pgx.ConnConfig{
+ Host: "127.0.0.1",
+ User: "pgx_md5",
+ Password: "secret",
+ Database: "pgx_test",
+ }
+
+ config := pgx.ConnPoolConfig{ConnConfig: connConfig}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer pool.Close()
+
+ db, err := stdlib.OpenFromConnPool(pool)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer closeDB(t, db)
+
+ // Can get pgx.ConnPool from driver
+ driver := db.Driver().(*stdlib.Driver)
+ if driver.Pool == nil {
+ t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not")
+ }
+
+ // Normal sql/database still works
+ var n int64
+ err = db.QueryRow("select 1").Scan(&n)
+ if err != nil {
+ t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
+ }
+}
+
+func TestOpenFromConnPoolRace(t *testing.T) {
+ wg := &sync.WaitGroup{}
+ connConfig := pgx.ConnConfig{
+ Host: "127.0.0.1",
+ User: "pgx_md5",
+ Password: "secret",
+ Database: "pgx_test",
+ }
+
+ config := pgx.ConnPoolConfig{ConnConfig: connConfig}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer pool.Close()
+
+ wg.Add(10)
+ for i := 0; i < 10; i++ {
+ go func() {
+ defer wg.Done()
+ db, err := stdlib.OpenFromConnPool(pool)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer closeDB(t, db)
+
+ // Can get pgx.ConnPool from driver
+ driver := db.Driver().(*stdlib.Driver)
+ if driver.Pool == nil {
+ t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not")
+ }
+ }()
+ }
+
+ wg.Wait()
+}
+
+func TestStmtExec(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("db.Begin unexpectedly failed: %v", err)
+ }
+
+ createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
+ _, err = createStmt.Exec()
+ if err != nil {
+ t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
+ }
+ closeStmt(t, createStmt)
+
+ insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
+ result, err := insertStmt.Exec("foo")
+ if err != nil {
+ t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
+ }
+
+ n, err := result.RowsAffected()
+ if err != nil {
+ t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
+ }
+ if n != 1 {
+ t.Fatalf("Expected 1, received %d", n)
+ }
+ closeStmt(t, insertStmt)
+
+ if err != nil {
+ t.Fatalf("tx.Commit unexpectedly failed: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestQueryCloseRowsEarly(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
+ defer closeStmt(t, stmt)
+
+ rows, err := stmt.Query(int32(1), int32(10))
+ if err != nil {
+ t.Fatalf("stmt.Query unexpectedly failed: %v", err)
+ }
+
+ // Close rows immediately without having read them
+ err = rows.Close()
+ if err != nil {
+ t.Fatalf("rows.Close unexpectedly failed: %v", err)
+ }
+
+ // Run the query again to ensure the connection and statement are still ok
+ rows, err = stmt.Query(int32(1), int32(10))
+ if err != nil {
+ t.Fatalf("stmt.Query unexpectedly failed: %v", err)
+ }
+
+ rowCount := int64(0)
+
+ for rows.Next() {
+ rowCount++
+
+ var s string
+ var n int64
+ if err := rows.Scan(&s, &n); err != nil {
+ t.Fatalf("rows.Scan unexpectedly failed: %v", err)
+ }
+ if s != "foo" {
+ t.Errorf(`Expected "foo", received "%v"`, s)
+ }
+ if n != rowCount {
+ t.Errorf("Expected %d, received %d", rowCount, n)
+ }
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("rows.Err unexpectedly is: %v", err)
+ }
+ if rowCount != 10 {
+ t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
+ }
+
+ err = rows.Close()
+ if err != nil {
+ t.Fatalf("rows.Close unexpectedly failed: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestConnExec(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ _, err := db.Exec("create temporary table t(a varchar not null)")
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ result, err := db.Exec("insert into t values('hey')")
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ n, err := result.RowsAffected()
+ if err != nil {
+ t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
+ }
+ if n != 1 {
+ t.Fatalf("Expected 1, received %d", n)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestConnQuery(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
+ if err != nil {
+ t.Fatalf("db.Query unexpectedly failed: %v", err)
+ }
+
+ rowCount := int64(0)
+
+ for rows.Next() {
+ rowCount++
+
+ var s string
+ var n int64
+ if err := rows.Scan(&s, &n); err != nil {
+ t.Fatalf("rows.Scan unexpectedly failed: %v", err)
+ }
+ if s != "foo" {
+ t.Errorf(`Expected "foo", received "%v"`, s)
+ }
+ if n != rowCount {
+ t.Errorf("Expected %d, received %d", rowCount, n)
+ }
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("rows.Err unexpectedly is: %v", err)
+ }
+ if rowCount != 10 {
+ t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
+ }
+
+ err = rows.Close()
+ if err != nil {
+ t.Fatalf("rows.Close unexpectedly failed: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+type testLog struct {
+ lvl int
+ msg string
+ ctx []interface{}
+}
+
+type testLogger struct {
+ logs []testLog
+}
+
+func (l *testLogger) Debug(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx})
+}
+func (l *testLogger) Info(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx})
+}
+func (l *testLogger) Warn(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx})
+}
+func (l *testLogger) Error(msg string, ctx ...interface{}) {
+ l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx})
+}
+
+func TestConnQueryLog(t *testing.T) {
+ logger := &testLogger{}
+
+ connConfig := pgx.ConnConfig{
+ Host: "127.0.0.1",
+ User: "pgx_md5",
+ Password: "secret",
+ Database: "pgx_test",
+ Logger: logger,
+ }
+
+ config := pgx.ConnPoolConfig{ConnConfig: connConfig}
+ pool, err := pgx.NewConnPool(config)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer pool.Close()
+
+ db, err := stdlib.OpenFromConnPool(pool)
+ if err != nil {
+ t.Fatalf("Unable to create connection pool: %v", err)
+ }
+ defer closeDB(t, db)
+
+ // clear logs from initial connection
+ logger.logs = []testLog{}
+
+ var n int64
+ err = db.QueryRow("select 1").Scan(&n)
+ if err != nil {
+ t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
+ }
+
+ l := logger.logs[0]
+ if l.msg != "Query" {
+ t.Errorf("Expected to log Query, but got %v", l)
+ }
+
+ if !(l.ctx[0] == "sql" && l.ctx[1] == "select 1") {
+ t.Errorf("Expected to log Query with sql 'select 1', but got %v", l)
+ }
+}
+
+func TestConnQueryNull(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ rows, err := db.Query("select $1::int", nil)
+ if err != nil {
+ t.Fatalf("db.Query unexpectedly failed: %v", err)
+ }
+
+ rowCount := int64(0)
+
+ for rows.Next() {
+ rowCount++
+
+ var n sql.NullInt64
+ if err := rows.Scan(&n); err != nil {
+ t.Fatalf("rows.Scan unexpectedly failed: %v", err)
+ }
+ if n.Valid != false {
+ t.Errorf("Expected n to be null, but it was %v", n)
+ }
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("rows.Err unexpectedly is: %v", err)
+ }
+ if rowCount != 1 {
+ t.Fatalf("Expected to receive 11 rows, instead received %d", rowCount)
+ }
+
+ err = rows.Close()
+ if err != nil {
+ t.Fatalf("rows.Close unexpectedly failed: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestConnQueryRowByteSlice(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ expected := []byte{222, 173, 190, 239}
+ var actual []byte
+
+ err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
+ if err != nil {
+ t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
+ }
+
+ if bytes.Compare(actual, expected) != 0 {
+ t.Fatalf("Expected %v, but got %v", expected, actual)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestConnQueryFailure(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ _, err := db.Query("select 'foo")
+ if _, ok := err.(pgx.PgError); !ok {
+ t.Fatalf("Expected db.Query to return pgx.PgError, but instead received: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+// Test type that pgx would handle natively in binary, but since it is not a
+// database/sql native type should be passed through as a string
+func TestConnQueryRowPgxBinary(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ sql := "select $1::int4[]"
+ expected := "{1,2,3}"
+ var actual string
+
+ err := db.QueryRow(sql, expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
+ }
+
+ if actual != expected {
+ t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestConnQueryRowUnknownType(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ sql := "select $1::point"
+ expected := "(1,2)"
+ var actual string
+
+ err := db.QueryRow(sql, expected).Scan(&actual)
+ if err != nil {
+ t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
+ }
+
+ if actual != expected {
+ t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestConnQueryJSONIntoByteSlice(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ if !serverHasJSON(t, db) {
+ t.Skip("Skipping due to server's lack of JSON type")
+ }
+
+ _, err := db.Exec(`
+ create temporary table docs(
+ body json not null
+ );
+
+ insert into docs(body) values('{"foo":"bar"}');
+`)
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ sql := `select * from docs`
+ expected := []byte(`{"foo":"bar"}`)
+ var actual []byte
+
+ err = db.QueryRow(sql).Scan(&actual)
+ if err != nil {
+ t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
+ }
+
+ if bytes.Compare(actual, expected) != 0 {
+ t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
+ }
+
+ _, err = db.Exec(`drop table docs`)
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ if !serverHasJSON(t, db) {
+ t.Skip("Skipping due to server's lack of JSON type")
+ }
+
+ _, err := db.Exec(`
+ create temporary table docs(
+ body json not null
+ );
+`)
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ expected := []byte(`{"foo":"bar"}`)
+
+ _, err = db.Exec(`insert into docs(body) values($1)`, expected)
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ var actual []byte
+ err = db.QueryRow(`select body from docs`).Scan(&actual)
+ if err != nil {
+ t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
+ }
+
+ if bytes.Compare(actual, expected) != 0 {
+ t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
+ }
+
+ _, err = db.Exec(`drop table docs`)
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ ensureConnValid(t, db)
+}
+
+func serverHasJSON(t *testing.T, db *sql.DB) bool {
+ var hasJSON bool
+ err := db.QueryRow(`select exists(select 1 from pg_type where typname='json')`).Scan(&hasJSON)
+ if err != nil {
+ t.Fatalf("db.QueryRow unexpectedly failed: %v", err)
+ }
+ return hasJSON
+}
+
+func TestTransactionLifeCycle(t *testing.T) {
+ db := openDB(t)
+ defer closeDB(t, db)
+
+ _, err := db.Exec("create temporary table t(a varchar not null)")
+ if err != nil {
+ t.Fatalf("db.Exec unexpectedly failed: %v", err)
+ }
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("db.Begin unexpectedly failed: %v", err)
+ }
+
+ _, err = tx.Exec("insert into t values('hi')")
+ if err != nil {
+ t.Fatalf("tx.Exec unexpectedly failed: %v", err)
+ }
+
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback unexpectedly failed: %v", err)
+ }
+
+ var n int64
+ err = db.QueryRow("select count(*) from t").Scan(&n)
+ if err != nil {
+ t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("Expected 0 rows due to rollback, instead found %d", n)
+ }
+
+ tx, err = db.Begin()
+ if err != nil {
+ t.Fatalf("db.Begin unexpectedly failed: %v", err)
+ }
+
+ _, err = tx.Exec("insert into t values('hi')")
+ if err != nil {
+ t.Fatalf("tx.Exec unexpectedly failed: %v", err)
+ }
+
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("tx.Commit unexpectedly failed: %v", err)
+ }
+
+ err = db.QueryRow("select count(*) from t").Scan(&n)
+ if err != nil {
+ t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
+ }
+ if n != 1 {
+ t.Fatalf("Expected 1 rows due to rollback, instead found %d", n)
+ }
+
+ ensureConnValid(t, db)
+}
diff --git a/vendor/github.com/jackc/pgx/stress_test.go b/vendor/github.com/jackc/pgx/stress_test.go
new file mode 100644
index 0000000..150d13c
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/stress_test.go
@@ -0,0 +1,346 @@
+package pgx_test
+
+import (
+ "errors"
+ "fmt"
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/jackc/fake"
+ "github.com/jackc/pgx"
+)
+
+type execer interface {
+ Exec(sql string, arguments ...interface{}) (commandTag pgx.CommandTag, err error)
+}
+type queryer interface {
+ Query(sql string, args ...interface{}) (*pgx.Rows, error)
+}
+type queryRower interface {
+ QueryRow(sql string, args ...interface{}) *pgx.Row
+}
+
+func TestStressConnPool(t *testing.T) {
+ maxConnections := 8
+ pool := createConnPool(t, maxConnections)
+ defer pool.Close()
+
+ setupStressDB(t, pool)
+
+ actions := []struct {
+ name string
+ fn func(*pgx.ConnPool, int) error
+ }{
+ {"insertUnprepared", func(p *pgx.ConnPool, n int) error { return insertUnprepared(p, n) }},
+ {"queryRowWithoutParams", func(p *pgx.ConnPool, n int) error { return queryRowWithoutParams(p, n) }},
+ {"query", func(p *pgx.ConnPool, n int) error { return queryCloseEarly(p, n) }},
+ {"queryCloseEarly", func(p *pgx.ConnPool, n int) error { return query(p, n) }},
+ {"queryErrorWhileReturningRows", func(p *pgx.ConnPool, n int) error { return queryErrorWhileReturningRows(p, n) }},
+ {"txInsertRollback", txInsertRollback},
+ {"txInsertCommit", txInsertCommit},
+ {"txMultipleQueries", txMultipleQueries},
+ {"notify", notify},
+ {"listenAndPoolUnlistens", listenAndPoolUnlistens},
+ {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
+ {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
+ }
+
+ var timer *time.Timer
+ if testing.Short() {
+ timer = time.NewTimer(5 * time.Second)
+ } else {
+ timer = time.NewTimer(60 * time.Second)
+ }
+ workerCount := 16
+
+ workChan := make(chan int)
+ doneChan := make(chan struct{})
+ errChan := make(chan error)
+
+ work := func() {
+ for n := range workChan {
+ action := actions[rand.Intn(len(actions))]
+ err := action.fn(pool, n)
+ if err != nil {
+ errChan <- err
+ break
+ }
+ }
+ doneChan <- struct{}{}
+ }
+
+ for i := 0; i < workerCount; i++ {
+ go work()
+ }
+
+ var stop bool
+ for i := 0; !stop; i++ {
+ select {
+ case <-timer.C:
+ stop = true
+ case workChan <- i:
+ case err := <-errChan:
+ close(workChan)
+ t.Fatal(err)
+ }
+ }
+ close(workChan)
+
+ for i := 0; i < workerCount; i++ {
+ <-doneChan
+ }
+}
+
+func TestStressTLSConnection(t *testing.T) {
+ t.Parallel()
+
+ if tlsConnConfig == nil {
+ t.Skip("Skipping due to undefined tlsConnConfig")
+ }
+
+ if testing.Short() {
+ t.Skip("Skipping due to testing -short")
+ }
+
+ conn, err := pgx.Connect(*tlsConnConfig)
+ if err != nil {
+ t.Fatalf("Unable to establish connection: %v", err)
+ }
+ defer conn.Close()
+
+ for i := 0; i < 50; i++ {
+ sql := `select * from generate_series(1, $1)`
+
+ rows, err := conn.Query(sql, 2000000)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var n int32
+ for rows.Next() {
+ rows.Scan(&n)
+ }
+
+ if rows.Err() != nil {
+ t.Fatalf("queryCount: %d, Row number: %d. %v", i, n, rows.Err())
+ }
+ }
+}
+
+func setupStressDB(t *testing.T, pool *pgx.ConnPool) {
+ _, err := pool.Exec(`
+ drop table if exists widgets;
+ create table widgets(
+ id serial primary key,
+ name varchar not null,
+ description text,
+ creation_time timestamptz
+ );
+`)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func insertUnprepared(e execer, actionNum int) error {
+ sql := `
+ insert into widgets(name, description, creation_time)
+ values($1, $2, $3)`
+
+ _, err := e.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
+ return err
+}
+
+func queryRowWithoutParams(qr queryRower, actionNum int) error {
+ var id int32
+ var name, description string
+ var creationTime time.Time
+
+ sql := `select * from widgets order by random() limit 1`
+
+ err := qr.QueryRow(sql).Scan(&id, &name, &description, &creationTime)
+ if err == pgx.ErrNoRows {
+ return nil
+ }
+ return err
+}
+
+func query(q queryer, actionNum int) error {
+ sql := `select * from widgets order by random() limit $1`
+
+ rows, err := q.Query(sql, 10)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var id int32
+ var name, description string
+ var creationTime time.Time
+ rows.Scan(&id, &name, &description, &creationTime)
+ }
+
+ return rows.Err()
+}
+
+func queryCloseEarly(q queryer, actionNum int) error {
+ sql := `select * from generate_series(1,$1)`
+
+ rows, err := q.Query(sql, 100)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ for i := 0; i < 10 && rows.Next(); i++ {
+ var n int32
+ rows.Scan(&n)
+ }
+ rows.Close()
+
+ return rows.Err()
+}
+
+func queryErrorWhileReturningRows(q queryer, actionNum int) error {
+ // This query should divide by 0 within the first number of rows
+ sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
+
+ rows, err := q.Query(sql)
+ if err != nil {
+ return nil
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var n int32
+ rows.Scan(&n)
+ }
+
+ if _, ok := rows.Err().(pgx.PgError); ok {
+ return nil
+ }
+ return rows.Err()
+}
+
+func notify(pool *pgx.ConnPool, actionNum int) error {
+ _, err := pool.Exec("notify stress")
+ return err
+}
+
+func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error {
+ conn, err := pool.Acquire()
+ if err != nil {
+ return err
+ }
+ defer pool.Release(conn)
+
+ err = conn.Listen("stress")
+ if err != nil {
+ return err
+ }
+
+ _, err = conn.WaitForNotification(100 * time.Millisecond)
+ if err == pgx.ErrNotificationTimeout {
+ return nil
+ }
+ return err
+}
+
+func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error {
+ psName := fmt.Sprintf("poolPreparedStatement%d", actionNum)
+
+ _, err := pool.Prepare(psName, "select $1::text")
+ if err != nil {
+ return err
+ }
+
+ var s string
+ err = pool.QueryRow(psName, "hello").Scan(&s)
+ if err != nil {
+ return err
+ }
+
+ if s != "hello" {
+ return fmt.Errorf("Prepared statement did not return expected value: %v", s)
+ }
+
+ return pool.Deallocate(psName)
+}
+
+func txInsertRollback(pool *pgx.ConnPool, actionNum int) error {
+ tx, err := pool.Begin()
+ if err != nil {
+ return err
+ }
+
+ sql := `
+ insert into widgets(name, description, creation_time)
+ values($1, $2, $3)`
+
+ _, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
+ if err != nil {
+ return err
+ }
+
+ return tx.Rollback()
+}
+
+func txInsertCommit(pool *pgx.ConnPool, actionNum int) error {
+ tx, err := pool.Begin()
+ if err != nil {
+ return err
+ }
+
+ sql := `
+ insert into widgets(name, description, creation_time)
+ values($1, $2, $3)`
+
+ _, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now())
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
+ tx, err := pool.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ errExpectedTxDeath := errors.New("Expected tx death")
+
+ actions := []struct {
+ name string
+ fn func() error
+ }{
+ {"insertUnprepared", func() error { return insertUnprepared(tx, actionNum) }},
+ {"queryRowWithoutParams", func() error { return queryRowWithoutParams(tx, actionNum) }},
+ {"query", func() error { return query(tx, actionNum) }},
+ {"queryCloseEarly", func() error { return queryCloseEarly(tx, actionNum) }},
+ {"queryErrorWhileReturningRows", func() error {
+ err := queryErrorWhileReturningRows(tx, actionNum)
+ if err != nil {
+ return err
+ }
+ return errExpectedTxDeath
+ }},
+ }
+
+ for i := 0; i < 20; i++ {
+ action := actions[rand.Intn(len(actions))]
+ err := action.fn()
+ if err == errExpectedTxDeath {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
diff --git a/vendor/github.com/jackc/pgx/tx.go b/vendor/github.com/jackc/pgx/tx.go
new file mode 100644
index 0000000..deb6c01
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/tx.go
@@ -0,0 +1,207 @@
+package pgx
+
+import (
+ "errors"
+ "fmt"
+)
+
+// Transaction isolation levels
+const (
+ Serializable = "serializable"
+ RepeatableRead = "repeatable read"
+ ReadCommitted = "read committed"
+ ReadUncommitted = "read uncommitted"
+)
+
+const (
+ TxStatusInProgress = 0
+ TxStatusCommitFailure = -1
+ TxStatusRollbackFailure = -2
+ TxStatusCommitSuccess = 1
+ TxStatusRollbackSuccess = 2
+)
+
+var ErrTxClosed = errors.New("tx is closed")
+
+// ErrTxCommitRollback occurs when an error has occurred in a transaction and
+// Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but
+// it is treated as ROLLBACK.
+var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
+
+// Begin starts a transaction with the default isolation level for the current
+// connection. To use a specific isolation level see BeginIso.
+func (c *Conn) Begin() (*Tx, error) {
+ return c.begin("")
+}
+
+// BeginIso starts a transaction with isoLevel as the transaction isolation
+// level.
+//
+// Valid isolation levels (and their constants) are:
+// serializable (pgx.Serializable)
+// repeatable read (pgx.RepeatableRead)
+// read committed (pgx.ReadCommitted)
+// read uncommitted (pgx.ReadUncommitted)
+func (c *Conn) BeginIso(isoLevel string) (*Tx, error) {
+ return c.begin(isoLevel)
+}
+
+func (c *Conn) begin(isoLevel string) (*Tx, error) {
+ var beginSQL string
+ if isoLevel == "" {
+ beginSQL = "begin"
+ } else {
+ beginSQL = fmt.Sprintf("begin isolation level %s", isoLevel)
+ }
+
+ _, err := c.Exec(beginSQL)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Tx{conn: c}, nil
+}
+
+// Tx represents a database transaction.
+//
+// All Tx methods return ErrTxClosed if Commit or Rollback has already been
+// called on the Tx.
+type Tx struct {
+ conn *Conn
+ afterClose func(*Tx)
+ err error
+ status int8
+}
+
+// Commit commits the transaction
+func (tx *Tx) Commit() error {
+ if tx.status != TxStatusInProgress {
+ return ErrTxClosed
+ }
+
+ commandTag, err := tx.conn.Exec("commit")
+ if err == nil && commandTag == "COMMIT" {
+ tx.status = TxStatusCommitSuccess
+ } else if err == nil && commandTag == "ROLLBACK" {
+ tx.status = TxStatusCommitFailure
+ tx.err = ErrTxCommitRollback
+ } else {
+ tx.status = TxStatusCommitFailure
+ tx.err = err
+ }
+
+ if tx.afterClose != nil {
+ tx.afterClose(tx)
+ }
+ return tx.err
+}
+
+// Rollback rolls back the transaction. Rollback will return ErrTxClosed if the
+// Tx is already closed, but is otherwise safe to call multiple times. Hence, a
+// defer tx.Rollback() is safe even if tx.Commit() will be called first in a
+// non-error condition.
+func (tx *Tx) Rollback() error {
+ if tx.status != TxStatusInProgress {
+ return ErrTxClosed
+ }
+
+ _, tx.err = tx.conn.Exec("rollback")
+ if tx.err == nil {
+ tx.status = TxStatusRollbackSuccess
+ } else {
+ tx.status = TxStatusRollbackFailure
+ }
+
+ if tx.afterClose != nil {
+ tx.afterClose(tx)
+ }
+ return tx.err
+}
+
+// Exec delegates to the underlying *Conn
+func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
+ if tx.status != TxStatusInProgress {
+ return CommandTag(""), ErrTxClosed
+ }
+
+ return tx.conn.Exec(sql, arguments...)
+}
+
+// Prepare delegates to the underlying *Conn
+func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) {
+ return tx.PrepareEx(name, sql, nil)
+}
+
+// PrepareEx delegates to the underlying *Conn
+func (tx *Tx) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
+ if tx.status != TxStatusInProgress {
+ return nil, ErrTxClosed
+ }
+
+ return tx.conn.PrepareEx(name, sql, opts)
+}
+
+// Query delegates to the underlying *Conn
+func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) {
+ if tx.status != TxStatusInProgress {
+ // Because checking for errors can be deferred to the *Rows, build one with the error
+ err := ErrTxClosed
+ return &Rows{closed: true, err: err}, err
+ }
+
+ return tx.conn.Query(sql, args...)
+}
+
+// QueryRow delegates to the underlying *Conn
+func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row {
+ rows, _ := tx.Query(sql, args...)
+ return (*Row)(rows)
+}
+
+// Deprecated. Use CopyFrom instead. CopyTo delegates to the underlying *Conn
+func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) {
+ if tx.status != TxStatusInProgress {
+ return 0, ErrTxClosed
+ }
+
+ return tx.conn.CopyTo(tableName, columnNames, rowSrc)
+}
+
+// CopyFrom delegates to the underlying *Conn
+func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
+ if tx.status != TxStatusInProgress {
+ return 0, ErrTxClosed
+ }
+
+ return tx.conn.CopyFrom(tableName, columnNames, rowSrc)
+}
+
+// Conn returns the *Conn this transaction is using.
+func (tx *Tx) Conn() *Conn {
+ return tx.conn
+}
+
+// Status returns the status of the transaction from the set of
+// pgx.TxStatus* constants.
+func (tx *Tx) Status() int8 {
+ return tx.status
+}
+
+// Err returns the final error state, if any, of calling Commit or Rollback.
+func (tx *Tx) Err() error {
+ return tx.err
+}
+
+// AfterClose adds f to a LILO queue of functions that will be called when
+// the transaction is closed (either Commit or Rollback).
+func (tx *Tx) AfterClose(f func(*Tx)) {
+ if tx.afterClose == nil {
+ tx.afterClose = f
+ } else {
+ prevFn := tx.afterClose
+ tx.afterClose = func(tx *Tx) {
+ f(tx)
+ prevFn(tx)
+ }
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/tx_test.go b/vendor/github.com/jackc/pgx/tx_test.go
new file mode 100644
index 0000000..435521a
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/tx_test.go
@@ -0,0 +1,297 @@
+package pgx_test
+
+import (
+ "github.com/jackc/pgx"
+ "testing"
+ "time"
+)
+
+func TestTransactionSuccessfulCommit(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ createSql := `
+ create temporary table foo(
+ id integer,
+ unique (id) initially deferred
+ );
+ `
+
+ if _, err := conn.Exec(createSql); err != nil {
+ t.Fatalf("Failed to create table: %v", err)
+ }
+
+ tx, err := conn.Begin()
+ if err != nil {
+ t.Fatalf("conn.Begin failed: %v", err)
+ }
+
+ _, err = tx.Exec("insert into foo(id) values (1)")
+ if err != nil {
+ t.Fatalf("tx.Exec failed: %v", err)
+ }
+
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("tx.Commit failed: %v", err)
+ }
+
+ var n int64
+ err = conn.QueryRow("select count(*) from foo").Scan(&n)
+ if err != nil {
+ t.Fatalf("QueryRow Scan failed: %v", err)
+ }
+ if n != 1 {
+ t.Fatalf("Did not receive correct number of rows: %v", n)
+ }
+}
+
+func TestTxCommitWhenTxBroken(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ createSql := `
+ create temporary table foo(
+ id integer,
+ unique (id) initially deferred
+ );
+ `
+
+ if _, err := conn.Exec(createSql); err != nil {
+ t.Fatalf("Failed to create table: %v", err)
+ }
+
+ tx, err := conn.Begin()
+ if err != nil {
+ t.Fatalf("conn.Begin failed: %v", err)
+ }
+
+ if _, err := tx.Exec("insert into foo(id) values (1)"); err != nil {
+ t.Fatalf("tx.Exec failed: %v", err)
+ }
+
+ // Purposely break transaction
+ if _, err := tx.Exec("syntax error"); err == nil {
+ t.Fatal("Unexpected success")
+ }
+
+ err = tx.Commit()
+ if err != pgx.ErrTxCommitRollback {
+ t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
+ }
+
+ var n int64
+ err = conn.QueryRow("select count(*) from foo").Scan(&n)
+ if err != nil {
+ t.Fatalf("QueryRow Scan failed: %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("Did not receive correct number of rows: %v", n)
+ }
+}
+
+func TestTxCommitSerializationFailure(t *testing.T) {
+ t.Parallel()
+
+ pool := createConnPool(t, 5)
+ defer pool.Close()
+
+ pool.Exec(`drop table if exists tx_serializable_sums`)
+ _, err := pool.Exec(`create table tx_serializable_sums(num integer);`)
+ if err != nil {
+ t.Fatalf("Unable to create temporary table: %v", err)
+ }
+ defer pool.Exec(`drop table tx_serializable_sums`)
+
+ tx1, err := pool.BeginIso(pgx.Serializable)
+ if err != nil {
+ t.Fatalf("BeginIso failed: %v", err)
+ }
+ defer tx1.Rollback()
+
+ tx2, err := pool.BeginIso(pgx.Serializable)
+ if err != nil {
+ t.Fatalf("BeginIso failed: %v", err)
+ }
+ defer tx2.Rollback()
+
+ _, err = tx1.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
+ if err != nil {
+ t.Fatalf("Exec failed: %v", err)
+ }
+
+ _, err = tx2.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
+ if err != nil {
+ t.Fatalf("Exec failed: %v", err)
+ }
+
+ err = tx1.Commit()
+ if err != nil {
+ t.Fatalf("Commit failed: %v", err)
+ }
+
+ err = tx2.Commit()
+ if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "40001" {
+ t.Fatalf("Expected serialization error 40001, got %#v", err)
+ }
+}
+
+func TestTransactionSuccessfulRollback(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ createSql := `
+ create temporary table foo(
+ id integer,
+ unique (id) initially deferred
+ );
+ `
+
+ if _, err := conn.Exec(createSql); err != nil {
+ t.Fatalf("Failed to create table: %v", err)
+ }
+
+ tx, err := conn.Begin()
+ if err != nil {
+ t.Fatalf("conn.Begin failed: %v", err)
+ }
+
+ _, err = tx.Exec("insert into foo(id) values (1)")
+ if err != nil {
+ t.Fatalf("tx.Exec failed: %v", err)
+ }
+
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback failed: %v", err)
+ }
+
+ var n int64
+ err = conn.QueryRow("select count(*) from foo").Scan(&n)
+ if err != nil {
+ t.Fatalf("QueryRow Scan failed: %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("Did not receive correct number of rows: %v", n)
+ }
+}
+
+func TestBeginIso(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
+ for _, iso := range isoLevels {
+ tx, err := conn.BeginIso(iso)
+ if err != nil {
+ t.Fatalf("conn.BeginIso failed: %v", err)
+ }
+
+ var level string
+ conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level)
+ if level != iso {
+ t.Errorf("Expected to be in isolation level %v but was %v", iso, level)
+ }
+
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback failed: %v", err)
+ }
+ }
+}
+
+func TestTxAfterClose(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tx, err := conn.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var zeroTime, t1, t2 time.Time
+ tx.AfterClose(func(tx *pgx.Tx) {
+ t1 = time.Now()
+ })
+
+ tx.AfterClose(func(tx *pgx.Tx) {
+ t2 = time.Now()
+ })
+
+ tx.Rollback()
+
+ if t1 == zeroTime {
+ t.Error("First Tx.AfterClose callback not called")
+ }
+
+ if t2 == zeroTime {
+ t.Error("Second Tx.AfterClose callback not called")
+ }
+
+ if t1.Before(t2) {
+ t.Errorf("AfterClose callbacks called out of order: %v, %v", t1, t2)
+ }
+}
+
+func TestTxStatus(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tx, err := conn.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if status := tx.Status(); status != pgx.TxStatusInProgress {
+ t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
+ }
+
+ if err := tx.Rollback(); err != nil {
+ t.Fatal(err)
+ }
+
+ if status := tx.Status(); status != pgx.TxStatusRollbackSuccess {
+ t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
+ }
+}
+
+func TestTxErr(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tx, err := conn.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Purposely break transaction
+ if _, err := tx.Exec("syntax error"); err == nil {
+ t.Fatal("Unexpected success")
+ }
+
+ if err := tx.Commit(); err != pgx.ErrTxCommitRollback {
+ t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
+ }
+
+ if status := tx.Status(); status != pgx.TxStatusCommitFailure {
+ t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
+ }
+
+ if err := tx.Err(); err != pgx.ErrTxCommitRollback {
+ t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/value_reader.go b/vendor/github.com/jackc/pgx/value_reader.go
new file mode 100644
index 0000000..a489754
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/value_reader.go
@@ -0,0 +1,156 @@
+package pgx
+
+import (
+ "errors"
+)
+
+// ValueReader is used by the Scanner interface to decode values.
+type ValueReader struct {
+ mr *msgReader
+ fd *FieldDescription
+ valueBytesRemaining int32
+ err error
+}
+
+// Err returns any error that the ValueReader has experienced
+func (r *ValueReader) Err() error {
+ return r.err
+}
+
+// Fatal tells r that a Fatal error has occurred
+func (r *ValueReader) Fatal(err error) {
+ r.err = err
+}
+
+// Len returns the number of unread bytes
+func (r *ValueReader) Len() int32 {
+ return r.valueBytesRemaining
+}
+
+// Type returns the *FieldDescription of the value
+func (r *ValueReader) Type() *FieldDescription {
+ return r.fd
+}
+
+func (r *ValueReader) ReadByte() byte {
+ if r.err != nil {
+ return 0
+ }
+
+ r.valueBytesRemaining--
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return 0
+ }
+
+ return r.mr.readByte()
+}
+
+func (r *ValueReader) ReadInt16() int16 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.valueBytesRemaining -= 2
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return 0
+ }
+
+ return r.mr.readInt16()
+}
+
+func (r *ValueReader) ReadUint16() uint16 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.valueBytesRemaining -= 2
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return 0
+ }
+
+ return r.mr.readUint16()
+}
+
+func (r *ValueReader) ReadInt32() int32 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.valueBytesRemaining -= 4
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return 0
+ }
+
+ return r.mr.readInt32()
+}
+
+func (r *ValueReader) ReadUint32() uint32 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.valueBytesRemaining -= 4
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return 0
+ }
+
+ return r.mr.readUint32()
+}
+
+func (r *ValueReader) ReadInt64() int64 {
+ if r.err != nil {
+ return 0
+ }
+
+ r.valueBytesRemaining -= 8
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return 0
+ }
+
+ return r.mr.readInt64()
+}
+
+func (r *ValueReader) ReadOid() Oid {
+ return Oid(r.ReadUint32())
+}
+
+// ReadString reads count bytes and returns as string
+func (r *ValueReader) ReadString(count int32) string {
+ if r.err != nil {
+ return ""
+ }
+
+ r.valueBytesRemaining -= count
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return ""
+ }
+
+ return r.mr.readString(count)
+}
+
+// ReadBytes reads count bytes and returns as []byte
+func (r *ValueReader) ReadBytes(count int32) []byte {
+ if r.err != nil {
+ return nil
+ }
+
+ if count < 0 {
+ r.Fatal(errors.New("count must not be negative"))
+ return nil
+ }
+
+ r.valueBytesRemaining -= count
+ if r.valueBytesRemaining < 0 {
+ r.Fatal(errors.New("read past end of value"))
+ return nil
+ }
+
+ return r.mr.readBytes(count)
+}
diff --git a/vendor/github.com/jackc/pgx/values.go b/vendor/github.com/jackc/pgx/values.go
new file mode 100644
index 0000000..a189e18
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/values.go
@@ -0,0 +1,3439 @@
+package pgx
+
+import (
+ "bytes"
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+ "io"
+ "math"
+ "net"
+ "reflect"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// PostgreSQL oids for common types
+const (
+ BoolOid = 16
+ ByteaOid = 17
+ CharOid = 18
+ NameOid = 19
+ Int8Oid = 20
+ Int2Oid = 21
+ Int4Oid = 23
+ TextOid = 25
+ OidOid = 26
+ TidOid = 27
+ XidOid = 28
+ CidOid = 29
+ JsonOid = 114
+ CidrOid = 650
+ CidrArrayOid = 651
+ Float4Oid = 700
+ Float8Oid = 701
+ UnknownOid = 705
+ InetOid = 869
+ BoolArrayOid = 1000
+ Int2ArrayOid = 1005
+ Int4ArrayOid = 1007
+ TextArrayOid = 1009
+ ByteaArrayOid = 1001
+ VarcharArrayOid = 1015
+ Int8ArrayOid = 1016
+ Float4ArrayOid = 1021
+ Float8ArrayOid = 1022
+ AclItemOid = 1033
+ AclItemArrayOid = 1034
+ InetArrayOid = 1041
+ VarcharOid = 1043
+ DateOid = 1082
+ TimestampOid = 1114
+ TimestampArrayOid = 1115
+ TimestampTzOid = 1184
+ TimestampTzArrayOid = 1185
+ RecordOid = 2249
+ UuidOid = 2950
+ JsonbOid = 3802
+)
+
+// PostgreSQL format codes
+const (
+ TextFormatCode = 0
+ BinaryFormatCode = 1
+)
+
+const maxUint = ^uint(0)
+const maxInt = int(maxUint >> 1)
+const minInt = -maxInt - 1
+
+// DefaultTypeFormats maps type names to their default requested format (text
+// or binary). In theory the Scanner interface should be the one to determine
+// the format of the returned values. However, the query has already been
+// executed by the time Scan is called so it has no chance to set the format.
+// So for types that should always be returned in binary the format should be
+// set here.
+var DefaultTypeFormats map[string]int16
+
+func init() {
+ DefaultTypeFormats = map[string]int16{
+ "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
+ "_bool": BinaryFormatCode,
+ "_bytea": BinaryFormatCode,
+ "_cidr": BinaryFormatCode,
+ "_float4": BinaryFormatCode,
+ "_float8": BinaryFormatCode,
+ "_inet": BinaryFormatCode,
+ "_int2": BinaryFormatCode,
+ "_int4": BinaryFormatCode,
+ "_int8": BinaryFormatCode,
+ "_text": BinaryFormatCode,
+ "_timestamp": BinaryFormatCode,
+ "_timestamptz": BinaryFormatCode,
+ "_varchar": BinaryFormatCode,
+ "aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
+ "bool": BinaryFormatCode,
+ "bytea": BinaryFormatCode,
+ "char": BinaryFormatCode,
+ "cid": BinaryFormatCode,
+ "cidr": BinaryFormatCode,
+ "date": BinaryFormatCode,
+ "float4": BinaryFormatCode,
+ "float8": BinaryFormatCode,
+ "json": BinaryFormatCode,
+ "jsonb": BinaryFormatCode,
+ "inet": BinaryFormatCode,
+ "int2": BinaryFormatCode,
+ "int4": BinaryFormatCode,
+ "int8": BinaryFormatCode,
+ "name": BinaryFormatCode,
+ "oid": BinaryFormatCode,
+ "record": BinaryFormatCode,
+ "text": BinaryFormatCode,
+ "tid": BinaryFormatCode,
+ "timestamp": BinaryFormatCode,
+ "timestamptz": BinaryFormatCode,
+ "varchar": BinaryFormatCode,
+ "xid": BinaryFormatCode,
+ }
+}
+
+// SerializationError occurs on failure to encode or decode a value
+type SerializationError string
+
+func (e SerializationError) Error() string {
+ return string(e)
+}
+
+// Deprecated: Scanner is an interface used to decode values from the PostgreSQL
+// server. To allow types to support pgx and database/sql.Scan this interface
+// has been deprecated in favor of PgxScanner.
+type Scanner interface {
+ // Scan MUST check r.Type().DataType (to check by OID) or
+ // r.Type().DataTypeName (to check by name) to ensure that it is scanning an
+ // expected column type. It also MUST check r.Type().FormatCode before
+ // decoding. It should not assume that it was called on a data type or format
+ // that it understands.
+ Scan(r *ValueReader) error
+}
+
+// PgxScanner is an interface used to decode values from the PostgreSQL server.
+// It is used exactly the same as the Scanner interface. It simply has renamed
+// the method.
+type PgxScanner interface {
+ // ScanPgx MUST check r.Type().DataType (to check by OID) or
+ // r.Type().DataTypeName (to check by name) to ensure that it is scanning an
+ // expected column type. It also MUST check r.Type().FormatCode before
+ // decoding. It should not assume that it was called on a data type or format
+ // that it understands.
+ ScanPgx(r *ValueReader) error
+}
+
+// Encoder is an interface used to encode values for transmission to the
+// PostgreSQL server.
+type Encoder interface {
+ // Encode writes the value to w.
+ //
+ // If the value is NULL an int32(-1) should be written.
+ //
+ // Encode MUST check oid to see if the parameter data type is compatible. If
+ // this is not done, the PostgreSQL server may detect the error if the
+ // expected data size or format of the encoded data does not match. But if
+ // the encoded data is a valid representation of the data type PostgreSQL
+ // expects such as date and int4, incorrect data may be stored.
+ Encode(w *WriteBuf, oid Oid) error
+
+ // FormatCode returns the format that the encoder writes the value. It must be
+ // either pgx.TextFormatCode or pgx.BinaryFormatCode.
+ FormatCode() int16
+}
+
+// NullFloat32 represents an float4 that may be null. NullFloat32 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullFloat32 struct {
+ Float32 float32
+ Valid bool // Valid is true if Float32 is not NULL
+}
+
+func (n *NullFloat32) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Float4Oid {
+ return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Float32, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Float32 = decodeFloat4(vr)
+ return vr.Err()
+}
+
+func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Float4Oid {
+ return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeFloat32(w, oid, n.Float32)
+}
+
+// NullFloat64 represents an float8 that may be null. NullFloat64 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullFloat64 struct {
+ Float64 float64
+ Valid bool // Valid is true if Float64 is not NULL
+}
+
+func (n *NullFloat64) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Float8Oid {
+ return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Float64, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Float64 = decodeFloat8(vr)
+ return vr.Err()
+}
+
+func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Float8Oid {
+ return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeFloat64(w, oid, n.Float64)
+}
+
+// NullString represents an string that may be null. NullString implements the
+// Scanner Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullString struct {
+ String string
+ Valid bool // Valid is true if String is not NULL
+}
+
+func (n *NullString) Scan(vr *ValueReader) error {
+ // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later
+
+ if vr.Len() == -1 {
+ n.String, n.Valid = "", false
+ return nil
+ }
+
+ n.Valid = true
+ n.String = decodeText(vr)
+ return vr.Err()
+}
+
+func (n NullString) FormatCode() int16 { return TextFormatCode }
+
+func (n NullString) Encode(w *WriteBuf, oid Oid) error {
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeString(w, oid, n.String)
+}
+
+// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem
+// might look like this:
+//
+// postgres=arwdDxt/postgres
+//
+// Note, however, that because the user/role name part of an aclitem is
+// an identifier, it follows all the usual formatting rules for SQL
+// identifiers: if it contains spaces and other special characters,
+// it should appear in double-quotes:
+//
+// postgres=arwdDxt/"role with spaces"
+//
+type AclItem string
+
+// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullAclItem struct {
+ AclItem AclItem
+ Valid bool // Valid is true if AclItem is not NULL
+}
+
+func (n *NullAclItem) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != AclItemOid {
+ return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.AclItem, n.Valid = "", false
+ return nil
+ }
+
+ n.Valid = true
+ n.AclItem = AclItem(decodeText(vr))
+ return vr.Err()
+}
+
+// Particularly important to return TextFormatCode, seeing as Postgres
+// only ever sends aclitem as text, not binary.
+func (n NullAclItem) FormatCode() int16 { return TextFormatCode }
+
+func (n NullAclItem) Encode(w *WriteBuf, oid Oid) error {
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeString(w, oid, string(n.AclItem))
+}
+
+// Name is a type used for PostgreSQL's special 63-byte
+// name data type, used for identifiers like table names.
+// The pg_class.relname column is a good example of where the
+// name data type is used.
+//
+// Note that the underlying Go data type of pgx.Name is string,
+// so there is no way to enforce the 63-byte length. Inputting
+// a longer name into PostgreSQL will result in silent truncation
+// to 63 bytes.
+//
+// Also, if you have custom-compiled PostgreSQL and set
+// NAMEDATALEN to a different value, obviously that number of
+// bytes applies, rather than the default 63.
+type Name string
+
+// NullName represents a pgx.Name that may be null. NullName implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullName struct {
+ Name Name
+ Valid bool // Valid is true if Name is not NULL
+}
+
+func (n *NullName) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != NameOid {
+ return SerializationError(fmt.Sprintf("NullName.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Name, n.Valid = "", false
+ return nil
+ }
+
+ n.Valid = true
+ n.Name = Name(decodeText(vr))
+ return vr.Err()
+}
+
+func (n NullName) FormatCode() int16 { return TextFormatCode }
+
+func (n NullName) Encode(w *WriteBuf, oid Oid) error {
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeString(w, oid, string(n.Name))
+}
+
+// The pgx.Char type is for PostgreSQL's special 8-bit-only
+// "char" type more akin to the C language's char type, or Go's byte type.
+// (Note that the name in PostgreSQL itself is "char", in double-quotes,
+// and not char.) It gets used a lot in PostgreSQL's system tables to hold
+// a single ASCII character value (eg pg_class.relkind).
+type Char byte
+
+// NullChar represents a pgx.Char that may be null. NullChar implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullChar struct {
+ Char Char
+ Valid bool // Valid is true if Char is not NULL
+}
+
+func (n *NullChar) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != CharOid {
+ return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Char, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Char = decodeChar(vr)
+ return vr.Err()
+}
+
+func (n NullChar) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullChar) Encode(w *WriteBuf, oid Oid) error {
+ if oid != CharOid {
+ return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeChar(w, oid, n.Char)
+}
+
+// NullInt16 represents a smallint that may be null. NullInt16 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan for prepared and unprepared queries.
+//
+// If Valid is false then the value is NULL.
+type NullInt16 struct {
+ Int16 int16
+ Valid bool // Valid is true if Int16 is not NULL
+}
+
+func (n *NullInt16) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Int2Oid {
+ return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Int16, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Int16 = decodeInt2(vr)
+ return vr.Err()
+}
+
+func (n NullInt16) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullInt16) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Int2Oid {
+ return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeInt16(w, oid, n.Int16)
+}
+
+// NullInt32 represents an integer that may be null. NullInt32 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullInt32 struct {
+ Int32 int32
+ Valid bool // Valid is true if Int32 is not NULL
+}
+
+func (n *NullInt32) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Int4Oid {
+ return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Int32, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Int32 = decodeInt4(vr)
+ return vr.Err()
+}
+
+func (n NullInt32) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullInt32) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Int4Oid {
+ return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeInt32(w, oid, n.Int32)
+}
+
+// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html,
+// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented
+// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h
+// in the PostgreSQL sources.
+type Oid uint32
+
+// NullOid represents a Command Identifier (Oid) that may be null. NullOid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullOid struct {
+ Oid Oid
+ Valid bool // Valid is true if Oid is not NULL
+}
+
+func (n *NullOid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != OidOid {
+ return SerializationError(fmt.Sprintf("NullOid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Oid, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Oid = decodeOid(vr)
+ return vr.Err()
+}
+
+func (n NullOid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullOid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != OidOid {
+ return SerializationError(fmt.Sprintf("NullOid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeOid(w, oid, n.Oid)
+}
+
+// Xid is PostgreSQL's Transaction ID type.
+//
+// In later versions of PostgreSQL, it is the type used for the backend_xid
+// and backend_xmin columns of the pg_stat_activity system view.
+//
+// Also, when one does
+//
+// select xmin, xmax, * from some_table;
+//
+// it is the data type of the xmin and xmax hidden system columns.
+//
+// It is currently implemented as an unsigned four byte integer.
+// Its definition can be found in src/include/postgres_ext.h as TransactionId
+// in the PostgreSQL sources.
+type Xid uint32
+
+// NullXid represents a Transaction ID (Xid) that may be null. NullXid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullXid struct {
+ Xid Xid
+ Valid bool // Valid is true if Xid is not NULL
+}
+
+func (n *NullXid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != XidOid {
+ return SerializationError(fmt.Sprintf("NullXid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Xid, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Xid = decodeXid(vr)
+ return vr.Err()
+}
+
+func (n NullXid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullXid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != XidOid {
+ return SerializationError(fmt.Sprintf("NullXid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeXid(w, oid, n.Xid)
+}
+
+// Cid is PostgreSQL's Command Identifier type.
+//
+// When one does
+//
+// select cmin, cmax, * from some_table;
+//
+// it is the data type of the cmin and cmax hidden system columns.
+//
+// It is currently implemented as an unsigned four byte integer.
+// Its definition can be found in src/include/c.h as CommandId
+// in the PostgreSQL sources.
+type Cid uint32
+
+// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullCid struct {
+ Cid Cid
+ Valid bool // Valid is true if Cid is not NULL
+}
+
+func (n *NullCid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != CidOid {
+ return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Cid, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Cid = decodeCid(vr)
+ return vr.Err()
+}
+
+func (n NullCid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullCid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != CidOid {
+ return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeCid(w, oid, n.Cid)
+}
+
+// Tid is PostgreSQL's Tuple Identifier type.
+//
+// When one does
+//
+// select ctid, * from some_table;
+//
+// it is the data type of the ctid hidden system column.
+//
+// It is currently implemented as a pair unsigned two byte integers.
+// Its conversion functions can be found in src/backend/utils/adt/tid.c
+// in the PostgreSQL sources.
+type Tid struct {
+ BlockNumber uint32
+ OffsetNumber uint16
+}
+
+// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullTid struct {
+ Tid Tid
+ Valid bool // Valid is true if Tid is not NULL
+}
+
+func (n *NullTid) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != TidOid {
+ return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false
+ return nil
+ }
+ n.Valid = true
+ n.Tid = decodeTid(vr)
+ return vr.Err()
+}
+
+func (n NullTid) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullTid) Encode(w *WriteBuf, oid Oid) error {
+ if oid != TidOid {
+ return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeTid(w, oid, n.Tid)
+}
+
+// NullInt64 represents an bigint that may be null. NullInt64 implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullInt64 struct {
+ Int64 int64
+ Valid bool // Valid is true if Int64 is not NULL
+}
+
+func (n *NullInt64) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != Int8Oid {
+ return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Int64, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ n.Int64 = decodeInt8(vr)
+ return vr.Err()
+}
+
+func (n NullInt64) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullInt64) Encode(w *WriteBuf, oid Oid) error {
+ if oid != Int8Oid {
+ return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeInt64(w, oid, n.Int64)
+}
+
+// NullBool represents an bool that may be null. NullBool implements the Scanner
+// and Encoder interfaces so it may be used both as an argument to Query[Row]
+// and a destination for Scan.
+//
+// If Valid is false then the value is NULL.
+type NullBool struct {
+ Bool bool
+ Valid bool // Valid is true if Bool is not NULL
+}
+
+func (n *NullBool) Scan(vr *ValueReader) error {
+ if vr.Type().DataType != BoolOid {
+ return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Bool, n.Valid = false, false
+ return nil
+ }
+ n.Valid = true
+ n.Bool = decodeBool(vr)
+ return vr.Err()
+}
+
+func (n NullBool) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullBool) Encode(w *WriteBuf, oid Oid) error {
+ if oid != BoolOid {
+ return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeBool(w, oid, n.Bool)
+}
+
+// NullTime represents an time.Time that may be null. NullTime implements the
+// Scanner and Encoder interfaces so it may be used both as an argument to
+// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL
+// types timestamptz, timestamp, and date.
+//
+// If Valid is false then the value is NULL.
+type NullTime struct {
+ Time time.Time
+ Valid bool // Valid is true if Time is not NULL
+}
+
+func (n *NullTime) Scan(vr *ValueReader) error {
+ oid := vr.Type().DataType
+ if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid {
+ return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType))
+ }
+
+ if vr.Len() == -1 {
+ n.Time, n.Valid = time.Time{}, false
+ return nil
+ }
+
+ n.Valid = true
+ switch oid {
+ case TimestampTzOid:
+ n.Time = decodeTimestampTz(vr)
+ case TimestampOid:
+ n.Time = decodeTimestamp(vr)
+ case DateOid:
+ n.Time = decodeDate(vr)
+ }
+
+ return vr.Err()
+}
+
+func (n NullTime) FormatCode() int16 { return BinaryFormatCode }
+
+func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
+ if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid {
+ return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid))
+ }
+
+ if !n.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ return encodeTime(w, oid, n.Time)
+}
+
+// Hstore represents an hstore column. It does not support a null column or null
+// key values (use NullHstore for this). Hstore implements the Scanner and
+// Encoder interfaces so it may be used both as an argument to Query[Row] and a
+// destination for Scan.
+type Hstore map[string]string
+
+func (h *Hstore) Scan(vr *ValueReader) error {
+ //oid for hstore not standardized, so we check its type name
+ if vr.Type().DataTypeName != "hstore" {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName)))
+ return nil
+ }
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null column into Hstore"))
+ return nil
+ }
+
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ m, err := parseHstoreToMap(vr.ReadString(vr.Len()))
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
+ return nil
+ }
+ hm := Hstore(m)
+ *h = hm
+ return nil
+ case BinaryFormatCode:
+ vr.Fatal(ProtocolError("Can't decode binary hstore"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+}
+
+func (h Hstore) FormatCode() int16 { return TextFormatCode }
+
+func (h Hstore) Encode(w *WriteBuf, oid Oid) error {
+ var buf bytes.Buffer
+
+ i := 0
+ for k, v := range h {
+ i++
+ ks := strings.Replace(k, `\`, `\\`, -1)
+ ks = strings.Replace(ks, `"`, `\"`, -1)
+ vs := strings.Replace(v, `\`, `\\`, -1)
+ vs = strings.Replace(vs, `"`, `\"`, -1)
+ buf.WriteString(`"`)
+ buf.WriteString(ks)
+ buf.WriteString(`"=>"`)
+ buf.WriteString(vs)
+ buf.WriteString(`"`)
+ if i < len(h) {
+ buf.WriteString(", ")
+ }
+ }
+ w.WriteInt32(int32(buf.Len()))
+ w.WriteBytes(buf.Bytes())
+ return nil
+}
+
+// NullHstore represents an hstore column that can be null or have null values
+// associated with its keys. NullHstore implements the Scanner and Encoder
+// interfaces so it may be used both as an argument to Query[Row] and a
+// destination for Scan.
+//
+// If Valid is false, then the value of the entire hstore column is NULL
+// If any of the NullString values in Store has Valid set to false, the key
+// appears in the hstore column, but its value is explicitly set to NULL.
+type NullHstore struct {
+ Hstore map[string]NullString
+ Valid bool
+}
+
+func (h *NullHstore) Scan(vr *ValueReader) error {
+ //oid for hstore not standardized, so we check its type name
+ if vr.Type().DataTypeName != "hstore" {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName)))
+ return nil
+ }
+
+ if vr.Len() == -1 {
+ h.Valid = false
+ return nil
+ }
+
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len()))
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
+ return nil
+ }
+ h.Valid = true
+ h.Hstore = store
+ return nil
+ case BinaryFormatCode:
+ vr.Fatal(ProtocolError("Can't decode binary hstore"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+}
+
+func (h NullHstore) FormatCode() int16 { return TextFormatCode }
+
+func (h NullHstore) Encode(w *WriteBuf, oid Oid) error {
+ var buf bytes.Buffer
+
+ if !h.Valid {
+ w.WriteInt32(-1)
+ return nil
+ }
+
+ i := 0
+ for k, v := range h.Hstore {
+ i++
+ ks := strings.Replace(k, `\`, `\\`, -1)
+ ks = strings.Replace(ks, `"`, `\"`, -1)
+ if v.Valid {
+ vs := strings.Replace(v.String, `\`, `\\`, -1)
+ vs = strings.Replace(vs, `"`, `\"`, -1)
+ buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
+ } else {
+ buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks))
+ }
+ if i < len(h.Hstore) {
+ buf.WriteString(", ")
+ }
+ }
+ w.WriteInt32(int32(buf.Len()))
+ w.WriteBytes(buf.Bytes())
+ return nil
+}
+
+// Encode encodes arg into wbuf as the type oid. This allows implementations
+// of the Encoder interface to delegate the actual work of encoding to the
+// built-in functionality.
+func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
+ if arg == nil {
+ wbuf.WriteInt32(-1)
+ return nil
+ }
+
+ switch arg := arg.(type) {
+ case Encoder:
+ return arg.Encode(wbuf, oid)
+ case driver.Valuer:
+ v, err := arg.Value()
+ if err != nil {
+ return err
+ }
+ return Encode(wbuf, oid, v)
+ case string:
+ return encodeString(wbuf, oid, arg)
+ case []AclItem:
+ return encodeAclItemSlice(wbuf, oid, arg)
+ case []byte:
+ return encodeByteSlice(wbuf, oid, arg)
+ case [][]byte:
+ return encodeByteSliceSlice(wbuf, oid, arg)
+ }
+
+ refVal := reflect.ValueOf(arg)
+
+ if refVal.Kind() == reflect.Ptr {
+ if refVal.IsNil() {
+ wbuf.WriteInt32(-1)
+ return nil
+ }
+ arg = refVal.Elem().Interface()
+ return Encode(wbuf, oid, arg)
+ }
+
+ if oid == JsonOid {
+ return encodeJSON(wbuf, oid, arg)
+ }
+ if oid == JsonbOid {
+ return encodeJSONB(wbuf, oid, arg)
+ }
+
+ switch arg := arg.(type) {
+ case []string:
+ return encodeStringSlice(wbuf, oid, arg)
+ case bool:
+ return encodeBool(wbuf, oid, arg)
+ case []bool:
+ return encodeBoolSlice(wbuf, oid, arg)
+ case int:
+ return encodeInt(wbuf, oid, arg)
+ case uint:
+ return encodeUInt(wbuf, oid, arg)
+ case Char:
+ return encodeChar(wbuf, oid, arg)
+ case AclItem:
+ // The aclitem data type goes over the wire using the same format as string,
+ // so just cast to string and use encodeString
+ return encodeString(wbuf, oid, string(arg))
+ case Name:
+ // The name data type goes over the wire using the same format as string,
+ // so just cast to string and use encodeString
+ return encodeString(wbuf, oid, string(arg))
+ case int8:
+ return encodeInt8(wbuf, oid, arg)
+ case uint8:
+ return encodeUInt8(wbuf, oid, arg)
+ case int16:
+ return encodeInt16(wbuf, oid, arg)
+ case []int16:
+ return encodeInt16Slice(wbuf, oid, arg)
+ case uint16:
+ return encodeUInt16(wbuf, oid, arg)
+ case []uint16:
+ return encodeUInt16Slice(wbuf, oid, arg)
+ case int32:
+ return encodeInt32(wbuf, oid, arg)
+ case []int32:
+ return encodeInt32Slice(wbuf, oid, arg)
+ case uint32:
+ return encodeUInt32(wbuf, oid, arg)
+ case []uint32:
+ return encodeUInt32Slice(wbuf, oid, arg)
+ case int64:
+ return encodeInt64(wbuf, oid, arg)
+ case []int64:
+ return encodeInt64Slice(wbuf, oid, arg)
+ case uint64:
+ return encodeUInt64(wbuf, oid, arg)
+ case []uint64:
+ return encodeUInt64Slice(wbuf, oid, arg)
+ case float32:
+ return encodeFloat32(wbuf, oid, arg)
+ case []float32:
+ return encodeFloat32Slice(wbuf, oid, arg)
+ case float64:
+ return encodeFloat64(wbuf, oid, arg)
+ case []float64:
+ return encodeFloat64Slice(wbuf, oid, arg)
+ case time.Time:
+ return encodeTime(wbuf, oid, arg)
+ case []time.Time:
+ return encodeTimeSlice(wbuf, oid, arg)
+ case net.IP:
+ return encodeIP(wbuf, oid, arg)
+ case []net.IP:
+ return encodeIPSlice(wbuf, oid, arg)
+ case net.IPNet:
+ return encodeIPNet(wbuf, oid, arg)
+ case []net.IPNet:
+ return encodeIPNetSlice(wbuf, oid, arg)
+ case Oid:
+ return encodeOid(wbuf, oid, arg)
+ case Xid:
+ return encodeXid(wbuf, oid, arg)
+ case Cid:
+ return encodeCid(wbuf, oid, arg)
+ default:
+ if strippedArg, ok := stripNamedType(&refVal); ok {
+ return Encode(wbuf, oid, strippedArg)
+ }
+ return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
+ }
+}
+
+func stripNamedType(val *reflect.Value) (interface{}, bool) {
+ switch val.Kind() {
+ case reflect.Int:
+ return int(val.Int()), true
+ case reflect.Int8:
+ return int8(val.Int()), true
+ case reflect.Int16:
+ return int16(val.Int()), true
+ case reflect.Int32:
+ return int32(val.Int()), true
+ case reflect.Int64:
+ return int64(val.Int()), true
+ case reflect.Uint:
+ return uint(val.Uint()), true
+ case reflect.Uint8:
+ return uint8(val.Uint()), true
+ case reflect.Uint16:
+ return uint16(val.Uint()), true
+ case reflect.Uint32:
+ return uint32(val.Uint()), true
+ case reflect.Uint64:
+ return uint64(val.Uint()), true
+ case reflect.String:
+ return val.String(), true
+ }
+
+ return nil, false
+}
+
+// Decode decodes from vr into d. d must be a pointer. This allows
+// implementations of the Decoder interface to delegate the actual work of
+// decoding to the built-in functionality.
+func Decode(vr *ValueReader, d interface{}) error {
+ switch v := d.(type) {
+ case *bool:
+ *v = decodeBool(vr)
+ case *int:
+ n := decodeInt(vr)
+ if n < int64(minInt) {
+ return fmt.Errorf("%d is less than minimum value for int", n)
+ } else if n > int64(maxInt) {
+ return fmt.Errorf("%d is greater than maximum value for int", n)
+ }
+ *v = int(n)
+ case *int8:
+ n := decodeInt(vr)
+ if n < math.MinInt8 {
+ return fmt.Errorf("%d is less than minimum value for int8", n)
+ } else if n > math.MaxInt8 {
+ return fmt.Errorf("%d is greater than maximum value for int8", n)
+ }
+ *v = int8(n)
+ case *int16:
+ n := decodeInt(vr)
+ if n < math.MinInt16 {
+ return fmt.Errorf("%d is less than minimum value for int16", n)
+ } else if n > math.MaxInt16 {
+ return fmt.Errorf("%d is greater than maximum value for int16", n)
+ }
+ *v = int16(n)
+ case *int32:
+ n := decodeInt(vr)
+ if n < math.MinInt32 {
+ return fmt.Errorf("%d is less than minimum value for int32", n)
+ } else if n > math.MaxInt32 {
+ return fmt.Errorf("%d is greater than maximum value for int32", n)
+ }
+ *v = int32(n)
+ case *int64:
+ n := decodeInt(vr)
+ if n < math.MinInt64 {
+ return fmt.Errorf("%d is less than minimum value for int64", n)
+ } else if n > math.MaxInt64 {
+ return fmt.Errorf("%d is greater than maximum value for int64", n)
+ }
+ *v = int64(n)
+ case *uint:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint8", n)
+ } else if maxInt == math.MaxInt32 && n > math.MaxUint32 {
+ return fmt.Errorf("%d is greater than maximum value for uint", n)
+ }
+ *v = uint(n)
+ case *uint8:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint8", n)
+ } else if n > math.MaxUint8 {
+ return fmt.Errorf("%d is greater than maximum value for uint8", n)
+ }
+ *v = uint8(n)
+ case *uint16:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint16", n)
+ } else if n > math.MaxUint16 {
+ return fmt.Errorf("%d is greater than maximum value for uint16", n)
+ }
+ *v = uint16(n)
+ case *uint32:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint32", n)
+ } else if n > math.MaxUint32 {
+ return fmt.Errorf("%d is greater than maximum value for uint32", n)
+ }
+ *v = uint32(n)
+ case *uint64:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for uint64", n)
+ }
+ *v = uint64(n)
+ case *Char:
+ *v = decodeChar(vr)
+ case *AclItem:
+ // aclitem goes over the wire just like text
+ *v = AclItem(decodeText(vr))
+ case *Name:
+ // name goes over the wire just like text
+ *v = Name(decodeText(vr))
+ case *Oid:
+ *v = decodeOid(vr)
+ case *Xid:
+ *v = decodeXid(vr)
+ case *Tid:
+ *v = decodeTid(vr)
+ case *Cid:
+ *v = decodeCid(vr)
+ case *string:
+ *v = decodeText(vr)
+ case *float32:
+ *v = decodeFloat4(vr)
+ case *float64:
+ *v = decodeFloat8(vr)
+ case *[]AclItem:
+ *v = decodeAclItemArray(vr)
+ case *[]bool:
+ *v = decodeBoolArray(vr)
+ case *[]int16:
+ *v = decodeInt2Array(vr)
+ case *[]uint16:
+ *v = decodeInt2ArrayToUInt(vr)
+ case *[]int32:
+ *v = decodeInt4Array(vr)
+ case *[]uint32:
+ *v = decodeInt4ArrayToUInt(vr)
+ case *[]int64:
+ *v = decodeInt8Array(vr)
+ case *[]uint64:
+ *v = decodeInt8ArrayToUInt(vr)
+ case *[]float32:
+ *v = decodeFloat4Array(vr)
+ case *[]float64:
+ *v = decodeFloat8Array(vr)
+ case *[]string:
+ *v = decodeTextArray(vr)
+ case *[]time.Time:
+ *v = decodeTimestampArray(vr)
+ case *[][]byte:
+ *v = decodeByteaArray(vr)
+ case *[]interface{}:
+ *v = decodeRecord(vr)
+ case *time.Time:
+ switch vr.Type().DataType {
+ case DateOid:
+ *v = decodeDate(vr)
+ case TimestampTzOid:
+ *v = decodeTimestampTz(vr)
+ case TimestampOid:
+ *v = decodeTimestamp(vr)
+ default:
+ return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)
+ }
+ case *net.IP:
+ ipnet := decodeInet(vr)
+ if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
+ return fmt.Errorf("Cannot decode netmask into *net.IP")
+ }
+ *v = ipnet.IP
+ case *[]net.IP:
+ ipnets := decodeInetArray(vr)
+ ips := make([]net.IP, len(ipnets))
+ for i, ipnet := range ipnets {
+ if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
+ return fmt.Errorf("Cannot decode netmask into *net.IP")
+ }
+ ips[i] = ipnet.IP
+ }
+ *v = ips
+ case *net.IPNet:
+ *v = decodeInet(vr)
+ case *[]net.IPNet:
+ *v = decodeInetArray(vr)
+ default:
+ if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
+ el := v.Elem()
+ switch el.Kind() {
+ // if d is a pointer to pointer, strip the pointer and try again
+ case reflect.Ptr:
+ // -1 is a null value
+ if vr.Len() == -1 {
+ if !el.IsNil() {
+ // if the destination pointer is not nil, nil it out
+ el.Set(reflect.Zero(el.Type()))
+ }
+ return nil
+ }
+ if el.IsNil() {
+ // allocate destination
+ el.Set(reflect.New(el.Type().Elem()))
+ }
+ d = el.Interface()
+ return Decode(vr, d)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ n := decodeInt(vr)
+ if el.OverflowInt(n) {
+ return fmt.Errorf("Scan cannot decode %d into %T", n, d)
+ }
+ el.SetInt(n)
+ return nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ n := decodeInt(vr)
+ if n < 0 {
+ return fmt.Errorf("%d is less than zero for %T", n, d)
+ }
+ if el.OverflowUint(uint64(n)) {
+ return fmt.Errorf("Scan cannot decode %d into %T", n, d)
+ }
+ el.SetUint(uint64(n))
+ return nil
+ case reflect.String:
+ el.SetString(decodeText(vr))
+ return nil
+ }
+ }
+ return fmt.Errorf("Scan cannot decode into %T", d)
+ }
+
+ return nil
+}
+
+func decodeBool(vr *ValueReader) bool {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into bool"))
+ return false
+ }
+
+ if vr.Type().DataType != BoolOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
+ return false
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return false
+ }
+
+ if vr.Len() != 1 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len())))
+ return false
+ }
+
+ b := vr.ReadByte()
+ return b != 0
+}
+
+func encodeBool(w *WriteBuf, oid Oid, value bool) error {
+ if oid != BoolOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid)
+ }
+
+ w.WriteInt32(1)
+
+ var n byte
+ if value {
+ n = 1
+ }
+
+ w.WriteByte(n)
+
+ return nil
+}
+
+func decodeInt(vr *ValueReader) int64 {
+ switch vr.Type().DataType {
+ case Int2Oid:
+ return int64(decodeInt2(vr))
+ case Int4Oid:
+ return int64(decodeInt4(vr))
+ case Int8Oid:
+ return int64(decodeInt8(vr))
+ }
+
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType)))
+ return 0
+}
+
+func decodeInt8(vr *ValueReader) int64 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into int64"))
+ return 0
+ }
+
+ if vr.Type().DataType != Int8Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len())))
+ return 0
+ }
+
+ return vr.ReadInt64()
+}
+
+func decodeChar(vr *ValueReader) Char {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into char"))
+ return Char(0)
+ }
+
+ if vr.Type().DataType != CharOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType)))
+ return Char(0)
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Char(0)
+ }
+
+ if vr.Len() != 1 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len())))
+ return Char(0)
+ }
+
+ return Char(vr.ReadByte())
+}
+
+func decodeInt2(vr *ValueReader) int16 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into int16"))
+ return 0
+ }
+
+ if vr.Type().DataType != Int2Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 2 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len())))
+ return 0
+ }
+
+ return vr.ReadInt16()
+}
+
+func encodeInt(w *WriteBuf, oid Oid, value int) error {
+ switch oid {
+ case Int2Oid:
+ if value < math.MinInt16 {
+ return fmt.Errorf("%d is less than min pg:int2", value)
+ } else if value > math.MaxInt16 {
+ return fmt.Errorf("%d is greater than max pg:int2", value)
+ }
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ if value < math.MinInt32 {
+ return fmt.Errorf("%d is less than min pg:int4", value)
+ } else if value > math.MaxInt32 {
+ return fmt.Errorf("%d is greater than max pg:int4", value)
+ }
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ if int64(value) <= int64(math.MaxInt64) {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ } else {
+ return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64))
+ }
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int8", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt(w *WriteBuf, oid Oid, value uint) error {
+ switch oid {
+ case Int2Oid:
+ if value > math.MaxInt16 {
+ return fmt.Errorf("%d is greater than max pg:int2", value)
+ }
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ if value > math.MaxInt32 {
+ return fmt.Errorf("%d is greater than max pg:int4", value)
+ }
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ //****** Changed value to int64(value) and math.MaxInt64 to int64(math.MaxInt64)
+ if int64(value) > int64(math.MaxInt64) {
+ return fmt.Errorf("%d is greater than max pg:int8", value)
+ }
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid)
+ }
+
+ return nil
+}
+
+func encodeChar(w *WriteBuf, oid Oid, value Char) error {
+ w.WriteInt32(1)
+ w.WriteByte(byte(value))
+ return nil
+}
+
+func encodeInt8(w *WriteBuf, oid Oid, value int8) error {
+ switch oid {
+ case Int2Oid:
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int8", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error {
+ switch oid {
+ case Int2Oid:
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid)
+ }
+
+ return nil
+}
+
+func encodeInt16(w *WriteBuf, oid Oid, value int16) error {
+ switch oid {
+ case Int2Oid:
+ w.WriteInt32(2)
+ w.WriteInt16(value)
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int16", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int16", oid)
+ }
+
+ return nil
+}
+
+func encodeInt32(w *WriteBuf, oid Oid, value int32) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(value)
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int32", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ if value <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
+ }
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid)
+ }
+
+ return nil
+}
+
+func encodeInt64(w *WriteBuf, oid Oid, value int64) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ if value <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
+ }
+ case Int8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(value)
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "int64", oid)
+ }
+
+ return nil
+}
+
+func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error {
+ switch oid {
+ case Int2Oid:
+ if value <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
+ }
+ case Int4Oid:
+ if value <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
+ }
+ case Int8Oid:
+
+ if value <= math.MaxInt64 {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(value))
+ } else {
+ return fmt.Errorf("%d is greater than max int64 %d", value, int64(math.MaxInt64))
+ }
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid)
+ }
+
+ return nil
+}
+
+func decodeInt4(vr *ValueReader) int32 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into int32"))
+ return 0
+ }
+
+ if vr.Type().DataType != Int4Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len())))
+ return 0
+ }
+
+ return vr.ReadInt32()
+}
+
+func decodeOid(vr *ValueReader) Oid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Oid"))
+ return Oid(0)
+ }
+
+ if vr.Type().DataType != OidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType)))
+ return Oid(0)
+ }
+
+ // Oid needs to decode text format because it is used in loadPgTypes
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+ n, err := strconv.ParseUint(s, 10, 32)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ }
+ return Oid(n)
+ case BinaryFormatCode:
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Oid(0)
+ }
+ return Oid(vr.ReadInt32())
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Oid(0)
+ }
+}
+
+func encodeOid(w *WriteBuf, oid Oid, value Oid) error {
+ if oid != OidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Oid", oid)
+ }
+
+ w.WriteInt32(4)
+ w.WriteUint32(uint32(value))
+
+ return nil
+}
+
+func decodeXid(vr *ValueReader) Xid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Xid"))
+ return Xid(0)
+ }
+
+ if vr.Type().DataType != XidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Xid", vr.Type().DataType)))
+ return Xid(0)
+ }
+
+ // Unlikely Xid will ever go over the wire as text format, but who knows?
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+ n, err := strconv.ParseUint(s, 10, 32)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ }
+ return Xid(n)
+ case BinaryFormatCode:
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Xid(0)
+ }
+ return Xid(vr.ReadUint32())
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Xid(0)
+ }
+}
+
+func encodeXid(w *WriteBuf, oid Oid, value Xid) error {
+ if oid != XidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Xid", oid)
+ }
+
+ w.WriteInt32(4)
+ w.WriteUint32(uint32(value))
+
+ return nil
+}
+
+func decodeCid(vr *ValueReader) Cid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Cid"))
+ return Cid(0)
+ }
+
+ if vr.Type().DataType != CidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType)))
+ return Cid(0)
+ }
+
+ // Unlikely Cid will ever go over the wire as text format, but who knows?
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+ n, err := strconv.ParseUint(s, 10, 32)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ }
+ return Cid(n)
+ case BinaryFormatCode:
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Cid(0)
+ }
+ return Cid(vr.ReadUint32())
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Cid(0)
+ }
+}
+
+func encodeCid(w *WriteBuf, oid Oid, value Cid) error {
+ if oid != CidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid)
+ }
+
+ w.WriteInt32(4)
+ w.WriteUint32(uint32(value))
+
+ return nil
+}
+
+// Note that we do not match negative numbers, because neither the
+// BlockNumber nor OffsetNumber of a Tid can be negative.
+var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`)
+
+func decodeTid(vr *ValueReader) Tid {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into Tid"))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+
+ if vr.Type().DataType != TidOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType)))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+
+ // Unlikely Tid will ever go over the wire as text format, but who knows?
+ switch vr.Type().FormatCode {
+ case TextFormatCode:
+ s := vr.ReadString(vr.Len())
+
+ match := tidRegexp.FindStringSubmatch(s)
+ if match == nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+
+ blockNumber, err := strconv.ParseUint(s, 10, 16)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s)))
+ }
+
+ offsetNumber, err := strconv.ParseUint(s, 10, 16)
+ if err != nil {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s)))
+ }
+ return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)}
+ case BinaryFormatCode:
+ if vr.Len() != 6 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+ return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()}
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return Tid{BlockNumber: 0, OffsetNumber: 0}
+ }
+}
+
+func encodeTid(w *WriteBuf, oid Oid, value Tid) error {
+ if oid != TidOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid)
+ }
+
+ w.WriteInt32(6)
+ w.WriteUint32(value.BlockNumber)
+ w.WriteUint16(value.OffsetNumber)
+
+ return nil
+}
+
+func decodeFloat4(vr *ValueReader) float32 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into float32"))
+ return 0
+ }
+
+ if vr.Type().DataType != Float4Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
+ return 0
+ }
+
+ i := vr.ReadInt32()
+ return math.Float32frombits(uint32(i))
+}
+
+func encodeFloat32(w *WriteBuf, oid Oid, value float32) error {
+ switch oid {
+ case Float4Oid:
+ w.WriteInt32(4)
+ w.WriteInt32(int32(math.Float32bits(value)))
+ case Float8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(math.Float64bits(float64(value))))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "float32", oid)
+ }
+
+ return nil
+}
+
+func decodeFloat8(vr *ValueReader) float64 {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into float64"))
+ return 0
+ }
+
+ if vr.Type().DataType != Float8Oid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType)))
+ return 0
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return 0
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len())))
+ return 0
+ }
+
+ i := vr.ReadInt64()
+ return math.Float64frombits(uint64(i))
+}
+
+func encodeFloat64(w *WriteBuf, oid Oid, value float64) error {
+ switch oid {
+ case Float8Oid:
+ w.WriteInt32(8)
+ w.WriteInt64(int64(math.Float64bits(value)))
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "float64", oid)
+ }
+
+ return nil
+}
+
+func decodeText(vr *ValueReader) string {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into string"))
+ return ""
+ }
+
+ return vr.ReadString(vr.Len())
+}
+
+func encodeString(w *WriteBuf, oid Oid, value string) error {
+ w.WriteInt32(int32(len(value)))
+ w.WriteBytes([]byte(value))
+ return nil
+}
+
+func decodeBytea(vr *ValueReader) []byte {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != ByteaOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ return vr.ReadBytes(vr.Len())
+}
+
+func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error {
+ w.WriteInt32(int32(len(value)))
+ w.WriteBytes(value)
+
+ return nil
+}
+
+func decodeJSON(vr *ValueReader, d interface{}) error {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != JsonOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType)))
+ }
+
+ bytes := vr.ReadBytes(vr.Len())
+ err := json.Unmarshal(bytes, d)
+ if err != nil {
+ vr.Fatal(err)
+ }
+ return err
+}
+
+func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error {
+ if oid != JsonOid {
+ return fmt.Errorf("cannot encode JSON into oid %v", oid)
+ }
+
+ s, err := json.Marshal(value)
+ if err != nil {
+ return fmt.Errorf("Failed to encode json from type: %T", value)
+ }
+
+ w.WriteInt32(int32(len(s)))
+ w.WriteBytes(s)
+
+ return nil
+}
+
+func decodeJSONB(vr *ValueReader, d interface{}) error {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != JsonbOid {
+ err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType))
+ vr.Fatal(err)
+ return err
+ }
+
+ bytes := vr.ReadBytes(vr.Len())
+ if vr.Type().FormatCode == BinaryFormatCode {
+ if bytes[0] != 1 {
+ err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0]))
+ vr.Fatal(err)
+ return err
+ }
+ bytes = bytes[1:]
+ }
+
+ err := json.Unmarshal(bytes, d)
+ if err != nil {
+ vr.Fatal(err)
+ }
+ return err
+}
+
+func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error {
+ if oid != JsonbOid {
+ return fmt.Errorf("cannot encode JSON into oid %v", oid)
+ }
+
+ s, err := json.Marshal(value)
+ if err != nil {
+ return fmt.Errorf("Failed to encode json from type: %T", value)
+ }
+
+ w.WriteInt32(int32(len(s) + 1))
+ w.WriteByte(1) // JSONB format header
+ w.WriteBytes(s)
+
+ return nil
+}
+
+func decodeDate(vr *ValueReader) time.Time {
+ var zeroTime time.Time
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
+ return zeroTime
+ }
+
+ if vr.Type().DataType != DateOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
+ return zeroTime
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zeroTime
+ }
+
+ if vr.Len() != 4 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len())))
+ }
+ dayOffset := vr.ReadInt32()
+ return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local)
+}
+
+func encodeTime(w *WriteBuf, oid Oid, value time.Time) error {
+ switch oid {
+ case DateOid:
+ tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix()
+ dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix()
+
+ secSinceDateEpoch := tUnix - dateEpoch
+ daysSinceDateEpoch := secSinceDateEpoch / 86400
+
+ w.WriteInt32(4)
+ w.WriteInt32(int32(daysSinceDateEpoch))
+
+ return nil
+ case TimestampTzOid, TimestampOid:
+ microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000
+ microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K
+
+ w.WriteInt32(8)
+ w.WriteInt64(microsecSinceY2K)
+
+ return nil
+ default:
+ return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid)
+ }
+}
+
+const microsecFromUnixEpochToY2K = 946684800 * 1000000
+
+func decodeTimestampTz(vr *ValueReader) time.Time {
+ var zeroTime time.Time
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
+ return zeroTime
+ }
+
+ if vr.Type().DataType != TimestampTzOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
+ return zeroTime
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zeroTime
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len())))
+ return zeroTime
+ }
+
+ microsecSinceY2K := vr.ReadInt64()
+ microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
+ return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
+}
+
+func decodeTimestamp(vr *ValueReader) time.Time {
+ var zeroTime time.Time
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into timestamp"))
+ return zeroTime
+ }
+
+ if vr.Type().DataType != TimestampOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
+ return zeroTime
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zeroTime
+ }
+
+ if vr.Len() != 8 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len())))
+ return zeroTime
+ }
+
+ microsecSinceY2K := vr.ReadInt64()
+ microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
+ return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
+}
+
+func decodeInet(vr *ValueReader) net.IPNet {
+ var zero net.IPNet
+
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into net.IPNet"))
+ return zero
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return zero
+ }
+
+ pgType := vr.Type()
+ if pgType.DataType != InetOid && pgType.DataType != CidrOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
+ return zero
+ }
+ if vr.Len() != 8 && vr.Len() != 20 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
+ return zero
+ }
+
+ vr.ReadByte() // ignore family
+ bits := vr.ReadByte()
+ vr.ReadByte() // ignore is_cidr
+ addressLength := vr.ReadByte()
+
+ var ipnet net.IPNet
+ ipnet.IP = vr.ReadBytes(int32(addressLength))
+ ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
+
+ return ipnet
+}
+
+func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error {
+ if oid != InetOid && oid != CidrOid {
+ return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid)
+ }
+
+ var size int32
+ var family byte
+ switch len(value.IP) {
+ case net.IPv4len:
+ size = 8
+ family = *w.conn.pgsqlAfInet
+ case net.IPv6len:
+ size = 20
+ family = *w.conn.pgsqlAfInet6
+ default:
+ return fmt.Errorf("Unexpected IP length: %v", len(value.IP))
+ }
+
+ w.WriteInt32(size)
+ w.WriteByte(family)
+ ones, _ := value.Mask.Size()
+ w.WriteByte(byte(ones))
+ w.WriteByte(0) // is_cidr is ignored on server
+ w.WriteByte(byte(len(value.IP)))
+ w.WriteBytes(value.IP)
+
+ return nil
+}
+
+func encodeIP(w *WriteBuf, oid Oid, value net.IP) error {
+ if oid != InetOid && oid != CidrOid {
+ return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid)
+ }
+
+ var ipnet net.IPNet
+ ipnet.IP = value
+ bitCount := len(value) * 8
+ ipnet.Mask = net.CIDRMask(bitCount, bitCount)
+ return encodeIPNet(w, oid, ipnet)
+}
+
+func decodeRecord(vr *ValueReader) []interface{} {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ if vr.Type().DataType != RecordOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType)))
+ return nil
+ }
+
+ valueCount := vr.ReadInt32()
+ record := make([]interface{}, 0, int(valueCount))
+
+ for i := int32(0); i < valueCount; i++ {
+ fd := FieldDescription{FormatCode: BinaryFormatCode}
+ fieldVR := ValueReader{mr: vr.mr, fd: &fd}
+ fd.DataType = vr.ReadOid()
+ fieldVR.valueBytesRemaining = vr.ReadInt32()
+ vr.valueBytesRemaining -= fieldVR.valueBytesRemaining
+
+ switch fd.DataType {
+ case BoolOid:
+ record = append(record, decodeBool(&fieldVR))
+ case ByteaOid:
+ record = append(record, decodeBytea(&fieldVR))
+ case Int8Oid:
+ record = append(record, decodeInt8(&fieldVR))
+ case Int2Oid:
+ record = append(record, decodeInt2(&fieldVR))
+ case Int4Oid:
+ record = append(record, decodeInt4(&fieldVR))
+ case OidOid:
+ record = append(record, decodeOid(&fieldVR))
+ case Float4Oid:
+ record = append(record, decodeFloat4(&fieldVR))
+ case Float8Oid:
+ record = append(record, decodeFloat8(&fieldVR))
+ case DateOid:
+ record = append(record, decodeDate(&fieldVR))
+ case TimestampTzOid:
+ record = append(record, decodeTimestampTz(&fieldVR))
+ case TimestampOid:
+ record = append(record, decodeTimestamp(&fieldVR))
+ case InetOid, CidrOid:
+ record = append(record, decodeInet(&fieldVR))
+ case TextOid, VarcharOid, UnknownOid:
+ record = append(record, decodeText(&fieldVR))
+ default:
+ vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType))
+ return nil
+ }
+
+ // Consume any remaining data
+ if fieldVR.Len() > 0 {
+ fieldVR.ReadBytes(fieldVR.Len())
+ }
+
+ if fieldVR.Err() != nil {
+ vr.Fatal(fieldVR.Err())
+ return nil
+ }
+ }
+
+ return record
+}
+
+func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
+ numDims := vr.ReadInt32()
+ if numDims > 1 {
+ return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims))
+ }
+
+ vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care
+ vr.ReadInt32() // element oid
+
+ if numDims == 0 {
+ return 0, nil
+ }
+
+ length = vr.ReadInt32()
+
+ idxFirstElem := vr.ReadInt32()
+ if idxFirstElem != 1 {
+ return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem))
+ }
+
+ return length, nil
+}
+
+func decodeBoolArray(vr *ValueReader) []bool {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != BoolArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]bool, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 1:
+ if vr.ReadByte() == 1 {
+ a[i] = true
+ }
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error {
+ if oid != BoolArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid)
+ }
+
+ encodeArrayHeader(w, BoolOid, len(slice), 5)
+ for _, v := range slice {
+ w.WriteInt32(1)
+ var b byte
+ if v {
+ b = 1
+ }
+ w.WriteByte(b)
+ }
+
+ return nil
+}
+
+func decodeByteaArray(vr *ValueReader) [][]byte {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != ByteaArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([][]byte, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ a[i] = vr.ReadBytes(elSize)
+ }
+ }
+
+ return a
+}
+
+func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error {
+ if oid != ByteaArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid)
+ }
+
+ size := 20 // array header size
+ for _, el := range value {
+ size += 4 + len(el)
+ }
+
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(ByteaOid)) // type of elements
+ w.WriteInt32(int32(len(value))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, el := range value {
+ encodeByteSlice(w, ByteaOid, el)
+ }
+
+ return nil
+}
+
+func decodeInt2Array(vr *ValueReader) []int16 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int2ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]int16, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 2:
+ a[i] = vr.ReadInt16()
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int2ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]uint16, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 2:
+ tmp := vr.ReadInt16()
+ if tmp < 0 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint16", tmp)))
+ return nil
+ }
+ a[i] = uint16(tmp)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error {
+ if oid != Int2ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid)
+ }
+
+ encodeArrayHeader(w, Int2Oid, len(slice), 6)
+ for _, v := range slice {
+ w.WriteInt32(2)
+ w.WriteInt16(v)
+ }
+
+ return nil
+}
+
+func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error {
+ if oid != Int2ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid)
+ }
+
+ encodeArrayHeader(w, Int2Oid, len(slice), 6)
+ for _, v := range slice {
+ if v <= math.MaxInt16 {
+ w.WriteInt32(2)
+ w.WriteInt16(int16(v))
+ } else {
+ return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16)
+ }
+ }
+
+ return nil
+}
+
+func decodeInt4Array(vr *ValueReader) []int32 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int4ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]int32, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 4:
+ a[i] = vr.ReadInt32()
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int4ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]uint32, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 4:
+ tmp := vr.ReadInt32()
+ if tmp < 0 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint32", tmp)))
+ return nil
+ }
+ a[i] = uint32(tmp)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error {
+ if oid != Int4ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid)
+ }
+
+ encodeArrayHeader(w, Int4Oid, len(slice), 8)
+ for _, v := range slice {
+ w.WriteInt32(4)
+ w.WriteInt32(v)
+ }
+
+ return nil
+}
+
+func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error {
+ if oid != Int4ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid)
+ }
+
+ encodeArrayHeader(w, Int4Oid, len(slice), 8)
+ for _, v := range slice {
+ if v <= math.MaxInt32 {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(v))
+ } else {
+ return fmt.Errorf("%d is greater than max integer %d", v, math.MaxInt32)
+ }
+ }
+
+ return nil
+}
+
+func decodeInt8Array(vr *ValueReader) []int64 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int8ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]int64, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ a[i] = vr.ReadInt64()
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Int8ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]uint64, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ tmp := vr.ReadInt64()
+ if tmp < 0 {
+ vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint64", tmp)))
+ return nil
+ }
+ a[i] = uint64(tmp)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error {
+ if oid != Int8ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid)
+ }
+
+ encodeArrayHeader(w, Int8Oid, len(slice), 12)
+ for _, v := range slice {
+ w.WriteInt32(8)
+ w.WriteInt64(v)
+ }
+
+ return nil
+}
+
+func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error {
+ if oid != Int8ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid)
+ }
+
+ encodeArrayHeader(w, Int8Oid, len(slice), 12)
+ for _, v := range slice {
+ if v <= math.MaxInt64 {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(v))
+ } else {
+ return fmt.Errorf("%d is greater than max bigint %d", v, int64(math.MaxInt64))
+ }
+ }
+
+ return nil
+}
+
+func decodeFloat4Array(vr *ValueReader) []float32 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Float4ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]float32, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 4:
+ n := vr.ReadInt32()
+ a[i] = math.Float32frombits(uint32(n))
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeFloat32Slice(w *WriteBuf, oid Oid, slice []float32) error {
+ if oid != Float4ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid)
+ }
+
+ encodeArrayHeader(w, Float4Oid, len(slice), 8)
+ for _, v := range slice {
+ w.WriteInt32(4)
+ w.WriteInt32(int32(math.Float32bits(v)))
+ }
+
+ return nil
+}
+
+func decodeFloat8Array(vr *ValueReader) []float64 {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != Float8ArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]float64, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ n := vr.ReadInt64()
+ a[i] = math.Float64frombits(uint64(n))
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeFloat64Slice(w *WriteBuf, oid Oid, slice []float64) error {
+ if oid != Float8ArrayOid {
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid)
+ }
+
+ encodeArrayHeader(w, Float8Oid, len(slice), 12)
+ for _, v := range slice {
+ w.WriteInt32(8)
+ w.WriteInt64(int64(math.Float64bits(v)))
+ }
+
+ return nil
+}
+
+func decodeTextArray(vr *ValueReader) []string {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]string, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ if elSize == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ }
+
+ a[i] = vr.ReadString(elSize)
+ }
+
+ return a
+}
+
+// escapeAclItem escapes an AclItem before it is added to
+// its aclitem[] string representation. The PostgreSQL aclitem
+// datatype itself can need escapes because it follows the
+// formatting rules of SQL identifiers. Think of this function
+// as escaping the escapes, so that PostgreSQL's array parser
+// will do the right thing.
+func escapeAclItem(acl string) (string, error) {
+ var escapedAclItem bytes.Buffer
+ reader := strings.NewReader(acl)
+ for {
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ // Here, EOF is an expected end state, not an error.
+ return escapedAclItem.String(), nil
+ }
+ // This error was not expected
+ return "", err
+ }
+ if needsEscape(rn) {
+ escapedAclItem.WriteRune('\\')
+ }
+ escapedAclItem.WriteRune(rn)
+ }
+}
+
+// needsEscape determines whether or not a rune needs escaping
+// before being placed in the textual representation of an
+// aclitem[] array.
+func needsEscape(rn rune) bool {
+ return rn == '\\' || rn == ',' || rn == '"' || rn == '}'
+}
+
+// encodeAclItemSlice encodes a slice of AclItems in
+// their textual represention for PostgreSQL.
+func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error {
+ strs := make([]string, len(aclitems))
+ var escapedAclItem string
+ var err error
+ for i := range strs {
+ escapedAclItem, err = escapeAclItem(string(aclitems[i]))
+ if err != nil {
+ return err
+ }
+ strs[i] = string(escapedAclItem)
+ }
+
+ var buf bytes.Buffer
+ buf.WriteRune('{')
+ buf.WriteString(strings.Join(strs, ","))
+ buf.WriteRune('}')
+ str := buf.String()
+ w.WriteInt32(int32(len(str)))
+ w.WriteBytes([]byte(str))
+ return nil
+}
+
+// parseAclItemArray parses the textual representation
+// of the aclitem[] type. The textual representation is chosen because
+// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin).
+// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
+// for formatting notes.
+func parseAclItemArray(arr string) ([]AclItem, error) {
+ reader := strings.NewReader(arr)
+ // Difficult to guess a performant initial capacity for a slice of
+ // aclitems, but let's go with 5.
+ aclItems := make([]AclItem, 0, 5)
+ // A single value
+ aclItem := AclItem("")
+ for {
+ // Grab the first/next/last rune to see if we are dealing with a
+ // quoted value, an unquoted value, or the end of the string.
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ // Here, EOF is an expected end state, not an error.
+ return aclItems, nil
+ }
+ // This error was not expected
+ return nil, err
+ }
+
+ if rn == '"' {
+ // Discard the opening quote of the quoted value.
+ aclItem, err = parseQuotedAclItem(reader)
+ } else {
+ // We have just read the first rune of an unquoted (bare) value;
+ // put it back so that ParseBareValue can read it.
+ err := reader.UnreadRune()
+ if err != nil {
+ return nil, err
+ }
+ aclItem, err = parseBareAclItem(reader)
+ }
+
+ if err != nil {
+ if err == io.EOF {
+ // Here, EOF is an expected end state, not an error..
+ aclItems = append(aclItems, aclItem)
+ return aclItems, nil
+ }
+ // This error was not expected.
+ return nil, err
+ }
+ aclItems = append(aclItems, aclItem)
+ }
+}
+
+// parseBareAclItem parses a bare (unquoted) aclitem from reader
+func parseBareAclItem(reader *strings.Reader) (AclItem, error) {
+ var aclItem bytes.Buffer
+ for {
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ // Return the read value in case the error is a harmless io.EOF.
+ // (io.EOF marks the end of a bare aclitem at the end of a string)
+ return AclItem(aclItem.String()), err
+ }
+ if rn == ',' {
+ // A comma marks the end of a bare aclitem.
+ return AclItem(aclItem.String()), nil
+ } else {
+ aclItem.WriteRune(rn)
+ }
+ }
+}
+
+// parseQuotedAclItem parses an aclitem which is in double quotes from reader
+func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) {
+ var aclItem bytes.Buffer
+ for {
+ rn, escaped, err := readPossiblyEscapedRune(reader)
+ if err != nil {
+ if err == io.EOF {
+ // Even when it is the last value, the final rune of
+ // a quoted aclitem should be the final closing quote, not io.EOF.
+ return AclItem(""), fmt.Errorf("unexpected end of quoted value")
+ }
+ // Return the read aclitem in case the error is a harmless io.EOF,
+ // which will be determined by the caller.
+ return AclItem(aclItem.String()), err
+ }
+ if !escaped && rn == '"' {
+ // An unescaped double quote marks the end of a quoted value.
+ // The next rune should either be a comma or the end of the string.
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ // Return the read value in case the error is a harmless io.EOF,
+ // which will be determined by the caller.
+ return AclItem(aclItem.String()), err
+ }
+ if rn != ',' {
+ return AclItem(""), fmt.Errorf("unexpected rune after quoted value")
+ }
+ return AclItem(aclItem.String()), nil
+ }
+ aclItem.WriteRune(rn)
+ }
+}
+
+// Returns the next rune from r, unless it is a backslash;
+// in that case, it returns the rune after the backslash. The second
+// return value tells us whether or not the rune was
+// preceeded by a backslash (escaped).
+func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) {
+ rn, _, err := reader.ReadRune()
+ if err != nil {
+ return 0, false, err
+ }
+ if rn == '\\' {
+ // Discard the backslash and read the next rune.
+ rn, _, err = reader.ReadRune()
+ if err != nil {
+ return 0, false, err
+ }
+ return rn, true, nil
+ }
+ return rn, false, nil
+}
+
+func decodeAclItemArray(vr *ValueReader) []AclItem {
+ if vr.Len() == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null into []AclItem"))
+ return nil
+ }
+
+ str := vr.ReadString(vr.Len())
+
+ // Short-circuit empty array.
+ if str == "{}" {
+ return []AclItem{}
+ }
+
+ // Remove the '{' at the front and the '}' at the end,
+ // so that parseAclItemArray doesn't have to deal with them.
+ str = str[1 : len(str)-1]
+ aclItems, err := parseAclItemArray(str)
+ if err != nil {
+ vr.Fatal(ProtocolError(err.Error()))
+ return nil
+ }
+ return aclItems
+}
+
+func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error {
+ var elOid Oid
+ switch oid {
+ case VarcharArrayOid:
+ elOid = VarcharOid
+ case TextArrayOid:
+ elOid = TextOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid)
+ }
+
+ var totalStringSize int
+ for _, v := range slice {
+ totalStringSize += len(v)
+ }
+
+ size := 20 + len(slice)*4 + totalStringSize
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(elOid)) // type of elements
+ w.WriteInt32(int32(len(slice))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, v := range slice {
+ w.WriteInt32(int32(len(v)))
+ w.WriteBytes([]byte(v))
+ }
+
+ return nil
+}
+
+func decodeTimestampArray(vr *ValueReader) []time.Time {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]time.Time, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ switch elSize {
+ case 8:
+ microsecSinceY2K := vr.ReadInt64()
+ microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
+ a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
+ case -1:
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ default:
+ vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize)))
+ return nil
+ }
+ }
+
+ return a
+}
+
+func encodeTimeSlice(w *WriteBuf, oid Oid, slice []time.Time) error {
+ var elOid Oid
+ switch oid {
+ case TimestampArrayOid:
+ elOid = TimestampOid
+ case TimestampTzArrayOid:
+ elOid = TimestampTzOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid)
+ }
+
+ encodeArrayHeader(w, int(elOid), len(slice), 12)
+ for _, t := range slice {
+ w.WriteInt32(8)
+ microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000
+ microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K
+ w.WriteInt64(microsecSinceY2K)
+ }
+
+ return nil
+}
+
+func decodeInetArray(vr *ValueReader) []net.IPNet {
+ if vr.Len() == -1 {
+ return nil
+ }
+
+ if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType)))
+ return nil
+ }
+
+ if vr.Type().FormatCode != BinaryFormatCode {
+ vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
+ return nil
+ }
+
+ numElems, err := decode1dArrayHeader(vr)
+ if err != nil {
+ vr.Fatal(err)
+ return nil
+ }
+
+ a := make([]net.IPNet, int(numElems))
+ for i := 0; i < len(a); i++ {
+ elSize := vr.ReadInt32()
+ if elSize == -1 {
+ vr.Fatal(ProtocolError("Cannot decode null element"))
+ return nil
+ }
+
+ vr.ReadByte() // ignore family
+ bits := vr.ReadByte()
+ vr.ReadByte() // ignore is_cidr
+ addressLength := vr.ReadByte()
+
+ var ipnet net.IPNet
+ ipnet.IP = vr.ReadBytes(int32(addressLength))
+ ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
+
+ a[i] = ipnet
+ }
+
+ return a
+}
+
+func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error {
+ var elOid Oid
+ switch oid {
+ case InetArrayOid:
+ elOid = InetOid
+ case CidrArrayOid:
+ elOid = CidrOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid)
+ }
+
+ size := int32(20) // array header size
+ for _, ipnet := range slice {
+ size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes
+ }
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(elOid)) // type of elements
+ w.WriteInt32(int32(len(slice))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, ipnet := range slice {
+ encodeIPNet(w, elOid, ipnet)
+ }
+
+ return nil
+}
+
+func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error {
+ var elOid Oid
+ switch oid {
+ case InetArrayOid:
+ elOid = InetOid
+ case CidrArrayOid:
+ elOid = CidrOid
+ default:
+ return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid)
+ }
+
+ size := int32(20) // array header size
+ for _, ip := range slice {
+ size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes
+ }
+ w.WriteInt32(int32(size))
+
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(elOid)) // type of elements
+ w.WriteInt32(int32(len(slice))) // number of elements
+ w.WriteInt32(1) // index of first element
+
+ for _, ip := range slice {
+ encodeIP(w, elOid, ip)
+ }
+
+ return nil
+}
+
+func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) {
+ w.WriteInt32(int32(20 + length*sizePerItem))
+ w.WriteInt32(1) // number of dimensions
+ w.WriteInt32(0) // no nulls
+ w.WriteInt32(int32(oid)) // type of elements
+ w.WriteInt32(int32(length)) // number of elements
+ w.WriteInt32(1) // index of first element
+}
diff --git a/vendor/github.com/jackc/pgx/values_test.go b/vendor/github.com/jackc/pgx/values_test.go
new file mode 100644
index 0000000..42d5bd3
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/values_test.go
@@ -0,0 +1,1183 @@
+package pgx_test
+
+import (
+ "bytes"
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx"
+)
+
+func TestDateTranscode(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ dates := []time.Time{
+ time.Date(1, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(1000, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(1600, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(1700, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local),
+ time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(2001, 1, 2, 0, 0, 0, 0, time.Local),
+ time.Date(2004, 2, 29, 0, 0, 0, 0, time.Local),
+ time.Date(2013, 7, 4, 0, 0, 0, 0, time.Local),
+ time.Date(2013, 12, 25, 0, 0, 0, 0, time.Local),
+ time.Date(2029, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(2081, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(2096, 2, 29, 0, 0, 0, 0, time.Local),
+ time.Date(2550, 1, 1, 0, 0, 0, 0, time.Local),
+ time.Date(9999, 12, 31, 0, 0, 0, 0, time.Local),
+ }
+
+ for _, actualDate := range dates {
+ var d time.Time
+
+ err := conn.QueryRow("select $1::date", actualDate).Scan(&d)
+ if err != nil {
+ t.Fatalf("Unexpected failure on QueryRow Scan: %v", err)
+ }
+ if !actualDate.Equal(d) {
+ t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate)
+ }
+ }
+}
+
+func TestTimestampTzTranscode(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local)
+
+ var outputTime time.Time
+
+ err := conn.QueryRow("select $1::timestamptz", inputTime).Scan(&outputTime)
+ if err != nil {
+ t.Fatalf("QueryRow Scan failed: %v", err)
+ }
+ if !inputTime.Equal(outputTime) {
+ t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
+ }
+
+ err = conn.QueryRow("select $1::timestamptz", inputTime).Scan(&outputTime)
+ if err != nil {
+ t.Fatalf("QueryRow Scan failed: %v", err)
+ }
+ if !inputTime.Equal(outputTime) {
+ t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
+ }
+}
+
+func TestJsonAndJsonbTranscode(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
+ if _, ok := conn.PgTypes[oid]; !ok {
+ return // No JSON/JSONB type -- must be running against old PostgreSQL
+ }
+
+ for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} {
+ pgtype := conn.PgTypes[oid]
+ pgtype.DefaultFormat = format
+ conn.PgTypes[oid] = pgtype
+
+ typename := conn.PgTypes[oid].Name
+
+ testJsonString(t, conn, typename, format)
+ testJsonStringPointer(t, conn, typename, format)
+ testJsonSingleLevelStringMap(t, conn, typename, format)
+ testJsonNestedMap(t, conn, typename, format)
+ testJsonStringArray(t, conn, typename, format)
+ testJsonInt64Array(t, conn, typename, format)
+ testJsonInt16ArrayFailureDueToOverflow(t, conn, typename, format)
+ testJsonStruct(t, conn, typename, format)
+ }
+ }
+}
+
+func testJsonString(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ input := `{"key": "value"}`
+ expectedOutput := map[string]string{"key": "value"}
+ var output map[string]string
+ err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
+ if err != nil {
+ t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
+ return
+ }
+
+ if !reflect.DeepEqual(expectedOutput, output) {
+ t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output)
+ return
+ }
+}
+
+func testJsonStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ input := `{"key": "value"}`
+ expectedOutput := map[string]string{"key": "value"}
+ var output map[string]string
+ err := conn.QueryRow("select $1::"+typename, &input).Scan(&output)
+ if err != nil {
+ t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
+ return
+ }
+
+ if !reflect.DeepEqual(expectedOutput, output) {
+ t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output)
+ return
+ }
+}
+
+func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ input := map[string]string{"key": "value"}
+ var output map[string]string
+ err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
+ if err != nil {
+ t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
+ return
+ }
+
+ if !reflect.DeepEqual(input, output) {
+ t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output)
+ return
+ }
+}
+
+func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ input := map[string]interface{}{
+ "name": "Uncanny",
+ "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)},
+ "inventory": []interface{}{"phone", "key"},
+ }
+ var output map[string]interface{}
+ err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
+ if err != nil {
+ t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
+ return
+ }
+
+ if !reflect.DeepEqual(input, output) {
+ t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output)
+ return
+ }
+}
+
+func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ input := []string{"foo", "bar", "baz"}
+ var output []string
+ err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
+ if err != nil {
+ t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
+ }
+
+ if !reflect.DeepEqual(input, output) {
+ t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output)
+ }
+}
+
+func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ input := []int64{1, 2, 234432}
+ var output []int64
+ err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
+ if err != nil {
+ t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
+ }
+
+ if !reflect.DeepEqual(input, output) {
+ t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output)
+ }
+}
+
+func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ input := []int{1, 2, 234432}
+ var output []int16
+ err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
+ if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
+ t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err)
+ }
+}
+
+func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) {
+ type person struct {
+ Name string `json:"name"`
+ Age int `json:"age"`
+ }
+
+ input := person{
+ Name: "John",
+ Age: 42,
+ }
+
+ var output person
+
+ err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
+ if err != nil {
+ t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err)
+ }
+
+ if !reflect.DeepEqual(input, output) {
+ t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output)
+ }
+}
+
+func mustParseCIDR(t *testing.T, s string) net.IPNet {
+ _, ipnet, err := net.ParseCIDR(s)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return *ipnet
+}
+
+func TestStringToNotTextTypeTranscode(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ input := "01086ee0-4963-4e35-9116-30c173a8d0bd"
+
+ var output string
+ err := conn.QueryRow("select $1::uuid", input).Scan(&output)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if input != output {
+ t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output)
+ }
+
+ err = conn.QueryRow("select $1::uuid", &input).Scan(&output)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if input != output {
+ t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output)
+ }
+}
+
+func TestInetCidrTranscodeIPNet(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ value net.IPNet
+ }{
+ {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")},
+ {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")},
+ {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")},
+ {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
+ {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")},
+ {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")},
+ {"select $1::inet", mustParseCIDR(t, "::/128")},
+ {"select $1::inet", mustParseCIDR(t, "::/0")},
+ {"select $1::inet", mustParseCIDR(t, "::1/128")},
+ {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
+ {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")},
+ {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")},
+ {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")},
+ {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
+ {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")},
+ {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")},
+ {"select $1::cidr", mustParseCIDR(t, "::/128")},
+ {"select $1::cidr", mustParseCIDR(t, "::/0")},
+ {"select $1::cidr", mustParseCIDR(t, "::1/128")},
+ {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
+ }
+
+ for i, tt := range tests {
+ var actual net.IPNet
+
+ err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
+ continue
+ }
+
+ if actual.String() != tt.value.String() {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestInetCidrTranscodeIP(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ value net.IP
+ }{
+ {"select $1::inet", net.ParseIP("0.0.0.0")},
+ {"select $1::inet", net.ParseIP("127.0.0.1")},
+ {"select $1::inet", net.ParseIP("12.34.56.0")},
+ {"select $1::inet", net.ParseIP("255.255.255.255")},
+ {"select $1::inet", net.ParseIP("::1")},
+ {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")},
+ {"select $1::cidr", net.ParseIP("0.0.0.0")},
+ {"select $1::cidr", net.ParseIP("127.0.0.1")},
+ {"select $1::cidr", net.ParseIP("12.34.56.0")},
+ {"select $1::cidr", net.ParseIP("255.255.255.255")},
+ {"select $1::cidr", net.ParseIP("::1")},
+ {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")},
+ }
+
+ for i, tt := range tests {
+ var actual net.IP
+
+ err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
+ continue
+ }
+
+ if !actual.Equal(tt.value) {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ failTests := []struct {
+ sql string
+ value net.IPNet
+ }{
+ {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
+ {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
+ }
+ for i, tt := range failTests {
+ var actual net.IP
+
+ err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
+ if !strings.Contains(err.Error(), "Cannot decode netmask") {
+ t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
+ continue
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestInetCidrArrayTranscodeIPNet(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ value []net.IPNet
+ }{
+ {
+ "select $1::inet[]",
+ []net.IPNet{
+ mustParseCIDR(t, "0.0.0.0/32"),
+ mustParseCIDR(t, "127.0.0.1/32"),
+ mustParseCIDR(t, "12.34.56.0/32"),
+ mustParseCIDR(t, "192.168.1.0/24"),
+ mustParseCIDR(t, "255.0.0.0/8"),
+ mustParseCIDR(t, "255.255.255.255/32"),
+ mustParseCIDR(t, "::/128"),
+ mustParseCIDR(t, "::/0"),
+ mustParseCIDR(t, "::1/128"),
+ mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
+ },
+ },
+ {
+ "select $1::cidr[]",
+ []net.IPNet{
+ mustParseCIDR(t, "0.0.0.0/32"),
+ mustParseCIDR(t, "127.0.0.1/32"),
+ mustParseCIDR(t, "12.34.56.0/32"),
+ mustParseCIDR(t, "192.168.1.0/24"),
+ mustParseCIDR(t, "255.0.0.0/8"),
+ mustParseCIDR(t, "255.255.255.255/32"),
+ mustParseCIDR(t, "::/128"),
+ mustParseCIDR(t, "::/0"),
+ mustParseCIDR(t, "::1/128"),
+ mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ var actual []net.IPNet
+
+ err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
+ continue
+ }
+
+ if !reflect.DeepEqual(actual, tt.value) {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestInetCidrArrayTranscodeIP(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ value []net.IP
+ }{
+ {
+ "select $1::inet[]",
+ []net.IP{
+ net.ParseIP("0.0.0.0"),
+ net.ParseIP("127.0.0.1"),
+ net.ParseIP("12.34.56.0"),
+ net.ParseIP("255.255.255.255"),
+ net.ParseIP("2607:f8b0:4009:80b::200e"),
+ },
+ },
+ {
+ "select $1::cidr[]",
+ []net.IP{
+ net.ParseIP("0.0.0.0"),
+ net.ParseIP("127.0.0.1"),
+ net.ParseIP("12.34.56.0"),
+ net.ParseIP("255.255.255.255"),
+ net.ParseIP("2607:f8b0:4009:80b::200e"),
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ var actual []net.IP
+
+ err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
+ continue
+ }
+
+ if !reflect.DeepEqual(actual, tt.value) {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+
+ failTests := []struct {
+ sql string
+ value []net.IPNet
+ }{
+ {
+ "select $1::inet[]",
+ []net.IPNet{
+ mustParseCIDR(t, "12.34.56.0/32"),
+ mustParseCIDR(t, "192.168.1.0/24"),
+ },
+ },
+ {
+ "select $1::cidr[]",
+ []net.IPNet{
+ mustParseCIDR(t, "12.34.56.0/32"),
+ mustParseCIDR(t, "192.168.1.0/24"),
+ },
+ },
+ }
+
+ for i, tt := range failTests {
+ var actual []net.IP
+
+ err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
+ if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") {
+ t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
+ continue
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestInetCidrTranscodeWithJustIP(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ value string
+ }{
+ {"select $1::inet", "0.0.0.0/32"},
+ {"select $1::inet", "127.0.0.1/32"},
+ {"select $1::inet", "12.34.56.0/32"},
+ {"select $1::inet", "255.255.255.255/32"},
+ {"select $1::inet", "::/128"},
+ {"select $1::inet", "2607:f8b0:4009:80b::200e/128"},
+ {"select $1::cidr", "0.0.0.0/32"},
+ {"select $1::cidr", "127.0.0.1/32"},
+ {"select $1::cidr", "12.34.56.0/32"},
+ {"select $1::cidr", "255.255.255.255/32"},
+ {"select $1::cidr", "::/128"},
+ {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"},
+ }
+
+ for i, tt := range tests {
+ expected := mustParseCIDR(t, tt.value)
+ var actual net.IPNet
+
+ err := conn.QueryRow(tt.sql, expected.IP).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
+ continue
+ }
+
+ if actual.String() != expected.String() {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestNullX(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type allTypes struct {
+ s pgx.NullString
+ i16 pgx.NullInt16
+ i32 pgx.NullInt32
+ c pgx.NullChar
+ a pgx.NullAclItem
+ n pgx.NullName
+ oid pgx.NullOid
+ xid pgx.NullXid
+ cid pgx.NullCid
+ tid pgx.NullTid
+ i64 pgx.NullInt64
+ f32 pgx.NullFloat32
+ f64 pgx.NullFloat64
+ b pgx.NullBool
+ t pgx.NullTime
+ }
+
+ var actual, zero allTypes
+
+ tests := []struct {
+ sql string
+ queryArgs []interface{}
+ scanArgs []interface{}
+ expected allTypes
+ }{
+ {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}},
+ {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}},
+ {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 1, Valid: true}}},
+ {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}},
+ {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}},
+ {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}},
+ {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 1, Valid: true}}},
+ {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 0, Valid: false}}},
+ {"select $1::oid", []interface{}{pgx.NullOid{Oid: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 4294967295, Valid: true}}},
+ {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}},
+ {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}},
+ {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}},
+ {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}},
+ {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}},
+ {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}},
+ {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}},
+ {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}},
+ {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}},
+ {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}},
+ // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks
+ {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}},
+ {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}},
+ {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}},
+ {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}},
+ {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}},
+ {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}},
+ {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}},
+ {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}},
+ {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}},
+ {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}},
+ {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: false}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 0, Valid: false}}},
+ {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}},
+ {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: false}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 0, Valid: false}}},
+ {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}},
+ {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}},
+ {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}},
+ {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}},
+ {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}},
+ {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}},
+ {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}},
+ {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}},
+ {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}},
+ }
+
+ for i, tt := range tests {
+ actual = zero
+
+ err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
+ }
+
+ if actual != tt.expected {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) {
+ if !reflect.DeepEqual(query, scan) {
+ t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan)
+ }
+}
+
+func TestAclArrayDecoding(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ sql := "select $1::aclitem[]"
+ var scan []pgx.AclItem
+
+ tests := []struct {
+ query []pgx.AclItem
+ }{
+ {
+ []pgx.AclItem{},
+ },
+ {
+ []pgx.AclItem{"=r/postgres"},
+ },
+ {
+ []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"},
+ },
+ {
+ []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`},
+ },
+ }
+ for i, tt := range tests {
+ err := conn.QueryRow(sql, tt.query).Scan(&scan)
+ if err != nil {
+ // t.Errorf(`%d. error reading array: %v`, i, err)
+ t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query)
+ if pgerr, ok := err.(pgx.PgError); ok {
+ t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail)
+ }
+ continue
+ }
+ assertAclItemSlicesEqual(t, tt.query, scan)
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestArrayDecoding(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ query interface{}
+ scan interface{}
+ assert func(*testing.T, interface{}, interface{})
+ }{
+ {
+ "select $1::bool[]", []bool{true, false, true}, &[]bool{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
+ t.Errorf("failed to encode bool[]")
+ }
+ },
+ },
+ {
+ "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
+ t.Errorf("failed to encode smallint[]")
+ }
+ },
+ },
+ {
+ "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
+ t.Errorf("failed to encode smallint[]")
+ }
+ },
+ },
+ {
+ "select $1::int[]", []int32{2, 4, 484}, &[]int32{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
+ t.Errorf("failed to encode int[]")
+ }
+ },
+ },
+ {
+ "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
+ t.Errorf("failed to encode int[]")
+ }
+ },
+ },
+ {
+ "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
+ t.Errorf("failed to encode bigint[]")
+ }
+ },
+ },
+ {
+ "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
+ t.Errorf("failed to encode bigint[]")
+ }
+ },
+ },
+ {
+ "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]string))) {
+ t.Errorf("failed to encode text[]")
+ }
+ },
+ },
+ {
+ "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
+ t.Errorf("failed to encode time.Time[] to timestamp[]")
+ }
+ },
+ },
+ {
+ "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
+ func(t *testing.T, query, scan interface{}) {
+ if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
+ t.Errorf("failed to encode time.Time[] to timestamptz[]")
+ }
+ },
+ },
+ {
+ "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{},
+ func(t *testing.T, query, scan interface{}) {
+ queryBytesSliceSlice := query.([][]byte)
+ scanBytesSliceSlice := *(scan.(*[][]byte))
+ if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) {
+ t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice))
+ }
+ for i := range queryBytesSliceSlice {
+ qb := queryBytesSliceSlice[i]
+ sb := scanBytesSliceSlice[i]
+ if !bytes.Equal(qb, sb) {
+ t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
+ }
+ }
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan)
+ if err != nil {
+ t.Errorf(`%d. error reading array: %v`, i, err)
+ continue
+ }
+ tt.assert(t, tt.query, tt.scan)
+ ensureConnValid(t, conn)
+ }
+}
+
+type shortScanner struct{}
+
+func (*shortScanner) Scan(r *pgx.ValueReader) error {
+ r.ReadByte()
+ return nil
+}
+
+func TestShortScanner(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ rows, err := conn.Query("select 'ab', 'cd' union select 'cd', 'ef'")
+ if err != nil {
+ t.Error(err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var s1, s2 shortScanner
+ err = rows.Scan(&s1, &s2)
+ if err != nil {
+ t.Error(err)
+ }
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestEmptyArrayDecoding(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ var val []string
+
+ err := conn.QueryRow("select array[]::text[]").Scan(&val)
+ if err != nil {
+ t.Errorf(`error reading array: %v`, err)
+ }
+ if len(val) != 0 {
+ t.Errorf("Expected 0 values, got %d", len(val))
+ }
+
+ var n, m int32
+
+ err = conn.QueryRow("select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m)
+ if err != nil {
+ t.Errorf(`error reading array: %v`, err)
+ }
+ if len(val) != 0 {
+ t.Errorf("Expected 0 values, got %d", len(val))
+ }
+ if n != 1 {
+ t.Errorf("Expected n to be 1, but it was %d", n)
+ }
+ if m != 42 {
+ t.Errorf("Expected n to be 42, but it was %d", n)
+ }
+
+ rows, err := conn.Query("select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]")
+ if err != nil {
+ t.Errorf(`error retrieving rows with array: %v`, err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ err = rows.Scan(&n, &val)
+ if err != nil {
+ t.Errorf(`error reading array: %v`, err)
+ }
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestNullXMismatch(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type allTypes struct {
+ s pgx.NullString
+ i16 pgx.NullInt16
+ i32 pgx.NullInt32
+ i64 pgx.NullInt64
+ f32 pgx.NullFloat32
+ f64 pgx.NullFloat64
+ b pgx.NullBool
+ t pgx.NullTime
+ }
+
+ var actual, zero allTypes
+
+ tests := []struct {
+ sql string
+ queryArgs []interface{}
+ scanArgs []interface{}
+ err string
+ }{
+ {"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"},
+ {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into OID 1082"},
+ {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into OID 1082"},
+ {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into OID 1082"},
+ {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into OID 1082"},
+ {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into OID 1082"},
+ {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into OID 1082"},
+ {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 23"},
+ }
+
+ for i, tt := range tests {
+ actual = zero
+
+ err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
+ if err == nil || !strings.Contains(err.Error(), tt.err) {
+ t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestPointerPointer(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type allTypes struct {
+ s *string
+ i16 *int16
+ i32 *int32
+ i64 *int64
+ f32 *float32
+ f64 *float64
+ b *bool
+ t *time.Time
+ }
+
+ var actual, zero, expected allTypes
+
+ {
+ s := "foo"
+ expected.s = &s
+ i16 := int16(1)
+ expected.i16 = &i16
+ i32 := int32(1)
+ expected.i32 = &i32
+ i64 := int64(1)
+ expected.i64 = &i64
+ f32 := float32(1.23)
+ expected.f32 = &f32
+ f64 := float64(1.23)
+ expected.f64 = &f64
+ b := true
+ expected.b = &b
+ t := time.Unix(123, 5000)
+ expected.t = &t
+ }
+
+ tests := []struct {
+ sql string
+ queryArgs []interface{}
+ scanArgs []interface{}
+ expected allTypes
+ }{
+ {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}},
+ {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}},
+ {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}},
+ {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}},
+ {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}},
+ {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}},
+ {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}},
+ {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}},
+ {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}},
+ {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}},
+ {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}},
+ {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}},
+ {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}},
+ {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}},
+ {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}},
+ {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}},
+ {"select $1::timestamp", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}},
+ {"select $1::timestamp", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}},
+ }
+
+ for i, tt := range tests {
+ actual = zero
+
+ err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
+ }
+
+ if !reflect.DeepEqual(actual, tt.expected) {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}
+
+func TestPointerPointerNonZero(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ f := "foo"
+ dest := &f
+
+ err := conn.QueryRow("select $1::text", nil).Scan(&dest)
+ if err != nil {
+ t.Errorf("Unexpected failure scanning: %v", err)
+ }
+ if dest != nil {
+ t.Errorf("Expected dest to be nil, got %#v", dest)
+ }
+}
+
+func TestEncodeTypeRename(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ type _int int
+ inInt := _int(3)
+ var outInt _int
+
+ type _int8 int8
+ inInt8 := _int8(3)
+ var outInt8 _int8
+
+ type _int16 int16
+ inInt16 := _int16(3)
+ var outInt16 _int16
+
+ type _int32 int32
+ inInt32 := _int32(4)
+ var outInt32 _int32
+
+ type _int64 int64
+ inInt64 := _int64(5)
+ var outInt64 _int64
+
+ type _uint uint
+ inUint := _uint(6)
+ var outUint _uint
+
+ type _uint8 uint8
+ inUint8 := _uint8(7)
+ var outUint8 _uint8
+
+ type _uint16 uint16
+ inUint16 := _uint16(8)
+ var outUint16 _uint16
+
+ type _uint32 uint32
+ inUint32 := _uint32(9)
+ var outUint32 _uint32
+
+ type _uint64 uint64
+ inUint64 := _uint64(10)
+ var outUint64 _uint64
+
+ type _string string
+ inString := _string("foo")
+ var outString _string
+
+ err := conn.QueryRow("select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text",
+ inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString,
+ ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)
+ if err != nil {
+ t.Fatalf("Failed with type rename: %v", err)
+ }
+
+ if inInt != outInt {
+ t.Errorf("int rename: expected %v, got %v", inInt, outInt)
+ }
+
+ if inInt8 != outInt8 {
+ t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8)
+ }
+
+ if inInt16 != outInt16 {
+ t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16)
+ }
+
+ if inInt32 != outInt32 {
+ t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32)
+ }
+
+ if inInt64 != outInt64 {
+ t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64)
+ }
+
+ if inUint != outUint {
+ t.Errorf("uint rename: expected %v, got %v", inUint, outUint)
+ }
+
+ if inUint8 != outUint8 {
+ t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8)
+ }
+
+ if inUint16 != outUint16 {
+ t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16)
+ }
+
+ if inUint32 != outUint32 {
+ t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32)
+ }
+
+ if inUint64 != outUint64 {
+ t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64)
+ }
+
+ if inString != outString {
+ t.Errorf("string rename: expected %v, got %v", inString, outString)
+ }
+
+ ensureConnValid(t, conn)
+}
+
+func TestRowDecode(t *testing.T) {
+ t.Parallel()
+
+ conn := mustConnect(t, *defaultConnConfig)
+ defer closeConn(t, conn)
+
+ tests := []struct {
+ sql string
+ expected []interface{}
+ }{
+ {
+ "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)",
+ []interface{}{
+ int32(1),
+ "cat",
+ time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(),
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ var actual []interface{}
+
+ err := conn.QueryRow(tt.sql).Scan(&actual)
+ if err != nil {
+ t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
+ continue
+ }
+
+ if !reflect.DeepEqual(actual, tt.expected) {
+ t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
+ }
+
+ ensureConnValid(t, conn)
+ }
+}