diff options
Diffstat (limited to 'vendor/github.com/pires/go-proxyproto/protocol.go')
| -rw-r--r-- | vendor/github.com/pires/go-proxyproto/protocol.go | 147 |
1 files changed, 46 insertions, 101 deletions
diff --git a/vendor/github.com/pires/go-proxyproto/protocol.go b/vendor/github.com/pires/go-proxyproto/protocol.go index 270b90d..658900a 100644 --- a/vendor/github.com/pires/go-proxyproto/protocol.go +++ b/vendor/github.com/pires/go-proxyproto/protocol.go @@ -2,8 +2,6 @@ package proxyproto import ( "bufio" - "errors" - "fmt" "io" "net" "sync" @@ -11,17 +9,11 @@ import ( "time" ) -var ( - // DefaultReadHeaderTimeout is how long header processing waits for header to - // be read from the wire, if Listener.ReaderHeaderTimeout is not set. - // It's kept as a global variable so to make it easier to find and override, - // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" - DefaultReadHeaderTimeout = 10 * time.Second - - // ErrInvalidUpstream should be returned when an upstream connection address - // is not trusted, and therefore is invalid. - ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information") -) +// DefaultReadHeaderTimeout is how long header processing waits for header to +// be read from the wire, if Listener.ReaderHeaderTimeout is not set. +// It's kept as a global variable so to make it easier to find and override, +// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" +var DefaultReadHeaderTimeout = 10 * time.Second // Listener is used to wrap an underlying listener, // whose connections may be using the HAProxy Proxy Protocol. @@ -51,11 +43,10 @@ type Conn struct { once sync.Once readErr error conn net.Conn + Validate Validator bufReader *bufio.Reader - reader io.Reader header *Header ProxyHeaderPolicy Policy - Validate Validator readHeaderTimeout time.Duration } @@ -72,70 +63,53 @@ func ValidateHeader(v Validator) func(*Conn) { } } -// SetReadHeaderTimeout sets the readHeaderTimeout for a connection when passed as option to NewConn() -func SetReadHeaderTimeout(t time.Duration) func(*Conn) { - return func(c *Conn) { - if t >= 0 { - c.readHeaderTimeout = t - } +// Accept waits for and returns the next connection to the listener. +func (p *Listener) Accept() (net.Conn, error) { + // Get the underlying connection + conn, err := p.Listener.Accept() + if err != nil { + return nil, err } -} -// Accept waits for and returns the next valid connection to the listener. -func (p *Listener) Accept() (net.Conn, error) { - for { - // Get the underlying connection - conn, err := p.Listener.Accept() + proxyHeaderPolicy := USE + if p.Policy != nil && p.ConnPolicy != nil { + panic("only one of policy or connpolicy must be provided.") + } + if p.Policy != nil || p.ConnPolicy != nil { + if p.Policy != nil { + proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) + } else { + proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{ + Upstream: conn.RemoteAddr(), + Downstream: conn.LocalAddr(), + }) + } if err != nil { + // can't decide the policy, we can't accept the connection + conn.Close() return nil, err } - - proxyHeaderPolicy := USE - if p.Policy != nil && p.ConnPolicy != nil { - panic("only one of policy or connpolicy must be provided.") - } - if p.Policy != nil || p.ConnPolicy != nil { - if p.Policy != nil { - proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) - } else { - proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{ - Upstream: conn.RemoteAddr(), - Downstream: conn.LocalAddr(), - }) - } - if err != nil { - // can't decide the policy, we can't accept the connection - conn.Close() - - if errors.Is(err, ErrInvalidUpstream) { - // keep listening for other connections - continue - } - - return nil, err - } - // Handle a connection as a regular one - if proxyHeaderPolicy == SKIP { - return conn, nil - } + // Handle a connection as a regular one + if proxyHeaderPolicy == SKIP { + return conn, nil } + } - newConn := NewConn( - conn, - WithPolicy(proxyHeaderPolicy), - ValidateHeader(p.ValidateHeader), - ) + newConn := NewConn( + conn, + WithPolicy(proxyHeaderPolicy), + ValidateHeader(p.ValidateHeader), + ) - // If the ReadHeaderTimeout for the listener is unset, use the default timeout. - if p.ReadHeaderTimeout == 0 { - p.ReadHeaderTimeout = DefaultReadHeaderTimeout - } + // If the ReadHeaderTimeout for the listener is unset, use the default timeout. + if p.ReadHeaderTimeout == 0 { + p.ReadHeaderTimeout = DefaultReadHeaderTimeout + } - // Set the readHeaderTimeout of the new conn to the value of the listener - newConn.readHeaderTimeout = p.ReadHeaderTimeout + // Set the readHeaderTimeout of the new conn to the value of the listener + newConn.readHeaderTimeout = p.ReadHeaderTimeout - return newConn, nil - } + return newConn, nil } // Close closes the underlying listener. @@ -151,15 +125,8 @@ func (p *Listener) Addr() net.Addr { // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { - // For v1 the header length is at most 108 bytes. - // For v2 the header length is at most 52 bytes plus the length of the TLVs. - // We use 256 bytes to be safe. - const bufSize = 256 - br := bufio.NewReaderSize(conn, bufSize) - pConn := &Conn{ - bufReader: br, - reader: io.MultiReader(br, conn), + bufReader: bufio.NewReader(conn), conn: conn, } @@ -181,7 +148,7 @@ func (p *Conn) Read(b []byte) (int, error) { return 0, p.readErr } - return p.reader.Read(b) + return p.bufReader.Read(b) } // Write wraps original conn.Write @@ -363,27 +330,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) { if p.readErr != nil { return 0, p.readErr } - - b := make([]byte, p.bufReader.Buffered()) - if _, err := p.bufReader.Read(b); err != nil { - return 0, err // this should never as we read buffered data - } - - var n int64 - { - nn, err := w.Write(b) - n += int64(nn) - if err != nil { - return n, err - } - } - { - nn, err := io.Copy(w, p.conn) - n += nn - if err != nil { - return n, err - } - } - - return n, nil + return p.bufReader.WriteTo(w) } |
