Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxyproto: PROXY protocol net.Conn and net.Listener impl #919

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/kevinburke/hostsfile v0.0.0-20220522040509-e5e984885321
github.com/mitchellh/go-wordwrap v1.0.1
github.com/mmatczuk/anyflag v0.0.0-20240709090339-eb9e24cd1b44
github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.20.4
github.com/prometheus/client_model v0.6.1
github.com/prometheus/common v0.59.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs=
github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down
177 changes: 177 additions & 0 deletions proxyproto/proxyproto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright 2022-2024 Sauce Labs Inc., all rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package proxyproto

import (
"bufio"
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/pires/go-proxyproto"
)

// Conn wraps a net.Conn and provides access to the proxy protocol header.
// If the header is not present or cannot be read within the timeout,
// the connection is closed.
type Conn struct {
net.Conn

readHeaderTimeout time.Duration
isHeaderRead atomic.Bool
headerMu sync.Mutex
header proxyproto.Header
headerErr error
}

func (c *Conn) NetConn() net.Conn {
return c.Conn
}

func (c *Conn) LocalAddr() net.Addr {
if err := c.readHeader(); err != nil {
return c.Conn.LocalAddr()
}

if c.headerErr != nil || c.header.Command.IsLocal() {
return c.Conn.LocalAddr()
}

return c.header.DestinationAddr
}

func (c *Conn) RemoteAddr() net.Addr {
if err := c.readHeader(); err != nil {
return c.Conn.RemoteAddr()
}

if c.headerErr != nil || c.header.Command.IsLocal() {
return c.Conn.RemoteAddr()
}

return c.header.SourceAddr
}

func (c *Conn) Read(b []byte) (n int, err error) {
if err := c.readHeader(); err != nil {
return 0, err
}
return c.Conn.Read(b)
}

func (c *Conn) Write(b []byte) (n int, err error) {
if err := c.readHeader(); err != nil {
return 0, err
}
return c.Conn.Write(b)
}

func (c *Conn) Header() (proxyproto.Header, error) {
return c.HeaderContext(context.Background())
}

func (c *Conn) HeaderContext(ctx context.Context) (proxyproto.Header, error) {
if err := c.readHeaderContext(ctx); err != nil {
return proxyproto.Header{}, err
}
return c.header, nil
}

func (c *Conn) readHeader() error {
return c.readHeaderContext(context.Background())
}

func (c *Conn) readHeaderContext(ctx context.Context) error {
if c.isHeaderRead.Load() {
return c.headerErr
}

c.headerMu.Lock()
defer c.headerMu.Unlock()

if c.isHeaderRead.Load() {
return c.headerErr
}

t0 := time.Now()
if c.readHeaderTimeout > 0 {
if d, ok := ctx.Deadline(); !ok || d.Sub(t0) > c.readHeaderTimeout {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, c.readHeaderTimeout)
defer cancel()
}
}

type result struct {
header *proxyproto.Header
err error
}
resCh := make(chan result)

go func() {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256
// Use a byteReader to read only one byte at a time,
// so we can read the header without consuming more bytes than needed.
// On success, the reader must be empty.
// Otherwise, the connection is closed on timeout or never read on error.
br := bufio.NewReaderSize(byteReader{c.Conn}, bufSize)

var r result
r.header, r.err = proxyproto.Read(br)

if r.err == nil && br.Buffered() > 0 {
panic("proxy protocol header read: unexpected data after header")
}

resCh <- r
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may leak.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make buffered chanel.

}()

select {
case <-ctx.Done():
c.Conn.Close()
c.headerErr = fmt.Errorf("proxy protocol header read timeout: %w", ctx.Err())
case r := <-resCh:
c.header = *r.header
c.headerErr = r.err
}

c.isHeaderRead.Store(true)

return c.headerErr
}

type Listener struct {
net.Listener
ReadHeaderTimeout time.Duration
}

func (l *Listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}

return &Conn{
Conn: c,
readHeaderTimeout: l.ReadHeaderTimeout,
}, nil
}

type byteReader struct {
r io.Reader
}

func (r byteReader) Read(p []byte) (int, error) {
return r.r.Read(p[:1])
}