From 3185a43420dab2efb85aa601bdfe267e8b2e11f2 Mon Sep 17 00:00:00 2001 From: Felix Hanley Date: Sun, 11 Jun 2017 23:25:57 +1000 Subject: Add torrent tagging - Move to postgresql to reduce locking issues - Move ih hash to single thread - Clean up stuff --- vendor/github.com/jackc/pgx/.gitignore | 24 + vendor/github.com/jackc/pgx/.travis.yml | 60 + vendor/github.com/jackc/pgx/CHANGELOG.md | 190 ++ vendor/github.com/jackc/pgx/LICENSE | 22 + vendor/github.com/jackc/pgx/README.md | 137 + vendor/github.com/jackc/pgx/aclitem_parse_test.go | 126 + vendor/github.com/jackc/pgx/bench_test.go | 765 +++++ vendor/github.com/jackc/pgx/conn.go | 1322 ++++++++ .../jackc/pgx/conn_config_test.go.example | 25 + .../jackc/pgx/conn_config_test.go.travis | 30 + vendor/github.com/jackc/pgx/conn_pool.go | 529 +++ .../github.com/jackc/pgx/conn_pool_private_test.go | 44 + vendor/github.com/jackc/pgx/conn_pool_test.go | 982 ++++++ vendor/github.com/jackc/pgx/conn_test.go | 1744 ++++++++++ vendor/github.com/jackc/pgx/copy_from.go | 241 ++ vendor/github.com/jackc/pgx/copy_from_test.go | 428 +++ vendor/github.com/jackc/pgx/copy_to.go | 222 ++ vendor/github.com/jackc/pgx/copy_to_test.go | 367 +++ vendor/github.com/jackc/pgx/doc.go | 265 ++ .../jackc/pgx/example_custom_type_test.go | 104 + vendor/github.com/jackc/pgx/example_json_test.go | 43 + vendor/github.com/jackc/pgx/examples/README.md | 7 + .../github.com/jackc/pgx/examples/chat/README.md | 25 + vendor/github.com/jackc/pgx/examples/chat/main.go | 95 + .../github.com/jackc/pgx/examples/todo/README.md | 72 + vendor/github.com/jackc/pgx/examples/todo/main.go | 140 + .../jackc/pgx/examples/todo/structure.sql | 4 + .../jackc/pgx/examples/url_shortener/README.md | 25 + .../jackc/pgx/examples/url_shortener/main.go | 128 + .../jackc/pgx/examples/url_shortener/structure.sql | 4 + vendor/github.com/jackc/pgx/fastpath.go | 108 + vendor/github.com/jackc/pgx/helper_test.go | 74 + vendor/github.com/jackc/pgx/hstore.go | 222 ++ vendor/github.com/jackc/pgx/hstore_test.go | 181 ++ vendor/github.com/jackc/pgx/large_objects.go | 147 + vendor/github.com/jackc/pgx/large_objects_test.go | 121 + vendor/github.com/jackc/pgx/logger.go | 81 + vendor/github.com/jackc/pgx/messages.go | 159 + vendor/github.com/jackc/pgx/msg_reader.go | 316 ++ vendor/github.com/jackc/pgx/pgpass.go | 85 + vendor/github.com/jackc/pgx/pgpass_test.go | 57 + vendor/github.com/jackc/pgx/query.go | 494 +++ vendor/github.com/jackc/pgx/query_test.go | 1414 ++++++++ vendor/github.com/jackc/pgx/replication.go | 429 +++ vendor/github.com/jackc/pgx/replication_test.go | 329 ++ vendor/github.com/jackc/pgx/sql.go | 29 + vendor/github.com/jackc/pgx/sql_test.go | 36 + vendor/github.com/jackc/pgx/stdlib/sql.go | 348 ++ vendor/github.com/jackc/pgx/stdlib/sql_test.go | 691 ++++ vendor/github.com/jackc/pgx/stress_test.go | 346 ++ vendor/github.com/jackc/pgx/tx.go | 207 ++ vendor/github.com/jackc/pgx/tx_test.go | 297 ++ vendor/github.com/jackc/pgx/value_reader.go | 156 + vendor/github.com/jackc/pgx/values.go | 3439 ++++++++++++++++++++ vendor/github.com/jackc/pgx/values_test.go | 1183 +++++++ 55 files changed, 19119 insertions(+) create mode 100644 vendor/github.com/jackc/pgx/.gitignore create mode 100644 vendor/github.com/jackc/pgx/.travis.yml create mode 100644 vendor/github.com/jackc/pgx/CHANGELOG.md create mode 100644 vendor/github.com/jackc/pgx/LICENSE create mode 100644 vendor/github.com/jackc/pgx/README.md create mode 100644 vendor/github.com/jackc/pgx/aclitem_parse_test.go create mode 100644 vendor/github.com/jackc/pgx/bench_test.go create mode 100644 vendor/github.com/jackc/pgx/conn.go create mode 100644 vendor/github.com/jackc/pgx/conn_config_test.go.example create mode 100644 vendor/github.com/jackc/pgx/conn_config_test.go.travis create mode 100644 vendor/github.com/jackc/pgx/conn_pool.go create mode 100644 vendor/github.com/jackc/pgx/conn_pool_private_test.go create mode 100644 vendor/github.com/jackc/pgx/conn_pool_test.go create mode 100644 vendor/github.com/jackc/pgx/conn_test.go create mode 100644 vendor/github.com/jackc/pgx/copy_from.go create mode 100644 vendor/github.com/jackc/pgx/copy_from_test.go create mode 100644 vendor/github.com/jackc/pgx/copy_to.go create mode 100644 vendor/github.com/jackc/pgx/copy_to_test.go create mode 100644 vendor/github.com/jackc/pgx/doc.go create mode 100644 vendor/github.com/jackc/pgx/example_custom_type_test.go create mode 100644 vendor/github.com/jackc/pgx/example_json_test.go create mode 100644 vendor/github.com/jackc/pgx/examples/README.md create mode 100644 vendor/github.com/jackc/pgx/examples/chat/README.md create mode 100644 vendor/github.com/jackc/pgx/examples/chat/main.go create mode 100644 vendor/github.com/jackc/pgx/examples/todo/README.md create mode 100644 vendor/github.com/jackc/pgx/examples/todo/main.go create mode 100644 vendor/github.com/jackc/pgx/examples/todo/structure.sql create mode 100644 vendor/github.com/jackc/pgx/examples/url_shortener/README.md create mode 100644 vendor/github.com/jackc/pgx/examples/url_shortener/main.go create mode 100644 vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql create mode 100644 vendor/github.com/jackc/pgx/fastpath.go create mode 100644 vendor/github.com/jackc/pgx/helper_test.go create mode 100644 vendor/github.com/jackc/pgx/hstore.go create mode 100644 vendor/github.com/jackc/pgx/hstore_test.go create mode 100644 vendor/github.com/jackc/pgx/large_objects.go create mode 100644 vendor/github.com/jackc/pgx/large_objects_test.go create mode 100644 vendor/github.com/jackc/pgx/logger.go create mode 100644 vendor/github.com/jackc/pgx/messages.go create mode 100644 vendor/github.com/jackc/pgx/msg_reader.go create mode 100644 vendor/github.com/jackc/pgx/pgpass.go create mode 100644 vendor/github.com/jackc/pgx/pgpass_test.go create mode 100644 vendor/github.com/jackc/pgx/query.go create mode 100644 vendor/github.com/jackc/pgx/query_test.go create mode 100644 vendor/github.com/jackc/pgx/replication.go create mode 100644 vendor/github.com/jackc/pgx/replication_test.go create mode 100644 vendor/github.com/jackc/pgx/sql.go create mode 100644 vendor/github.com/jackc/pgx/sql_test.go create mode 100644 vendor/github.com/jackc/pgx/stdlib/sql.go create mode 100644 vendor/github.com/jackc/pgx/stdlib/sql_test.go create mode 100644 vendor/github.com/jackc/pgx/stress_test.go create mode 100644 vendor/github.com/jackc/pgx/tx.go create mode 100644 vendor/github.com/jackc/pgx/tx_test.go create mode 100644 vendor/github.com/jackc/pgx/value_reader.go create mode 100644 vendor/github.com/jackc/pgx/values.go create mode 100644 vendor/github.com/jackc/pgx/values_test.go (limited to 'vendor/github.com/jackc/pgx') diff --git a/vendor/github.com/jackc/pgx/.gitignore b/vendor/github.com/jackc/pgx/.gitignore new file mode 100644 index 0000000..cb0cd90 --- /dev/null +++ b/vendor/github.com/jackc/pgx/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe + +conn_config_test.go diff --git a/vendor/github.com/jackc/pgx/.travis.yml b/vendor/github.com/jackc/pgx/.travis.yml new file mode 100644 index 0000000..d9ea43b --- /dev/null +++ b/vendor/github.com/jackc/pgx/.travis.yml @@ -0,0 +1,60 @@ +language: go + +go: + - 1.7.4 + - 1.6.4 + - tip + +# Derived from https://github.com/lib/pq/blob/master/.travis.yml +before_install: + - sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common + - sudo rm -rf /var/lib/postgresql + - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" + - sudo apt-get update -qq + - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION + - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf + - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf + - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + - echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + - echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + - echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + - echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + - echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + - sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf + - "[[ $PGVERSION < 9.6 ]] || echo \"wal_level='logical'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf" + - "[[ $PGVERSION < 9.6 ]] || echo \"max_wal_senders=5\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf" + - "[[ $PGVERSION < 9.6 ]] || echo \"max_replication_slots=5\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf" + - sudo /etc/init.d/postgresql restart + +env: + matrix: + - PGVERSION=9.6 + - PGVERSION=9.5 + - PGVERSION=9.4 + - PGVERSION=9.3 + - PGVERSION=9.2 + +# The tricky test user, below, has to actually exist so that it can be used in a test +# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. +before_script: + - mv conn_config_test.go.travis conn_config_test.go + - psql -U postgres -c 'create database pgx_test' + - "[[ \"${PGVERSION}\" = '9.0' ]] && psql -U postgres -f /usr/share/postgresql/9.0/contrib/hstore.sql pgx_test || psql -U postgres pgx_test -c 'create extension hstore'" + - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" + - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + - psql -U postgres -c "create user pgx_replication with replication password 'secret'" + - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" + +install: + - go get -u github.com/shopspring/decimal + - go get -u gopkg.in/inconshreveable/log15.v2 + - go get -u github.com/jackc/fake + +script: + - go test -v -race -short ./... + +matrix: + allow_failures: + - go: tip diff --git a/vendor/github.com/jackc/pgx/CHANGELOG.md b/vendor/github.com/jackc/pgx/CHANGELOG.md new file mode 100644 index 0000000..88c663b --- /dev/null +++ b/vendor/github.com/jackc/pgx/CHANGELOG.md @@ -0,0 +1,190 @@ +# Unreleased + +## Fixes + +* database/sql driver created through stdlib.OpenFromConnPool closes connections when requested by database/sql rather than release to underlying connection pool. + +# 2.11.0 (June 5, 2017) + +## Fixes + +* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock) + +## Features + +* .pgpass support (j7b) +* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen) +* Add ParseConnectionString (James Lawrence) + +## Performance + +* Optimize HStore encoding (René Kroon) + +# 2.10.0 (March 17, 2017) + +## Fixes + +* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood) +* Explicitly close checked-in connections on ConnPool.Reset, previously they were closed by GC + +## Features + +* Add xid type support (Manni Wood) +* Add cid type support (Manni Wood) +* Add tid type support (Manni Wood) +* Add "char" type support (Manni Wood) +* Add NullOid type (Manni Wood) +* Add json/jsonb binary support to allow use with CopyTo +* Add named error ErrAcquireTimeout (Alexander Staubo) +* Add logical replication decoding (Kris Wehner) +* Add PgxScanner interface to allow types to simultaneously support database/sql and pgx (Jack Christensen) +* Add CopyFrom with schema support (Jack Christensen) + +## Compatibility + +* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work. +* CopyTo is now deprecated but will continue to work. + +# 2.9.0 (August 26, 2016) + +## Fixes + +* Fix *ConnPool.Deallocate() not deleting prepared statement from map +* Fix stdlib not logging unprepared query SQL (Krzysztof Dryś) +* Fix Rows.Values() with varchar binary format +* Concurrent ConnPool.Acquire calls with Dialer timeouts now timeout in the expected amount of time (Konstantin Dzreev) + +## Features + +* Add CopyTo +* Add PrepareEx +* Add basic record to []interface{} decoding +* Encode and decode between all Go and PostgreSQL integer types with bounds checking +* Decode inet/cidr to net.IP +* Encode/decode [][]byte to/from bytea[] +* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64 + +## Performance + +* Substantial reduction in memory allocations + +# 2.8.1 (March 24, 2016) + +## Features + +* Scan accepts nil argument to ignore a column + +## Fixes + +* Fix compilation on 32-bit architecture +* Fix Tx.status not being set on error on Commit +* Fix Listen/Unlisten with special characters + +# 2.8.0 (March 18, 2016) + +## Fixes + +* Fix unrecognized commit failure +* Fix msgReader.rxMsg bug when msgReader already has error +* Go float64 can no longer be encoded to a PostgreSQL float4 +* Fix connection corruption when query with error is closed early + +## Features + +This release adds multiple extension points helpful when wrapping pgx with +custom application behavior. pgx can now use custom types designed for the +standard database/sql package such as +[github.com/shopspring/decimal](https://github.com/shopspring/decimal). + +* Add *Tx.AfterClose() hook +* Add *Tx.Conn() +* Add *Tx.Status() +* Add *Tx.Err() +* Add *Rows.AfterClose() hook +* Add *Rows.Conn() +* Add *Conn.SetLogger() to allow changing logger +* Add *Conn.SetLogLevel() to allow changing log level +* Add ConnPool.Reset method +* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces +* Rows.Scan errors now include which argument caused error +* Add Encode() to allow custom Encoders to reuse internal encoding functionality +* Add Decode() to allow customer Decoders to reuse internal decoding functionality +* Add ConnPool.Prepare method +* Add ConnPool.Deallocate method +* Add Scan to uint32 and uint64 (utrack) +* Add encode and decode to []uint16, []uint32, and []uint64 (Max Musatov) + +## Performance + +* []byte skips encoding/decoding + +# 2.7.1 (October 26, 2015) + +* Disable SSL renegotiation + +# 2.7.0 (October 16, 2015) + +* Add RuntimeParams to ConnConfig +* ParseURI extracts RuntimeParams +* ParseDSN extracts RuntimeParams +* ParseEnvLibpq extracts PGAPPNAME +* Prepare is now idempotent +* Rows.Values now supports oid type +* ConnPool.Release automatically unlistens connections (Joseph Glanville) +* Add trace log level +* Add more efficient log leveling +* Retry automatically on ConnPool.Begin (Joseph Glanville) +* Encode from net.IP to inet and cidr +* Generalize encoding pointer to string to any PostgreSQL type +* Add UUID encoding from pointer to string (Joseph Glanville) +* Add null mapping to pointer to pointer (Jonathan Rudenberg) +* Add JSON and JSONB type support (Joseph Glanville) + +# 2.6.0 (September 3, 2015) + +* Add inet and cidr type support +* Add binary decoding to TimestampOid in stdlib driver (Samuel Stauffer) +* Add support for specifying sslmode in connection strings (Rick Snyder) +* Allow ConnPool to have MaxConnections of 1 +* Add basic PGSSLMODE to support to ParseEnvLibpq +* Add fallback TLS config +* Expose specific error for TSL refused +* More error details exposed in PgError +* Support custom dialer (Lewis Marshall) + +# 2.5.0 (April 15, 2015) + +* Fix stdlib nil support (Blaž Hrastnik) +* Support custom Scanner not reading entire value +* Fix empty array scanning (Laurent Debacker) +* Add ParseDSN (deoxxa) +* Add timestamp support to NullTime +* Remove unused text format scanners +* Return error when too many parameters on Prepare +* Add Travis CI integration (Jonathan Rudenberg) +* Large object support (Jonathan Rudenberg) +* Fix reading null byte arrays (Karl Seguin) +* Add timestamptz[] support +* Add timestamp[] support (Karl Seguin) +* Add bool[] support (Karl Seguin) +* Allow writing []byte into text and varchar columns without type conversion (Hari Bhaskaran) +* Fix ConnPool Close panic +* Add Listen / notify example +* Reduce memory allocations (Karl Seguin) + +# 2.4.0 (October 3, 2014) + +* Add per connection oid to name map +* Add Hstore support (Andy Walker) +* Move introductory docs to godoc from readme +* Fix documentation references to TextEncoder and BinaryEncoder +* Add keep-alive to TCP connections (Andy Walker) +* Add support for EmptyQueryResponse / Allow no-op Exec (Andy Walker) +* Allow reading any type into []byte +* WaitForNotification detects lost connections quicker + +# 2.3.0 (September 16, 2014) + +* Truncate logged strings and byte slices +* Extract more error information from PostgreSQL +* Fix data race with Rows and ConnPool diff --git a/vendor/github.com/jackc/pgx/LICENSE b/vendor/github.com/jackc/pgx/LICENSE new file mode 100644 index 0000000..7dee3da --- /dev/null +++ b/vendor/github.com/jackc/pgx/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/vendor/github.com/jackc/pgx/README.md b/vendor/github.com/jackc/pgx/README.md new file mode 100644 index 0000000..5550f6b --- /dev/null +++ b/vendor/github.com/jackc/pgx/README.md @@ -0,0 +1,137 @@ +[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx) + +# Pgx + +## Master Branch + +This is the `master` branch which tracks the stable release of the current +version. At the moment this is `v2`. The `v3` branch which is currently in beta. +General release is planned for July. `v3` is considered to be stable in the +sense of lack of known bugs, but the API is not considered stable until general +release. No further changes are planned, but the beta process may surface +desirable changes. If possible API changes are acceptable, then `v3` is the +recommented branch for new development. Regardless, please lock to the `v2` or +`v3` branch as when `v3` is released breaking changes will be applied to the +master branch. + +Pgx is a pure Go database connection library designed specifically for +PostgreSQL. Pgx is different from other drivers such as +[pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a +database/sql compatible driver, pgx is primarily intended to be used directly. +It offers a native interface similar to database/sql that offers better +performance and more features. + +## Features + +Pgx supports many additional features beyond what is available through database/sql. + +* Listen / notify +* Transaction isolation level control +* Full TLS connection control +* Binary format support for custom types (can be much faster) +* Copy from protocol support for faster bulk data loads +* Logging support +* Configurable connection pool with after connect hooks to do arbitrary connection setup +* PostgreSQL array to Go slice mapping for integers, floats, and strings +* Hstore support +* JSON and JSONB support +* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP +* Large object support +* Null mapping to Null* struct or pointer to pointer. +* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types +* Logical replication connections, including receiving WAL and sending standby status updates + +## Performance + +Pgx performs roughly equivalent to [pq](http://godoc.org/github.com/lib/pq) and +[go-pg](https://github.com/go-pg/pg) for selecting a single column from a single +row, but it is substantially faster when selecting multiple entire rows (6893 +queries/sec for pgx vs. 3968 queries/sec for pq -- 73% faster). + +See this [gist](https://gist.github.com/jackc/d282f39e088b495fba3e) for the +underlying benchmark results or checkout +[go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself. + +## database/sql + +Import the ```github.com/jackc/pgx/stdlib``` package to use pgx as a driver for +database/sql. It is possible to retrieve a pgx connection from database/sql on +demand. This allows using the database/sql interface in most places, but using +pgx directly when more performance or PostgreSQL specific features are needed. + +## Documentation + +pgx includes extensive documentation in the godoc format. It is viewable online at [godoc.org](https://godoc.org/github.com/jackc/pgx). + +## Testing + +pgx supports multiple connection and authentication types. Setting up a test +environment that can test all of them can be cumbersome. In particular, +Windows cannot test Unix domain socket connections. Because of this pgx will +skip tests for connection types that are not configured. + +### Normal Test Environment + +To setup the normal test environment, first install these dependencies: + + go get github.com/jackc/fake + go get github.com/shopspring/decimal + go get gopkg.in/inconshreveable/log15.v2 + +Then run the following SQL: + + create user pgx_md5 password 'secret'; + create user " tricky, ' } "" \ test user " password 'secret'; + create database pgx_test; + create user pgx_replication with replication password 'secret'; + +Connect to database pgx_test and run: + + create extension hstore; + +Next open conn_config_test.go.example and make a copy without the +.example. If your PostgreSQL server is accepting connections on 127.0.0.1, +then you are done. + +### Connection and Authentication Test Environment + +Complete the normal test environment setup and also do the following. + +Run the following SQL: + + create user pgx_none; + create user pgx_pw password 'secret'; + +Add the following to your pg_hba.conf: + +If you are developing on Unix with domain socket connections: + + local pgx_test pgx_none trust + local pgx_test pgx_pw password + local pgx_test pgx_md5 md5 + +If you are developing on Windows with TCP connections: + + host pgx_test pgx_none 127.0.0.1/32 trust + host pgx_test pgx_pw 127.0.0.1/32 password + host pgx_test pgx_md5 127.0.0.1/32 md5 + +### Replication Test Environment + +Add a replication user: + + create user pgx_replication with replication password 'secret'; + +Add a replication line to your pg_hba.conf: + + host replication pgx_replication 127.0.0.1/32 md5 + +Change the following settings in your postgresql.conf: + + wal_level=logical + max_wal_senders=5 + max_replication_slots=5 + +## Version Policy + +pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely). diff --git a/vendor/github.com/jackc/pgx/aclitem_parse_test.go b/vendor/github.com/jackc/pgx/aclitem_parse_test.go new file mode 100644 index 0000000..5c7c748 --- /dev/null +++ b/vendor/github.com/jackc/pgx/aclitem_parse_test.go @@ -0,0 +1,126 @@ +package pgx + +import ( + "reflect" + "testing" +) + +func TestEscapeAclItem(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + "foo", + "foo", + }, + { + `foo, "\}`, + `foo\, \"\\\}`, + }, + } + + for i, tt := range tests { + actual, err := escapeAclItem(tt.input) + + if err != nil { + t.Errorf("%d. Unexpected error %v", i, err) + } + + if actual != tt.expected { + t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual) + } + } +} + +func TestParseAclItemArray(t *testing.T) { + tests := []struct { + input string + expected []AclItem + errMsg string + }{ + { + "", + []AclItem{}, + "", + }, + { + "one", + []AclItem{"one"}, + "", + }, + { + `"one"`, + []AclItem{"one"}, + "", + }, + { + "one,two,three", + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","two","three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one",two,"three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `one,two,"three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","two",three`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","t w o",three`, + []AclItem{"one", "t w o", "three"}, + "", + }, + { + `"one","t, w o\"\}\\",three`, + []AclItem{"one", `t, w o"}\`, "three"}, + "", + }, + { + `"one","two",three"`, + []AclItem{"one", "two", `three"`}, + "", + }, + { + `"one","two,"three"`, + nil, + "unexpected rune after quoted value", + }, + { + `"one","two","three`, + nil, + "unexpected end of quoted value", + }, + } + + for i, tt := range tests { + actual, err := parseAclItemArray(tt.input) + + if err != nil { + if tt.errMsg == "" { + t.Errorf("%d. Unexpected error %v", i, err) + } else if err.Error() != tt.errMsg { + t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error()) + } + } else if tt.errMsg != "" { + t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg) + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual) + } + } +} diff --git a/vendor/github.com/jackc/pgx/bench_test.go b/vendor/github.com/jackc/pgx/bench_test.go new file mode 100644 index 0000000..30e31e2 --- /dev/null +++ b/vendor/github.com/jackc/pgx/bench_test.go @@ -0,0 +1,765 @@ +package pgx_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + "time" + + "github.com/jackc/pgx" + log "gopkg.in/inconshreveable/log15.v2" +) + +func BenchmarkConnPool(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var conn *pgx.Conn + if conn, err = pool.Acquire(); err != nil { + b.Fatalf("Unable to acquire connection: %v", err) + } + pool.Release(conn) + } +} + +func BenchmarkConnPoolQueryRow(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + num := float64(-1) + if err := pool.QueryRow("select random()").Scan(&num); err != nil { + b.Fatal(err) + } + + if num < 0 { + b.Fatalf("expected `select random()` to return between 0 and 1 but it was: %v", num) + } + } +} + +func BenchmarkNullXWithNullValues(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var record struct { + id int32 + userName string + email pgx.NullString + name pgx.NullString + sex pgx.NullString + birthDate pgx.NullTime + lastLoginTime pgx.NullTime + } + + err = conn.QueryRow("selectNulls").Scan( + &record.id, + &record.userName, + &record.email, + &record.name, + &record.sex, + &record.birthDate, + &record.lastLoginTime, + ) + if err != nil { + b.Fatal(err) + } + + // These checks both ensure that the correct data was returned + // and provide a benchmark of accessing the returned values. + if record.id != 1 { + b.Fatalf("bad value for id: %v", record.id) + } + if record.userName != "johnsmith" { + b.Fatalf("bad value for userName: %v", record.userName) + } + if record.email.Valid { + b.Fatalf("bad value for email: %v", record.email) + } + if record.name.Valid { + b.Fatalf("bad value for name: %v", record.name) + } + if record.sex.Valid { + b.Fatalf("bad value for sex: %v", record.sex) + } + if record.birthDate.Valid { + b.Fatalf("bad value for birthDate: %v", record.birthDate) + } + if record.lastLoginTime.Valid { + b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) + } + } +} + +func BenchmarkNullXWithPresentValues(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var record struct { + id int32 + userName string + email pgx.NullString + name pgx.NullString + sex pgx.NullString + birthDate pgx.NullTime + lastLoginTime pgx.NullTime + } + + err = conn.QueryRow("selectNulls").Scan( + &record.id, + &record.userName, + &record.email, + &record.name, + &record.sex, + &record.birthDate, + &record.lastLoginTime, + ) + if err != nil { + b.Fatal(err) + } + + // These checks both ensure that the correct data was returned + // and provide a benchmark of accessing the returned values. + if record.id != 1 { + b.Fatalf("bad value for id: %v", record.id) + } + if record.userName != "johnsmith" { + b.Fatalf("bad value for userName: %v", record.userName) + } + if !record.email.Valid || record.email.String != "johnsmith@example.com" { + b.Fatalf("bad value for email: %v", record.email) + } + if !record.name.Valid || record.name.String != "John Smith" { + b.Fatalf("bad value for name: %v", record.name) + } + if !record.sex.Valid || record.sex.String != "male" { + b.Fatalf("bad value for sex: %v", record.sex) + } + if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { + b.Fatalf("bad value for birthDate: %v", record.birthDate) + } + if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { + b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) + } + } +} + +func BenchmarkPointerPointerWithNullValues(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var record struct { + id int32 + userName string + email *string + name *string + sex *string + birthDate *time.Time + lastLoginTime *time.Time + } + + err = conn.QueryRow("selectNulls").Scan( + &record.id, + &record.userName, + &record.email, + &record.name, + &record.sex, + &record.birthDate, + &record.lastLoginTime, + ) + if err != nil { + b.Fatal(err) + } + + // These checks both ensure that the correct data was returned + // and provide a benchmark of accessing the returned values. + if record.id != 1 { + b.Fatalf("bad value for id: %v", record.id) + } + if record.userName != "johnsmith" { + b.Fatalf("bad value for userName: %v", record.userName) + } + if record.email != nil { + b.Fatalf("bad value for email: %v", record.email) + } + if record.name != nil { + b.Fatalf("bad value for name: %v", record.name) + } + if record.sex != nil { + b.Fatalf("bad value for sex: %v", record.sex) + } + if record.birthDate != nil { + b.Fatalf("bad value for birthDate: %v", record.birthDate) + } + if record.lastLoginTime != nil { + b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) + } + } +} + +func BenchmarkPointerPointerWithPresentValues(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var record struct { + id int32 + userName string + email *string + name *string + sex *string + birthDate *time.Time + lastLoginTime *time.Time + } + + err = conn.QueryRow("selectNulls").Scan( + &record.id, + &record.userName, + &record.email, + &record.name, + &record.sex, + &record.birthDate, + &record.lastLoginTime, + ) + if err != nil { + b.Fatal(err) + } + + // These checks both ensure that the correct data was returned + // and provide a benchmark of accessing the returned values. + if record.id != 1 { + b.Fatalf("bad value for id: %v", record.id) + } + if record.userName != "johnsmith" { + b.Fatalf("bad value for userName: %v", record.userName) + } + if record.email == nil || *record.email != "johnsmith@example.com" { + b.Fatalf("bad value for email: %v", record.email) + } + if record.name == nil || *record.name != "John Smith" { + b.Fatalf("bad value for name: %v", record.name) + } + if record.sex == nil || *record.sex != "male" { + b.Fatalf("bad value for sex: %v", record.sex) + } + if record.birthDate == nil || *record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { + b.Fatalf("bad value for birthDate: %v", record.birthDate) + } + if record.lastLoginTime == nil || *record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { + b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) + } + } +} + +func BenchmarkSelectWithoutLogging(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + benchmarkSelectWithLog(b, conn) +} + +func BenchmarkSelectWithLoggingTraceWithLog15(b *testing.B) { + connConfig := *defaultConnConfig + + logger := log.New() + lvl, err := log.LvlFromString("debug") + if err != nil { + b.Fatal(err) + } + logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) + connConfig.Logger = logger + connConfig.LogLevel = pgx.LogLevelTrace + conn := mustConnect(b, connConfig) + defer closeConn(b, conn) + + benchmarkSelectWithLog(b, conn) +} + +func BenchmarkSelectWithLoggingDebugWithLog15(b *testing.B) { + connConfig := *defaultConnConfig + + logger := log.New() + lvl, err := log.LvlFromString("debug") + if err != nil { + b.Fatal(err) + } + logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) + connConfig.Logger = logger + connConfig.LogLevel = pgx.LogLevelDebug + conn := mustConnect(b, connConfig) + defer closeConn(b, conn) + + benchmarkSelectWithLog(b, conn) +} + +func BenchmarkSelectWithLoggingInfoWithLog15(b *testing.B) { + connConfig := *defaultConnConfig + + logger := log.New() + lvl, err := log.LvlFromString("info") + if err != nil { + b.Fatal(err) + } + logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) + connConfig.Logger = logger + connConfig.LogLevel = pgx.LogLevelInfo + conn := mustConnect(b, connConfig) + defer closeConn(b, conn) + + benchmarkSelectWithLog(b, conn) +} + +func BenchmarkSelectWithLoggingErrorWithLog15(b *testing.B) { + connConfig := *defaultConnConfig + + logger := log.New() + lvl, err := log.LvlFromString("error") + if err != nil { + b.Fatal(err) + } + logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) + connConfig.Logger = logger + connConfig.LogLevel = pgx.LogLevelError + conn := mustConnect(b, connConfig) + defer closeConn(b, conn) + + benchmarkSelectWithLog(b, conn) +} + +func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { + _, err := conn.Prepare("test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var record struct { + id int32 + userName string + email string + name string + sex string + birthDate time.Time + lastLoginTime time.Time + } + + err = conn.QueryRow("test").Scan( + &record.id, + &record.userName, + &record.email, + &record.name, + &record.sex, + &record.birthDate, + &record.lastLoginTime, + ) + if err != nil { + b.Fatal(err) + } + + // These checks both ensure that the correct data was returned + // and provide a benchmark of accessing the returned values. + if record.id != 1 { + b.Fatalf("bad value for id: %v", record.id) + } + if record.userName != "johnsmith" { + b.Fatalf("bad value for userName: %v", record.userName) + } + if record.email != "johnsmith@example.com" { + b.Fatalf("bad value for email: %v", record.email) + } + if record.name != "John Smith" { + b.Fatalf("bad value for name: %v", record.name) + } + if record.sex != "male" { + b.Fatalf("bad value for sex: %v", record.sex) + } + if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { + b.Fatalf("bad value for birthDate: %v", record.birthDate) + } + if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { + b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) + } + } +} + +func BenchmarkLog15Discard(b *testing.B) { + logger := log.New() + lvl, err := log.LvlFromString("error") + if err != nil { + b.Fatal(err) + } + logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Debug("benchmark", "i", i, "b.N", b.N) + } +} + +const benchmarkWriteTableCreateSQL = `drop table if exists t; + +create table t( + varchar_1 varchar not null, + varchar_2 varchar not null, + varchar_null_1 varchar, + date_1 date not null, + date_null_1 date, + int4_1 int4 not null, + int4_2 int4 not null, + int4_null_1 int4, + tstz_1 timestamptz not null, + tstz_2 timestamptz, + bool_1 bool not null, + bool_2 bool not null, + bool_3 bool not null +); +` + +const benchmarkWriteTableInsertSQL = `insert into t( + varchar_1, + varchar_2, + varchar_null_1, + date_1, + date_null_1, + int4_1, + int4_2, + int4_null_1, + tstz_1, + tstz_2, + bool_1, + bool_2, + bool_3 +) values ( + $1::varchar, + $2::varchar, + $3::varchar, + $4::date, + $5::date, + $6::int4, + $7::int4, + $8::int4, + $9::timestamptz, + $10::timestamptz, + $11::bool, + $12::bool, + $13::bool +)` + +type benchmarkWriteTableCopyFromSrc struct { + count int + idx int + row []interface{} +} + +func (s *benchmarkWriteTableCopyFromSrc) Next() bool { + s.idx++ + return s.idx < s.count +} + +func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) { + return s.row, nil +} + +func (s *benchmarkWriteTableCopyFromSrc) Err() error { + return nil +} + +func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { + return &benchmarkWriteTableCopyFromSrc{ + count: count, + row: []interface{}{ + "varchar_1", + "varchar_2", + pgx.NullString{}, + time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), + pgx.NullTime{}, + 1, + 2, + pgx.NullInt32{}, + time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), + true, + false, + true, + }, + } +} + +func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyFromSrc(n) + + tx, err := conn.Begin() + if err != nil { + b.Fatal(err) + } + + for src.Next() { + values, _ := src.Values() + if _, err = tx.Exec("insert_t", values...); err != nil { + b.Fatalf("Exec unexpectedly failed with: %v", err) + } + } + + err = tx.Commit() + if err != nil { + b.Fatal(err) + } + } +} + +// note this function is only used for benchmarks -- it doesn't escape tableName +// or columnNames +func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyFromSource) (int, error) { + maxRowsPerInsert := 65535 / len(columnNames) + rowsThisInsert := 0 + rowCount := 0 + + sqlBuf := &bytes.Buffer{} + args := make(pgx.QueryArgs, 0) + + resetQuery := func() { + sqlBuf.Reset() + fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", ")) + + args = args[0:0] + + rowsThisInsert = 0 + } + resetQuery() + + tx, err := conn.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + + for rowSrc.Next() { + if rowsThisInsert > 0 { + sqlBuf.WriteByte(',') + } + + sqlBuf.WriteByte('(') + + values, err := rowSrc.Values() + if err != nil { + return 0, err + } + + for i, val := range values { + if i > 0 { + sqlBuf.WriteByte(',') + } + sqlBuf.WriteString(args.Append(val)) + } + + sqlBuf.WriteByte(')') + + rowsThisInsert++ + + if rowsThisInsert == maxRowsPerInsert { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + resetQuery() + } + } + + if rowsThisInsert > 0 { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + } + + if err := tx.Commit(); err != nil { + return 0, nil + } + + return rowCount, nil + +} + +func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyFromSrc(n) + + _, err := multiInsert(conn, "t", + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyFromSrc(n) + + _, err := conn.CopyFrom(pgx.Identifier{"t"}, + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWrite5RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 5) +} + +func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 5) +} + +func BenchmarkWrite5RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 5) +} + +func BenchmarkWrite10RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10) +} + +func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10) +} + +func BenchmarkWrite10RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10) +} + +func BenchmarkWrite100RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 100) +} + +func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 100) +} + +func BenchmarkWrite100RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 100) +} + +func BenchmarkWrite1000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 1000) +} + +func BenchmarkWrite10000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10000) +} diff --git a/vendor/github.com/jackc/pgx/conn.go b/vendor/github.com/jackc/pgx/conn.go new file mode 100644 index 0000000..a2d60e7 --- /dev/null +++ b/vendor/github.com/jackc/pgx/conn.go @@ -0,0 +1,1322 @@ +package pgx + +import ( + "bufio" + "crypto/md5" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "net/url" + "os" + "os/user" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" +) + +// DialFunc is a function that can be used to connect to a PostgreSQL server +type DialFunc func(network, addr string) (net.Conn, error) + +// ConnConfig contains all the options used to establish a connection. +type ConnConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 // default: 5432 + Database string + User string // default: OS user name + Password string + TLSConfig *tls.Config // config for TLS connection -- nil disables TLS + UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa + FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS + Logger Logger + LogLevel int + Dial DialFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) +} + +// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. +// Use ConnPool to manage access to multiple database connections from multiple +// goroutines. +type Conn struct { + conn net.Conn // the underlying TCP or unix domain socket connection + lastActivityTime time.Time // the last time the connection was used + reader *bufio.Reader // buffered reader to improve read performance + wbuf [1024]byte + writeBuf WriteBuf + Pid int32 // backend pid + SecretKey int32 // key to use to send a cancel query message to the server + RuntimeParams map[string]string // parameters that have been reported by the server + PgTypes map[Oid]PgType // oids to PgTypes + config ConnConfig // config used when establishing this connection + TxStatus byte + preparedStatements map[string]*PreparedStatement + channels map[string]struct{} + notifications []*Notification + alive bool + causeOfDeath error + logger Logger + logLevel int + mr msgReader + fp *fastpath + pgsqlAfInet *byte + pgsqlAfInet6 *byte + busy bool + poolResetCount int + preallocatedRows []Rows +} + +// PreparedStatement is a description of a prepared statement +type PreparedStatement struct { + Name string + SQL string + FieldDescriptions []FieldDescription + ParameterOids []Oid +} + +// PrepareExOptions is an option struct that can be passed to PrepareEx +type PrepareExOptions struct { + ParameterOids []Oid +} + +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + Pid int32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + +// PgType is information about PostgreSQL type and how to encode and decode it +type PgType struct { + Name string // name of type e.g. int4, text, date + DefaultFormat int16 // default format (text or binary) this type will be requested in +} + +// CommandTag is the result of an Exec function +type CommandTag string + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (such as "CREATE TABLE") then it returns 0 +func (ct CommandTag) RowsAffected() int64 { + s := string(ct) + index := strings.LastIndex(s, " ") + if index == -1 { + return 0 + } + n, _ := strconv.ParseInt(s[index+1:], 10, 64) + return n +} + +// Identifier a PostgreSQL identifier or name. Identifiers can be composed of +// multiple parts such as ["schema", "table"] or ["table", "column"]. +type Identifier []string + +// Sanitize returns a sanitized string safe for SQL interpolation. +func (ident Identifier) Sanitize() string { + parts := make([]string, len(ident)) + for i := range ident { + parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + } + return strings.Join(parts, ".") +} + +// ErrNoRows occurs when rows are expected but none are returned. +var ErrNoRows = errors.New("no rows in result set") + +// ErrNotificationTimeout occurs when WaitForNotification times out. +var ErrNotificationTimeout = errors.New("notification timeout") + +// ErrDeadConn occurs on an attempt to use a dead connection +var ErrDeadConn = errors.New("conn is dead") + +// ErrTLSRefused occurs when the connection attempt requires TLS and the +// PostgreSQL server refuses to use TLS +var ErrTLSRefused = errors.New("server refused TLS connection") + +// ErrConnBusy occurs when the connection is busy (for example, in the middle of +// reading query results) and another action is attempts. +var ErrConnBusy = errors.New("conn is busy") + +// ErrInvalidLogLevel occurs on attempt to set an invalid log level. +var ErrInvalidLogLevel = errors.New("invalid log level") + +// ProtocolError occurs when unexpected data is received from PostgreSQL +type ProtocolError string + +func (e ProtocolError) Error() string { + return string(e) +} + +// Connect establishes a connection with a PostgreSQL server using config. +// config.Host must be specified. config.User will default to the OS user name. +// Other config fields are optional. +func Connect(config ConnConfig) (c *Conn, err error) { + return connect(config, nil, nil, nil) +} + +func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) { + c = new(Conn) + + c.config = config + + if pgTypes != nil { + c.PgTypes = make(map[Oid]PgType, len(pgTypes)) + for k, v := range pgTypes { + c.PgTypes[k] = v + } + } + + if pgsqlAfInet != nil { + c.pgsqlAfInet = new(byte) + *c.pgsqlAfInet = *pgsqlAfInet + } + if pgsqlAfInet6 != nil { + c.pgsqlAfInet6 = new(byte) + *c.pgsqlAfInet6 = *pgsqlAfInet6 + } + + if c.config.LogLevel != 0 { + c.logLevel = c.config.LogLevel + } else { + // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug + c.logLevel = LogLevelDebug + } + c.logger = c.config.Logger + c.mr.log = c.log + c.mr.shouldLog = c.shouldLog + + if c.config.User == "" { + user, err := user.Current() + if err != nil { + return nil, err + } + c.config.User = user.Username + if c.shouldLog(LogLevelDebug) { + c.log(LogLevelDebug, "Using default connection config", "User", c.config.User) + } + } + + if c.config.Port == 0 { + c.config.Port = 5432 + if c.shouldLog(LogLevelDebug) { + c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port) + } + } + + network := "tcp" + address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) + // See if host is a valid path, if yes connect with a socket + if _, err := os.Stat(c.config.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network = "unix" + address = c.config.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) + } + } + if c.config.Dial == nil { + c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial + } + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) + } + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err)) + } + err = c.connect(config, network, address, config.FallbackTLSConfig) + } + + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err)) + } + return nil, err + } + + return c, nil +} + +func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { + c.conn, err = c.config.Dial(network, address) + if err != nil { + return err + } + defer func() { + if c != nil && err != nil { + c.conn.Close() + c.alive = false + } + }() + + c.RuntimeParams = make(map[string]string) + c.preparedStatements = make(map[string]*PreparedStatement) + c.channels = make(map[string]struct{}) + c.alive = true + c.lastActivityTime = time.Now() + + if tlsConfig != nil { + if c.shouldLog(LogLevelDebug) { + c.log(LogLevelDebug, "Starting TLS handshake") + } + if err := c.startTLS(tlsConfig); err != nil { + return err + } + } + + c.reader = bufio.NewReader(c.conn) + c.mr.reader = c.reader + + msg := newStartupMessage() + + // Default to disabling TLS renegotiation. + // + // Go does not support (https://github.com/golang/go/issues/5742) + // PostgreSQL recommends disabling (http://www.postgresql.org/docs/9.4/static/runtime-config-connection.html#GUC-SSL-RENEGOTIATION-LIMIT) + if tlsConfig != nil { + msg.options["ssl_renegotiation_limit"] = "0" + } + + // Copy default run-time params + for k, v := range config.RuntimeParams { + msg.options[k] = v + } + + msg.options["user"] = c.config.User + if c.config.Database != "" { + msg.options["database"] = c.config.Database + } + + if err = c.txStartupMessage(msg); err != nil { + return err + } + + for { + var t byte + var r *msgReader + t, r, err = c.rxMsg() + if err != nil { + return err + } + + switch t { + case backendKeyData: + c.rxBackendKeyData(r) + case authenticationX: + if err = c.rxAuthenticationX(r); err != nil { + return err + } + case readyForQuery: + c.rxReadyForQuery(r) + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Connection established") + } + + // Replication connections can't execute the queries to + // populate the c.PgTypes and c.pgsqlAfInet + if _, ok := msg.options["replication"]; ok { + return nil + } + + if c.PgTypes == nil { + err = c.loadPgTypes() + if err != nil { + return err + } + } + + if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil { + err = c.loadInetConstants() + if err != nil { + return err + } + } + + return nil + default: + if err = c.processContextFreeMsg(t, r); err != nil { + return err + } + } + } +} + +func (c *Conn) loadPgTypes() error { + rows, err := c.Query(`select t.oid, t.typname +from pg_type t +left join pg_type base_type on t.typelem=base_type.oid +where ( + t.typtype='b' + and (base_type.oid is null or base_type.typtype='b') + ) + or t.typname in('record');`) + if err != nil { + return err + } + + c.PgTypes = make(map[Oid]PgType, 128) + + for rows.Next() { + var oid Oid + var t PgType + + rows.Scan(&oid, &t.Name) + + // The zero value is text format so we ignore any types without a default type format + t.DefaultFormat, _ = DefaultTypeFormats[t.Name] + + c.PgTypes[oid] = t + } + + return rows.Err() +} + +// Family is needed for binary encoding of inet/cidr. The constant is based on +// the server's definition of AF_INET. In theory, this could differ between +// platforms, so request an IPv4 and an IPv6 inet and get the family from that. +func (c *Conn) loadInetConstants() error { + var ipv4, ipv6 []byte + + err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6) + if err != nil { + return err + } + + c.pgsqlAfInet = &ipv4[0] + c.pgsqlAfInet6 = &ipv6[0] + + return nil +} + +// Close closes a connection. It is safe to call Close on a already closed +// connection. +func (c *Conn) Close() (err error) { + if !c.IsAlive() { + return nil + } + + wbuf := newWriteBuf(c, 'X') + wbuf.closeMsg() + + _, err = c.conn.Write(wbuf.buf) + + c.die(errors.New("Closed")) + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Closed connection") + } + return err +} + +// ParseURI parses a database URI into ConnConfig +// +// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams. +func ParseURI(uri string) (ConnConfig, error) { + var cp ConnConfig + + url, err := url.Parse(uri) + if err != nil { + return cp, err + } + + if url.User != nil { + cp.User = url.User.Username() + cp.Password, _ = url.User.Password() + } + + parts := strings.SplitN(url.Host, ":", 2) + cp.Host = parts[0] + if len(parts) == 2 { + p, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) + } + cp.Database = strings.TrimLeft(url.Path, "/") + + err = configSSL(url.Query().Get("sslmode"), &cp) + if err != nil { + return cp, err + } + + ignoreKeys := map[string]struct{}{ + "sslmode": {}, + } + + cp.RuntimeParams = make(map[string]string) + + for k, v := range url.Query() { + if _, ok := ignoreKeys[k]; ok { + continue + } + + cp.RuntimeParams[k] = v[0] + } + if cp.Password == "" { + pgpass(&cp) + } + return cp, nil +} + +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) + +// ParseDSN parses a database DSN (data source name) into a ConnConfig +// +// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable") +// +// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams. +// +// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb") +// +// ParseDSN tries to match libpq behavior with regard to sslmode. See comments +// for ParseEnvLibpq for more information on the security implications of +// sslmode options. +func ParseDSN(s string) (ConnConfig, error) { + var cp ConnConfig + + m := dsnRegexp.FindAllStringSubmatch(s, -1) + + var sslmode string + + cp.RuntimeParams = make(map[string]string) + + for _, b := range m { + switch b[1] { + case "user": + cp.User = b[2] + case "password": + cp.Password = b[2] + case "host": + cp.Host = b[2] + case "port": + p, err := strconv.ParseUint(b[2], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) + case "dbname": + cp.Database = b[2] + case "sslmode": + sslmode = b[2] + default: + cp.RuntimeParams[b[1]] = b[2] + } + } + + err := configSSL(sslmode, &cp) + if err != nil { + return cp, err + } + if cp.Password == "" { + pgpass(&cp) + } + return cp, nil +} + +// ParseConnectionString parses either a URI or a DSN connection string. +// see ParseURI and ParseDSN for details. +func ParseConnectionString(s string) (ConnConfig, error) { + if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") { + return ParseURI(s) + } + return ParseDSN(s) +} + +// ParseEnvLibpq parses the environment like libpq does into a ConnConfig +// +// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details +// on the meaning of environment variables. +// +// ParseEnvLibpq currently recognizes the following environment variables: +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGSSLMODE +// PGAPPNAME +// +// Important TLS Security Notes: +// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This +// includes defaulting to "prefer" behavior if no environment variable is set. +// +// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION +// for details on what level of security each sslmode provides. +// +// "require" and "verify-ca" modes currently are treated as "verify-full". e.g. +// They have stronger security guarantees than they would with libpq. Do not +// rely on this behavior as it may be possible to match libpq in the future. If +// you need full security use "verify-full". +// +// Several of the PGSSLMODE options (including the default behavior of "prefer") +// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or +// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is +// later set from a different source that UseFallbackTLS MUST be set false to +// avoid the possibility of falling back to weaker or disabled security. +func ParseEnvLibpq() (ConnConfig, error) { + var cc ConnConfig + + cc.Host = os.Getenv("PGHOST") + + if pgport := os.Getenv("PGPORT"); pgport != "" { + if port, err := strconv.ParseUint(pgport, 10, 16); err == nil { + cc.Port = uint16(port) + } else { + return cc, err + } + } + + cc.Database = os.Getenv("PGDATABASE") + cc.User = os.Getenv("PGUSER") + cc.Password = os.Getenv("PGPASSWORD") + + sslmode := os.Getenv("PGSSLMODE") + + err := configSSL(sslmode, &cc) + if err != nil { + return cc, err + } + + cc.RuntimeParams = make(map[string]string) + if appname := os.Getenv("PGAPPNAME"); appname != "" { + cc.RuntimeParams["application_name"] = appname + } + if cc.Password == "" { + pgpass(&cc) + } + return cc, nil +} + +func configSSL(sslmode string, cc *ConnConfig) error { + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + switch sslmode { + case "disable": + case "allow": + cc.UseFallbackTLS = true + cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} + case "prefer": + cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} + cc.UseFallbackTLS = true + cc.FallbackTLSConfig = nil + case "require", "verify-ca", "verify-full": + cc.TLSConfig = &tls.Config{ + ServerName: cc.Host, + } + default: + return errors.New("sslmode is invalid") + } + + return nil +} + +// Prepare creates a prepared statement with name and sql. sql can contain placeholders +// for bound parameters. These placeholders are referenced positional as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same +// name and sql arguments. This allows a code path to Prepare and Query/Exec without +// concern for if the statement has already been prepared. +func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { + return c.PrepareEx(name, sql, nil) +} + +// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders +// for bound parameters. These placeholders are referenced positional as $1, $2, etc. +// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct +// +// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same +// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without +// concern for if the statement has already been prepared. +func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + if name != "" { + if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { + return ps, nil + } + } + + if c.shouldLog(LogLevelError) { + defer func() { + if err != nil { + c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) + } + }() + } + + // parse + wbuf := newWriteBuf(c, 'P') + wbuf.WriteCString(name) + wbuf.WriteCString(sql) + + if opts != nil { + if len(opts.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) + } + wbuf.WriteInt16(int16(len(opts.ParameterOids))) + for _, oid := range opts.ParameterOids { + wbuf.WriteInt32(int32(oid)) + } + } else { + wbuf.WriteInt16(0) + } + + // describe + wbuf.startMsg('D') + wbuf.WriteByte('S') + wbuf.WriteCString(name) + + // sync + wbuf.startMsg('S') + wbuf.closeMsg() + + _, err = c.conn.Write(wbuf.buf) + if err != nil { + c.die(err) + return nil, err + } + + ps = &PreparedStatement{Name: name, SQL: sql} + + var softErr error + + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return nil, err + } + + switch t { + case parseComplete: + case parameterDescription: + ps.ParameterOids = c.rxParameterDescription(r) + + if len(ps.ParameterOids) > 65535 && softErr == nil { + softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) + } + case rowDescription: + ps.FieldDescriptions = c.rxRowDescription(r) + for i := range ps.FieldDescriptions { + t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType] + ps.FieldDescriptions[i].DataTypeName = t.Name + ps.FieldDescriptions[i].FormatCode = t.DefaultFormat + } + case noData: + case readyForQuery: + c.rxReadyForQuery(r) + + if softErr == nil { + c.preparedStatements[name] = ps + } + + return ps, softErr + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e + } + } + } +} + +// Deallocate released a prepared statement +func (c *Conn) Deallocate(name string) (err error) { + delete(c.preparedStatements, name) + + // close + wbuf := newWriteBuf(c, 'C') + wbuf.WriteByte('S') + wbuf.WriteCString(name) + + // flush + wbuf.startMsg('H') + wbuf.closeMsg() + + _, err = c.conn.Write(wbuf.buf) + if err != nil { + c.die(err) + return err + } + + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case closeComplete: + return nil + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } +} + +// Listen establishes a PostgreSQL listen/notify to channel +func (c *Conn) Listen(channel string) error { + _, err := c.Exec("listen " + quoteIdentifier(channel)) + if err != nil { + return err + } + + c.channels[channel] = struct{}{} + + return nil +} + +// Unlisten unsubscribes from a listen channel +func (c *Conn) Unlisten(channel string) error { + _, err := c.Exec("unlisten " + quoteIdentifier(channel)) + if err != nil { + return err + } + + delete(c.channels, channel) + return nil +} + +// WaitForNotification waits for a PostgreSQL notification for up to timeout. +// If the timeout occurs it returns pgx.ErrNotificationTimeout +func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) { + // Return already received notification immediately + if len(c.notifications) > 0 { + notification := c.notifications[0] + c.notifications = c.notifications[1:] + return notification, nil + } + + stopTime := time.Now().Add(timeout) + + for { + now := time.Now() + + if now.After(stopTime) { + return nil, ErrNotificationTimeout + } + + // If there has been no activity on this connection for a while send a nop message just to ensure + // the connection is alive + nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second) + if nextEnsureAliveTime.Before(now) { + // If the server can't respond to a nop in 15 seconds, assume it's dead + err := c.conn.SetReadDeadline(now.Add(15 * time.Second)) + if err != nil { + return nil, err + } + + _, err = c.Exec("--;") + if err != nil { + return nil, err + } + + c.lastActivityTime = now + } + + var deadline time.Time + if stopTime.Before(nextEnsureAliveTime) { + deadline = stopTime + } else { + deadline = nextEnsureAliveTime + } + + notification, err := c.waitForNotification(deadline) + if err != ErrNotificationTimeout { + return notification, err + } + } +} + +func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { + var zeroTime time.Time + + for { + // Use SetReadDeadline to implement the timeout. SetReadDeadline will + // cause operations to fail with a *net.OpError that has a Timeout() + // of true. Because the normal pgx rxMsg path considers any error to + // have potentially corrupted the state of the connection, it dies + // on any errors. So to avoid timeout errors in rxMsg we set the + // deadline and peek into the reader. If a timeout error occurs there + // we don't break the pgx connection. If the Peek returns that data + // is available then we turn off the read deadline before the rxMsg. + err := c.conn.SetReadDeadline(deadline) + if err != nil { + return nil, err + } + + // Wait until there is a byte available before continuing onto the normal msg reading path + _, err = c.reader.Peek(1) + if err != nil { + c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline + if err, ok := err.(*net.OpError); ok && err.Timeout() { + return nil, ErrNotificationTimeout + } + return nil, err + } + + err = c.conn.SetReadDeadline(zeroTime) + if err != nil { + return nil, err + } + + var t byte + var r *msgReader + if t, r, err = c.rxMsg(); err == nil { + if err = c.processContextFreeMsg(t, r); err != nil { + return nil, err + } + } else { + return nil, err + } + + if len(c.notifications) > 0 { + notification := c.notifications[0] + c.notifications = c.notifications[1:] + return notification, nil + } + } +} + +func (c *Conn) IsAlive() bool { + return c.alive +} + +func (c *Conn) CauseOfDeath() error { + return c.causeOfDeath +} + +func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { + if ps, present := c.preparedStatements[sql]; present { + return c.sendPreparedQuery(ps, arguments...) + } + return c.sendSimpleQuery(sql, arguments...) +} + +func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { + + if len(args) == 0 { + wbuf := newWriteBuf(c, 'Q') + wbuf.WriteCString(sql) + wbuf.closeMsg() + + _, err := c.conn.Write(wbuf.buf) + if err != nil { + c.die(err) + return err + } + + return nil + } + + ps, err := c.Prepare("", sql) + if err != nil { + return err + } + + return c.sendPreparedQuery(ps, args...) +} + +func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { + if len(ps.ParameterOids) != len(arguments) { + return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) + } + + // bind + wbuf := newWriteBuf(c, 'B') + wbuf.WriteByte(0) + wbuf.WriteCString(ps.Name) + + wbuf.WriteInt16(int16(len(ps.ParameterOids))) + for i, oid := range ps.ParameterOids { + switch arg := arguments[i].(type) { + case Encoder: + wbuf.WriteInt16(arg.FormatCode()) + case string, *string: + wbuf.WriteInt16(TextFormatCode) + default: + switch oid { + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid, JsonOid, JsonbOid: + wbuf.WriteInt16(BinaryFormatCode) + default: + wbuf.WriteInt16(TextFormatCode) + } + } + } + + wbuf.WriteInt16(int16(len(arguments))) + for i, oid := range ps.ParameterOids { + if err := Encode(wbuf, oid, arguments[i]); err != nil { + return err + } + } + + wbuf.WriteInt16(int16(len(ps.FieldDescriptions))) + for _, fd := range ps.FieldDescriptions { + wbuf.WriteInt16(fd.FormatCode) + } + + // execute + wbuf.startMsg('E') + wbuf.WriteByte(0) + wbuf.WriteInt32(0) + + // sync + wbuf.startMsg('S') + wbuf.closeMsg() + + _, err = c.conn.Write(wbuf.buf) + if err != nil { + c.die(err) + } + + return err +} + +// Exec executes sql. sql can be either a prepared statement name or an SQL string. +// arguments should be referenced positionally from the sql string as $1, $2, etc. +func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + if err = c.lock(); err != nil { + return commandTag, err + } + + startTime := time.Now() + c.lastActivityTime = startTime + + defer func() { + if err == nil { + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) + } + } else { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) + } + } + + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() + + if err = c.sendQuery(sql, arguments...); err != nil { + return + } + + var softErr error + + for { + var t byte + var r *msgReader + t, r, err = c.rxMsg() + if err != nil { + return commandTag, err + } + + switch t { + case readyForQuery: + c.rxReadyForQuery(r) + return commandTag, softErr + case rowDescription: + case dataRow: + case bindComplete: + case commandComplete: + commandTag = CommandTag(r.readCString()) + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e + } + } + } +} + +// Processes messages that are not exclusive to one context such as +// authentication or query response. The response to these messages +// is the same regardless of when they occur. +func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { + switch t { + case 'S': + c.rxParameterStatus(r) + return nil + case errorResponse: + return c.rxErrorResponse(r) + case noticeResponse: + return nil + case emptyQueryResponse: + return nil + case notificationResponse: + c.rxNotificationResponse(r) + return nil + default: + return fmt.Errorf("Received unknown message type: %c", t) + } +} + +func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { + if !c.alive { + return 0, nil, ErrDeadConn + } + + t, err = c.mr.rxMsg() + if err != nil { + c.die(err) + } + + c.lastActivityTime = time.Now() + + if c.shouldLog(LogLevelTrace) { + c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) + } + + return t, &c.mr, err +} + +func (c *Conn) rxAuthenticationX(r *msgReader) (err error) { + switch r.readInt32() { + case 0: // AuthenticationOk + case 3: // AuthenticationCleartextPassword + err = c.txPasswordMessage(c.config.Password) + case 5: // AuthenticationMD5Password + salt := r.readString(4) + digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt) + err = c.txPasswordMessage(digestedPassword) + default: + err = errors.New("Received unknown authentication message") + } + + return +} + +func hexMD5(s string) string { + hash := md5.New() + io.WriteString(hash, s) + return hex.EncodeToString(hash.Sum(nil)) +} + +func (c *Conn) rxParameterStatus(r *msgReader) { + key := r.readCString() + value := r.readCString() + c.RuntimeParams[key] = value +} + +func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { + for { + switch r.readByte() { + case 'S': + err.Severity = r.readCString() + case 'C': + err.Code = r.readCString() + case 'M': + err.Message = r.readCString() + case 'D': + err.Detail = r.readCString() + case 'H': + err.Hint = r.readCString() + case 'P': + s := r.readCString() + n, _ := strconv.ParseInt(s, 10, 32) + err.Position = int32(n) + case 'p': + s := r.readCString() + n, _ := strconv.ParseInt(s, 10, 32) + err.InternalPosition = int32(n) + case 'q': + err.InternalQuery = r.readCString() + case 'W': + err.Where = r.readCString() + case 's': + err.SchemaName = r.readCString() + case 't': + err.TableName = r.readCString() + case 'c': + err.ColumnName = r.readCString() + case 'd': + err.DataTypeName = r.readCString() + case 'n': + err.ConstraintName = r.readCString() + case 'F': + err.File = r.readCString() + case 'L': + s := r.readCString() + n, _ := strconv.ParseInt(s, 10, 32) + err.Line = int32(n) + case 'R': + err.Routine = r.readCString() + + case 0: // End of error message + if err.Severity == "FATAL" { + c.die(err) + } + return + default: // Ignore other error fields + r.readCString() + } + } +} + +func (c *Conn) rxBackendKeyData(r *msgReader) { + c.Pid = r.readInt32() + c.SecretKey = r.readInt32() +} + +func (c *Conn) rxReadyForQuery(r *msgReader) { + c.TxStatus = r.readByte() +} + +func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { + fieldCount := r.readInt16() + fields = make([]FieldDescription, fieldCount) + for i := int16(0); i < fieldCount; i++ { + f := &fields[i] + f.Name = r.readCString() + f.Table = r.readOid() + f.AttributeNumber = r.readInt16() + f.DataType = r.readOid() + f.DataTypeSize = r.readInt16() + f.Modifier = r.readInt32() + f.FormatCode = r.readInt16() + } + return +} + +func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { + // Internally, PostgreSQL supports greater than 64k parameters to a prepared + // statement. But the parameter description uses a 16-bit integer for the + // count of parameters. If there are more than 64K parameters, this count is + // wrong. So read the count, ignore it, and compute the proper value from + // the size of the message. + r.readInt16() + parameterCount := r.msgBytesRemaining / 4 + + parameters = make([]Oid, 0, parameterCount) + + for i := int32(0); i < parameterCount; i++ { + parameters = append(parameters, r.readOid()) + } + return +} + +func (c *Conn) rxNotificationResponse(r *msgReader) { + n := new(Notification) + n.Pid = r.readInt32() + n.Channel = r.readCString() + n.Payload = r.readCString() + c.notifications = append(c.notifications, n) +} + +func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { + err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return + } + + response := make([]byte, 1) + if _, err = io.ReadFull(c.conn, response); err != nil { + return + } + + if response[0] != 'S' { + return ErrTLSRefused + } + + c.conn = tls.Client(c.conn, tlsConfig) + + return nil +} + +func (c *Conn) txStartupMessage(msg *startupMessage) error { + _, err := c.conn.Write(msg.Bytes()) + return err +} + +func (c *Conn) txPasswordMessage(password string) (err error) { + wbuf := newWriteBuf(c, 'p') + wbuf.WriteCString(password) + wbuf.closeMsg() + + _, err = c.conn.Write(wbuf.buf) + + return err +} + +func (c *Conn) die(err error) { + c.alive = false + c.causeOfDeath = err + c.conn.Close() +} + +func (c *Conn) lock() error { + if c.busy { + return ErrConnBusy + } + c.busy = true + return nil +} + +func (c *Conn) unlock() error { + if !c.busy { + return errors.New("unlock conn that is not busy") + } + c.busy = false + return nil +} + +func (c *Conn) shouldLog(lvl int) bool { + return c.logger != nil && c.logLevel >= lvl +} + +func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { + if c.Pid != 0 { + ctx = append(ctx, "pid", c.Pid) + } + + switch lvl { + case LogLevelTrace: + c.logger.Debug(msg, ctx...) + case LogLevelDebug: + c.logger.Debug(msg, ctx...) + case LogLevelInfo: + c.logger.Info(msg, ctx...) + case LogLevelWarn: + c.logger.Warn(msg, ctx...) + case LogLevelError: + c.logger.Error(msg, ctx...) + } +} + +// SetLogger replaces the current logger and returns the previous logger. +func (c *Conn) SetLogger(logger Logger) Logger { + oldLogger := c.logger + c.logger = logger + return oldLogger +} + +// SetLogLevel replaces the current log level and returns the previous log +// level. +func (c *Conn) SetLogLevel(lvl int) (int, error) { + oldLvl := c.logLevel + + if lvl < LogLevelNone || lvl > LogLevelTrace { + return oldLvl, ErrInvalidLogLevel + } + + c.logLevel = lvl + return lvl, nil +} + +func quoteIdentifier(s string) string { + return `"` + strings.Replace(s, `"`, `""`, -1) + `"` +} diff --git a/vendor/github.com/jackc/pgx/conn_config_test.go.example b/vendor/github.com/jackc/pgx/conn_config_test.go.example new file mode 100644 index 0000000..cac798b --- /dev/null +++ b/vendor/github.com/jackc/pgx/conn_config_test.go.example @@ -0,0 +1,25 @@ +package pgx_test + +import ( + "github.com/jackc/pgx" +) + +var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} + +// To skip tests for specific connection / authentication types set that connection param to nil +var tcpConnConfig *pgx.ConnConfig = nil +var unixSocketConnConfig *pgx.ConnConfig = nil +var md5ConnConfig *pgx.ConnConfig = nil +var plainPasswordConnConfig *pgx.ConnConfig = nil +var invalidUserConnConfig *pgx.ConnConfig = nil +var tlsConnConfig *pgx.ConnConfig = nil +var customDialerConnConfig *pgx.ConnConfig = nil +var replicationConnConfig *pgx.ConnConfig = nil + +// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} +// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} +// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} +// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} +// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/vendor/github.com/jackc/pgx/conn_config_test.go.travis b/vendor/github.com/jackc/pgx/conn_config_test.go.travis new file mode 100644 index 0000000..75714bf --- /dev/null +++ b/vendor/github.com/jackc/pgx/conn_config_test.go.travis @@ -0,0 +1,30 @@ +package pgx_test + +import ( + "crypto/tls" + "github.com/jackc/pgx" + "os" + "strconv" +) + +var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} +var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} +var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} +var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} +var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var replicationConnConfig *pgx.ConnConfig = nil + +func init() { + version := os.Getenv("PGVERSION") + + if len(version) > 0 { + v, err := strconv.ParseFloat(version,64) + if err == nil && v >= 9.6 { + replicationConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"} + } + } +} + diff --git a/vendor/github.com/jackc/pgx/conn_pool.go b/vendor/github.com/jackc/pgx/conn_pool.go new file mode 100644 index 0000000..1913699 --- /dev/null +++ b/vendor/github.com/jackc/pgx/conn_pool.go @@ -0,0 +1,529 @@ +package pgx + +import ( + "errors" + "sync" + "time" +) + +type ConnPoolConfig struct { + ConnConfig + MaxConnections int // max simultaneous connections to use, default 5, must be at least 2 + AfterConnect func(*Conn) error // function to call on every new connection + AcquireTimeout time.Duration // max wait time when all connections are busy (0 means no timeout) +} + +type ConnPool struct { + allConnections []*Conn + availableConnections []*Conn + cond *sync.Cond + config ConnConfig // config used when establishing connection + inProgressConnects int + maxConnections int + resetCount int + afterConnect func(*Conn) error + logger Logger + logLevel int + closed bool + preparedStatements map[string]*PreparedStatement + acquireTimeout time.Duration + pgTypes map[Oid]PgType + pgsqlAfInet *byte + pgsqlAfInet6 *byte + txAfterClose func(tx *Tx) + rowsAfterClose func(rows *Rows) +} + +type ConnPoolStat struct { + MaxConnections int // max simultaneous connections to use + CurrentConnections int // current live connections + AvailableConnections int // unused live connections +} + +// ErrAcquireTimeout occurs when an attempt to acquire a connection times out. +var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") + +// NewConnPool creates a new ConnPool. config.ConnConfig is passed through to +// Connect directly. +func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { + p = new(ConnPool) + p.config = config.ConnConfig + p.maxConnections = config.MaxConnections + if p.maxConnections == 0 { + p.maxConnections = 5 + } + if p.maxConnections < 1 { + return nil, errors.New("MaxConnections must be at least 1") + } + p.acquireTimeout = config.AcquireTimeout + if p.acquireTimeout < 0 { + return nil, errors.New("AcquireTimeout must be equal to or greater than 0") + } + + p.afterConnect = config.AfterConnect + + if config.LogLevel != 0 { + p.logLevel = config.LogLevel + } else { + // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug + p.logLevel = LogLevelDebug + } + p.logger = config.Logger + if p.logger == nil { + p.logLevel = LogLevelNone + } + + p.txAfterClose = func(tx *Tx) { + p.Release(tx.Conn()) + } + + p.rowsAfterClose = func(rows *Rows) { + p.Release(rows.Conn()) + } + + p.allConnections = make([]*Conn, 0, p.maxConnections) + p.availableConnections = make([]*Conn, 0, p.maxConnections) + p.preparedStatements = make(map[string]*PreparedStatement) + p.cond = sync.NewCond(new(sync.Mutex)) + + // Initially establish one connection + var c *Conn + c, err = p.createConnection() + if err != nil { + return + } + p.allConnections = append(p.allConnections, c) + p.availableConnections = append(p.availableConnections, c) + + return +} + +// Acquire takes exclusive use of a connection until it is released. +func (p *ConnPool) Acquire() (*Conn, error) { + p.cond.L.Lock() + c, err := p.acquire(nil) + p.cond.L.Unlock() + return c, err +} + +// deadlinePassed returns true if the given deadline has passed. +func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { + return deadline != nil && time.Now().After(*deadline) +} + +// acquire performs acquision assuming pool is already locked +func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { + if p.closed { + return nil, errors.New("cannot acquire from closed pool") + } + + // A connection is available + if len(p.availableConnections) > 0 { + c := p.availableConnections[len(p.availableConnections)-1] + c.poolResetCount = p.resetCount + p.availableConnections = p.availableConnections[:len(p.availableConnections)-1] + return c, nil + } + + // Set initial timeout/deadline value. If the method (acquire) happens to + // recursively call itself the deadline should retain its value. + if deadline == nil && p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + + // Make sure the deadline (if it is) has not passed yet + if p.deadlinePassed(deadline) { + return nil, ErrAcquireTimeout + } + + // If there is a deadline then start a timeout timer + var timer *time.Timer + if deadline != nil { + timer = time.AfterFunc(deadline.Sub(time.Now()), func() { + p.cond.Broadcast() + }) + defer timer.Stop() + } + + // No connections are available, but we can create more + if len(p.allConnections)+p.inProgressConnects < p.maxConnections { + // Create a new connection. + // Careful here: createConnectionUnlocked() removes the current lock, + // creates a connection and then locks it back. + c, err := p.createConnectionUnlocked() + if err != nil { + return nil, err + } + c.poolResetCount = p.resetCount + p.allConnections = append(p.allConnections, c) + return c, nil + } + // All connections are in use and we cannot create more + if p.logLevel >= LogLevelWarn { + p.logger.Warn("All connections in pool are busy - waiting...") + } + + // Wait until there is an available connection OR room to create a new connection + for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { + if p.deadlinePassed(deadline) { + return nil, ErrAcquireTimeout + } + p.cond.Wait() + } + + // Stop the timer so that we do not spawn it on every acquire call. + if timer != nil { + timer.Stop() + } + return p.acquire(deadline) +} + +// Release gives up use of a connection. +func (p *ConnPool) Release(conn *Conn) { + if conn.TxStatus != 'I' { + conn.Exec("rollback") + } + + if len(conn.channels) > 0 { + if err := conn.Unlisten("*"); err != nil { + conn.die(err) + } + conn.channels = make(map[string]struct{}) + } + conn.notifications = nil + + p.cond.L.Lock() + + if conn.poolResetCount != p.resetCount { + conn.Close() + p.cond.L.Unlock() + p.cond.Signal() + return + } + + if conn.IsAlive() { + p.availableConnections = append(p.availableConnections, conn) + } else { + p.removeFromAllConnections(conn) + } + p.cond.L.Unlock() + p.cond.Signal() +} + +// removeFromAllConnections Removes the given connection from the list. +// It returns true if the connection was found and removed or false otherwise. +func (p *ConnPool) removeFromAllConnections(conn *Conn) bool { + for i, c := range p.allConnections { + if conn == c { + p.allConnections = append(p.allConnections[:i], p.allConnections[i+1:]...) + return true + } + } + return false +} + +// Close ends the use of a connection pool. It prevents any new connections +// from being acquired, waits until all acquired connections are released, +// then closes all underlying connections. +func (p *ConnPool) Close() { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + p.closed = true + + // Wait until all connections are released + if len(p.availableConnections) != len(p.allConnections) { + for len(p.availableConnections) != len(p.allConnections) { + p.cond.Wait() + } + } + + for _, c := range p.allConnections { + _ = c.Close() + } +} + +// Reset closes all open connections, but leaves the pool open. It is intended +// for use when an error is detected that would disrupt all connections (such as +// a network interruption or a server state change). +// +// It is safe to reset a pool while connections are checked out. Those +// connections will be closed when they are returned to the pool. +func (p *ConnPool) Reset() { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + p.resetCount++ + p.allConnections = p.allConnections[0:0] + + for _, conn := range p.availableConnections { + conn.Close() + } + + p.availableConnections = p.availableConnections[0:0] +} + +// invalidateAcquired causes all acquired connections to be closed when released. +// The pool must already be locked. +func (p *ConnPool) invalidateAcquired() { + p.resetCount++ + + for _, c := range p.availableConnections { + c.poolResetCount = p.resetCount + } + + p.allConnections = p.allConnections[:len(p.availableConnections)] + copy(p.allConnections, p.availableConnections) +} + +// Stat returns connection pool statistics +func (p *ConnPool) Stat() (s ConnPoolStat) { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + s.MaxConnections = p.maxConnections + s.CurrentConnections = len(p.allConnections) + s.AvailableConnections = len(p.availableConnections) + return +} + +func (p *ConnPool) createConnection() (*Conn, error) { + c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6) + if err != nil { + return nil, err + } + return p.afterConnectionCreated(c) +} + +// createConnectionUnlocked Removes the current lock, creates a new connection, and +// then locks it back. +// Here is the point: lets say our pool dialer's OpenTimeout is set to 3 seconds. +// And we have a pool with 20 connections in it, and we try to acquire them all at +// startup. +// If it happens that the remote server is not accessible, then the first connection +// in the pool blocks all the others for 3 secs, before it gets the timeout. Then +// connection #2 holds the lock and locks everything for the next 3 secs until it +// gets OpenTimeout err, etc. And the very last 20th connection will fail only after +// 3 * 20 = 60 secs. +// To avoid this we put Connect(p.config) outside of the lock (it is thread safe) +// what would allow us to make all the 20 connection in parallel (more or less). +func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { + p.inProgressConnects++ + p.cond.L.Unlock() + c, err := Connect(p.config) + p.cond.L.Lock() + p.inProgressConnects-- + + if err != nil { + return nil, err + } + return p.afterConnectionCreated(c) +} + +// afterConnectionCreated executes (if it is) afterConnect() callback and prepares +// all the known statements for the new connection. +func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { + p.pgTypes = c.PgTypes + p.pgsqlAfInet = c.pgsqlAfInet + p.pgsqlAfInet6 = c.pgsqlAfInet6 + + if p.afterConnect != nil { + err := p.afterConnect(c) + if err != nil { + c.die(err) + return nil, err + } + } + + for _, ps := range p.preparedStatements { + if _, err := c.Prepare(ps.Name, ps.SQL); err != nil { + c.die(err) + return nil, err + } + } + + return c, nil +} + +// Exec acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + var c *Conn + if c, err = p.Acquire(); err != nil { + return + } + defer p.Release(c) + + return c.Exec(sql, arguments...) +} + +// Query acquires a connection and delegates the call to that connection. When +// *Rows are closed, the connection is released automatically. +func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { + c, err := p.Acquire() + if err != nil { + // Because checking for errors can be deferred to the *Rows, build one with the error + return &Rows{closed: true, err: err}, err + } + + rows, err := c.Query(sql, args...) + if err != nil { + p.Release(c) + return rows, err + } + + rows.AfterClose(p.rowsAfterClose) + + return rows, nil +} + +// QueryRow acquires a connection and delegates the call to that connection. The +// connection is released automatically after Scan is called on the returned +// *Row. +func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { + rows, _ := p.Query(sql, args...) + return (*Row)(rows) +} + +// Begin acquires a connection and begins a transaction on it. When the +// transaction is closed the connection will be automatically released. +func (p *ConnPool) Begin() (*Tx, error) { + return p.BeginIso("") +} + +// Prepare creates a prepared statement on a connection in the pool to test the +// statement is valid. If it succeeds all connections accessed through the pool +// will have the statement available. +// +// Prepare creates a prepared statement with name and sql. sql can contain +// placeholders for bound parameters. These placeholders are referenced +// positional as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with +// the same name and sql arguments. This allows a code path to Prepare and +// Query/Exec/PrepareEx without concern for if the statement has already been prepared. +func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { + return p.PrepareEx(name, sql, nil) +} + +// PrepareEx creates a prepared statement on a connection in the pool to test the +// statement is valid. If it succeeds all connections accessed through the pool +// will have the statement available. +// +// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders +// for bound parameters. These placeholders are referenced positional as $1, $2, etc. +// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct +// +// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same +// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without +// concern for if the statement has already been prepared. +func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { + return ps, nil + } + + c, err := p.acquire(nil) + if err != nil { + return nil, err + } + + p.availableConnections = append(p.availableConnections, c) + + // Double check that the statement was not prepared by someone else + // while we were acquiring the connection (since acquire is not fully + // blocking now, see createConnectionUnlocked()) + if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { + return ps, nil + } + + ps, err := c.PrepareEx(name, sql, opts) + if err != nil { + return nil, err + } + + for _, c := range p.availableConnections { + _, err := c.PrepareEx(name, sql, opts) + if err != nil { + return nil, err + } + } + + p.invalidateAcquired() + p.preparedStatements[name] = ps + + return ps, err +} + +// Deallocate releases a prepared statement from all connections in the pool. +func (p *ConnPool) Deallocate(name string) (err error) { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + for _, c := range p.availableConnections { + if err := c.Deallocate(name); err != nil { + return err + } + } + + p.invalidateAcquired() + delete(p.preparedStatements, name) + + return nil +} + +// BeginIso acquires a connection and begins a transaction in isolation mode iso +// on it. When the transaction is closed the connection will be automatically +// released. +func (p *ConnPool) BeginIso(iso string) (*Tx, error) { + for { + c, err := p.Acquire() + if err != nil { + return nil, err + } + + tx, err := c.BeginIso(iso) + if err != nil { + alive := c.IsAlive() + p.Release(c) + + // If connection is still alive then the error is not something trying + // again on a new connection would fix, so just return the error. But + // if the connection is dead try to acquire a new connection and try + // again. + if alive { + return nil, err + } + continue + } + + tx.AfterClose(p.txAfterClose) + return tx, nil + } +} + +// Deprecated. Use CopyFrom instead. CopyTo acquires a connection, delegates the +// call to that connection, and releases the connection. +func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + c, err := p.Acquire() + if err != nil { + return 0, err + } + defer p.Release(c) + + return c.CopyTo(tableName, columnNames, rowSrc) +} + +// CopyFrom acquires a connection, delegates the call to that connection, and +// releases the connection. +func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { + c, err := p.Acquire() + if err != nil { + return 0, err + } + defer p.Release(c) + + return c.CopyFrom(tableName, columnNames, rowSrc) +} diff --git a/vendor/github.com/jackc/pgx/conn_pool_private_test.go b/vendor/github.com/jackc/pgx/conn_pool_private_test.go new file mode 100644 index 0000000..ef0ec1d --- /dev/null +++ b/vendor/github.com/jackc/pgx/conn_pool_private_test.go @@ -0,0 +1,44 @@ +package pgx + +import ( + "testing" +) + +func compareConnSlices(slice1, slice2 []*Conn) bool { + if len(slice1) != len(slice2) { + return false + } + for i, c := range slice1 { + if c != slice2[i] { + return false + } + } + return true +} + +func TestConnPoolRemoveFromAllConnections(t *testing.T) { + t.Parallel() + pool := ConnPool{} + conn1 := &Conn{} + conn2 := &Conn{} + conn3 := &Conn{} + + // First element + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn1) + if !compareConnSlices(pool.allConnections, []*Conn{conn2, conn3}) { + t.Fatal("First element test failed") + } + // Element somewhere in the middle + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn2) + if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn3}) { + t.Fatal("Middle element test failed") + } + // Last element + pool.allConnections = []*Conn{conn1, conn2, conn3} + pool.removeFromAllConnections(conn3) + if !compareConnSlices(pool.allConnections, []*Conn{conn1, conn2}) { + t.Fatal("Last element test failed") + } +} diff --git a/vendor/github.com/jackc/pgx/conn_pool_test.go b/vendor/github.com/jackc/pgx/conn_pool_test.go new file mode 100644 index 0000000..ab76bfb --- /dev/null +++ b/vendor/github.com/jackc/pgx/conn_pool_test.go @@ -0,0 +1,982 @@ +package pgx_test + +import ( + "errors" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: maxConnections} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + return pool +} + +func acquireAllConnections(t *testing.T, pool *pgx.ConnPool, maxConnections int) []*pgx.Conn { + connections := make([]*pgx.Conn, maxConnections) + for i := 0; i < maxConnections; i++ { + var err error + if connections[i], err = pool.Acquire(); err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + } + return connections +} + +func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) { + for _, c := range connections { + pool.Release(c) + } +} + +func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) { + startTime := time.Now() + c, err := pool.Acquire() + return c, time.Since(startTime), err +} + +func TestNewConnPool(t *testing.T) { + t.Parallel() + + var numCallbacks int + afterConnect := func(c *pgx.Conn) error { + numCallbacks++ + return nil + } + + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 2, AfterConnect: afterConnect} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatal("Unable to establish connection pool") + } + defer pool.Close() + + // It initially connects once + stat := pool.Stat() + if stat.CurrentConnections != 1 { + t.Errorf("Expected 1 connection to be established immediately, but %v were", numCallbacks) + } + + // Pool creation returns an error if any AfterConnect callback does + errAfterConnect := errors.New("Some error") + afterConnect = func(c *pgx.Conn) error { + return errAfterConnect + } + + config = pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 2, AfterConnect: afterConnect} + pool, err = pgx.NewConnPool(config) + if err != errAfterConnect { + t.Errorf("Expected errAfterConnect but received unexpected: %v", err) + } +} + +func TestNewConnPoolDefaultsTo5MaxConnections(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatal("Unable to establish connection pool") + } + defer pool.Close() + + if n := pool.Stat().MaxConnections; n != 5 { + t.Fatalf("Expected pool to default to 5 max connections, but it was %d", n) + } +} + +func TestPoolAcquireAndReleaseCycle(t *testing.T) { + t.Parallel() + + maxConnections := 2 + incrementCount := int32(100) + completeSync := make(chan int) + pool := createConnPool(t, maxConnections) + defer pool.Close() + + allConnections := acquireAllConnections(t, pool, maxConnections) + + for _, c := range allConnections { + mustExec(t, c, "create temporary table t(counter integer not null)") + mustExec(t, c, "insert into t(counter) values(0);") + } + + releaseAllConnections(pool, allConnections) + + f := func() { + conn, err := pool.Acquire() + if err != nil { + t.Fatal("Unable to acquire connection") + } + defer pool.Release(conn) + + // Increment counter... + mustExec(t, conn, "update t set counter = counter + 1") + completeSync <- 0 + } + + for i := int32(0); i < incrementCount; i++ { + go f() + } + + // Wait for all f() to complete + for i := int32(0); i < incrementCount; i++ { + <-completeSync + } + + // Check that temp table in each connection has been incremented some number of times + actualCount := int32(0) + allConnections = acquireAllConnections(t, pool, maxConnections) + + for _, c := range allConnections { + var n int32 + c.QueryRow("select counter from t").Scan(&n) + if n == 0 { + t.Error("A connection was never used") + } + + actualCount += n + } + + if actualCount != incrementCount { + fmt.Println(actualCount) + t.Error("Wrong number of increments") + } + + releaseAllConnections(pool, allConnections) +} + +func TestPoolNonBlockingConnections(t *testing.T) { + t.Parallel() + + var dialCountLock sync.Mutex + dialCount := 0 + openTimeout := 1 * time.Second + testDialer := func(network, address string) (net.Conn, error) { + var firstDial bool + dialCountLock.Lock() + dialCount++ + firstDial = dialCount == 1 + dialCountLock.Unlock() + + if firstDial { + return net.Dial(network, address) + } else { + time.Sleep(openTimeout) + return nil, &net.OpError{Op: "dial", Net: "tcp"} + } + } + + maxConnections := 3 + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: maxConnections, + } + config.ConnConfig.Dial = testDialer + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Expected NewConnPool not to fail, instead it failed with: %v", err) + } + defer pool.Close() + + // NewConnPool establishes an initial connection + // so we need to close that for the rest of the test + if conn, err := pool.Acquire(); err == nil { + conn.Close() + pool.Release(conn) + } else { + t.Fatalf("pool.Acquire unexpectedly failed: %v", err) + } + + var wg sync.WaitGroup + wg.Add(maxConnections) + + startedAt := time.Now() + for i := 0; i < maxConnections; i++ { + go func() { + _, err := pool.Acquire() + wg.Done() + if err == nil { + t.Fatal("Acquire() expected to fail but it did not") + } + }() + } + wg.Wait() + + // Prior to createConnectionUnlocked() use the test took + // maxConnections * openTimeout seconds to complete. + // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds. + timeTaken := time.Since(startedAt) + if timeTaken > openTimeout+1*time.Second { + t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken) + } + +} + +func TestAcquireTimeoutSanity(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + } + + // case 1: default 0 value + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Expected NewConnPool with default config.AcquireTimeout not to fail, instead it failed with '%v'", err) + } + pool.Close() + + // case 2: negative value + config.AcquireTimeout = -1 * time.Second + _, err = pgx.NewConnPool(config) + if err == nil { + t.Fatal("Expected NewConnPool with negative config.AcquireTimeout to fail, instead it did not") + } + + // case 3: positive value + config.AcquireTimeout = 1 * time.Second + pool, err = pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Expected NewConnPool with positive config.AcquireTimeout not to fail, instead it failed with '%v'", err) + } + defer pool.Close() +} + +func TestPoolWithAcquireTimeoutSet(t *testing.T) { + t.Parallel() + + connAllocTimeout := 2 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireWithTimeTaken(pool) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < connAllocTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) + } +} + +func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { + t.Parallel() + + maxConnections := 1 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, maxConnections) + + // ... then try to consume 1 more. It should hang forever. + // To unblock it we release the previously taken connection in a goroutine. + stopDeadWaitTimeout := 5 * time.Second + timer := time.AfterFunc(stopDeadWaitTimeout, func() { + releaseAllConnections(pool, allConnections) + }) + defer timer.Stop() + + conn, timeTaken, err := acquireWithTimeTaken(pool) + if err == nil { + pool.Release(conn) + } else { + t.Fatalf("Expected error to be nil, instead it was '%v'", err) + } + if timeTaken < stopDeadWaitTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) + } +} + +func TestPoolReleaseWithTransactions(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + conn, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + mustExec(t, conn, "begin") + if _, err = conn.Exec("selct"); err == nil { + t.Fatal("Did not receive expected error") + } + + if conn.TxStatus != 'E' { + t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus) + } + + pool.Release(conn) + + if conn.TxStatus != 'I' { + t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus) + } + + conn, err = pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + mustExec(t, conn, "begin") + if conn.TxStatus != 'T' { + t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus) + } + + pool.Release(conn) + + if conn.TxStatus != 'I' { + t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus) + } +} + +func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { + t.Parallel() + + maxConnections := 3 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + doSomething := func() { + c, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to Acquire: %v", err) + } + rows, _ := c.Query("select 1, pg_sleep(0.02)") + rows.Close() + pool.Release(c) + } + + for i := 0; i < 10; i++ { + doSomething() + } + + stat := pool.Stat() + if stat.CurrentConnections != 1 { + t.Fatalf("Pool shouldn't have established more connections when no contention: %v", stat.CurrentConnections) + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doSomething() + }() + } + wg.Wait() + + stat = pool.Stat() + if stat.CurrentConnections != stat.MaxConnections { + t.Fatalf("Pool should have used all possible connections: %v", stat.CurrentConnections) + } +} + +func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { + t.Parallel() + + // Run timing sensitive test many times + for i := 0; i < 50; i++ { + func() { + maxConnections := 3 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + var c1, c2 *pgx.Conn + var err error + var stat pgx.ConnPoolStat + + if c1, err = pool.Acquire(); err != nil { + t.Fatalf("Unexpected error acquiring connection: %v", err) + } + defer func() { + if c1 != nil { + pool.Release(c1) + } + }() + + if c2, err = pool.Acquire(); err != nil { + t.Fatalf("Unexpected error acquiring connection: %v", err) + } + defer func() { + if c2 != nil { + pool.Release(c2) + } + }() + + if _, err = c2.Exec("select pg_terminate_backend($1)", c1.Pid); err != nil { + t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) + } + + // do something with the connection so it knows it's dead + rows, _ := c1.Query("select 1") + rows.Close() + if rows.Err() == nil { + t.Fatal("Expected error but none occurred") + } + + if c1.IsAlive() { + t.Fatal("Expected connection to be dead but it wasn't") + } + + stat = pool.Stat() + if stat.CurrentConnections != 2 { + t.Fatalf("Unexpected CurrentConnections: %v", stat.CurrentConnections) + } + if stat.AvailableConnections != 0 { + t.Fatalf("Unexpected AvailableConnections: %v", stat.CurrentConnections) + } + + pool.Release(c1) + c1 = nil // so it doesn't get released again by the defer + + stat = pool.Stat() + if stat.CurrentConnections != 1 { + t.Fatalf("Unexpected CurrentConnections: %v", stat.CurrentConnections) + } + if stat.AvailableConnections != 0 { + t.Fatalf("Unexpected AvailableConnections: %v", stat.CurrentConnections) + } + }() + } +} + +func TestConnPoolResetClosesCheckedOutConnectionsOnRelease(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 5) + defer pool.Close() + + inProgressRows := []*pgx.Rows{} + var inProgressPIDs []int32 + + // Start some queries and reset pool while they are in progress + for i := 0; i < 10; i++ { + rows, err := pool.Query("select pg_backend_pid() union all select 1 union all select 2") + if err != nil { + t.Fatal(err) + } + + rows.Next() + var pid int32 + rows.Scan(&pid) + inProgressPIDs = append(inProgressPIDs, pid) + + inProgressRows = append(inProgressRows, rows) + pool.Reset() + } + + // Check that the queries are completed + for _, rows := range inProgressRows { + var expectedN int32 + + for rows.Next() { + expectedN++ + var n int32 + err := rows.Scan(&n) + if err != nil { + t.Fatal(err) + } + if expectedN != n { + t.Fatalf("Expected n to be %d, but it was %d", expectedN, n) + } + } + + if err := rows.Err(); err != nil { + t.Fatal(err) + } + } + + // pool should be in fresh state due to previous reset + stats := pool.Stat() + if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + var connCount int + err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount) + if err != nil { + t.Fatal(err) + } + if connCount != 0 { + t.Fatalf("%d connections not closed", connCount) + } +} + +func TestConnPoolResetClosesCheckedInConnections(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 5) + defer pool.Close() + + inProgressRows := []*pgx.Rows{} + var inProgressPIDs []int32 + + // Start some queries and reset pool while they are in progress + for i := 0; i < 5; i++ { + rows, err := pool.Query("select pg_backend_pid()") + if err != nil { + t.Fatal(err) + } + + inProgressRows = append(inProgressRows, rows) + } + + // Check that the queries are completed + for _, rows := range inProgressRows { + for rows.Next() { + var pid int32 + err := rows.Scan(&pid) + if err != nil { + t.Fatal(err) + } + inProgressPIDs = append(inProgressPIDs, pid) + + } + + if err := rows.Err(); err != nil { + t.Fatal(err) + } + } + + // Ensure pool is fully connected and available + stats := pool.Stat() + if stats.CurrentConnections != 5 || stats.AvailableConnections != 5 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + pool.Reset() + + // Pool should be empty after reset + stats = pool.Stat() + if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + var connCount int + err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount) + if err != nil { + t.Fatal(err) + } + if connCount != 0 { + t.Fatalf("%d connections not closed", connCount) + } +} + +func TestConnPoolTransaction(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + stats := pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + tx, err := pool.Begin() + if err != nil { + t.Fatalf("pool.Begin failed: %v", err) + } + defer tx.Rollback() + + var n int32 + err = tx.QueryRow("select 40+$1", 2).Scan(&n) + if err != nil { + t.Fatalf("tx.QueryRow Scan failed: %v", err) + } + if n != 42 { + t.Errorf("Expected 42, got %d", n) + } + + stats = pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback failed: %v", err) + } + + stats = pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } +} + +func TestConnPoolTransactionIso(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + tx, err := pool.BeginIso(pgx.Serializable) + if err != nil { + t.Fatalf("pool.Begin failed: %v", err) + } + defer tx.Rollback() + + var level string + err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&level) + if err != nil { + t.Fatalf("tx.QueryRow failed: %v", level) + } + + if level != "serializable" { + t.Errorf("Expected to be in isolation level %v but was %v", "serializable", level) + } +} + +func TestConnPoolBeginRetry(t *testing.T) { + t.Parallel() + + // Run timing sensitive test many times + for i := 0; i < 50; i++ { + func() { + pool := createConnPool(t, 2) + defer pool.Close() + + killerConn, err := pool.Acquire() + if err != nil { + t.Fatal(err) + } + defer pool.Release(killerConn) + + victimConn, err := pool.Acquire() + if err != nil { + t.Fatal(err) + } + pool.Release(victimConn) + + // Terminate connection that was released to pool + if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.Pid); err != nil { + t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) + } + + // Since victimConn is the only available connection in the pool, pool.Begin should + // try to use it, fail, and allocate another connection + tx, err := pool.Begin() + if err != nil { + t.Fatalf("pool.Begin failed: %v", err) + } + defer tx.Rollback() + + var txPid int32 + err = tx.QueryRow("select pg_backend_pid()").Scan(&txPid) + if err != nil { + t.Fatalf("tx.QueryRow Scan failed: %v", err) + } + if txPid == victimConn.Pid { + t.Error("Expected txPid to defer from killed conn pid, but it didn't") + } + }() + } +} + +func TestConnPoolQuery(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + var sum, rowCount int32 + + rows, err := pool.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("pool.Query failed: %v", err) + } + + stats := pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } + + stats = pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } +} + +func TestConnPoolQueryConcurrentLoad(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 10) + defer pool.Close() + + n := 100 + done := make(chan bool) + + for i := 0; i < n; i++ { + go func() { + defer func() { done <- true }() + var rowCount int32 + + rows, err := pool.Query("select generate_series(1,$1)", 1000) + if err != nil { + t.Fatalf("pool.Query failed: %v", err) + } + defer rows.Close() + + for rows.Next() { + var n int32 + err = rows.Scan(&n) + if err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + if n != rowCount+1 { + t.Fatalf("Expected n to be %d, but it was %d", rowCount+1, n) + } + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", rows.Err()) + } + + if rowCount != 1000 { + t.Error("Select called onDataRow wrong number of times") + } + + _, err = pool.Exec("--;") + if err != nil { + t.Fatalf("pool.Exec failed: %v", err) + } + }() + } + + for i := 0; i < n; i++ { + <-done + } +} + +func TestConnPoolQueryRow(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + var n int32 + err := pool.QueryRow("select 40+$1", 2).Scan(&n) + if err != nil { + t.Fatalf("pool.QueryRow Scan failed: %v", err) + } + + if n != 42 { + t.Errorf("Expected 42, got %d", n) + } + + stats := pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } +} + +func TestConnPoolExec(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + results, err := pool.Exec("create temporary table foo(id integer primary key);") + if err != nil { + t.Fatalf("Unexpected error from pool.Exec: %v", err) + } + if results != "CREATE TABLE" { + t.Errorf("Unexpected results from Exec: %v", results) + } + + results, err = pool.Exec("insert into foo(id) values($1)", 1) + if err != nil { + t.Fatalf("Unexpected error from pool.Exec: %v", err) + } + if results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + + results, err = pool.Exec("drop table foo;") + if err != nil { + t.Fatalf("Unexpected error from pool.Exec: %v", err) + } + if results != "DROP TABLE" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} + +func TestConnPoolPrepare(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + _, err := pool.Prepare("test", "select $1::varchar") + if err != nil { + t.Fatalf("Unable to prepare statement: %v", err) + } + + var s string + err = pool.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Errorf("Executing prepared statement failed: %v", err) + } + + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) + } + + err = pool.Deallocate("test") + if err != nil { + t.Errorf("Deallocate failed: %v", err) + } + + err = pool.QueryRow("test", "hello").Scan(&s) + if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") { + t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) + } +} + +func TestConnPoolPrepareDeallocatePrepare(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + _, err := pool.Prepare("test", "select $1::varchar") + if err != nil { + t.Fatalf("Unable to prepare statement: %v", err) + } + err = pool.Deallocate("test") + if err != nil { + t.Fatalf("Unable to deallocate statement: %v", err) + } + _, err = pool.Prepare("test", "select $1::varchar") + if err != nil { + t.Fatalf("Unable to prepare statement: %v", err) + } + + var s string + err = pool.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Fatalf("Executing prepared statement failed: %v", err) + } + + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) + } +} + +func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + testPreparedStatement := func(db queryRower, desc string) { + var s string + err := db.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Fatalf("%s. Executing prepared statement failed: %v", desc, err) + } + + if s != "hello" { + t.Fatalf("%s. Prepared statement did not return expected value: %v", desc, s) + } + } + + newReleaseOnce := func(c *pgx.Conn) func() { + var once sync.Once + return func() { + once.Do(func() { pool.Release(c) }) + } + } + + c1, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + c1Release := newReleaseOnce(c1) + defer c1Release() + + _, err = pool.Prepare("test", "select $1::varchar") + if err != nil { + t.Fatalf("Unable to prepare statement: %v", err) + } + + testPreparedStatement(pool, "pool") + + c1Release() + + c2, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + c2Release := newReleaseOnce(c2) + defer c2Release() + + // This conn will not be available and will be connection at this point + c3, err := pool.Acquire() + if err != nil { + t.Fatalf("Unable to acquire connection: %v", err) + } + c3Release := newReleaseOnce(c3) + defer c3Release() + + testPreparedStatement(c2, "c2") + testPreparedStatement(c3, "c3") + + c2Release() + c3Release() + + err = pool.Deallocate("test") + if err != nil { + t.Errorf("Deallocate failed: %v", err) + } + + var s string + err = pool.QueryRow("test", "hello").Scan(&s) + if err, ok := err.(pgx.PgError); !(ok && err.Code == "42601") { + t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) + } +} diff --git a/vendor/github.com/jackc/pgx/conn_test.go b/vendor/github.com/jackc/pgx/conn_test.go new file mode 100644 index 0000000..cfb9956 --- /dev/null +++ b/vendor/github.com/jackc/pgx/conn_test.go @@ -0,0 +1,1744 @@ +package pgx_test + +import ( + "crypto/tls" + "fmt" + "net" + "os" + "reflect" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnect(t *testing.T) { + t.Parallel() + + conn, err := pgx.Connect(*defaultConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.Pid == 0 { + t.Error("Backend PID not stored") + } + + if conn.SecretKey == 0 { + t.Error("Backend secret key not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithUnixSocketDirectory(t *testing.T) { + t.Parallel() + + // /.s.PGSQL.5432 + if unixSocketConnConfig == nil { + t.Skip("Skipping due to undefined unixSocketConnConfig") + } + + conn, err := pgx.Connect(*unixSocketConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithUnixSocketFile(t *testing.T) { + t.Parallel() + + if unixSocketConnConfig == nil { + t.Skip("Skipping due to undefined unixSocketConnConfig") + } + + connParams := *unixSocketConnConfig + connParams.Host = connParams.Host + "/.s.PGSQL.5432" + conn, err := pgx.Connect(connParams) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithTcp(t *testing.T) { + t.Parallel() + + if tcpConnConfig == nil { + t.Skip("Skipping due to undefined tcpConnConfig") + } + + conn, err := pgx.Connect(*tcpConnConfig) + if err != nil { + t.Fatal("Unable to establish connection: " + err.Error()) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithTLS(t *testing.T) { + t.Parallel() + + if tlsConnConfig == nil { + t.Skip("Skipping due to undefined tlsConnConfig") + } + + conn, err := pgx.Connect(*tlsConnConfig) + if err != nil { + t.Fatal("Unable to establish connection: " + err.Error()) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithInvalidUser(t *testing.T) { + t.Parallel() + + if invalidUserConnConfig == nil { + t.Skip("Skipping due to undefined invalidUserConnConfig") + } + + _, err := pgx.Connect(*invalidUserConnConfig) + pgErr, ok := err.(pgx.PgError) + if !ok { + t.Fatalf("Expected to receive a PgError with code 28000, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } +} + +func TestConnectWithPlainTextPassword(t *testing.T) { + t.Parallel() + + if plainPasswordConnConfig == nil { + t.Skip("Skipping due to undefined plainPasswordConnConfig") + } + + conn, err := pgx.Connect(*plainPasswordConnConfig) + if err != nil { + t.Fatal("Unable to establish connection: " + err.Error()) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithMD5Password(t *testing.T) { + t.Parallel() + + if md5ConnConfig == nil { + t.Skip("Skipping due to undefined md5ConnConfig") + } + + conn, err := pgx.Connect(*md5ConnConfig) + if err != nil { + t.Fatal("Unable to establish connection: " + err.Error()) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithTLSFallback(t *testing.T) { + t.Parallel() + + if tlsConnConfig == nil { + t.Skip("Skipping due to undefined tlsConnConfig") + } + + connConfig := *tlsConnConfig + connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure + + conn, err := pgx.Connect(connConfig) + if err == nil { + t.Fatal("Expected failed connection, but succeeded") + } + + connConfig.UseFallbackTLS = true + connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} + + conn, err = pgx.Connect(connConfig) + if err != nil { + t.Fatal("Unable to establish connection: " + err.Error()) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithConnectionRefused(t *testing.T) { + t.Parallel() + + // Presumably nothing is listening on 127.0.0.1:1 + bad := *defaultConnConfig + bad.Host = "127.0.0.1" + bad.Port = 1 + + _, err := pgx.Connect(bad) + if err == nil { + t.Fatal("Expected error establishing connection to bad port") + } +} + +func TestConnectCustomDialer(t *testing.T) { + t.Parallel() + + if customDialerConnConfig == nil { + t.Skip("Skipping due to undefined customDialerConnConfig") + } + + dialled := false + conf := *customDialerConnConfig + conf.Dial = func(network, address string) (net.Conn, error) { + dialled = true + return net.Dial(network, address) + } + + conn, err := pgx.Connect(conf) + if err != nil { + t.Fatalf("Unable to establish connection: %s", err) + } + if !dialled { + t.Fatal("Connect did not use custom dialer") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithRuntimeParams(t *testing.T) { + t.Parallel() + + connConfig := *defaultConnConfig + connConfig.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } + + conn, err := pgx.Connect(connConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer conn.Close() + + var s string + err = conn.QueryRow("show application_name").Scan(&s) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if s != "pgxtest" { + t.Errorf("Expected application_name to be %s, but it was %s", "pgxtest", s) + } + + err = conn.QueryRow("show search_path").Scan(&s) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if s != "myschema" { + t.Errorf("Expected search_path to be %s, but it was %s", "myschema", s) + } +} + +func TestParseURI(t *testing.T) { + t.Parallel() + + tests := []struct { + url string + connParams pgx.ConnConfig + }{ + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + UseFallbackTLS: false, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgresql://jack:secret@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + connParams, err := pgx.ParseURI(tt.url) + if err != nil { + t.Errorf("%d. Unexpected error from pgx.ParseURL(%q) => %v", i, tt.url, err) + continue + } + + if !reflect.DeepEqual(connParams, tt.connParams) { + t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) + } + } +} + +func TestParseDSN(t *testing.T) { + t.Parallel() + + tests := []struct { + url string + connParams pgx.ConnConfig + }{ + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + connParams, err := pgx.ParseDSN(tt.url) + if err != nil { + t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) + continue + } + + if !reflect.DeepEqual(connParams, tt.connParams) { + t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) + } + } +} + +func TestParseConnectionString(t *testing.T) { + t.Parallel() + + tests := []struct { + url string + connParams pgx.ConnConfig + }{ + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + UseFallbackTLS: false, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgresql://jack:secret@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + connParams, err := pgx.ParseConnectionString(tt.url) + if err != nil { + t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) + continue + } + + if !reflect.DeepEqual(connParams, tt.connParams) { + t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) + } + } +} + +func TestParseEnvLibpq(t *testing.T) { + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME"} + + savedEnv := make(map[string]string) + for _, n := range pgEnvvars { + savedEnv[n] = os.Getenv(n) + } + defer func() { + for k, v := range savedEnv { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("Unable to restore environment: %v", err) + } + } + }() + + tests := []struct { + name string + envvars map[string]string + config pgx.ConnConfig + }{ + { + name: "No environment", + envvars: map[string]string{}, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "Normal PG vars", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + }, + config: pgx.ConnConfig{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "application_name", + envvars: map[string]string{ + "PGAPPNAME": "pgxtest", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, + }, + }, + { + name: "sslmode=disable", + envvars: map[string]string{ + "PGSSLMODE": "disable", + }, + config: pgx.ConnConfig{ + TLSConfig: nil, + UseFallbackTLS: false, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode=allow", + envvars: map[string]string{ + "PGSSLMODE": "allow", + }, + config: pgx.ConnConfig{ + TLSConfig: nil, + UseFallbackTLS: true, + FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode=prefer", + envvars: map[string]string{ + "PGSSLMODE": "prefer", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode=require", + envvars: map[string]string{ + "PGSSLMODE": "require", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{}, + UseFallbackTLS: false, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode=verify-ca", + envvars: map[string]string{ + "PGSSLMODE": "verify-ca", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{}, + UseFallbackTLS: false, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode=verify-full", + envvars: map[string]string{ + "PGSSLMODE": "verify-full", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{}, + UseFallbackTLS: false, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode=verify-full with host", + envvars: map[string]string{ + "PGHOST": "pgx.example", + "PGSSLMODE": "verify-full", + }, + config: pgx.ConnConfig{ + Host: "pgx.example", + TLSConfig: &tls.Config{ + ServerName: "pgx.example", + }, + UseFallbackTLS: false, + RuntimeParams: map[string]string{}, + }, + }, + } + + for _, tt := range tests { + for _, n := range pgEnvvars { + err := os.Unsetenv(n) + if err != nil { + t.Fatalf("%s: Unable to clear environment: %v", tt.name, err) + } + } + + for k, v := range tt.envvars { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("%s: Unable to set environment: %v", tt.name, err) + } + } + + config, err := pgx.ParseEnvLibpq() + if err != nil { + t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err) + continue + } + + if config.Host != tt.config.Host { + t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host) + } + if config.Port != tt.config.Port { + t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port) + } + if config.Port != tt.config.Port { + t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port) + } + if config.User != tt.config.User { + t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User) + } + if config.Password != tt.config.Password { + t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password) + } + + if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) { + t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams) + } + + tlsTests := []struct { + name string + expected *tls.Config + actual *tls.Config + }{ + { + name: "TLSConfig", + expected: tt.config.TLSConfig, + actual: config.TLSConfig, + }, + { + name: "FallbackTLSConfig", + expected: tt.config.FallbackTLSConfig, + actual: config.FallbackTLSConfig, + }, + } + for _, tlsTest := range tlsTests { + name := tlsTest.name + expected := tlsTest.expected + actual := tlsTest.actual + + if expected == nil && actual != nil { + t.Errorf("%s / %s: expected nil, but it was set", tt.name, name) + } else if expected != nil && actual == nil { + t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name) + } else if expected != nil && actual != nil { + if actual.InsecureSkipVerify != expected.InsecureSkipVerify { + t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify) + } + + if actual.ServerName != expected.ServerName { + t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName) + } + } + } + + if config.UseFallbackTLS != tt.config.UseFallbackTLS { + t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS) + } + } +} + +func TestExec(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + + if results := mustExec(t, conn, "drop table foo;"); results != "DROP TABLE" { + t.Error("Unexpected results from Exec") + } + + // Multiple statements can be executed -- last command tag is returned + if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results != "DROP TABLE" { + t.Error("Unexpected results from Exec") + } + + // Can execute longer SQL strings than sharedBufferSize + if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results != "SELECT 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + + // Exec no-op which does not return a command tag + if results := mustExec(t, conn, "--;"); results != "" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} + +func TestExecFailure(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if _, err := conn.Exec("selct;"); err == nil { + t.Fatal("Expected SQL syntax error") + } + + rows, _ := conn.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) + } +} + +func TestPrepare(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + _, err := conn.Prepare("test", "select $1::varchar") + if err != nil { + t.Errorf("Unable to prepare statement: %v", err) + return + } + + var s string + err = conn.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Errorf("Executing prepared statement failed: %v", err) + } + + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) + } + + err = conn.Deallocate("test") + if err != nil { + t.Errorf("conn.Deallocate failed: %v", err) + } + + // Create another prepared statement to ensure Deallocate left the connection + // in a working state and that we can reuse the prepared statement name. + + _, err = conn.Prepare("test", "select $1::integer") + if err != nil { + t.Errorf("Unable to prepare statement: %v", err) + return + } + + var n int32 + err = conn.QueryRow("test", int32(1)).Scan(&n) + if err != nil { + t.Errorf("Executing prepared statement failed: %v", err) + } + + if n != 1 { + t.Errorf("Prepared statement did not return expected value: %v", s) + } + + err = conn.Deallocate("test") + if err != nil { + t.Errorf("conn.Deallocate failed: %v", err) + } +} + +func TestPrepareBadSQLFailure(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if _, err := conn.Prepare("badSQL", "select foo"); err == nil { + t.Fatal("Prepare should have failed with syntax error") + } + + ensureConnValid(t, conn) +} + +func TestPrepareQueryManyParameters(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + count int + succeed bool + }{ + { + count: 65534, + succeed: true, + }, + { + count: 65535, + succeed: true, + }, + { + count: 65536, + succeed: false, + }, + { + count: 65537, + succeed: false, + }, + } + + for i, tt := range tests { + params := make([]string, 0, tt.count) + args := make([]interface{}, 0, tt.count) + for j := 0; j < tt.count; j++ { + params = append(params, fmt.Sprintf("($%d::text)", j+1)) + args = append(args, strconv.Itoa(j)) + } + + sql := "values" + strings.Join(params, ", ") + + psName := fmt.Sprintf("manyParams%d", i) + _, err := conn.Prepare(psName, sql) + if err != nil { + if tt.succeed { + t.Errorf("%d. %v", i, err) + } + continue + } + if !tt.succeed { + t.Errorf("%d. Expected error but succeeded", i) + continue + } + + rows, err := conn.Query(psName, args...) + if err != nil { + t.Errorf("conn.Query failed: %v", err) + continue + } + + for rows.Next() { + var s string + rows.Scan(&s) + } + + if rows.Err() != nil { + t.Errorf("Reading query result failed: %v", err) + } + } + + ensureConnValid(t, conn) +} + +func TestPrepareIdempotency(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for i := 0; i < 2; i++ { + _, err := conn.Prepare("test", "select 42::integer") + if err != nil { + t.Fatalf("%d. Unable to prepare statement: %v", i, err) + } + + var n int32 + err = conn.QueryRow("test").Scan(&n) + if err != nil { + t.Errorf("%d. Executing prepared statement failed: %v", i, err) + } + + if n != int32(42) { + t.Errorf("%d. Prepared statement did not return expected value: %v", i, n) + } + } + + _, err := conn.Prepare("test", "select 'fail'::varchar") + if err == nil { + t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't") + return + } +} + +func TestPrepareEx(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}}) + if err != nil { + t.Errorf("Unable to prepare statement: %v", err) + return + } + + var s string + err = conn.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Errorf("Executing prepared statement failed: %v", err) + } + + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) + } + + err = conn.Deallocate("test") + if err != nil { + t.Errorf("conn.Deallocate failed: %v", err) + } +} + +func TestListenNotify(t *testing.T) { + t.Parallel() + + listener := mustConnect(t, *defaultConnConfig) + defer closeConn(t, listener) + + if err := listener.Listen("chat"); err != nil { + t.Fatalf("Unable to start listening: %v", err) + } + + notifier := mustConnect(t, *defaultConnConfig) + defer closeConn(t, notifier) + + mustExec(t, notifier, "notify chat") + + // when notification is waiting on the socket to be read + notification, err := listener.WaitForNotification(time.Second) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "chat" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } + + // when notification has already been read during previous query + mustExec(t, notifier, "notify chat") + rows, _ := listener.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("Unexpected error on Query: %v", rows.Err()) + } + notification, err = listener.WaitForNotification(0) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "chat" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } + + // when timeout occurs + notification, err = listener.WaitForNotification(time.Millisecond) + if err != pgx.ErrNotificationTimeout { + t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) + } + if notification != nil { + t.Errorf("WaitForNotification returned an unexpected notification: %v", notification) + } + + // listener can listen again after a timeout + mustExec(t, notifier, "notify chat") + notification, err = listener.WaitForNotification(time.Second) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "chat" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } +} + +func TestUnlistenSpecificChannel(t *testing.T) { + t.Parallel() + + listener := mustConnect(t, *defaultConnConfig) + defer closeConn(t, listener) + + if err := listener.Listen("unlisten_test"); err != nil { + t.Fatalf("Unable to start listening: %v", err) + } + + notifier := mustConnect(t, *defaultConnConfig) + defer closeConn(t, notifier) + + mustExec(t, notifier, "notify unlisten_test") + + // when notification is waiting on the socket to be read + notification, err := listener.WaitForNotification(time.Second) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "unlisten_test" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } + + err = listener.Unlisten("unlisten_test") + if err != nil { + t.Fatalf("Unexpected error on Unlisten: %v", err) + } + + // when notification has already been read during previous query + mustExec(t, notifier, "notify unlisten_test") + rows, _ := listener.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("Unexpected error on Query: %v", rows.Err()) + } + notification, err = listener.WaitForNotification(100 * time.Millisecond) + if err != pgx.ErrNotificationTimeout { + t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) + } +} + +func TestListenNotifyWhileBusyIsSafe(t *testing.T) { + t.Parallel() + + listenerDone := make(chan bool) + go func() { + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + defer func() { + listenerDone <- true + }() + + if err := conn.Listen("busysafe"); err != nil { + t.Fatalf("Unable to start listening: %v", err) + } + + for i := 0; i < 5000; i++ { + var sum int32 + var rowCount int32 + + rows, err := conn.Query("select generate_series(1,$1)", 100) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if sum != 5050 { + t.Fatalf("Wrong rows sum: %v", sum) + } + + if rowCount != 100 { + t.Fatalf("Wrong number of rows: %v", rowCount) + } + + time.Sleep(1 * time.Microsecond) + } + }() + + notifierDone := make(chan bool) + go func() { + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + defer func() { + notifierDone <- true + }() + + for i := 0; i < 100000; i++ { + mustExec(t, conn, "notify busysafe, 'hello'") + time.Sleep(1 * time.Microsecond) + } + }() + + <-listenerDone +} + +func TestListenNotifySelfNotification(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if err := conn.Listen("self"); err != nil { + t.Fatalf("Unable to start listening: %v", err) + } + + // Notify self and WaitForNotification immediately + mustExec(t, conn, "notify self") + + notification, err := conn.WaitForNotification(time.Second) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "self" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } + + // Notify self and do something else before WaitForNotification + mustExec(t, conn, "notify self") + + rows, _ := conn.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("Unexpected error on Query: %v", rows.Err()) + } + + notification, err = conn.WaitForNotification(time.Second) + if err != nil { + t.Fatalf("Unexpected error on WaitForNotification: %v", err) + } + if notification.Channel != "self" { + t.Errorf("Did not receive notification on expected channel: %v", notification.Channel) + } +} + +func TestListenUnlistenSpecialCharacters(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + chanName := "special characters !@#{$%^&*()}" + if err := conn.Listen(chanName); err != nil { + t.Fatalf("Unable to start listening: %v", err) + } + + if err := conn.Unlisten(chanName); err != nil { + t.Fatalf("Unable to stop listening: %v", err) + } +} + +func TestFatalRxError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var n int32 + var s string + err := conn.QueryRow("select 1::int4, pg_sleep(10)::varchar").Scan(&n, &s) + if err == pgx.ErrDeadConn { + } else if pgErr, ok := err.(pgx.PgError); ok && pgErr.Severity == "FATAL" { + } else { + t.Fatalf("Expected QueryRow Scan to return fatal PgError or ErrDeadConn, but instead received %v", err) + } + }() + + otherConn, err := pgx.Connect(*defaultConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer otherConn.Close() + + if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.Pid); err != nil { + t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) + } + + wg.Wait() + + if conn.IsAlive() { + t.Fatal("Connection should not be live but was") + } +} + +func TestFatalTxError(t *testing.T) { + t.Parallel() + + // Run timing sensitive test many times + for i := 0; i < 50; i++ { + func() { + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + otherConn, err := pgx.Connect(*defaultConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer otherConn.Close() + + _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.Pid) + if err != nil { + t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) + } + + _, err = conn.Query("select 1") + if err == nil { + t.Fatal("Expected error but none occurred") + } + + if conn.IsAlive() { + t.Fatalf("Connection should not be live but was. Previous Query err: %v", err) + } + }() + } +} + +func TestCommandTag(t *testing.T) { + t.Parallel() + + var tests = []struct { + commandTag pgx.CommandTag + rowsAffected int64 + }{ + {commandTag: "INSERT 0 5", rowsAffected: 5}, + {commandTag: "UPDATE 0", rowsAffected: 0}, + {commandTag: "UPDATE 1", rowsAffected: 1}, + {commandTag: "DELETE 0", rowsAffected: 0}, + {commandTag: "DELETE 1", rowsAffected: 1}, + {commandTag: "CREATE TABLE", rowsAffected: 0}, + {commandTag: "ALTER TABLE", rowsAffected: 0}, + {commandTag: "DROP TABLE", rowsAffected: 0}, + } + + for i, tt := range tests { + actual := tt.commandTag.RowsAffected() + if tt.rowsAffected != actual { + t.Errorf(`%d. "%s" should have affected %d rows but it was %d`, i, tt.commandTag, tt.rowsAffected, actual) + } + } +} + +func TestInsertBoolArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} + +func TestInsertTimestampArray(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } + + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } +} + +func TestCatchSimultaneousConnectionQueries(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows1, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows1.Close() + + _, err = conn.Query("select generate_series(1,$1)", 10) + if err != pgx.ErrConnBusy { + t.Fatalf("conn.Query should have failed with pgx.ErrConnBusy, but it was %v", err) + } +} + +func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows.Close() + + _, err = conn.Exec("create temporary table foo(spice timestamp[])") + if err != pgx.ErrConnBusy { + t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err) + } +} + +type testLog struct { + lvl int + msg string + ctx []interface{} +} + +type testLogger struct { + logs []testLog +} + +func (l *testLogger) Debug(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx}) +} +func (l *testLogger) Info(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx}) +} +func (l *testLogger) Warn(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx}) +} +func (l *testLogger) Error(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx}) +} + +func TestSetLogger(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + l1 := &testLogger{} + oldLogger := conn.SetLogger(l1) + if oldLogger != nil { + t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger) + } + + if err := conn.Listen("foo"); err != nil { + t.Fatal(err) + } + + if len(l1.logs) == 0 { + t.Fatal("Expected new logger l1 to be called, but it wasn't") + } + + l2 := &testLogger{} + oldLogger = conn.SetLogger(l2) + if oldLogger != l1 { + t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger) + } + + if err := conn.Listen("bar"); err != nil { + t.Fatal(err) + } + + if len(l2.logs) == 0 { + t.Fatal("Expected new logger l2 to be called, but it wasn't") + } +} + +func TestSetLogLevel(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + logger := &testLogger{} + conn.SetLogger(logger) + + if _, err := conn.SetLogLevel(0); err != pgx.ErrInvalidLogLevel { + t.Fatal("SetLogLevel with invalid level did not return error") + } + + if _, err := conn.SetLogLevel(pgx.LogLevelNone); err != nil { + t.Fatal(err) + } + + if err := conn.Listen("foo"); err != nil { + t.Fatal(err) + } + + if len(logger.logs) != 0 { + t.Fatalf("Expected logger not to be called, but it was: %v", logger.logs) + } + + if _, err := conn.SetLogLevel(pgx.LogLevelTrace); err != nil { + t.Fatal(err) + } + + if err := conn.Listen("bar"); err != nil { + t.Fatal(err) + } + + if len(logger.logs) == 0 { + t.Fatal("Expected logger to be called, but it wasn't") + } +} + +func TestIdentifierSanitize(t *testing.T) { + t.Parallel() + + tests := []struct { + ident pgx.Identifier + expected string + }{ + { + ident: pgx.Identifier{`foo`}, + expected: `"foo"`, + }, + { + ident: pgx.Identifier{`select`}, + expected: `"select"`, + }, + { + ident: pgx.Identifier{`foo`, `bar`}, + expected: `"foo"."bar"`, + }, + { + ident: pgx.Identifier{`you should " not do this`}, + expected: `"you should "" not do this"`, + }, + { + ident: pgx.Identifier{`you should " not do this`, `please don't`}, + expected: `"you should "" not do this"."please don't"`, + }, + } + + for i, tt := range tests { + qval := tt.ident.Sanitize() + if qval != tt.expected { + t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval) + } + } +} diff --git a/vendor/github.com/jackc/pgx/copy_from.go b/vendor/github.com/jackc/pgx/copy_from.go new file mode 100644 index 0000000..1f8a230 --- /dev/null +++ b/vendor/github.com/jackc/pgx/copy_from.go @@ -0,0 +1,241 @@ +package pgx + +import ( + "bytes" + "fmt" +) + +// CopyFromRows returns a CopyFromSource interface over the provided rows slice +// making it usable by *Conn.CopyFrom. +func CopyFromRows(rows [][]interface{}) CopyFromSource { + return ©FromRows{rows: rows, idx: -1} +} + +type copyFromRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyFromRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyFromRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyFromRows) Err() error { + return nil +} + +// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. +type CopyFromSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyFromSource. If + // this is not nil *Conn.CopyFrom will abort the copy. + Err() error +} + +type copyFrom struct { + conn *Conn + tableName Identifier + columnNames []string + rowSrc CopyFromSource + readerErrChan chan error +} + +func (ct *copyFrom) readUntilReadyForQuery() { + for { + t, r, err := ct.conn.rxMsg() + if err != nil { + ct.readerErrChan <- err + close(ct.readerErrChan) + return + } + + switch t { + case readyForQuery: + ct.conn.rxReadyForQuery(r) + close(ct.readerErrChan) + return + case commandComplete: + case errorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(r) + default: + err = ct.conn.processContextFreeMsg(t, r) + if err != nil { + ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + } + } + } +} + +func (ct *copyFrom) waitForReaderDone() error { + var err error + for err = range ct.readerErrChan { + } + return err +} + +func (ct *copyFrom) run() (int, error) { + quotedTableName := ct.tableName.Sanitize() + buf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := buf.String() + + ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + if err != nil { + return 0, err + } + + err = ct.conn.readUntilCopyInResponse() + if err != nil { + return 0, err + } + + go ct.readUntilReadyForQuery() + defer ct.waitForReaderDone() + + wbuf := newWriteBuf(ct.conn, copyData) + + wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) + wbuf.WriteInt32(0) + wbuf.WriteInt32(0) + + var sentCount int + + for ct.rowSrc.Next() { + select { + case err = <-ct.readerErrChan: + return 0, err + default: + } + + if len(wbuf.buf) > 65536 { + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + // Directly manipulate wbuf to reset to reuse the same buffer + wbuf.buf = wbuf.buf[0:5] + wbuf.buf[0] = copyData + wbuf.sizeIdx = 1 + } + + sentCount++ + + values, err := ct.rowSrc.Values() + if err != nil { + ct.cancelCopyIn() + return 0, err + } + if len(values) != len(ct.columnNames) { + ct.cancelCopyIn() + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + wbuf.WriteInt16(int16(len(ct.columnNames))) + for i, val := range values { + err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + ct.cancelCopyIn() + return 0, err + } + + } + } + + if ct.rowSrc.Err() != nil { + ct.cancelCopyIn() + return 0, ct.rowSrc.Err() + } + + wbuf.WriteInt16(-1) // terminate the copy stream + + wbuf.startMsg(copyDone) + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + err = ct.waitForReaderDone() + if err != nil { + return 0, err + } + return sentCount, nil +} + +func (c *Conn) readUntilCopyInResponse() error { + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case copyInResponse: + return nil + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } +} + +func (ct *copyFrom) cancelCopyIn() error { + wbuf := newWriteBuf(ct.conn, copyFail) + wbuf.WriteCString("client error: abort") + wbuf.closeMsg() + _, err := ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return err + } + + return nil +} + +// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. +// It returns the number of rows copied and an error. +// +// CopyFrom requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { + ct := ©From{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run() +} diff --git a/vendor/github.com/jackc/pgx/copy_from_test.go b/vendor/github.com/jackc/pgx/copy_from_test.go new file mode 100644 index 0000000..54da698 --- /dev/null +++ b/vendor/github.com/jackc/pgx/copy_from_test.go @@ -0,0 +1,428 @@ +package pgx_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnCopyFromSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`) + + inputRows := [][]interface{}{} + + for i := 0; i < 10000; i++ { + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromJSON(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + } + + mustExec(t, conn, `create temporary table foo( + a json, + b jsonb + )`) + + inputRows := [][]interface{}{ + {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + {nil, nil}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromFailServerSideMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar not null + )`) + + inputRows := [][]interface{}{ + {int32(1), "abc"}, + {int32(2), nil}, // this row should trigger a failure + {int32(3), "def"}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type failSource struct { + count int +} + +func (fs *failSource) Next() bool { + time.Sleep(time.Millisecond * 100) + fs.count++ + return fs.count < 100 +} + +func (fs *failSource) Values() ([]interface{}, error) { + if fs.count == 3 { + return []interface{}{nil}, nil + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (fs *failSource) Err() error { + return nil +} + +func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + startTime := time.Now() + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + endTime := time.Now() + copyTime := endTime.Sub(startTime) + if copyTime > time.Second { + t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = fmt.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + +func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFinalErrSource struct { + count int +} + +func (cfs *clientFinalErrSource) Next() bool { + cfs.count++ + return cfs.count < 5 +} + +func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFinalErrSource) Err() error { + return fmt.Errorf("final error") +} + +func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} diff --git a/vendor/github.com/jackc/pgx/copy_to.go b/vendor/github.com/jackc/pgx/copy_to.go new file mode 100644 index 0000000..229e9a4 --- /dev/null +++ b/vendor/github.com/jackc/pgx/copy_to.go @@ -0,0 +1,222 @@ +package pgx + +import ( + "bytes" + "fmt" +) + +// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource +// interface over the provided rows slice making it usable by *Conn.CopyTo. +func CopyToRows(rows [][]interface{}) CopyToSource { + return ©ToRows{rows: rows, idx: -1} +} + +type copyToRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyToRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyToRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyToRows) Err() error { + return nil +} + +// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by +// *Conn.CopyTo as the source for copy data. +type CopyToSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyToSource. If + // this is not nil *Conn.CopyTo will abort the copy. + Err() error +} + +type copyTo struct { + conn *Conn + tableName string + columnNames []string + rowSrc CopyToSource + readerErrChan chan error +} + +func (ct *copyTo) readUntilReadyForQuery() { + for { + t, r, err := ct.conn.rxMsg() + if err != nil { + ct.readerErrChan <- err + close(ct.readerErrChan) + return + } + + switch t { + case readyForQuery: + ct.conn.rxReadyForQuery(r) + close(ct.readerErrChan) + return + case commandComplete: + case errorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(r) + default: + err = ct.conn.processContextFreeMsg(t, r) + if err != nil { + ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + } + } + } +} + +func (ct *copyTo) waitForReaderDone() error { + var err error + for err = range ct.readerErrChan { + } + return err +} + +func (ct *copyTo) run() (int, error) { + quotedTableName := quoteIdentifier(ct.tableName) + buf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := buf.String() + + ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + if err != nil { + return 0, err + } + + err = ct.conn.readUntilCopyInResponse() + if err != nil { + return 0, err + } + + go ct.readUntilReadyForQuery() + defer ct.waitForReaderDone() + + wbuf := newWriteBuf(ct.conn, copyData) + + wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) + wbuf.WriteInt32(0) + wbuf.WriteInt32(0) + + var sentCount int + + for ct.rowSrc.Next() { + select { + case err = <-ct.readerErrChan: + return 0, err + default: + } + + if len(wbuf.buf) > 65536 { + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + // Directly manipulate wbuf to reset to reuse the same buffer + wbuf.buf = wbuf.buf[0:5] + wbuf.buf[0] = copyData + wbuf.sizeIdx = 1 + } + + sentCount++ + + values, err := ct.rowSrc.Values() + if err != nil { + ct.cancelCopyIn() + return 0, err + } + if len(values) != len(ct.columnNames) { + ct.cancelCopyIn() + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + wbuf.WriteInt16(int16(len(ct.columnNames))) + for i, val := range values { + err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + ct.cancelCopyIn() + return 0, err + } + + } + } + + if ct.rowSrc.Err() != nil { + ct.cancelCopyIn() + return 0, ct.rowSrc.Err() + } + + wbuf.WriteInt16(-1) // terminate the copy stream + + wbuf.startMsg(copyDone) + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + err = ct.waitForReaderDone() + if err != nil { + return 0, err + } + return sentCount, nil +} + +func (ct *copyTo) cancelCopyIn() error { + wbuf := newWriteBuf(ct.conn, copyFail) + wbuf.WriteCString("client error: abort") + wbuf.closeMsg() + _, err := ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return err + } + + return nil +} + +// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to +// perform bulk data insertion. It returns the number of rows copied and an +// error. +// +// CopyTo requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + ct := ©To{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run() +} diff --git a/vendor/github.com/jackc/pgx/copy_to_test.go b/vendor/github.com/jackc/pgx/copy_to_test.go new file mode 100644 index 0000000..ac27042 --- /dev/null +++ b/vendor/github.com/jackc/pgx/copy_to_test.go @@ -0,0 +1,367 @@ +package pgx_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`) + + inputRows := [][]interface{}{} + + for i := 0; i < 10000; i++ { + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToJSON(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + } + + mustExec(t, conn, `create temporary table foo( + a json, + b jsonb + )`) + + inputRows := [][]interface{}{ + {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + {nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToFailServerSideMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar not null + )`) + + inputRows := [][]interface{}{ + {int32(1), "abc"}, + {int32(2), nil}, // this row should trigger a failure + {int32(3), "def"}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + startTime := time.Now() + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + endTime := time.Now() + copyTime := endTime.Sub(startTime) + if copyTime > time.Second { + t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} diff --git a/vendor/github.com/jackc/pgx/doc.go b/vendor/github.com/jackc/pgx/doc.go new file mode 100644 index 0000000..f527d11 --- /dev/null +++ b/vendor/github.com/jackc/pgx/doc.go @@ -0,0 +1,265 @@ +// Package pgx is a PostgreSQL database driver. +/* +pgx provides lower level access to PostgreSQL than the standard database/sql +It remains as similar to the database/sql interface as possible while +providing better speed and access to PostgreSQL specific features. Import +github.com/jack/pgx/stdlib to use pgx as a database/sql compatible driver. + +Query Interface + +pgx implements Query and Scan in the familiar database/sql style. + + var sum int32 + + // Send the query to the server. The returned rows MUST be closed + // before conn can be used again. + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + return err + } + + // rows.Close is called by rows.Next when all rows are read + // or an error occurs in Next or Scan. So it may optionally be + // omitted if nothing in the rows.Next loop can panic. It is + // safe to close rows multiple times. + defer rows.Close() + + // Iterate through the result set + for rows.Next() { + var n int32 + err = rows.Scan(&n) + if err != nil { + return err + } + sum += n + } + + // Any errors encountered by rows.Next or rows.Scan will be returned here + if rows.Err() != nil { + return err + } + + // No errors found - do something with sum + +pgx also implements QueryRow in the same style as database/sql. + + var name string + var weight int64 + err := conn.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight) + if err != nil { + return err + } + +Use Exec to execute a query that does not return a result set. + + commandTag, err := conn.Exec("delete from widgets where id=$1", 42) + if err != nil { + return err + } + if commandTag.RowsAffected() != 1 { + return errors.New("No row found to delete") + } + +Connection Pool + +Connection pool usage is explicit and configurable. In pgx, a connection can +be created and managed directly, or a connection pool with a configurable +maximum connections can be used. Also, the connection pool offers an after +connect hook that allows every connection to be automatically setup before +being made available in the connection pool. This is especially useful to +ensure all connections have the same prepared statements available or to +change any other connection settings. + +It delegates Query, QueryRow, Exec, and Begin functions to an automatically +checked out and released connection so you can avoid manually acquiring and +releasing connections when you do not need that level of control. + + var name string + var weight int64 + err := pool.QueryRow("select name, weight from widgets where id=$1", 42).Scan(&name, &weight) + if err != nil { + return err + } + +Base Type Mapping + +pgx maps between all common base types directly between Go and PostgreSQL. In +particular: + + Go PostgreSQL + ----------------------- + string varchar + text + + // Integers are automatically be converted to any other integer type if + // it can be done without overflow or underflow. + int8 + int16 smallint + int32 int + int64 bigint + int + uint8 + uint16 + uint32 + uint64 + uint + + // Floats are strict and do not automatically convert like integers. + float32 float4 + float64 float8 + + time.Time date + timestamp + timestamptz + + []byte bytea + + +Null Mapping + +pgx can map nulls in two ways. The first is Null* types that have a data field +and a valid field. They work in a similar fashion to database/sql. The second +is to use a pointer to a pointer. + + var foo pgx.NullString + var bar *string + err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b) + if err != nil { + return err + } + +Array Mapping + +pgx maps between int16, int32, int64, float32, float64, and string Go slices +and the equivalent PostgreSQL array type. Go slices of native types do not +support nulls, so if a PostgreSQL array that contains a null is read into a +native Go slice an error will occur. + +Hstore Mapping + +pgx includes an Hstore type and a NullHstore type. Hstore is simply a +map[string]string and is preferred when the hstore contains no nulls. NullHstore +follows the Null* pattern and supports null values. + +JSON and JSONB Mapping + +pgx includes built-in support to marshal and unmarshal between Go types and +the PostgreSQL JSON and JSONB. + +Inet and Cidr Mapping + +pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In +addition, as a convenience pgx will encode from a net.IP; it will assume a /32 +netmask for IPv4 and a /128 for IPv6. + +Custom Type Support + +pgx includes support for the common data types like integers, floats, strings, +dates, and times that have direct mappings between Go and SQL. Support can be +added for additional types like point, hstore, numeric, etc. that do not have +direct mappings in Go by the types implementing ScannerPgx and Encoder. + +Custom types can support text or binary formats. Binary format can provide a +large performance increase. The natural place for deciding the format for a +value would be in ScannerPgx as it is responsible for decoding the returned +data. However, that is impossible as the query has already been sent by the time +the ScannerPgx is invoked. The solution to this is the global +DefaultTypeFormats. If a custom type prefers binary format it should register it +there. + + pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode + +Note that the type is referred to by name, not by OID. This is because custom +PostgreSQL types like hstore will have different OIDs on different servers. When +pgx establishes a connection it queries the pg_type table for all types. It then +matches the names in DefaultTypeFormats with the returned OIDs and stores it in +Conn.PgTypes. + +See example_custom_type_test.go for an example of a custom type for the +PostgreSQL point type. + +pgx also includes support for custom types implementing the database/sql.Scanner +and database/sql/driver.Valuer interfaces. + +Raw Bytes Mapping + +[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified +to PostgreSQL. In like manner, a *[]byte passed to Scan will be filled with +the raw bytes returned by PostgreSQL. This can be especially useful for reading +varchar, text, json, and jsonb values directly into a []byte and avoiding the +type conversion from string. + +Transactions + +Transactions are started by calling Begin or BeginIso. The BeginIso variant +creates a transaction with a specified isolation level. + + tx, err := conn.Begin() + if err != nil { + return err + } + // Rollback is safe to call even if the tx is already closed, so if + // the tx commits successfully, this is a no-op + defer tx.Rollback() + + _, err = tx.Exec("insert into foo(id) values (1)") + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + +Copy Protocol + +Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL +copy protocol. CopyFrom accepts a CopyFromSource interface. If the data is already +in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource interface. Or +implement CopyFromSource to avoid buffering the entire data set in memory. + + rows := [][]interface{}{ + {"John", "Smith", int32(36)}, + {"Jane", "Doe", int32(29)}, + } + + copyCount, err := conn.CopyFrom( + "people", + []string{"first_name", "last_name", "age"}, + pgx.CopyFromRows(rows), + ) + +CopyFrom can be faster than an insert with as few as 5 rows. + +Listen and Notify + +pgx can listen to the PostgreSQL notification system with the +WaitForNotification function. It takes a maximum time to wait for a +notification. + + err := conn.Listen("channelname") + if err != nil { + return nil + } + + if notification, err := conn.WaitForNotification(time.Second); err != nil { + // do something with notification + } + +TLS + +The pgx ConnConfig struct has a TLSConfig field. If this field is +nil, then TLS will be disabled. If it is present, then it will be used to +configure the TLS connection. This allows total configuration of the TLS +connection. + +Logging + +pgx defines a simple logger interface. Connections optionally accept a logger +that satisfies this interface. The log15 package +(http://gopkg.in/inconshreveable/log15.v2) satisfies this interface and it is +simple to define adapters for other loggers. Set LogLevel to control logging +verbosity. +*/ +package pgx diff --git a/vendor/github.com/jackc/pgx/example_custom_type_test.go b/vendor/github.com/jackc/pgx/example_custom_type_test.go new file mode 100644 index 0000000..34cc316 --- /dev/null +++ b/vendor/github.com/jackc/pgx/example_custom_type_test.go @@ -0,0 +1,104 @@ +package pgx_test + +import ( + "errors" + "fmt" + "github.com/jackc/pgx" + "regexp" + "strconv" +) + +var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) + +// NullPoint represents a point that may be null. +// +// If Valid is false then the value is NULL. +type NullPoint struct { + X, Y float64 // Coordinates of point + Valid bool // Valid is true if not NULL +} + +func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error { + if vr.Type().DataTypeName != "point" { + return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (OID %d)", vr.Type().DataTypeName, vr.Type().DataType)) + } + + if vr.Len() == -1 { + p.X, p.Y, p.Valid = 0, 0, false + return nil + } + + switch vr.Type().FormatCode { + case pgx.TextFormatCode: + s := vr.ReadString(vr.Len()) + match := pointRegexp.FindStringSubmatch(s) + if match == nil { + return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) + } + + var err error + p.X, err = strconv.ParseFloat(match[1], 64) + if err != nil { + return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) + } + p.Y, err = strconv.ParseFloat(match[2], 64) + if err != nil { + return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) + } + case pgx.BinaryFormatCode: + return errors.New("binary format not implemented") + default: + return fmt.Errorf("unknown format %v", vr.Type().FormatCode) + } + + p.Valid = true + return vr.Err() +} + +func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode } + +func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { + if !p.Valid { + w.WriteInt32(-1) + return nil + } + + s := fmt.Sprintf("point(%v,%v)", p.X, p.Y) + w.WriteInt32(int32(len(s))) + w.WriteBytes([]byte(s)) + + return nil +} + +func (p NullPoint) String() string { + if p.Valid { + return fmt.Sprintf("%v, %v", p.X, p.Y) + } + return "null point" +} + +func Example_CustomType() { + conn, err := pgx.Connect(*defaultConnConfig) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + var p NullPoint + err = conn.QueryRow("select null::point").Scan(&p) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(p) + + err = conn.QueryRow("select point(1.5,2.5)").Scan(&p) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(p) + // Output: + // null point + // 1.5, 2.5 +} diff --git a/vendor/github.com/jackc/pgx/example_json_test.go b/vendor/github.com/jackc/pgx/example_json_test.go new file mode 100644 index 0000000..c153415 --- /dev/null +++ b/vendor/github.com/jackc/pgx/example_json_test.go @@ -0,0 +1,43 @@ +package pgx_test + +import ( + "fmt" + "github.com/jackc/pgx" +) + +func Example_JSON() { + conn, err := pgx.Connect(*defaultConnConfig) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if _, ok := conn.PgTypes[pgx.JsonOid]; !ok { + // No JSON type -- must be running against very old PostgreSQL + // Pretend it works + fmt.Println("John", 42) + return + } + + type person struct { + Name string `json:"name"` + Age int `json:"age"` + } + + input := person{ + Name: "John", + Age: 42, + } + + var output person + + err = conn.QueryRow("select $1::json", input).Scan(&output) + if err != nil { + fmt.Println(err) + return + } + + fmt.Println(output.Name, output.Age) + // Output: + // John 42 +} diff --git a/vendor/github.com/jackc/pgx/examples/README.md b/vendor/github.com/jackc/pgx/examples/README.md new file mode 100644 index 0000000..6a97bc0 --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/README.md @@ -0,0 +1,7 @@ +# Examples + +* chat is a command line chat program using listen/notify. +* todo is a command line todo list that demonstrates basic CRUD actions. +* url_shortener contains a simple example of using pgx in a web context. +* [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx (uses v1 of pgx). +* [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx. diff --git a/vendor/github.com/jackc/pgx/examples/chat/README.md b/vendor/github.com/jackc/pgx/examples/chat/README.md new file mode 100644 index 0000000..a093525 --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/chat/README.md @@ -0,0 +1,25 @@ +# Description + +This is a sample chat program implemented using PostgreSQL's listen/notify +functionality with pgx. + +Start multiple instances of this program connected to the same database to chat +between them. + +## Connection configuration + +The database connection is configured via enviroment variables. + +* CHAT_DB_HOST - defaults to localhost +* CHAT_DB_USER - defaults to current OS user +* CHAT_DB_PASSWORD - defaults to empty string +* CHAT_DB_DATABASE - defaults to postgres + +You can either export them then run chat: + + export CHAT_DB_HOST=/private/tmp + ./chat + +Or you can prefix the chat execution with the environment variables: + + CHAT_DB_HOST=/private/tmp ./chat diff --git a/vendor/github.com/jackc/pgx/examples/chat/main.go b/vendor/github.com/jackc/pgx/examples/chat/main.go new file mode 100644 index 0000000..517508c --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/chat/main.go @@ -0,0 +1,95 @@ +package main + +import ( + "bufio" + "fmt" + "github.com/jackc/pgx" + "os" + "time" +) + +var pool *pgx.ConnPool + +func main() { + var err error + pool, err = pgx.NewConnPool(extractConfig()) + if err != nil { + fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) + os.Exit(1) + } + + go listen() + + fmt.Println(`Type a message and press enter. + +This message should appear in any other chat instances connected to the same +database. + +Type "exit" to quit. +`) + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + msg := scanner.Text() + if msg == "exit" { + os.Exit(0) + } + + _, err = pool.Exec("select pg_notify('chat', $1)", msg) + if err != nil { + fmt.Fprintln(os.Stderr, "Error sending notification:", err) + os.Exit(1) + } + } + if err := scanner.Err(); err != nil { + fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) + os.Exit(1) + } +} + +func listen() { + conn, err := pool.Acquire() + if err != nil { + fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) + os.Exit(1) + } + defer pool.Release(conn) + + conn.Listen("chat") + + for { + notification, err := conn.WaitForNotification(time.Second) + if err == pgx.ErrNotificationTimeout { + continue + } + if err != nil { + fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) + os.Exit(1) + } + + fmt.Println("PID:", notification.Pid, "Channel:", notification.Channel, "Payload:", notification.Payload) + } +} + +func extractConfig() pgx.ConnPoolConfig { + var config pgx.ConnPoolConfig + + config.Host = os.Getenv("CHAT_DB_HOST") + if config.Host == "" { + config.Host = "localhost" + } + + config.User = os.Getenv("CHAT_DB_USER") + if config.User == "" { + config.User = os.Getenv("USER") + } + + config.Password = os.Getenv("CHAT_DB_PASSWORD") + + config.Database = os.Getenv("CHAT_DB_DATABASE") + if config.Database == "" { + config.Database = "postgres" + } + + return config +} diff --git a/vendor/github.com/jackc/pgx/examples/todo/README.md b/vendor/github.com/jackc/pgx/examples/todo/README.md new file mode 100644 index 0000000..eb3d95b --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/todo/README.md @@ -0,0 +1,72 @@ +# Description + +This is a sample todo list implemented using pgx as the connector to a +PostgreSQL data store. + +# Usage + +Create a PostgreSQL database and run structure.sql into it to create the +necessary data schema. + +Example: + + createdb todo + psql todo < structure.sql + +Build todo: + + go build + +## Connection configuration + +The database connection is configured via enviroment variables. + +* TODO_DB_HOST - defaults to localhost +* TODO_DB_USER - defaults to current OS user +* TODO_DB_PASSWORD - defaults to empty string +* TODO_DB_DATABASE - defaults to todo + +You can either export them then run todo: + + export TODO_DB_HOST=/private/tmp + ./todo list + +Or you can prefix the todo execution with the environment variables: + + TODO_DB_HOST=/private/tmp ./todo list + +## Add a todo item + + ./todo add 'Learn go' + +## List tasks + + ./todo list + +## Update a task + + ./todo add 1 'Learn more go' + +## Delete a task + + ./todo remove 1 + +# Example Setup and Execution + + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ createdb todo + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ psql todo < structure.sql + Expanded display is used automatically. + Timing is on. + CREATE TABLE + Time: 6.363 ms + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ go build + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ export TODO_DB_HOST=/private/tmp + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo add 'Learn Go' + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list + 1. Learn Go + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo update 1 'Learn more Go' + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list + 1. Learn more Go + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo remove 1 + jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list diff --git a/vendor/github.com/jackc/pgx/examples/todo/main.go b/vendor/github.com/jackc/pgx/examples/todo/main.go new file mode 100644 index 0000000..bacdf24 --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/todo/main.go @@ -0,0 +1,140 @@ +package main + +import ( + "fmt" + "github.com/jackc/pgx" + "os" + "strconv" +) + +var conn *pgx.Conn + +func main() { + var err error + conn, err = pgx.Connect(extractConfig()) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to connection to database: %v\n", err) + os.Exit(1) + } + + if len(os.Args) == 1 { + printHelp() + os.Exit(0) + } + + switch os.Args[1] { + case "list": + err = listTasks() + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to list tasks: %v\n", err) + os.Exit(1) + } + + case "add": + err = addTask(os.Args[2]) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to add task: %v\n", err) + os.Exit(1) + } + + case "update": + n, err := strconv.ParseInt(os.Args[2], 10, 32) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err) + os.Exit(1) + } + err = updateTask(int32(n), os.Args[3]) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to update task: %v\n", err) + os.Exit(1) + } + + case "remove": + n, err := strconv.ParseInt(os.Args[2], 10, 32) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err) + os.Exit(1) + } + err = removeTask(int32(n)) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to remove task: %v\n", err) + os.Exit(1) + } + + default: + fmt.Fprintln(os.Stderr, "Invalid command") + printHelp() + os.Exit(1) + } +} + +func listTasks() error { + rows, _ := conn.Query("select * from tasks") + + for rows.Next() { + var id int32 + var description string + err := rows.Scan(&id, &description) + if err != nil { + return err + } + fmt.Printf("%d. %s\n", id, description) + } + + return rows.Err() +} + +func addTask(description string) error { + _, err := conn.Exec("insert into tasks(description) values($1)", description) + return err +} + +func updateTask(itemNum int32, description string) error { + _, err := conn.Exec("update tasks set description=$1 where id=$2", description, itemNum) + return err +} + +func removeTask(itemNum int32) error { + _, err := conn.Exec("delete from tasks where id=$1", itemNum) + return err +} + +func printHelp() { + fmt.Print(`Todo pgx demo + +Usage: + + todo list + todo add task + todo update task_num item + todo remove task_num + +Example: + + todo add 'Learn Go' + todo list +`) +} + +func extractConfig() pgx.ConnConfig { + var config pgx.ConnConfig + + config.Host = os.Getenv("TODO_DB_HOST") + if config.Host == "" { + config.Host = "localhost" + } + + config.User = os.Getenv("TODO_DB_USER") + if config.User == "" { + config.User = os.Getenv("USER") + } + + config.Password = os.Getenv("TODO_DB_PASSWORD") + + config.Database = os.Getenv("TODO_DB_DATABASE") + if config.Database == "" { + config.Database = "todo" + } + + return config +} diff --git a/vendor/github.com/jackc/pgx/examples/todo/structure.sql b/vendor/github.com/jackc/pgx/examples/todo/structure.sql new file mode 100644 index 0000000..567685d --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/todo/structure.sql @@ -0,0 +1,4 @@ +create table tasks ( + id serial primary key, + description text not null +); diff --git a/vendor/github.com/jackc/pgx/examples/url_shortener/README.md b/vendor/github.com/jackc/pgx/examples/url_shortener/README.md new file mode 100644 index 0000000..cc04d60 --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/url_shortener/README.md @@ -0,0 +1,25 @@ +# Description + +This is a sample REST URL shortener service implemented using pgx as the connector to a PostgreSQL data store. + +# Usage + +Create a PostgreSQL database and run structure.sql into it to create the necessary data schema. + +Edit connectionOptions in main.go with the location and credentials for your database. + +Run main.go: + + go run main.go + +## Create or Update a Shortened URL + + curl -X PUT -d 'http://www.google.com' http://localhost:8080/google + +## Get a Shortened URL + + curl http://localhost:8080/google + +## Delete a Shortened URL + + curl -X DELETE http://localhost:8080/google diff --git a/vendor/github.com/jackc/pgx/examples/url_shortener/main.go b/vendor/github.com/jackc/pgx/examples/url_shortener/main.go new file mode 100644 index 0000000..f6a22c3 --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/url_shortener/main.go @@ -0,0 +1,128 @@ +package main + +import ( + "github.com/jackc/pgx" + log "gopkg.in/inconshreveable/log15.v2" + "io/ioutil" + "net/http" + "os" +) + +var pool *pgx.ConnPool + +// afterConnect creates the prepared statements that this application uses +func afterConnect(conn *pgx.Conn) (err error) { + _, err = conn.Prepare("getUrl", ` + select url from shortened_urls where id=$1 + `) + if err != nil { + return + } + + _, err = conn.Prepare("deleteUrl", ` + delete from shortened_urls where id=$1 + `) + if err != nil { + return + } + + // There technically is a small race condition in doing an upsert with a CTE + // where one of two simultaneous requests to the shortened URL would fail + // with a unique index violation. As the point of this demo is pgx usage and + // not how to perfectly upsert in PostgreSQL it is deemed acceptable. + _, err = conn.Prepare("putUrl", ` + with upsert as ( + update shortened_urls + set url=$2 + where id=$1 + returning * + ) + insert into shortened_urls(id, url) + select $1, $2 where not exists(select 1 from upsert) + `) + return +} + +func getUrlHandler(w http.ResponseWriter, req *http.Request) { + var url string + err := pool.QueryRow("getUrl", req.URL.Path).Scan(&url) + switch err { + case nil: + http.Redirect(w, req, url, http.StatusSeeOther) + case pgx.ErrNoRows: + http.NotFound(w, req) + default: + http.Error(w, "Internal server error", http.StatusInternalServerError) + } +} + +func putUrlHandler(w http.ResponseWriter, req *http.Request) { + id := req.URL.Path + var url string + if body, err := ioutil.ReadAll(req.Body); err == nil { + url = string(body) + } else { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + if _, err := pool.Exec("putUrl", id, url); err == nil { + w.WriteHeader(http.StatusOK) + } else { + http.Error(w, "Internal server error", http.StatusInternalServerError) + } +} + +func deleteUrlHandler(w http.ResponseWriter, req *http.Request) { + if _, err := pool.Exec("deleteUrl", req.URL.Path); err == nil { + w.WriteHeader(http.StatusOK) + } else { + http.Error(w, "Internal server error", http.StatusInternalServerError) + } +} + +func urlHandler(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case "GET": + getUrlHandler(w, req) + + case "PUT": + putUrlHandler(w, req) + + case "DELETE": + deleteUrlHandler(w, req) + + default: + w.Header().Add("Allow", "GET, PUT, DELETE") + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func main() { + var err error + connPoolConfig := pgx.ConnPoolConfig{ + ConnConfig: pgx.ConnConfig{ + Host: "127.0.0.1", + User: "jack", + Password: "jack", + Database: "url_shortener", + Logger: log.New("module", "pgx"), + }, + MaxConnections: 5, + AfterConnect: afterConnect, + } + pool, err = pgx.NewConnPool(connPoolConfig) + if err != nil { + log.Crit("Unable to create connection pool", "error", err) + os.Exit(1) + } + + http.HandleFunc("/", urlHandler) + + log.Info("Starting URL shortener on localhost:8080") + err = http.ListenAndServe("localhost:8080", nil) + if err != nil { + log.Crit("Unable to start web server", "error", err) + os.Exit(1) + } +} diff --git a/vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql b/vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql new file mode 100644 index 0000000..0b9de25 --- /dev/null +++ b/vendor/github.com/jackc/pgx/examples/url_shortener/structure.sql @@ -0,0 +1,4 @@ +create table shortened_urls ( + id text primary key, + url text not null +); \ No newline at end of file diff --git a/vendor/github.com/jackc/pgx/fastpath.go b/vendor/github.com/jackc/pgx/fastpath.go new file mode 100644 index 0000000..19b9878 --- /dev/null +++ b/vendor/github.com/jackc/pgx/fastpath.go @@ -0,0 +1,108 @@ +package pgx + +import ( + "encoding/binary" +) + +func newFastpath(cn *Conn) *fastpath { + return &fastpath{cn: cn, fns: make(map[string]Oid)} +} + +type fastpath struct { + cn *Conn + fns map[string]Oid +} + +func (f *fastpath) functionOID(name string) Oid { + return f.fns[name] +} + +func (f *fastpath) addFunction(name string, oid Oid) { + f.fns[name] = oid +} + +func (f *fastpath) addFunctions(rows *Rows) error { + for rows.Next() { + var name string + var oid Oid + if err := rows.Scan(&name, &oid); err != nil { + return err + } + f.addFunction(name, oid) + } + return rows.Err() +} + +type fpArg []byte + +func fpIntArg(n int32) fpArg { + res := make([]byte, 4) + binary.BigEndian.PutUint32(res, uint32(n)) + return res +} + +func fpInt64Arg(n int64) fpArg { + res := make([]byte, 8) + binary.BigEndian.PutUint64(res, uint64(n)) + return res +} + +func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { + wbuf := newWriteBuf(f.cn, 'F') // function call + wbuf.WriteInt32(int32(oid)) // function object id + wbuf.WriteInt16(1) // # of argument format codes + wbuf.WriteInt16(1) // format code: binary + wbuf.WriteInt16(int16(len(args))) // # of arguments + for _, arg := range args { + wbuf.WriteInt32(int32(len(arg))) // length of argument + wbuf.WriteBytes(arg) // argument value + } + wbuf.WriteInt16(1) // response format code (binary) + wbuf.closeMsg() + + if _, err := f.cn.conn.Write(wbuf.buf); err != nil { + return nil, err + } + + for { + var t byte + var r *msgReader + t, r, err = f.cn.rxMsg() + if err != nil { + return nil, err + } + switch t { + case 'V': // FunctionCallResponse + data := r.readBytes(r.readInt32()) + res = make([]byte, len(data)) + copy(res, data) + case 'Z': // Ready for query + f.cn.rxReadyForQuery(r) + // done + return + default: + if err := f.cn.processContextFreeMsg(t, r); err != nil { + return nil, err + } + } + } +} + +func (f *fastpath) CallFn(fn string, args []fpArg) ([]byte, error) { + return f.Call(f.functionOID(fn), args) +} + +func fpInt32(data []byte, err error) (int32, error) { + if err != nil { + return 0, err + } + n := int32(binary.BigEndian.Uint32(data)) + return n, nil +} + +func fpInt64(data []byte, err error) (int64, error) { + if err != nil { + return 0, err + } + return int64(binary.BigEndian.Uint64(data)), nil +} diff --git a/vendor/github.com/jackc/pgx/helper_test.go b/vendor/github.com/jackc/pgx/helper_test.go new file mode 100644 index 0000000..eff731e --- /dev/null +++ b/vendor/github.com/jackc/pgx/helper_test.go @@ -0,0 +1,74 @@ +package pgx_test + +import ( + "github.com/jackc/pgx" + "testing" +) + +func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { + conn, err := pgx.Connect(config) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + return conn +} + +func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn { + conn, err := pgx.ReplicationConnect(config) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + return conn +} + + +func closeConn(t testing.TB, conn *pgx.Conn) { + err := conn.Close() + if err != nil { + t.Fatalf("conn.Close unexpectedly failed: %v", err) + } +} + +func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) { + err := conn.Close() + if err != nil { + t.Fatalf("conn.Close unexpectedly failed: %v", err) + } +} + +func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) { + var err error + if commandTag, err = conn.Exec(sql, arguments...); err != nil { + t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) + } + return +} + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, conn *pgx.Conn) { + var sum, rowCount int32 + + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } +} diff --git a/vendor/github.com/jackc/pgx/hstore.go b/vendor/github.com/jackc/pgx/hstore.go new file mode 100644 index 0000000..0ab9f77 --- /dev/null +++ b/vendor/github.com/jackc/pgx/hstore.go @@ -0,0 +1,222 @@ +package pgx + +import ( + "bytes" + "errors" + "fmt" + "unicode" + "unicode/utf8" +) + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +func parseHstoreToMap(s string) (m map[string]string, err error) { + keys, values, err := ParseHstore(s) + if err != nil { + return + } + m = make(map[string]string, len(keys)) + for i, key := range keys { + if !values[i].Valid { + err = fmt.Errorf("key '%s' has NULL value", key) + m = nil + return + } + m[key] = values[i].String + } + return +} + +func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) { + keys, values, err := ParseHstore(s) + if err != nil { + return + } + + store = make(map[string]NullString, len(keys)) + + for i, key := range keys { + store[key] = values[i] + } + return +} + +// ParseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores, but is exported for use +// in handling custom data structures backed by an hstore column without the +// overhead of creating a map[string]string +func ParseHstore(s string) (k []string, v []NullString, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []NullString{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + if buf.Len() == 0 { + err = errors.New("Empty Key is invalid") + } else { + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + } + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, NullString{String: buf.String(), Valid: true}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, NullString{String: "", Valid: false}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} diff --git a/vendor/github.com/jackc/pgx/hstore_test.go b/vendor/github.com/jackc/pgx/hstore_test.go new file mode 100644 index 0000000..c948f0c --- /dev/null +++ b/vendor/github.com/jackc/pgx/hstore_test.go @@ -0,0 +1,181 @@ +package pgx_test + +import ( + "github.com/jackc/pgx" + "testing" +) + +func TestHstoreTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type test struct { + hstore pgx.Hstore + description string + } + + tests := []test{ + {pgx.Hstore{}, "empty"}, + {pgx.Hstore{"foo": "bar"}, "single key/value"}, + {pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"}, + {pgx.Hstore{"NULL": "bar"}, `string "NULL" key`}, + {pgx.Hstore{"foo": "NULL"}, `string "NULL" value`}, + } + + specialStringTests := []struct { + input string + description string + }{ + {`"`, `double quote (")`}, + {`'`, `single quote (')`}, + {`\`, `backslash (\)`}, + {`\\`, `multiple backslashes (\\)`}, + {`=>`, `separator (=>)`}, + {` `, `space`}, + {`\ / / \\ => " ' " '`, `multiple special characters`}, + } + for _, sst := range specialStringTests { + tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"}) + tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"}) + tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"}) + tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description}) + + tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"}) + tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"}) + tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"}) + tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description}) + } + + for _, tt := range tests { + var result pgx.Hstore + err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result) + if err != nil { + t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) + } + + for key, inValue := range tt.hstore { + outValue, ok := result[key] + if ok { + if inValue != outValue { + t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue) + } + } else { + t.Errorf(`%s: Missing key %s`, tt.description, key) + } + } + + ensureConnValid(t, conn) + } +} + +func TestNullHstoreTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type test struct { + nullHstore pgx.NullHstore + description string + } + + tests := []test{ + {pgx.NullHstore{}, "null"}, + {pgx.NullHstore{Valid: true}, "empty"}, + {pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}}, + Valid: true}, + "single key/value"}, + {pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}}, + Valid: true}, + "multiple key/values"}, + {pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}}, + Valid: true}, + `string "NULL" key`}, + {pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}}, + Valid: true}, + `string "NULL" value`}, + {pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}}, + Valid: true}, + `NULL value`}, + } + + specialStringTests := []struct { + input string + description string + }{ + {`"`, `double quote (")`}, + {`'`, `single quote (')`}, + {`\`, `backslash (\)`}, + {`\\`, `multiple backslashes (\\)`}, + {`=>`, `separator (=>)`}, + {` `, `space`}, + {`\ / / \\ => " ' " '`, `multiple special characters`}, + } + for _, sst := range specialStringTests { + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}}, + Valid: true}, + "key with " + sst.description + " at beginning"}) + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}}, + Valid: true}, + "key with " + sst.description + " in middle"}) + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}}, + Valid: true}, + "key with " + sst.description + " at end"}) + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}}, + Valid: true}, + "key is " + sst.description}) + + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}}, + Valid: true}, + "value with " + sst.description + " at beginning"}) + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}}, + Valid: true}, + "value with " + sst.description + " in middle"}) + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}}, + Valid: true}, + "value with " + sst.description + " at end"}) + tests = append(tests, test{pgx.NullHstore{ + Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}}, + Valid: true}, + "value is " + sst.description}) + } + + for _, tt := range tests { + var result pgx.NullHstore + err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result) + if err != nil { + t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) + } + + if result.Valid != tt.nullHstore.Valid { + t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid) + } + + for key, inValue := range tt.nullHstore.Hstore { + outValue, ok := result.Hstore[key] + if ok { + if inValue != outValue { + t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue) + } + } else { + t.Errorf(`%s: Missing key %s`, tt.description, key) + } + } + + ensureConnValid(t, conn) + } +} diff --git a/vendor/github.com/jackc/pgx/large_objects.go b/vendor/github.com/jackc/pgx/large_objects.go new file mode 100644 index 0000000..a4922ef --- /dev/null +++ b/vendor/github.com/jackc/pgx/large_objects.go @@ -0,0 +1,147 @@ +package pgx + +import ( + "io" +) + +// LargeObjects is a structure used to access the large objects API. It is only +// valid within the transaction where it was created. +// +// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html +type LargeObjects struct { + // Has64 is true if the server is capable of working with 64-bit numbers + Has64 bool + fp *fastpath +} + +const largeObjectFns = `select proname, oid from pg_catalog.pg_proc +where proname in ( +'lo_open', +'lo_close', +'lo_create', +'lo_unlink', +'lo_lseek', +'lo_lseek64', +'lo_tell', +'lo_tell64', +'lo_truncate', +'lo_truncate64', +'loread', +'lowrite') +and pronamespace = (select oid from pg_catalog.pg_namespace where nspname = 'pg_catalog')` + +// LargeObjects returns a LargeObjects instance for the transaction. +func (tx *Tx) LargeObjects() (*LargeObjects, error) { + if tx.conn.fp == nil { + tx.conn.fp = newFastpath(tx.conn) + } + if _, exists := tx.conn.fp.fns["lo_open"]; !exists { + res, err := tx.Query(largeObjectFns) + if err != nil { + return nil, err + } + if err := tx.conn.fp.addFunctions(res); err != nil { + return nil, err + } + } + + lo := &LargeObjects{fp: tx.conn.fp} + _, lo.Has64 = lo.fp.fns["lo_lseek64"] + + return lo, nil +} + +type LargeObjectMode int32 + +const ( + LargeObjectModeWrite LargeObjectMode = 0x20000 + LargeObjectModeRead LargeObjectMode = 0x40000 +) + +// Create creates a new large object. If id is zero, the server assigns an +// unused OID. +func (o *LargeObjects) Create(id Oid) (Oid, error) { + newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) + return Oid(newOid), err +} + +// Open opens an existing large object with the given mode. +func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) { + fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))})) + return &LargeObject{fd: fd, lo: o}, err +} + +// Unlink removes a large object from the database. +func (o *LargeObjects) Unlink(oid Oid) error { + _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))}) + return err +} + +// A LargeObject is a large object stored on the server. It is only valid within +// the transaction that it was initialized in. It implements these interfaces: +// +// io.Writer +// io.Reader +// io.Seeker +// io.Closer +type LargeObject struct { + fd int32 + lo *LargeObjects +} + +// Write writes p to the large object and returns the number of bytes written +// and an error if not all of p was written. +func (o *LargeObject) Write(p []byte) (int, error) { + n, err := fpInt32(o.lo.fp.CallFn("lowrite", []fpArg{fpIntArg(o.fd), p})) + return int(n), err +} + +// Read reads up to len(p) bytes into p returning the number of bytes read. +func (o *LargeObject) Read(p []byte) (int, error) { + res, err := o.lo.fp.CallFn("loread", []fpArg{fpIntArg(o.fd), fpIntArg(int32(len(p)))}) + if len(res) < len(p) { + err = io.EOF + } + return copy(p, res), err +} + +// Seek moves the current location pointer to the new location specified by offset. +func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) { + if o.lo.Has64 { + n, err = fpInt64(o.lo.fp.CallFn("lo_lseek64", []fpArg{fpIntArg(o.fd), fpInt64Arg(offset), fpIntArg(int32(whence))})) + } else { + var n32 int32 + n32, err = fpInt32(o.lo.fp.CallFn("lo_lseek", []fpArg{fpIntArg(o.fd), fpIntArg(int32(offset)), fpIntArg(int32(whence))})) + n = int64(n32) + } + return +} + +// Tell returns the current read or write location of the large object +// descriptor. +func (o *LargeObject) Tell() (n int64, err error) { + if o.lo.Has64 { + n, err = fpInt64(o.lo.fp.CallFn("lo_tell64", []fpArg{fpIntArg(o.fd)})) + } else { + var n32 int32 + n32, err = fpInt32(o.lo.fp.CallFn("lo_tell", []fpArg{fpIntArg(o.fd)})) + n = int64(n32) + } + return +} + +// Trunctes the large object to size. +func (o *LargeObject) Truncate(size int64) (err error) { + if o.lo.Has64 { + _, err = o.lo.fp.CallFn("lo_truncate64", []fpArg{fpIntArg(o.fd), fpInt64Arg(size)}) + } else { + _, err = o.lo.fp.CallFn("lo_truncate", []fpArg{fpIntArg(o.fd), fpIntArg(int32(size))}) + } + return +} + +// Close closees the large object descriptor. +func (o *LargeObject) Close() error { + _, err := o.lo.fp.CallFn("lo_close", []fpArg{fpIntArg(o.fd)}) + return err +} diff --git a/vendor/github.com/jackc/pgx/large_objects_test.go b/vendor/github.com/jackc/pgx/large_objects_test.go new file mode 100644 index 0000000..a19c851 --- /dev/null +++ b/vendor/github.com/jackc/pgx/large_objects_test.go @@ -0,0 +1,121 @@ +package pgx_test + +import ( + "io" + "testing" + + "github.com/jackc/pgx" +) + +func TestLargeObjects(t *testing.T) { + t.Parallel() + + conn, err := pgx.Connect(*defaultConnConfig) + if err != nil { + t.Fatal(err) + } + + tx, err := conn.Begin() + if err != nil { + t.Fatal(err) + } + + lo, err := tx.LargeObjects() + if err != nil { + t.Fatal(err) + } + + id, err := lo.Create(0) + if err != nil { + t.Fatal(err) + } + + obj, err := lo.Open(id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) + if err != nil { + t.Fatal(err) + } + + n, err := obj.Write([]byte("testing")) + if err != nil { + t.Fatal(err) + } + if n != 7 { + t.Errorf("Expected n to be 7, got %d", n) + } + + pos, err := obj.Seek(1, 0) + if err != nil { + t.Fatal(err) + } + if pos != 1 { + t.Errorf("Expected pos to be 1, got %d", pos) + } + + res := make([]byte, 6) + n, err = obj.Read(res) + if err != nil { + t.Fatal(err) + } + if string(res) != "esting" { + t.Errorf(`Expected res to be "esting", got %q`, res) + } + if n != 6 { + t.Errorf("Expected n to be 6, got %d", n) + } + + n, err = obj.Read(res) + if err != io.EOF { + t.Error("Expected io.EOF, go nil") + } + if n != 0 { + t.Errorf("Expected n to be 0, got %d", n) + } + + pos, err = obj.Tell() + if err != nil { + t.Fatal(err) + } + if pos != 7 { + t.Errorf("Expected pos to be 7, got %d", pos) + } + + err = obj.Truncate(1) + if err != nil { + t.Fatal(err) + } + + pos, err = obj.Seek(-1, 2) + if err != nil { + t.Fatal(err) + } + if pos != 0 { + t.Errorf("Expected pos to be 0, got %d", pos) + } + + res = make([]byte, 2) + n, err = obj.Read(res) + if err != io.EOF { + t.Errorf("Expected err to be io.EOF, got %v", err) + } + if n != 1 { + t.Errorf("Expected n to be 1, got %d", n) + } + if res[0] != 't' { + t.Errorf("Expected res[0] to be 't', got %v", res[0]) + } + + err = obj.Close() + if err != nil { + t.Fatal(err) + } + + err = lo.Unlink(id) + if err != nil { + t.Fatal(err) + } + + _, err = lo.Open(id, pgx.LargeObjectModeRead) + if e, ok := err.(pgx.PgError); !ok || e.Code != "42704" { + t.Errorf("Expected undefined_object error (42704), got %#v", err) + } +} diff --git a/vendor/github.com/jackc/pgx/logger.go b/vendor/github.com/jackc/pgx/logger.go new file mode 100644 index 0000000..4423325 --- /dev/null +++ b/vendor/github.com/jackc/pgx/logger.go @@ -0,0 +1,81 @@ +package pgx + +import ( + "encoding/hex" + "errors" + "fmt" +) + +// The values for log levels are chosen such that the zero value means that no +// log level was specified and we can default to LogLevelDebug to preserve +// the behavior that existed prior to log level introduction. +const ( + LogLevelTrace = 6 + LogLevelDebug = 5 + LogLevelInfo = 4 + LogLevelWarn = 3 + LogLevelError = 2 + LogLevelNone = 1 +) + +// Logger is the interface used to get logging from pgx internals. +// https://github.com/inconshreveable/log15 is the recommended logging package. +// This logging interface was extracted from there. However, it should be simple +// to adapt any logger to this interface. +type Logger interface { + // Log a message at the given level with context key/value pairs + Debug(msg string, ctx ...interface{}) + Info(msg string, ctx ...interface{}) + Warn(msg string, ctx ...interface{}) + Error(msg string, ctx ...interface{}) +} + +// LogLevelFromString converts log level string to constant +// +// Valid levels: +// trace +// debug +// info +// warn +// error +// none +func LogLevelFromString(s string) (int, error) { + switch s { + case "trace": + return LogLevelTrace, nil + case "debug": + return LogLevelDebug, nil + case "info": + return LogLevelInfo, nil + case "warn": + return LogLevelWarn, nil + case "error": + return LogLevelError, nil + case "none": + return LogLevelNone, nil + default: + return 0, errors.New("invalid log level") + } +} + +func logQueryArgs(args []interface{}) []interface{} { + logArgs := make([]interface{}, 0, len(args)) + + for _, a := range args { + switch v := a.(type) { + case []byte: + if len(v) < 64 { + a = hex.EncodeToString(v) + } else { + a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) + } + case string: + if len(v) > 64 { + a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) + } + } + logArgs = append(logArgs, a) + } + + return logArgs +} diff --git a/vendor/github.com/jackc/pgx/messages.go b/vendor/github.com/jackc/pgx/messages.go new file mode 100644 index 0000000..317ba27 --- /dev/null +++ b/vendor/github.com/jackc/pgx/messages.go @@ -0,0 +1,159 @@ +package pgx + +import ( + "encoding/binary" +) + +const ( + protocolVersionNumber = 196608 // 3.0 +) + +const ( + backendKeyData = 'K' + authenticationX = 'R' + readyForQuery = 'Z' + rowDescription = 'T' + dataRow = 'D' + commandComplete = 'C' + errorResponse = 'E' + noticeResponse = 'N' + parseComplete = '1' + parameterDescription = 't' + bindComplete = '2' + notificationResponse = 'A' + emptyQueryResponse = 'I' + noData = 'n' + closeComplete = '3' + flush = 'H' + copyInResponse = 'G' + copyData = 'd' + copyFail = 'f' + copyDone = 'c' +) + +type startupMessage struct { + options map[string]string +} + +func newStartupMessage() *startupMessage { + return &startupMessage{map[string]string{}} +} + +func (s *startupMessage) Bytes() (buf []byte) { + buf = make([]byte, 8, 128) + binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber)) + for key, value := range s.options { + buf = append(buf, key...) + buf = append(buf, 0) + buf = append(buf, value...) + buf = append(buf, 0) + } + buf = append(buf, ("\000")...) + binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) + return buf +} + +type FieldDescription struct { + Name string + Table Oid + AttributeNumber int16 + DataType Oid + DataTypeSize int16 + DataTypeName string + Modifier int32 + FormatCode int16 +} + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +func newWriteBuf(c *Conn, t byte) *WriteBuf { + buf := append(c.wbuf[0:0], t, 0, 0, 0, 0) + c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c} + return &c.writeBuf +} + +// WriteBuf is used build messages to send to the PostgreSQL server. It is used +// by the Encoder interface when implementing custom encoders. +type WriteBuf struct { + buf []byte + sizeIdx int + conn *Conn +} + +func (wb *WriteBuf) startMsg(t byte) { + wb.closeMsg() + wb.buf = append(wb.buf, t, 0, 0, 0, 0) + wb.sizeIdx = len(wb.buf) - 4 +} + +func (wb *WriteBuf) closeMsg() { + binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx)) +} + +func (wb *WriteBuf) WriteByte(b byte) { + wb.buf = append(wb.buf, b) +} + +func (wb *WriteBuf) WriteCString(s string) { + wb.buf = append(wb.buf, []byte(s)...) + wb.buf = append(wb.buf, 0) +} + +func (wb *WriteBuf) WriteInt16(n int16) { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, uint16(n)) + wb.buf = append(wb.buf, b...) +} + +func (wb *WriteBuf) WriteUint16(n uint16) { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + wb.buf = append(wb.buf, b...) +} + +func (wb *WriteBuf) WriteInt32(n int32) { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, uint32(n)) + wb.buf = append(wb.buf, b...) +} + +func (wb *WriteBuf) WriteUint32(n uint32) { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + wb.buf = append(wb.buf, b...) +} + +func (wb *WriteBuf) WriteInt64(n int64) { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(n)) + wb.buf = append(wb.buf, b...) +} + +func (wb *WriteBuf) WriteBytes(b []byte) { + wb.buf = append(wb.buf, b...) +} diff --git a/vendor/github.com/jackc/pgx/msg_reader.go b/vendor/github.com/jackc/pgx/msg_reader.go new file mode 100644 index 0000000..21db5d2 --- /dev/null +++ b/vendor/github.com/jackc/pgx/msg_reader.go @@ -0,0 +1,316 @@ +package pgx + +import ( + "bufio" + "encoding/binary" + "errors" + "io" +) + +// msgReader is a helper that reads values from a PostgreSQL message. +type msgReader struct { + reader *bufio.Reader + msgBytesRemaining int32 + err error + log func(lvl int, msg string, ctx ...interface{}) + shouldLog func(lvl int) bool +} + +// Err returns any error that the msgReader has experienced +func (r *msgReader) Err() error { + return r.err +} + +// fatal tells rc that a Fatal error has occurred +func (r *msgReader) fatal(err error) { + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) + } + r.err = err +} + +// rxMsg reads the type and size of the next message. +func (r *msgReader) rxMsg() (byte, error) { + if r.err != nil { + return 0, r.err + } + + if r.msgBytesRemaining > 0 { + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) + } + + _, err := r.reader.Discard(int(r.msgBytesRemaining)) + if err != nil { + return 0, err + } + } + + b, err := r.reader.Peek(5) + if err != nil { + r.fatal(err) + return 0, err + } + msgType := b[0] + r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 + r.reader.Discard(5) + return msgType, nil +} + +func (r *msgReader) readByte() byte { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining-- + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.ReadByte() + if err != nil { + r.fatal(err) + return 0 + } + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) + } + + return b +} + +func (r *msgReader) readInt16() int16 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 2 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(2) + if err != nil { + r.fatal(err) + return 0 + } + + n := int16(binary.BigEndian.Uint16(b)) + + r.reader.Discard(2) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + +func (r *msgReader) readInt32() int32 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 4 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(4) + if err != nil { + r.fatal(err) + return 0 + } + + n := int32(binary.BigEndian.Uint32(b)) + + r.reader.Discard(4) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + +func (r *msgReader) readUint16() uint16 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 2 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(2) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint16(binary.BigEndian.Uint16(b)) + + r.reader.Discard(2) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + +func (r *msgReader) readUint32() uint32 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 4 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(4) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint32(binary.BigEndian.Uint32(b)) + + r.reader.Discard(4) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + +func (r *msgReader) readInt64() int64 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 8 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(8) + if err != nil { + r.fatal(err) + return 0 + } + + n := int64(binary.BigEndian.Uint64(b)) + + r.reader.Discard(8) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + +func (r *msgReader) readOid() Oid { + return Oid(r.readInt32()) +} + +// readCString reads a null terminated string +func (r *msgReader) readCString() string { + if r.err != nil { + return "" + } + + b, err := r.reader.ReadBytes(0) + if err != nil { + r.fatal(err) + return "" + } + + r.msgBytesRemaining -= int32(len(b)) + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return "" + } + + s := string(b[0 : len(b)-1]) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + } + + return s +} + +// readString reads count bytes and returns as string +func (r *msgReader) readString(countI32 int32) string { + if r.err != nil { + return "" + } + + r.msgBytesRemaining -= countI32 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return "" + } + + count := int(countI32) + var s string + + if r.reader.Buffered() >= count { + buf, _ := r.reader.Peek(count) + s = string(buf) + r.reader.Discard(count) + } else { + buf := make([]byte, count) + _, err := io.ReadFull(r.reader, buf) + if err != nil { + r.fatal(err) + return "" + } + s = string(buf) + } + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + } + + return s +} + +// readBytes reads count bytes and returns as []byte +func (r *msgReader) readBytes(count int32) []byte { + if r.err != nil { + return nil + } + + r.msgBytesRemaining -= count + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return nil + } + + b := make([]byte, int(count)) + + _, err := io.ReadFull(r.reader, b) + if err != nil { + r.fatal(err) + return nil + } + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) + } + + return b +} diff --git a/vendor/github.com/jackc/pgx/pgpass.go b/vendor/github.com/jackc/pgx/pgpass.go new file mode 100644 index 0000000..b6f028d --- /dev/null +++ b/vendor/github.com/jackc/pgx/pgpass.go @@ -0,0 +1,85 @@ +package pgx + +import ( + "bufio" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" +) + +func parsepgpass(cfg *ConnConfig, line string) *string { + const ( + backslash = "\r" + colon = "\n" + ) + const ( + host int = iota + port + database + username + pw + ) + line = strings.Replace(line, `\:`, colon, -1) + line = strings.Replace(line, `\\`, backslash, -1) + parts := strings.Split(line, `:`) + if len(parts) != 5 { + return nil + } + for i := range parts { + if parts[i] == `*` { + continue + } + parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1) + switch i { + case host: + if parts[i] != cfg.Host { + return nil + } + case port: + portstr := fmt.Sprintf(`%v`, cfg.Port) + if portstr == "0" { + portstr = "5432" + } + if parts[i] != portstr { + return nil + } + case database: + if parts[i] != cfg.Database { + return nil + } + case username: + if parts[i] != cfg.User { + return nil + } + } + } + return &parts[4] +} + +func pgpass(cfg *ConnConfig) (found bool) { + passfile := os.Getenv("PGPASSFILE") + if passfile == "" { + u, err := user.Current() + if err != nil { + return + } + passfile = filepath.Join(u.HomeDir, ".pgpass") + } + f, err := os.Open(passfile) + if err != nil { + return + } + defer f.Close() + scanner := bufio.NewScanner(f) + var pw *string + for scanner.Scan() { + pw = parsepgpass(cfg, scanner.Text()) + if pw != nil { + cfg.Password = *pw + return true + } + } + return false +} diff --git a/vendor/github.com/jackc/pgx/pgpass_test.go b/vendor/github.com/jackc/pgx/pgpass_test.go new file mode 100644 index 0000000..f6094c8 --- /dev/null +++ b/vendor/github.com/jackc/pgx/pgpass_test.go @@ -0,0 +1,57 @@ +package pgx + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func unescape(s string) string { + s = strings.Replace(s, `\:`, `:`, -1) + s = strings.Replace(s, `\\`, `\`, -1) + return s +} + +var passfile = [][]string{ + []string{"test1", "5432", "larrydb", "larry", "whatstheidea"}, + []string{"test1", "5432", "moedb", "moe", "imbecile"}, + []string{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, + []string{"test2", "5432", "*", "shemp", "heymoe"}, + []string{"test2", "5432", "*", "*", `test\\ing\:`}, +} + +func TestPGPass(t *testing.T) { + tf, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer tf.Close() + defer os.Remove(tf.Name()) + os.Setenv("PGPASSFILE", tf.Name()) + for _, l := range passfile { + _, err := fmt.Fprintln(tf, strings.Join(l, `:`)) + if err != nil { + t.Fatal(err) + } + } + if err = tf.Close(); err != nil { + t.Fatal(err) + } + for i, l := range passfile { + cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]} + found := pgpass(&cfg) + if !found { + t.Fatalf("Entry %v not found", i) + } + if cfg.Password != unescape(l[4]) { + t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) + } + } + cfg := ConnConfig{Host: "derp", Database: "herp", User: "joe"} + found := pgpass(&cfg) + if found { + t.Fatal("bad found") + } +} diff --git a/vendor/github.com/jackc/pgx/query.go b/vendor/github.com/jackc/pgx/query.go new file mode 100644 index 0000000..19b867e --- /dev/null +++ b/vendor/github.com/jackc/pgx/query.go @@ -0,0 +1,494 @@ +package pgx + +import ( + "database/sql" + "errors" + "fmt" + "time" +) + +// Row is a convenience wrapper over Rows that is returned by QueryRow. +type Row Rows + +// Scan works the same as (*Rows Scan) with the following exceptions. If no +// rows were found it returns ErrNoRows. If multiple rows are returned it +// ignores all but the first. +func (r *Row) Scan(dest ...interface{}) (err error) { + rows := (*Rows)(r) + + if rows.Err() != nil { + return rows.Err() + } + + if !rows.Next() { + if rows.Err() == nil { + return ErrNoRows + } + return rows.Err() + } + + rows.Scan(dest...) + rows.Close() + return rows.Err() +} + +// Rows is the result set returned from *Conn.Query. Rows must be closed before +// the *Conn can be used again. Rows are closed by explicitly calling Close(), +// calling Next() until it returns false, or when a fatal error occurs. +type Rows struct { + conn *Conn + mr *msgReader + fields []FieldDescription + vr ValueReader + rowCount int + columnIdx int + err error + startTime time.Time + sql string + args []interface{} + afterClose func(*Rows) + unlockConn bool + closed bool +} + +func (rows *Rows) FieldDescriptions() []FieldDescription { + return rows.fields +} + +func (rows *Rows) close() { + if rows.closed { + return + } + + if rows.unlockConn { + rows.conn.unlock() + rows.unlockConn = false + } + + rows.closed = true + + if rows.err == nil { + if rows.conn.shouldLog(LogLevelInfo) { + endTime := time.Now() + rows.conn.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount) + } + } else if rows.conn.shouldLog(LogLevelError) { + rows.conn.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args)) + } + + if rows.afterClose != nil { + rows.afterClose(rows) + } +} + +func (rows *Rows) readUntilReadyForQuery() { + for { + t, r, err := rows.conn.rxMsg() + if err != nil { + rows.close() + return + } + + switch t { + case readyForQuery: + rows.conn.rxReadyForQuery(r) + rows.close() + return + case rowDescription: + case dataRow: + case commandComplete: + case bindComplete: + case errorResponse: + err = rows.conn.rxErrorResponse(r) + if rows.err == nil { + rows.err = err + } + default: + err = rows.conn.processContextFreeMsg(t, r) + if err != nil { + rows.close() + return + } + } + } +} + +// Close closes the rows, making the connection ready for use again. It is safe +// to call Close after rows is already closed. +func (rows *Rows) Close() { + if rows.closed { + return + } + rows.readUntilReadyForQuery() + rows.close() +} + +func (rows *Rows) Err() error { + return rows.err +} + +// abort signals that the query was not successfully sent to the server. +// This differs from Fatal in that it is not necessary to readUntilReadyForQuery +func (rows *Rows) abort(err error) { + if rows.err != nil { + return + } + + rows.err = err + rows.close() +} + +// Fatal signals an error occurred after the query was sent to the server. It +// closes the rows automatically. +func (rows *Rows) Fatal(err error) { + if rows.err != nil { + return + } + + rows.err = err + rows.Close() +} + +// Next prepares the next row for reading. It returns true if there is another +// row and false if no more rows are available. It automatically closes rows +// when all rows are read. +func (rows *Rows) Next() bool { + if rows.closed { + return false + } + + rows.rowCount++ + rows.columnIdx = 0 + rows.vr = ValueReader{} + + for { + t, r, err := rows.conn.rxMsg() + if err != nil { + rows.Fatal(err) + return false + } + + switch t { + case readyForQuery: + rows.conn.rxReadyForQuery(r) + rows.close() + return false + case dataRow: + fieldCount := r.readInt16() + if int(fieldCount) != len(rows.fields) { + rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) + return false + } + + rows.mr = r + return true + case commandComplete: + case bindComplete: + default: + err = rows.conn.processContextFreeMsg(t, r) + if err != nil { + rows.Fatal(err) + return false + } + } + } +} + +// Conn returns the *Conn this *Rows is using. +func (rows *Rows) Conn() *Conn { + return rows.conn +} + +func (rows *Rows) nextColumn() (*ValueReader, bool) { + if rows.closed { + return nil, false + } + if len(rows.fields) <= rows.columnIdx { + rows.Fatal(ProtocolError("No next column available")) + return nil, false + } + + if rows.vr.Len() > 0 { + rows.mr.readBytes(rows.vr.Len()) + } + + fd := &rows.fields[rows.columnIdx] + rows.columnIdx++ + size := rows.mr.readInt32() + rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size} + return &rows.vr, true +} + +type scanArgError struct { + col int + err error +} + +func (e scanArgError) Error() string { + return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) +} + +// Scan reads the values from the current row into dest values positionally. +// dest can include pointers to core types, values implementing the Scanner +// interface, []byte, and nil. []byte will skip the decoding process and directly +// copy the raw bytes received from PostgreSQL. nil will skip the value entirely. +func (rows *Rows) Scan(dest ...interface{}) (err error) { + if len(rows.fields) != len(dest) { + err = fmt.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) + rows.Fatal(err) + return err + } + + for i, d := range dest { + vr, _ := rows.nextColumn() + + if d == nil { + continue + } + + // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes + if b, ok := d.(*[]byte); ok { + // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) + // Otherwise read the bytes directly regardless of what the actual type is. + if vr.Type().DataType == ByteaOid { + *b = decodeBytea(vr) + } else { + if vr.Len() != -1 { + *b = vr.ReadBytes(vr.Len()) + } else { + *b = nil + } + } + } else if s, ok := d.(Scanner); ok { + err = s.Scan(vr) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if s, ok := d.(PgxScanner); ok { + err = s.ScanPgx(vr) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if s, ok := d.(sql.Scanner); ok { + var val interface{} + if 0 <= vr.Len() { + switch vr.Type().DataType { + case BoolOid: + val = decodeBool(vr) + case Int8Oid: + val = int64(decodeInt8(vr)) + case Int2Oid: + val = int64(decodeInt2(vr)) + case Int4Oid: + val = int64(decodeInt4(vr)) + case TextOid, VarcharOid: + val = decodeText(vr) + case OidOid: + val = int64(decodeOid(vr)) + case Float4Oid: + val = float64(decodeFloat4(vr)) + case Float8Oid: + val = decodeFloat8(vr) + case DateOid: + val = decodeDate(vr) + case TimestampOid: + val = decodeTimestamp(vr) + case TimestampTzOid: + val = decodeTimestampTz(vr) + default: + val = vr.ReadBytes(vr.Len()) + } + } + err = s.Scan(val) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if vr.Type().DataType == JsonOid { + // Because the argument passed to decodeJSON will escape the heap. + // This allows d to be stack allocated and only copied to the heap when + // we actually are decoding JSON. This saves one memory allocation per + // row. + d2 := d + decodeJSON(vr, &d2) + } else if vr.Type().DataType == JsonbOid { + // Same trick as above for getting stack allocation + d2 := d + decodeJSONB(vr, &d2) + } else { + if err := Decode(vr, d); err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } + if vr.Err() != nil { + rows.Fatal(scanArgError{col: i, err: vr.Err()}) + } + + if rows.Err() != nil { + return rows.Err() + } + } + + return nil +} + +// Values returns an array of the row values +func (rows *Rows) Values() ([]interface{}, error) { + if rows.closed { + return nil, errors.New("rows is closed") + } + + values := make([]interface{}, 0, len(rows.fields)) + + for range rows.fields { + vr, _ := rows.nextColumn() + + if vr.Len() == -1 { + values = append(values, nil) + continue + } + + switch vr.Type().FormatCode { + // All intrinsic types (except string) are encoded with binary + // encoding so anything else should be treated as a string + case TextFormatCode: + values = append(values, vr.ReadString(vr.Len())) + case BinaryFormatCode: + switch vr.Type().DataType { + case TextOid, VarcharOid: + values = append(values, decodeText(vr)) + case BoolOid: + values = append(values, decodeBool(vr)) + case ByteaOid: + values = append(values, decodeBytea(vr)) + case Int8Oid: + values = append(values, decodeInt8(vr)) + case Int2Oid: + values = append(values, decodeInt2(vr)) + case Int4Oid: + values = append(values, decodeInt4(vr)) + case OidOid: + values = append(values, decodeOid(vr)) + case Float4Oid: + values = append(values, decodeFloat4(vr)) + case Float8Oid: + values = append(values, decodeFloat8(vr)) + case BoolArrayOid: + values = append(values, decodeBoolArray(vr)) + case Int2ArrayOid: + values = append(values, decodeInt2Array(vr)) + case Int4ArrayOid: + values = append(values, decodeInt4Array(vr)) + case Int8ArrayOid: + values = append(values, decodeInt8Array(vr)) + case Float4ArrayOid: + values = append(values, decodeFloat4Array(vr)) + case Float8ArrayOid: + values = append(values, decodeFloat8Array(vr)) + case TextArrayOid, VarcharArrayOid: + values = append(values, decodeTextArray(vr)) + case TimestampArrayOid, TimestampTzArrayOid: + values = append(values, decodeTimestampArray(vr)) + case DateOid: + values = append(values, decodeDate(vr)) + case TimestampTzOid: + values = append(values, decodeTimestampTz(vr)) + case TimestampOid: + values = append(values, decodeTimestamp(vr)) + case InetOid, CidrOid: + values = append(values, decodeInet(vr)) + case JsonOid: + var d interface{} + decodeJSON(vr, &d) + values = append(values, d) + case JsonbOid: + var d interface{} + decodeJSONB(vr, &d) + values = append(values, d) + default: + rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) + } + default: + rows.Fatal(errors.New("Unknown format code")) + } + + if vr.Err() != nil { + rows.Fatal(vr.Err()) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + } + + return values, rows.Err() +} + +// AfterClose adds f to a LILO queue of functions that will be called when +// rows is closed. +func (rows *Rows) AfterClose(f func(*Rows)) { + if rows.afterClose == nil { + rows.afterClose = f + } else { + prevFn := rows.afterClose + rows.afterClose = func(rows *Rows) { + f(rows) + prevFn(rows) + } + } +} + +// Query executes sql with args. If there is an error the returned *Rows will +// be returned in an error state. So it is allowed to ignore the error returned +// from Query and handle it in *Rows. +func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { + c.lastActivityTime = time.Now() + + rows := c.getRows(sql, args) + + if err := c.lock(); err != nil { + rows.abort(err) + return rows, err + } + rows.unlockConn = true + + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.Prepare("", sql) + if err != nil { + rows.abort(err) + return rows, rows.err + } + } + rows.sql = ps.SQL + rows.fields = ps.FieldDescriptions + err := c.sendPreparedQuery(ps, args...) + if err != nil { + rows.abort(err) + } + return rows, rows.err +} + +func (c *Conn) getRows(sql string, args []interface{}) *Rows { + if len(c.preallocatedRows) == 0 { + c.preallocatedRows = make([]Rows, 64) + } + + r := &c.preallocatedRows[len(c.preallocatedRows)-1] + c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] + + r.conn = c + r.startTime = c.lastActivityTime + r.sql = sql + r.args = args + + return r +} + +// QueryRow is a convenience wrapper over Query. Any error that occurs while +// querying is deferred until calling Scan on the returned *Row. That *Row will +// error with ErrNoRows if no rows are returned. +func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { + rows, _ := c.Query(sql, args...) + return (*Row)(rows) +} diff --git a/vendor/github.com/jackc/pgx/query_test.go b/vendor/github.com/jackc/pgx/query_test.go new file mode 100644 index 0000000..f08887b --- /dev/null +++ b/vendor/github.com/jackc/pgx/query_test.go @@ -0,0 +1,1414 @@ +package pgx_test + +import ( + "bytes" + "database/sql" + "fmt" + "strings" + "testing" + "time" + + "github.com/jackc/pgx" + + "github.com/shopspring/decimal" +) + +func TestConnQueryScan(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var sum, rowCount int32 + + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } +} + +func TestConnQueryValues(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var rowCount int32 + + rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n::oid from generate_series(1,$1) n", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows.Close() + + for rows.Next() { + rowCount++ + + values, err := rows.Values() + if err != nil { + t.Fatalf("rows.Values failed: %v", err) + } + if len(values) != 5 { + t.Errorf("Expected rows.Values to return 5 values, but it returned %d", len(values)) + } + if values[0] != "foo" { + t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) + } + if values[1] != "bar" { + t.Errorf(`Expected values[1] to be "bar", but it was %v`, values[1]) + } + + if values[2] != rowCount { + t.Errorf(`Expected values[2] to be %d, but it was %d`, rowCount, values[2]) + } + + if values[3] != nil { + t.Errorf(`Expected values[3] to be %v, but it was %d`, nil, values[3]) + } + + if values[4] != pgx.Oid(rowCount) { + t.Errorf(`Expected values[4] to be %d, but it was %d`, rowCount, values[4]) + } + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } +} + +// Test that a connection stays valid when query results are closed early +func TestConnQueryCloseEarly(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Immediately close query without reading any rows + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + rows.Close() + + ensureConnValid(t, conn) + + // Read partial response then close + rows, err = conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n int32 + rows.Scan(&n) + if n != 1 { + t.Fatalf("Expected 1 from first row, but got %v", n) + } + + rows.Close() + + ensureConnValid(t, conn) +} + +func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select 1/(10-n) from generate_series(1,10) n") + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + rows.Close() + + ensureConnValid(t, conn) +} + +// Test that a connection stays valid when query results read incorrectly +func TestConnQueryReadWrongTypeError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Read a single value incorrectly + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + rowsRead := 0 + + for rows.Next() { + var t time.Time + rows.Scan(&t) + rowsRead++ + } + + if rowsRead != 1 { + t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) + } + + if rows.Err() == nil { + t.Fatal("Expected Rows to have an error after an improper read but it didn't") + } + + if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" { + t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) + } + + ensureConnValid(t, conn) +} + +// Test that a connection stays valid when query results read incorrectly +func TestConnQueryReadTooManyValues(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Read too many values + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + rowsRead := 0 + + for rows.Next() { + var n, m int32 + rows.Scan(&n, &m) + rowsRead++ + } + + if rowsRead != 1 { + t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) + } + + if rows.Err() == nil { + t.Fatal("Expected Rows to have an error after an improper read but it didn't") + } + + ensureConnValid(t, conn) +} + +func TestConnQueryScanIgnoreColumn(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select 1::int8, 2::int8, 3::int8") + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n, m int64 + err = rows.Scan(&n, nil, &m) + if err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rows.Close() + + if n != 1 { + t.Errorf("Expected n to equal 1, but it was %d", n) + } + + if m != 3 { + t.Errorf("Expected n to equal 3, but it was %d", m) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select null::int8, 1::int8") + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n, m pgx.NullInt64 + err = rows.Scan(&n, &m) + if err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rows.Close() + + if n.Valid { + t.Error("Null should not be valid, but it was") + } + + if !m.Valid { + t.Error("1 should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +type pgxNullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} + +func (n *pgxNullInt64) ScanPgx(vr *pgx.ValueReader) error { + if vr.Type().DataType != pgx.Int8Oid { + return pgx.SerializationError(fmt.Sprintf("pgxNullInt64.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + + err := pgx.Decode(vr, &n.Int64) + if err != nil { + return err + } + return vr.Err() +} + +func TestConnQueryPgxScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select null::int8, 1::int8") + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n, m pgxNullInt64 + err = rows.Scan(&n, &m) + if err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rows.Close() + + if n.Valid { + t.Error("Null should not be valid, but it was") + } + + if !m.Valid { + t.Error("1 should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryErrorWhileReturningRows(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for i := 0; i < 100; i++ { + func() { + sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)` + + rows, err := conn.Query(sql) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + } + + if err, ok := rows.Err().(pgx.PgError); !ok { + t.Fatalf("Expected pgx.PgError, got %v", err) + } + + ensureConnValid(t, conn) + }() + } + +} + +func TestConnQueryEncoder(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + n := pgx.NullInt64{Int64: 1, Valid: true} + + rows, err := conn.Query("select $1::int8", &n) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var m pgx.NullInt64 + err = rows.Scan(&m) + if err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rows.Close() + + if !m.Valid { + t.Error("m should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +func TestQueryEncodeError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select $1::integer", "wrong") + if err != nil { + t.Errorf("conn.Query failure: %v", err) + } + defer rows.Close() + + rows.Next() + + if rows.Err() == nil { + t.Error("Expected rows.Err() to return error, but it didn't") + } + if rows.Err().Error() != `ERROR: invalid input syntax for integer: "wrong" (SQLSTATE 22P02)` { + t.Error("Expected rows.Err() to return different error:", rows.Err()) + } +} + +// Ensure that an argument that implements Encoder works when the parameter type +// is a core type. +type coreEncoder struct{} + +func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode } + +func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { + w.WriteInt32(int32(2)) + w.WriteBytes([]byte("42")) + return nil +} + +func TestQueryEncodeCoreTextFormatError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var n int32 + err := conn.QueryRow("select $1::integer", &coreEncoder{}).Scan(&n) + if err != nil { + t.Fatalf("Unexpected conn.QueryRow error: %v", err) + } + + if n != 42 { + t.Errorf("Expected 42, got %v", n) + } +} + +func TestQueryRowCoreTypes(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s string + f32 float32 + f64 float64 + b bool + t time.Time + oid pgx.Oid + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, + {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, + {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, + {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, + {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, + {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, + {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, + {"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + + // Check that Scan errors when a core type is null + err = conn.QueryRow(tt.sql, nil).Scan(tt.scanArgs...) + if err == nil { + t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowCoreIntegerEncoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + ui uint + ui8 uint8 + ui16 uint16 + ui32 uint32 + ui64 uint64 + i int + i8 int8 + i16 int16 + i32 int32 + i64 int64 + } + + var actual, zero allTypes + + successfulEncodeTests := []struct { + sql string + queryArg interface{} + scanArg interface{} + expected allTypes + }{ + // Check any integer type where value is within int2 range can be encoded + {"select $1::int2", int(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", int8(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", int16(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", int32(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", int64(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", uint(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", uint8(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", uint16(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", uint32(42), &actual.i16, allTypes{i16: 42}}, + {"select $1::int2", uint64(42), &actual.i16, allTypes{i16: 42}}, + + // Check any integer type where value is within int4 range can be encoded + {"select $1::int4", int(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", int8(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", int16(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", int32(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", int64(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", uint(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", uint8(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", uint16(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", uint32(42), &actual.i32, allTypes{i32: 42}}, + {"select $1::int4", uint64(42), &actual.i32, allTypes{i32: 42}}, + + // Check any integer type where value is within int8 range can be encoded + {"select $1::int8", int(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", int8(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", int16(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", int32(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", int64(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", uint(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", uint8(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", uint16(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", uint32(42), &actual.i64, allTypes{i64: 42}}, + {"select $1::int8", uint64(42), &actual.i64, allTypes{i64: 42}}, + } + + for i, tt := range successfulEncodeTests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArg).Scan(tt.scanArg) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) + continue + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArg -> %v)", i, tt.expected, actual, tt.sql, tt.queryArg) + } + + ensureConnValid(t, conn) + } + + failedEncodeTests := []struct { + sql string + queryArg interface{} + }{ + // Check any integer type where value is outside pg:int2 range cannot be encoded + {"select $1::int2", int(32769)}, + {"select $1::int2", int32(32769)}, + {"select $1::int2", int32(32769)}, + {"select $1::int2", int64(32769)}, + {"select $1::int2", uint(32769)}, + {"select $1::int2", uint16(32769)}, + {"select $1::int2", uint32(32769)}, + {"select $1::int2", uint64(32769)}, + + // Check any integer type where value is outside pg:int4 range cannot be encoded + {"select $1::int4", int64(2147483649)}, + {"select $1::int4", uint32(2147483649)}, + {"select $1::int4", uint64(2147483649)}, + + // Check any integer type where value is outside pg:int8 range cannot be encoded + {"select $1::int8", uint64(9223372036854775809)}, + } + + for i, tt := range failedEncodeTests { + err := conn.QueryRow(tt.sql, tt.queryArg).Scan(nil) + if err == nil { + t.Errorf("%d. Expected failure to encode, but unexpectedly succeeded: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) + } else if !strings.Contains(err.Error(), "is greater than") { + t.Errorf("%d. Expected failure to encode, but got: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowCoreIntegerDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + ui uint + ui8 uint8 + ui16 uint16 + ui32 uint32 + ui64 uint64 + i int + i8 int8 + i16 int16 + i32 int32 + i64 int64 + } + + var actual, zero allTypes + + successfulDecodeTests := []struct { + sql string + scanArg interface{} + expected allTypes + }{ + // Check any integer type where value is within Go:int range can be decoded + {"select 42::int2", &actual.i, allTypes{i: 42}}, + {"select 42::int4", &actual.i, allTypes{i: 42}}, + {"select 42::int8", &actual.i, allTypes{i: 42}}, + {"select -42::int2", &actual.i, allTypes{i: -42}}, + {"select -42::int4", &actual.i, allTypes{i: -42}}, + {"select -42::int8", &actual.i, allTypes{i: -42}}, + + // Check any integer type where value is within Go:int8 range can be decoded + {"select 42::int2", &actual.i8, allTypes{i8: 42}}, + {"select 42::int4", &actual.i8, allTypes{i8: 42}}, + {"select 42::int8", &actual.i8, allTypes{i8: 42}}, + {"select -42::int2", &actual.i8, allTypes{i8: -42}}, + {"select -42::int4", &actual.i8, allTypes{i8: -42}}, + {"select -42::int8", &actual.i8, allTypes{i8: -42}}, + + // Check any integer type where value is within Go:int16 range can be decoded + {"select 42::int2", &actual.i16, allTypes{i16: 42}}, + {"select 42::int4", &actual.i16, allTypes{i16: 42}}, + {"select 42::int8", &actual.i16, allTypes{i16: 42}}, + {"select -42::int2", &actual.i16, allTypes{i16: -42}}, + {"select -42::int4", &actual.i16, allTypes{i16: -42}}, + {"select -42::int8", &actual.i16, allTypes{i16: -42}}, + + // Check any integer type where value is within Go:int32 range can be decoded + {"select 42::int2", &actual.i32, allTypes{i32: 42}}, + {"select 42::int4", &actual.i32, allTypes{i32: 42}}, + {"select 42::int8", &actual.i32, allTypes{i32: 42}}, + {"select -42::int2", &actual.i32, allTypes{i32: -42}}, + {"select -42::int4", &actual.i32, allTypes{i32: -42}}, + {"select -42::int8", &actual.i32, allTypes{i32: -42}}, + + // Check any integer type where value is within Go:int64 range can be decoded + {"select 42::int2", &actual.i64, allTypes{i64: 42}}, + {"select 42::int4", &actual.i64, allTypes{i64: 42}}, + {"select 42::int8", &actual.i64, allTypes{i64: 42}}, + {"select -42::int2", &actual.i64, allTypes{i64: -42}}, + {"select -42::int4", &actual.i64, allTypes{i64: -42}}, + {"select -42::int8", &actual.i64, allTypes{i64: -42}}, + + // Check any integer type where value is within Go:uint range can be decoded + {"select 128::int2", &actual.ui, allTypes{ui: 128}}, + {"select 128::int4", &actual.ui, allTypes{ui: 128}}, + {"select 128::int8", &actual.ui, allTypes{ui: 128}}, + + // Check any integer type where value is within Go:uint8 range can be decoded + {"select 128::int2", &actual.ui8, allTypes{ui8: 128}}, + {"select 128::int4", &actual.ui8, allTypes{ui8: 128}}, + {"select 128::int8", &actual.ui8, allTypes{ui8: 128}}, + + // Check any integer type where value is within Go:uint16 range can be decoded + {"select 42::int2", &actual.ui16, allTypes{ui16: 42}}, + {"select 32768::int4", &actual.ui16, allTypes{ui16: 32768}}, + {"select 32768::int8", &actual.ui16, allTypes{ui16: 32768}}, + + // Check any integer type where value is within Go:uint32 range can be decoded + {"select 42::int2", &actual.ui32, allTypes{ui32: 42}}, + {"select 42::int4", &actual.ui32, allTypes{ui32: 42}}, + {"select 2147483648::int8", &actual.ui32, allTypes{ui32: 2147483648}}, + + // Check any integer type where value is within Go:uint64 range can be decoded + {"select 42::int2", &actual.ui64, allTypes{ui64: 42}}, + {"select 42::int4", &actual.ui64, allTypes{ui64: 42}}, + {"select 42::int8", &actual.ui64, allTypes{ui64: 42}}, + } + + for i, tt := range successfulDecodeTests { + actual = zero + + err := conn.QueryRow(tt.sql).Scan(tt.scanArg) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) + continue + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + + failedDecodeTests := []struct { + sql string + scanArg interface{} + expectedErr string + }{ + // Check any integer type where value is outside Go:int8 range cannot be decoded + {"select 128::int2", &actual.i8, "is greater than"}, + {"select 128::int4", &actual.i8, "is greater than"}, + {"select 128::int8", &actual.i8, "is greater than"}, + {"select -129::int2", &actual.i8, "is less than"}, + {"select -129::int4", &actual.i8, "is less than"}, + {"select -129::int8", &actual.i8, "is less than"}, + + // Check any integer type where value is outside Go:int16 range cannot be decoded + {"select 32768::int4", &actual.i16, "is greater than"}, + {"select 32768::int8", &actual.i16, "is greater than"}, + {"select -32769::int4", &actual.i16, "is less than"}, + {"select -32769::int8", &actual.i16, "is less than"}, + + // Check any integer type where value is outside Go:int32 range cannot be decoded + {"select 2147483648::int8", &actual.i32, "is greater than"}, + {"select -2147483649::int8", &actual.i32, "is less than"}, + + // Check any integer type where value is outside Go:uint range cannot be decoded + {"select -1::int2", &actual.ui, "is less than"}, + {"select -1::int4", &actual.ui, "is less than"}, + {"select -1::int8", &actual.ui, "is less than"}, + + // Check any integer type where value is outside Go:uint8 range cannot be decoded + {"select 256::int2", &actual.ui8, "is greater than"}, + {"select 256::int4", &actual.ui8, "is greater than"}, + {"select 256::int8", &actual.ui8, "is greater than"}, + {"select -1::int2", &actual.ui8, "is less than"}, + {"select -1::int4", &actual.ui8, "is less than"}, + {"select -1::int8", &actual.ui8, "is less than"}, + + // Check any integer type where value is outside Go:uint16 cannot be decoded + {"select 65536::int4", &actual.ui16, "is greater than"}, + {"select 65536::int8", &actual.ui16, "is greater than"}, + {"select -1::int2", &actual.ui16, "is less than"}, + {"select -1::int4", &actual.ui16, "is less than"}, + {"select -1::int8", &actual.ui16, "is less than"}, + + // Check any integer type where value is outside Go:uint32 range cannot be decoded + {"select 4294967296::int8", &actual.ui32, "is greater than"}, + {"select -1::int2", &actual.ui32, "is less than"}, + {"select -1::int4", &actual.ui32, "is less than"}, + {"select -1::int8", &actual.ui32, "is less than"}, + + // Check any integer type where value is outside Go:uint64 range cannot be decoded + {"select -1::int2", &actual.ui64, "is less than"}, + {"select -1::int4", &actual.ui64, "is less than"}, + {"select -1::int8", &actual.ui64, "is less than"}, + } + + for i, tt := range failedDecodeTests { + err := conn.QueryRow(tt.sql).Scan(tt.scanArg) + if err == nil { + t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql) + } else if !strings.Contains(err.Error(), tt.expectedErr) { + t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowCoreByteSlice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + queryArg interface{} + expected []byte + }{ + {"select $1::text", "Jack", []byte("Jack")}, + {"select $1::text", []byte("Jack"), []byte("Jack")}, + {"select $1::int4", int32(239023409), []byte{14, 63, 53, 49}}, + {"select $1::varchar", []byte("Jack"), []byte("Jack")}, + {"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}}, + } + + for i, tt := range tests { + var actual []byte + + err := conn.QueryRow(tt.sql, tt.queryArg).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) + } + + if !bytes.Equal(actual, tt.expected) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowByteSliceArgument(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := "select $1::int4" + queryArg := []byte{14, 63, 53, 49} + expected := int32(239023409) + + var actual int32 + + err := conn.QueryRow(sql, queryArg).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if expected != actual { + t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowUnknownType(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := "select $1::point" + expected := "(1,0)" + var actual string + + err := conn.QueryRow(sql, expected).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if actual != expected { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) + + } + + ensureConnValid(t, conn) +} + +func TestQueryRowErrors(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + i16 int16 + i int + s string + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + err string + }{ + {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "could not determine data type of parameter $1 (SQLSTATE 42P18)"}, + {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, + {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into any integer type"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot encode int8 into oid 600"}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err == nil { + t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) + } + if err != nil && !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowNoResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var n int32 + err := conn.QueryRow("select 1 where 1=0").Scan(&n) + if err != pgx.ErrNoRows { + t.Errorf("Expected pgx.ErrNoRows, got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreInt16Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []int16 + + tests := []struct { + sql string + expected []int16 + }{ + {"select $1::int2[]", []int16{1, 2, 3, 4, 5}}, + {"select $1::int2[]", []int16{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreInt32Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []int32 + + tests := []struct { + sql string + expected []int32 + }{ + {"select $1::int4[]", []int32{1, 2, 3, 4, 5}}, + {"select $1::int4[]", []int32{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreInt64Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []int64 + + tests := []struct { + sql string + expected []int64 + }{ + {"select $1::int8[]", []int64{1, 2, 3, 4, 5}}, + {"select $1::int8[]", []int64{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreFloat32Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []float32 + + tests := []struct { + sql string + expected []float32 + }{ + {"select $1::float4[]", []float32{1.5, 2.0, 3.5}}, + {"select $1::float4[]", []float32{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreFloat64Slice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []float64 + + tests := []struct { + sql string + expected []float64 + }{ + {"select $1::float8[]", []float64{1.5, 2.0, 3.5}}, + {"select $1::float8[]", []float64{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowCoreStringSlice(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []string + + tests := []struct { + sql string + expected []string + }{ + {"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}}, + {"select $1::text[]", []string{}}, + {"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}}, + {"select $1::varchar[]", []string{}}, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v", i, err) + } + + if len(actual) != len(tt.expected) { + t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) + } + + for j := 0; j < len(actual); j++ { + if actual[j] != tt.expected[j] { + t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) + } + } + + ensureConnValid(t, conn) + } + + // Check that Scan errors when an array with a null is scanned into a core slice type + err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual) + if err == nil { + t.Error("Expected null to cause error when scanned into slice, but it didn't") + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) + } + + ensureConnValid(t, conn) +} + +func TestReadingValueAfterEmptyArray(t *testing.T) { + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var a []string + var b int32 + err := conn.QueryRow("select '{}'::text[], 42::integer").Scan(&a, &b) + if err != nil { + t.Fatalf("conn.QueryRow failed: %v", err) + } + + if len(a) != 0 { + t.Errorf("Expected 'a' to have length 0, but it was: %d", len(a)) + } + + if b != 42 { + t.Errorf("Expected 'b' to 42, but it was: %d", b) + } +} + +func TestReadingNullByteArray(t *testing.T) { + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var a []byte + err := conn.QueryRow("select null::text").Scan(&a) + if err != nil { + t.Fatalf("conn.QueryRow failed: %v", err) + } + + if a != nil { + t.Errorf("Expected 'a' to be nil, but it was: %v", a) + } +} + +func TestReadingNullByteArrays(t *testing.T) { + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select null::text union all select null::text") + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + count := 0 + for rows.Next() { + count++ + var a []byte + if err := rows.Scan(&a); err != nil { + t.Fatalf("failed to scan row: %v", err) + } + if a != nil { + t.Errorf("Expected 'a' to be nil, but it was: %v", a) + } + } + if count != 2 { + t.Errorf("Expected to read 2 rows, read: %d", count) + } +} + +// Use github.com/shopspring/decimal as real-world database/sql custom type +// to test against. +func TestConnQueryDatabaseSQLScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var num decimal.Decimal + + err := conn.QueryRow("select '1234.567'::decimal").Scan(&num) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + + expected, err := decimal.NewFromString("1234.567") + if err != nil { + t.Fatal(err) + } + + if !num.Equals(expected) { + t.Errorf("Expected num to be %v, but it was %v", expected, num) + } + + ensureConnValid(t, conn) +} + +// Use github.com/shopspring/decimal as real-world database/sql custom type +// to test against. +func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + expected, err := decimal.NewFromString("1234.567") + if err != nil { + t.Fatal(err) + } + var num decimal.Decimal + + err = conn.QueryRow("select $1::decimal", &expected).Scan(&num) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + + if !num.Equals(expected) { + t.Errorf("Expected num to be %v, but it was %v", expected, num) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryDatabaseSQLNullX(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type row struct { + boolValid sql.NullBool + boolNull sql.NullBool + int64Valid sql.NullInt64 + int64Null sql.NullInt64 + float64Valid sql.NullFloat64 + float64Null sql.NullFloat64 + stringValid sql.NullString + stringNull sql.NullString + } + + expected := row{ + boolValid: sql.NullBool{Bool: true, Valid: true}, + int64Valid: sql.NullInt64{Int64: 123, Valid: true}, + float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true}, + stringValid: sql.NullString{String: "pgx", Valid: true}, + } + + var actual row + + err := conn.QueryRow( + "select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text", + expected.boolValid, + expected.boolNull, + expected.int64Valid, + expected.int64Null, + expected.float64Valid, + expected.float64Null, + expected.stringValid, + expected.stringNull, + ).Scan( + &actual.boolValid, + &actual.boolNull, + &actual.int64Valid, + &actual.int64Null, + &actual.float64Valid, + &actual.float64Null, + &actual.stringValid, + &actual.stringNull, + ) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + + if expected != actual { + t.Errorf("Expected %v, but got %v", expected, actual) + } + + ensureConnValid(t, conn) +} diff --git a/vendor/github.com/jackc/pgx/replication.go b/vendor/github.com/jackc/pgx/replication.go new file mode 100644 index 0000000..7b28d6b --- /dev/null +++ b/vendor/github.com/jackc/pgx/replication.go @@ -0,0 +1,429 @@ +package pgx + +import ( + "errors" + "fmt" + "net" + "time" +) + +const ( + copyBothResponse = 'W' + walData = 'w' + senderKeepalive = 'k' + standbyStatusUpdate = 'r' + initialReplicationResponseTimeout = 5 * time.Second +) + +var epochNano int64 + +func init() { + epochNano = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixNano() +} + +// Format the given 64bit LSN value into the XXX/XXX format, +// which is the format reported by postgres. +func FormatLSN(lsn uint64) string { + return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn)) +} + +// Parse the given XXX/XXX format LSN as reported by postgres, +// into a 64 bit integer as used internally by the wire procotols +func ParseLSN(lsn string) (outputLsn uint64, err error) { + var upperHalf uint64 + var lowerHalf uint64 + var nparsed int + nparsed, err = fmt.Sscanf(lsn, "%X/%X", &upperHalf, &lowerHalf) + if err != nil { + return + } + + if nparsed != 2 { + err = errors.New(fmt.Sprintf("Failed to parsed LSN: %s", lsn)) + return + } + + outputLsn = (upperHalf << 32) + lowerHalf + return +} + +// The WAL message contains WAL payload entry data +type WalMessage struct { + // The WAL start position of this data. This + // is the WAL position we need to track. + WalStart uint64 + // The server wal end and server time are + // documented to track the end position and current + // time of the server, both of which appear to be + // unimplemented in pg 9.5. + ServerWalEnd uint64 + ServerTime uint64 + // The WAL data is the raw unparsed binary WAL entry. + // The contents of this are determined by the output + // logical encoding plugin. + WalData []byte +} + +func (w *WalMessage) Time() time.Time { + return time.Unix(0, (int64(w.ServerTime)*1000)+epochNano) +} + +func (w *WalMessage) ByteLag() uint64 { + return (w.ServerWalEnd - w.WalStart) +} + +func (w *WalMessage) String() string { + return fmt.Sprintf("Wal: %s Time: %s Lag: %d", FormatLSN(w.WalStart), w.Time(), w.ByteLag()) +} + +// The server heartbeat is sent periodically from the server, +// including server status, and a reply request field +type ServerHeartbeat struct { + // The current max wal position on the server, + // used for lag tracking + ServerWalEnd uint64 + // The server time, in microseconds since jan 1 2000 + ServerTime uint64 + // If 1, the server is requesting a standby status message + // to be sent immediately. + ReplyRequested byte +} + +func (s *ServerHeartbeat) Time() time.Time { + return time.Unix(0, (int64(s.ServerTime)*1000)+epochNano) +} + +func (s *ServerHeartbeat) String() string { + return fmt.Sprintf("WalEnd: %s ReplyRequested: %d T: %s", FormatLSN(s.ServerWalEnd), s.ReplyRequested, s.Time()) +} + +// The replication message wraps all possible messages from the +// server received during replication. At most one of the wal message +// or server heartbeat will be non-nil +type ReplicationMessage struct { + WalMessage *WalMessage + ServerHeartbeat *ServerHeartbeat +} + +// The standby status is the client side heartbeat sent to the postgresql +// server to track the client wal positions. For practical purposes, +// all wal positions are typically set to the same value. +type StandbyStatus struct { + // The WAL position that's been locally written + WalWritePosition uint64 + // The WAL position that's been locally flushed + WalFlushPosition uint64 + // The WAL position that's been locally applied + WalApplyPosition uint64 + // The client time in microseconds since jan 1 2000 + ClientTime uint64 + // If 1, requests the server to immediately send a + // server heartbeat + ReplyRequested byte +} + +// Create a standby status struct, which sets all the WAL positions +// to the given wal position, and the client time to the current time. +// The wal positions are, in order: +// WalFlushPosition +// WalApplyPosition +// WalWritePosition +// +// If only one position is provided, it will be used as the value for all 3 +// status fields. Note you must provide either 1 wal position, or all 3 +// in order to initialize the standby status. +func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error) { + if len(walPositions) == 1 { + status = new(StandbyStatus) + status.WalFlushPosition = walPositions[0] + status.WalApplyPosition = walPositions[0] + status.WalWritePosition = walPositions[0] + } else if len(walPositions) == 3 { + status = new(StandbyStatus) + status.WalFlushPosition = walPositions[0] + status.WalApplyPosition = walPositions[1] + status.WalWritePosition = walPositions[2] + } else { + err = errors.New(fmt.Sprintf("Invalid number of wal positions provided, need 1 or 3, got %d", len(walPositions))) + return + } + status.ClientTime = uint64((time.Now().UnixNano() - epochNano) / 1000) + return +} + +func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { + if config.RuntimeParams == nil { + config.RuntimeParams = make(map[string]string) + } + config.RuntimeParams["replication"] = "database" + + c, err := Connect(config) + if err != nil { + return + } + return &ReplicationConn{c: c}, nil +} + +type ReplicationConn struct { + c *Conn +} + +// Send standby status to the server, which both acts as a keepalive +// message to the server, as well as carries the WAL position of the +// client, which then updates the server's replication slot position. +func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { + writeBuf := newWriteBuf(rc.c, copyData) + writeBuf.WriteByte(standbyStatusUpdate) + writeBuf.WriteInt64(int64(k.WalWritePosition)) + writeBuf.WriteInt64(int64(k.WalFlushPosition)) + writeBuf.WriteInt64(int64(k.WalApplyPosition)) + writeBuf.WriteInt64(int64(k.ClientTime)) + writeBuf.WriteByte(k.ReplyRequested) + + writeBuf.closeMsg() + + _, err = rc.c.conn.Write(writeBuf.buf) + if err != nil { + rc.c.die(err) + } + + return +} + +func (rc *ReplicationConn) Close() error { + return rc.c.Close() +} + +func (rc *ReplicationConn) IsAlive() bool { + return rc.c.IsAlive() +} + +func (rc *ReplicationConn) CauseOfDeath() error { + return rc.c.CauseOfDeath() +} + +func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { + var t byte + var reader *msgReader + t, reader, err = rc.c.rxMsg() + if err != nil { + return + } + + switch t { + case noticeResponse: + pgError := rc.c.rxErrorResponse(reader) + if rc.c.shouldLog(LogLevelInfo) { + rc.c.log(LogLevelInfo, pgError.Error()) + } + case errorResponse: + err = rc.c.rxErrorResponse(reader) + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, err.Error()) + } + return + case copyBothResponse: + // This is the tail end of the replication process start, + // and can be safely ignored + return + case copyData: + var msgType byte + msgType = reader.readByte() + switch msgType { + case walData: + walStart := reader.readInt64() + serverWalEnd := reader.readInt64() + serverTime := reader.readInt64() + walData := reader.readBytes(reader.msgBytesRemaining) + walMessage := WalMessage{WalStart: uint64(walStart), + ServerWalEnd: uint64(serverWalEnd), + ServerTime: uint64(serverTime), + WalData: walData, + } + + return &ReplicationMessage{WalMessage: &walMessage}, nil + case senderKeepalive: + serverWalEnd := reader.readInt64() + serverTime := reader.readInt64() + replyNow := reader.readByte() + h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow} + return &ReplicationMessage{ServerHeartbeat: h}, nil + default: + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, "Unexpected data playload message type %v", t) + } + } + default: + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, "Unexpected replication message type %v", t) + } + } + return +} + +// Wait for a single replication message up to timeout time. +// +// Properly using this requires some knowledge of the postgres replication mechanisms, +// as the client can receive both WAL data (the ultimate payload) and server heartbeat +// updates. The caller also must send standby status updates in order to keep the connection +// alive and working. +// +// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified +// duration. +func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { + var zeroTime time.Time + + deadline := time.Now().Add(timeout) + + // Use SetReadDeadline to implement the timeout. SetReadDeadline will + // cause operations to fail with a *net.OpError that has a Timeout() + // of true. Because the normal pgx rxMsg path considers any error to + // have potentially corrupted the state of the connection, it dies + // on any errors. So to avoid timeout errors in rxMsg we set the + // deadline and peek into the reader. If a timeout error occurs there + // we don't break the pgx connection. If the Peek returns that data + // is available then we turn off the read deadline before the rxMsg. + err = rc.c.conn.SetReadDeadline(deadline) + if err != nil { + return nil, err + } + + // Wait until there is a byte available before continuing onto the normal msg reading path + _, err = rc.c.reader.Peek(1) + if err != nil { + rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline + if err, ok := err.(*net.OpError); ok && err.Timeout() { + return nil, ErrNotificationTimeout + } + return nil, err + } + + err = rc.c.conn.SetReadDeadline(zeroTime) + if err != nil { + return nil, err + } + + return rc.readReplicationMessage() +} + +func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { + rc.c.lastActivityTime = time.Now() + + rows := rc.c.getRows(sql, nil) + + if err := rc.c.lock(); err != nil { + rows.abort(err) + return rows, err + } + rows.unlockConn = true + + err := rc.c.sendSimpleQuery(sql) + if err != nil { + rows.abort(err) + } + + var t byte + var r *msgReader + t, r, err = rc.c.rxMsg() + if err != nil { + return nil, err + } + + switch t { + case rowDescription: + rows.fields = rc.c.rxRowDescription(r) + // We don't have c.PgTypes here because we're a replication + // connection. This means the field descriptions will have + // only Oids. Not much we can do about this. + default: + if e := rc.c.processContextFreeMsg(t, r); e != nil { + rows.abort(e) + return rows, e + } + } + + return rows, rows.err +} + +// Execute the "IDENTIFY_SYSTEM" command as documented here: +// https://www.postgresql.org/docs/9.5/static/protocol-replication.html +// +// This will return (if successful) a result set that has a single row +// that contains the systemid, current timeline, xlogpos and database +// name. +// +// NOTE: Because this is a replication mode connection, we don't have +// type names, so the field descriptions in the result will have only +// Oids and no DataTypeName values +func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { + return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM") +} + +// Execute the "TIMELINE_HISTORY" command as documented here: +// https://www.postgresql.org/docs/9.5/static/protocol-replication.html +// +// This will return (if successful) a result set that has a single row +// that contains the filename of the history file and the content +// of the history file. If called for timeline 1, typically this will +// generate an error that the timeline history file does not exist. +// +// NOTE: Because this is a replication mode connection, we don't have +// type names, so the field descriptions in the result will have only +// Oids and no DataTypeName values +func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) { + return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline)) +} + +// Start a replication connection, sending WAL data to the given replication +// receiver. This function wraps a START_REPLICATION command as documented +// here: +// https://www.postgresql.org/docs/9.5/static/protocol-replication.html +// +// Once started, the client needs to invoke WaitForReplicationMessage() in order +// to fetch the WAL and standby status. Also, it is the responsibility of the caller +// to periodically send StandbyStatus messages to update the replication slot position. +// +// This function assumes that slotName has already been created. In order to omit the timeline argument +// pass a -1 for the timeline to get the server default behavior. +func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, timeline int64, pluginArguments ...string) (err error) { + var queryString string + if timeline >= 0 { + queryString = fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s TIMELINE %d", slotName, FormatLSN(startLsn), timeline) + } else { + queryString = fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s", slotName, FormatLSN(startLsn)) + } + + for _, arg := range pluginArguments { + queryString += fmt.Sprintf(" %s", arg) + } + + if err = rc.c.sendQuery(queryString); err != nil { + return + } + + // The first replication message that comes back here will be (in a success case) + // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has + // started. This call will either return nil, nil or if it returns an error + // that indicates the start replication command failed + var r *ReplicationMessage + r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout) + if err != nil && r != nil { + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, "Unxpected replication message %v", r) + } + } + + return +} + +// Create the replication slot, using the given name and output plugin. +func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { + _, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin)) + return +} + +// Drop the replication slot for the given name +func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { + _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) + return +} diff --git a/vendor/github.com/jackc/pgx/replication_test.go b/vendor/github.com/jackc/pgx/replication_test.go new file mode 100644 index 0000000..4f810c7 --- /dev/null +++ b/vendor/github.com/jackc/pgx/replication_test.go @@ -0,0 +1,329 @@ +package pgx_test + +import ( + "fmt" + "github.com/jackc/pgx" + "reflect" + "strconv" + "strings" + "testing" + "time" +) + +// This function uses a postgresql 9.6 specific column +func getConfirmedFlushLsnFor(t *testing.T, conn *pgx.Conn, slot string) string { + // Fetch the restart LSN of the slot, to establish a starting point + rows, err := conn.Query(fmt.Sprintf("select confirmed_flush_lsn from pg_replication_slots where slot_name='%s'", slot)) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows.Close() + + var restartLsn string + for rows.Next() { + rows.Scan(&restartLsn) + } + return restartLsn +} + +// This battleship test (at least somewhat by necessity) does +// several things all at once in a single run. It: +// - Establishes a replication connection & slot +// - Does a series of operations to create some known WAL entries +// - Replicates the entries down, and checks that the rows it +// created come down in order +// - Sends a standby status message to update the server with the +// wal position of the slot +// - Checks the wal position of the slot on the server to make sure +// the update succeeded +func TestSimpleReplicationConnection(t *testing.T) { + t.Parallel() + + var err error + + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + conn := mustConnect(t, *replicationConnConfig) + defer closeConn(t, conn) + + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) + + err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding") + if err != nil { + t.Logf("replication slot create failed: %v", err) + } + + // Do a simple change so we can get some wal data + _, err = conn.Exec("create table if not exists replication_test (a integer)") + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + err = replicationConn.StartReplication("pgx_test", 0, -1) + if err != nil { + t.Fatalf("Failed to start replication: %v", err) + } + + var i int32 + var insertedTimes []int64 + for i < 5 { + var ct pgx.CommandTag + currentTime := time.Now().Unix() + insertedTimes = append(insertedTimes, currentTime) + ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime) + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + t.Logf("Inserted %d rows", ct.RowsAffected()) + i++ + } + + i = 0 + var foundTimes []int64 + var foundCount int + var maxWal uint64 + for { + var message *pgx.ReplicationMessage + + message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second)) + if err != nil { + if err != pgx.ErrNotificationTimeout { + t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) + } + } + if message != nil { + if message.WalMessage != nil { + // The waldata payload with the test_decoding plugin looks like: + // public.replication_test: INSERT: a[integer]:2 + // What we wanna do here is check that once we find one of our inserted times, + // that they occur in the wal stream in the order we executed them. + walString := string(message.WalMessage.WalData) + if strings.Contains(walString, "public.replication_test: INSERT") { + stringParts := strings.Split(walString, ":") + offset, err := strconv.ParseInt(stringParts[len(stringParts)-1], 10, 64) + if err != nil { + t.Fatalf("Failed to parse walString %s", walString) + } + if foundCount > 0 || offset == insertedTimes[0] { + foundTimes = append(foundTimes, offset) + foundCount++ + } + } + if message.WalMessage.WalStart > maxWal { + maxWal = message.WalMessage.WalStart + } + + } + if message.ServerHeartbeat != nil { + t.Logf("Got heartbeat: %s", message.ServerHeartbeat) + } + } else { + t.Log("Timed out waiting for wal message") + i++ + } + if i > 3 { + t.Log("Actual timeout") + break + } + } + + if foundCount != len(insertedTimes) { + t.Fatalf("Failed to find all inserted time values in WAL stream (found %d expected %d)", foundCount, len(insertedTimes)) + } + + for i := range insertedTimes { + if foundTimes[i] != insertedTimes[i] { + t.Fatalf("Found %d expected %d", foundTimes[i], insertedTimes[i]) + } + } + + t.Logf("Found %d times, as expected", len(foundTimes)) + + // Before closing our connection, let's send a standby status to update our wal + // position, which should then be reflected if we fetch out our current wal position + // for the slot + status, err := pgx.NewStandbyStatus(maxWal) + if err != nil { + t.Errorf("Failed to create standby status %v", err) + } + replicationConn.SendStandbyStatus(status) + + restartLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") + integerRestartLsn, _ := pgx.ParseLSN(restartLsn) + if integerRestartLsn != maxWal { + t.Fatalf("Wal offset update failed, expected %s found %s", pgx.FormatLSN(maxWal), restartLsn) + } + + closeReplicationConn(t, replicationConn) + + replicationConn2 := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn2) + + err = replicationConn2.DropReplicationSlot("pgx_test") + if err != nil { + t.Fatalf("Failed to drop replication slot: %v", err) + } + + droppedLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") + if droppedLsn != "" { + t.Errorf("Got odd flush lsn %s for supposedly dropped slot", droppedLsn) + } +} + +func TestReplicationConn_DropReplicationSlot(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) + + err := replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding") + if err != nil { + t.Logf("replication slot create failed: %v", err) + } + err = replicationConn.DropReplicationSlot("pgx_slot_test") + if err != nil { + t.Fatalf("Failed to drop replication slot: %v", err) + } + + // We re-create to ensure the drop worked. + err = replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding") + if err != nil { + t.Logf("replication slot create failed: %v", err) + } + + // And finally we drop to ensure we don't leave dirty state + err = replicationConn.DropReplicationSlot("pgx_slot_test") + if err != nil { + t.Fatalf("Failed to drop replication slot: %v", err) + } +} + +func TestIdentifySystem(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn2 := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn2) + + r, err := replicationConn2.IdentifySystem() + if err != nil { + t.Error(err) + } + defer r.Close() + for _, fd := range r.FieldDescriptions() { + t.Logf("Field: %s of type %v", fd.Name, fd.DataType) + } + + var rowCount int + for r.Next() { + rowCount++ + values, err := r.Values() + if err != nil { + t.Error(err) + } + t.Logf("Row values: %v", values) + } + if r.Err() != nil { + t.Error(r.Err()) + } + + if rowCount == 0 { + t.Errorf("Failed to find any rows: %d", rowCount) + } +} + +func getCurrentTimeline(t *testing.T, rc *pgx.ReplicationConn) int { + r, err := rc.IdentifySystem() + if err != nil { + t.Error(err) + } + defer r.Close() + for r.Next() { + values, e := r.Values() + if e != nil { + t.Error(e) + } + timeline, e := strconv.Atoi(values[1].(string)) + if e != nil { + t.Error(e) + } + return timeline + } + t.Fatal("Failed to read timeline") + return -1 +} + +func TestGetTimelineHistory(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) + + timeline := getCurrentTimeline(t, replicationConn) + + r, err := replicationConn.TimelineHistory(timeline) + if err != nil { + t.Errorf("%#v", err) + } + defer r.Close() + + for _, fd := range r.FieldDescriptions() { + t.Logf("Field: %s of type %v", fd.Name, fd.DataType) + } + + var rowCount int + for r.Next() { + rowCount++ + values, err := r.Values() + if err != nil { + t.Error(err) + } + t.Logf("Row values: %v", values) + } + if r.Err() != nil { + if strings.Contains(r.Err().Error(), "No such file or directory") { + // This is normal, this means the timeline we're on has no + // history, which is the common case in a test db that + // has only one timeline + return + } + t.Error(r.Err()) + } + + // If we have a timeline history (see above) there should have been + // rows emitted + if rowCount == 0 { + t.Errorf("Failed to find any rows: %d", rowCount) + } +} + +func TestStandbyStatusParsing(t *testing.T) { + // Let's push the boundary conditions of the standby status and ensure it errors correctly + status, err := pgx.NewStandbyStatus(0, 1, 2, 3, 4) + if err == nil { + t.Errorf("Expected error from new standby status, got %v", status) + } + + // And if you provide 3 args, ensure the right fields are set + status, err = pgx.NewStandbyStatus(1, 2, 3) + if err != nil { + t.Errorf("Failed to create test status: %v", err) + } + if status.WalFlushPosition != 1 { + t.Errorf("Unexpected flush position %d", status.WalFlushPosition) + } + if status.WalApplyPosition != 2 { + t.Errorf("Unexpected apply position %d", status.WalApplyPosition) + } + if status.WalWritePosition != 3 { + t.Errorf("Unexpected write position %d", status.WalWritePosition) + } +} diff --git a/vendor/github.com/jackc/pgx/sql.go b/vendor/github.com/jackc/pgx/sql.go new file mode 100644 index 0000000..7ee0f2a --- /dev/null +++ b/vendor/github.com/jackc/pgx/sql.go @@ -0,0 +1,29 @@ +package pgx + +import ( + "strconv" +) + +// QueryArgs is a container for arguments to an SQL query. It is helpful when +// building SQL statements where the number of arguments is variable. +type QueryArgs []interface{} + +var placeholders []string + +func init() { + placeholders = make([]string, 64) + + for i := 1; i < 64; i++ { + placeholders[i] = "$" + strconv.Itoa(i) + } +} + +// Append adds a value to qa and returns the placeholder value for the +// argument. e.g. $1, $2, etc. +func (qa *QueryArgs) Append(v interface{}) string { + *qa = append(*qa, v) + if len(*qa) < len(placeholders) { + return placeholders[len(*qa)] + } + return "$" + strconv.Itoa(len(*qa)) +} diff --git a/vendor/github.com/jackc/pgx/sql_test.go b/vendor/github.com/jackc/pgx/sql_test.go new file mode 100644 index 0000000..dd03603 --- /dev/null +++ b/vendor/github.com/jackc/pgx/sql_test.go @@ -0,0 +1,36 @@ +package pgx_test + +import ( + "strconv" + "testing" + + "github.com/jackc/pgx" +) + +func TestQueryArgs(t *testing.T) { + var qa pgx.QueryArgs + + for i := 1; i < 512; i++ { + expectedPlaceholder := "$" + strconv.Itoa(i) + placeholder := qa.Append(i) + if placeholder != expectedPlaceholder { + t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder) + } + } +} + +func BenchmarkQueryArgs(b *testing.B) { + for i := 0; i < b.N; i++ { + qa := pgx.QueryArgs(make([]interface{}, 0, 16)) + qa.Append("foo1") + qa.Append("foo2") + qa.Append("foo3") + qa.Append("foo4") + qa.Append("foo5") + qa.Append("foo6") + qa.Append("foo7") + qa.Append("foo8") + qa.Append("foo9") + qa.Append("foo10") + } +} diff --git a/vendor/github.com/jackc/pgx/stdlib/sql.go b/vendor/github.com/jackc/pgx/stdlib/sql.go new file mode 100644 index 0000000..8c78cd3 --- /dev/null +++ b/vendor/github.com/jackc/pgx/stdlib/sql.go @@ -0,0 +1,348 @@ +// Package stdlib is the compatibility layer from pgx to database/sql. +// +// A database/sql connection can be established through sql.Open. +// +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") +// if err != nil { +// return err +// } +// +// Or from a DSN string. +// +// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") +// if err != nil { +// return err +// } +// +// Or a normal pgx connection pool can be established and the database/sql +// connection can be created through stdlib.OpenFromConnPool(). This allows +// more control over the connection process (such as TLS), more control +// over the connection pool, setting an AfterConnect hook, and using both +// database/sql and pgx interfaces as needed. +// +// connConfig := pgx.ConnConfig{ +// Host: "localhost", +// User: "pgx_md5", +// Password: "secret", +// Database: "pgx_test", +// } +// +// config := pgx.ConnPoolConfig{ConnConfig: connConfig} +// pool, err := pgx.NewConnPool(config) +// if err != nil { +// return err +// } +// +// db, err := stdlib.OpenFromConnPool(pool) +// if err != nil { +// t.Fatalf("Unable to create connection pool: %v", err) +// } +// +// If the database/sql connection is established through +// stdlib.OpenFromConnPool then access to a pgx *ConnPool can be regained +// through db.Driver(). This allows writing a fast path for pgx while +// preserving compatibility with other drivers and database +// +// if driver, ok := db.Driver().(*stdlib.Driver); ok && driver.Pool != nil { +// // fast path with pgx +// } else { +// // normal path for other drivers and databases +// } +package stdlib + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "sync" + + "github.com/jackc/pgx" +) + +var ( + openFromConnPoolCountMu sync.Mutex + openFromConnPoolCount int +) + +// oids that map to intrinsic database/sql types. These will be allowed to be +// binary, anything else will be forced to text format +var databaseSqlOids map[pgx.Oid]bool + +func init() { + d := &Driver{} + sql.Register("pgx", d) + + databaseSqlOids = make(map[pgx.Oid]bool) + databaseSqlOids[pgx.BoolOid] = true + databaseSqlOids[pgx.ByteaOid] = true + databaseSqlOids[pgx.Int2Oid] = true + databaseSqlOids[pgx.Int4Oid] = true + databaseSqlOids[pgx.Int8Oid] = true + databaseSqlOids[pgx.Float4Oid] = true + databaseSqlOids[pgx.Float8Oid] = true + databaseSqlOids[pgx.DateOid] = true + databaseSqlOids[pgx.TimestampTzOid] = true + databaseSqlOids[pgx.TimestampOid] = true +} + +type Driver struct { + Pool *pgx.ConnPool +} + +func (d *Driver) Open(name string) (driver.Conn, error) { + if d.Pool != nil { + conn, err := d.Pool.Acquire() + if err != nil { + return nil, err + } + + return &Conn{conn: conn, pool: d.Pool}, nil + } + + connConfig, err := pgx.ParseConnectionString(name) + if err != nil { + return nil, err + } + + conn, err := pgx.Connect(connConfig) + if err != nil { + return nil, err + } + + c := &Conn{conn: conn} + return c, nil +} + +// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB +// with pool as the backend. This enables full control over the connection +// process and configuration while maintaining compatibility with the +// database/sql interface. In addition, by calling Driver() on the returned +// *sql.DB and typecasting to *stdlib.Driver a reference to the pgx.ConnPool can +// be reaquired later. This allows fast paths targeting pgx to be used while +// still maintaining compatibility with other databases and drivers. +// +// pool connection size must be at least 2. +func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { + d := &Driver{Pool: pool} + + openFromConnPoolCountMu.Lock() + name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) + openFromConnPoolCount++ + openFromConnPoolCountMu.Unlock() + + sql.Register(name, d) + db, err := sql.Open(name, "") + if err != nil { + return nil, err + } + + // Presumably OpenFromConnPool is being used because the user wants to use + // database/sql most of the time, but fast path with pgx some of the time. + // Allow database/sql to use all the connections, but release 2 idle ones. + // Don't have database/sql immediately release all idle connections because + // that would mean that prepared statements would be lost (which kills + // performance if the prepared statements constantly have to be reprepared) + stat := pool.Stat() + + if stat.MaxConnections <= 2 { + return nil, errors.New("pool connection size must be at least 3") + } + db.SetMaxIdleConns(stat.MaxConnections - 2) + db.SetMaxOpenConns(stat.MaxConnections) + + return db, nil +} + +type Conn struct { + conn *pgx.Conn + pool *pgx.ConnPool + psCount int64 // Counter used for creating unique prepared statement names +} + +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + name := fmt.Sprintf("pgx_%d", c.psCount) + c.psCount++ + + ps, err := c.conn.Prepare(name, query) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return &Stmt{ps: ps, conn: c}, nil +} + +func (c *Conn) Close() error { + err := c.conn.Close() + if c.pool != nil { + c.pool.Release(c.conn) + } + + return err +} + +func (c *Conn) Begin() (driver.Tx, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + _, err := c.conn.Exec("begin") + if err != nil { + return nil, err + } + + return &Tx{conn: c.conn}, nil +} + +func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := valueToInterface(argsV) + commandTag, err := c.conn.Exec(query, args...) + return driver.RowsAffected(commandTag.RowsAffected()), err +} + +func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + ps, err := c.conn.Prepare("", query) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return c.queryPrepared("", argsV) +} + +func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := valueToInterface(argsV) + + rows, err := c.conn.Query(name, args...) + if err != nil { + return nil, err + } + + return &Rows{rows: rows}, nil +} + +// Anything that isn't a database/sql compatible type needs to be forced to +// text format so that pgx.Rows.Values doesn't decode it into a native type +// (e.g. []int32) +func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { + for i := range ps.FieldDescriptions { + intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType] + if !intrinsic { + ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode + } + } +} + +type Stmt struct { + ps *pgx.PreparedStatement + conn *Conn +} + +func (s *Stmt) Close() error { + return s.conn.conn.Deallocate(s.ps.Name) +} + +func (s *Stmt) NumInput() int { + return len(s.ps.ParameterOids) +} + +func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { + return s.conn.Exec(s.ps.Name, argsV) +} + +func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { + return s.conn.queryPrepared(s.ps.Name, argsV) +} + +// TODO - rename to avoid alloc +type Rows struct { + rows *pgx.Rows +} + +func (r *Rows) Columns() []string { + fieldDescriptions := r.rows.FieldDescriptions() + names := make([]string, 0, len(fieldDescriptions)) + for _, fd := range fieldDescriptions { + names = append(names, fd.Name) + } + return names +} + +func (r *Rows) Close() error { + r.rows.Close() + return nil +} + +func (r *Rows) Next(dest []driver.Value) error { + more := r.rows.Next() + if !more { + if r.rows.Err() == nil { + return io.EOF + } else { + return r.rows.Err() + } + } + + values, err := r.rows.Values() + if err != nil { + return err + } + + if len(dest) < len(values) { + fmt.Printf("%d: %#v\n", len(dest), dest) + fmt.Printf("%d: %#v\n", len(values), values) + return errors.New("expected more values than were received") + } + + for i, v := range values { + dest[i] = driver.Value(v) + } + + return nil +} + +func valueToInterface(argsV []driver.Value) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v != nil { + args = append(args, v.(interface{})) + } else { + args = append(args, nil) + } + } + return args +} + +type Tx struct { + conn *pgx.Conn +} + +func (t *Tx) Commit() error { + _, err := t.conn.Exec("commit") + return err +} + +func (t *Tx) Rollback() error { + _, err := t.conn.Exec("rollback") + return err +} diff --git a/vendor/github.com/jackc/pgx/stdlib/sql_test.go b/vendor/github.com/jackc/pgx/stdlib/sql_test.go new file mode 100644 index 0000000..1455ca1 --- /dev/null +++ b/vendor/github.com/jackc/pgx/stdlib/sql_test.go @@ -0,0 +1,691 @@ +package stdlib_test + +import ( + "bytes" + "database/sql" + "github.com/jackc/pgx" + "github.com/jackc/pgx/stdlib" + "sync" + "testing" +) + +func openDB(t *testing.T) *sql.DB { + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + + return db +} + +func closeDB(t *testing.T, db *sql.DB) { + err := db.Close() + if err != nil { + t.Fatalf("db.Close unexpectedly failed: %v", err) + } +} + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, db *sql.DB) { + var sum, rowCount int32 + + rows, err := db.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("db.Query failed: %v", err) + } + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("db.Query failed: %v", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } +} + +type preparer interface { + Prepare(query string) (*sql.Stmt, error) +} + +func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { + stmt, err := p.Prepare(sql) + if err != nil { + t.Fatalf("%v Prepare unexpectedly failed: %v", p, err) + } + + return stmt +} + +func closeStmt(t *testing.T, stmt *sql.Stmt) { + err := stmt.Close() + if err != nil { + t.Fatalf("stmt.Close unexpectedly failed: %v", err) + } +} + +func TestNormalLifeCycle(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") + defer closeStmt(t, stmt) + + rows, err := stmt.Query(int32(1), int32(10)) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var s string + var n int64 + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestSqlOpenDoesNotHavePool(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + driver := db.Driver().(*stdlib.Driver) + if driver.Pool != nil { + t.Fatal("Did not expect driver opened through database/sql to have Pool, but it did") + } +} + +func TestOpenFromConnPool(t *testing.T) { + connConfig := pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + } + + config := pgx.ConnPoolConfig{ConnConfig: connConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + db, err := stdlib.OpenFromConnPool(pool) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer closeDB(t, db) + + // Can get pgx.ConnPool from driver + driver := db.Driver().(*stdlib.Driver) + if driver.Pool == nil { + t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") + } + + // Normal sql/database still works + var n int64 + err = db.QueryRow("select 1").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow unexpectedly failed: %v", err) + } +} + +func TestOpenFromConnPoolRace(t *testing.T) { + wg := &sync.WaitGroup{} + connConfig := pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + } + + config := pgx.ConnPoolConfig{ConnConfig: connConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + db, err := stdlib.OpenFromConnPool(pool) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer closeDB(t, db) + + // Can get pgx.ConnPool from driver + driver := db.Driver().(*stdlib.Driver) + if driver.Pool == nil { + t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") + } + }() + } + + wg.Wait() +} + +func TestStmtExec(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatalf("db.Begin unexpectedly failed: %v", err) + } + + createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") + _, err = createStmt.Exec() + if err != nil { + t.Fatalf("stmt.Exec unexpectedly failed: %v", err) + } + closeStmt(t, createStmt) + + insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") + result, err := insertStmt.Exec("foo") + if err != nil { + t.Fatalf("stmt.Exec unexpectedly failed: %v", err) + } + + n, err := result.RowsAffected() + if err != nil { + t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("Expected 1, received %d", n) + } + closeStmt(t, insertStmt) + + if err != nil { + t.Fatalf("tx.Commit unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestQueryCloseRowsEarly(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") + defer closeStmt(t, stmt) + + rows, err := stmt.Query(int32(1), int32(10)) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + // Close rows immediately without having read them + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + // Run the query again to ensure the connection and statement are still ok + rows, err = stmt.Query(int32(1), int32(10)) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var s string + var n int64 + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnExec(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(a varchar not null)") + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + result, err := db.Exec("insert into t values('hey')") + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + n, err := result.RowsAffected() + if err != nil { + t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("Expected 1, received %d", n) + } + + ensureConnValid(t, db) +} + +func TestConnQuery(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) + if err != nil { + t.Fatalf("db.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var s string + var n int64 + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + +type testLog struct { + lvl int + msg string + ctx []interface{} +} + +type testLogger struct { + logs []testLog +} + +func (l *testLogger) Debug(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx}) +} +func (l *testLogger) Info(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx}) +} +func (l *testLogger) Warn(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx}) +} +func (l *testLogger) Error(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx}) +} + +func TestConnQueryLog(t *testing.T) { + logger := &testLogger{} + + connConfig := pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + Logger: logger, + } + + config := pgx.ConnPoolConfig{ConnConfig: connConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + db, err := stdlib.OpenFromConnPool(pool) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer closeDB(t, db) + + // clear logs from initial connection + logger.logs = []testLog{} + + var n int64 + err = db.QueryRow("select 1").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow unexpectedly failed: %v", err) + } + + l := logger.logs[0] + if l.msg != "Query" { + t.Errorf("Expected to log Query, but got %v", l) + } + + if !(l.ctx[0] == "sql" && l.ctx[1] == "select 1") { + t.Errorf("Expected to log Query with sql 'select 1', but got %v", l) + } +} + +func TestConnQueryNull(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.Query("select $1::int", nil) + if err != nil { + t.Fatalf("db.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var n sql.NullInt64 + if err := rows.Scan(&n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + if n.Valid != false { + t.Errorf("Expected n to be null, but it was %v", n) + } + } + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + if rowCount != 1 { + t.Fatalf("Expected to receive 11 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnQueryRowByteSlice(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + expected := []byte{222, 173, 190, 239} + var actual []byte + + err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) + if err != nil { + t.Fatalf("db.QueryRow unexpectedly failed: %v", err) + } + + if bytes.Compare(actual, expected) != 0 { + t.Fatalf("Expected %v, but got %v", expected, actual) + } + + ensureConnValid(t, db) +} + +func TestConnQueryFailure(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Query("select 'foo") + if _, ok := err.(pgx.PgError); !ok { + t.Fatalf("Expected db.Query to return pgx.PgError, but instead received: %v", err) + } + + ensureConnValid(t, db) +} + +// Test type that pgx would handle natively in binary, but since it is not a +// database/sql native type should be passed through as a string +func TestConnQueryRowPgxBinary(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + sql := "select $1::int4[]" + expected := "{1,2,3}" + var actual string + + err := db.QueryRow(sql, expected).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if actual != expected { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) + } + + ensureConnValid(t, db) +} + +func TestConnQueryRowUnknownType(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + sql := "select $1::point" + expected := "(1,2)" + var actual string + + err := db.QueryRow(sql, expected).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if actual != expected { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) + } + + ensureConnValid(t, db) +} + +func TestConnQueryJSONIntoByteSlice(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + if !serverHasJSON(t, db) { + t.Skip("Skipping due to server's lack of JSON type") + } + + _, err := db.Exec(` + create temporary table docs( + body json not null + ); + + insert into docs(body) values('{"foo":"bar"}'); +`) + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + sql := `select * from docs` + expected := []byte(`{"foo":"bar"}`) + var actual []byte + + err = db.QueryRow(sql).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if bytes.Compare(actual, expected) != 0 { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) + } + + _, err = db.Exec(`drop table docs`) + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + if !serverHasJSON(t, db) { + t.Skip("Skipping due to server's lack of JSON type") + } + + _, err := db.Exec(` + create temporary table docs( + body json not null + ); +`) + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + expected := []byte(`{"foo":"bar"}`) + + _, err = db.Exec(`insert into docs(body) values($1)`, expected) + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + var actual []byte + err = db.QueryRow(`select body from docs`).Scan(&actual) + if err != nil { + t.Fatalf("db.QueryRow unexpectedly failed: %v", err) + } + + if bytes.Compare(actual, expected) != 0 { + t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual)) + } + + _, err = db.Exec(`drop table docs`) + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + +func serverHasJSON(t *testing.T, db *sql.DB) bool { + var hasJSON bool + err := db.QueryRow(`select exists(select 1 from pg_type where typname='json')`).Scan(&hasJSON) + if err != nil { + t.Fatalf("db.QueryRow unexpectedly failed: %v", err) + } + return hasJSON +} + +func TestTransactionLifeCycle(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(a varchar not null)") + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatalf("db.Begin unexpectedly failed: %v", err) + } + + _, err = tx.Exec("insert into t values('hi')") + if err != nil { + t.Fatalf("tx.Exec unexpectedly failed: %v", err) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback unexpectedly failed: %v", err) + } + + var n int64 + err = db.QueryRow("select count(*) from t").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) + } + if n != 0 { + t.Fatalf("Expected 0 rows due to rollback, instead found %d", n) + } + + tx, err = db.Begin() + if err != nil { + t.Fatalf("db.Begin unexpectedly failed: %v", err) + } + + _, err = tx.Exec("insert into t values('hi')") + if err != nil { + t.Fatalf("tx.Exec unexpectedly failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("tx.Commit unexpectedly failed: %v", err) + } + + err = db.QueryRow("select count(*) from t").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("Expected 1 rows due to rollback, instead found %d", n) + } + + ensureConnValid(t, db) +} diff --git a/vendor/github.com/jackc/pgx/stress_test.go b/vendor/github.com/jackc/pgx/stress_test.go new file mode 100644 index 0000000..150d13c --- /dev/null +++ b/vendor/github.com/jackc/pgx/stress_test.go @@ -0,0 +1,346 @@ +package pgx_test + +import ( + "errors" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/jackc/fake" + "github.com/jackc/pgx" +) + +type execer interface { + Exec(sql string, arguments ...interface{}) (commandTag pgx.CommandTag, err error) +} +type queryer interface { + Query(sql string, args ...interface{}) (*pgx.Rows, error) +} +type queryRower interface { + QueryRow(sql string, args ...interface{}) *pgx.Row +} + +func TestStressConnPool(t *testing.T) { + maxConnections := 8 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + setupStressDB(t, pool) + + actions := []struct { + name string + fn func(*pgx.ConnPool, int) error + }{ + {"insertUnprepared", func(p *pgx.ConnPool, n int) error { return insertUnprepared(p, n) }}, + {"queryRowWithoutParams", func(p *pgx.ConnPool, n int) error { return queryRowWithoutParams(p, n) }}, + {"query", func(p *pgx.ConnPool, n int) error { return queryCloseEarly(p, n) }}, + {"queryCloseEarly", func(p *pgx.ConnPool, n int) error { return query(p, n) }}, + {"queryErrorWhileReturningRows", func(p *pgx.ConnPool, n int) error { return queryErrorWhileReturningRows(p, n) }}, + {"txInsertRollback", txInsertRollback}, + {"txInsertCommit", txInsertCommit}, + {"txMultipleQueries", txMultipleQueries}, + {"notify", notify}, + {"listenAndPoolUnlistens", listenAndPoolUnlistens}, + {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, + {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, + } + + var timer *time.Timer + if testing.Short() { + timer = time.NewTimer(5 * time.Second) + } else { + timer = time.NewTimer(60 * time.Second) + } + workerCount := 16 + + workChan := make(chan int) + doneChan := make(chan struct{}) + errChan := make(chan error) + + work := func() { + for n := range workChan { + action := actions[rand.Intn(len(actions))] + err := action.fn(pool, n) + if err != nil { + errChan <- err + break + } + } + doneChan <- struct{}{} + } + + for i := 0; i < workerCount; i++ { + go work() + } + + var stop bool + for i := 0; !stop; i++ { + select { + case <-timer.C: + stop = true + case workChan <- i: + case err := <-errChan: + close(workChan) + t.Fatal(err) + } + } + close(workChan) + + for i := 0; i < workerCount; i++ { + <-doneChan + } +} + +func TestStressTLSConnection(t *testing.T) { + t.Parallel() + + if tlsConnConfig == nil { + t.Skip("Skipping due to undefined tlsConnConfig") + } + + if testing.Short() { + t.Skip("Skipping due to testing -short") + } + + conn, err := pgx.Connect(*tlsConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer conn.Close() + + for i := 0; i < 50; i++ { + sql := `select * from generate_series(1, $1)` + + rows, err := conn.Query(sql, 2000000) + if err != nil { + t.Fatal(err) + } + + var n int32 + for rows.Next() { + rows.Scan(&n) + } + + if rows.Err() != nil { + t.Fatalf("queryCount: %d, Row number: %d. %v", i, n, rows.Err()) + } + } +} + +func setupStressDB(t *testing.T, pool *pgx.ConnPool) { + _, err := pool.Exec(` + drop table if exists widgets; + create table widgets( + id serial primary key, + name varchar not null, + description text, + creation_time timestamptz + ); +`) + if err != nil { + t.Fatal(err) + } +} + +func insertUnprepared(e execer, actionNum int) error { + sql := ` + insert into widgets(name, description, creation_time) + values($1, $2, $3)` + + _, err := e.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) + return err +} + +func queryRowWithoutParams(qr queryRower, actionNum int) error { + var id int32 + var name, description string + var creationTime time.Time + + sql := `select * from widgets order by random() limit 1` + + err := qr.QueryRow(sql).Scan(&id, &name, &description, &creationTime) + if err == pgx.ErrNoRows { + return nil + } + return err +} + +func query(q queryer, actionNum int) error { + sql := `select * from widgets order by random() limit $1` + + rows, err := q.Query(sql, 10) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var id int32 + var name, description string + var creationTime time.Time + rows.Scan(&id, &name, &description, &creationTime) + } + + return rows.Err() +} + +func queryCloseEarly(q queryer, actionNum int) error { + sql := `select * from generate_series(1,$1)` + + rows, err := q.Query(sql, 100) + if err != nil { + return err + } + defer rows.Close() + + for i := 0; i < 10 && rows.Next(); i++ { + var n int32 + rows.Scan(&n) + } + rows.Close() + + return rows.Err() +} + +func queryErrorWhileReturningRows(q queryer, actionNum int) error { + // This query should divide by 0 within the first number of rows + sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)` + + rows, err := q.Query(sql) + if err != nil { + return nil + } + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + } + + if _, ok := rows.Err().(pgx.PgError); ok { + return nil + } + return rows.Err() +} + +func notify(pool *pgx.ConnPool, actionNum int) error { + _, err := pool.Exec("notify stress") + return err +} + +func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { + conn, err := pool.Acquire() + if err != nil { + return err + } + defer pool.Release(conn) + + err = conn.Listen("stress") + if err != nil { + return err + } + + _, err = conn.WaitForNotification(100 * time.Millisecond) + if err == pgx.ErrNotificationTimeout { + return nil + } + return err +} + +func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error { + psName := fmt.Sprintf("poolPreparedStatement%d", actionNum) + + _, err := pool.Prepare(psName, "select $1::text") + if err != nil { + return err + } + + var s string + err = pool.QueryRow(psName, "hello").Scan(&s) + if err != nil { + return err + } + + if s != "hello" { + return fmt.Errorf("Prepared statement did not return expected value: %v", s) + } + + return pool.Deallocate(psName) +} + +func txInsertRollback(pool *pgx.ConnPool, actionNum int) error { + tx, err := pool.Begin() + if err != nil { + return err + } + + sql := ` + insert into widgets(name, description, creation_time) + values($1, $2, $3)` + + _, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) + if err != nil { + return err + } + + return tx.Rollback() +} + +func txInsertCommit(pool *pgx.ConnPool, actionNum int) error { + tx, err := pool.Begin() + if err != nil { + return err + } + + sql := ` + insert into widgets(name, description, creation_time) + values($1, $2, $3)` + + _, err = tx.Exec(sql, fake.ProductName(), fake.Sentences(), time.Now()) + if err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + +func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { + tx, err := pool.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + errExpectedTxDeath := errors.New("Expected tx death") + + actions := []struct { + name string + fn func() error + }{ + {"insertUnprepared", func() error { return insertUnprepared(tx, actionNum) }}, + {"queryRowWithoutParams", func() error { return queryRowWithoutParams(tx, actionNum) }}, + {"query", func() error { return query(tx, actionNum) }}, + {"queryCloseEarly", func() error { return queryCloseEarly(tx, actionNum) }}, + {"queryErrorWhileReturningRows", func() error { + err := queryErrorWhileReturningRows(tx, actionNum) + if err != nil { + return err + } + return errExpectedTxDeath + }}, + } + + for i := 0; i < 20; i++ { + action := actions[rand.Intn(len(actions))] + err := action.fn() + if err == errExpectedTxDeath { + return nil + } else if err != nil { + return err + } + } + + return tx.Commit() +} diff --git a/vendor/github.com/jackc/pgx/tx.go b/vendor/github.com/jackc/pgx/tx.go new file mode 100644 index 0000000..deb6c01 --- /dev/null +++ b/vendor/github.com/jackc/pgx/tx.go @@ -0,0 +1,207 @@ +package pgx + +import ( + "errors" + "fmt" +) + +// Transaction isolation levels +const ( + Serializable = "serializable" + RepeatableRead = "repeatable read" + ReadCommitted = "read committed" + ReadUncommitted = "read uncommitted" +) + +const ( + TxStatusInProgress = 0 + TxStatusCommitFailure = -1 + TxStatusRollbackFailure = -2 + TxStatusCommitSuccess = 1 + TxStatusRollbackSuccess = 2 +) + +var ErrTxClosed = errors.New("tx is closed") + +// ErrTxCommitRollback occurs when an error has occurred in a transaction and +// Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but +// it is treated as ROLLBACK. +var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") + +// Begin starts a transaction with the default isolation level for the current +// connection. To use a specific isolation level see BeginIso. +func (c *Conn) Begin() (*Tx, error) { + return c.begin("") +} + +// BeginIso starts a transaction with isoLevel as the transaction isolation +// level. +// +// Valid isolation levels (and their constants) are: +// serializable (pgx.Serializable) +// repeatable read (pgx.RepeatableRead) +// read committed (pgx.ReadCommitted) +// read uncommitted (pgx.ReadUncommitted) +func (c *Conn) BeginIso(isoLevel string) (*Tx, error) { + return c.begin(isoLevel) +} + +func (c *Conn) begin(isoLevel string) (*Tx, error) { + var beginSQL string + if isoLevel == "" { + beginSQL = "begin" + } else { + beginSQL = fmt.Sprintf("begin isolation level %s", isoLevel) + } + + _, err := c.Exec(beginSQL) + if err != nil { + return nil, err + } + + return &Tx{conn: c}, nil +} + +// Tx represents a database transaction. +// +// All Tx methods return ErrTxClosed if Commit or Rollback has already been +// called on the Tx. +type Tx struct { + conn *Conn + afterClose func(*Tx) + err error + status int8 +} + +// Commit commits the transaction +func (tx *Tx) Commit() error { + if tx.status != TxStatusInProgress { + return ErrTxClosed + } + + commandTag, err := tx.conn.Exec("commit") + if err == nil && commandTag == "COMMIT" { + tx.status = TxStatusCommitSuccess + } else if err == nil && commandTag == "ROLLBACK" { + tx.status = TxStatusCommitFailure + tx.err = ErrTxCommitRollback + } else { + tx.status = TxStatusCommitFailure + tx.err = err + } + + if tx.afterClose != nil { + tx.afterClose(tx) + } + return tx.err +} + +// Rollback rolls back the transaction. Rollback will return ErrTxClosed if the +// Tx is already closed, but is otherwise safe to call multiple times. Hence, a +// defer tx.Rollback() is safe even if tx.Commit() will be called first in a +// non-error condition. +func (tx *Tx) Rollback() error { + if tx.status != TxStatusInProgress { + return ErrTxClosed + } + + _, tx.err = tx.conn.Exec("rollback") + if tx.err == nil { + tx.status = TxStatusRollbackSuccess + } else { + tx.status = TxStatusRollbackFailure + } + + if tx.afterClose != nil { + tx.afterClose(tx) + } + return tx.err +} + +// Exec delegates to the underlying *Conn +func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + if tx.status != TxStatusInProgress { + return CommandTag(""), ErrTxClosed + } + + return tx.conn.Exec(sql, arguments...) +} + +// Prepare delegates to the underlying *Conn +func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) { + return tx.PrepareEx(name, sql, nil) +} + +// PrepareEx delegates to the underlying *Conn +func (tx *Tx) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { + if tx.status != TxStatusInProgress { + return nil, ErrTxClosed + } + + return tx.conn.PrepareEx(name, sql, opts) +} + +// Query delegates to the underlying *Conn +func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) { + if tx.status != TxStatusInProgress { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := ErrTxClosed + return &Rows{closed: true, err: err}, err + } + + return tx.conn.Query(sql, args...) +} + +// QueryRow delegates to the underlying *Conn +func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { + rows, _ := tx.Query(sql, args...) + return (*Row)(rows) +} + +// Deprecated. Use CopyFrom instead. CopyTo delegates to the underlying *Conn +func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + if tx.status != TxStatusInProgress { + return 0, ErrTxClosed + } + + return tx.conn.CopyTo(tableName, columnNames, rowSrc) +} + +// CopyFrom delegates to the underlying *Conn +func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { + if tx.status != TxStatusInProgress { + return 0, ErrTxClosed + } + + return tx.conn.CopyFrom(tableName, columnNames, rowSrc) +} + +// Conn returns the *Conn this transaction is using. +func (tx *Tx) Conn() *Conn { + return tx.conn +} + +// Status returns the status of the transaction from the set of +// pgx.TxStatus* constants. +func (tx *Tx) Status() int8 { + return tx.status +} + +// Err returns the final error state, if any, of calling Commit or Rollback. +func (tx *Tx) Err() error { + return tx.err +} + +// AfterClose adds f to a LILO queue of functions that will be called when +// the transaction is closed (either Commit or Rollback). +func (tx *Tx) AfterClose(f func(*Tx)) { + if tx.afterClose == nil { + tx.afterClose = f + } else { + prevFn := tx.afterClose + tx.afterClose = func(tx *Tx) { + f(tx) + prevFn(tx) + } + } +} diff --git a/vendor/github.com/jackc/pgx/tx_test.go b/vendor/github.com/jackc/pgx/tx_test.go new file mode 100644 index 0000000..435521a --- /dev/null +++ b/vendor/github.com/jackc/pgx/tx_test.go @@ -0,0 +1,297 @@ +package pgx_test + +import ( + "github.com/jackc/pgx" + "testing" + "time" +) + +func TestTransactionSuccessfulCommit(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + tx, err := conn.Begin() + if err != nil { + t.Fatalf("conn.Begin failed: %v", err) + } + + _, err = tx.Exec("insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("tx.Commit failed: %v", err) + } + + var n int64 + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 1 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } +} + +func TestTxCommitWhenTxBroken(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + tx, err := conn.Begin() + if err != nil { + t.Fatalf("conn.Begin failed: %v", err) + } + + if _, err := tx.Exec("insert into foo(id) values (1)"); err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + // Purposely break transaction + if _, err := tx.Exec("syntax error"); err == nil { + t.Fatal("Unexpected success") + } + + err = tx.Commit() + if err != pgx.ErrTxCommitRollback { + t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) + } + + var n int64 + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } +} + +func TestTxCommitSerializationFailure(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 5) + defer pool.Close() + + pool.Exec(`drop table if exists tx_serializable_sums`) + _, err := pool.Exec(`create table tx_serializable_sums(num integer);`) + if err != nil { + t.Fatalf("Unable to create temporary table: %v", err) + } + defer pool.Exec(`drop table tx_serializable_sums`) + + tx1, err := pool.BeginIso(pgx.Serializable) + if err != nil { + t.Fatalf("BeginIso failed: %v", err) + } + defer tx1.Rollback() + + tx2, err := pool.BeginIso(pgx.Serializable) + if err != nil { + t.Fatalf("BeginIso failed: %v", err) + } + defer tx2.Rollback() + + _, err = tx1.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`) + if err != nil { + t.Fatalf("Exec failed: %v", err) + } + + _, err = tx2.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`) + if err != nil { + t.Fatalf("Exec failed: %v", err) + } + + err = tx1.Commit() + if err != nil { + t.Fatalf("Commit failed: %v", err) + } + + err = tx2.Commit() + if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "40001" { + t.Fatalf("Expected serialization error 40001, got %#v", err) + } +} + +func TestTransactionSuccessfulRollback(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + tx, err := conn.Begin() + if err != nil { + t.Fatalf("conn.Begin failed: %v", err) + } + + _, err = tx.Exec("insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback failed: %v", err) + } + + var n int64 + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } +} + +func TestBeginIso(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} + for _, iso := range isoLevels { + tx, err := conn.BeginIso(iso) + if err != nil { + t.Fatalf("conn.BeginIso failed: %v", err) + } + + var level string + conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) + if level != iso { + t.Errorf("Expected to be in isolation level %v but was %v", iso, level) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback failed: %v", err) + } + } +} + +func TestTxAfterClose(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tx, err := conn.Begin() + if err != nil { + t.Fatal(err) + } + + var zeroTime, t1, t2 time.Time + tx.AfterClose(func(tx *pgx.Tx) { + t1 = time.Now() + }) + + tx.AfterClose(func(tx *pgx.Tx) { + t2 = time.Now() + }) + + tx.Rollback() + + if t1 == zeroTime { + t.Error("First Tx.AfterClose callback not called") + } + + if t2 == zeroTime { + t.Error("Second Tx.AfterClose callback not called") + } + + if t1.Before(t2) { + t.Errorf("AfterClose callbacks called out of order: %v, %v", t1, t2) + } +} + +func TestTxStatus(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tx, err := conn.Begin() + if err != nil { + t.Fatal(err) + } + + if status := tx.Status(); status != pgx.TxStatusInProgress { + t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + if status := tx.Status(); status != pgx.TxStatusRollbackSuccess { + t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) + } +} + +func TestTxErr(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tx, err := conn.Begin() + if err != nil { + t.Fatal(err) + } + + // Purposely break transaction + if _, err := tx.Exec("syntax error"); err == nil { + t.Fatal("Unexpected success") + } + + if err := tx.Commit(); err != pgx.ErrTxCommitRollback { + t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) + } + + if status := tx.Status(); status != pgx.TxStatusCommitFailure { + t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) + } + + if err := tx.Err(); err != pgx.ErrTxCommitRollback { + t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) + } +} diff --git a/vendor/github.com/jackc/pgx/value_reader.go b/vendor/github.com/jackc/pgx/value_reader.go new file mode 100644 index 0000000..a489754 --- /dev/null +++ b/vendor/github.com/jackc/pgx/value_reader.go @@ -0,0 +1,156 @@ +package pgx + +import ( + "errors" +) + +// ValueReader is used by the Scanner interface to decode values. +type ValueReader struct { + mr *msgReader + fd *FieldDescription + valueBytesRemaining int32 + err error +} + +// Err returns any error that the ValueReader has experienced +func (r *ValueReader) Err() error { + return r.err +} + +// Fatal tells r that a Fatal error has occurred +func (r *ValueReader) Fatal(err error) { + r.err = err +} + +// Len returns the number of unread bytes +func (r *ValueReader) Len() int32 { + return r.valueBytesRemaining +} + +// Type returns the *FieldDescription of the value +func (r *ValueReader) Type() *FieldDescription { + return r.fd +} + +func (r *ValueReader) ReadByte() byte { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining-- + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readByte() +} + +func (r *ValueReader) ReadInt16() int16 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 2 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readInt16() +} + +func (r *ValueReader) ReadUint16() uint16 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 2 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint16() +} + +func (r *ValueReader) ReadInt32() int32 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 4 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readInt32() +} + +func (r *ValueReader) ReadUint32() uint32 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 4 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint32() +} + +func (r *ValueReader) ReadInt64() int64 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 8 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readInt64() +} + +func (r *ValueReader) ReadOid() Oid { + return Oid(r.ReadUint32()) +} + +// ReadString reads count bytes and returns as string +func (r *ValueReader) ReadString(count int32) string { + if r.err != nil { + return "" + } + + r.valueBytesRemaining -= count + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return "" + } + + return r.mr.readString(count) +} + +// ReadBytes reads count bytes and returns as []byte +func (r *ValueReader) ReadBytes(count int32) []byte { + if r.err != nil { + return nil + } + + if count < 0 { + r.Fatal(errors.New("count must not be negative")) + return nil + } + + r.valueBytesRemaining -= count + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return nil + } + + return r.mr.readBytes(count) +} diff --git a/vendor/github.com/jackc/pgx/values.go b/vendor/github.com/jackc/pgx/values.go new file mode 100644 index 0000000..a189e18 --- /dev/null +++ b/vendor/github.com/jackc/pgx/values.go @@ -0,0 +1,3439 @@ +package pgx + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" + "io" + "math" + "net" + "reflect" + "regexp" + "strconv" + "strings" + "time" +) + +// PostgreSQL oids for common types +const ( + BoolOid = 16 + ByteaOid = 17 + CharOid = 18 + NameOid = 19 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + TidOid = 27 + XidOid = 28 + CidOid = 29 + JsonOid = 114 + CidrOid = 650 + CidrArrayOid = 651 + Float4Oid = 700 + Float8Oid = 701 + UnknownOid = 705 + InetOid = 869 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + ByteaArrayOid = 1001 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + AclItemOid = 1033 + AclItemArrayOid = 1034 + InetArrayOid = 1041 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + TimestampTzOid = 1184 + TimestampTzArrayOid = 1185 + RecordOid = 2249 + UuidOid = 2950 + JsonbOid = 3802 +) + +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) +const minInt = -maxInt - 1 + +// DefaultTypeFormats maps type names to their default requested format (text +// or binary). In theory the Scanner interface should be the one to determine +// the format of the returned values. However, the query has already been +// executed by the time Scan is called so it has no chance to set the format. +// So for types that should always be returned in binary the format should be +// set here. +var DefaultTypeFormats map[string]int16 + +func init() { + DefaultTypeFormats = map[string]int16{ + "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) + "_bool": BinaryFormatCode, + "_bytea": BinaryFormatCode, + "_cidr": BinaryFormatCode, + "_float4": BinaryFormatCode, + "_float8": BinaryFormatCode, + "_inet": BinaryFormatCode, + "_int2": BinaryFormatCode, + "_int4": BinaryFormatCode, + "_int8": BinaryFormatCode, + "_text": BinaryFormatCode, + "_timestamp": BinaryFormatCode, + "_timestamptz": BinaryFormatCode, + "_varchar": BinaryFormatCode, + "aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) + "bool": BinaryFormatCode, + "bytea": BinaryFormatCode, + "char": BinaryFormatCode, + "cid": BinaryFormatCode, + "cidr": BinaryFormatCode, + "date": BinaryFormatCode, + "float4": BinaryFormatCode, + "float8": BinaryFormatCode, + "json": BinaryFormatCode, + "jsonb": BinaryFormatCode, + "inet": BinaryFormatCode, + "int2": BinaryFormatCode, + "int4": BinaryFormatCode, + "int8": BinaryFormatCode, + "name": BinaryFormatCode, + "oid": BinaryFormatCode, + "record": BinaryFormatCode, + "text": BinaryFormatCode, + "tid": BinaryFormatCode, + "timestamp": BinaryFormatCode, + "timestamptz": BinaryFormatCode, + "varchar": BinaryFormatCode, + "xid": BinaryFormatCode, + } +} + +// SerializationError occurs on failure to encode or decode a value +type SerializationError string + +func (e SerializationError) Error() string { + return string(e) +} + +// Deprecated: Scanner is an interface used to decode values from the PostgreSQL +// server. To allow types to support pgx and database/sql.Scan this interface +// has been deprecated in favor of PgxScanner. +type Scanner interface { + // Scan MUST check r.Type().DataType (to check by OID) or + // r.Type().DataTypeName (to check by name) to ensure that it is scanning an + // expected column type. It also MUST check r.Type().FormatCode before + // decoding. It should not assume that it was called on a data type or format + // that it understands. + Scan(r *ValueReader) error +} + +// PgxScanner is an interface used to decode values from the PostgreSQL server. +// It is used exactly the same as the Scanner interface. It simply has renamed +// the method. +type PgxScanner interface { + // ScanPgx MUST check r.Type().DataType (to check by OID) or + // r.Type().DataTypeName (to check by name) to ensure that it is scanning an + // expected column type. It also MUST check r.Type().FormatCode before + // decoding. It should not assume that it was called on a data type or format + // that it understands. + ScanPgx(r *ValueReader) error +} + +// Encoder is an interface used to encode values for transmission to the +// PostgreSQL server. +type Encoder interface { + // Encode writes the value to w. + // + // If the value is NULL an int32(-1) should be written. + // + // Encode MUST check oid to see if the parameter data type is compatible. If + // this is not done, the PostgreSQL server may detect the error if the + // expected data size or format of the encoded data does not match. But if + // the encoded data is a valid representation of the data type PostgreSQL + // expects such as date and int4, incorrect data may be stored. + Encode(w *WriteBuf, oid Oid) error + + // FormatCode returns the format that the encoder writes the value. It must be + // either pgx.TextFormatCode or pgx.BinaryFormatCode. + FormatCode() int16 +} + +// NullFloat32 represents an float4 that may be null. NullFloat32 implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullFloat32 struct { + Float32 float32 + Valid bool // Valid is true if Float32 is not NULL +} + +func (n *NullFloat32) Scan(vr *ValueReader) error { + if vr.Type().DataType != Float4Oid { + return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Float32, n.Valid = 0, false + return nil + } + n.Valid = true + n.Float32 = decodeFloat4(vr) + return vr.Err() +} + +func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode } + +func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { + if oid != Float4Oid { + return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeFloat32(w, oid, n.Float32) +} + +// NullFloat64 represents an float8 that may be null. NullFloat64 implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullFloat64 struct { + Float64 float64 + Valid bool // Valid is true if Float64 is not NULL +} + +func (n *NullFloat64) Scan(vr *ValueReader) error { + if vr.Type().DataType != Float8Oid { + return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Float64, n.Valid = 0, false + return nil + } + n.Valid = true + n.Float64 = decodeFloat8(vr) + return vr.Err() +} + +func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } + +func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { + if oid != Float8Oid { + return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeFloat64(w, oid, n.Float64) +} + +// NullString represents an string that may be null. NullString implements the +// Scanner Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullString struct { + String string + Valid bool // Valid is true if String is not NULL +} + +func (n *NullString) Scan(vr *ValueReader) error { + // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later + + if vr.Len() == -1 { + n.String, n.Valid = "", false + return nil + } + + n.Valid = true + n.String = decodeText(vr) + return vr.Err() +} + +func (n NullString) FormatCode() int16 { return TextFormatCode } + +func (n NullString) Encode(w *WriteBuf, oid Oid) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeString(w, oid, n.String) +} + +// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type AclItem string + +// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullAclItem struct { + AclItem AclItem + Valid bool // Valid is true if AclItem is not NULL +} + +func (n *NullAclItem) Scan(vr *ValueReader) error { + if vr.Type().DataType != AclItemOid { + return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.AclItem, n.Valid = "", false + return nil + } + + n.Valid = true + n.AclItem = AclItem(decodeText(vr)) + return vr.Err() +} + +// Particularly important to return TextFormatCode, seeing as Postgres +// only ever sends aclitem as text, not binary. +func (n NullAclItem) FormatCode() int16 { return TextFormatCode } + +func (n NullAclItem) Encode(w *WriteBuf, oid Oid) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeString(w, oid, string(n.AclItem)) +} + +// Name is a type used for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. +type Name string + +// NullName represents a pgx.Name that may be null. NullName implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullName struct { + Name Name + Valid bool // Valid is true if Name is not NULL +} + +func (n *NullName) Scan(vr *ValueReader) error { + if vr.Type().DataType != NameOid { + return SerializationError(fmt.Sprintf("NullName.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Name, n.Valid = "", false + return nil + } + + n.Valid = true + n.Name = Name(decodeText(vr)) + return vr.Err() +} + +func (n NullName) FormatCode() int16 { return TextFormatCode } + +func (n NullName) Encode(w *WriteBuf, oid Oid) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeString(w, oid, string(n.Name)) +} + +// The pgx.Char type is for PostgreSQL's special 8-bit-only +// "char" type more akin to the C language's char type, or Go's byte type. +// (Note that the name in PostgreSQL itself is "char", in double-quotes, +// and not char.) It gets used a lot in PostgreSQL's system tables to hold +// a single ASCII character value (eg pg_class.relkind). +type Char byte + +// NullChar represents a pgx.Char that may be null. NullChar implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullChar struct { + Char Char + Valid bool // Valid is true if Char is not NULL +} + +func (n *NullChar) Scan(vr *ValueReader) error { + if vr.Type().DataType != CharOid { + return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Char, n.Valid = 0, false + return nil + } + n.Valid = true + n.Char = decodeChar(vr) + return vr.Err() +} + +func (n NullChar) FormatCode() int16 { return BinaryFormatCode } + +func (n NullChar) Encode(w *WriteBuf, oid Oid) error { + if oid != CharOid { + return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeChar(w, oid, n.Char) +} + +// NullInt16 represents a smallint that may be null. NullInt16 implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullInt16 struct { + Int16 int16 + Valid bool // Valid is true if Int16 is not NULL +} + +func (n *NullInt16) Scan(vr *ValueReader) error { + if vr.Type().DataType != Int2Oid { + return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Int16, n.Valid = 0, false + return nil + } + n.Valid = true + n.Int16 = decodeInt2(vr) + return vr.Err() +} + +func (n NullInt16) FormatCode() int16 { return BinaryFormatCode } + +func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { + if oid != Int2Oid { + return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeInt16(w, oid, n.Int16) +} + +// NullInt32 represents an integer that may be null. NullInt32 implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullInt32 struct { + Int32 int32 + Valid bool // Valid is true if Int32 is not NULL +} + +func (n *NullInt32) Scan(vr *ValueReader) error { + if vr.Type().DataType != Int4Oid { + return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Int32, n.Valid = 0, false + return nil + } + n.Valid = true + n.Int32 = decodeInt4(vr) + return vr.Err() +} + +func (n NullInt32) FormatCode() int16 { return BinaryFormatCode } + +func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { + if oid != Int4Oid { + return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeInt32(w, oid, n.Int32) +} + +// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, +// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented +// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h +// in the PostgreSQL sources. +type Oid uint32 + +// NullOid represents a Command Identifier (Oid) that may be null. NullOid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullOid struct { + Oid Oid + Valid bool // Valid is true if Oid is not NULL +} + +func (n *NullOid) Scan(vr *ValueReader) error { + if vr.Type().DataType != OidOid { + return SerializationError(fmt.Sprintf("NullOid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Oid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Oid = decodeOid(vr) + return vr.Err() +} + +func (n NullOid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullOid) Encode(w *WriteBuf, oid Oid) error { + if oid != OidOid { + return SerializationError(fmt.Sprintf("NullOid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeOid(w, oid, n.Oid) +} + +// Xid is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. +type Xid uint32 + +// NullXid represents a Transaction ID (Xid) that may be null. NullXid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullXid struct { + Xid Xid + Valid bool // Valid is true if Xid is not NULL +} + +func (n *NullXid) Scan(vr *ValueReader) error { + if vr.Type().DataType != XidOid { + return SerializationError(fmt.Sprintf("NullXid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Xid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Xid = decodeXid(vr) + return vr.Err() +} + +func (n NullXid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullXid) Encode(w *WriteBuf, oid Oid) error { + if oid != XidOid { + return SerializationError(fmt.Sprintf("NullXid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeXid(w, oid, n.Xid) +} + +// Cid is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type Cid uint32 + +// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullCid struct { + Cid Cid + Valid bool // Valid is true if Cid is not NULL +} + +func (n *NullCid) Scan(vr *ValueReader) error { + if vr.Type().DataType != CidOid { + return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Cid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Cid = decodeCid(vr) + return vr.Err() +} + +func (n NullCid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullCid) Encode(w *WriteBuf, oid Oid) error { + if oid != CidOid { + return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeCid(w, oid, n.Cid) +} + +// Tid is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type Tid struct { + BlockNumber uint32 + OffsetNumber uint16 +} + +// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullTid struct { + Tid Tid + Valid bool // Valid is true if Tid is not NULL +} + +func (n *NullTid) Scan(vr *ValueReader) error { + if vr.Type().DataType != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false + return nil + } + n.Valid = true + n.Tid = decodeTid(vr) + return vr.Err() +} + +func (n NullTid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullTid) Encode(w *WriteBuf, oid Oid) error { + if oid != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeTid(w, oid, n.Tid) +} + +// NullInt64 represents an bigint that may be null. NullInt64 implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} + +func (n *NullInt64) Scan(vr *ValueReader) error { + if vr.Type().DataType != Int8Oid { + return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + n.Int64 = decodeInt8(vr) + return vr.Err() +} + +func (n NullInt64) FormatCode() int16 { return BinaryFormatCode } + +func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { + if oid != Int8Oid { + return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeInt64(w, oid, n.Int64) +} + +// NullBool represents an bool that may be null. NullBool implements the Scanner +// and Encoder interfaces so it may be used both as an argument to Query[Row] +// and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullBool struct { + Bool bool + Valid bool // Valid is true if Bool is not NULL +} + +func (n *NullBool) Scan(vr *ValueReader) error { + if vr.Type().DataType != BoolOid { + return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Bool, n.Valid = false, false + return nil + } + n.Valid = true + n.Bool = decodeBool(vr) + return vr.Err() +} + +func (n NullBool) FormatCode() int16 { return BinaryFormatCode } + +func (n NullBool) Encode(w *WriteBuf, oid Oid) error { + if oid != BoolOid { + return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeBool(w, oid, n.Bool) +} + +// NullTime represents an time.Time that may be null. NullTime implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL +// types timestamptz, timestamp, and date. +// +// If Valid is false then the value is NULL. +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +func (n *NullTime) Scan(vr *ValueReader) error { + oid := vr.Type().DataType + if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { + return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Time, n.Valid = time.Time{}, false + return nil + } + + n.Valid = true + switch oid { + case TimestampTzOid: + n.Time = decodeTimestampTz(vr) + case TimestampOid: + n.Time = decodeTimestamp(vr) + case DateOid: + n.Time = decodeDate(vr) + } + + return vr.Err() +} + +func (n NullTime) FormatCode() int16 { return BinaryFormatCode } + +func (n NullTime) Encode(w *WriteBuf, oid Oid) error { + if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { + return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeTime(w, oid, n.Time) +} + +// Hstore represents an hstore column. It does not support a null column or null +// key values (use NullHstore for this). Hstore implements the Scanner and +// Encoder interfaces so it may be used both as an argument to Query[Row] and a +// destination for Scan. +type Hstore map[string]string + +func (h *Hstore) Scan(vr *ValueReader) error { + //oid for hstore not standardized, so we check its type name + if vr.Type().DataTypeName != "hstore" { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName))) + return nil + } + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null column into Hstore")) + return nil + } + + switch vr.Type().FormatCode { + case TextFormatCode: + m, err := parseHstoreToMap(vr.ReadString(vr.Len())) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) + return nil + } + hm := Hstore(m) + *h = hm + return nil + case BinaryFormatCode: + vr.Fatal(ProtocolError("Can't decode binary hstore")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } +} + +func (h Hstore) FormatCode() int16 { return TextFormatCode } + +func (h Hstore) Encode(w *WriteBuf, oid Oid) error { + var buf bytes.Buffer + + i := 0 + for k, v := range h { + i++ + ks := strings.Replace(k, `\`, `\\`, -1) + ks = strings.Replace(ks, `"`, `\"`, -1) + vs := strings.Replace(v, `\`, `\\`, -1) + vs = strings.Replace(vs, `"`, `\"`, -1) + buf.WriteString(`"`) + buf.WriteString(ks) + buf.WriteString(`"=>"`) + buf.WriteString(vs) + buf.WriteString(`"`) + if i < len(h) { + buf.WriteString(", ") + } + } + w.WriteInt32(int32(buf.Len())) + w.WriteBytes(buf.Bytes()) + return nil +} + +// NullHstore represents an hstore column that can be null or have null values +// associated with its keys. NullHstore implements the Scanner and Encoder +// interfaces so it may be used both as an argument to Query[Row] and a +// destination for Scan. +// +// If Valid is false, then the value of the entire hstore column is NULL +// If any of the NullString values in Store has Valid set to false, the key +// appears in the hstore column, but its value is explicitly set to NULL. +type NullHstore struct { + Hstore map[string]NullString + Valid bool +} + +func (h *NullHstore) Scan(vr *ValueReader) error { + //oid for hstore not standardized, so we check its type name + if vr.Type().DataTypeName != "hstore" { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName))) + return nil + } + + if vr.Len() == -1 { + h.Valid = false + return nil + } + + switch vr.Type().FormatCode { + case TextFormatCode: + store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len())) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) + return nil + } + h.Valid = true + h.Hstore = store + return nil + case BinaryFormatCode: + vr.Fatal(ProtocolError("Can't decode binary hstore")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } +} + +func (h NullHstore) FormatCode() int16 { return TextFormatCode } + +func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { + var buf bytes.Buffer + + if !h.Valid { + w.WriteInt32(-1) + return nil + } + + i := 0 + for k, v := range h.Hstore { + i++ + ks := strings.Replace(k, `\`, `\\`, -1) + ks = strings.Replace(ks, `"`, `\"`, -1) + if v.Valid { + vs := strings.Replace(v.String, `\`, `\\`, -1) + vs = strings.Replace(vs, `"`, `\"`, -1) + buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) + } else { + buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks)) + } + if i < len(h.Hstore) { + buf.WriteString(", ") + } + } + w.WriteInt32(int32(buf.Len())) + w.WriteBytes(buf.Bytes()) + return nil +} + +// Encode encodes arg into wbuf as the type oid. This allows implementations +// of the Encoder interface to delegate the actual work of encoding to the +// built-in functionality. +func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { + if arg == nil { + wbuf.WriteInt32(-1) + return nil + } + + switch arg := arg.(type) { + case Encoder: + return arg.Encode(wbuf, oid) + case driver.Valuer: + v, err := arg.Value() + if err != nil { + return err + } + return Encode(wbuf, oid, v) + case string: + return encodeString(wbuf, oid, arg) + case []AclItem: + return encodeAclItemSlice(wbuf, oid, arg) + case []byte: + return encodeByteSlice(wbuf, oid, arg) + case [][]byte: + return encodeByteSliceSlice(wbuf, oid, arg) + } + + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { + wbuf.WriteInt32(-1) + return nil + } + arg = refVal.Elem().Interface() + return Encode(wbuf, oid, arg) + } + + if oid == JsonOid { + return encodeJSON(wbuf, oid, arg) + } + if oid == JsonbOid { + return encodeJSONB(wbuf, oid, arg) + } + + switch arg := arg.(type) { + case []string: + return encodeStringSlice(wbuf, oid, arg) + case bool: + return encodeBool(wbuf, oid, arg) + case []bool: + return encodeBoolSlice(wbuf, oid, arg) + case int: + return encodeInt(wbuf, oid, arg) + case uint: + return encodeUInt(wbuf, oid, arg) + case Char: + return encodeChar(wbuf, oid, arg) + case AclItem: + // The aclitem data type goes over the wire using the same format as string, + // so just cast to string and use encodeString + return encodeString(wbuf, oid, string(arg)) + case Name: + // The name data type goes over the wire using the same format as string, + // so just cast to string and use encodeString + return encodeString(wbuf, oid, string(arg)) + case int8: + return encodeInt8(wbuf, oid, arg) + case uint8: + return encodeUInt8(wbuf, oid, arg) + case int16: + return encodeInt16(wbuf, oid, arg) + case []int16: + return encodeInt16Slice(wbuf, oid, arg) + case uint16: + return encodeUInt16(wbuf, oid, arg) + case []uint16: + return encodeUInt16Slice(wbuf, oid, arg) + case int32: + return encodeInt32(wbuf, oid, arg) + case []int32: + return encodeInt32Slice(wbuf, oid, arg) + case uint32: + return encodeUInt32(wbuf, oid, arg) + case []uint32: + return encodeUInt32Slice(wbuf, oid, arg) + case int64: + return encodeInt64(wbuf, oid, arg) + case []int64: + return encodeInt64Slice(wbuf, oid, arg) + case uint64: + return encodeUInt64(wbuf, oid, arg) + case []uint64: + return encodeUInt64Slice(wbuf, oid, arg) + case float32: + return encodeFloat32(wbuf, oid, arg) + case []float32: + return encodeFloat32Slice(wbuf, oid, arg) + case float64: + return encodeFloat64(wbuf, oid, arg) + case []float64: + return encodeFloat64Slice(wbuf, oid, arg) + case time.Time: + return encodeTime(wbuf, oid, arg) + case []time.Time: + return encodeTimeSlice(wbuf, oid, arg) + case net.IP: + return encodeIP(wbuf, oid, arg) + case []net.IP: + return encodeIPSlice(wbuf, oid, arg) + case net.IPNet: + return encodeIPNet(wbuf, oid, arg) + case []net.IPNet: + return encodeIPNetSlice(wbuf, oid, arg) + case Oid: + return encodeOid(wbuf, oid, arg) + case Xid: + return encodeXid(wbuf, oid, arg) + case Cid: + return encodeCid(wbuf, oid, arg) + default: + if strippedArg, ok := stripNamedType(&refVal); ok { + return Encode(wbuf, oid, strippedArg) + } + return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + } +} + +func stripNamedType(val *reflect.Value) (interface{}, bool) { + switch val.Kind() { + case reflect.Int: + return int(val.Int()), true + case reflect.Int8: + return int8(val.Int()), true + case reflect.Int16: + return int16(val.Int()), true + case reflect.Int32: + return int32(val.Int()), true + case reflect.Int64: + return int64(val.Int()), true + case reflect.Uint: + return uint(val.Uint()), true + case reflect.Uint8: + return uint8(val.Uint()), true + case reflect.Uint16: + return uint16(val.Uint()), true + case reflect.Uint32: + return uint32(val.Uint()), true + case reflect.Uint64: + return uint64(val.Uint()), true + case reflect.String: + return val.String(), true + } + + return nil, false +} + +// Decode decodes from vr into d. d must be a pointer. This allows +// implementations of the Decoder interface to delegate the actual work of +// decoding to the built-in functionality. +func Decode(vr *ValueReader, d interface{}) error { + switch v := d.(type) { + case *bool: + *v = decodeBool(vr) + case *int: + n := decodeInt(vr) + if n < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", n) + } else if n > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", n) + } + *v = int(n) + case *int8: + n := decodeInt(vr) + if n < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", n) + } else if n > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", n) + } + *v = int8(n) + case *int16: + n := decodeInt(vr) + if n < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", n) + } else if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", n) + } + *v = int16(n) + case *int32: + n := decodeInt(vr) + if n < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", n) + } else if n > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", n) + } + *v = int32(n) + case *int64: + n := decodeInt(vr) + if n < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", n) + } else if n > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", n) + } + *v = int64(n) + case *uint: + n := decodeInt(vr) + if n < 0 { + return fmt.Errorf("%d is less than zero for uint8", n) + } else if maxInt == math.MaxInt32 && n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint", n) + } + *v = uint(n) + case *uint8: + n := decodeInt(vr) + if n < 0 { + return fmt.Errorf("%d is less than zero for uint8", n) + } else if n > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", n) + } + *v = uint8(n) + case *uint16: + n := decodeInt(vr) + if n < 0 { + return fmt.Errorf("%d is less than zero for uint16", n) + } else if n > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", n) + } + *v = uint16(n) + case *uint32: + n := decodeInt(vr) + if n < 0 { + return fmt.Errorf("%d is less than zero for uint32", n) + } else if n > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", n) + } + *v = uint32(n) + case *uint64: + n := decodeInt(vr) + if n < 0 { + return fmt.Errorf("%d is less than zero for uint64", n) + } + *v = uint64(n) + case *Char: + *v = decodeChar(vr) + case *AclItem: + // aclitem goes over the wire just like text + *v = AclItem(decodeText(vr)) + case *Name: + // name goes over the wire just like text + *v = Name(decodeText(vr)) + case *Oid: + *v = decodeOid(vr) + case *Xid: + *v = decodeXid(vr) + case *Tid: + *v = decodeTid(vr) + case *Cid: + *v = decodeCid(vr) + case *string: + *v = decodeText(vr) + case *float32: + *v = decodeFloat4(vr) + case *float64: + *v = decodeFloat8(vr) + case *[]AclItem: + *v = decodeAclItemArray(vr) + case *[]bool: + *v = decodeBoolArray(vr) + case *[]int16: + *v = decodeInt2Array(vr) + case *[]uint16: + *v = decodeInt2ArrayToUInt(vr) + case *[]int32: + *v = decodeInt4Array(vr) + case *[]uint32: + *v = decodeInt4ArrayToUInt(vr) + case *[]int64: + *v = decodeInt8Array(vr) + case *[]uint64: + *v = decodeInt8ArrayToUInt(vr) + case *[]float32: + *v = decodeFloat4Array(vr) + case *[]float64: + *v = decodeFloat8Array(vr) + case *[]string: + *v = decodeTextArray(vr) + case *[]time.Time: + *v = decodeTimestampArray(vr) + case *[][]byte: + *v = decodeByteaArray(vr) + case *[]interface{}: + *v = decodeRecord(vr) + case *time.Time: + switch vr.Type().DataType { + case DateOid: + *v = decodeDate(vr) + case TimestampTzOid: + *v = decodeTimestampTz(vr) + case TimestampOid: + *v = decodeTimestamp(vr) + default: + return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) + } + case *net.IP: + ipnet := decodeInet(vr) + if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("Cannot decode netmask into *net.IP") + } + *v = ipnet.IP + case *[]net.IP: + ipnets := decodeInetArray(vr) + ips := make([]net.IP, len(ipnets)) + for i, ipnet := range ipnets { + if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("Cannot decode netmask into *net.IP") + } + ips[i] = ipnet.IP + } + *v = ips + case *net.IPNet: + *v = decodeInet(vr) + case *[]net.IPNet: + *v = decodeInetArray(vr) + default: + if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if d is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + // -1 is a null value + if vr.Len() == -1 { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + d = el.Interface() + return Decode(vr, d) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n := decodeInt(vr) + if el.OverflowInt(n) { + return fmt.Errorf("Scan cannot decode %d into %T", n, d) + } + el.SetInt(n) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n := decodeInt(vr) + if n < 0 { + return fmt.Errorf("%d is less than zero for %T", n, d) + } + if el.OverflowUint(uint64(n)) { + return fmt.Errorf("Scan cannot decode %d into %T", n, d) + } + el.SetUint(uint64(n)) + return nil + case reflect.String: + el.SetString(decodeText(vr)) + return nil + } + } + return fmt.Errorf("Scan cannot decode into %T", d) + } + + return nil +} + +func decodeBool(vr *ValueReader) bool { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into bool")) + return false + } + + if vr.Type().DataType != BoolOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) + return false + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return false + } + + if vr.Len() != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) + return false + } + + b := vr.ReadByte() + return b != 0 +} + +func encodeBool(w *WriteBuf, oid Oid, value bool) error { + if oid != BoolOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) + } + + w.WriteInt32(1) + + var n byte + if value { + n = 1 + } + + w.WriteByte(n) + + return nil +} + +func decodeInt(vr *ValueReader) int64 { + switch vr.Type().DataType { + case Int2Oid: + return int64(decodeInt2(vr)) + case Int4Oid: + return int64(decodeInt4(vr)) + case Int8Oid: + return int64(decodeInt8(vr)) + } + + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType))) + return 0 +} + +func decodeInt8(vr *ValueReader) int64 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into int64")) + return 0 + } + + if vr.Type().DataType != Int8Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) + return 0 + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return 0 + } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) + return 0 + } + + return vr.ReadInt64() +} + +func decodeChar(vr *ValueReader) Char { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into char")) + return Char(0) + } + + if vr.Type().DataType != CharOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType))) + return Char(0) + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Char(0) + } + + if vr.Len() != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len()))) + return Char(0) + } + + return Char(vr.ReadByte()) +} + +func decodeInt2(vr *ValueReader) int16 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 + } + + if vr.Type().DataType != Int2Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) + return 0 + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return 0 + } + + if vr.Len() != 2 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) + return 0 + } + + return vr.ReadInt16() +} + +func encodeInt(w *WriteBuf, oid Oid, value int) error { + switch oid { + case Int2Oid: + if value < math.MinInt16 { + return fmt.Errorf("%d is less than min pg:int2", value) + } else if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than max pg:int2", value) + } + w.WriteInt32(2) + w.WriteInt16(int16(value)) + case Int4Oid: + if value < math.MinInt32 { + return fmt.Errorf("%d is less than min pg:int4", value) + } else if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than max pg:int4", value) + } + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + if int64(value) <= int64(math.MaxInt64) { + w.WriteInt32(8) + w.WriteInt64(int64(value)) + } else { + return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64)) + } + default: + return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) + } + + return nil +} + +func encodeUInt(w *WriteBuf, oid Oid, value uint) error { + switch oid { + case Int2Oid: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than max pg:int2", value) + } + w.WriteInt32(2) + w.WriteInt16(int16(value)) + case Int4Oid: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than max pg:int4", value) + } + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + //****** Changed value to int64(value) and math.MaxInt64 to int64(math.MaxInt64) + if int64(value) > int64(math.MaxInt64) { + return fmt.Errorf("%d is greater than max pg:int8", value) + } + w.WriteInt32(8) + w.WriteInt64(int64(value)) + + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) + } + + return nil +} + +func encodeChar(w *WriteBuf, oid Oid, value Char) error { + w.WriteInt32(1) + w.WriteByte(byte(value)) + return nil +} + +func encodeInt8(w *WriteBuf, oid Oid, value int8) error { + switch oid { + case Int2Oid: + w.WriteInt32(2) + w.WriteInt16(int16(value)) + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) + } + + return nil +} + +func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error { + switch oid { + case Int2Oid: + w.WriteInt32(2) + w.WriteInt16(int16(value)) + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) + } + + return nil +} + +func encodeInt16(w *WriteBuf, oid Oid, value int16) error { + switch oid { + case Int2Oid: + w.WriteInt32(2) + w.WriteInt16(value) + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) + } + + return nil +} + +func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) + } + + return nil +} + +func encodeInt32(w *WriteBuf, oid Oid, value int32) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(value) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int32", oid) + } + + return nil +} + +func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + if value <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(value)) + } else { + return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) + } + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid) + } + + return nil +} + +func encodeInt64(w *WriteBuf, oid Oid, value int64) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + if value <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(value)) + } else { + return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) + } + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(value) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int64", oid) + } + + return nil +} + +func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + if value <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(value)) + } else { + return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) + } + case Int8Oid: + + if value <= math.MaxInt64 { + w.WriteInt32(8) + w.WriteInt64(int64(value)) + } else { + return fmt.Errorf("%d is greater than max int64 %d", value, int64(math.MaxInt64)) + } + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) + } + + return nil +} + +func decodeInt4(vr *ValueReader) int32 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into int32")) + return 0 + } + + if vr.Type().DataType != Int4Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) + return 0 + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return 0 + } + + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) + return 0 + } + + return vr.ReadInt32() +} + +func decodeOid(vr *ValueReader) Oid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Oid")) + return Oid(0) + } + + if vr.Type().DataType != OidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType))) + return Oid(0) + } + + // Oid needs to decode text format because it is used in loadPgTypes + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Oid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Oid(0) + } + return Oid(vr.ReadInt32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Oid(0) + } +} + +func encodeOid(w *WriteBuf, oid Oid, value Oid) error { + if oid != OidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Oid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + +func decodeXid(vr *ValueReader) Xid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Xid")) + return Xid(0) + } + + if vr.Type().DataType != XidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Xid", vr.Type().DataType))) + return Xid(0) + } + + // Unlikely Xid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Xid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Xid(0) + } + return Xid(vr.ReadUint32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Xid(0) + } +} + +func encodeXid(w *WriteBuf, oid Oid, value Xid) error { + if oid != XidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Xid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + +func decodeCid(vr *ValueReader) Cid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Cid")) + return Cid(0) + } + + if vr.Type().DataType != CidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType))) + return Cid(0) + } + + // Unlikely Cid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Cid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Cid(0) + } + return Cid(vr.ReadUint32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Cid(0) + } +} + +func encodeCid(w *WriteBuf, oid Oid, value Cid) error { + if oid != CidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + +// Note that we do not match negative numbers, because neither the +// BlockNumber nor OffsetNumber of a Tid can be negative. +var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) + +func decodeTid(vr *ValueReader) Tid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Tid")) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + if vr.Type().DataType != TidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + // Unlikely Tid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + + match := tidRegexp.FindStringSubmatch(s) + if match == nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + blockNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s))) + } + + offsetNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) + } + return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)} + case BinaryFormatCode: + if vr.Len() != 6 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()} + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } +} + +func encodeTid(w *WriteBuf, oid Oid, value Tid) error { + if oid != TidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) + } + + w.WriteInt32(6) + w.WriteUint32(value.BlockNumber) + w.WriteUint16(value.OffsetNumber) + + return nil +} + +func decodeFloat4(vr *ValueReader) float32 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into float32")) + return 0 + } + + if vr.Type().DataType != Float4Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) + return 0 + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return 0 + } + + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) + return 0 + } + + i := vr.ReadInt32() + return math.Float32frombits(uint32(i)) +} + +func encodeFloat32(w *WriteBuf, oid Oid, value float32) error { + switch oid { + case Float4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(math.Float32bits(value))) + case Float8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(math.Float64bits(float64(value)))) + default: + return fmt.Errorf("cannot encode %s into oid %v", "float32", oid) + } + + return nil +} + +func decodeFloat8(vr *ValueReader) float64 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into float64")) + return 0 + } + + if vr.Type().DataType != Float8Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) + return 0 + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return 0 + } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) + return 0 + } + + i := vr.ReadInt64() + return math.Float64frombits(uint64(i)) +} + +func encodeFloat64(w *WriteBuf, oid Oid, value float64) error { + switch oid { + case Float8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(math.Float64bits(value))) + default: + return fmt.Errorf("cannot encode %s into oid %v", "float64", oid) + } + + return nil +} + +func decodeText(vr *ValueReader) string { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into string")) + return "" + } + + return vr.ReadString(vr.Len()) +} + +func encodeString(w *WriteBuf, oid Oid, value string) error { + w.WriteInt32(int32(len(value))) + w.WriteBytes([]byte(value)) + return nil +} + +func decodeBytea(vr *ValueReader) []byte { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != ByteaOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + return vr.ReadBytes(vr.Len()) +} + +func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error { + w.WriteInt32(int32(len(value))) + w.WriteBytes(value) + + return nil +} + +func decodeJSON(vr *ValueReader, d interface{}) error { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != JsonOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) + } + + bytes := vr.ReadBytes(vr.Len()) + err := json.Unmarshal(bytes, d) + if err != nil { + vr.Fatal(err) + } + return err +} + +func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { + if oid != JsonOid { + return fmt.Errorf("cannot encode JSON into oid %v", oid) + } + + s, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("Failed to encode json from type: %T", value) + } + + w.WriteInt32(int32(len(s))) + w.WriteBytes(s) + + return nil +} + +func decodeJSONB(vr *ValueReader, d interface{}) error { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != JsonbOid { + err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)) + vr.Fatal(err) + return err + } + + bytes := vr.ReadBytes(vr.Len()) + if vr.Type().FormatCode == BinaryFormatCode { + if bytes[0] != 1 { + err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])) + vr.Fatal(err) + return err + } + bytes = bytes[1:] + } + + err := json.Unmarshal(bytes, d) + if err != nil { + vr.Fatal(err) + } + return err +} + +func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error { + if oid != JsonbOid { + return fmt.Errorf("cannot encode JSON into oid %v", oid) + } + + s, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("Failed to encode json from type: %T", value) + } + + w.WriteInt32(int32(len(s) + 1)) + w.WriteByte(1) // JSONB format header + w.WriteBytes(s) + + return nil +} + +func decodeDate(vr *ValueReader) time.Time { + var zeroTime time.Time + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into time.Time")) + return zeroTime + } + + if vr.Type().DataType != DateOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) + return zeroTime + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return zeroTime + } + + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) + } + dayOffset := vr.ReadInt32() + return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) +} + +func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { + switch oid { + case DateOid: + tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch := secSinceDateEpoch / 86400 + + w.WriteInt32(4) + w.WriteInt32(int32(daysSinceDateEpoch)) + + return nil + case TimestampTzOid, TimestampOid: + microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 + microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + + w.WriteInt32(8) + w.WriteInt64(microsecSinceY2K) + + return nil + default: + return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) + } +} + +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +func decodeTimestampTz(vr *ValueReader) time.Time { + var zeroTime time.Time + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into time.Time")) + return zeroTime + } + + if vr.Type().DataType != TimestampTzOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) + return zeroTime + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return zeroTime + } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) + return zeroTime + } + + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) +} + +func decodeTimestamp(vr *ValueReader) time.Time { + var zeroTime time.Time + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into timestamp")) + return zeroTime + } + + if vr.Type().DataType != TimestampOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) + return zeroTime + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return zeroTime + } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) + return zeroTime + } + + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) +} + +func decodeInet(vr *ValueReader) net.IPNet { + var zero net.IPNet + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into net.IPNet")) + return zero + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return zero + } + + pgType := vr.Type() + if pgType.DataType != InetOid && pgType.DataType != CidrOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) + return zero + } + if vr.Len() != 8 && vr.Len() != 20 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len()))) + return zero + } + + vr.ReadByte() // ignore family + bits := vr.ReadByte() + vr.ReadByte() // ignore is_cidr + addressLength := vr.ReadByte() + + var ipnet net.IPNet + ipnet.IP = vr.ReadBytes(int32(addressLength)) + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + + return ipnet +} + +func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { + if oid != InetOid && oid != CidrOid { + return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid) + } + + var size int32 + var family byte + switch len(value.IP) { + case net.IPv4len: + size = 8 + family = *w.conn.pgsqlAfInet + case net.IPv6len: + size = 20 + family = *w.conn.pgsqlAfInet6 + default: + return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) + } + + w.WriteInt32(size) + w.WriteByte(family) + ones, _ := value.Mask.Size() + w.WriteByte(byte(ones)) + w.WriteByte(0) // is_cidr is ignored on server + w.WriteByte(byte(len(value.IP))) + w.WriteBytes(value.IP) + + return nil +} + +func encodeIP(w *WriteBuf, oid Oid, value net.IP) error { + if oid != InetOid && oid != CidrOid { + return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid) + } + + var ipnet net.IPNet + ipnet.IP = value + bitCount := len(value) * 8 + ipnet.Mask = net.CIDRMask(bitCount, bitCount) + return encodeIPNet(w, oid, ipnet) +} + +func decodeRecord(vr *ValueReader) []interface{} { + if vr.Len() == -1 { + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + if vr.Type().DataType != RecordOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) + return nil + } + + valueCount := vr.ReadInt32() + record := make([]interface{}, 0, int(valueCount)) + + for i := int32(0); i < valueCount; i++ { + fd := FieldDescription{FormatCode: BinaryFormatCode} + fieldVR := ValueReader{mr: vr.mr, fd: &fd} + fd.DataType = vr.ReadOid() + fieldVR.valueBytesRemaining = vr.ReadInt32() + vr.valueBytesRemaining -= fieldVR.valueBytesRemaining + + switch fd.DataType { + case BoolOid: + record = append(record, decodeBool(&fieldVR)) + case ByteaOid: + record = append(record, decodeBytea(&fieldVR)) + case Int8Oid: + record = append(record, decodeInt8(&fieldVR)) + case Int2Oid: + record = append(record, decodeInt2(&fieldVR)) + case Int4Oid: + record = append(record, decodeInt4(&fieldVR)) + case OidOid: + record = append(record, decodeOid(&fieldVR)) + case Float4Oid: + record = append(record, decodeFloat4(&fieldVR)) + case Float8Oid: + record = append(record, decodeFloat8(&fieldVR)) + case DateOid: + record = append(record, decodeDate(&fieldVR)) + case TimestampTzOid: + record = append(record, decodeTimestampTz(&fieldVR)) + case TimestampOid: + record = append(record, decodeTimestamp(&fieldVR)) + case InetOid, CidrOid: + record = append(record, decodeInet(&fieldVR)) + case TextOid, VarcharOid, UnknownOid: + record = append(record, decodeText(&fieldVR)) + default: + vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) + return nil + } + + // Consume any remaining data + if fieldVR.Len() > 0 { + fieldVR.ReadBytes(fieldVR.Len()) + } + + if fieldVR.Err() != nil { + vr.Fatal(fieldVR.Err()) + return nil + } + } + + return record +} + +func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { + numDims := vr.ReadInt32() + if numDims > 1 { + return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims)) + } + + vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care + vr.ReadInt32() // element oid + + if numDims == 0 { + return 0, nil + } + + length = vr.ReadInt32() + + idxFirstElem := vr.ReadInt32() + if idxFirstElem != 1 { + return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem)) + } + + return length, nil +} + +func decodeBoolArray(vr *ValueReader) []bool { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != BoolArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]bool, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 1: + if vr.ReadByte() == 1 { + a[i] = true + } + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error { + if oid != BoolArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid) + } + + encodeArrayHeader(w, BoolOid, len(slice), 5) + for _, v := range slice { + w.WriteInt32(1) + var b byte + if v { + b = 1 + } + w.WriteByte(b) + } + + return nil +} + +func decodeByteaArray(vr *ValueReader) [][]byte { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != ByteaArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([][]byte, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + a[i] = vr.ReadBytes(elSize) + } + } + + return a +} + +func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error { + if oid != ByteaArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid) + } + + size := 20 // array header size + for _, el := range value { + size += 4 + len(el) + } + + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(ByteaOid)) // type of elements + w.WriteInt32(int32(len(value))) // number of elements + w.WriteInt32(1) // index of first element + + for _, el := range value { + encodeByteSlice(w, ByteaOid, el) + } + + return nil +} + +func decodeInt2Array(vr *ValueReader) []int16 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int2ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]int16, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 2: + a[i] = vr.ReadInt16() + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) + return nil + } + } + + return a +} + +func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int2ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]uint16, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 2: + tmp := vr.ReadInt16() + if tmp < 0 { + vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint16", tmp))) + return nil + } + a[i] = uint16(tmp) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { + if oid != Int2ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) + } + + encodeArrayHeader(w, Int2Oid, len(slice), 6) + for _, v := range slice { + w.WriteInt32(2) + w.WriteInt16(v) + } + + return nil +} + +func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error { + if oid != Int2ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) + } + + encodeArrayHeader(w, Int2Oid, len(slice), 6) + for _, v := range slice { + if v <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(v)) + } else { + return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16) + } + } + + return nil +} + +func decodeInt4Array(vr *ValueReader) []int32 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int4ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]int32, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 4: + a[i] = vr.ReadInt32() + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) + return nil + } + } + + return a +} + +func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int4ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]uint32, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 4: + tmp := vr.ReadInt32() + if tmp < 0 { + vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint32", tmp))) + return nil + } + a[i] = uint32(tmp) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { + if oid != Int4ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid) + } + + encodeArrayHeader(w, Int4Oid, len(slice), 8) + for _, v := range slice { + w.WriteInt32(4) + w.WriteInt32(v) + } + + return nil +} + +func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error { + if oid != Int4ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid) + } + + encodeArrayHeader(w, Int4Oid, len(slice), 8) + for _, v := range slice { + if v <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(v)) + } else { + return fmt.Errorf("%d is greater than max integer %d", v, math.MaxInt32) + } + } + + return nil +} + +func decodeInt8Array(vr *ValueReader) []int64 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int8ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]int64, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + a[i] = vr.ReadInt64() + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) + return nil + } + } + + return a +} + +func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Int8ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]uint64, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + tmp := vr.ReadInt64() + if tmp < 0 { + vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint64", tmp))) + return nil + } + a[i] = uint64(tmp) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { + if oid != Int8ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid) + } + + encodeArrayHeader(w, Int8Oid, len(slice), 12) + for _, v := range slice { + w.WriteInt32(8) + w.WriteInt64(v) + } + + return nil +} + +func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error { + if oid != Int8ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid) + } + + encodeArrayHeader(w, Int8Oid, len(slice), 12) + for _, v := range slice { + if v <= math.MaxInt64 { + w.WriteInt32(8) + w.WriteInt64(int64(v)) + } else { + return fmt.Errorf("%d is greater than max bigint %d", v, int64(math.MaxInt64)) + } + } + + return nil +} + +func decodeFloat4Array(vr *ValueReader) []float32 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Float4ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]float32, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 4: + n := vr.ReadInt32() + a[i] = math.Float32frombits(uint32(n)) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeFloat32Slice(w *WriteBuf, oid Oid, slice []float32) error { + if oid != Float4ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid) + } + + encodeArrayHeader(w, Float4Oid, len(slice), 8) + for _, v := range slice { + w.WriteInt32(4) + w.WriteInt32(int32(math.Float32bits(v))) + } + + return nil +} + +func decodeFloat8Array(vr *ValueReader) []float64 { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != Float8ArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]float64, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + n := vr.ReadInt64() + a[i] = math.Float64frombits(uint64(n)) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeFloat64Slice(w *WriteBuf, oid Oid, slice []float64) error { + if oid != Float8ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid) + } + + encodeArrayHeader(w, Float8Oid, len(slice), 12) + for _, v := range slice { + w.WriteInt32(8) + w.WriteInt64(int64(math.Float64bits(v))) + } + + return nil +} + +func decodeTextArray(vr *ValueReader) []string { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]string, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + if elSize == -1 { + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + } + + a[i] = vr.ReadString(elSize) + } + + return a +} + +// escapeAclItem escapes an AclItem before it is added to +// its aclitem[] string representation. The PostgreSQL aclitem +// datatype itself can need escapes because it follows the +// formatting rules of SQL identifiers. Think of this function +// as escaping the escapes, so that PostgreSQL's array parser +// will do the right thing. +func escapeAclItem(acl string) (string, error) { + var escapedAclItem bytes.Buffer + reader := strings.NewReader(acl) + for { + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error. + return escapedAclItem.String(), nil + } + // This error was not expected + return "", err + } + if needsEscape(rn) { + escapedAclItem.WriteRune('\\') + } + escapedAclItem.WriteRune(rn) + } +} + +// needsEscape determines whether or not a rune needs escaping +// before being placed in the textual representation of an +// aclitem[] array. +func needsEscape(rn rune) bool { + return rn == '\\' || rn == ',' || rn == '"' || rn == '}' +} + +// encodeAclItemSlice encodes a slice of AclItems in +// their textual represention for PostgreSQL. +func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { + strs := make([]string, len(aclitems)) + var escapedAclItem string + var err error + for i := range strs { + escapedAclItem, err = escapeAclItem(string(aclitems[i])) + if err != nil { + return err + } + strs[i] = string(escapedAclItem) + } + + var buf bytes.Buffer + buf.WriteRune('{') + buf.WriteString(strings.Join(strs, ",")) + buf.WriteRune('}') + str := buf.String() + w.WriteInt32(int32(len(str))) + w.WriteBytes([]byte(str)) + return nil +} + +// parseAclItemArray parses the textual representation +// of the aclitem[] type. The textual representation is chosen because +// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin). +// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +// for formatting notes. +func parseAclItemArray(arr string) ([]AclItem, error) { + reader := strings.NewReader(arr) + // Difficult to guess a performant initial capacity for a slice of + // aclitems, but let's go with 5. + aclItems := make([]AclItem, 0, 5) + // A single value + aclItem := AclItem("") + for { + // Grab the first/next/last rune to see if we are dealing with a + // quoted value, an unquoted value, or the end of the string. + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error. + return aclItems, nil + } + // This error was not expected + return nil, err + } + + if rn == '"' { + // Discard the opening quote of the quoted value. + aclItem, err = parseQuotedAclItem(reader) + } else { + // We have just read the first rune of an unquoted (bare) value; + // put it back so that ParseBareValue can read it. + err := reader.UnreadRune() + if err != nil { + return nil, err + } + aclItem, err = parseBareAclItem(reader) + } + + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error.. + aclItems = append(aclItems, aclItem) + return aclItems, nil + } + // This error was not expected. + return nil, err + } + aclItems = append(aclItems, aclItem) + } +} + +// parseBareAclItem parses a bare (unquoted) aclitem from reader +func parseBareAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer + for { + rn, _, err := reader.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF. + // (io.EOF marks the end of a bare aclitem at the end of a string) + return AclItem(aclItem.String()), err + } + if rn == ',' { + // A comma marks the end of a bare aclitem. + return AclItem(aclItem.String()), nil + } else { + aclItem.WriteRune(rn) + } + } +} + +// parseQuotedAclItem parses an aclitem which is in double quotes from reader +func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer + for { + rn, escaped, err := readPossiblyEscapedRune(reader) + if err != nil { + if err == io.EOF { + // Even when it is the last value, the final rune of + // a quoted aclitem should be the final closing quote, not io.EOF. + return AclItem(""), fmt.Errorf("unexpected end of quoted value") + } + // Return the read aclitem in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err + } + if !escaped && rn == '"' { + // An unescaped double quote marks the end of a quoted value. + // The next rune should either be a comma or the end of the string. + rn, _, err := reader.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err + } + if rn != ',' { + return AclItem(""), fmt.Errorf("unexpected rune after quoted value") + } + return AclItem(aclItem.String()), nil + } + aclItem.WriteRune(rn) + } +} + +// Returns the next rune from r, unless it is a backslash; +// in that case, it returns the rune after the backslash. The second +// return value tells us whether or not the rune was +// preceeded by a backslash (escaped). +func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) { + rn, _, err := reader.ReadRune() + if err != nil { + return 0, false, err + } + if rn == '\\' { + // Discard the backslash and read the next rune. + rn, _, err = reader.ReadRune() + if err != nil { + return 0, false, err + } + return rn, true, nil + } + return rn, false, nil +} + +func decodeAclItemArray(vr *ValueReader) []AclItem { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) + return nil + } + + str := vr.ReadString(vr.Len()) + + // Short-circuit empty array. + if str == "{}" { + return []AclItem{} + } + + // Remove the '{' at the front and the '}' at the end, + // so that parseAclItemArray doesn't have to deal with them. + str = str[1 : len(str)-1] + aclItems, err := parseAclItemArray(str) + if err != nil { + vr.Fatal(ProtocolError(err.Error())) + return nil + } + return aclItems +} + +func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { + var elOid Oid + switch oid { + case VarcharArrayOid: + elOid = VarcharOid + case TextArrayOid: + elOid = TextOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid) + } + + var totalStringSize int + for _, v := range slice { + totalStringSize += len(v) + } + + size := 20 + len(slice)*4 + totalStringSize + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, v := range slice { + w.WriteInt32(int32(len(v))) + w.WriteBytes([]byte(v)) + } + + return nil +} + +func decodeTimestampArray(vr *ValueReader) []time.Time { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]time.Time, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + switch elSize { + case 8: + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + case -1: + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize))) + return nil + } + } + + return a +} + +func encodeTimeSlice(w *WriteBuf, oid Oid, slice []time.Time) error { + var elOid Oid + switch oid { + case TimestampArrayOid: + elOid = TimestampOid + case TimestampTzArrayOid: + elOid = TimestampTzOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid) + } + + encodeArrayHeader(w, int(elOid), len(slice), 12) + for _, t := range slice { + w.WriteInt32(8) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + w.WriteInt64(microsecSinceY2K) + } + + return nil +} + +func decodeInetArray(vr *ValueReader) []net.IPNet { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]net.IPNet, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + if elSize == -1 { + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + } + + vr.ReadByte() // ignore family + bits := vr.ReadByte() + vr.ReadByte() // ignore is_cidr + addressLength := vr.ReadByte() + + var ipnet net.IPNet + ipnet.IP = vr.ReadBytes(int32(addressLength)) + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + + a[i] = ipnet + } + + return a +} + +func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error { + var elOid Oid + switch oid { + case InetArrayOid: + elOid = InetOid + case CidrArrayOid: + elOid = CidrOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) + } + + size := int32(20) // array header size + for _, ipnet := range slice { + size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes + } + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, ipnet := range slice { + encodeIPNet(w, elOid, ipnet) + } + + return nil +} + +func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error { + var elOid Oid + switch oid { + case InetArrayOid: + elOid = InetOid + case CidrArrayOid: + elOid = CidrOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) + } + + size := int32(20) // array header size + for _, ip := range slice { + size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes + } + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, ip := range slice { + encodeIP(w, elOid, ip) + } + + return nil +} + +func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { + w.WriteInt32(int32(20 + length*sizePerItem)) + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(oid)) // type of elements + w.WriteInt32(int32(length)) // number of elements + w.WriteInt32(1) // index of first element +} diff --git a/vendor/github.com/jackc/pgx/values_test.go b/vendor/github.com/jackc/pgx/values_test.go new file mode 100644 index 0000000..42d5bd3 --- /dev/null +++ b/vendor/github.com/jackc/pgx/values_test.go @@ -0,0 +1,1183 @@ +package pgx_test + +import ( + "bytes" + "net" + "reflect" + "strings" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestDateTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + dates := []time.Time{ + time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1000, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1600, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1700, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), + time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2001, 1, 2, 0, 0, 0, 0, time.Local), + time.Date(2004, 2, 29, 0, 0, 0, 0, time.Local), + time.Date(2013, 7, 4, 0, 0, 0, 0, time.Local), + time.Date(2013, 12, 25, 0, 0, 0, 0, time.Local), + time.Date(2029, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2081, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2096, 2, 29, 0, 0, 0, 0, time.Local), + time.Date(2550, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(9999, 12, 31, 0, 0, 0, 0, time.Local), + } + + for _, actualDate := range dates { + var d time.Time + + err := conn.QueryRow("select $1::date", actualDate).Scan(&d) + if err != nil { + t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) + } + if !actualDate.Equal(d) { + t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) + } + } +} + +func TestTimestampTzTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) + + var outputTime time.Time + + err := conn.QueryRow("select $1::timestamptz", inputTime).Scan(&outputTime) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if !inputTime.Equal(outputTime) { + t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) + } + + err = conn.QueryRow("select $1::timestamptz", inputTime).Scan(&outputTime) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if !inputTime.Equal(outputTime) { + t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) + } +} + +func TestJsonAndJsonbTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + + for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { + pgtype := conn.PgTypes[oid] + pgtype.DefaultFormat = format + conn.PgTypes[oid] = pgtype + + typename := conn.PgTypes[oid].Name + + testJsonString(t, conn, typename, format) + testJsonStringPointer(t, conn, typename, format) + testJsonSingleLevelStringMap(t, conn, typename, format) + testJsonNestedMap(t, conn, typename, format) + testJsonStringArray(t, conn, typename, format) + testJsonInt64Array(t, conn, typename, format) + testJsonInt16ArrayFailureDueToOverflow(t, conn, typename, format) + testJsonStruct(t, conn, typename, format) + } + } +} + +func testJsonString(t *testing.T, conn *pgx.Conn, typename string, format int16) { + input := `{"key": "value"}` + expectedOutput := map[string]string{"key": "value"} + var output map[string]string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + return + } + + if !reflect.DeepEqual(expectedOutput, output) { + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) + return + } +} + +func testJsonStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { + input := `{"key": "value"}` + expectedOutput := map[string]string{"key": "value"} + var output map[string]string + err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) + if err != nil { + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + return + } + + if !reflect.DeepEqual(expectedOutput, output) { + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) + return + } +} + +func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { + input := map[string]string{"key": "value"} + var output map[string]string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + return + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) + return + } +} + +func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { + input := map[string]interface{}{ + "name": "Uncanny", + "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, + "inventory": []interface{}{"phone", "key"}, + } + var output map[string]interface{} + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + return + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) + return + } +} + +func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { + input := []string{"foo", "bar", "baz"} + var output []string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) + } +} + +func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { + input := []int64{1, 2, 234432} + var output []int64 + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) + } +} + +func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { + input := []int{1, 2, 234432} + var output []int16 + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { + t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) + } +} + +func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { + type person struct { + Name string `json:"name"` + Age int `json:"age"` + } + + input := person{ + Name: "John", + Age: 42, + } + + var output person + + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) + } +} + +func mustParseCIDR(t *testing.T, s string) net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return *ipnet +} + +func TestStringToNotTextTypeTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + input := "01086ee0-4963-4e35-9116-30c173a8d0bd" + + var output string + err := conn.QueryRow("select $1::uuid", input).Scan(&output) + if err != nil { + t.Fatal(err) + } + if input != output { + t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) + } + + err = conn.QueryRow("select $1::uuid", &input).Scan(&output) + if err != nil { + t.Fatal(err) + } + if input != output { + t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) + } +} + +func TestInetCidrTranscodeIPNet(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::inet", mustParseCIDR(t, "::/128")}, + {"select $1::inet", mustParseCIDR(t, "::/0")}, + {"select $1::inet", mustParseCIDR(t, "::1/128")}, + {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::cidr", mustParseCIDR(t, "::/128")}, + {"select $1::cidr", mustParseCIDR(t, "::/0")}, + {"select $1::cidr", mustParseCIDR(t, "::1/128")}, + {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + } + + for i, tt := range tests { + var actual net.IPNet + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if actual.String() != tt.value.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + +func TestInetCidrTranscodeIP(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value net.IP + }{ + {"select $1::inet", net.ParseIP("0.0.0.0")}, + {"select $1::inet", net.ParseIP("127.0.0.1")}, + {"select $1::inet", net.ParseIP("12.34.56.0")}, + {"select $1::inet", net.ParseIP("255.255.255.255")}, + {"select $1::inet", net.ParseIP("::1")}, + {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, + {"select $1::cidr", net.ParseIP("0.0.0.0")}, + {"select $1::cidr", net.ParseIP("127.0.0.1")}, + {"select $1::cidr", net.ParseIP("12.34.56.0")}, + {"select $1::cidr", net.ParseIP("255.255.255.255")}, + {"select $1::cidr", net.ParseIP("::1")}, + {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, + } + + for i, tt := range tests { + var actual net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !actual.Equal(tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + + failTests := []struct { + sql string + value net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + } + for i, tt := range failTests { + var actual net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if !strings.Contains(err.Error(), "Cannot decode netmask") { + t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + ensureConnValid(t, conn) + } +} + +func TestInetCidrArrayTranscodeIPNet(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value []net.IPNet + }{ + { + "select $1::inet[]", + []net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, + }, + { + "select $1::cidr[]", + []net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, + }, + } + + for i, tt := range tests { + var actual []net.IPNet + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !reflect.DeepEqual(actual, tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + +func TestInetCidrArrayTranscodeIP(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value []net.IP + }{ + { + "select $1::inet[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, + }, + { + "select $1::cidr[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, + }, + } + + for i, tt := range tests { + var actual []net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !reflect.DeepEqual(actual, tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + + failTests := []struct { + sql string + value []net.IPNet + }{ + { + "select $1::inet[]", + []net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, + }, + { + "select $1::cidr[]", + []net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, + }, + } + + for i, tt := range failTests { + var actual []net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") { + t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + ensureConnValid(t, conn) + } +} + +func TestInetCidrTranscodeWithJustIP(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value string + }{ + {"select $1::inet", "0.0.0.0/32"}, + {"select $1::inet", "127.0.0.1/32"}, + {"select $1::inet", "12.34.56.0/32"}, + {"select $1::inet", "255.255.255.255/32"}, + {"select $1::inet", "::/128"}, + {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, + {"select $1::cidr", "0.0.0.0/32"}, + {"select $1::cidr", "127.0.0.1/32"}, + {"select $1::cidr", "12.34.56.0/32"}, + {"select $1::cidr", "255.255.255.255/32"}, + {"select $1::cidr", "::/128"}, + {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, + } + + for i, tt := range tests { + expected := mustParseCIDR(t, tt.value) + var actual net.IPNet + + err := conn.QueryRow(tt.sql, expected.IP).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if actual.String() != expected.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + +func TestNullX(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s pgx.NullString + i16 pgx.NullInt16 + i32 pgx.NullInt32 + c pgx.NullChar + a pgx.NullAclItem + n pgx.NullName + oid pgx.NullOid + xid pgx.NullXid + cid pgx.NullCid + tid pgx.NullTid + i64 pgx.NullInt64 + f32 pgx.NullFloat32 + f64 pgx.NullFloat64 + b pgx.NullBool + t pgx.NullTime + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, + {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, + {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 1, Valid: true}}}, + {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, + {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, + {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, + {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 1, Valid: true}}}, + {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 0, Valid: false}}}, + {"select $1::oid", []interface{}{pgx.NullOid{Oid: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 4294967295, Valid: true}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, + {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, + {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, + // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, + {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, + {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, + {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, + {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: false}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 0, Valid: false}}}, + {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, + {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: false}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 0, Valid: false}}}, + {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}}, + {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}}, + {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, + {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, + {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, + {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, + {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { + if !reflect.DeepEqual(query, scan) { + t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan) + } +} + +func TestAclArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := "select $1::aclitem[]" + var scan []pgx.AclItem + + tests := []struct { + query []pgx.AclItem + }{ + { + []pgx.AclItem{}, + }, + { + []pgx.AclItem{"=r/postgres"}, + }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, + }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, + }, + } + for i, tt := range tests { + err := conn.QueryRow(sql, tt.query).Scan(&scan) + if err != nil { + // t.Errorf(`%d. error reading array: %v`, i, err) + t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) + if pgerr, ok := err.(pgx.PgError); ok { + t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) + } + continue + } + assertAclItemSlicesEqual(t, tt.query, scan) + ensureConnValid(t, conn) + } +} + +func TestArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + query interface{} + scan interface{} + assert func(*testing.T, interface{}, interface{}) + }{ + { + "select $1::bool[]", []bool{true, false, true}, &[]bool{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]bool))) { + t.Errorf("failed to encode bool[]") + } + }, + }, + { + "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]int16))) { + t.Errorf("failed to encode smallint[]") + } + }, + }, + { + "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { + t.Errorf("failed to encode smallint[]") + } + }, + }, + { + "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]int32))) { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]int64))) { + t.Errorf("failed to encode bigint[]") + } + }, + }, + { + "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { + t.Errorf("failed to encode bigint[]") + } + }, + }, + { + "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]string))) { + t.Errorf("failed to encode text[]") + } + }, + }, + { + "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { + t.Errorf("failed to encode time.Time[] to timestamp[]") + } + }, + }, + { + "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { + t.Errorf("failed to encode time.Time[] to timestamptz[]") + } + }, + }, + { + "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, + func(t *testing.T, query, scan interface{}) { + queryBytesSliceSlice := query.([][]byte) + scanBytesSliceSlice := *(scan.(*[][]byte)) + if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { + t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) + } + for i := range queryBytesSliceSlice { + qb := queryBytesSliceSlice[i] + sb := scanBytesSliceSlice[i] + if !bytes.Equal(qb, sb) { + t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) + } + } + }, + }, + } + + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`%d. error reading array: %v`, i, err) + continue + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) + } +} + +type shortScanner struct{} + +func (*shortScanner) Scan(r *pgx.ValueReader) error { + r.ReadByte() + return nil +} + +func TestShortScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select 'ab', 'cd' union select 'cd', 'ef'") + if err != nil { + t.Error(err) + } + defer rows.Close() + + for rows.Next() { + var s1, s2 shortScanner + err = rows.Scan(&s1, &s2) + if err != nil { + t.Error(err) + } + } + + ensureConnValid(t, conn) +} + +func TestEmptyArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var val []string + + err := conn.QueryRow("select array[]::text[]").Scan(&val) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + if len(val) != 0 { + t.Errorf("Expected 0 values, got %d", len(val)) + } + + var n, m int32 + + err = conn.QueryRow("select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + if len(val) != 0 { + t.Errorf("Expected 0 values, got %d", len(val)) + } + if n != 1 { + t.Errorf("Expected n to be 1, but it was %d", n) + } + if m != 42 { + t.Errorf("Expected n to be 42, but it was %d", n) + } + + rows, err := conn.Query("select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]") + if err != nil { + t.Errorf(`error retrieving rows with array: %v`, err) + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(&n, &val) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + } + + ensureConnValid(t, conn) +} + +func TestNullXMismatch(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s pgx.NullString + i16 pgx.NullInt16 + i32 pgx.NullInt32 + i64 pgx.NullInt64 + f32 pgx.NullFloat32 + f64 pgx.NullFloat64 + b pgx.NullBool + t pgx.NullTime + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + err string + }{ + {"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"}, + {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into OID 1082"}, + {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 23"}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err == nil || !strings.Contains(err.Error(), tt.err) { + t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err) + } + + ensureConnValid(t, conn) + } +} + +func TestPointerPointer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s *string + i16 *int16 + i32 *int32 + i64 *int64 + f32 *float32 + f64 *float64 + b *bool + t *time.Time + } + + var actual, zero, expected allTypes + + { + s := "foo" + expected.s = &s + i16 := int16(1) + expected.i16 = &i16 + i32 := int32(1) + expected.i32 = &i32 + i64 := int64(1) + expected.i64 = &i64 + f32 := float32(1.23) + expected.f32 = &f32 + f64 := float64(1.23) + expected.f64 = &f64 + b := true + expected.b = &b + t := time.Unix(123, 5000) + expected.t = &t + } + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, + {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, + {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, + {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, + {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, + {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, + {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, + {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, + {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, + {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, + {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, + {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, + {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, + {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, + {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, + {"select $1::timestamp", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamp", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func TestPointerPointerNonZero(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + f := "foo" + dest := &f + + err := conn.QueryRow("select $1::text", nil).Scan(&dest) + if err != nil { + t.Errorf("Unexpected failure scanning: %v", err) + } + if dest != nil { + t.Errorf("Expected dest to be nil, got %#v", dest) + } +} + +func TestEncodeTypeRename(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type _int int + inInt := _int(3) + var outInt _int + + type _int8 int8 + inInt8 := _int8(3) + var outInt8 _int8 + + type _int16 int16 + inInt16 := _int16(3) + var outInt16 _int16 + + type _int32 int32 + inInt32 := _int32(4) + var outInt32 _int32 + + type _int64 int64 + inInt64 := _int64(5) + var outInt64 _int64 + + type _uint uint + inUint := _uint(6) + var outUint _uint + + type _uint8 uint8 + inUint8 := _uint8(7) + var outUint8 _uint8 + + type _uint16 uint16 + inUint16 := _uint16(8) + var outUint16 _uint16 + + type _uint32 uint32 + inUint32 := _uint32(9) + var outUint32 _uint32 + + type _uint64 uint64 + inUint64 := _uint64(10) + var outUint64 _uint64 + + type _string string + inString := _string("foo") + var outString _string + + err := conn.QueryRow("select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", + inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, + ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) + if err != nil { + t.Fatalf("Failed with type rename: %v", err) + } + + if inInt != outInt { + t.Errorf("int rename: expected %v, got %v", inInt, outInt) + } + + if inInt8 != outInt8 { + t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) + } + + if inInt16 != outInt16 { + t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) + } + + if inInt32 != outInt32 { + t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) + } + + if inInt64 != outInt64 { + t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) + } + + if inUint != outUint { + t.Errorf("uint rename: expected %v, got %v", inUint, outUint) + } + + if inUint8 != outUint8 { + t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) + } + + if inUint16 != outUint16 { + t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) + } + + if inUint32 != outUint32 { + t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) + } + + if inUint64 != outUint64 { + t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) + } + + if inString != outString { + t.Errorf("string rename: expected %v, got %v", inString, outString) + } + + ensureConnValid(t, conn) +} + +func TestRowDecode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + expected []interface{} + }{ + { + "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", + []interface{}{ + int32(1), + "cat", + time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), + }, + }, + } + + for i, tt := range tests { + var actual []interface{} + + err := conn.QueryRow(tt.sql).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) + continue + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} -- cgit v1.2.3