diff --git a/go/enclave/rpc/TenStorageRead.go b/go/enclave/rpc/TenStorageRead.go index 0de76e6470..1fee8dff54 100644 --- a/go/enclave/rpc/TenStorageRead.go +++ b/go/enclave/rpc/TenStorageRead.go @@ -37,8 +37,15 @@ func TenStorageReadValidate(reqParams []any, builder *CallBuilder[storageReadWit return nil } - if !rpc.whitelist.AllowedStorageSlots[slot] { - builder.Err = fmt.Errorf("eth_getStorageAt is not supported on TEN") + contract, err := rpc.storage.ReadContract(builder.ctx, *address) + if err != nil { + builder.Err = fmt.Errorf("eth_getStorageAt is not supported for this contract") + return nil + } + + // block the call for un-transparent contracts and non-whitelisted slots + if !rpc.whitelist.AllowedStorageSlots[slot] && !contract.IsTransparent() { + builder.Err = fmt.Errorf("eth_getStorageAt is not supported for this contract") return nil } diff --git a/go/enclave/storage/interfaces.go b/go/enclave/storage/interfaces.go index d1629fc426..5bf5107be8 100644 --- a/go/enclave/storage/interfaces.go +++ b/go/enclave/storage/interfaces.go @@ -6,6 +6,8 @@ import ( "io" "math/big" + "github.com/ten-protocol/go-ten/go/enclave/storage/enclavedb" + "github.com/ethereum/go-ethereum/triedb" "github.com/ethereum/go-ethereum/core/state" @@ -150,7 +152,7 @@ type Storage interface { // StateDB - return the underlying state database StateDB() state.Database - ReadContractCreator(ctx context.Context, address gethcommon.Address) (*gethcommon.Address, error) + ReadContract(ctx context.Context, address gethcommon.Address) (*enclavedb.Contract, error) } type ScanStorage interface { diff --git a/go/enclave/storage/storage.go b/go/enclave/storage/storage.go index caba3d462a..6d372f3f0f 100644 --- a/go/enclave/storage/storage.go +++ b/go/enclave/storage/storage.go @@ -803,8 +803,13 @@ func (s *storageImpl) readOrWriteEOA(ctx context.Context, dbTX *sql.Tx, addr get }) } -func (s *storageImpl) ReadContractCreator(ctx context.Context, address gethcommon.Address) (*gethcommon.Address, error) { - return enclavedb.ReadContractCreator(ctx, s.db.GetSQLDB(), address) +func (s *storageImpl) ReadContract(ctx context.Context, address gethcommon.Address) (*enclavedb.Contract, error) { + dbtx, err := s.db.GetSQLDB().BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer dbtx.Rollback() + return enclavedb.ReadContractByAddress(ctx, dbtx, address) } func (s *storageImpl) logDuration(method string, stopWatch *measure.Stopwatch) {