Skip to content

Commit

Permalink
Add type hints/annotations and refactoring
Browse files Browse the repository at this point in the history
- add type hints/annotations and refactoring to fix mypy and other errors
- remove metaclass from EVM.All profile, use simple tasks attribute assignment instead
  • Loading branch information
erwin-wee committed Dec 23, 2023
1 parent 5078432 commit 8234a12
Show file tree
Hide file tree
Showing 18 changed files with 211 additions and 181 deletions.
42 changes: 22 additions & 20 deletions chainbench/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

import click
from click import Context, Parameter
from locust import runners

from chainbench.user.evm import EVMMethods
Expand Down Expand Up @@ -44,11 +45,11 @@
)
@click.version_option(message="%(prog)s-%(version)s")
@click.pass_context
def cli(ctx: click.Context):
def cli(ctx: Context):
ctx.obj = ContextData()


def validate_method(ctx, param, value) -> str:
def validate_method(ctx: Context, param: Parameter, value: str) -> str:
if value is not None:
method_list = [task_to_method(task) for task in get_subclass_methods(EVMMethods)]
if value not in method_list:
Expand All @@ -60,7 +61,7 @@ def validate_method(ctx, param, value) -> str:
return value


def profile_exists(profile, profile_dir):
def profile_exists(profile: str, profile_dir: Path) -> None:
profile_list = get_profiles(profile_dir)
if profile not in profile_list:
raise click.BadParameter(
Expand All @@ -69,21 +70,17 @@ def profile_exists(profile, profile_dir):
)


def validate_profile_dir(ctx, param, value) -> Path | None:
def validate_profile_dir(ctx: Context, param: Parameter, value: Path) -> Path | None:
if value is not None:
profile_dir = Path(value)
if not profile_dir.exists():
raise click.BadParameter(f"Profile directory {value} does not exist.")
if not profile_dir.is_dir():
raise click.BadParameter(f"Profile directory {value} is not a directory.")
if "profile" in ctx.params:
profile_exists(ctx.params["profile"], profile_dir)
return profile_dir
else:
return None


def validate_profile(ctx, param, value) -> str:
def validate_profile(ctx: Context, param: Parameter, value: str) -> str:
if value is not None:
if "profile_dir" in ctx.params:
profile_exists(value, ctx.params["profile_dir"])
Expand All @@ -100,7 +97,12 @@ def validate_profile(ctx, param, value) -> str:
)
@click.argument("method", default=None, callback=validate_method, required=False)
@click.option(
"-d", "--profile-dir", default=None, callback=validate_profile_dir, type=click.Path(), help="Profile directory"
"-d",
"--profile-dir",
default=None,
callback=validate_profile_dir,
type=click.Path(exists=True, dir_okay=True, file_okay=False, path_type=Path),
help="Profile directory",
)
@click.option(
"-p",
Expand Down Expand Up @@ -133,7 +135,7 @@ def validate_profile(ctx, param, value) -> str:
"--results-dir",
default=Path("results"),
help="Results directory",
type=click.Path(),
type=click.Path(dir_okay=True, file_okay=False, writable=True, path_type=Path),
show_default=True,
)
@click.option("--headless", is_flag=True, help="Run in headless mode")
Expand Down Expand Up @@ -176,7 +178,7 @@ def validate_profile(ctx, param, value) -> str:
@click.option("--size", default=None, help="Set the size of the test data. e.g. --size S")
@click.pass_context
def start(
ctx: click.Context,
ctx: Context,
profile: str,
profile_dir: Path | None,
host: str,
Expand All @@ -203,7 +205,7 @@ def start(
use_recent_blocks: bool,
size: str | None,
method: str | None = None,
):
) -> None:
if notify:
click.echo(f"Notify when test is finished using topic: {notify}")
notifier = Notifier(topic=notify)
Expand Down Expand Up @@ -383,7 +385,7 @@ def start(
ctx.obj.notifier.notify(title="Test finished", message=f"Test finished for {profile}", tags=["tada"])


def validate_clients(ctx, param, value) -> list[str]:
def validate_clients(ctx: Context, param: Parameter, value: str) -> list[str]:
from chainbench.tools.discovery.rpc import RPCDiscovery

if value is not None:
Expand Down Expand Up @@ -414,7 +416,7 @@ def validate_clients(ctx, param, value) -> list[str]:
help="List of methods used to test the endpoint will "
"be based on the clients specified here, default to eth. e.g. --clients eth,bsc.\n",
)
def discover(endpoint: str | None, clients: list[str]):
def discover(endpoint: str | None, clients: list[str]) -> None:
if not endpoint:
click.echo("Target endpoint is required.")
sys.exit(1)
Expand All @@ -435,14 +437,14 @@ def discover(endpoint: str | None, clients: list[str]):


@cli.group(name="list", help="Lists values of the given type.")
def _list():
def _list() -> None:
pass


@_list.command(
help="Lists all available client options for method discovery.",
)
def clients():
def clients() -> None:
from chainbench.tools.discovery.rpc import RPCDiscovery

for client in RPCDiscovery.get_clients():
Expand All @@ -458,18 +460,18 @@ def clients():
"--profile-dir",
default=get_base_path(__file__) / "profile",
callback=validate_profile_dir,
type=click.Path(),
type=click.Path(exists=True, dir_okay=True, file_okay=False, path_type=Path),
help="Profile directory",
)
def profiles(profile_dir: Path):
def profiles(profile_dir: Path) -> None:
for profile in get_profiles(profile_dir):
click.echo(profile)


@_list.command(
help="Lists all available evm methods.",
)
def methods():
def methods() -> None:
task_list = get_subclass_methods(EVMMethods)
for task in task_list:
click.echo(task_to_method(task))
Expand Down
11 changes: 2 additions & 9 deletions chainbench/profile/evm/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,11 @@
# """

from locust import constant_pacing
from locust.user.users import UserMeta

from chainbench.user.evm import EVMMethods
from chainbench.util.cli import get_subclass_methods


class EVMMethodsMeta(UserMeta):
def __new__(cls, name, bases, attrs):
new_cls = super().__new__(cls, name, bases, attrs)
new_cls.tasks = [EVMMethods.get_method(method) for method in get_subclass_methods(EVMMethods)]
return new_cls


class EVMAllProfile(EVMMethods, metaclass=EVMMethodsMeta):
class EVMAllProfile(EVMMethods):
wait_time = constant_pacing(1)
tasks = [EVMMethods.get_method(method) for method in get_subclass_methods(EVMMethods)]
45 changes: 30 additions & 15 deletions chainbench/test_data/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import logging
import typing as t
from argparse import Namespace
from dataclasses import dataclass, field
from secrets import token_hex

import httpx
from configargparse import Namespace
from gevent.lock import Semaphore as GeventSemaphore
from tenacity import retry, stop_after_attempt

Expand Down Expand Up @@ -59,10 +59,10 @@ class BlockchainData:
tx_hashes: TxHashes = field(default_factory=list)
accounts: Accounts = field(default_factory=list)

def to_json(self):
def to_json(self) -> str:
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)

def from_json(self, json_data):
def from_json(self, json_data: str) -> None:
data = json.loads(json_data)
self.start_block_number = data["start_block_number"]
self.end_block_number = data["end_block_number"]
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self, rpc_version: str = "2.0"):

self._data: BlockchainData | None = None

def update(self, host_url: str, parsed_options: Namespace):
def update(self, host_url: str, parsed_options: Namespace) -> None:
self._logger.info("Updating data")
self._host = host_url
self._logger.debug("Host: %s", self._host)
Expand All @@ -104,14 +104,24 @@ def update(self, host_url: str, parsed_options: Namespace):
self._lock.release()
self._logger.info("Lock released")

def _process_block(self, block_number, block, txs, tx_hashes, accounts, blocks, size, return_txs):
def _process_block(
self,
block_number: BlockNumber,
block: Block,
txs: list[Tx],
tx_hashes: set[TxHash],
accounts: set[Account],
blocks: set[tuple[BlockNumber, BlockHash]],
size: BlockchainDataSize,
return_txs: bool = True,
) -> None:
raise NotImplementedError

def _get_start_and_end_blocks(self, parsed_options):
def _get_start_and_end_blocks(self, parsed_options: Namespace) -> tuple[BlockNumber, BlockNumber]:
raise NotImplementedError

# get initial data from blockchain
def _get_init_data_from_blockchain(self, parsed_options) -> BlockchainData:
def _get_init_data_from_blockchain(self, parsed_options: Namespace) -> BlockchainData:
def print_progress():
print(
f"txs = {len(txs)}/{size.txs} "
Expand Down Expand Up @@ -155,7 +165,7 @@ def print_progress():
accounts=sorted(list(accounts)),
)

def init_data_from_json(self, json_data: str):
def init_data_from_json(self, json_data: str) -> None:
self._data = BlockchainData()
self._data.from_json(json_data)
self._logger.info("Data updated. Releasing lock")
Expand Down Expand Up @@ -185,14 +195,14 @@ def _parse_hex_to_int(value: str) -> int:
return int(value, 16)

@staticmethod
def _append_if_not_none(data, val):
def _append_if_not_none(data: list | set, val: t.Any) -> None:
if val is not None:
if isinstance(data, list):
data.append(val)
elif isinstance(data, set):
data.add(val)

def _make_body(self, method: str, params: list[t.Any] | None = None):
def _make_body(self, method: str, params: list[t.Any] | None = None) -> dict[str, t.Any]:
if params is None:
params = []

Expand All @@ -203,7 +213,7 @@ def _make_body(self, method: str, params: list[t.Any] | None = None):
"id": token_hex(8),
}

def _make_call(self, method: str, params: list[t.Any] | None = None):
def _make_call(self, method: str, params: list[t.Any] | None = None) -> t.Any:
if params is None:
params = []

Expand Down Expand Up @@ -236,17 +246,22 @@ def _make_call(self, method: str, params: list[t.Any] | None = None):

return data["result"]

def close(self):
def close(self) -> None:
self._client.close()

def wait(self):
def wait(self) -> None:
self._lock.wait()

def _fetch_block(self, block_number, return_txs: bool = True) -> tuple[BlockNumber, Block]:
def _fetch_block(self, block_number: BlockNumber, return_txs: bool = True) -> tuple[BlockNumber, Block]:
raise NotImplementedError

@retry(reraise=True, stop=stop_after_attempt(5))
def _fetch_random_block(self, start: int, end: int, return_txs: bool = True) -> tuple[BlockNumber, Block]:
def _fetch_random_block(
self,
start: BlockNumber,
end: BlockNumber,
return_txs: bool = True,
) -> tuple[BlockNumber, Block]:
rng = get_rng()
block_number = rng.random.randint(start, end)
return self._fetch_block(block_number, return_txs=return_txs)
Expand Down
4 changes: 3 additions & 1 deletion chainbench/test_data/dummy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from configargparse import Namespace

from chainbench.test_data.base import BaseTestData, BlockchainData


class DummyTestData(BaseTestData):
def _get_init_data(self):
def _get_init_data_from_blockchain(self, parsed_options: Namespace) -> BlockchainData:
return BlockchainData()
8 changes: 5 additions & 3 deletions chainbench/test_data/evm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Mapping

from configargparse import Namespace

from chainbench.test_data.base import (
Account,
BaseTestData,
Expand Down Expand Up @@ -51,7 +53,7 @@ class EVMTestData(BaseTestData):
def _fetch_chain_id(self) -> int:
return self._parse_hex_to_int(self._make_call("eth_chainId"))

def _fetch_latest_block_number(self) -> int:
def _fetch_latest_block_number(self) -> BlockNumber:
result = self._make_call("eth_blockNumber")
return self._parse_hex_to_int(result)

Expand All @@ -67,7 +69,7 @@ def _fetch_block(self, block_number: int | str, return_txs: bool = True) -> tupl
result = self._make_call("eth_getBlockByNumber", [block_number, return_txs])
return self._parse_hex_to_int(result["number"]), result

def _get_start_and_end_blocks(self, parsed_options) -> tuple[BlockNumber, BlockNumber]:
def _get_start_and_end_blocks(self, parsed_options: Namespace) -> tuple[BlockNumber, BlockNumber]:
chain_id: int = self._fetch_chain_id()
end_block_number = self._fetch_latest_block_number()
if not parsed_options.use_recent_blocks and chain_id in self.CHAIN_INFO:
Expand All @@ -88,7 +90,7 @@ def _process_block(
blocks: set[tuple[BlockNumber, BlockHash]],
size: BlockchainDataSize,
return_txs: bool = True,
):
) -> None:
if size.blocks > len(blocks):
self._append_if_not_none(blocks, (block_number, block["hash"]))
if return_txs:
Expand Down
21 changes: 11 additions & 10 deletions chainbench/test_data/solana.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from configargparse import Namespace
from tenacity import retry, stop_after_attempt

from chainbench.test_data.base import (
Expand All @@ -19,7 +20,7 @@
class SolanaTestData(BaseTestData):
BLOCK_TIME = 0.4

def _fetch_block(self, block_number: int, return_txs: bool = True) -> tuple[BlockNumber, Block]:
def _fetch_block(self, block_number: BlockNumber, return_txs: bool = True) -> tuple[BlockNumber, Block]:
if return_txs:
transaction_details = "accounts"
else:
Expand All @@ -37,21 +38,21 @@ def _fetch_block(self, block_number: int, return_txs: bool = True) -> tuple[Bloc
raise e
return block_number, result

def _fetch_latest_slot_number(self):
def _fetch_latest_slot_number(self) -> BlockNumber:
slot = self._make_call("getLatestBlockhash")["context"]["slot"]
return slot

@retry(reraise=True, stop=stop_after_attempt(5))
def _fetch_latest_block(self):
def _fetch_latest_block(self) -> tuple[BlockNumber, Block]:
slot_number = self._fetch_latest_slot_number()
latest_block = self._fetch_block(slot_number, return_txs=True)
return slot_number, latest_block
latest_block_number, latest_block = self._fetch_block(slot_number, return_txs=True)
return latest_block_number, latest_block

def _fetch_first_available_block(self):
block = self._make_call("getFirstAvailableBlock")
return block
def _fetch_first_available_block(self) -> BlockNumber:
block_number = self._make_call("getFirstAvailableBlock")
return block_number

def _get_start_and_end_blocks(self, parsed_options) -> tuple[int, int]:
def _get_start_and_end_blocks(self, parsed_options: Namespace) -> tuple[BlockNumber, BlockNumber]:
end_block_number, _latest_block = self._fetch_latest_block()
start_block_number = self._fetch_first_available_block()

Expand All @@ -71,7 +72,7 @@ def _process_block(
blocks: set[tuple[BlockNumber, BlockHash]],
size: BlockchainDataSize,
return_txs: bool = True,
):
) -> None:
if size.blocks > len(blocks):
self._append_if_not_none(blocks, (block_number, block["blockhash"]))
if return_txs:
Expand Down
Loading

0 comments on commit 8234a12

Please sign in to comment.