summaryrefslogtreecommitdiff
path: root/vendor/github.com/pires/go-proxyproto/protocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pires/go-proxyproto/protocol.go')
-rw-r--r--vendor/github.com/pires/go-proxyproto/protocol.go147
1 files changed, 46 insertions, 101 deletions
diff --git a/vendor/github.com/pires/go-proxyproto/protocol.go b/vendor/github.com/pires/go-proxyproto/protocol.go
index 270b90d..658900a 100644
--- a/vendor/github.com/pires/go-proxyproto/protocol.go
+++ b/vendor/github.com/pires/go-proxyproto/protocol.go
@@ -2,8 +2,6 @@ package proxyproto
import (
"bufio"
- "errors"
- "fmt"
"io"
"net"
"sync"
@@ -11,17 +9,11 @@ import (
"time"
)
-var (
- // DefaultReadHeaderTimeout is how long header processing waits for header to
- // be read from the wire, if Listener.ReaderHeaderTimeout is not set.
- // It's kept as a global variable so to make it easier to find and override,
- // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
- DefaultReadHeaderTimeout = 10 * time.Second
-
- // ErrInvalidUpstream should be returned when an upstream connection address
- // is not trusted, and therefore is invalid.
- ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
-)
+// DefaultReadHeaderTimeout is how long header processing waits for header to
+// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
+// It's kept as a global variable so to make it easier to find and override,
+// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
+var DefaultReadHeaderTimeout = 10 * time.Second
// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
@@ -51,11 +43,10 @@ type Conn struct {
once sync.Once
readErr error
conn net.Conn
+ Validate Validator
bufReader *bufio.Reader
- reader io.Reader
header *Header
ProxyHeaderPolicy Policy
- Validate Validator
readHeaderTimeout time.Duration
}
@@ -72,70 +63,53 @@ func ValidateHeader(v Validator) func(*Conn) {
}
}
-// SetReadHeaderTimeout sets the readHeaderTimeout for a connection when passed as option to NewConn()
-func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
- return func(c *Conn) {
- if t >= 0 {
- c.readHeaderTimeout = t
- }
+// Accept waits for and returns the next connection to the listener.
+func (p *Listener) Accept() (net.Conn, error) {
+ // Get the underlying connection
+ conn, err := p.Listener.Accept()
+ if err != nil {
+ return nil, err
}
-}
-// Accept waits for and returns the next valid connection to the listener.
-func (p *Listener) Accept() (net.Conn, error) {
- for {
- // Get the underlying connection
- conn, err := p.Listener.Accept()
+ proxyHeaderPolicy := USE
+ if p.Policy != nil && p.ConnPolicy != nil {
+ panic("only one of policy or connpolicy must be provided.")
+ }
+ if p.Policy != nil || p.ConnPolicy != nil {
+ if p.Policy != nil {
+ proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
+ } else {
+ proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
+ Upstream: conn.RemoteAddr(),
+ Downstream: conn.LocalAddr(),
+ })
+ }
if err != nil {
+ // can't decide the policy, we can't accept the connection
+ conn.Close()
return nil, err
}
-
- proxyHeaderPolicy := USE
- if p.Policy != nil && p.ConnPolicy != nil {
- panic("only one of policy or connpolicy must be provided.")
- }
- if p.Policy != nil || p.ConnPolicy != nil {
- if p.Policy != nil {
- proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
- } else {
- proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
- Upstream: conn.RemoteAddr(),
- Downstream: conn.LocalAddr(),
- })
- }
- if err != nil {
- // can't decide the policy, we can't accept the connection
- conn.Close()
-
- if errors.Is(err, ErrInvalidUpstream) {
- // keep listening for other connections
- continue
- }
-
- return nil, err
- }
- // Handle a connection as a regular one
- if proxyHeaderPolicy == SKIP {
- return conn, nil
- }
+ // Handle a connection as a regular one
+ if proxyHeaderPolicy == SKIP {
+ return conn, nil
}
+ }
- newConn := NewConn(
- conn,
- WithPolicy(proxyHeaderPolicy),
- ValidateHeader(p.ValidateHeader),
- )
+ newConn := NewConn(
+ conn,
+ WithPolicy(proxyHeaderPolicy),
+ ValidateHeader(p.ValidateHeader),
+ )
- // If the ReadHeaderTimeout for the listener is unset, use the default timeout.
- if p.ReadHeaderTimeout == 0 {
- p.ReadHeaderTimeout = DefaultReadHeaderTimeout
- }
+ // If the ReadHeaderTimeout for the listener is unset, use the default timeout.
+ if p.ReadHeaderTimeout == 0 {
+ p.ReadHeaderTimeout = DefaultReadHeaderTimeout
+ }
- // Set the readHeaderTimeout of the new conn to the value of the listener
- newConn.readHeaderTimeout = p.ReadHeaderTimeout
+ // Set the readHeaderTimeout of the new conn to the value of the listener
+ newConn.readHeaderTimeout = p.ReadHeaderTimeout
- return newConn, nil
- }
+ return newConn, nil
}
// Close closes the underlying listener.
@@ -151,15 +125,8 @@ func (p *Listener) Addr() net.Addr {
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
- // For v1 the header length is at most 108 bytes.
- // For v2 the header length is at most 52 bytes plus the length of the TLVs.
- // We use 256 bytes to be safe.
- const bufSize = 256
- br := bufio.NewReaderSize(conn, bufSize)
-
pConn := &Conn{
- bufReader: br,
- reader: io.MultiReader(br, conn),
+ bufReader: bufio.NewReader(conn),
conn: conn,
}
@@ -181,7 +148,7 @@ func (p *Conn) Read(b []byte) (int, error) {
return 0, p.readErr
}
- return p.reader.Read(b)
+ return p.bufReader.Read(b)
}
// Write wraps original conn.Write
@@ -363,27 +330,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
if p.readErr != nil {
return 0, p.readErr
}
-
- b := make([]byte, p.bufReader.Buffered())
- if _, err := p.bufReader.Read(b); err != nil {
- return 0, err // this should never as we read buffered data
- }
-
- var n int64
- {
- nn, err := w.Write(b)
- n += int64(nn)
- if err != nil {
- return n, err
- }
- }
- {
- nn, err := io.Copy(w, p.conn)
- n += nn
- if err != nil {
- return n, err
- }
- }
-
- return n, nil
+ return p.bufReader.WriteTo(w)
}