Skip to content

Commit

Permalink
feat: CHAINID constant (#171)
Browse files Browse the repository at this point in the history
* feat: CHAINID constant

* chore: `black .`

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Dec 22, 2024
1 parent 36f0e8f commit f488f19
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 57 deletions.
118 changes: 62 additions & 56 deletions eth_portfolio/_db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
from multicall.utils import get_event_loop
from pony.orm import BindingError, OperationalError, commit, db_session, flush, select
from y._db.entities import db
from y.constants import CHAINID
from y.exceptions import reraise_excs_with_extra_context

from eth_portfolio._db import entities
from eth_portfolio._db.decorators import break_locks, requery_objs_on_diff_tx_err
from eth_portfolio._db.entities import (
AddressExtended,
BlockExtended,
ContractExtended,
TokenExtended,
)
from eth_portfolio._decimal import Decimal
from eth_portfolio.structs import InternalTransfer, TokenTransfer, Transaction, TransactionRLP
from eth_portfolio.typing import _P, _T, Fn
Expand Down Expand Up @@ -84,14 +91,14 @@ def robust_db_session(fn: Fn[_P, _T]) -> Fn[_P, _T]:

@a_sync(default="async", executor=_block_executor)
@robust_db_session
def get_block(block: int) -> entities.BlockExtended:
if b := entities.BlockExtended.get(chain=chain.id, number=block):
def get_block(block: int) -> BlockExtended:
if b := BlockExtended.get(chain=CHAINID, number=block):
return b
elif b := Block.get(chain=chain.id, number=block):
if isinstance(b, entities.BlockExtended):
elif b := Block.get(chain=CHAINID, number=block):
if isinstance(b, BlockExtended):
# in case of race cndtn
return b
raise ValueError(b, b.number, b.chain.id)
raise ValueError(b, b.number, b.CHAINID)
hash = b.hash
ts = b.timestamp
prices = [(price.token.address, price.price) for price in b.prices]
Expand All @@ -107,7 +114,7 @@ def get_block(block: int) -> entities.BlockExtended:
b.delete()
commit()
b = insert(
type=entities.BlockExtended,
type=BlockExtended,
chain=get_chain(sync=True),
number=block,
hash=hash,
Expand All @@ -130,9 +137,9 @@ def get_block(block: int) -> entities.BlockExtended:
if not isinstance(asdasd, Chain):
raise TypeError(asdasd)
commit()
if b := insert(type=entities.BlockExtended, chain=asdasd, number=block):
if b := insert(type=BlockExtended, chain=asdasd, number=block):
return b
return entities.BlockExtended.get(chain=chain.id, number=block)
return BlockExtended.get(chain=CHAINID, number=block)


@a_sync(default="async", executor=_block_executor)
Expand Down Expand Up @@ -184,20 +191,20 @@ def __is_token(address) -> bool:

@a_sync(default="async", executor=_address_executor)
@robust_db_session
def get_address(address: ChecksumAddress) -> entities.AddressExtended:
entity_type = entities.TokenExtended
entity = entities.Address.get(chain=chain.id, address=address)
def get_address(address: ChecksumAddress) -> AddressExtended:
entity_type = TokenExtended
entity = entities.Address.get(chain=CHAINID, address=address)
""" TODO: fix this later
entity = entities.Address.get(chain=chain, address=address)
if isinstance(entity, (Token, entities.TokenExtended)):
entity_type = entities.TokenExtended
elif isinstance(entity, (Contract, entities.ContractExtended)):
entity_type = entities.ContractExtended
elif isinstance(entity, (Address, entities.AddressExtended)):
entity_type = entities.AddressExtended
if isinstance(entity, (Token, TokenExtended)):
entity_type = TokenExtended
elif isinstance(entity, (Contract, ContractExtended)):
entity_type = ContractExtended
elif isinstance(entity, (Address, AddressExtended)):
entity_type = AddressExtended
elif entity is None:
# TODO: this logic should live in ypm, prob
entity_type = entities.AddressExtended if not is_contract(address) else entities.TokenExtended if is_token(address) else entities.ContractExtended
entity_type = AddressExtended if not is_contract(address) else TokenExtended if is_token(address) else ContractExtended
else:
raise NotImplementedError(entity, entity_type)
Expand All @@ -209,12 +216,12 @@ def get_address(address: ChecksumAddress) -> entities.AddressExtended:
entity.delete()
commit()
"""
if entity := entities.Address.get(chain=chain.id, address=address):
if entity := entities.Address.get(chain=CHAINID, address=address):
return entity

ensure_chain()
return insert(type=entity_type, chain=chain.id, address=address) or entity_type.get(
chain=chain.id, address=address
return insert(type=entity_type, chain=CHAINID, address=address) or entity_type.get(
chain=CHAINID, address=address
)


Expand All @@ -233,12 +240,12 @@ def ensure_addresses(*addresses: ChecksumAddress) -> None:

@a_sync(default="async", executor=_token_executor)
@robust_db_session
def get_token(address: ChecksumAddress) -> entities.TokenExtended:
if t := entities.TokenExtended.get(chain=chain.id, address=address):
def get_token(address: ChecksumAddress) -> TokenExtended:
if t := TokenExtended.get(chain=CHAINID, address=address):
return t
kwargs = {}
if t := Address.get(chain=chain.id, address=address):
if isinstance(t, entities.TokenExtended):
if t := Address.get(chain=CHAINID, address=address):
if isinstance(t, TokenExtended):
# double check due to possible race cntdn
return t
"""
Expand Down Expand Up @@ -274,8 +281,8 @@ def get_token(address: ChecksumAddress) -> entities.TokenExtended:
ensure_chain()
commit()
return insert(
type=entities.TokenExtended, chain=chain.id, address=address, **kwargs
) or entities.TokenExtended.get(chain=chain.id, address=address)
type=TokenExtended, chain=CHAINID, address=address, **kwargs
) or TokenExtended.get(chain=CHAINID, address=address)


@a_sync(default="async", executor=_token_executor)
Expand All @@ -285,7 +292,7 @@ def ensure_token(token_address: ChecksumAddress) -> None:


async def get_transaction(sender: ChecksumAddress, nonce: int) -> Optional[Transaction]:
startup_txs = await transactions_known_at_startup(chain.id, sender)
startup_txs = await transactions_known_at_startup(CHAINID, sender)
data = startup_txs.pop(nonce, None) or await __get_transaction_bytes_from_db(sender, nonce)
if data:
await _yield_to_loop()
Expand All @@ -307,7 +314,7 @@ async def _yield_to_loop():
@robust_db_session
def __get_transaction_bytes_from_db(sender: ChecksumAddress, nonce: int) -> Optional[bytes]:
entity: entities.Transaction
if entity := entities.Transaction.get(from_address=(chain.id, sender), nonce=nonce):
if entity := entities.Transaction.get(from_address=(CHAINID, sender), nonce=nonce):
return entity.raw


Expand Down Expand Up @@ -350,10 +357,10 @@ def _insert_transaction(transaction: Transaction) -> None:
with reraise_excs_with_extra_context(transaction):
entities.Transaction(
**transaction.__db_primary_key__,
block=(chain.id, transaction.block_number),
block=(CHAINID, transaction.block_number),
transaction_index=transaction.transaction_index,
hash=transaction.hash.hex(),
to_address=(chain.id, transaction.to_address) if transaction.to_address else None,
to_address=(CHAINID, transaction.to_address) if transaction.to_address else None,
value=transaction.value,
price=transaction.price,
value_usd=transaction.value_usd,
Expand All @@ -372,21 +379,21 @@ def get_internal_transfer(trace: evmspec.FilterTrace) -> Optional[InternalTransf
block = trace.blockNumber
entity: entities.InternalTransfer
if entity := entities.InternalTransfer.get(
block=(chain.id, block),
block=(CHAINID, block),
transaction_index=trace.transactionPosition,
hash=trace.transactionHash,
type=trace.type.name,
call_type=trace.callType,
from_address=(chain.id, trace.sender),
to_address=(chain.id, trace.to),
from_address=(CHAINID, trace.sender),
to_address=(CHAINID, trace.to),
value=trace.value.scaled,
trace_address=(chain.id, trace.traceAddress),
trace_address=(CHAINID, trace.traceAddress),
gas=trace.gas,
gas_used=trace.gasUsed if "gasUsed" in trace else None,
input=trace.input,
output=trace.output,
subtraces=trace.subtraces,
address=(chain.id, trace.address),
address=(CHAINID, trace.address),
):
return json.decode(entity.raw, type=InternalTransfer, dec_hook=_decode_hook)

Expand All @@ -395,21 +402,21 @@ def get_internal_transfer(trace: evmspec.FilterTrace) -> Optional[InternalTransf
@robust_db_session
def delete_internal_transfer(transfer: InternalTransfer) -> None:
if entity := entities.InternalTransfer.get(
block=(chain.id, transfer.block_number),
block=(CHAINID, transfer.block_number),
transaction_index=transfer.transaction_index,
hash=transfer.hash,
type=transfer.type,
call_type=transfer.call_type,
from_address=(chain.id, transfer.from_address),
to_address=(chain.id, transfer.to_address),
from_address=(CHAINID, transfer.from_address),
to_address=(CHAINID, transfer.to_address),
value=transfer.value,
trace_address=(chain.id, transfer.trace_address),
trace_address=(CHAINID, transfer.trace_address),
gas=transfer.gas,
gas_used=transfer.gas_used,
input=transfer.input,
output=transfer.output,
subtraces=transfer.subtraces,
address=(chain.id, transfer.address),
address=(CHAINID, transfer.address),
):
entity.delete()

Expand All @@ -429,13 +436,13 @@ async def insert_internal_transfer(transfer: InternalTransfer) -> None:
@robust_db_session
def _insert_internal_transfer(transfer: InternalTransfer) -> None:
entities.InternalTransfer(
block=(chain.id, transfer.block_number),
block=(CHAINID, transfer.block_number),
transaction_index=transfer.transaction_index,
hash=transfer.hash,
type=transfer.type,
call_type=transfer.call_type,
from_address=(chain.id, transfer.from_address),
to_address=(chain.id, transfer.to_address),
from_address=(CHAINID, transfer.from_address),
to_address=(CHAINID, transfer.to_address),
value=transfer.value,
price=transfer.price,
value_usd=transfer.value_usd,
Expand All @@ -448,11 +455,11 @@ def _insert_internal_transfer(transfer: InternalTransfer) -> None:

async def get_token_transfer(transfer: evmspec.Log) -> Optional[TokenTransfer]:
pk = {
"block": (chain.id, transfer.blockNumber),
"block": (CHAINID, transfer.blockNumber),
"transaction_index": transfer.transactionIndex,
"log_index": transfer.logIndex,
}
startup_xfers = await token_transfers_known_at_startup()
startup_xfers = await token_transfers_known_at_startup(CHAINID)
data = startup_xfers.pop(tuple(pk.values()), None) or await __get_token_transfer_bytes_from_db(
pk
)
Expand Down Expand Up @@ -490,18 +497,17 @@ def transactions_known_at_startup(chainid: int, from_address: ChecksumAddress) -

@a_sync(default="async", executor=_transaction_read_executor, ram_cache_maxsize=None)
@robust_db_session
def token_transfers_known_at_startup() -> Dict[_TokenTransferPK, bytes]:
chainid: int
def token_transfers_known_at_startup(chainid: int) -> Dict[_TokenTransferPK, bytes]:
block: int
tx_index: int
log_index: int
raw: bytes

transfers = {}
for chainid, block, tx_index, log_index, raw in select(
(t.block.chain.id, t.block.number, t.transaction_index, t.log_index, t.raw)
for block, tx_index, log_index, raw in select(
(t.block.number, t.transaction_index, t.log_index, t.raw)
for t in entities.TokenTransfer # type: ignore [attr-defined]
if t.block.chain.id == chain.id
if t.block.chain.id == chainid
):
pk = ((chainid, block), tx_index, log_index)
transfers[pk] = raw
Expand All @@ -512,7 +518,7 @@ def token_transfers_known_at_startup() -> Dict[_TokenTransferPK, bytes]:
@robust_db_session
def delete_token_transfer(token_transfer: TokenTransfer) -> None:
if entity := entities.TokenTransfer.get(
block=(chain.id, token_transfer.block_number),
block=(CHAINID, token_transfer.block_number),
transaction_index=token_transfer.transaction_index,
log_index=token_transfer.log_index,
):
Expand Down Expand Up @@ -553,13 +559,13 @@ async def insert_token_transfer(token_transfer: TokenTransfer) -> None:
@robust_db_session
def _insert_token_transfer(token_transfer: TokenTransfer) -> None:
entities.TokenTransfer(
block=(chain.id, token_transfer.block_number),
block=(CHAINID, token_transfer.block_number),
transaction_index=token_transfer.transaction_index,
log_index=token_transfer.log_index,
hash=token_transfer.hash.hex(),
token=(chain.id, token_transfer.token_address),
from_address=(chain.id, token_transfer.from_address),
to_address=(chain.id, token_transfer.to_address),
token=(CHAINID, token_transfer.token_address),
from_address=(CHAINID, token_transfer.from_address),
to_address=(CHAINID, token_transfer.to_address),
value=token_transfer.value,
price=token_transfer.price,
value_usd=token_transfer.value_usd,
Expand Down
2 changes: 1 addition & 1 deletion eth_portfolio/_ledgers/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ async def _load_new_objects(
start_block = int(start_block)
if isinstance(end_block, float) and int(end_block) == end_block:
end_block = int(end_block)

block_ranges = [
[hex(i), hex(i + BATCH_SIZE - 1)] for i in range(start_block, end_block, BATCH_SIZE)
]
Expand Down

0 comments on commit f488f19

Please sign in to comment.