Skip to content

Commit

Permalink
feature: make API more convenient (#6)
Browse files Browse the repository at this point in the history
- Consolidate the shadowsocks package. No more client subpackage. This makes all the API available in one place.
- The consolidation allows me to remove "Shadowsocks" from the name of functions and types.
- Replace `net.UDPAddr` and `net.TCPAddr` with strings, which support domain names and are easier to work with.
- Add convenience `TCPStreamDialer` to convert a Dialer to a StreamDialer.
- Hide `Cipher` type and make `NewEncryptionKey` take a string which is easier to use
- Add tests.

This makes the library easier and more pleasant to use. I think I'm finally happy with how the transport API is looking like.
  • Loading branch information
fortuna authored May 2, 2023
1 parent ce98f18 commit c0cf642
Show file tree
Hide file tree
Showing 18 changed files with 387 additions and 305 deletions.
23 changes: 10 additions & 13 deletions transport/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,24 @@ type PacketEndpoint interface {
Connect(ctx context.Context) (net.Conn, error)
}

// PacketListener provides a way to create a local unbound packet connection to send packets to different destinations.
type PacketListener interface {
// ListenPacket creates a PacketConn that can be used to relay packets (such as UDP) through some proxy.
ListenPacket(ctx context.Context) (net.PacketConn, error)
}

// UDPEndpoint is a [PacketEndpoint] that connects to the given address via UDP
type UDPEndpoint struct {
// The Dialer used to create the net.Conn on Connect().
Dialer net.Dialer
// The remote address to pass to Dial.
RemoteAddr net.UDPAddr
// The endpoint address (host:port) to pass to Dial.
// If the host is a domain name, consider pre-resolving it to avoid resolution calls.
Address string
}

var _ PacketEndpoint = (*UDPEndpoint)(nil)

// Connect implements [PacketEndpoint.Connect].
func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) {
conn, err := e.Dialer.DialContext(ctx, "udp", e.RemoteAddr.String())
if err != nil {
return nil, err
}
return conn, nil
return e.Dialer.DialContext(ctx, "udp", e.Address)
}

// PacketListener provides a way to create a local unbound packet connection to send packets to different destinations.
type PacketListener interface {
// ListenPacket creates a PacketConn that can be used to relay packets (such as UDP) through some proxy.
ListenPacket(ctx context.Context) (net.PacketConn, error)
}
49 changes: 49 additions & 0 deletions transport/packet_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2023 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package transport

import (
"context"
"syscall"
"testing"

"github.com/stretchr/testify/require"
)

func TestUDPEndpointIPv4(t *testing.T) {
const serverAddr = "127.0.0.10:8888"
ep := &UDPEndpoint{Address: serverAddr}
ep.Dialer.Control = func(network, address string, c syscall.RawConn) error {
require.Equal(t, "udp4", network)
require.Equal(t, serverAddr, address)
return nil
}
conn, err := ep.Connect(context.Background())
require.Nil(t, err)
require.Equal(t, serverAddr, conn.RemoteAddr().String())
}

func TestUDPEndpointIPv6(t *testing.T) {
const serverAddr = "[::1]:8888"
ep := &UDPEndpoint{Address: serverAddr}
ep.Dialer.Control = func(network, address string, c syscall.RawConn) error {
require.Equal(t, "udp6", network)
require.Equal(t, serverAddr, address)
return nil
}
conn, err := ep.Connect(context.Background())
require.Nil(t, err)
require.Equal(t, serverAddr, conn.RemoteAddr().String())
}
62 changes: 35 additions & 27 deletions transport/shadowsocks/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ import (
"golang.org/x/crypto/hkdf"
)

type Cipher struct {
name string
type cipherSpec struct {
newInstance func(key []byte) (cipher.AEAD, error)
keySize int
saltSize int
Expand All @@ -36,13 +35,20 @@ type Cipher struct {

// List of supported AEAD ciphers, as specified at https://shadowsocks.org/guide/aead.html
var (
CHACHA20IETFPOLY1305 = &Cipher{"AEAD_CHACHA20_POLY1305", chacha20poly1305.New, chacha20poly1305.KeySize, 32, 16}
AES256GCM = &Cipher{"AEAD_AES_256_GCM", newAesGCM, 32, 32, 16}
AES192GCM = &Cipher{"AEAD_AES_192_GCM", newAesGCM, 24, 24, 16}
AES128GCM = &Cipher{"AEAD_AES_128_GCM", newAesGCM, 16, 16, 16}
CHACHA20IETFPOLY1305 = "AEAD_CHACHA20_POLY1305"
AES256GCM = "AEAD_AES_256_GCM"
AES192GCM = "AEAD_AES_192_GCM"
AES128GCM = "AEAD_AES_128_GCM"
)

var supportedCiphers = [](*Cipher){CHACHA20IETFPOLY1305, AES256GCM, AES192GCM, AES128GCM}
var (
chacha20IETFPOLY1305Cipher = &cipherSpec{chacha20poly1305.New, chacha20poly1305.KeySize, 32, 16}
aes256GCMCipher = &cipherSpec{newAesGCM, 32, 32, 16}
aes192GCMCipher = &cipherSpec{newAesGCM, 24, 24, 16}
aes128GCMCipher = &cipherSpec{newAesGCM, 16, 16, 16}
)

var supportedCiphers = [](string){CHACHA20IETFPOLY1305, AES256GCM, AES192GCM, AES128GCM}

// ErrUnsupportedCipher is returned by [CypherByName] when the named cipher is not supported.
type ErrUnsupportedCipher struct {
Expand All @@ -54,19 +60,22 @@ func (err ErrUnsupportedCipher) Error() string {
return "unsupported cipher " + err.Name
}

// Largest tag size among the supported ciphers. Used by the TCP buffer pool
const maxTagSize = 16

// CipherByName returns a [*Cipher] with the given name, or an error if the cipher is not supported.
// The name must be the IETF name (as per https://www.iana.org/assignments/aead-parameters/aead-parameters.xhtml) or the
// Shadowsocks alias from https://shadowsocks.org/guide/aead.html.
func CipherByName(name string) (*Cipher, error) {
func cipherByName(name string) (*cipherSpec, error) {
switch strings.ToUpper(name) {
case "AEAD_CHACHA20_POLY1305", "CHACHA20-IETF-POLY1305":
return CHACHA20IETFPOLY1305, nil
return chacha20IETFPOLY1305Cipher, nil
case "AEAD_AES_256_GCM", "AES-256-GCM":
return AES256GCM, nil
return aes256GCMCipher, nil
case "AEAD_AES_192_GCM", "AES-192-GCM":
return AES192GCM, nil
return aes192GCMCipher, nil
case "AEAD_AES_128_GCM", "AES-128-GCM":
return AES128GCM, nil
return aes128GCMCipher, nil
default:
return nil, ErrUnsupportedCipher{name}
}
Expand All @@ -80,19 +89,9 @@ func newAesGCM(key []byte) (cipher.AEAD, error) {
return cipher.NewGCM(blk)
}

func maxTagSize() int {
max := 0
for _, spec := range supportedCiphers {
if spec.tagSize > max {
max = spec.tagSize
}
}
return max
}

// EncryptionKey encapsulates a Shadowsocks AEAD spec and a secret
type EncryptionKey struct {
cipher *Cipher
cipher *cipherSpec
secret []byte
}

Expand Down Expand Up @@ -138,12 +137,21 @@ func simpleEVPBytesToKey(data []byte, keyLen int) ([]byte, error) {
return derived[:keyLen], nil
}

// NewEncryptionKey creates a Cipher given a cipher name and a secret
func NewEncryptionKey(cipher *Cipher, secretText string) (*EncryptionKey, error) {
// NewEncryptionKey creates a Cipher given a cipher name and a secret.
// The cipher name must be the IETF name (as per https://www.iana.org/assignments/aead-parameters/aead-parameters.xhtml)
// or the Shadowsocks alias from https://shadowsocks.org/guide/aead.html.
func NewEncryptionKey(cipherName string, secretText string) (*EncryptionKey, error) {
var key EncryptionKey
var err error
key.cipher, err = cipherByName(cipherName)
if err != nil {
return nil, err
}

// Key derivation as per https://shadowsocks.org/en/spec/AEAD-Ciphers.html
secret, err := simpleEVPBytesToKey([]byte(secretText), cipher.keySize)
key.secret, err = simpleEVPBytesToKey([]byte(secretText), key.cipher.keySize)
if err != nil {
return nil, err
}
return &EncryptionKey{cipher, secret}, nil
return &key, nil
}
46 changes: 32 additions & 14 deletions transport/shadowsocks/cipher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ package shadowsocks
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func assertCipher(t *testing.T, cipher *Cipher, saltSize, tagSize int) {
func assertCipher(t *testing.T, cipher string, saltSize, tagSize int) {
key, err := NewEncryptionKey(cipher, "")
require.Nil(t, err)
require.Equal(t, saltSize, key.SaltSize())

dummyAead, err := key.NewAEAD(make([]byte, cipher.keySize))
dummyAead, err := key.NewAEAD(make([]byte, key.SaltSize()))
require.Nil(t, err)
require.Equal(t, dummyAead.Overhead(), key.TagSize())
require.Equal(t, tagSize, key.TagSize())
require.Equal(t, key.TagSize(), dummyAead.Overhead())
}

func TestSizes(t *testing.T) {
Expand All @@ -40,27 +41,29 @@ func TestSizes(t *testing.T) {
}

func TestShadowsocksCipherNames(t *testing.T) {
cipher, err := CipherByName("chacha20-ietf-poly1305")
key, err := NewEncryptionKey("chacha20-ietf-poly1305", "")
require.Nil(t, err)
require.Equal(t, CHACHA20IETFPOLY1305, cipher)
require.Equal(t, chacha20IETFPOLY1305Cipher, key.cipher)

cipher, err = CipherByName("aes-256-gcm")
key, err = NewEncryptionKey("aes-256-gcm", "")
require.Nil(t, err)
require.Equal(t, AES256GCM, cipher)
require.Equal(t, aes256GCMCipher, key.cipher)

cipher, err = CipherByName("aes-192-gcm")
key, err = NewEncryptionKey("aes-192-gcm", "")
require.Nil(t, err)
require.Equal(t, AES192GCM, cipher)
require.Equal(t, aes192GCMCipher, key.cipher)

cipher, err = CipherByName("aes-128-gcm")
key, err = NewEncryptionKey("aes-128-gcm", "")
require.Nil(t, err)
require.Equal(t, AES128GCM, cipher)
require.Equal(t, aes128GCMCipher, key.cipher)
}

func TestUnsupportedCipher(t *testing.T) {
_, err := CipherByName("aes-256-cfb")
if err == nil {
t.Errorf("Should get an error for unsupported cipher")
_, err := NewEncryptionKey("aes-256-cfb", "")
var unsupportedErr ErrUnsupportedCipher
if assert.ErrorAs(t, err, &unsupportedErr) {
assert.Equal(t, "aes-256-cfb", unsupportedErr.Name)
assert.Equal(t, "unsupported cipher aes-256-cfb", unsupportedErr.Error())
}
}

Expand All @@ -79,3 +82,18 @@ func TestMaxNonceSize(t *testing.T) {
}
}
}

func TestMaxTagSize(t *testing.T) {
var calculatedMax int
for _, cipher := range supportedCiphers {
key, err := NewEncryptionKey(cipher, "")
if !assert.Nilf(t, err, "Failed to create cipher %v", cipher) {
continue
}
assert.LessOrEqualf(t, key.TagSize(), maxTagSize, "Tag size for cipher %v (%v) is greater than the max (%v)", cipher, key.TagSize(), maxTagSize)
if key.TagSize() > calculatedMax {
calculatedMax = key.TagSize()
}
}
require.Equal(t, maxTagSize, calculatedMax)
}
50 changes: 0 additions & 50 deletions transport/shadowsocks/client/salt.go

This file was deleted.

Loading

0 comments on commit c0cf642

Please sign in to comment.