Skip to content

Commit

Permalink
refactor: use web3py v7 persistent provider for subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Jan 28, 2025
1 parent 5a4ad92 commit d1966fd
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 251 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@
url="https://github.com/ApeWorX/silverback",
include_package_data=True,
install_requires=[
"apepay>=0.3.2,<1",
"apepay>=0.3.3,<1",
"click", # Use same version as eth-ape
"eth-ape>=0.8.19,<1.0",
"eth-ape>=0.8.24,<1",
"ethpm-types>=0.6.10", # lower pin only, `eth-ape` governs upper pin
"eth-pydantic-types", # Use same version as eth-ape
"packaging", # Use same version as eth-ape
"pydantic_settings", # Use same version as eth-ape
"taskiq[metrics]>=0.11.9,<0.12",
"tomlkit>=0.12,<1", # For reading/writing global platform profile
"fief-client[cli]>=0.19,<1", # for platform auth/cluster login
"websockets>=14.1,<15", # For subscriptions
"web3>=7.7,<8", # TODO: Remove when Ape v0.9 is released (Ape v0.8 allows web3 v6)
],
entry_points={
"console_scripts": ["silverback=silverback._cli:cli"],
Expand Down
147 changes: 88 additions & 59 deletions silverback/runner.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Callable

from ape import chain
from ape.logging import logger
from ape.utils import ManagerAccessMixin
from ape_ethereum.ecosystem import keccak
from eth_utils import to_hex
from ethpm_types import EventABI
from packaging.specifiers import SpecifierSet
from packaging.version import Version
from taskiq import AsyncTaskiqTask
from taskiq.kicker import AsyncKicker
from web3 import AsyncWeb3, WebSocketProvider
from web3.utils.subscriptions import (
LogsSubscription,
LogsSubscriptionContext,
NewHeadsSubscription,
NewHeadsSubscriptionContext,
)

from .exceptions import Halt, NoTasksAvailableError, NoWebsocketAvailableError, StartupFailure
from .main import SilverbackBot, SystemConfig, TaskData
from .recorder import BaseRecorder, TaskResult
from .state import Datastore, StateSnapshot
from .subscriptions import SubscriptionType, Web3SubscriptionsManager
from .types import TaskType
from .utils import (
async_wrap_iter,
Expand Down Expand Up @@ -88,18 +96,18 @@ async def _checkpoint(
await self.datastore.save(result.return_value)

@abstractmethod
async def _block_task(self, task_data: TaskData):
async def _block_task(self, task_data: TaskData) -> asyncio.Task | None:
"""
Handle a block_handler task
"""

@abstractmethod
async def _event_task(self, task_data: TaskData):
async def _event_task(self, task_data: TaskData) -> asyncio.Task | None:
"""
Handle an event handler task for the given contract event
"""

async def run(self):
async def run(self, *runtime_tasks: asyncio.Task | Callable[[], asyncio.Task]):
"""
Run the task broker client for the assembled ``SilverbackBot`` bot.
Expand Down Expand Up @@ -210,6 +218,7 @@ async def run(self):
# NOTE: No need to handle results otherwise

# Create our long-running event listeners
listener_tasks = []
new_block_taskdata_results = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK
)
Expand All @@ -230,23 +239,22 @@ async def run(self):
raise NoTasksAvailableError()

# NOTE: Any propagated failure in here should be handled such that shutdown tasks also run
# TODO: `asyncio.TaskGroup` added in Python 3.11
listener_tasks = (
*(
asyncio.create_task(self._block_task(task_def))
for task_def in new_block_taskdata_results.return_value
),
*(
asyncio.create_task(self._event_task(task_def))
for task_def in event_log_taskdata_results.return_value
),
)
for task_def in new_block_taskdata_results.return_value:
if (task := await self._block_task(task_def)) is not None:
listener_tasks.append(task)

for task_def in event_log_taskdata_results.return_value:
if (task := await self._event_task(task_def)) is not None:
listener_tasks.append(task)

listener_tasks.extend(t if isinstance(t, asyncio.Task) else t() for t in runtime_tasks)

# NOTE: Safe to do this because no tasks were actually scheduled to run
if len(listener_tasks) == 0:
raise NoTasksAvailableError()

# Run until one task bubbles up an exception that should stop execution
# TODO: `asyncio.TaskGroup` added in Python 3.11
tasks_with_errors, tasks_running = await asyncio.wait(
listener_tasks, return_when=asyncio.FIRST_EXCEPTION
)
Expand Down Expand Up @@ -310,19 +318,21 @@ def __init__(self, bot: SilverbackBot, *args, **kwargs):

self.ws_uri = ws_uri

async def _block_task(self, task_data: TaskData):
async def _block_task(self, task_data: TaskData) -> asyncio.Task | None:
new_block_task_kicker = self._create_task_kicker(task_data)
sub_id = await self.subscriptions.subscribe(SubscriptionType.BLOCKS)
logger.debug(f"Handling blocks via {sub_id}")

async for raw_block in self.subscriptions.get_subscription_data(sub_id):
block = self.provider.network.ecosystem.decode_block(hexbytes_dict(raw_block))

async def block_handler(ctx: NewHeadsSubscriptionContext):
block = self.provider.network.ecosystem.decode_block(hexbytes_dict(ctx.result))
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await new_block_task_kicker.kiq(raw_block))
await self._handle_task(await new_block_task_kicker.kiq(block))
await self._checkpoint(last_block_processed=block.number)

async def _event_task(self, task_data: TaskData):
sub_id = await self._web3.subscription_manager.subscribe(
NewHeadsSubscription(label=task_data.name, handler=block_handler)
)
logger.debug(f"Handling blocks via {sub_id}")

async def _event_task(self, task_data: TaskData) -> asyncio.Task | None:
if not (contract_address := task_data.labels.get("contract_address")):
raise StartupFailure("Contract instance required.")

Expand All @@ -333,26 +343,37 @@ async def _event_task(self, task_data: TaskData):

event_log_task_kicker = self._create_task_kicker(task_data)

sub_id = await self.subscriptions.subscribe(
SubscriptionType.EVENTS,
address=contract_address,
topics=["0x" + keccak(text=event_abi.selector).hex()],
)
logger.debug(f"Handling '{contract_address}:{event_abi.name}' logs via {sub_id}")

async for raw_event in self.subscriptions.get_subscription_data(sub_id):
async def log_handler(ctx: LogsSubscriptionContext):
event = next( # NOTE: `next` is okay since it only has one item
self.provider.network.ecosystem.decode_logs([raw_event], event_abi)
self.provider.network.ecosystem.decode_logs([ctx.result], event_abi)
)

# TODO: Fix upstream w/ web3py
event.transaction_hash = "0x" + event.transaction_hash.hex()
await self._checkpoint(last_block_seen=event.block_number)
await self._handle_task(await event_log_task_kicker.kiq(event))
await self._checkpoint(last_block_processed=event.block_number)

async def run(self):
async with Web3SubscriptionsManager(self.ws_uri) as subscriptions:
self.subscriptions = subscriptions
await super().run()
sub_id = await self._web3.subscription_manager.subscribe(
LogsSubscription(
label=task_data.name,
address=contract_address,
topics=[to_hex(keccak(text=event_abi.selector))],
handler=log_handler,
)
)
logger.debug(f"Handling '{contract_address}:{event_abi.name}' logs via {sub_id}")

async def run(self, *runtime_tasks: asyncio.Task | Callable[[], asyncio.Task]):
async with AsyncWeb3(WebSocketProvider(self.ws_uri)) as web3:
self._web3 = web3

def run_subscriptions() -> asyncio.Task:
return asyncio.create_task(
web3.subscription_manager.handle_subscriptions(run_forever=True)
)

await super().run(*runtime_tasks, run_subscriptions)
await web3.subscription_manager.unsubscribe_all()


class PollingRunner(BaseRunner, ManagerAccessMixin):
Expand All @@ -370,7 +391,7 @@ def __init__(self, bot: SilverbackBot, *args, **kwargs):
"Do not use in production over long time periods unless you know what you're doing."
)

async def _block_task(self, task_data: TaskData):
async def _block_task(self, task_data: TaskData) -> asyncio.Task | None:
new_block_task_kicker = self._create_task_kicker(task_data)

if block_settings := self.bot.poll_settings.get("_blocks_"):
Expand All @@ -381,17 +402,21 @@ async def _block_task(self, task_data: TaskData):
new_block_timeout = (
new_block_timeout if new_block_timeout is not None else self.bot.new_block_timeout
)
async for block in async_wrap_iter(
chain.blocks.poll_blocks(
# NOTE: No start block because we should begin polling from head
new_block_timeout=new_block_timeout,
)
):
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await new_block_task_kicker.kiq(block))
await self._checkpoint(last_block_processed=block.number)

async def _event_task(self, task_data: TaskData):
async def block_handler():
async for block in async_wrap_iter(
chain.blocks.poll_blocks(
# NOTE: No start block because we should begin polling from head
new_block_timeout=new_block_timeout,
)
):
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await new_block_task_kicker.kiq(block))
await self._checkpoint(last_block_processed=block.number)

return asyncio.create_task(block_handler())

async def _event_task(self, task_data: TaskData) -> asyncio.Task | None:
if not (contract_address := task_data.labels.get("contract_address")):
raise StartupFailure("Contract instance required.")

Expand All @@ -409,14 +434,18 @@ async def _event_task(self, task_data: TaskData):
new_block_timeout = (
new_block_timeout if new_block_timeout is not None else self.bot.new_block_timeout
)
async for event in async_wrap_iter(
self.provider.poll_logs(
# NOTE: No start block because we should begin polling from head
address=contract_address,
new_block_timeout=new_block_timeout,
events=[event_abi],
)
):
await self._checkpoint(last_block_seen=event.block_number)
await self._handle_task(await event_log_task_kicker.kiq(event))
await self._checkpoint(last_block_processed=event.block_number)

async def log_handler():
async for event in async_wrap_iter(
self.provider.poll_logs(
# NOTE: No start block because we should begin polling from head
address=contract_address,
new_block_timeout=new_block_timeout,
events=[event_abi],
)
):
await self._checkpoint(last_block_seen=event.block_number)
await self._handle_task(await event_log_task_kicker.kiq(event))
await self._checkpoint(last_block_processed=event.block_number)

return asyncio.create_task(log_handler())
Loading

0 comments on commit d1966fd

Please sign in to comment.