Skip to content

Commit

Permalink
fix: typing updates
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Jan 28, 2025
1 parent d1966fd commit 2d7edfe
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
2 changes: 1 addition & 1 deletion silverback/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SilverbackException(ApeException):

# TODO: `ExceptionGroup` added in Python 3.11
class StartupFailure(SilverbackException):
def __init__(self, *exceptions: Exception | str):
def __init__(self, *exceptions: BaseException | str | None):
if len(exceptions) == 1 and isinstance(exceptions[0], str):
super().__init__(exceptions[0])
elif error_str := "\n".join(str(e) for e in exceptions):
Expand Down
23 changes: 9 additions & 14 deletions silverback/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@
from .recorder import BaseRecorder, TaskResult
from .state import Datastore, StateSnapshot
from .types import TaskType
from .utils import (
async_wrap_iter,
hexbytes_dict,
run_taskiq_task_group_wait_results,
run_taskiq_task_wait_result,
)
from .utils import async_wrap_iter, run_taskiq_task_group_wait_results, run_taskiq_task_wait_result


class BaseRunner(ABC):
Expand Down Expand Up @@ -156,12 +151,12 @@ async def run(self, *runtime_tasks: asyncio.Task | Callable[[], asyncio.Task]):
"Silverback no longer supports runner-based snapshotting, "
"please upgrade your bot SDK version to latest to use snapshots."
)
startup_state = StateSnapshot(
startup_state: StateSnapshot | None = StateSnapshot(
last_block_seen=-1,
last_block_processed=-1,
) # Use empty snapshot

elif not (startup_state := await self.datastore.init(bot_id=self.bot.identifier)):
elif not (startup_state := await self.datastore.init(self.bot.identifier)):
logger.warning("No state snapshot detected, using empty snapshot")
startup_state = StateSnapshot(
# TODO: Migrate these to parameters (remove explicitly from state)
Expand All @@ -186,7 +181,7 @@ async def run(self, *runtime_tasks: asyncio.Task | Callable[[], asyncio.Task]):

# Initialize recorder (if available)
if self.recorder:
await self.recorder.init(bot_id=self.bot.identifier)
await self.recorder.init(self.bot.identifier)

# Execute Silverback startup task before we init the rest
startup_taskdata_result = await run_taskiq_task_wait_result(
Expand Down Expand Up @@ -318,11 +313,11 @@ def __init__(self, bot: SilverbackBot, *args, **kwargs):

self.ws_uri = ws_uri

async def _block_task(self, task_data: TaskData) -> asyncio.Task | None:
async def _block_task(self, task_data: TaskData) -> None:
new_block_task_kicker = self._create_task_kicker(task_data)

async def block_handler(ctx: NewHeadsSubscriptionContext):
block = self.provider.network.ecosystem.decode_block(hexbytes_dict(ctx.result))
block = self.provider.network.ecosystem.decode_block(dict(ctx.result))
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await new_block_task_kicker.kiq(block))
await self._checkpoint(last_block_processed=block.number)
Expand All @@ -332,7 +327,7 @@ async def block_handler(ctx: NewHeadsSubscriptionContext):
)
logger.debug(f"Handling blocks via {sub_id}")

async def _event_task(self, task_data: TaskData) -> asyncio.Task | None:
async def _event_task(self, task_data: TaskData) -> None:
if not (contract_address := task_data.labels.get("contract_address")):
raise StartupFailure("Contract instance required.")

Expand Down Expand Up @@ -391,7 +386,7 @@ def __init__(self, bot: SilverbackBot, *args, **kwargs):
"Do not use in production over long time periods unless you know what you're doing."
)

async def _block_task(self, task_data: TaskData) -> asyncio.Task | None:
async def _block_task(self, task_data: TaskData) -> asyncio.Task:
new_block_task_kicker = self._create_task_kicker(task_data)

if block_settings := self.bot.poll_settings.get("_blocks_"):
Expand All @@ -416,7 +411,7 @@ async def block_handler():

return asyncio.create_task(block_handler())

async def _event_task(self, task_data: TaskData) -> asyncio.Task | None:
async def _event_task(self, task_data: TaskData) -> asyncio.Task:
if not (contract_address := task_data.labels.get("contract_address")):
raise StartupFailure("Contract instance required.")

Expand Down

0 comments on commit 2d7edfe

Please sign in to comment.