diff --git a/eth_portfolio/address.py b/eth_portfolio/address.py index 0053555a..0f68098b 100644 --- a/eth_portfolio/address.py +++ b/eth_portfolio/address.py @@ -10,6 +10,7 @@ from y.constants import EEE_ADDRESS from y.datatypes import Address, Block +from eth_portfolio import protocols from eth_portfolio._ledgers.address import (AddressInternalTransfersLedger, AddressLedgerBase, AddressTokenTransfersLedger, @@ -17,8 +18,6 @@ PandableLedgerEntryList) from eth_portfolio._loaders import balances from eth_portfolio._utils import _LedgeredBase -from eth_portfolio.protocols import _external -from eth_portfolio.protocols.lending import _lending from eth_portfolio.typing import (Balance, RemoteTokenBalances, TokenBalances, WalletBalances) @@ -73,7 +72,7 @@ async def assets(self, block: Optional[Block] = None) -> TokenBalances: @stuck_coro_debugger async def debt(self, block: Optional[Block] = None) -> RemoteTokenBalances: - return await _lending.debt(self.address, block=block) + return await protocols.lending.debt(self.address, block=block) @stuck_coro_debugger async def external_balances(self, block: Optional[Block] = None) -> RemoteTokenBalances: @@ -109,11 +108,11 @@ async def token_balances(self, block) -> TokenBalances: @stuck_coro_debugger async def collateral(self, block: Optional[Block] = None) -> RemoteTokenBalances: - return await _lending.collateral(self.address, block=block) + return await protocols.lending.collateral(self.address, block=block) @stuck_coro_debugger async def staking(self, block: Optional[Block] = None) -> RemoteTokenBalances: - return await _external.balances(self.address, block=block) + return await protocols.balances(self.address, block=block) # Ledger Entries diff --git a/eth_portfolio/protocols/__init__.py b/eth_portfolio/protocols/__init__.py index 389ed2d6..a04ed390 100644 --- a/eth_portfolio/protocols/__init__.py +++ b/eth_portfolio/protocols/__init__.py @@ -1,5 +1,4 @@ -import asyncio from typing import List, Optional import a_sync @@ -7,24 +6,21 @@ from eth_portfolio._utils import (_get_protocols_for_submodule, _import_submodules) +from eth_portfolio.protocols import lending from eth_portfolio.protocols._base import StakingPoolABC from eth_portfolio.typing import RemoteTokenBalances _import_submodules() +protocols: List[StakingPoolABC] = _get_protocols_for_submodule() # type: ignore [assignment] -class ExternalBalances: - protocols: List[StakingPoolABC] = _get_protocols_for_submodule() # type: ignore - - @a_sync.future - async def balances(self, address: Address, block: Optional[Block] = None) -> RemoteTokenBalances: - if not self.protocols: - return RemoteTokenBalances() - return RemoteTokenBalances({ - type(protocol).__name__: protocol_balances - async for protocol, protocol_balances - in a_sync.map(lambda p: p.balances(address, block), self.protocols) - if protocol_balances is not None - }) - -_external = ExternalBalances() +@a_sync.future +async def balances(address: Address, block: Optional[Block] = None) -> RemoteTokenBalances: + if not protocols: + return RemoteTokenBalances() + return RemoteTokenBalances({ + type(protocol).__name__: protocol_balances + async for protocol, protocol_balances + in a_sync.map(lambda p: p.balances(address, block), protocols) + if protocol_balances is not None + }) diff --git a/eth_portfolio/protocols/lending/__init__.py b/eth_portfolio/protocols/lending/__init__.py index a6060c37..ec319e6e 100644 --- a/eth_portfolio/protocols/lending/__init__.py +++ b/eth_portfolio/protocols/lending/__init__.py @@ -13,30 +13,25 @@ _import_submodules() - -class Lending: - def __init__(self) -> None: - self.protocols: List[Union[LendingProtocol, LendingProtocolWithLockedCollateral]] = _get_protocols_for_submodule() # type: ignore - - @a_sync.future - @stuck_coro_debugger - async def collateral(self, address: Address, block: Optional[Block] = None) -> RemoteTokenBalances: - protocols = (protocol for protocol in self.protocols if isinstance(protocol, LendingProtocolWithLockedCollateral)) - return RemoteTokenBalances({ - type(protocol).__name__: token_balances - async for protocol, token_balances in a_sync.map(lambda p: p.balances(address, block), protocols) - if token_balances is not None - }) - - @a_sync.future - @stuck_coro_debugger - async def debt(self, address: Address, block: Optional[Block] = None) -> RemoteTokenBalances: - if not self.protocols: - return RemoteTokenBalances() - return RemoteTokenBalances({ - type(protocol).__name__: token_balances - async for protocol, token_balances in a_sync.map(lambda p: p.debt(address, block), self.protocols) - if token_balances is not None - }) - -_lending = Lending() +protocols: List[Union[LendingProtocol, LendingProtocolWithLockedCollateral]] = _get_protocols_for_submodule() # type: ignore [assignment] + +@a_sync.future +@stuck_coro_debugger +async def collateral(address: Address, block: Optional[Block] = None) -> RemoteTokenBalances: + protocols = (protocol for protocol in protocols if isinstance(protocol, LendingProtocolWithLockedCollateral)) + return RemoteTokenBalances({ + type(protocol).__name__: token_balances + async for protocol, token_balances in a_sync.map(lambda p: p.balances(address, block), protocols) + if token_balances is not None + }) + +@a_sync.future +@stuck_coro_debugger +async def debt(address: Address, block: Optional[Block] = None) -> RemoteTokenBalances: + if not protocols: + return RemoteTokenBalances() + return RemoteTokenBalances({ + type(protocol).__name__: token_balances + async for protocol, token_balances in a_sync.map(lambda p: p.debt(address, block), protocols) + if token_balances is not None + }) diff --git a/tests/protocols/test_external.py b/tests/protocols/test_external.py index 84a0da3e..21ff2f45 100644 --- a/tests/protocols/test_external.py +++ b/tests/protocols/test_external.py @@ -2,10 +2,9 @@ import pytest from unittest.mock import patch, AsyncMock, create_autospec +from eth_portfolio import protocols from eth_portfolio.protocols._base import StakingPoolABC from eth_portfolio.typing import Balance, RemoteTokenBalances, TokenBalances -from eth_portfolio.protocols import ExternalBalances - SOME_ADDRESS = '0x0000000000000000000000000000000000000001' SOME_TOKEN = '0x0000000000000000000000000000000000000002' @@ -21,10 +20,9 @@ class MockProtocolB(AsyncMock): @patch('a_sync.map') @pytest.mark.asyncio async def test_balances_no_protocols(mock_map): - external_balances = ExternalBalances() - external_balances.protocols = [] + protocols.protocols = [] - balances = await external_balances.balances(SOME_ADDRESS) + balances = await protocols.balances(SOME_ADDRESS) assert balances == RemoteTokenBalances() mock_map.assert_not_called() @@ -37,15 +35,14 @@ async def test_balances_with_protocols(): mock_protocol_b = MockProtocolB() mock_protocol_b.balances.return_value = TokenBalances({SOME_OTHER_TOKEN: Balance(200, 400)}) - external_balances = ExternalBalances() - external_balances.protocols = [mock_protocol_a, mock_protocol_b] + protocols.protocols = [mock_protocol_a, mock_protocol_b] - balances = await external_balances.balances(SOME_ADDRESS) + balances = await protocols.balances(SOME_ADDRESS) assert balances == RemoteTokenBalances({'MockProtocolA': TokenBalances({SOME_TOKEN: Balance(100, 200)}), 'MockProtocolB': TokenBalances({SOME_OTHER_TOKEN: Balance(200, 400)})}) - for protocol in external_balances.protocols: + for protocol in protocols.protocols: protocol.balances.assert_called_once_with(SOME_ADDRESS, None) @pytest.mark.asyncio @@ -59,13 +56,12 @@ async def test_balances_with_protocols_and_block(): mock_protocol_b = MockProtocolB() mock_protocol_b.balances.return_value = TokenBalances({SOME_OTHER_TOKEN: Balance(200, 400)}) - external_balances = ExternalBalances() - external_balances.protocols = [mock_protocol_a, mock_protocol_b] + protocols.protocols = [mock_protocol_a, mock_protocol_b] - balances = await external_balances.balances(SOME_ADDRESS, block) + balances = await protocols.balances(SOME_ADDRESS, block) assert balances == RemoteTokenBalances({'MockProtocolA': TokenBalances({SOME_TOKEN: Balance(100, 200)}), 'MockProtocolB': TokenBalances({SOME_OTHER_TOKEN: Balance(200, 400)})}) - for protocol in external_balances.protocols: + for protocol in protocols.protocols: protocol.balances.assert_called_once_with(SOME_ADDRESS, block)