Skip to content

Commit

Permalink
refactor: cleanup comments and remove DecryptOnce (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Apr 27, 2023
1 parent ae41944 commit ce98f18
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 38 deletions.
7 changes: 5 additions & 2 deletions transport/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"net"
)

// PacketEndpoint represents an endpoint that can be used to established packet connections (like UDP)
// PacketEndpoint represents an endpoint that can be used to established packet connections (like UDP) to a fixed destination.
type PacketEndpoint interface {
// Connect creates a connection bound to an endpoint, returning the connection.
Connect(ctx context.Context) (net.Conn, error)
Expand All @@ -31,14 +31,17 @@ type PacketListener interface {
ListenPacket(ctx context.Context) (net.PacketConn, error)
}

// UDPEndpoint is a PacketEndpoint that connects to the given address via UDP
// UDPEndpoint is a [PacketEndpoint] that connects to the given address via UDP
type UDPEndpoint struct {
// The Dialer used to create the net.Conn on Connect().
Dialer net.Dialer
// The remote address to pass to Dial.
RemoteAddr net.UDPAddr
}

var _ PacketEndpoint = (*UDPEndpoint)(nil)

// Connect implements [PacketEndpoint.Connect].
func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) {
conn, err := e.Dialer.DialContext(ctx, "udp", e.RemoteAddr.String())
if err != nil {
Expand Down
20 changes: 2 additions & 18 deletions transport/shadowsocks/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ var (

var supportedCiphers = [](*Cipher){CHACHA20IETFPOLY1305, AES256GCM, AES192GCM, AES128GCM}

// ErrUnsupportedCipher is returned by [CypherByName] when the named cipher is not supported.
type ErrUnsupportedCipher struct {
// The name of the requested [Cipher]
Name string
}

Expand Down Expand Up @@ -145,21 +147,3 @@ func NewEncryptionKey(cipher *Cipher, secretText string) (*EncryptionKey, error)
}
return &EncryptionKey{cipher, secret}, nil
}

// Assumes all ciphers have NonceSize() <= 12.
var zeroNonce [12]byte

// DecryptOnce will decrypt the cipherText using the cipher and salt, appending the output to plainText.
func DecryptOnce(key *EncryptionKey, salt []byte, plainText, cipherText []byte) ([]byte, error) {
aead, err := key.NewAEAD(salt)
if err != nil {
return nil, err
}
if len(cipherText) < aead.Overhead() {
return nil, io.ErrUnexpectedEOF
}
if cap(plainText)-len(plainText) < len(cipherText)-aead.Overhead() {
return nil, io.ErrShortBuffer
}
return aead.Open(plainText, zeroNonce[:aead.NonceSize()], cipherText, nil)
}
30 changes: 23 additions & 7 deletions transport/shadowsocks/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ import (
"io"
)

// ErrShortPacket is identical to shadowaead.ErrShortPacket
// ErrShortPacket indicates the destination packet given to Unpack is too short.
var ErrShortPacket = errors.New("short packet")

// Assumes all ciphers have NonceSize() <= 12.
var zeroNonce [12]byte

// Pack encrypts a Shadowsocks-UDP packet and returns a slice containing the encrypted packet.
// dst must be big enough to hold the encrypted packet.
// If plaintext and dst overlap but are not aligned for in-place encryption, this
Expand All @@ -47,20 +50,33 @@ func Pack(dst, plaintext []byte, key *EncryptionKey) ([]byte, error) {
return aead.Seal(salt, zeroNonce[:aead.NonceSize()], plaintext, nil), nil
}

// Unpack decrypts a Shadowsocks-UDP packet and returns a slice containing the decrypted payload or an error.
// Unpack decrypts a Shadowsocks-UDP packet in the format [salt][cipherText][AEAD tag] and returns a slice containing
// the decrypted payload or an error.
// If dst is present, it is used to store the plaintext, and must have enough capacity.
// If dst is nil, decryption proceeds in-place.
// This function is needed because shadowaead.Unpack() embeds its own replay detection,
// which we do not always want, especially on memory-constrained clients.
func Unpack(dst, pkt []byte, key *EncryptionKey) ([]byte, error) {
saltSize := key.SaltSize()
if len(pkt) < saltSize {
return nil, ErrShortPacket
}

salt := pkt[:saltSize]
msg := pkt[saltSize:]
cipherTextAndTag := pkt[saltSize:]
if len(cipherTextAndTag) < key.TagSize() {
return nil, io.ErrUnexpectedEOF
}

if dst == nil {
dst = msg
dst = cipherTextAndTag
}
if cap(dst) < len(cipherTextAndTag)-key.TagSize() {
return nil, io.ErrShortBuffer
}
return DecryptOnce(key, salt, dst[:0], msg)

aead, err := key.NewAEAD(salt)
if err != nil {
return nil, err
}

return aead.Open(dst[:0], zeroNonce[:aead.NonceSize()], cipherTextAndTag, nil)
}
12 changes: 6 additions & 6 deletions transport/shadowsocks/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const payloadSizeMask = 0x3FFF // 16*1024 - 1
// The largest buffer we could need is for decrypting a max-length payload.
var readBufPool = slicepool.MakePool(payloadSizeMask + maxTagSize())

// Writer is an io.Writer that also implements io.ReaderFrom to
// Writer is an [io.Writer] that also implements [io.ReaderFrom] to
// allow for piping the data without extra allocations and copies.
// The LazyWrite and Flush methods allow a header to be
// added but delayed until the first write, for concatenation.
Expand Down Expand Up @@ -63,8 +63,8 @@ var (
_ io.ReaderFrom = (*Writer)(nil)
)

// NewShadowsocksWriter creates a Writer that encrypts the given Writer using
// the shadowsocks protocol with the given shadowsocks key.
// NewShadowsocksWriter creates a [Writer] that encrypts the given [io.Writer] using
// the shadowsocks protocol with the given encryption key.
func NewShadowsocksWriter(writer io.Writer, key *EncryptionKey) *Writer {
return &Writer{writer: writer, key: key, saltGenerator: RandomSaltGenerator}
}
Expand Down Expand Up @@ -279,15 +279,15 @@ type chunkReader struct {
payload slicepool.LazySlice
}

// Reader is an io.Reader that also implements io.WriterTo to
// Reader is an [io.Reader] that also implements [io.WriterTo] to
// allow for piping the data without extra allocations and copies.
type Reader interface {
io.Reader
io.WriterTo
}

// NewShadowsocksReader creates a Reader that decrypts the given Reader using
// the shadowsocks protocol with the given shadowsocks key.
// NewShadowsocksReader creates a [Reader] that decrypts the given [io.Reader] using
// the shadowsocks protocol with the given encryption key.
func NewShadowsocksReader(reader io.Reader, key *EncryptionKey) Reader {
return &readConverter{
cr: &chunkReader{
Expand Down
13 changes: 8 additions & 5 deletions transport/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,30 @@ type StreamConn interface {
CloseWrite() error
}

// StreamEndpoint represents an endpoint that can be used to established stream connections (like TCP)
// StreamEndpoint represents an endpoint that can be used to established stream connections (like TCP) to a fixed destination.
type StreamEndpoint interface {
// Connect establishes a connection with the endpoint, returning the connection.
Connect(ctx context.Context) (StreamConn, error)
}

// StreamDialer provides a way to establish stream connections to a destination.
// StreamDialer provides a way to dial a destination and establish stream connections.
type StreamDialer interface {
// Dial connects to `raddr`.
// `raddr` has the form `host:port`, where `host` can be a domain name or IP address.
Dial(ctx context.Context, raddr string) (StreamConn, error)
}

// TCPEndpoint is a StreamEndpoint that connects to the given address via TCP
// TCPEndpoint is a [StreamEndpoint] that connects to the given address via TCP.
type TCPEndpoint struct {
// The Dialer used to create the connection on Connect().
Dialer net.Dialer
// The remote address to pass to DialTCP.
RemoteAddr net.TCPAddr
}

var _ StreamEndpoint = (*TCPEndpoint)(nil)

// Connect implements [StreamEndpoint.Connect].
func (e TCPEndpoint) Connect(ctx context.Context) (StreamConn, error) {
conn, err := e.Dialer.DialContext(ctx, "tcp", e.RemoteAddr.String())
if err != nil {
Expand Down Expand Up @@ -86,8 +89,8 @@ func (dc *duplexConnAdaptor) CloseWrite() error {
return dc.StreamConn.CloseWrite()
}

// WrapDuplexConn wraps an existing DuplexConn with new Reader and Writer, but
// preserving the original CloseRead() and CloseWrite().
// WrapDuplexConn wraps an existing [StreamConn] with new Reader and Writer, but
// preserving the original [StreamConn.CloseRead] and [StreamConn.CloseWrite].
func WrapConn(c StreamConn, r io.Reader, w io.Writer) StreamConn {
conn := c
// We special-case duplexConnAdaptor to avoid multiple levels of nesting.
Expand Down

0 comments on commit ce98f18

Please sign in to comment.