From 2ed8b1ce1c655181f431bf41becfdc25d44f2f06 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 28 Apr 2023 02:58:14 +0000 Subject: [PATCH 1/9] Simplify APIs --- transport/packet.go | 23 ++++---- transport/packet_test.go | 49 +++++++++++++++++ transport/shadowsocks/cipher_test.go | 7 ++- transport/stream.go | 79 +++++++++++++++++----------- transport/stream_test.go | 71 ++++++++++++++++++++----- 5 files changed, 170 insertions(+), 59 deletions(-) create mode 100644 transport/packet_test.go diff --git a/transport/packet.go b/transport/packet.go index 87775f45..5685428e 100644 --- a/transport/packet.go +++ b/transport/packet.go @@ -25,27 +25,24 @@ type PacketEndpoint interface { Connect(ctx context.Context) (net.Conn, error) } -// PacketListener provides a way to create a local unbound packet connection to send packets to different destinations. -type PacketListener interface { - // ListenPacket creates a PacketConn that can be used to relay packets (such as UDP) through some proxy. - ListenPacket(ctx context.Context) (net.PacketConn, error) -} - // UDPEndpoint is a [PacketEndpoint] that connects to the given address via UDP type UDPEndpoint struct { // The Dialer used to create the net.Conn on Connect(). Dialer net.Dialer - // The remote address to pass to Dial. - RemoteAddr net.UDPAddr + // The remote address (host:port) to pass to Dial. + // If the host is a domain name, consider pre-resolving it to avoid resolution calls. + RemoteAddr string } var _ PacketEndpoint = (*UDPEndpoint)(nil) // Connect implements [PacketEndpoint.Connect]. func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) { - conn, err := e.Dialer.DialContext(ctx, "udp", e.RemoteAddr.String()) - if err != nil { - return nil, err - } - return conn, nil + return e.Dialer.DialContext(ctx, "udp", e.RemoteAddr) +} + +// PacketListener provides a way to create a local unbound packet connection to send packets to different destinations. +type PacketListener interface { + // ListenPacket creates a PacketConn that can be used to relay packets (such as UDP) through some proxy. + ListenPacket(ctx context.Context) (net.PacketConn, error) } diff --git a/transport/packet_test.go b/transport/packet_test.go new file mode 100644 index 00000000..cc862a70 --- /dev/null +++ b/transport/packet_test.go @@ -0,0 +1,49 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "context" + "syscall" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUDPEndpointIPv4(t *testing.T) { + const serverAddr = "127.0.0.10:8888" + ep := &UDPEndpoint{RemoteAddr: serverAddr} + ep.Dialer.Control = func(network, address string, c syscall.RawConn) error { + require.Equal(t, "udp4", network) + require.Equal(t, serverAddr, address) + return nil + } + conn, err := ep.Connect(context.Background()) + require.Nil(t, err) + require.Equal(t, serverAddr, conn.RemoteAddr().String()) +} + +func TestUDPEndpointIPv6(t *testing.T) { + const serverAddr = "[::1]:8888" + ep := &UDPEndpoint{RemoteAddr: serverAddr} + ep.Dialer.Control = func(network, address string, c syscall.RawConn) error { + require.Equal(t, "udp6", network) + require.Equal(t, serverAddr, address) + return nil + } + conn, err := ep.Connect(context.Background()) + require.Nil(t, err) + require.Equal(t, serverAddr, conn.RemoteAddr().String()) +} diff --git a/transport/shadowsocks/cipher_test.go b/transport/shadowsocks/cipher_test.go index 17d2c3d5..b76ba324 100644 --- a/transport/shadowsocks/cipher_test.go +++ b/transport/shadowsocks/cipher_test.go @@ -17,6 +17,7 @@ package shadowsocks import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -59,8 +60,10 @@ func TestShadowsocksCipherNames(t *testing.T) { func TestUnsupportedCipher(t *testing.T) { _, err := CipherByName("aes-256-cfb") - if err == nil { - t.Errorf("Should get an error for unsupported cipher") + var unsupportedErr ErrUnsupportedCipher + if assert.ErrorAs(t, err, &unsupportedErr) { + assert.Equal(t, "aes-256-cfb", unsupportedErr.Name) + assert.Equal(t, "unsupported cipher aes-256-cfb", unsupportedErr.Error()) } } diff --git a/transport/stream.go b/transport/stream.go index 5f3d6486..9e18320e 100644 --- a/transport/stream.go +++ b/transport/stream.go @@ -32,44 +32,14 @@ type StreamConn interface { CloseWrite() error } -// StreamEndpoint represents an endpoint that can be used to established stream connections (like TCP) to a fixed destination. -type StreamEndpoint interface { - // Connect establishes a connection with the endpoint, returning the connection. - Connect(ctx context.Context) (StreamConn, error) -} - -// StreamDialer provides a way to dial a destination and establish stream connections. -type StreamDialer interface { - // Dial connects to `raddr`. - // `raddr` has the form `host:port`, where `host` can be a domain name or IP address. - Dial(ctx context.Context, raddr string) (StreamConn, error) -} - -// TCPEndpoint is a [StreamEndpoint] that connects to the given address via TCP. -type TCPEndpoint struct { - // The Dialer used to create the connection on Connect(). - Dialer net.Dialer - // The remote address to pass to DialTCP. - RemoteAddr net.TCPAddr -} - -var _ StreamEndpoint = (*TCPEndpoint)(nil) - -// Connect implements [StreamEndpoint.Connect]. -func (e TCPEndpoint) Connect(ctx context.Context) (StreamConn, error) { - conn, err := e.Dialer.DialContext(ctx, "tcp", e.RemoteAddr.String()) - if err != nil { - return nil, err - } - return conn.(*net.TCPConn), nil -} - type duplexConnAdaptor struct { StreamConn r io.Reader w io.Writer } +var _ StreamConn = (*duplexConnAdaptor)(nil) + func (dc *duplexConnAdaptor) Read(b []byte) (int, error) { return dc.r.Read(b) } @@ -99,3 +69,48 @@ func WrapConn(c StreamConn, r io.Reader, w io.Writer) StreamConn { } return &duplexConnAdaptor{StreamConn: conn, r: r, w: w} } + +// StreamEndpoint represents an endpoint that can be used to established stream connections (like TCP) to a fixed destination. +type StreamEndpoint interface { + // Connect establishes a connection with the endpoint, returning the connection. + Connect(ctx context.Context) (StreamConn, error) +} + +// DialEndpoint is a [StreamEndpoint] that connects to the given address using the given [StreamDialer]. +type DialEndpoint struct { + // The Dialer used to create the connection on Connect(). + Dialer StreamDialer + // The remote address (host:port) to pass to Dial. + // If the host is a domain name, consider pre-resolving it to avoid resolution calls. + RemoteAddr string +} + +var _ StreamEndpoint = (*DialEndpoint)(nil) + +// Connect implements [StreamEndpoint.Connect]. +func (e DialEndpoint) Connect(ctx context.Context) (StreamConn, error) { + return e.Dialer.Dial(ctx, e.RemoteAddr) +} + +// StreamDialer provides a way to dial a destination and establish stream connections. +type StreamDialer interface { + // Dial connects to `raddr`. + // `raddr` has the form `host:port`, where `host` can be a domain name or IP address. + Dial(ctx context.Context, raddr string) (StreamConn, error) +} + +// TCPStreamDialer is a [StreamDialer] that uses the standard [net.Dialer] to dial. +// It provides a convenient way to use a [net.Dialer] when you need a [StreamDialer]. +type TCPStreamDialer struct { + Dialer net.Dialer +} + +var _ StreamDialer = (*TCPStreamDialer)(nil) + +func (d *TCPStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, error) { + conn, err := d.Dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + return conn.(*net.TCPConn), nil +} diff --git a/transport/stream_test.go b/transport/stream_test.go index 8adf6cb8..d55495b1 100644 --- a/transport/stream_test.go +++ b/transport/stream_test.go @@ -16,20 +16,23 @@ package transport import ( "context" + "errors" "net" "sync" + "syscall" "testing" "testing/iotest" + + "github.com/stretchr/testify/require" ) -func TestNewTCPEndpointIPv4(t *testing.T) { +func TestNewTCPStreamDialerIPv4(t *testing.T) { requestText := []byte("Request") responseText := []byte("Response") - listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}) - if err != nil { - t.Fatalf("Failed to create TCP listener: %v", err) - } + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 10)}) + require.Nil(t, err, "Failed to create TCP listener") + var running sync.WaitGroup running.Add(1) go func() { @@ -59,17 +62,61 @@ func TestNewTCPEndpointIPv4(t *testing.T) { } }() - e := TCPEndpoint{RemoteAddr: *listener.Addr().(*net.TCPAddr)} - serverConn, err := e.Connect(context.Background()) - if err != nil { - t.Fatalf("Connect failed: %v", err) + dialer := &TCPStreamDialer{} + dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { + require.Equal(t, "tcp4", network) + require.Equal(t, listener.Addr().String(), address) + return nil } + serverConn, err := dialer.Dial(context.Background(), listener.Addr().String()) + require.Nil(t, err, "Dial failed") + require.Equal(t, listener.Addr().String(), serverConn.RemoteAddr().String()) defer serverConn.Close() + serverConn.Write(requestText) serverConn.CloseWrite() - if err = iotest.TestReader(serverConn, responseText); err != nil { - t.Fatalf("Response read failed: %v", err) - } + + require.Nil(t, iotest.TestReader(serverConn, responseText), "Response read failed") serverConn.CloseRead() + running.Wait() } + +func TestNewTCPStreamDialerAddress(t *testing.T) { + errCancel := errors.New("cancelled") + dialer := &TCPStreamDialer{} + + dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { + require.Equal(t, "tcp4", network) + require.Equal(t, "8.8.8.8:53", address) + return errCancel + } + _, err := dialer.Dial(context.Background(), "8.8.8.8:53") + require.ErrorIs(t, err, errCancel) + + dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { + require.Equal(t, "tcp6", network) + require.Equal(t, "[2001:4860:4860::8888]:53", address) + return errCancel + } + _, err = dialer.Dial(context.Background(), "[2001:4860:4860::8888]:53") + require.ErrorIs(t, err, errCancel) +} + +func TestDialEndpointAddr(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 2)}) + require.Nil(t, err, "Failed to create TCP listener") + defer listener.Close() + + dialer := &TCPStreamDialer{} + dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { + require.Equal(t, "tcp4", network) + require.Equal(t, listener.Addr().String(), address) + return nil + } + endpoint := DialEndpoint{Dialer: dialer, RemoteAddr: listener.Addr().String()} + conn, err := endpoint.Connect(context.Background()) + require.Nil(t, err) + require.Equal(t, listener.Addr().String(), conn.RemoteAddr().String()) + require.Nil(t, conn.Close()) +} From dc81900173419a40632005d020c11f1b47f61f98 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 28 Apr 2023 03:01:57 +0000 Subject: [PATCH 2/9] Rename --- transport/stream.go | 8 ++++---- transport/stream_test.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/transport/stream.go b/transport/stream.go index 9e18320e..7be9ff4a 100644 --- a/transport/stream.go +++ b/transport/stream.go @@ -76,8 +76,8 @@ type StreamEndpoint interface { Connect(ctx context.Context) (StreamConn, error) } -// DialEndpoint is a [StreamEndpoint] that connects to the given address using the given [StreamDialer]. -type DialEndpoint struct { +// DialStreamEndpoint is a [StreamEndpoint] that connects to the given address using the given [StreamDialer]. +type DialStreamEndpoint struct { // The Dialer used to create the connection on Connect(). Dialer StreamDialer // The remote address (host:port) to pass to Dial. @@ -85,10 +85,10 @@ type DialEndpoint struct { RemoteAddr string } -var _ StreamEndpoint = (*DialEndpoint)(nil) +var _ StreamEndpoint = (*DialStreamEndpoint)(nil) // Connect implements [StreamEndpoint.Connect]. -func (e DialEndpoint) Connect(ctx context.Context) (StreamConn, error) { +func (e DialStreamEndpoint) Connect(ctx context.Context) (StreamConn, error) { return e.Dialer.Dial(ctx, e.RemoteAddr) } diff --git a/transport/stream_test.go b/transport/stream_test.go index d55495b1..8eb6111a 100644 --- a/transport/stream_test.go +++ b/transport/stream_test.go @@ -103,7 +103,7 @@ func TestNewTCPStreamDialerAddress(t *testing.T) { require.ErrorIs(t, err, errCancel) } -func TestDialEndpointAddr(t *testing.T) { +func TestDialStreamEndpointAddr(t *testing.T) { listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 2)}) require.Nil(t, err, "Failed to create TCP listener") defer listener.Close() @@ -114,7 +114,7 @@ func TestDialEndpointAddr(t *testing.T) { require.Equal(t, listener.Addr().String(), address) return nil } - endpoint := DialEndpoint{Dialer: dialer, RemoteAddr: listener.Addr().String()} + endpoint := DialStreamEndpoint{Dialer: dialer, RemoteAddr: listener.Addr().String()} conn, err := endpoint.Connect(context.Background()) require.Nil(t, err) require.Equal(t, listener.Addr().String(), conn.RemoteAddr().String()) From 2d5f1bf0f10fd4802620f27554c482b8938a8360 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 28 Apr 2023 05:21:11 +0000 Subject: [PATCH 3/9] Fix tests --- transport/shadowsocks/client/packet_listener_test.go | 4 ++-- transport/shadowsocks/client/stream_dialer_test.go | 10 +++++----- transport/stream.go | 5 ++++- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/transport/shadowsocks/client/packet_listener_test.go b/transport/shadowsocks/client/packet_listener_test.go index f7a33eda..5111c3d1 100644 --- a/transport/shadowsocks/client/packet_listener_test.go +++ b/transport/shadowsocks/client/packet_listener_test.go @@ -30,7 +30,7 @@ import ( func TestShadowsocksPacketListener_ListenPacket(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksUDPEchoServer(key, testTargetAddr, t) - proxyEndpoint := transport.UDPEndpoint{RemoteAddr: *proxy.LocalAddr().(*net.UDPAddr)} + proxyEndpoint := transport.UDPEndpoint{RemoteAddr: proxy.LocalAddr().String()} d, err := NewShadowsocksPacketListener(proxyEndpoint, key) if err != nil { t.Fatalf("Failed to create PacketListener: %v", err) @@ -54,7 +54,7 @@ func BenchmarkShadowsocksPacketListener_ListenPacket(b *testing.B) { key := makeTestKey(b) proxy, running := startShadowsocksUDPEchoServer(key, testTargetAddr, b) - proxyEndpoint := transport.UDPEndpoint{RemoteAddr: *proxy.LocalAddr().(*net.UDPAddr)} + proxyEndpoint := transport.UDPEndpoint{RemoteAddr: proxy.LocalAddr().String()} d, err := NewShadowsocksPacketListener(proxyEndpoint, key) if err != nil { b.Fatalf("Failed to create PacketListener: %v", err) diff --git a/transport/shadowsocks/client/stream_dialer_test.go b/transport/shadowsocks/client/stream_dialer_test.go index 5a6fd165..be87dc7b 100644 --- a/transport/shadowsocks/client/stream_dialer_test.go +++ b/transport/shadowsocks/client/stream_dialer_test.go @@ -30,7 +30,7 @@ import ( func TestShadowsocksStreamDialer_Dial(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, t) - proxyEndpoint := transport.TCPEndpoint{RemoteAddr: *proxy.Addr().(*net.TCPAddr)} + proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: proxy.Addr().String()} d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) @@ -50,7 +50,7 @@ func TestShadowsocksStreamDialer_Dial(t *testing.T) { func TestShadowsocksStreamDialer_DialNoPayload(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, t) - proxyEndpoint := transport.TCPEndpoint{RemoteAddr: *proxy.Addr().(*net.TCPAddr)} + proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: proxy.Addr().String()} d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) @@ -93,7 +93,7 @@ func TestShadowsocksStreamDialer_DialFastClose(t *testing.T) { }() key := makeTestKey(t) - proxyEndpoint := transport.TCPEndpoint{RemoteAddr: *listener.Addr().(*net.TCPAddr)} + proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: listener.Addr().String()} d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) @@ -142,7 +142,7 @@ func TestShadowsocksStreamDialer_TCPPrefix(t *testing.T) { }() key := makeTestKey(t) - proxyEndpoint := transport.TCPEndpoint{RemoteAddr: *listener.Addr().(*net.TCPAddr)} + proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: listener.Addr().String()} d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) @@ -163,7 +163,7 @@ func BenchmarkShadowsocksStreamDialer_Dial(b *testing.B) { key := makeTestKey(b) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, b) - proxyEndpoint := transport.TCPEndpoint{RemoteAddr: *proxy.Addr().(*net.TCPAddr)} + proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: proxy.Addr().String()} d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) if err != nil { b.Fatalf("Failed to create StreamDialer: %v", err) diff --git a/transport/stream.go b/transport/stream.go index 7be9ff4a..9b120164 100644 --- a/transport/stream.go +++ b/transport/stream.go @@ -78,7 +78,7 @@ type StreamEndpoint interface { // DialStreamEndpoint is a [StreamEndpoint] that connects to the given address using the given [StreamDialer]. type DialStreamEndpoint struct { - // The Dialer used to create the connection on Connect(). + // The Dialer used to create the connection on Connect(). If nil, it uses a TCP net.Dialer. Dialer StreamDialer // The remote address (host:port) to pass to Dial. // If the host is a domain name, consider pre-resolving it to avoid resolution calls. @@ -89,6 +89,9 @@ var _ StreamEndpoint = (*DialStreamEndpoint)(nil) // Connect implements [StreamEndpoint.Connect]. func (e DialStreamEndpoint) Connect(ctx context.Context) (StreamConn, error) { + if e.Dialer == nil { + return (&TCPStreamDialer{}).Dial(ctx, e.RemoteAddr) + } return e.Dialer.Dial(ctx, e.RemoteAddr) } From 9d092c89a57e568d14d4374955b770fb4b4178cf Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 28 Apr 2023 05:44:22 +0000 Subject: [PATCH 4/9] Restore TCPEndpoint --- transport/packet.go | 6 ++--- transport/packet_test.go | 4 ++-- .../client/packet_listener_test.go | 4 ++-- .../shadowsocks/client/stream_dialer_test.go | 15 ++++-------- transport/stream.go | 23 ++++++++++--------- transport/stream_test.go | 5 ++-- 6 files changed, 26 insertions(+), 31 deletions(-) diff --git a/transport/packet.go b/transport/packet.go index 5685428e..63ec0c0b 100644 --- a/transport/packet.go +++ b/transport/packet.go @@ -29,16 +29,16 @@ type PacketEndpoint interface { type UDPEndpoint struct { // The Dialer used to create the net.Conn on Connect(). Dialer net.Dialer - // The remote address (host:port) to pass to Dial. + // The endpoint address (host:port) to pass to Dial. // If the host is a domain name, consider pre-resolving it to avoid resolution calls. - RemoteAddr string + Address string } var _ PacketEndpoint = (*UDPEndpoint)(nil) // Connect implements [PacketEndpoint.Connect]. func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) { - return e.Dialer.DialContext(ctx, "udp", e.RemoteAddr) + return e.Dialer.DialContext(ctx, "udp", e.Address) } // PacketListener provides a way to create a local unbound packet connection to send packets to different destinations. diff --git a/transport/packet_test.go b/transport/packet_test.go index cc862a70..31fddbc3 100644 --- a/transport/packet_test.go +++ b/transport/packet_test.go @@ -24,7 +24,7 @@ import ( func TestUDPEndpointIPv4(t *testing.T) { const serverAddr = "127.0.0.10:8888" - ep := &UDPEndpoint{RemoteAddr: serverAddr} + ep := &UDPEndpoint{Address: serverAddr} ep.Dialer.Control = func(network, address string, c syscall.RawConn) error { require.Equal(t, "udp4", network) require.Equal(t, serverAddr, address) @@ -37,7 +37,7 @@ func TestUDPEndpointIPv4(t *testing.T) { func TestUDPEndpointIPv6(t *testing.T) { const serverAddr = "[::1]:8888" - ep := &UDPEndpoint{RemoteAddr: serverAddr} + ep := &UDPEndpoint{Address: serverAddr} ep.Dialer.Control = func(network, address string, c syscall.RawConn) error { require.Equal(t, "udp6", network) require.Equal(t, serverAddr, address) diff --git a/transport/shadowsocks/client/packet_listener_test.go b/transport/shadowsocks/client/packet_listener_test.go index 5111c3d1..1e77b37b 100644 --- a/transport/shadowsocks/client/packet_listener_test.go +++ b/transport/shadowsocks/client/packet_listener_test.go @@ -30,7 +30,7 @@ import ( func TestShadowsocksPacketListener_ListenPacket(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksUDPEchoServer(key, testTargetAddr, t) - proxyEndpoint := transport.UDPEndpoint{RemoteAddr: proxy.LocalAddr().String()} + proxyEndpoint := transport.UDPEndpoint{Address: proxy.LocalAddr().String()} d, err := NewShadowsocksPacketListener(proxyEndpoint, key) if err != nil { t.Fatalf("Failed to create PacketListener: %v", err) @@ -54,7 +54,7 @@ func BenchmarkShadowsocksPacketListener_ListenPacket(b *testing.B) { key := makeTestKey(b) proxy, running := startShadowsocksUDPEchoServer(key, testTargetAddr, b) - proxyEndpoint := transport.UDPEndpoint{RemoteAddr: proxy.LocalAddr().String()} + proxyEndpoint := transport.UDPEndpoint{Address: proxy.LocalAddr().String()} d, err := NewShadowsocksPacketListener(proxyEndpoint, key) if err != nil { b.Fatalf("Failed to create PacketListener: %v", err) diff --git a/transport/shadowsocks/client/stream_dialer_test.go b/transport/shadowsocks/client/stream_dialer_test.go index be87dc7b..26931515 100644 --- a/transport/shadowsocks/client/stream_dialer_test.go +++ b/transport/shadowsocks/client/stream_dialer_test.go @@ -30,8 +30,7 @@ import ( func TestShadowsocksStreamDialer_Dial(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, t) - proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: proxy.Addr().String()} - d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) + d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -50,8 +49,7 @@ func TestShadowsocksStreamDialer_Dial(t *testing.T) { func TestShadowsocksStreamDialer_DialNoPayload(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, t) - proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: proxy.Addr().String()} - d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) + d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -93,8 +91,7 @@ func TestShadowsocksStreamDialer_DialFastClose(t *testing.T) { }() key := makeTestKey(t) - proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: listener.Addr().String()} - d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) + d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -142,8 +139,7 @@ func TestShadowsocksStreamDialer_TCPPrefix(t *testing.T) { }() key := makeTestKey(t) - proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: listener.Addr().String()} - d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) + d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -163,8 +159,7 @@ func BenchmarkShadowsocksStreamDialer_Dial(b *testing.B) { key := makeTestKey(b) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, b) - proxyEndpoint := transport.DialStreamEndpoint{RemoteAddr: proxy.Addr().String()} - d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) + d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) if err != nil { b.Fatalf("Failed to create StreamDialer: %v", err) } diff --git a/transport/stream.go b/transport/stream.go index 9b120164..cae32bae 100644 --- a/transport/stream.go +++ b/transport/stream.go @@ -76,23 +76,24 @@ type StreamEndpoint interface { Connect(ctx context.Context) (StreamConn, error) } -// DialStreamEndpoint is a [StreamEndpoint] that connects to the given address using the given [StreamDialer]. -type DialStreamEndpoint struct { - // The Dialer used to create the connection on Connect(). If nil, it uses a TCP net.Dialer. - Dialer StreamDialer - // The remote address (host:port) to pass to Dial. +// TCPEndpoint is a [StreamEndpoint] that connects to the given address using the given [StreamDialer]. +type TCPEndpoint struct { + // The Dialer used to create the net.Conn on Connect(). + Dialer net.Dialer + // The endpoint address (host:port) to pass to Dial. // If the host is a domain name, consider pre-resolving it to avoid resolution calls. - RemoteAddr string + Address string } -var _ StreamEndpoint = (*DialStreamEndpoint)(nil) +var _ StreamEndpoint = (*TCPEndpoint)(nil) // Connect implements [StreamEndpoint.Connect]. -func (e DialStreamEndpoint) Connect(ctx context.Context) (StreamConn, error) { - if e.Dialer == nil { - return (&TCPStreamDialer{}).Dial(ctx, e.RemoteAddr) +func (e *TCPEndpoint) Connect(ctx context.Context) (StreamConn, error) { + conn, err := e.Dialer.DialContext(ctx, "tcp", e.Address) + if err != nil { + return nil, err } - return e.Dialer.Dial(ctx, e.RemoteAddr) + return conn.(*net.TCPConn), nil } // StreamDialer provides a way to dial a destination and establish stream connections. diff --git a/transport/stream_test.go b/transport/stream_test.go index 8eb6111a..e1f37dae 100644 --- a/transport/stream_test.go +++ b/transport/stream_test.go @@ -108,13 +108,12 @@ func TestDialStreamEndpointAddr(t *testing.T) { require.Nil(t, err, "Failed to create TCP listener") defer listener.Close() - dialer := &TCPStreamDialer{} - dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { + endpoint := TCPEndpoint{Address: listener.Addr().String()} + endpoint.Dialer.Control = func(network, address string, c syscall.RawConn) error { require.Equal(t, "tcp4", network) require.Equal(t, listener.Addr().String(), address) return nil } - endpoint := DialStreamEndpoint{Dialer: dialer, RemoteAddr: listener.Addr().String()} conn, err := endpoint.Connect(context.Background()) require.Nil(t, err) require.Equal(t, listener.Addr().String(), conn.RemoteAddr().String()) From c434aa78ad91366790c5f77c17859e17f6e90ab4 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 28 Apr 2023 20:03:16 +0000 Subject: [PATCH 5/9] Hide Cipher type --- transport/shadowsocks/cipher.go | 62 ++++++++++++--------- transport/shadowsocks/cipher_test.go | 39 +++++++++---- transport/shadowsocks/compatibility_test.go | 4 +- transport/shadowsocks/stream.go | 4 +- 4 files changed, 65 insertions(+), 44 deletions(-) diff --git a/transport/shadowsocks/cipher.go b/transport/shadowsocks/cipher.go index 6ec4735d..4efd91ef 100644 --- a/transport/shadowsocks/cipher.go +++ b/transport/shadowsocks/cipher.go @@ -26,8 +26,7 @@ import ( "golang.org/x/crypto/hkdf" ) -type Cipher struct { - name string +type cipherSpec struct { newInstance func(key []byte) (cipher.AEAD, error) keySize int saltSize int @@ -36,13 +35,20 @@ type Cipher struct { // List of supported AEAD ciphers, as specified at https://shadowsocks.org/guide/aead.html var ( - CHACHA20IETFPOLY1305 = &Cipher{"AEAD_CHACHA20_POLY1305", chacha20poly1305.New, chacha20poly1305.KeySize, 32, 16} - AES256GCM = &Cipher{"AEAD_AES_256_GCM", newAesGCM, 32, 32, 16} - AES192GCM = &Cipher{"AEAD_AES_192_GCM", newAesGCM, 24, 24, 16} - AES128GCM = &Cipher{"AEAD_AES_128_GCM", newAesGCM, 16, 16, 16} + CHACHA20IETFPOLY1305 = "AEAD_CHACHA20_POLY1305" + AES256GCM = "AEAD_AES_256_GCM" + AES192GCM = "AEAD_AES_192_GCM" + AES128GCM = "AEAD_AES_128_GCM" ) -var supportedCiphers = [](*Cipher){CHACHA20IETFPOLY1305, AES256GCM, AES192GCM, AES128GCM} +var ( + chacha20IETFPOLY1305Cipher = &cipherSpec{chacha20poly1305.New, chacha20poly1305.KeySize, 32, 16} + aes256GCMCipher = &cipherSpec{newAesGCM, 32, 32, 16} + aes192GCMCipher = &cipherSpec{newAesGCM, 24, 24, 16} + aes128GCMCipher = &cipherSpec{newAesGCM, 16, 16, 16} +) + +var supportedCiphers = [](string){CHACHA20IETFPOLY1305, AES256GCM, AES192GCM, AES128GCM} // ErrUnsupportedCipher is returned by [CypherByName] when the named cipher is not supported. type ErrUnsupportedCipher struct { @@ -54,19 +60,22 @@ func (err ErrUnsupportedCipher) Error() string { return "unsupported cipher " + err.Name } +// Largest tag size among the supported ciphers. Used by the TCP buffer pool +const maxTagSize = 16 + // CipherByName returns a [*Cipher] with the given name, or an error if the cipher is not supported. // The name must be the IETF name (as per https://www.iana.org/assignments/aead-parameters/aead-parameters.xhtml) or the // Shadowsocks alias from https://shadowsocks.org/guide/aead.html. -func CipherByName(name string) (*Cipher, error) { +func cipherByName(name string) (*cipherSpec, error) { switch strings.ToUpper(name) { case "AEAD_CHACHA20_POLY1305", "CHACHA20-IETF-POLY1305": - return CHACHA20IETFPOLY1305, nil + return chacha20IETFPOLY1305Cipher, nil case "AEAD_AES_256_GCM", "AES-256-GCM": - return AES256GCM, nil + return aes256GCMCipher, nil case "AEAD_AES_192_GCM", "AES-192-GCM": - return AES192GCM, nil + return aes192GCMCipher, nil case "AEAD_AES_128_GCM", "AES-128-GCM": - return AES128GCM, nil + return aes128GCMCipher, nil default: return nil, ErrUnsupportedCipher{name} } @@ -80,19 +89,9 @@ func newAesGCM(key []byte) (cipher.AEAD, error) { return cipher.NewGCM(blk) } -func maxTagSize() int { - max := 0 - for _, spec := range supportedCiphers { - if spec.tagSize > max { - max = spec.tagSize - } - } - return max -} - // EncryptionKey encapsulates a Shadowsocks AEAD spec and a secret type EncryptionKey struct { - cipher *Cipher + cipher *cipherSpec secret []byte } @@ -138,12 +137,21 @@ func simpleEVPBytesToKey(data []byte, keyLen int) ([]byte, error) { return derived[:keyLen], nil } -// NewEncryptionKey creates a Cipher given a cipher name and a secret -func NewEncryptionKey(cipher *Cipher, secretText string) (*EncryptionKey, error) { +// NewEncryptionKey creates a Cipher given a cipher name and a secret. +// The cipher name must be the IETF name (as per https://www.iana.org/assignments/aead-parameters/aead-parameters.xhtml) +// or the Shadowsocks alias from https://shadowsocks.org/guide/aead.html. +func NewEncryptionKey(cipherName string, secretText string) (*EncryptionKey, error) { + var key EncryptionKey + var err error + key.cipher, err = cipherByName(cipherName) + if err != nil { + return nil, err + } + // Key derivation as per https://shadowsocks.org/en/spec/AEAD-Ciphers.html - secret, err := simpleEVPBytesToKey([]byte(secretText), cipher.keySize) + key.secret, err = simpleEVPBytesToKey([]byte(secretText), key.cipher.keySize) if err != nil { return nil, err } - return &EncryptionKey{cipher, secret}, nil + return &key, nil } diff --git a/transport/shadowsocks/cipher_test.go b/transport/shadowsocks/cipher_test.go index b76ba324..f15c4b8e 100644 --- a/transport/shadowsocks/cipher_test.go +++ b/transport/shadowsocks/cipher_test.go @@ -21,15 +21,15 @@ import ( "github.com/stretchr/testify/require" ) -func assertCipher(t *testing.T, cipher *Cipher, saltSize, tagSize int) { +func assertCipher(t *testing.T, cipher string, saltSize, tagSize int) { key, err := NewEncryptionKey(cipher, "") require.Nil(t, err) require.Equal(t, saltSize, key.SaltSize()) - dummyAead, err := key.NewAEAD(make([]byte, cipher.keySize)) + dummyAead, err := key.NewAEAD(make([]byte, key.SaltSize())) require.Nil(t, err) - require.Equal(t, dummyAead.Overhead(), key.TagSize()) require.Equal(t, tagSize, key.TagSize()) + require.Equal(t, key.TagSize(), dummyAead.Overhead()) } func TestSizes(t *testing.T) { @@ -41,25 +41,25 @@ func TestSizes(t *testing.T) { } func TestShadowsocksCipherNames(t *testing.T) { - cipher, err := CipherByName("chacha20-ietf-poly1305") + key, err := NewEncryptionKey("chacha20-ietf-poly1305", "") require.Nil(t, err) - require.Equal(t, CHACHA20IETFPOLY1305, cipher) + require.Equal(t, chacha20IETFPOLY1305Cipher, key.cipher) - cipher, err = CipherByName("aes-256-gcm") + key, err = NewEncryptionKey("aes-256-gcm", "") require.Nil(t, err) - require.Equal(t, AES256GCM, cipher) + require.Equal(t, aes256GCMCipher, key.cipher) - cipher, err = CipherByName("aes-192-gcm") + key, err = NewEncryptionKey("aes-192-gcm", "") require.Nil(t, err) - require.Equal(t, AES192GCM, cipher) + require.Equal(t, aes192GCMCipher, key.cipher) - cipher, err = CipherByName("aes-128-gcm") + key, err = NewEncryptionKey("aes-128-gcm", "") require.Nil(t, err) - require.Equal(t, AES128GCM, cipher) + require.Equal(t, aes128GCMCipher, key.cipher) } func TestUnsupportedCipher(t *testing.T) { - _, err := CipherByName("aes-256-cfb") + _, err := NewEncryptionKey("aes-256-cfb", "") var unsupportedErr ErrUnsupportedCipher if assert.ErrorAs(t, err, &unsupportedErr) { assert.Equal(t, "aes-256-cfb", unsupportedErr.Name) @@ -82,3 +82,18 @@ func TestMaxNonceSize(t *testing.T) { } } } + +func TestMaxTagSize(t *testing.T) { + var calculatedMax int + for _, cipher := range supportedCiphers { + key, err := NewEncryptionKey(cipher, "") + if !assert.Nilf(t, err, "Failed to create cipher %v", cipher) { + continue + } + assert.LessOrEqualf(t, key.TagSize(), maxTagSize, "Tag size for cipher %v (%v) is greater than the max (%v)", cipher, key.TagSize(), maxTagSize) + if key.TagSize() > calculatedMax { + calculatedMax = key.TagSize() + } + } + require.Equal(t, maxTagSize, calculatedMax) +} diff --git a/transport/shadowsocks/compatibility_test.go b/transport/shadowsocks/compatibility_test.go index 8d2650ad..10cfa0db 100644 --- a/transport/shadowsocks/compatibility_test.go +++ b/transport/shadowsocks/compatibility_test.go @@ -34,9 +34,7 @@ func TestCompatibility(t *testing.T) { var wait sync.WaitGroup wait.Add(1) - cipher, err := CipherByName(cipherName) - require.Nil(t, err) - key, err := NewEncryptionKey(cipher, secret) + key, err := NewEncryptionKey(cipherName, secret) require.Nil(t, err, "NewCipher failed: %v", err) ssWriter := NewShadowsocksWriter(left, key) go func() { diff --git a/transport/shadowsocks/stream.go b/transport/shadowsocks/stream.go index d42e30ef..27755462 100644 --- a/transport/shadowsocks/stream.go +++ b/transport/shadowsocks/stream.go @@ -25,12 +25,12 @@ import ( "github.com/Jigsaw-Code/outline-internal-sdk/internal/slicepool" ) -// payloadSizeMask is the maximum size of payload in bytes. +// payloadSizeMask is the maximum size of payload in bytes, as per https://shadowsocks.org/guide/aead.html#tcp. const payloadSizeMask = 0x3FFF // 16*1024 - 1 // Buffer pool used for decrypting Shadowsocks streams. // The largest buffer we could need is for decrypting a max-length payload. -var readBufPool = slicepool.MakePool(payloadSizeMask + maxTagSize()) +var readBufPool = slicepool.MakePool(payloadSizeMask + maxTagSize) // Writer is an [io.Writer] that also implements [io.ReaderFrom] to // allow for piping the data without extra allocations and copies. From 66327eedd6137539bbc03b29d655c9bde7acfccd Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 28 Apr 2023 22:00:53 +0000 Subject: [PATCH 6/9] Consolidate shadowsocks package --- transport/shadowsocks/client/salt.go | 50 ------------- transport/shadowsocks/client/salt_test.go | 75 ------------------- .../{client => }/client_testing.go | 8 +- transport/shadowsocks/compatibility_test.go | 4 +- .../{client => }/packet_listener.go | 13 ++-- .../{client => }/packet_listener_test.go | 13 ++-- transport/shadowsocks/salt.go | 29 +++++++ transport/shadowsocks/salt_test.go | 54 +++++++++++++ transport/shadowsocks/stream.go | 8 +- .../shadowsocks/{client => }/stream_dialer.go | 25 +++---- .../{client => }/stream_dialer_test.go | 29 ++++--- transport/shadowsocks/stream_test.go | 34 ++++----- 12 files changed, 147 insertions(+), 195 deletions(-) delete mode 100644 transport/shadowsocks/client/salt.go delete mode 100644 transport/shadowsocks/client/salt_test.go rename transport/shadowsocks/{client => }/client_testing.go (86%) rename transport/shadowsocks/{client => }/packet_listener.go (90%) rename transport/shadowsocks/{client => }/packet_listener_test.go (89%) rename transport/shadowsocks/{client => }/stream_dialer.go (80%) rename transport/shadowsocks/{client => }/stream_dialer_test.go (82%) diff --git a/transport/shadowsocks/client/salt.go b/transport/shadowsocks/client/salt.go deleted file mode 100644 index 9a3b22f4..00000000 --- a/transport/shadowsocks/client/salt.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2022 Jigsaw Operations LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "crypto/rand" - "errors" - - "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" -) - -type prefixSaltGenerator struct { - prefix []byte -} - -func (g prefixSaltGenerator) GetSalt(salt []byte) error { - n := copy(salt, g.prefix) - if n != len(g.prefix) { - return errors.New("prefix is too long") - } - _, err := rand.Read(salt[n:]) - return err -} - -// NewPrefixSaltGenerator returns a SaltGenerator whose output consists of -// the provided prefix, followed by random bytes. This is useful to change -// how shadowsocks traffic is classified by middleboxes. -// -// Note: Prefixes steal entropy from the initialization vector. This weakens -// security by increasing the likelihood that the same IV is used in two -// different connections (which becomes likely once 2^(N/2) connections are -// made, due to the birthday attack). If an IV is reused, the attacker can -// not only decrypt the ciphertext of those two connections; they can also -// easily recover the shadowsocks key and decrypt all other connections to -// this server. Use with care! -func NewPrefixSaltGenerator(prefix []byte) shadowsocks.SaltGenerator { - return prefixSaltGenerator{prefix} -} diff --git a/transport/shadowsocks/client/salt_test.go b/transport/shadowsocks/client/salt_test.go deleted file mode 100644 index 32a80e3b..00000000 --- a/transport/shadowsocks/client/salt_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2022 Jigsaw Operations LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "testing" - - "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" -) - -// setRandomBitsToOne replaces any random bits in the output with 1. -func setRandomBitsToOne(salter shadowsocks.SaltGenerator, output []byte) error { - salt := make([]byte, len(output)) - // OR together 128 salts. The probability that any random bit is - // 0 for all 128 random salts is 2^-128, which is close enough to zero. - for i := 0; i < 128; i++ { - if err := salter.GetSalt(salt); err != nil { - return err - } - for i := range salt { - output[i] |= salt[i] - } - } - return nil -} - -// Test that the prefix bytes are respected, and the remainder are random. -func TestTypicalPrefix(t *testing.T) { - prefix := []byte("twelve bytes") - salter := NewPrefixSaltGenerator(prefix) - - output := make([]byte, 32) - if err := setRandomBitsToOne(salter, output); err != nil { - t.Error(err) - } - - for i := 0; i < 12; i++ { - if output[i] != prefix[i] { - t.Error("prefix mismatch") - } - } - - for _, b := range output[12:] { - if b != 0xFF { - t.Error("unexpected zero bit") - } - } -} - -// Test that all bytes are random when the prefix is nil -func TestNilPrefix(t *testing.T) { - salter := NewPrefixSaltGenerator(nil) - - output := make([]byte, 64) - if err := setRandomBitsToOne(salter, output); err != nil { - t.Error(err) - } - for _, b := range output { - if b != 0xFF { - t.Error("unexpected zero bit") - } - } -} diff --git a/transport/shadowsocks/client/client_testing.go b/transport/shadowsocks/client_testing.go similarity index 86% rename from transport/shadowsocks/client/client_testing.go rename to transport/shadowsocks/client_testing.go index 6e5799c4..d63029bb 100644 --- a/transport/shadowsocks/client/client_testing.go +++ b/transport/shadowsocks/client_testing.go @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package client +package shadowsocks import ( "bytes" "io" "testing" - - "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" ) const ( @@ -42,8 +40,8 @@ func expectEchoPayload(conn io.ReadWriter, payload, buf []byte, t testing.TB) { } } -func makeTestKey(tb testing.TB) *shadowsocks.EncryptionKey { - key, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, "testPassword") +func makeTestKey(tb testing.TB) *EncryptionKey { + key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "testPassword") if err != nil { tb.Fatalf("Failed to create key: %v", err) } diff --git a/transport/shadowsocks/compatibility_test.go b/transport/shadowsocks/compatibility_test.go index 10cfa0db..fb83330c 100644 --- a/transport/shadowsocks/compatibility_test.go +++ b/transport/shadowsocks/compatibility_test.go @@ -36,13 +36,13 @@ func TestCompatibility(t *testing.T) { wait.Add(1) key, err := NewEncryptionKey(cipherName, secret) require.Nil(t, err, "NewCipher failed: %v", err) - ssWriter := NewShadowsocksWriter(left, key) + ssWriter := NewWriter(left, key) go func() { defer wait.Done() var err error ssWriter.Write([]byte(fromLeft)) - ssReader := NewShadowsocksReader(left, key) + ssReader := NewReader(left, key) receivedByLeft := make([]byte, len(fromRight)) _, err = ssReader.Read(receivedByLeft) require.Nil(t, err, "Read failed: %v", err) diff --git a/transport/shadowsocks/client/packet_listener.go b/transport/shadowsocks/packet_listener.go similarity index 90% rename from transport/shadowsocks/client/packet_listener.go rename to transport/shadowsocks/packet_listener.go index d525b6d1..da497ef4 100644 --- a/transport/shadowsocks/client/packet_listener.go +++ b/transport/shadowsocks/packet_listener.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package client +package shadowsocks import ( "context" @@ -23,7 +23,6 @@ import ( "github.com/Jigsaw-Code/outline-internal-sdk/internal/slicepool" "github.com/Jigsaw-Code/outline-internal-sdk/transport" - "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -35,12 +34,12 @@ var udpPool = slicepool.MakePool(clientUDPBufferSize) type packetListener struct { endpoint transport.PacketEndpoint - key *shadowsocks.EncryptionKey + key *EncryptionKey } var _ transport.PacketListener = (*packetListener)(nil) -func NewShadowsocksPacketListener(endpoint transport.PacketEndpoint, key *shadowsocks.EncryptionKey) (transport.PacketListener, error) { +func NewPacketListener(endpoint transport.PacketEndpoint, key *EncryptionKey) (transport.PacketListener, error) { if endpoint == nil { return nil, errors.New("argument endpoint must not be nil") } @@ -61,7 +60,7 @@ func (c *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, erro type packetConn struct { net.Conn - key *shadowsocks.EncryptionKey + key *EncryptionKey } var _ net.PacketConn = (*packetConn)(nil) @@ -80,7 +79,7 @@ func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { // partially overlapping the plaintext and cipher slices since `Pack` skips the salt when calling // `AEAD.Seal` (see https://golang.org/pkg/crypto/cipher/#AEAD). plaintextBuf := append(append(cipherBuf[saltSize:saltSize], socksTargetAddr...), b...) - buf, err := shadowsocks.Pack(cipherBuf, plaintextBuf, c.key) + buf, err := Pack(cipherBuf, plaintextBuf, c.key) if err != nil { return 0, err } @@ -98,7 +97,7 @@ func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { return 0, nil, err } // Decrypt in-place. - buf, err := shadowsocks.Unpack(nil, cipherBuf[:n], c.key) + buf, err := Unpack(nil, cipherBuf[:n], c.key) if err != nil { return 0, nil, err } diff --git a/transport/shadowsocks/client/packet_listener_test.go b/transport/shadowsocks/packet_listener_test.go similarity index 89% rename from transport/shadowsocks/client/packet_listener_test.go rename to transport/shadowsocks/packet_listener_test.go index 1e77b37b..8ff3d097 100644 --- a/transport/shadowsocks/client/packet_listener_test.go +++ b/transport/shadowsocks/packet_listener_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package client +package shadowsocks import ( "context" @@ -23,7 +23,6 @@ import ( "time" "github.com/Jigsaw-Code/outline-internal-sdk/transport" - "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -31,7 +30,7 @@ func TestShadowsocksPacketListener_ListenPacket(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksUDPEchoServer(key, testTargetAddr, t) proxyEndpoint := transport.UDPEndpoint{Address: proxy.LocalAddr().String()} - d, err := NewShadowsocksPacketListener(proxyEndpoint, key) + d, err := NewPacketListener(proxyEndpoint, key) if err != nil { t.Fatalf("Failed to create PacketListener: %v", err) } @@ -55,7 +54,7 @@ func BenchmarkShadowsocksPacketListener_ListenPacket(b *testing.B) { key := makeTestKey(b) proxy, running := startShadowsocksUDPEchoServer(key, testTargetAddr, b) proxyEndpoint := transport.UDPEndpoint{Address: proxy.LocalAddr().String()} - d, err := NewShadowsocksPacketListener(proxyEndpoint, key) + d, err := NewPacketListener(proxyEndpoint, key) if err != nil { b.Fatalf("Failed to create PacketListener: %v", err) } @@ -78,7 +77,7 @@ func BenchmarkShadowsocksPacketListener_ListenPacket(b *testing.B) { running.Wait() } -func startShadowsocksUDPEchoServer(key *shadowsocks.EncryptionKey, expectedTgtAddr string, t testing.TB) (net.Conn, *sync.WaitGroup) { +func startShadowsocksUDPEchoServer(key *EncryptionKey, expectedTgtAddr string, t testing.TB) (net.Conn, *sync.WaitGroup) { conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { t.Fatalf("Proxy ListenUDP failed: %v", err) @@ -97,7 +96,7 @@ func startShadowsocksUDPEchoServer(key *shadowsocks.EncryptionKey, expectedTgtAd t.Logf("Failed to read from UDP conn: %v", err) return } - buf, err := shadowsocks.Unpack(clientBuf, cipherBuf[:n], key) + buf, err := Unpack(clientBuf, cipherBuf[:n], key) if err != nil { t.Fatalf("Failed to decrypt: %v", err) } @@ -109,7 +108,7 @@ func startShadowsocksUDPEchoServer(key *shadowsocks.EncryptionKey, expectedTgtAd t.Fatalf("Expected target address '%v'. Got '%v'", expectedTgtAddr, tgtAddr) } // Echo both the payload and SOCKS address. - buf, err = shadowsocks.Pack(cipherBuf, buf, key) + buf, err = Pack(cipherBuf, buf, key) if err != nil { t.Fatalf("Failed to encrypt: %v", err) } diff --git a/transport/shadowsocks/salt.go b/transport/shadowsocks/salt.go index 64c60cb4..ead5bffd 100644 --- a/transport/shadowsocks/salt.go +++ b/transport/shadowsocks/salt.go @@ -16,6 +16,7 @@ package shadowsocks import ( "crypto/rand" + "errors" ) // SaltGenerator generates unique salts to use in Shadowsocks connections. @@ -35,3 +36,31 @@ func (randomSaltGenerator) GetSalt(salt []byte) error { // RandomSaltGenerator is a basic SaltGenerator. var RandomSaltGenerator SaltGenerator = randomSaltGenerator{} + +type prefixSaltGenerator struct { + prefix []byte +} + +func (g prefixSaltGenerator) GetSalt(salt []byte) error { + n := copy(salt, g.prefix) + if n != len(g.prefix) { + return errors.New("prefix is too long") + } + _, err := rand.Read(salt[n:]) + return err +} + +// NewPrefixSaltGenerator returns a SaltGenerator whose output consists of +// the provided prefix, followed by random bytes. This is useful to change +// how shadowsocks traffic is classified by middleboxes. +// +// Note: Prefixes steal entropy from the initialization vector. This weakens +// security by increasing the likelihood that the same IV is used in two +// different connections (which becomes likely once 2^(N/2) connections are +// made, due to the birthday attack). If an IV is reused, the attacker can +// not only decrypt the ciphertext of those two connections; they can also +// easily recover the shadowsocks key and decrypt all other connections to +// this server. Use with care! +func NewPrefixSaltGenerator(prefix []byte) SaltGenerator { + return prefixSaltGenerator{prefix} +} diff --git a/transport/shadowsocks/salt_test.go b/transport/shadowsocks/salt_test.go index 647e8b7b..f56cb2c4 100644 --- a/transport/shadowsocks/salt_test.go +++ b/transport/shadowsocks/salt_test.go @@ -42,3 +42,57 @@ func BenchmarkRandomSaltGenerator(b *testing.B) { } }) } + +// setRandomBitsToOne replaces any random bits in the output with 1. +func setRandomBitsToOne(salter SaltGenerator, output []byte) error { + salt := make([]byte, len(output)) + // OR together 128 salts. The probability that any random bit is + // 0 for all 128 random salts is 2^-128, which is close enough to zero. + for i := 0; i < 128; i++ { + if err := salter.GetSalt(salt); err != nil { + return err + } + for i := range salt { + output[i] |= salt[i] + } + } + return nil +} + +// Test that the prefix bytes are respected, and the remainder are random. +func TestTypicalPrefix(t *testing.T) { + prefix := []byte("twelve bytes") + salter := NewPrefixSaltGenerator(prefix) + + output := make([]byte, 32) + if err := setRandomBitsToOne(salter, output); err != nil { + t.Error(err) + } + + for i := 0; i < 12; i++ { + if output[i] != prefix[i] { + t.Error("prefix mismatch") + } + } + + for _, b := range output[12:] { + if b != 0xFF { + t.Error("unexpected zero bit") + } + } +} + +// Test that all bytes are random when the prefix is nil +func TestNilPrefix(t *testing.T) { + salter := NewPrefixSaltGenerator(nil) + + output := make([]byte, 64) + if err := setRandomBitsToOne(salter, output); err != nil { + t.Error(err) + } + for _, b := range output { + if b != 0xFF { + t.Error("unexpected zero bit") + } + } +} diff --git a/transport/shadowsocks/stream.go b/transport/shadowsocks/stream.go index 27755462..cf0f296f 100644 --- a/transport/shadowsocks/stream.go +++ b/transport/shadowsocks/stream.go @@ -63,9 +63,9 @@ var ( _ io.ReaderFrom = (*Writer)(nil) ) -// NewShadowsocksWriter creates a [Writer] that encrypts the given [io.Writer] using +// NewWriter creates a [Writer] that encrypts the given [io.Writer] using // the shadowsocks protocol with the given encryption key. -func NewShadowsocksWriter(writer io.Writer, key *EncryptionKey) *Writer { +func NewWriter(writer io.Writer, key *EncryptionKey) *Writer { return &Writer{writer: writer, key: key, saltGenerator: RandomSaltGenerator} } @@ -286,9 +286,9 @@ type Reader interface { io.WriterTo } -// NewShadowsocksReader creates a [Reader] that decrypts the given [io.Reader] using +// NewReader creates a [Reader] that decrypts the given [io.Reader] using // the shadowsocks protocol with the given encryption key. -func NewShadowsocksReader(reader io.Reader, key *EncryptionKey) Reader { +func NewReader(reader io.Reader, key *EncryptionKey) Reader { return &readConverter{ cr: &chunkReader{ reader: reader, diff --git a/transport/shadowsocks/client/stream_dialer.go b/transport/shadowsocks/stream_dialer.go similarity index 80% rename from transport/shadowsocks/client/stream_dialer.go rename to transport/shadowsocks/stream_dialer.go index ef215f99..fe3d07fe 100644 --- a/transport/shadowsocks/client/stream_dialer.go +++ b/transport/shadowsocks/stream_dialer.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package client +package shadowsocks import ( "context" @@ -20,35 +20,34 @@ import ( "time" "github.com/Jigsaw-Code/outline-internal-sdk/transport" - "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" "github.com/shadowsocks/go-shadowsocks2/socks" ) -// NewShadowsocksStreamDialer creates a client that routes connections to a Shadowsocks proxy listening at +// NewStreamDialer creates a client that routes connections to a Shadowsocks proxy listening at // the given StreamEndpoint, with `key` as the Shadowsocks encyption key. -func NewShadowsocksStreamDialer(endpoint transport.StreamEndpoint, key *shadowsocks.EncryptionKey) (*ShadowsocksStreamDialer, error) { +func NewStreamDialer(endpoint transport.StreamEndpoint, key *EncryptionKey) (*StreamDialer, error) { if endpoint == nil { return nil, errors.New("argument endpoint must not be nil") } if key == nil { return nil, errors.New("argument key must not be nil") } - d := ShadowsocksStreamDialer{endpoint: endpoint, key: key, ClientDataWait: 10 * time.Millisecond} + d := StreamDialer{endpoint: endpoint, key: key, ClientDataWait: 10 * time.Millisecond} return &d, nil } -type ShadowsocksStreamDialer struct { +type StreamDialer struct { endpoint transport.StreamEndpoint - key *shadowsocks.EncryptionKey + key *EncryptionKey // SaltGenerator is used by Shadowsocks to generate the connection salts. // `SaltGenerator` may be `nil`, which defaults to [shadowsocks.RandomSaltGenerator]. - SaltGenerator shadowsocks.SaltGenerator + SaltGenerator SaltGenerator // ClientDataWait specifies the amount of time to wait for client data before sending // the Shadowsocks connection request to the proxy server. It's 10 milliseconds by default. // - // ShadowsocksStreamDialer has an optimization to send the initial client payload along with + // StreamDialer has an optimization to send the initial client payload along with // the Shadowsocks connection request. This saves one packet during connection, and also // reduces the distinctiveness of the connection pattern. // @@ -61,7 +60,7 @@ type ShadowsocksStreamDialer struct { ClientDataWait time.Duration } -var _ transport.StreamDialer = (*ShadowsocksStreamDialer)(nil) +var _ transport.StreamDialer = (*StreamDialer)(nil) // Dial implements StreamDialer.Dial via a Shadowsocks server. // @@ -78,7 +77,7 @@ var _ transport.StreamDialer = (*ShadowsocksStreamDialer)(nil) // initial data from the application in order to send the Shadowsocks salt, SOCKS address and initial data // all in one packet. This makes the size of the initial packet hard to predict, avoiding packet size // fingerprinting. We can only get the application initial data if we return a connection first. -func (c *ShadowsocksStreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { +func (c *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { socksTargetAddr := socks.ParseAddr(remoteAddr) if socksTargetAddr == nil { return nil, errors.New("failed to parse target address") @@ -87,7 +86,7 @@ func (c *ShadowsocksStreamDialer) Dial(ctx context.Context, remoteAddr string) ( if err != nil { return nil, err } - ssw := shadowsocks.NewShadowsocksWriter(proxyConn, c.key) + ssw := NewWriter(proxyConn, c.key) if c.SaltGenerator != nil { ssw.SetSaltGenerator(c.SaltGenerator) } @@ -99,6 +98,6 @@ func (c *ShadowsocksStreamDialer) Dial(ctx context.Context, remoteAddr string) ( time.AfterFunc(c.ClientDataWait, func() { ssw.Flush() }) - ssr := shadowsocks.NewShadowsocksReader(proxyConn, c.key) + ssr := NewReader(proxyConn, c.key) return transport.WrapConn(proxyConn, ssr, ssw), nil } diff --git a/transport/shadowsocks/client/stream_dialer_test.go b/transport/shadowsocks/stream_dialer_test.go similarity index 82% rename from transport/shadowsocks/client/stream_dialer_test.go rename to transport/shadowsocks/stream_dialer_test.go index 26931515..9a34d82c 100644 --- a/transport/shadowsocks/client/stream_dialer_test.go +++ b/transport/shadowsocks/stream_dialer_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package client +package shadowsocks import ( "context" @@ -23,14 +23,13 @@ import ( "time" "github.com/Jigsaw-Code/outline-internal-sdk/transport" - "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" "github.com/shadowsocks/go-shadowsocks2/socks" ) -func TestShadowsocksStreamDialer_Dial(t *testing.T) { +func TestStreamDialer_Dial(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, t) - d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) + d, err := NewStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -46,10 +45,10 @@ func TestShadowsocksStreamDialer_Dial(t *testing.T) { running.Wait() } -func TestShadowsocksStreamDialer_DialNoPayload(t *testing.T) { +func TestStreamDialer_DialNoPayload(t *testing.T) { key := makeTestKey(t) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, t) - d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) + d, err := NewStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -68,7 +67,7 @@ func TestShadowsocksStreamDialer_DialNoPayload(t *testing.T) { running.Wait() } -func TestShadowsocksStreamDialer_DialFastClose(t *testing.T) { +func TestStreamDialer_DialFastClose(t *testing.T) { // Set up a listener that verifies no data is sent. listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { @@ -91,7 +90,7 @@ func TestShadowsocksStreamDialer_DialFastClose(t *testing.T) { }() key := makeTestKey(t) - d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}, key) + d, err := NewStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -109,7 +108,7 @@ func TestShadowsocksStreamDialer_DialFastClose(t *testing.T) { <-done } -func TestShadowsocksStreamDialer_TCPPrefix(t *testing.T) { +func TestStreamDialer_TCPPrefix(t *testing.T) { prefix := []byte("test prefix") listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) @@ -139,7 +138,7 @@ func TestShadowsocksStreamDialer_TCPPrefix(t *testing.T) { }() key := makeTestKey(t) - d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}, key) + d, err := NewStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}, key) if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } @@ -153,13 +152,13 @@ func TestShadowsocksStreamDialer_TCPPrefix(t *testing.T) { running.Wait() } -func BenchmarkShadowsocksStreamDialer_Dial(b *testing.B) { +func BenchmarkStreamDialer_Dial(b *testing.B) { b.StopTimer() b.ResetTimer() key := makeTestKey(b) proxy, running := startShadowsocksTCPEchoProxy(key, testTargetAddr, b) - d, err := NewShadowsocksStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) + d, err := NewStreamDialer(&transport.TCPEndpoint{Address: proxy.Addr().String()}, key) if err != nil { b.Fatalf("Failed to create StreamDialer: %v", err) } @@ -181,7 +180,7 @@ func BenchmarkShadowsocksStreamDialer_Dial(b *testing.B) { running.Wait() } -func startShadowsocksTCPEchoProxy(key *shadowsocks.EncryptionKey, expectedTgtAddr string, t testing.TB) (net.Listener, *sync.WaitGroup) { +func startShadowsocksTCPEchoProxy(key *EncryptionKey, expectedTgtAddr string, t testing.TB) (net.Listener, *sync.WaitGroup) { listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { t.Fatalf("ListenTCP failed: %v", err) @@ -202,8 +201,8 @@ func startShadowsocksTCPEchoProxy(key *shadowsocks.EncryptionKey, expectedTgtAdd go func() { defer running.Done() defer clientConn.Close() - ssr := shadowsocks.NewShadowsocksReader(clientConn, key) - ssw := shadowsocks.NewShadowsocksWriter(clientConn, key) + ssr := NewReader(clientConn, key) + ssw := NewWriter(clientConn, key) ssClientConn := transport.WrapConn(clientConn, ssr, ssw) tgtAddr, err := socks.ReadAddr(ssClientConn) diff --git a/transport/shadowsocks/stream_test.go b/transport/shadowsocks/stream_test.go index 12387aa7..165ca20c 100644 --- a/transport/shadowsocks/stream_test.go +++ b/transport/shadowsocks/stream_test.go @@ -22,7 +22,7 @@ func TestCipherReaderAuthenticationFailure(t *testing.T) { require.Nil(t, err) clientReader := strings.NewReader("Fails Authentication") - reader := NewShadowsocksReader(clientReader, key) + reader := NewReader(clientReader, key) _, err = reader.Read(make([]byte, 1)) if err == nil { t.Fatalf("Expected authentication failure, got %v", err) @@ -34,7 +34,7 @@ func TestCipherReaderUnexpectedEOF(t *testing.T) { require.Nil(t, err) clientReader := strings.NewReader("short") - server := NewShadowsocksReader(clientReader, key) + server := NewReader(clientReader, key) _, err = server.Read(make([]byte, 10)) require.Equal(t, io.ErrUnexpectedEOF, err) } @@ -44,7 +44,7 @@ func TestCipherReaderEOF(t *testing.T) { require.Nil(t, err) clientReader := strings.NewReader("") - server := NewShadowsocksReader(clientReader, key) + server := NewReader(clientReader, key) _, err = server.Read(make([]byte, 10)) if err != io.EOF { t.Fatalf("Expected EOF, got %v", err) @@ -96,7 +96,7 @@ func TestCipherReaderGoodReads(t *testing.T) { t.Fatal(err) } - reader := NewShadowsocksReader(ssText, key) + reader := NewReader(ssText, key) plainText := make([]byte, len("[First Block]")+len("[Third Block]")) n, err := io.ReadFull(reader, plainText) if err != nil { @@ -117,7 +117,7 @@ func TestCipherReaderClose(t *testing.T) { require.Nil(t, err) pipeReader, pipeWriter := io.Pipe() - server := NewShadowsocksReader(pipeReader, key) + server := NewReader(pipeReader, key) result := make(chan error) go func() { _, err := server.Read(make([]byte, 10)) @@ -135,7 +135,7 @@ func TestCipherReaderCloseError(t *testing.T) { require.Nil(t, err) pipeReader, pipeWriter := io.Pipe() - server := NewShadowsocksReader(pipeReader, key) + server := NewReader(pipeReader, key) result := make(chan error) go func() { _, err := server.Read(make([]byte, 10)) @@ -153,8 +153,8 @@ func TestEndToEnd(t *testing.T) { require.Nil(t, err) connReader, connWriter := io.Pipe() - writer := NewShadowsocksWriter(connWriter, key) - reader := NewShadowsocksReader(connReader, key) + writer := NewWriter(connWriter, key) + reader := NewReader(connReader, key) expected := "Test" wg := sync.WaitGroup{} var writeErr error @@ -182,7 +182,7 @@ func TestLazyWriteFlush(t *testing.T) { key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret") require.Nil(t, err) buf := new(bytes.Buffer) - writer := NewShadowsocksWriter(buf, key) + writer := NewWriter(buf, key) header := []byte{1, 2, 3, 4} n, err := writer.LazyWrite(header) if n != len(header) { @@ -216,7 +216,7 @@ func TestLazyWriteFlush(t *testing.T) { } // Verify content arrives in two blocks - reader := NewShadowsocksReader(buf, key) + reader := NewReader(buf, key) decrypted := make([]byte, len(header)+len(body)) n, err = reader.Read(decrypted) if n != len(header) { @@ -244,7 +244,7 @@ func TestLazyWriteConcat(t *testing.T) { key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret") require.Nil(t, err) buf := new(bytes.Buffer) - writer := NewShadowsocksWriter(buf, key) + writer := NewWriter(buf, key) header := []byte{1, 2, 3, 4} n, err := writer.LazyWrite(header) if n != len(header) { @@ -280,7 +280,7 @@ func TestLazyWriteConcat(t *testing.T) { } // Verify content arrives in one block - reader := NewShadowsocksReader(buf, key) + reader := NewReader(buf, key) decrypted := make([]byte, len(body)+len(header)) n, err = reader.Read(decrypted) if n != len(decrypted) { @@ -299,7 +299,7 @@ func TestLazyWriteOversize(t *testing.T) { key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret") require.Nil(t, err) buf := new(bytes.Buffer) - writer := NewShadowsocksWriter(buf, key) + writer := NewWriter(buf, key) N := 25000 // More than one block, less than two. data := make([]byte, N) for i := range data { @@ -323,7 +323,7 @@ func TestLazyWriteOversize(t *testing.T) { } // Verify content - reader := NewShadowsocksReader(buf, key) + reader := NewReader(buf, key) decrypted, err := ioutil.ReadAll(reader) if len(decrypted) != N { t.Errorf("Wrong number of bytes out: %d", len(decrypted)) @@ -340,7 +340,7 @@ func TestLazyWriteConcurrentFlush(t *testing.T) { key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret") require.Nil(t, err) buf := new(bytes.Buffer) - writer := NewShadowsocksWriter(buf, key) + writer := NewWriter(buf, key) header := []byte{1, 2, 3, 4} n, err := writer.LazyWrite(header) if n != len(header) { @@ -395,7 +395,7 @@ func TestLazyWriteConcurrentFlush(t *testing.T) { } // Verify content arrives in two blocks - reader := NewShadowsocksReader(buf, key) + reader := NewReader(buf, key) decrypted := make([]byte, len(header)+len(body)) n, err = reader.Read(decrypted) if n != len(header) { @@ -436,7 +436,7 @@ func BenchmarkWriter(b *testing.B) { key, err := NewEncryptionKey(CHACHA20IETFPOLY1305, "test secret") require.Nil(b, err) - writer := NewShadowsocksWriter(new(nullIO), key) + writer := NewWriter(new(nullIO), key) start := time.Now() b.StartTimer() From eabd5778e20f144599ae23f4bcc2aba2145003c8 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 28 Apr 2023 23:43:08 +0000 Subject: [PATCH 7/9] Add macOS tests Enable macOS tests Add TODO Try Windows Fix job name Use setup-go@v4 Read Go version from go.mod Add CWD Fix go-version-file Use ./go.mod Use github.workspace Reorder Spaces Remove cwd Try CD Test Again Again Use localhost IP Add bench Refactor test Add bench Again Re-enable --- .github/workflows/test.yml | 51 ++++++++++-- .../shadowsocks/client/stream_dialer_test.go | 54 ++++++------ transport/stream_test.go | 83 ++++++++++--------- 3 files changed, 121 insertions(+), 67 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 21e243b8..e27f898e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,21 +11,62 @@ permissions: # added using https://github.com/step-security/secure-workflows jobs: - build: - name: Build + test_linux: + name: Linux runs-on: ubuntu-latest steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: Set up Go 1.20 + uses: actions/setup-go@v4 + with: + go-version-file: '${{ github.workspace }}/go.mod' + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v -race -bench '.' ./... -benchtime=100ms + + + test_macos: + name: macOS + runs-on: macos-latest + steps: + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + - name: Set up Go 1.20 - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: - go-version: ^1.20 + go-version-file: '${{ github.workspace }}/go.mod' + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v -race -bench '.' ./... -benchtime=100ms + + + test_windows: + name: Windows + runs-on: windows-latest + steps: - name: Check out code into the Go module directory uses: actions/checkout@v3 + - name: Set up Go 1.20 + uses: actions/setup-go@v4 + with: + go-version-file: '${{ github.workspace }}/go.mod' + - name: Build run: go build -v ./... - name: Test - run: go test -v -race -bench=. ./... -benchtime=100ms + run: go test -v -race -bench '.' -benchtime=100ms ./... + diff --git a/transport/shadowsocks/client/stream_dialer_test.go b/transport/shadowsocks/client/stream_dialer_test.go index 5a6fd165..b8c583a9 100644 --- a/transport/shadowsocks/client/stream_dialer_test.go +++ b/transport/shadowsocks/client/stream_dialer_test.go @@ -25,6 +25,7 @@ import ( "github.com/Jigsaw-Code/outline-internal-sdk/transport" "github.com/Jigsaw-Code/outline-internal-sdk/transport/shadowsocks" "github.com/shadowsocks/go-shadowsocks2/socks" + "github.com/stretchr/testify/require" ) func TestShadowsocksStreamDialer_Dial(t *testing.T) { @@ -73,43 +74,46 @@ func TestShadowsocksStreamDialer_DialNoPayload(t *testing.T) { func TestShadowsocksStreamDialer_DialFastClose(t *testing.T) { // Set up a listener that verifies no data is sent. listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) - if err != nil { - t.Fatalf("ListenTCP failed: %v", err) - } + require.Nilf(t, err, "ListenTCP failed: %v", err) + defer listener.Close() - done := make(chan struct{}) + var running sync.WaitGroup + running.Add(2) + // Server go func() { + defer running.Done() conn, err := listener.Accept() - if err != nil { - t.Error(err) - } + require.Nil(t, err) + defer conn.Close() buf := make([]byte, 64) n, err := conn.Read(buf) if n > 0 || err != io.EOF { t.Errorf("Expected EOF, got %v, %v", buf[:n], err) } - listener.Close() - close(done) }() - key := makeTestKey(t) - proxyEndpoint := transport.TCPEndpoint{RemoteAddr: *listener.Addr().(*net.TCPAddr)} - d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) - if err != nil { - t.Fatalf("Failed to create StreamDialer: %v", err) - } - conn, err := d.Dial(context.Background(), testTargetAddr) - if err != nil { - t.Fatalf("StreamDialer.Dial failed: %v", err) - } + // Client + go func() { + defer running.Done() + key := makeTestKey(t) + proxyEndpoint := transport.TCPEndpoint{RemoteAddr: *listener.Addr().(*net.TCPAddr)} + d, err := NewShadowsocksStreamDialer(proxyEndpoint, key) + require.Nilf(t, err, "Failed to create StreamDialer: %v", err) + // Extend the wait to be safer. + d.ClientDataWait = 100 * time.Millisecond + + conn, err := d.Dial(context.Background(), testTargetAddr) + require.Nilf(t, err, "StreamDialer.Dial failed: %v", err) + + // Wait for less than 100 milliseconds to ensure that the target + // address is not sent. + time.Sleep(1 * time.Millisecond) + // Close the connection before the target address is sent. + conn.Close() + }() - // Wait for less than 10 milliseconds to ensure that the target - // address is not sent. - time.Sleep(1 * time.Millisecond) - // Close the connection before the target address is sent. - conn.Close() // Wait for the listener to verify the close. - <-done + running.Wait() } func TestShadowsocksStreamDialer_TCPPrefix(t *testing.T) { diff --git a/transport/stream_test.go b/transport/stream_test.go index 8adf6cb8..b97ba650 100644 --- a/transport/stream_test.go +++ b/transport/stream_test.go @@ -20,6 +20,9 @@ import ( "sync" "testing" "testing/iotest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewTCPEndpointIPv4(t *testing.T) { @@ -27,49 +30,55 @@ func TestNewTCPEndpointIPv4(t *testing.T) { responseText := []byte("Response") listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1)}) - if err != nil { - t.Fatalf("Failed to create TCP listener: %v", err) - } + require.Nilf(t, err, "Failed to create TCP listener: %v", err) + defer listener.Close() + var running sync.WaitGroup - running.Add(1) + running.Add(2) + + // Server go func() { defer running.Done() - defer listener.Close() clientConn, err := listener.AcceptTCP() - if err != nil { - t.Errorf("AcceptTCP failed: %v", err) - return - } + require.Nilf(t, err, "AcceptTCP failed: %v", err) + defer clientConn.Close() - if err = iotest.TestReader(clientConn, requestText); err != nil { - t.Errorf("Request read failed: %v", err) - return - } - if err = clientConn.CloseRead(); err != nil { - t.Errorf("CloseRead failed: %v", err) - return - } - if _, err = clientConn.Write(responseText); err != nil { - t.Errorf("Write failed: %v", err) - return - } - if err = clientConn.CloseWrite(); err != nil { - t.Errorf("CloseWrite failed: %v", err) - return - } + err = iotest.TestReader(clientConn, requestText) + assert.Nilf(t, err, "Request read failed: %v", err) + + // This works on Linux, but on macOS it errors with "shutdown: socket is not connected" (syscall.ENOTCONN). + // It seems that on macOS you cannot call CloseRead() if you've already received a FIN and read all the data. + // TODO(fortuna): Consider wrapping StreamConns on macOS to make CloseRead a no-op if Read has returned io.EOF + // or WriteTo has been called. + // err = clientConn.CloseRead() + // assert.Nilf(t, err, "clientConn.CloseRead failed: %v", err) + + _, err = clientConn.Write(responseText) + assert.Nilf(t, err, "Write failed: %v", err) + + err = clientConn.CloseWrite() + assert.Nilf(t, err, "CloseWrite failed: %v", err) + }() + + // Client + go func() { + defer running.Done() + e := TCPEndpoint{RemoteAddr: *listener.Addr().(*net.TCPAddr)} + serverConn, err := e.Connect(context.Background()) + require.Nilf(t, err, "Connect failed: %v", err) + defer serverConn.Close() + + n, err := serverConn.Write(requestText) + require.Nil(t, err) + require.Equal(t, 7, n) + assert.Nil(t, serverConn.CloseWrite()) + + err = iotest.TestReader(serverConn, responseText) + require.Nilf(t, err, "Response read failed: %v", err) + // See CloseRead comment on the server go-routine. + // err = serverConn.CloseRead() + // assert.Nilf(t, err, "serverConn.CloseRead failed: %v", err) }() - e := TCPEndpoint{RemoteAddr: *listener.Addr().(*net.TCPAddr)} - serverConn, err := e.Connect(context.Background()) - if err != nil { - t.Fatalf("Connect failed: %v", err) - } - defer serverConn.Close() - serverConn.Write(requestText) - serverConn.CloseWrite() - if err = iotest.TestReader(serverConn, responseText); err != nil { - t.Fatalf("Response read failed: %v", err) - } - serverConn.CloseRead() running.Wait() } From b7a48b6f80cecbbdf921604c17b5944ea502e6be Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Sat, 29 Apr 2023 21:54:48 +0000 Subject: [PATCH 8/9] Try windows 2019 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e27f898e..4f582bfc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,7 +53,7 @@ jobs: test_windows: name: Windows - runs-on: windows-latest + runs-on: windows-2019 steps: - name: Check out code into the Go module directory From 501e0708fcb9770c333c8cb7271546572e0efa92 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Sat, 29 Apr 2023 21:57:59 +0000 Subject: [PATCH 9/9] Add comment --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4f582bfc..17e811f8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,6 +53,8 @@ jobs: test_windows: name: Windows + # Use windows-2019, which is a lot faster than windows-2022: + # https://github.com/actions/runner-images/issues/5166 runs-on: windows-2019 steps: