From 62c06f1eef9b9e1f918a74e40be9047768315961 Mon Sep 17 00:00:00 2001 From: anishnaik Date: Tue, 6 Jun 2023 15:30:11 -0400 Subject: [PATCH] Fix addr cheatcode (#155) - Fix `addr` and `sign` cheatcodes to correctly left-pad private keys that do not fill up 32-byte slices --- chain/cheat_codes.go | 18 ++++++++++---- .../contracts/cheat_codes/utils/addr.sol | 22 +++++++++++++---- utils/crypto_utils.go | 24 +++++++++++++++++++ 3 files changed, 54 insertions(+), 10 deletions(-) create mode 100644 utils/crypto_utils.go diff --git a/chain/cheat_codes.go b/chain/cheat_codes.go index 8404eaa4..49bd4ae3 100644 --- a/chain/cheat_codes.go +++ b/chain/cheat_codes.go @@ -340,8 +340,12 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract, // addr: Compute the address for a given private key contract.addMethod("addr", abi.Arguments{{Type: typeUint256}}, abi.Arguments{{Type: typeAddress}}, func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { - // Using TOECDSAUnsafe b/c the private key is guaranteed to be of length 256 bits, at most - privateKey := crypto.ToECDSAUnsafe(inputs[0].(*big.Int).Bytes()) + // Get the private key object + privateKey, err := utils.GetPrivateKey(inputs[0].(*big.Int).Bytes()) + if err != nil { + errorMessage := "addr: " + err.Error() + return nil, cheatCodeRevertData([]byte(errorMessage)) + } // Get ECDSA public key publicKey := privateKey.Public().(*ecdsa.PublicKey) @@ -357,11 +361,15 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract, contract.addMethod("sign", abi.Arguments{{Type: typeUint256}, {Type: typeBytes32}}, abi.Arguments{{Type: typeUint8}, {Type: typeBytes32}, {Type: typeBytes32}}, func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { - // Using TOECDSAUnsafe b/c the private key is guaranteed to be of length 256 bits, at most - privateKey := crypto.ToECDSAUnsafe(inputs[0].(*big.Int).Bytes()) - digest := inputs[1].([32]byte) + // Get the private key object + privateKey, err := utils.GetPrivateKey(inputs[0].(*big.Int).Bytes()) + if err != nil { + errorMessage := "sign: " + err.Error() + return nil, cheatCodeRevertData([]byte(errorMessage)) + } // Sign digest + digest := inputs[1].([32]byte) sig, err := crypto.Sign(digest[:], privateKey) if err != nil { return nil, cheatCodeRevertData([]byte("sign: malformed input to signature algorithm")) diff --git a/fuzzing/testdata/contracts/cheat_codes/utils/addr.sol b/fuzzing/testdata/contracts/cheat_codes/utils/addr.sol index 6a4cd930..81aae1e4 100644 --- a/fuzzing/testdata/contracts/cheat_codes/utils/addr.sol +++ b/fuzzing/testdata/contracts/cheat_codes/utils/addr.sol @@ -8,11 +8,23 @@ contract TestContract { // Obtain our cheat code contract reference. CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); - uint256 privateKey = 0x6df21769a2082e03f7e21f6395561279e9a7feb846b2bf740798c794ad196e00; - address expectedAddress = 0xdf8Ef652AdE0FA4790843a726164df8cf8649339; + // Test with random private key + uint256 pkOne = 0x6df21769a2082e03f7e21f6395561279e9a7feb846b2bf740798c794ad196e00; + address addrOne = 0xdf8Ef652AdE0FA4790843a726164df8cf8649339; + address result = cheats.addr(pkOne); + assert(result == addrOne); - // Call cheats.addr - address result = cheats.addr(privateKey); - assert(result == expectedAddress); + // Test with private key that requires padding + uint256 pkTwo = 1; + address addrTwo = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf; + result = cheats.addr(pkTwo); + assert(result == addrTwo); + + // Test with zero + uint256 pkThree = 0; + cheats.addr(pkThree); + // A private key of zero is not allowed so if we hit this assertion, then cheats.addr() did not revert which + // is incorrect + assert(false); } } diff --git a/utils/crypto_utils.go b/utils/crypto_utils.go new file mode 100644 index 00000000..0144e51d --- /dev/null +++ b/utils/crypto_utils.go @@ -0,0 +1,24 @@ +package utils + +import ( + "crypto/ecdsa" + "github.com/ethereum/go-ethereum/crypto" + "github.com/pkg/errors" +) + +// GetPrivateKey will return a private key object given a byte slice. Only slices between lengths 1 and 32 (inclusive) +// are valid. +func GetPrivateKey(b []byte) (*ecdsa.PrivateKey, error) { + // Make sure that private key is not zero + if len(b) < 1 || len(b) > 32 { + return nil, errors.New("invalid private key") + } + + // Then pad the private key slice to a fixed 32-byte array + paddedPrivateKey := make([]byte, 32) + copy(paddedPrivateKey[32-len(b):], b) + + // Next we will actually retrieve the private key object + privateKey, err := crypto.ToECDSA(paddedPrivateKey[:]) + return privateKey, errors.WithStack(err) +}