Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tudor-malene committed Jun 13, 2024
1 parent 0894a39 commit fabb8a0
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 36 deletions.
21 changes: 7 additions & 14 deletions go/enclave/l2chain/l2_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"math/big"

"github.com/ten-protocol/go-ten/go/common/errutil"

"github.com/ten-protocol/go-ten/go/config"

"github.com/ten-protocol/go-ten/go/enclave/storage"
Expand Down Expand Up @@ -66,13 +68,14 @@ func NewChain(

func (oc *obscuroChain) AccountOwner(ctx context.Context, address gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*gethcommon.Address, error) {
// check if account is a contract
isContract, err := oc.isAccountContractAtBlock(ctx, address, blockNumber)
_, err := oc.storage.ReadContractAddress(ctx, address)
if err != nil {
if errors.Is(err, errutil.ErrNotFound) {
// the account is not a contract, so it must be an EOA
return &address, nil
}
return nil, err
}
if !isContract {
return &address, nil
}

// If the address is a contract, find the signer of the deploy transaction
txHash, err := oc.storage.GetContractCreationTx(ctx, address)
Expand Down Expand Up @@ -212,13 +215,3 @@ func (oc *obscuroChain) GetChainStateAtTransaction(ctx context.Context, batch *c
}
return nil, vm.BlockContext{}, nil, fmt.Errorf("transaction index %d out of range for batch %#x", txIndex, batch.Hash())
}

// Returns whether the account is a contract
func (oc *obscuroChain) isAccountContractAtBlock(ctx context.Context, accountAddr gethcommon.Address, blockNumber *gethrpc.BlockNumber) (bool, error) {
chainState, err := oc.Registry.GetBatchStateAtHeight(ctx, blockNumber)
if err != nil {
return false, fmt.Errorf("unable to get blockchain state - %w", err)
}

return len(chainState.GetCode(accountAddr)) > 0, nil
}
2 changes: 1 addition & 1 deletion go/enclave/rpc/GetBalance.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func GetBalanceValidate(reqParams []any, builder *CallBuilder[BalanceReq, hexuti
func GetBalanceExecute(builder *CallBuilder[BalanceReq, hexutil.Big], rpc *EncryptionManager) error {
acctOwner, err := rpc.chain.AccountOwner(builder.ctx, *builder.Param.Addr, builder.Param.Block.BlockNumber)
if err != nil {
return err
return fmt.Errorf("cannot determine account owner. Cause: %w", err)
}

// authorise the call
Expand Down
2 changes: 1 addition & 1 deletion go/enclave/storage/enclavedb/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func ReadBatchTransactions(ctx context.Context, db *sql.DB, height uint64) ([]*c
}

func GetContractCreationTx(ctx context.Context, db *sql.DB, address gethcommon.Address) (*gethcommon.Hash, error) {
row := db.QueryRowContext(ctx, "select tx.hash from exec_tx join tx on tx.id=exec_tx.tx where created_contract_address=? ", address.Bytes())
row := db.QueryRowContext(ctx, "select tx.hash from exec_tx join tx on tx.id=exec_tx.tx join contract on created_contract_address=contract.id where address=? ", address.Bytes())

var txHashBytes []byte
err := row.Scan(&txHashBytes)
Expand Down
13 changes: 7 additions & 6 deletions go/enclave/storage/enclavedb/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
"where b.is_canonical=true "
)

func WriteEventType(ctx context.Context, dbTX *sql.Tx, contractID uint64, eventSignature gethcommon.Hash, isLifecycle bool) (uint64, error) {
func WriteEventType(ctx context.Context, dbTX *sql.Tx, contractID *uint64, eventSignature gethcommon.Hash, isLifecycle bool) (uint64, error) {
res, err := dbTX.ExecContext(ctx, "insert into event_type (contract, event_sig, lifecycle_event) values (?, ?, ?)", contractID, eventSignature.Bytes(), isLifecycle)
if err != nil {
return 0, err
Expand Down Expand Up @@ -284,7 +284,7 @@ func WriteEoa(ctx context.Context, dbTX *sql.Tx, sender *gethcommon.Address) (ui
return uint64(id), nil
}

func ReadEoa(ctx context.Context, dbTx *sql.Tx, addr *gethcommon.Address) (uint64, error) {
func ReadEoa(ctx context.Context, dbTx *sql.Tx, addr gethcommon.Address) (uint64, error) {
row := dbTx.QueryRowContext(ctx, "select id from externally_owned_account where address = ?", addr.Bytes())

var id uint64
Expand All @@ -300,17 +300,18 @@ func ReadEoa(ctx context.Context, dbTx *sql.Tx, addr *gethcommon.Address) (uint6
return id, nil
}

func WriteContractAddress(ctx context.Context, dbTX *sql.Tx, contractAddress *gethcommon.Address) (uint64, error) {
func WriteContractAddress(ctx context.Context, dbTX *sql.Tx, contractAddress *gethcommon.Address) (*uint64, error) {
insert := "insert into contract (address) values (?)"
res, err := dbTX.ExecContext(ctx, insert, contractAddress.Bytes())
if err != nil {
return 0, err
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
return nil, err
}
return uint64(id), nil
v := uint64(id)
return &v, nil
}

func ReadContractAddress(ctx context.Context, dbTx *sql.Tx, addr gethcommon.Address) (uint64, error) {
Expand Down
4 changes: 4 additions & 0 deletions go/enclave/storage/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ type Storage interface {

// StateDB - return the underlying state database
StateDB() state.Database

ReadEOA(ctx context.Context, addr gethcommon.Address) (*uint64, error)

ReadContractAddress(ctx context.Context, addr gethcommon.Address) (*uint64, error)
}

type ScanStorage interface {
Expand Down
46 changes: 32 additions & 14 deletions go/enclave/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ func (s *storageImpl) StoreBatch(ctx context.Context, batch *core.Batch, convert
if err != nil {
return fmt.Errorf("could not read tx sender. Cause: %w", err)
}
_, err = s.findEOA(ctx, dbTx, sender)
_, err = s.readEOA(ctx, dbTx, *sender)
if err != nil {
if errors.Is(err, errutil.ErrNotFound) {
_, err := enclavedb.WriteEoa(ctx, dbTx, sender)
Expand Down Expand Up @@ -675,7 +675,7 @@ func (s *storageImpl) storeReceiptAndEventLogs(ctx context.Context, dbTX *sql.Tx
var createdContract *uint64
var nilAddr gethcommon.Address
if receipt.ContractAddress != nilAddr {
createdContractId, err := s.findContractAddress(ctx, dbTX, receipt.ContractAddress)
createdContractId, err := s.readContractAddress(ctx, dbTX, receipt.ContractAddress)
if err != nil {
if errors.Is(err, errutil.ErrNotFound) {
createdContractId, err = enclavedb.WriteContractAddress(ctx, dbTX, &receipt.ContractAddress)
Expand All @@ -685,7 +685,7 @@ func (s *storageImpl) storeReceiptAndEventLogs(ctx context.Context, dbTX *sql.Tx
}
// return fmt.Errorf("could not read contract address. Cause: %w", err)
}
createdContract = &createdContractId
createdContract = createdContractId
}
// Convert the receipt into their storage form and serialize them
storageReceipt := (*types.ReceiptForStorage)(receipt)
Expand Down Expand Up @@ -749,7 +749,7 @@ func (s *storageImpl) storeEventLog(ctx context.Context, dbTX *sql.Tx, execTxId
eventT, err := s.readEventType(ctx, dbTX, l.Address, l.Topics[0])
if err != nil {
if errors.Is(err, errutil.ErrNotFound) {
contractAddId, err := s.findContractAddress(ctx, dbTX, l.Address)
contractAddId, err := s.readContractAddress(ctx, dbTX, l.Address)
if err != nil {
if errors.Is(err, errutil.ErrNotFound) {
contractAddId, err = enclavedb.WriteContractAddress(ctx, dbTX, &l.Address)
Expand Down Expand Up @@ -796,11 +796,11 @@ func (s *storageImpl) storeEventLog(ctx context.Context, dbTX *sql.Tx, execTxId
func (s *storageImpl) findRelevantAddress(ctx context.Context, dbTX *sql.Tx, topic gethcommon.Hash) (*uint64, error) {
potentialAddr := common.ExtractPotentialAddress(topic)
if potentialAddr != nil {
eoaID, err := s.findEOA(ctx, dbTX, potentialAddr)
eoaID, err := s.readEOA(ctx, dbTX, *potentialAddr)
if err != nil {
return nil, err
}
return &eoaID, nil
return eoaID, nil
// todo - do we need to check anything else?
}
return nil, nil
Expand Down Expand Up @@ -988,8 +988,17 @@ func (s *storageImpl) CountTransactionsPerAddress(ctx context.Context, address *
return enclavedb.CountTransactionsPerAddress(ctx, s.db.GetSQLDB(), address)
}

func (s *storageImpl) findEOA(ctx context.Context, dbTX *sql.Tx, addr *gethcommon.Address) (uint64, error) {
defer s.logDuration("findEOA", measure.NewStopwatch())
func (s *storageImpl) ReadEOA(ctx context.Context, addr gethcommon.Address) (*uint64, error) {
dbtx, err := s.db.NewDBTransaction(ctx)
if err != nil {
return nil, err
}
defer dbtx.Rollback()
return s.readEOA(ctx, dbtx, addr)
}

func (s *storageImpl) readEOA(ctx context.Context, dbTX *sql.Tx, addr gethcommon.Address) (*uint64, error) {
defer s.logDuration("readEOA", measure.NewStopwatch())
id, err := common.GetCachedValue(ctx, s.eoaCache, s.logger, addr, func(v any) (*uint64, error) {
id, err := enclavedb.ReadEoa(ctx, dbTX, addr)
if err != nil {
Expand All @@ -998,13 +1007,22 @@ func (s *storageImpl) findEOA(ctx context.Context, dbTX *sql.Tx, addr *gethcommo
return &id, nil
})
if err != nil {
return 0, err
return nil, err
}
return id, err
}

func (s *storageImpl) ReadContractAddress(ctx context.Context, addr gethcommon.Address) (*uint64, error) {
dbtx, err := s.db.NewDBTransaction(ctx)
if err != nil {
return nil, err
}
return *id, err
defer dbtx.Rollback()
return s.readContractAddress(ctx, dbtx, addr)
}

func (s *storageImpl) findContractAddress(ctx context.Context, dbTX *sql.Tx, addr gethcommon.Address) (uint64, error) {
defer s.logDuration("findContractAddress", measure.NewStopwatch())
func (s *storageImpl) readContractAddress(ctx context.Context, dbTX *sql.Tx, addr gethcommon.Address) (*uint64, error) {
defer s.logDuration("readContractAddress", measure.NewStopwatch())
id, err := common.GetCachedValue(ctx, s.contractAddressCache, s.logger, addr, func(v any) (*uint64, error) {
id, err := enclavedb.ReadContractAddress(ctx, dbTX, addr)
if err != nil {
Expand All @@ -1013,9 +1031,9 @@ func (s *storageImpl) findContractAddress(ctx context.Context, dbTX *sql.Tx, add
return &id, nil
})
if err != nil {
return 0, err
return nil, err
}
return *id, err
return id, err
}

func (s *storageImpl) findEventTopic(ctx context.Context, dbTX *sql.Tx, topic []byte) (uint64, *uint64, error) {
Expand Down

0 comments on commit fabb8a0

Please sign in to comment.