Skip to content

Commit

Permalink
chore: refactor asyncio imports (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Dec 17, 2024
1 parent a6e4c45 commit ff34d35
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 51 deletions.
20 changes: 10 additions & 10 deletions eth_portfolio/_db/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import logging
from asyncio import gather, get_event_loop
from contextlib import suppress
from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -141,7 +140,7 @@ def is_token(address) -> bool:

async def _is_token(address) -> bool:
# just breaking a weird lock, dont mind me
if retval := await asyncio.get_event_loop().run_in_executor(process, __is_token, address):
if retval := await get_event_loop().run_in_executor(process, __is_token, address):
logger.debug("%s is token")
else:
logger.debug("%s is not token")
Expand All @@ -153,7 +152,7 @@ def __is_token(address) -> bool:
erc = ERC20(address, asynchronous=True)
if all(
get_event_loop().run_until_complete(
asyncio.gather(erc._symbol(), erc._name(), erc.total_supply_readable())
gather(erc._symbol(), erc._name(), erc.total_supply_readable())
)
):
return True
Expand Down Expand Up @@ -287,10 +286,11 @@ def delete_transaction(transaction: Transaction) -> None:

async def insert_transaction(transaction: Transaction) -> None:
# Make sure these are in the db so below we can call them and use the results all in one transaction
coros = [ensure_block(transaction.block_number), ensure_address(transaction.from_address)]
if transaction.to_address:
coros.append(ensure_address(transaction.to_address))
await asyncio.gather(*coros)
coros = [ensure_block(transaction.block_number), ensure_address(transaction.from_address)] # type: ignore [arg-type]
address = transaction.to_address
if address is not None:
coros.append(ensure_address(address))
await gather(*coros)
await _insert_transaction(transaction)


Expand Down Expand Up @@ -369,7 +369,7 @@ async def insert_internal_transfer(transfer: InternalTransfer) -> None:
coros = [ensure_block(transfer.block_number), ensure_address(transfer.from_address)]
if to_address := getattr(transfer, "to_address", None):
coros.append(ensure_address(to_address))
await asyncio.gather(*coros)
await gather(*coros)
await _insert_internal_transfer(transfer)


Expand Down Expand Up @@ -465,7 +465,7 @@ async def insert_token_transfer(token_transfer: TokenTransfer) -> None:
ensure_address(token_transfer.from_address),
ensure_address(token_transfer.to_address),
]
await asyncio.gather(*coros)
await gather(*coros)
await _insert_token_transfer(token_transfer)


Expand Down
4 changes: 2 additions & 2 deletions eth_portfolio/_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
from asyncio import gather
from typing import TYPE_CHECKING

from y.datatypes import Block
Expand All @@ -18,7 +18,7 @@ def __init__(self, start_block: Block, end_block: Block, ledger: "AddressLedgerB
self.end_block = end_block

async def load_remaining(self) -> None:
return await asyncio.gather(
return await gather(
self.ledger._load_new_objects(self.start_block, self.ledger.cached_thru - 1),
self.ledger._load_new_objects(self.ledger.cached_from + 1, self.end_block),
)
10 changes: 6 additions & 4 deletions eth_portfolio/_ledgers/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,12 +667,14 @@ async def _load_new_objects(
for direction, (start, end) in product(["toAddress", "fromAddress"], block_ranges)
]

generator_function = a_sync.as_completed

# NOTE: We only want tqdm progress bar when there is work to do
if len(block_ranges) > 1:
if len(block_ranges) == 1:
generator_function = a_sync.as_completed
else:
generator_function = partial( # type: ignore [assignment]
generator_function, tqdm=True, desc=f"Trace Filters {self.address}"
)
a_sync.as_completed, tqdm=True, desc=f"Trace Filters {self.address}"
)

if tasks := [
create_task(
Expand Down
15 changes: 7 additions & 8 deletions eth_portfolio/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
Union,
)

import a_sync
import dank_mids
from a_sync import ASyncGenericBase, ASyncIterable, ASyncIterator, as_yielded
from async_lru import alru_cache
from brownie import chain
from brownie.exceptions import ContractNotFound
Expand Down Expand Up @@ -229,22 +229,21 @@ def _unpack_indicies(indicies: Union[Block, Tuple[Block, Block]]) -> Tuple[Block
return start_block, end_block


class _AiterMixin(a_sync.ASyncIterable[_T]):
__doc__ = a_sync.ASyncIterable.__doc__
class _AiterMixin(ASyncIterable[_T]):

def __aiter__(self) -> AsyncIterator[_T]:
return self[self._start_block : chain.height].__aiter__()

def __getitem__(self, slice: slice) -> a_sync.ASyncIterator[_T]:
def __getitem__(self, slice: slice) -> ASyncIterator[_T]:
if slice.start is not None and not isinstance(slice.start, (int, datetime)):
raise TypeError(f"start must be int or datetime. you passed {slice.start}")
if slice.stop and not isinstance(slice.stop, (int, datetime)):
raise TypeError(f"start must be int or datetime. you passed {slice.start}")
if slice.step is not None:
raise ValueError("You cannot use a step here.")
return a_sync.ASyncIterator(self._get_and_yield(slice.start or 0, slice.stop))
return ASyncIterator(self._get_and_yield(slice.start or 0, slice.stop))

def yield_forever(self) -> a_sync.ASyncIterator[_T]:
def yield_forever(self) -> ASyncIterator[_T]:
return self[self._start_block : None]

@abstractmethod
Expand All @@ -261,7 +260,7 @@ def _start_block(self) -> int: ...
_LT = TypeVar("_LT")


class _LedgeredBase(a_sync.ASyncGenericBase, _AiterMixin["LedgerEntry"], Generic[_LT]):
class _LedgeredBase(ASyncGenericBase, _AiterMixin["LedgerEntry"], Generic[_LT]):
"""A mixin class for things with ledgers"""

transactions: _LT
Expand All @@ -285,4 +284,4 @@ def _ledgers(self) -> Iterator[_LT]:
def _get_and_yield(
self, start_block: Block, end_block: Block
) -> AsyncGenerator["LedgerEntry", None]:
return a_sync.as_yielded(*(ledger[start_block:end_block] for ledger in self._ledgers)) # type: ignore [return-value,index]
return as_yielded(*(ledger[start_block:end_block] for ledger in self._ledgers)) # type: ignore [return-value,index]
6 changes: 3 additions & 3 deletions eth_portfolio/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
with external protocols.
"""

import asyncio
import logging
from asyncio import gather
from typing import TYPE_CHECKING, Dict, Optional

import a_sync
Expand Down Expand Up @@ -225,7 +225,7 @@ async def external_balances(self, block: Optional[Block] = None) -> RemoteTokenB
Examples:
>>> external_balances = await address.external_balances(12345678)
"""
balances = await asyncio.gather(
balances = await gather(
self.staking(block, sync=False), self.collateral(block, sync=False)
)
return sum(balances) # type: ignore [arg-type, return-value]
Expand All @@ -246,7 +246,7 @@ async def balances(self, block: Optional[Block]) -> TokenBalances:
Examples:
>>> balances = await address.balances(12345678)
"""
eth_balance, token_balances = await asyncio.gather(
eth_balance, token_balances = await gather(
self.eth_balance(block, sync=False),
self.token_balances(block, sync=False),
)
Expand Down
4 changes: 2 additions & 2 deletions eth_portfolio/buckets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import logging
from asyncio import gather
from typing import Optional, Set

from async_lru import alru_cache
Expand Down Expand Up @@ -50,7 +50,7 @@ async def _unwrap_token(token) -> str:
underlying = await YearnInspiredVault(token, asynchronous=True).underlying
return await _unwrap_token(underlying)
if curve and (pool := await curve.get_pool(token)):
pool_tokens = set(await asyncio.gather(*[_unwrap_token(coin) for coin in await pool.coins]))
pool_tokens = set(await gather(*[_unwrap_token(coin) for coin in await pool.coins]))
if pool_bucket := _pool_bucket(pool_tokens):
return pool_bucket # type: ignore
if aave and await aave.is_atoken(token):
Expand Down
4 changes: 2 additions & 2 deletions eth_portfolio/portfolio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
This file is part of a larger system that includes modules for handling portfolio addresses, ledger entries, and other related tasks.
"""

import asyncio
import logging
from asyncio import gather
from functools import wraps
from typing import Any, AsyncIterator, Dict, Iterable, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -509,7 +509,7 @@ async def df(self, start_block: Block, end_block: Block, full: bool = False) ->
>>> print(df)
"""
df = concat(
await asyncio.gather(
await gather(
*(ledger.df(start_block, end_block, sync=False) for ledger in self._ledgers)
)
)
Expand Down
6 changes: 3 additions & 3 deletions eth_portfolio/protocols/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
import asyncio
from asyncio import gather
from typing import List, Optional

import a_sync
Expand Down Expand Up @@ -28,7 +28,7 @@ class ProtocolWithStakingABC(ProtocolABC, metaclass=abc.ABCMeta):

@stuck_coro_debugger
async def _balances(self, address: Address, block: Optional[Block] = None) -> TokenBalances:
return sum(await asyncio.gather(*[pool.balances(address, block) for pool in self.pools])) # type: ignore
return sum(await gather(*[pool.balances(address, block) for pool in self.pools])) # type: ignore


class StakingPoolABC(ProtocolABC, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -99,7 +99,7 @@ async def _balances(self, address: Address, block: Optional[Block] = None) -> To
if self.should_check(block):
balance = Decimal(await self(address, block=block)) # type: ignore
if balance:
scale, price = await asyncio.gather(self.scale, self.price(block, sync=False))
scale, price = await gather(self.scale, self.price(block, sync=False))
balance /= scale # type: ignore
balances[self.token.address] = Balance(
balance, balance * Decimal(price), token=self.token.address, block=block
Expand Down
4 changes: 2 additions & 2 deletions eth_portfolio/protocols/dsr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
from asyncio import gather
from typing import Optional

from y import Contract, Network, dai
Expand All @@ -18,7 +18,7 @@ def __init__(self) -> None:

async def _balances(self, address: Address, block: Optional[Block] = None) -> TokenBalances:
balances = TokenBalances(block=block)
pie, exchange_rate = await asyncio.gather(
pie, exchange_rate = await gather(
self.dsr_manager.pieOf.coroutine(address, block_identifier=block),
self._exchange_rate(block),
)
Expand Down
14 changes: 7 additions & 7 deletions eth_portfolio/protocols/lending/compound.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
from asyncio import gather
from typing import List, Optional

import a_sync
Expand Down Expand Up @@ -30,7 +30,7 @@ class Compound(LendingProtocol):
@alru_cache(ttl=300)
@stuck_coro_debugger
async def underlyings(self) -> List[ERC20]:
all_markets: List[List[CToken]] = await asyncio.gather(
all_markets: List[List[CToken]] = await gather(
*[comp.markets for comp in compound.trollers.values()]
)
markets: List[Contract] = [
Expand All @@ -43,7 +43,7 @@ async def underlyings(self) -> List[ERC20]:
other_markets = [market for market in markets if hasattr(market, "underlying")]

markets = gas_token_markets + other_markets
underlyings = [weth for market in gas_token_markets] + await asyncio.gather(
underlyings = [weth for market in gas_token_markets] + await gather(
*[market.underlying.coroutine() for market in other_markets]
)

Expand All @@ -69,10 +69,10 @@ async def _debt(self, address: Address, block: Optional[Block] = None) -> TokenB
address = str(address)
markets: List[Contract]
underlyings: List[ERC20]
markets, underlyings = await asyncio.gather(*[self.markets(), self.underlyings()])
debt_data, underlying_scale = await asyncio.gather(
asyncio.gather(*[_borrow_balance_stored(market, address, block) for market in markets]),
asyncio.gather(*[underlying.__scale__ for underlying in underlyings]),
markets, underlyings = await gather(*[self.markets(), self.underlyings()])
debt_data, underlying_scale = await gather(
gather(*[_borrow_balance_stored(market, address, block) for market in markets]),
gather(*[underlying.__scale__ for underlying in underlyings]),
)

balances: TokenBalances = TokenBalances(block=block)
Expand Down
16 changes: 8 additions & 8 deletions eth_portfolio/protocols/lending/maker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
from asyncio import gather
from typing import List, Optional

from async_lru import alru_cache
Expand Down Expand Up @@ -37,14 +37,14 @@ def __init__(self) -> None:

@stuck_coro_debugger
async def _balances(self, address: Address, block: Optional[Block] = None) -> TokenBalances:
ilks, urn = await asyncio.gather(self.get_ilks(block), self._urn(address))
ilks, urn = await gather(self.get_ilks(block), self._urn(address))

gem_coros = asyncio.gather(*[self.get_gem(str(ilk)) for ilk in ilks])
ink_coros = asyncio.gather(
gem_coros = gather(*[self.get_gem(str(ilk)) for ilk in ilks])
ink_coros = gather(
*[self.vat.urns.coroutine(ilk, urn, block_identifier=block) for ilk in ilks]
)

gems, ink_data = await asyncio.gather(gem_coros, ink_coros)
gems, ink_data = await gather(gem_coros, ink_coros)

balances: TokenBalances = TokenBalances(block=block)
for token, data in zip(gems, ink_data):
Expand All @@ -59,11 +59,11 @@ async def _debt(self, address: Address, block: Optional[int] = None) -> TokenBal
if block is not None and block <= await contract_creation_block_async(self.ilk_registry):
return TokenBalances(block=block)

ilks, urn = await asyncio.gather(self.get_ilks(block), self._urn(address))
ilks, urn = await gather(self.get_ilks(block), self._urn(address))

data = await asyncio.gather(
data = await gather(
*[
asyncio.gather(
gather(
self.vat.urns.coroutine(ilk, urn, block_identifier=block),
self.vat.ilks.coroutine(ilk, block_identifier=block),
)
Expand Down

0 comments on commit ff34d35

Please sign in to comment.