Skip to content

Commit

Permalink
Merge pull request #43 from crytic/improve-generate-cli
Browse files Browse the repository at this point in the history
Improve unit test generation CLI arguments
  • Loading branch information
tuturu-tech authored Apr 2, 2024
2 parents 3a6bf68 + 3bb3141 commit ba33ab9
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 125 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ The available tool commands are:
The `generate` command is used to generate Foundry unit tests from Echidna or Medusa corpus call sequences.

**Command-line options:**
- `compilation_path`: The path to the Solidity file or Foundry directory
- `-cd`/`--corpus-dir` `path_to_corpus_dir`: The path to the corpus directory relative to the working directory.
- `-c`/`--contract` `contract_name`: The name of the target contract.
- `-td`/`--test-directory` `path_to_test_directory`: The path to the test directory relative to the working directory.
- `-i`/`--inheritance-path` `relative_path_to_contract`: The relative path from the test directory to the contract (used for inheritance).
- `-f`/`--fuzzer` `fuzzer_name`: The name of the fuzzer, currently supported: `echidna` and `medusa`
- `--named-inputs`: Includes function input names when making calls
- `--config`: Path to the fuzz-utils config JSON file
- `--all-sequences`: Include all corpus sequences when generating unit tests.
- `compilation_path`: The path to the Solidity file or Foundry directory. By default `.`
- `-cd`/`--corpus-dir` `path_to_corpus_dir`: The path to the corpus directory relative to the working directory. By default `corpus`
- `-c`/`--contract` `contract_name`: The name of the target contract. If the compilation path only contains one contract the target will be automatically derived.
- `-td`/`--test-directory` `path_to_test_directory`: The path to the test directory relative to the working directory. By default `test`
- `-i`/`--inheritance-path` `relative_path_to_contract`: The relative path from the test directory to the contract (used for overriding inheritance). If this configuration option is not provided the inheritance path will be automatically derived.
- `-f`/`--fuzzer` `fuzzer_name`: The name of the fuzzer, currently supported: `echidna` and `medusa`. By default `medusa`
- `--named-inputs`: Includes function input names when making calls. By default`false`
- `--config`: Path to the fuzz-utils config JSON file. Empty by default.
- `--all-sequences`: Include all corpus sequences when generating unit tests. By default `false`

**Example**

In order to generate a test file for the [BasicTypes.sol](tests/test_data/src/BasicTypes.sol) contract, based on the Echidna corpus reproducers for this contract ([corpus-basic](tests/test_data/echidna-corpora/corpus-basic/)), we need to `cd` into the `tests/test_data` directory which contains the Foundry project and run the command:
```bash
fuzz-utils generate ./src/BasicTypes.sol --corpus-dir echidna-corpora/corpus-basic --contract "BasicTypes" --test-directory "./test/" --inheritance-path "../src/" --fuzzer echidna
fuzz-utils generate ./src/BasicTypes.sol --corpus-dir echidna-corpora/corpus-basic --contract "BasicTypes" --fuzzer echidna
```

Running this command should generate a `BasicTypes_Echidna_Test.sol` file in the [test](/tests/test_data/test/) directory of the Foundry project.
Expand All @@ -65,9 +65,9 @@ The `template` command is used to generate a fuzzing harness. The harness can in

**Command-line options:**
- `compilation_path`: The path to the Solidity file or Foundry directory
- `-n`/`--name` `name: str`: The name of the fuzzing harness.
- `-c`/`--contracts` `target_contracts: list`: The name of the target contract.
- `-o`/`--output-dir` `output_directory: str`: Output directory name. By default it is `fuzzing`
- `-n`/`--name` `name: str`: The name of the fuzzing harness. By default `DefaultHarness`
- `-c`/`--contracts` `target_contracts: list`: The name of the target contract. Empty by default.
- `-o`/`--output-dir` `output_directory: str`: Output directory name. By default `fuzzing`
- `--config`: Path to the `fuzz-utils` config JSON file
- `--mode`: The strategy to use when generating the harnesses. Valid options: `simple`, `prank`, `actor`

Expand Down
45 changes: 18 additions & 27 deletions fuzz_utils/generate/FoundryTest.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,49 @@
"""The FoundryTest class that handles generation of unit tests from call sequences"""
import os
import sys
import json
import copy
from typing import Any
import jinja2

from slither import Slither
from slither.core.declarations.contract import Contract
from fuzz_utils.utils.crytic_print import CryticPrint
from fuzz_utils.utils.slither_utils import get_target_contract
from fuzz_utils.templates.default_config import default_config

from fuzz_utils.generate.fuzzers.Medusa import Medusa
from fuzz_utils.generate.fuzzers.Echidna import Echidna
from fuzz_utils.templates.foundry_templates import templates


class FoundryTest: # pylint: disable=too-many-instance-attributes
# pylint: disable=too-few-public-methods,too-many-instance-attributes
class FoundryTest:
"""
Handles the generation of Foundry test files
"""

config: dict = copy.deepcopy(default_config["generate"])

def __init__(
self,
config: dict,
slither: Slither,
fuzzer: Echidna | Medusa,
) -> None:
self.inheritance_path = config["inheritancePath"]
self.target_name = config["targetContract"]
self.corpus_path = config["corpusDir"]
self.test_dir = config["testsDir"]
self.all_sequences = config["allSequences"]
self.slither = slither
self.target = self.get_target_contract()
self.fuzzer = fuzzer
for key, value in config.items():
if key in self.config:
self.config[key] = value

def get_target_contract(self) -> Contract:
"""Gets the Slither Contract object for the specified contract file"""
contracts = self.slither.get_contract_from_name(self.target_name)
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
for contract in contracts:
if contract.name == self.target_name:
return contract

# TODO throw error if no contract found
sys.exit(-1)
self.target = get_target_contract(self.slither, self.config["targetContract"])
self.target_file_name = self.target.source_mapping.filename.relative.split("/")[-1]
self.fuzzer = fuzzer

def create_poc(self) -> str:
"""Takes in a directory path to the echidna reproducers and generates a test file"""

file_list: list[dict[str, Any]] = []
tests_list = []
dir_list = []
if self.all_sequences:
if self.config["allSequences"]:
dir_list = self.fuzzer.corpus_dirs
else:
dir_list = [self.fuzzer.reproducer_dir]
Expand Down Expand Up @@ -79,13 +71,12 @@ def create_poc(self) -> str:

# 4. Generate the test file
template = jinja2.Template(templates["CONTRACT"])
write_path = f"{self.test_dir}{self.target_name}"
inheritance_path = f"{self.inheritance_path}{self.target_name}"

write_path = os.path.join(self.config["testsDir"], self.config["targetContract"])
inheritance_path = os.path.join(self.config["inheritancePath"])
# 5. Save the test file
test_file_str = template.render(
file_path=f"{inheritance_path}.sol",
target_name=self.target_name,
file_path=inheritance_path,
target_name=self.config["targetContract"],
amount=0,
tests=tests_list,
fuzzer=self.fuzzer.name,
Expand Down
16 changes: 3 additions & 13 deletions fuzz_utils/generate/fuzzers/Echidna.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jinja2

from slither import Slither
from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.user_defined_type import UserDefinedType
Expand All @@ -16,9 +15,10 @@
from fuzz_utils.templates.foundry_templates import templates
from fuzz_utils.utils.encoding import parse_echidna_byte_string
from fuzz_utils.utils.error_handler import handle_exit
from fuzz_utils.utils.slither_utils import get_target_contract


# pylint: disable=too-many-instance-attributes
# pylint: disable=too-few-public-methods,too-many-instance-attributes
class Echidna:
"""
Handles the generation of Foundry test files from Echidna reproducers
Expand All @@ -30,22 +30,12 @@ def __init__(
self.name = "Echidna"
self.target_name = target_name
self.slither = slither
self.target = self.get_target_contract()
self.target = get_target_contract(slither, target_name)
self.reproducer_dir = f"{corpus_path}/reproducers"
self.corpus_dirs = [f"{corpus_path}/coverage", self.reproducer_dir]
self.named_inputs = named_inputs
self.declared_variables: set[tuple[str, str]] = set()

def get_target_contract(self) -> Contract:
"""Finds and returns Slither Contract"""
contracts = self.slither.get_contract_from_name(self.target_name)
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
for contract in contracts:
if contract.name == self.target_name:
return contract

handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")

def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str:
"""
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.
Expand Down
18 changes: 4 additions & 14 deletions fuzz_utils/generate/fuzzers/Medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from eth_abi import abi
from eth_utils import to_checksum_address
from slither import Slither
from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.user_defined_type import UserDefinedType
Expand All @@ -16,9 +15,10 @@
from fuzz_utils.templates.foundry_templates import templates
from fuzz_utils.utils.encoding import byte_to_escape_sequence
from fuzz_utils.utils.error_handler import handle_exit
from fuzz_utils.utils.slither_utils import get_target_contract


class Medusa: # pylint: disable=too-many-instance-attributes
# pylint: disable=too-few-public-methods,too-many-instance-attributes
class Medusa:
"""
Handles the generation of Foundry test files from Medusa reproducers
"""
Expand All @@ -30,7 +30,7 @@ def __init__(
self.target_name = target_name
self.corpus_path = corpus_path
self.slither = slither
self.target = self.get_target_contract()
self.target = get_target_contract(slither, target_name)
self.reproducer_dir = f"{corpus_path}/test_results"
self.corpus_dirs = [
f"{corpus_path}/call_sequences/immutable",
Expand All @@ -40,16 +40,6 @@ def __init__(
self.named_inputs = named_inputs
self.declared_variables: set[tuple[str, str]] = set()

def get_target_contract(self) -> Contract:
"""Finds and returns Slither Contract"""
contracts = self.slither.get_contract_from_name(self.target_name)
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
for contract in contracts:
if contract.name == self.target_name:
return contract

handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")

def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str:
"""
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.
Expand Down
44 changes: 39 additions & 5 deletions fuzz_utils/parsing/commands/generate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Defines the flags and logic associated with the `generate` command"""
import json
from pathlib import Path
from argparse import Namespace, ArgumentParser
from slither import Slither
from fuzz_utils.utils.crytic_print import CryticPrint
from fuzz_utils.generate.FoundryTest import FoundryTest
from fuzz_utils.generate.fuzzers.Medusa import Medusa
from fuzz_utils.generate.fuzzers.Echidna import Echidna
from fuzz_utils.utils.error_handler import handle_exit
from fuzz_utils.parsing.parser_util import check_config_and_set_default_values, open_config
from fuzz_utils.utils.slither_utils import get_target_contract

COMMAND: str = "generate"


def generate_flags(parser: ArgumentParser) -> None:
Expand Down Expand Up @@ -57,15 +61,13 @@ def generate_flags(parser: ArgumentParser) -> None:
)


# pylint: disable=too-many-branches
def generate_command(args: Namespace) -> None:
"""The execution logic of the `generate` command"""
config: dict = {}
# If the config file is defined, read it
if args.config:
with open(args.config, "r", encoding="utf-8") as readFile:
complete_config = json.load(readFile)
if "generate" in complete_config:
config = complete_config["generate"]
config = open_config(args.config, COMMAND)
# Override the config with the CLI values
if args.compilation_path:
config["compilationPath"] = args.compilation_path
Expand All @@ -90,10 +92,18 @@ def generate_command(args: Namespace) -> None:
if "allSequences" not in config:
config["allSequences"] = False

check_config_and_set_default_values(
config,
["compilationPath", "testsDir", "fuzzer", "corpusDir"],
[".", "test", "medusa", "corpus"],
)

CryticPrint().print_information("Running Slither...")
slither = Slither(args.compilation_path)
fuzzer: Echidna | Medusa

derive_config(slither, config)

match config["fuzzer"]:
case "echidna":
fuzzer = Echidna(
Expand All @@ -114,3 +124,27 @@ def generate_command(args: Namespace) -> None:
foundry_test = FoundryTest(config, slither, fuzzer)
foundry_test.create_poc()
CryticPrint().print_success("Done!")


def derive_config(slither: Slither, config: dict) -> None:
"""Derive values for the target contract and inheritance path"""
# Derive target if it is not defined but the compilationPath only contains one contract
if "targetContract" not in config or len(config["targetContract"]) == 0:
if len(slither.contracts_derived) == 1:
config["targetContract"] = slither.contracts_derived[0].name
CryticPrint().print_information(
f"Target contract not specified. Using derived target: {config['targetContract']}."
)
else:
handle_exit(
"Target contract cannot be determined. Please specify the target with `-c targetName`"
)

# Derive inheritance path if it is not defined
if "inheritancePath" not in config or len(config["inheritancePath"]) == 0:
contract = get_target_contract(slither, config["targetContract"])
contract_path = Path(contract.source_mapping.filename.relative)
tests_path = Path(config["testsDir"])
config["inheritancePath"] = str(
Path(*([".." * len(tests_path.parts)])).joinpath(contract_path)
)
20 changes: 8 additions & 12 deletions fuzz_utils/parsing/commands/template.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Defines the flags and logic associated with the `template` command"""
import os
import json
from argparse import Namespace, ArgumentParser
from slither import Slither
from fuzz_utils.template.HarnessGenerator import HarnessGenerator
from fuzz_utils.utils.crytic_print import CryticPrint
from fuzz_utils.utils.remappings import find_remappings
from fuzz_utils.utils.error_handler import handle_exit
from fuzz_utils.parsing.parser_util import (
check_configuration_field_exists_and_non_empty,
open_config,
)

COMMAND: str = "template"


def template_flags(parser: ArgumentParser) -> None:
Expand Down Expand Up @@ -42,10 +47,7 @@ def template_command(args: Namespace) -> None:
else:
output_dir = os.path.join("./test", "fuzzing")
if args.config:
with open(args.config, "r", encoding="utf-8") as readFile:
complete_config = json.load(readFile)
if "template" in complete_config:
config = complete_config["template"]
config = open_config(args.config, COMMAND)

if args.target_contracts:
config["targets"] = args.target_contracts
Expand All @@ -72,15 +74,9 @@ def check_configuration(config: dict) -> None:
"""Checks the configuration"""
mandatory_configuration_fields = ["mode", "targets", "compilationPath"]
for field in mandatory_configuration_fields:
check_configuration_field_exists_and_non_empty(config, field)
check_configuration_field_exists_and_non_empty(config, COMMAND, field)

if config["mode"].lower() not in ("simple", "prank", "actor"):
handle_exit(
f"The selected mode {config['mode']} is not a valid harness generation strategy."
)


def check_configuration_field_exists_and_non_empty(config: dict, field: str) -> None:
"""Checks that the configuration dictionary contains a non-empty field"""
if field not in config or len(config[field]) == 0:
handle_exit(f"The template configuration field {field} is not configured.")
31 changes: 31 additions & 0 deletions fuzz_utils/parsing/parser_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Utility functions used in the command parsers"""
import json
from fuzz_utils.utils.error_handler import handle_exit


def check_config_and_set_default_values(
config: dict, fields: list[str], defaults: list[str]
) -> None:
"""Checks that the configuration dictionary contains a non-empty field"""
assert len(fields) == len(defaults)
for idx, field in enumerate(fields):
if field not in config or len(config[field]) == 0:
config[field] = defaults[idx]


def check_configuration_field_exists_and_non_empty(config: dict, command: str, field: str) -> None:
"""Checks that the configuration dictionary contains a non-empty field"""
if field not in config or len(config[field]) == 0:
handle_exit(f"The {command} configuration field {field} is not configured.")


def open_config(cli_config: str, command: str) -> dict:
"""Open config file if provided return its contents"""
with open(cli_config, "r", encoding="utf-8") as readFile:
complete_config = json.load(readFile)
if command in complete_config:
return complete_config[command]

handle_exit(
f"The provided configuration file does not contain the `{command}` command configuration field."
)
Loading

0 comments on commit ba33ab9

Please sign in to comment.