diff --git a/go.mod b/go.mod index b0e7e6ca13..259b179a98 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/holiman/uint256 v1.2.4 github.com/jolestar/go-commons-pool/v2 v2.1.2 + github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 github.com/pkg/errors v0.9.1 @@ -34,7 +35,6 @@ require ( github.com/sanity-io/litter v1.5.5 github.com/status-im/keycard-go v0.3.2 github.com/stretchr/testify v1.9.0 - github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 github.com/tidwall/gjson v1.17.1 github.com/valyala/fasthttp v1.52.0 gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40 @@ -107,7 +107,6 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect @@ -131,6 +130,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/supranational/blst v0.3.11 // indirect + github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect diff --git a/go/common/cache_util.go b/go/common/cache_util.go index 52136049f1..b35cc9068b 100644 --- a/go/common/cache_util.go +++ b/go/common/cache_util.go @@ -10,8 +10,8 @@ import ( // GetCachedValue - returns the cached value for the provided key. If the key is not found, then invoke the 'onFailed' function // which returns the value, and cache it -func GetCachedValue[V any](cache *cache.Cache[*V], logger gethlog.Logger, key any, onCacheMiss func(any) (*V, error)) (*V, error) { - value, err := cache.Get(context.Background(), key) +func GetCachedValue[V any](ctx context.Context, cache *cache.Cache[*V], logger gethlog.Logger, key any, onCacheMiss func(any) (*V, error)) (*V, error) { + value, err := cache.Get(ctx, key) if err != nil || value == nil { // todo metrics for cache misses v, err := onCacheMiss(key) @@ -21,18 +21,18 @@ func GetCachedValue[V any](cache *cache.Cache[*V], logger gethlog.Logger, key an if v == nil { logger.Crit("Returned a nil value from the onCacheMiss function. Should not happen.") } - CacheValue(cache, logger, key, v) + CacheValue(ctx, cache, logger, key, v) return v, nil } return value, err } -func CacheValue[V any](cache *cache.Cache[*V], logger gethlog.Logger, key any, v *V) { +func CacheValue[V any](ctx context.Context, cache *cache.Cache[*V], logger gethlog.Logger, key any, v *V) { if v == nil { return } - err := cache.Set(context.Background(), key, v) + err := cache.Set(ctx, key, v) if err != nil { logger.Error("Could not store value in cache", log.ErrKey, err) } diff --git a/go/common/enclave.go b/go/common/enclave.go index 4e923835a4..9e07f59b67 100644 --- a/go/common/enclave.go +++ b/go/common/enclave.go @@ -1,6 +1,7 @@ package common import ( + "context" "encoding/json" "math/big" @@ -34,19 +35,19 @@ type Enclave interface { EnclaveScan // Status checks whether the enclave is ready to process requests - only implemented by the RPC layer - Status() (Status, SystemError) + Status(context.Context) (Status, SystemError) // Attestation - Produces an attestation report which will be used to request the shared secret from another enclave. - Attestation() (*AttestationReport, SystemError) + Attestation(context.Context) (*AttestationReport, SystemError) // GenerateSecret - the genesis enclave is responsible with generating the secret entropy - GenerateSecret() (EncryptedSharedEnclaveSecret, SystemError) + GenerateSecret(context.Context) (EncryptedSharedEnclaveSecret, SystemError) // InitEnclave - initialise an enclave with a seed received by another enclave - InitEnclave(secret EncryptedSharedEnclaveSecret) SystemError + InitEnclave(ctx context.Context, secret EncryptedSharedEnclaveSecret) SystemError // EnclaveID - returns the enclave's ID - EnclaveID() (EnclaveID, SystemError) + EnclaveID(context.Context) (EnclaveID, SystemError) // SubmitL1Block - Used for the host to submit L1 blocks to the enclave, these may be: // a. historic block - if the enclave is behind and in the process of catching up with the L1 state @@ -54,41 +55,41 @@ type Enclave interface { // It is the responsibility of the host to gossip the returned rollup // For good functioning the caller should always submit blocks ordered by height // submitting a block before receiving ancestors of it, will result in it being ignored - SubmitL1Block(block L1Block, receipts L1Receipts, isLatest bool) (*BlockSubmissionResponse, SystemError) + SubmitL1Block(ctx context.Context, block L1Block, receipts L1Receipts, isLatest bool) (*BlockSubmissionResponse, SystemError) // SubmitTx - user transactions - SubmitTx(tx EncryptedTx) (*responses.RawTx, SystemError) + SubmitTx(ctx context.Context, tx EncryptedTx) (*responses.RawTx, SystemError) // SubmitBatch submits a batch received from the sequencer for processing. - SubmitBatch(batch *ExtBatch) SystemError + SubmitBatch(ctx context.Context, batch *ExtBatch) SystemError // ObsCall - Execute a smart contract to retrieve data. The equivalent of "Eth_call" // Todo - return the result with a block delay. To prevent frontrunning. - ObsCall(encryptedParams EncryptedParamsCall) (*responses.Call, SystemError) + ObsCall(ctx context.Context, encryptedParams EncryptedParamsCall) (*responses.Call, SystemError) // GetTransactionCount returns the nonce of the wallet with the given address (encrypted with the acc viewing key) - GetTransactionCount(encryptedParams EncryptedParamsGetTxCount) (*responses.TxCount, SystemError) + GetTransactionCount(ctx context.Context, encryptedParams EncryptedParamsGetTxCount) (*responses.TxCount, SystemError) // Stop gracefully stops the enclave Stop() SystemError // GetTransaction returns a transaction in JSON format, encrypted with the viewing key for the transaction's `from` field. - GetTransaction(encryptedParams EncryptedParamsGetTxByHash) (*responses.TxByHash, SystemError) + GetTransaction(ctx context.Context, encryptedParams EncryptedParamsGetTxByHash) (*responses.TxByHash, SystemError) // GetTransactionReceipt returns a transaction receipt given its signed hash, or nil if the transaction is unknown - GetTransactionReceipt(encryptedParams EncryptedParamsGetTxReceipt) (*responses.TxReceipt, SystemError) + GetTransactionReceipt(ctx context.Context, encryptedParams EncryptedParamsGetTxReceipt) (*responses.TxReceipt, SystemError) // GetBalance returns the balance of the address on the Obscuro network, encrypted with the viewing key for the // address. - GetBalance(encryptedParams EncryptedParamsGetBalance) (*responses.Balance, SystemError) + GetBalance(ctx context.Context, encryptedParams EncryptedParamsGetBalance) (*responses.Balance, SystemError) // GetCode returns the code stored at the given address in the state for the given rollup hash. - GetCode(address gethcommon.Address, rollupHash *gethcommon.Hash) ([]byte, SystemError) + GetCode(ctx context.Context, address gethcommon.Address, rollupHash *gethcommon.Hash) ([]byte, SystemError) // Subscribe adds a log subscription to the enclave under the given ID, provided the request is authenticated // correctly. The events will be populated in the BlockSubmissionResponse. If there is an existing subscription // with the given ID, it is overwritten. - Subscribe(id rpc.ID, encryptedParams EncryptedParamsLogSubscription) SystemError + Subscribe(ctx context.Context, id rpc.ID, encryptedParams EncryptedParamsLogSubscription) SystemError // Unsubscribe removes the log subscription with the given ID from the enclave. If there is no subscription with // the given ID, nothing is deleted. @@ -98,54 +99,54 @@ type Enclave interface { StopClient() SystemError // EstimateGas tries to estimate the gas needed to execute a specific transaction based on the pending state. - EstimateGas(encryptedParams EncryptedParamsEstimateGas) (*responses.Gas, SystemError) + EstimateGas(ctx context.Context, encryptedParams EncryptedParamsEstimateGas) (*responses.Gas, SystemError) // GetLogs returns all the logs matching the filter. - GetLogs(encryptedParams EncryptedParamsGetLogs) (*responses.Logs, SystemError) + GetLogs(ctx context.Context, encryptedParams EncryptedParamsGetLogs) (*responses.Logs, SystemError) // HealthCheck returns whether the enclave is in a healthy state - HealthCheck() (bool, SystemError) + HealthCheck(context.Context) (bool, SystemError) // GetBatch - retrieve a batch if existing within the enclave db. - GetBatch(hash L2BatchHash) (*ExtBatch, SystemError) + GetBatch(ctx context.Context, hash L2BatchHash) (*ExtBatch, SystemError) // GetBatchBySeqNo - retrieve batch by sequencer number if it's in the db. - GetBatchBySeqNo(seqNo uint64) (*ExtBatch, SystemError) + GetBatchBySeqNo(ctx context.Context, seqNo uint64) (*ExtBatch, SystemError) // GetRollupData - retrieve the first batch sequence and start time for a given rollup. - GetRollupData(hash L2RollupHash) (*PublicRollupMetadata, SystemError) + GetRollupData(ctx context.Context, hash L2RollupHash) (*PublicRollupMetadata, SystemError) // CreateBatch - creates a new head batch extending the previous one for the latest known L1 head if the node is // a sequencer. Will panic otherwise. - CreateBatch(skipIfEmpty bool) SystemError + CreateBatch(ctx context.Context, skipIfEmpty bool) SystemError // CreateRollup - will create a new rollup by going through the sequencer if the node is a sequencer // or panic otherwise. - CreateRollup(fromSeqNo uint64) (*ExtRollup, SystemError) + CreateRollup(ctx context.Context, fromSeqNo uint64) (*ExtRollup, SystemError) // DebugTraceTransaction returns the trace of a transaction - DebugTraceTransaction(hash gethcommon.Hash, config *tracers.TraceConfig) (json.RawMessage, SystemError) + DebugTraceTransaction(ctx context.Context, hash gethcommon.Hash, config *tracers.TraceConfig) (json.RawMessage, SystemError) // StreamL2Updates - will stream any new batches as they are created/detected // All will be queued in the channel that has been returned. StreamL2Updates() (chan StreamL2UpdatesResponse, func()) // DebugEventLogRelevancy returns the logs of a transaction - DebugEventLogRelevancy(hash gethcommon.Hash) (json.RawMessage, SystemError) + DebugEventLogRelevancy(ctx context.Context, hash gethcommon.Hash) (json.RawMessage, SystemError) } // EnclaveScan represents the methods that are used for data scanning in the enclave type EnclaveScan interface { // GetTotalContractCount returns the total number of contracts that have been deployed - GetTotalContractCount() (*big.Int, SystemError) + GetTotalContractCount(context.Context) (*big.Int, SystemError) // GetCustomQuery returns the data of a custom query - GetCustomQuery(encryptedParams EncryptedParamsGetStorageAt) (*responses.PrivateQueryResponse, SystemError) + GetCustomQuery(ctx context.Context, encryptedParams EncryptedParamsGetStorageAt) (*responses.PrivateQueryResponse, SystemError) // GetPublicTransactionData returns a list of public transaction data - GetPublicTransactionData(pagination *QueryPagination) (*TransactionListingResponse, SystemError) + GetPublicTransactionData(ctx context.Context, pagination *QueryPagination) (*TransactionListingResponse, SystemError) // EnclavePublicConfig returns network data that is known to the enclave but can be shared publicly - EnclavePublicConfig() (*EnclavePublicConfig, SystemError) + EnclavePublicConfig(context.Context) (*EnclavePublicConfig, SystemError) } // BlockSubmissionResponse is the response sent from the enclave back to the node after ingesting a block diff --git a/go/common/gethencoding/geth_encoding.go b/go/common/gethencoding/geth_encoding.go index 679f6a9323..812dbc527a 100644 --- a/go/common/gethencoding/geth_encoding.go +++ b/go/common/gethencoding/geth_encoding.go @@ -1,6 +1,7 @@ package gethencoding import ( + "context" "encoding/json" "fmt" "math/big" @@ -45,8 +46,8 @@ const ( // EncodingService handles conversion to Geth data structures type EncodingService interface { - CreateEthHeaderForBatch(h *common.BatchHeader) (*types.Header, error) - CreateEthBlockFromBatch(b *core.Batch) (*types.Block, error) + CreateEthHeaderForBatch(ctx context.Context, h *common.BatchHeader) (*types.Header, error) + CreateEthBlockFromBatch(ctx context.Context, b *core.Batch) (*types.Block, error) } type gethEncodingServiceImpl struct { @@ -307,11 +308,11 @@ func ExtractEthCall(param interface{}) (*gethapi.TransactionArgs, error) { // CreateEthHeaderForBatch - the EVM requires an Ethereum header. // We convert the Batch headers to Ethereum headers to be able to use the Geth EVM. // Special care must be taken to maintain a valid chain of these converted headers. -func (enc *gethEncodingServiceImpl) CreateEthHeaderForBatch(h *common.BatchHeader) (*types.Header, error) { +func (enc *gethEncodingServiceImpl) CreateEthHeaderForBatch(ctx context.Context, h *common.BatchHeader) (*types.Header, error) { // wrap in a caching layer - return common.GetCachedValue(enc.gethHeaderCache, enc.logger, h.Hash(), func(a any) (*types.Header, error) { + return common.GetCachedValue(ctx, enc.gethHeaderCache, enc.logger, h.Hash(), func(a any) (*types.Header, error) { // deterministically calculate the private randomness that will be exposed to the EVM - secret, err := enc.storage.FetchSecret() + secret, err := enc.storage.FetchSecret(ctx) if err != nil { enc.logger.Crit("Could not fetch shared secret. Exiting.", log.ErrKey, err) } @@ -322,7 +323,7 @@ func (enc *gethEncodingServiceImpl) CreateEthHeaderForBatch(h *common.BatchHeade convertedParentHash := common.GethGenesisParentHash if h.SequencerOrderNo.Uint64() > common.L2GenesisSeqNo { - convertedParentHash, err = enc.storage.FetchConvertedHash(h.ParentHash) + convertedParentHash, err = enc.storage.FetchConvertedHash(ctx, h.ParentHash) if err != nil { enc.logger.Error("Cannot find the converted value for the parent of", log.BatchSeqNoKey, h.SequencerOrderNo) return nil, err @@ -377,8 +378,8 @@ type localBlock struct { ReceivedFrom interface{} } -func (enc *gethEncodingServiceImpl) CreateEthBlockFromBatch(b *core.Batch) (*types.Block, error) { - blockHeader, err := enc.CreateEthHeaderForBatch(b.Header) +func (enc *gethEncodingServiceImpl) CreateEthBlockFromBatch(ctx context.Context, b *core.Batch) (*types.Block, error) { + blockHeader, err := enc.CreateEthHeaderForBatch(ctx, b.Header) if err != nil { return nil, fmt.Errorf("unable to create eth block from batch - %w", err) } diff --git a/go/common/gethutil/gethutil.go b/go/common/gethutil/gethutil.go index 346cfdc15a..d9c79f742a 100644 --- a/go/common/gethutil/gethutil.go +++ b/go/common/gethutil/gethutil.go @@ -2,6 +2,7 @@ package gethutil import ( "bytes" + "context" "fmt" "github.com/ten-protocol/go-ten/go/enclave/storage" @@ -18,8 +19,8 @@ var EmptyHash = gethcommon.Hash{} // LCA - returns the latest common ancestor of the 2 blocks or an error if no common ancestor is found // it also returns the blocks that became canonincal, and the once that are now the fork -func LCA(newCanonical *types.Block, oldCanonical *types.Block, resolver storage.BlockResolver) (*common.ChainFork, error) { - b, cp, ncp, err := internalLCA(newCanonical, oldCanonical, resolver, []common.L1BlockHash{}, []common.L1BlockHash{oldCanonical.Hash()}) +func LCA(ctx context.Context, newCanonical *types.Block, oldCanonical *types.Block, resolver storage.BlockResolver) (*common.ChainFork, error) { + b, cp, ncp, err := internalLCA(ctx, newCanonical, oldCanonical, resolver, []common.L1BlockHash{}, []common.L1BlockHash{oldCanonical.Hash()}) // remove the common ancestor if len(cp) > 0 { cp = cp[0 : len(cp)-1] @@ -36,7 +37,7 @@ func LCA(newCanonical *types.Block, oldCanonical *types.Block, resolver storage. }, err } -func internalLCA(newCanonical *types.Block, oldCanonical *types.Block, resolver storage.BlockResolver, canonicalPath []common.L1BlockHash, nonCanonicalPath []common.L1BlockHash) (*types.Block, []common.L1BlockHash, []common.L1BlockHash, error) { +func internalLCA(ctx context.Context, newCanonical *types.Block, oldCanonical *types.Block, resolver storage.BlockResolver, canonicalPath []common.L1BlockHash, nonCanonicalPath []common.L1BlockHash) (*types.Block, []common.L1BlockHash, []common.L1BlockHash, error) { if newCanonical.NumberU64() == common.L1GenesisHeight || oldCanonical.NumberU64() == common.L1GenesisHeight { return newCanonical, canonicalPath, nonCanonicalPath, nil } @@ -44,29 +45,29 @@ func internalLCA(newCanonical *types.Block, oldCanonical *types.Block, resolver return newCanonical, canonicalPath, nonCanonicalPath, nil } if newCanonical.NumberU64() > oldCanonical.NumberU64() { - p, err := resolver.FetchBlock(newCanonical.ParentHash()) + p, err := resolver.FetchBlock(ctx, newCanonical.ParentHash()) if err != nil { return nil, nil, nil, fmt.Errorf("could not retrieve parent block. Cause: %w", err) } - return internalLCA(p, oldCanonical, resolver, append(canonicalPath, p.Hash()), nonCanonicalPath) + return internalLCA(ctx, p, oldCanonical, resolver, append(canonicalPath, p.Hash()), nonCanonicalPath) } if oldCanonical.NumberU64() > newCanonical.NumberU64() { - p, err := resolver.FetchBlock(oldCanonical.ParentHash()) + p, err := resolver.FetchBlock(ctx, oldCanonical.ParentHash()) if err != nil { return nil, nil, nil, fmt.Errorf("could not retrieve parent block. Cause: %w", err) } - return internalLCA(newCanonical, p, resolver, canonicalPath, append(nonCanonicalPath, p.Hash())) + return internalLCA(ctx, newCanonical, p, resolver, canonicalPath, append(nonCanonicalPath, p.Hash())) } - parentBlockA, err := resolver.FetchBlock(newCanonical.ParentHash()) + parentBlockA, err := resolver.FetchBlock(ctx, newCanonical.ParentHash()) if err != nil { return nil, nil, nil, fmt.Errorf("could not retrieve parent block. Cause: %w", err) } - parentBlockB, err := resolver.FetchBlock(oldCanonical.ParentHash()) + parentBlockB, err := resolver.FetchBlock(ctx, oldCanonical.ParentHash()) if err != nil { return nil, nil, nil, fmt.Errorf("could not retrieve parent block. Cause: %w", err) } - return internalLCA(parentBlockA, parentBlockB, resolver, append(canonicalPath, parentBlockA.Hash()), append(nonCanonicalPath, parentBlockB.Hash())) + return internalLCA(ctx, parentBlockA, parentBlockB, resolver, append(canonicalPath, parentBlockA.Hash()), append(nonCanonicalPath, parentBlockB.Hash())) } diff --git a/go/common/host/host.go b/go/common/host/host.go index f016b1d95b..1a49f0888d 100644 --- a/go/common/host/host.go +++ b/go/common/host/host.go @@ -1,6 +1,8 @@ package host import ( + "context" + "github.com/ethereum/go-ethereum/core/types" "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/config" @@ -17,7 +19,7 @@ type Host interface { // Start initializes the main loop of the host. Start() error // SubmitAndBroadcastTx submits an encrypted transaction to the enclave, and broadcasts it to the other hosts on the network. - SubmitAndBroadcastTx(encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) + SubmitAndBroadcastTx(ctx context.Context, encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) // SubscribeLogs feeds logs matching the encrypted log subscription to the matchedLogs channel. SubscribeLogs(id rpc.ID, encryptedLogSubscription common.EncryptedParamsLogSubscription, matchedLogs chan []byte) error // UnsubscribeLogs terminates a log subscription between the host and the enclave. @@ -26,7 +28,7 @@ type Host interface { Stop() error // HealthCheck returns the health status of the host + enclave + db - HealthCheck() (*HealthCheck, error) + HealthCheck(context.Context) (*HealthCheck, error) // ObscuroConfig returns the info of the Obscuro network ObscuroConfig() (*common.ObscuroNetworkInfo, error) diff --git a/go/common/host/services.go b/go/common/host/services.go index 2ee5bbee1b..03f66011ec 100644 --- a/go/common/host/services.go +++ b/go/common/host/services.go @@ -1,6 +1,7 @@ package host import ( + "context" "math/big" "github.com/ten-protocol/go-ten/go/responses" @@ -33,7 +34,7 @@ const ( type Service interface { Start() error Stop() error - HealthStatus() HealthStatus + HealthStatus(context.Context) HealthStatus } // P2P provides an interface for the host to interact with the P2P network @@ -125,7 +126,7 @@ type L2BatchRepository interface { // Subscribe will register a batch handler to receive new batches as they arrive Subscribe(handler L2BatchHandler) func() - FetchBatchBySeqNo(seqNo *big.Int) (*common.ExtBatch, error) + FetchBatchBySeqNo(background context.Context, seqNo *big.Int) (*common.ExtBatch, error) // AddBatch is used to notify the repository of a new batch, e.g. from the enclave when seq produces one or a rollup is consumed // Note: it is fine to add batches that the repo already has, it will just ignore them @@ -143,13 +144,13 @@ type L2BatchHandler interface { type EnclaveService interface { // LookupBatchBySeqNo is used to fetch batch data from the enclave - it is only used as a fallback for the sequencer // host if it's missing a batch (other host services should use L2Repo to fetch batch data) - LookupBatchBySeqNo(seqNo *big.Int) (*common.ExtBatch, error) + LookupBatchBySeqNo(ctx context.Context, seqNo *big.Int) (*common.ExtBatch, error) // GetEnclaveClient returns an enclave client // todo (@matt) we probably don't want to expose this GetEnclaveClient() common.Enclave // SubmitAndBroadcastTx submits an encrypted transaction to the enclave, and broadcasts it to other hosts on the network (in particular, to the sequencer) - SubmitAndBroadcastTx(encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) + SubmitAndBroadcastTx(ctx context.Context, encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) Subscribe(id rpc.ID, encryptedLogSubscription common.EncryptedParamsLogSubscription) error Unsubscribe(id rpc.ID) error diff --git a/go/common/subscription/new_heads_manager.go b/go/common/subscription/new_heads_manager.go index eded54d1b2..13e6b6671a 100644 --- a/go/common/subscription/new_heads_manager.go +++ b/go/common/subscription/new_heads_manager.go @@ -1,6 +1,7 @@ package subscription import ( + "context" "math/big" "sync" "sync/atomic" @@ -93,7 +94,7 @@ func (nhs *NewHeadsService) Stop() error { return nil } -func (nhs *NewHeadsService) HealthStatus() host.HealthStatus { +func (nhs *NewHeadsService) HealthStatus(context.Context) host.HealthStatus { return &host.BasicErrHealthStatus{} } diff --git a/go/config/enclave_config.go b/go/config/enclave_config.go index 68f92524f0..a181b7a695 100644 --- a/go/config/enclave_config.go +++ b/go/config/enclave_config.go @@ -6,6 +6,7 @@ import ( "os" "strconv" "strings" + "time" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ten-protocol/go-ten/go/common" @@ -69,6 +70,10 @@ type EnclaveConfig struct { BaseFee *big.Int GasBatchExecutionLimit uint64 GasLocalExecutionCapFlag uint64 + + // RPCTimeout - calls that are longer than this will be cancelled, to prevent resource starvation + // normally, the context is propagated from the host, but in some cases ( like the evm, we have to create a context) + RPCTimeout time.Duration } func NewConfigFromFlags(cliFlags map[string]*flag.TenFlag) (*EnclaveConfig, error) { @@ -154,7 +159,10 @@ func retrieveEnvFlags() (map[string]*flag.TenFlag, error) { } func newConfig(flags map[string]*flag.TenFlag) (*EnclaveConfig, error) { - cfg := &EnclaveConfig{} + cfg := &EnclaveConfig{ + // hardcoding for now + RPCTimeout: 5 * time.Second, + } nodeType, err := common.ToNodeType(flags[NodeTypeFlag].String()) if err != nil { diff --git a/go/enclave/components/attestation.go b/go/enclave/components/attestation.go index 645f4557f1..c6dc7e19c9 100644 --- a/go/enclave/components/attestation.go +++ b/go/enclave/components/attestation.go @@ -2,6 +2,7 @@ package components import ( "bytes" + "context" "crypto/sha256" "encoding/json" "fmt" @@ -20,14 +21,14 @@ type IDData struct { type AttestationProvider interface { // GetReport returns the verifiable attestation report - GetReport(pubKey []byte, enclaveID gethcommon.Address, hostAddress string) (*common.AttestationReport, error) + GetReport(ctx context.Context, pubKey []byte, enclaveID gethcommon.Address, hostAddress string) (*common.AttestationReport, error) // VerifyReport returns the embedded report data VerifyReport(att *common.AttestationReport) ([]byte, error) } type EgoAttestationProvider struct{} -func (e *EgoAttestationProvider) GetReport(pubKey []byte, enclaveID gethcommon.Address, hostAddress string) (*common.AttestationReport, error) { +func (e *EgoAttestationProvider) GetReport(ctx context.Context, pubKey []byte, enclaveID gethcommon.Address, hostAddress string) (*common.AttestationReport, error) { idHash, err := getIDHash(enclaveID, pubKey, hostAddress) if err != nil { return nil, err @@ -56,7 +57,7 @@ func (e *EgoAttestationProvider) VerifyReport(att *common.AttestationReport) ([] type DummyAttestationProvider struct{} -func (e *DummyAttestationProvider) GetReport(pubKey []byte, enclaveID gethcommon.Address, hostAddress string) (*common.AttestationReport, error) { +func (e *DummyAttestationProvider) GetReport(ctx context.Context, pubKey []byte, enclaveID gethcommon.Address, hostAddress string) (*common.AttestationReport, error) { return &common.AttestationReport{ Report: []byte("MOCK REPORT"), PubKey: pubKey, diff --git a/go/enclave/components/batch_executor.go b/go/enclave/components/batch_executor.go index 78a7b0d9bc..83b900688e 100644 --- a/go/enclave/components/batch_executor.go +++ b/go/enclave/components/batch_executor.go @@ -2,12 +2,15 @@ package components import ( "bytes" + "context" "errors" "fmt" "math/big" "sort" "sync" + "github.com/ten-protocol/go-ten/go/config" + "github.com/ten-protocol/go-ten/go/common/gethencoding" "github.com/ten-protocol/go-ten/go/enclave/gas" @@ -36,6 +39,7 @@ var ErrNoTransactionsToProcess = fmt.Errorf("no transactions to process") // batchExecutor - the component responsible for executing batches type batchExecutor struct { storage storage.Storage + config config.EnclaveConfig gethEncodingService gethencoding.EncodingService crossChainProcessors *crosschain.Processors genesis *genesis.Genesis @@ -51,6 +55,7 @@ type batchExecutor struct { func NewBatchExecutor( storage storage.Storage, + config config.EnclaveConfig, gethEncodingService gethencoding.EncodingService, cc *crosschain.Processors, genesis *genesis.Genesis, @@ -61,6 +66,7 @@ func NewBatchExecutor( ) BatchExecutor { return &batchExecutor{ storage: storage, + config: config, gethEncodingService: gethEncodingService, crossChainProcessors: cc, genesis: genesis, @@ -75,10 +81,10 @@ func NewBatchExecutor( // filterTransactionsWithSufficientFunds - this function estimates hte l1 fees for the transaction in a given batch execution context. It does so by taking the price of the // pinned L1 block and using it as the cost per gas for the estimated gas of the calldata encoding of a transaction. It filters out any transactions that cannot afford to pay for their L1 // publishing cost. -func (executor *batchExecutor) filterTransactionsWithSufficientFunds(stateDB *state.StateDB, context *BatchExecutionContext) (common.L2PricedTransactions, common.L2PricedTransactions) { +func (executor *batchExecutor) filterTransactionsWithSufficientFunds(ctx context.Context, stateDB *state.StateDB, context *BatchExecutionContext) (common.L2PricedTransactions, common.L2PricedTransactions) { transactions := make(common.L2PricedTransactions, 0) freeTransactions := make(common.L2PricedTransactions, 0) - block, _ := executor.storage.FetchBlock(context.BlockPtr) + block, _ := executor.storage.FetchBlock(ctx, context.BlockPtr) for _, tx := range context.Transactions { sender, err := core.GetAuthenticatedSender(context.ChainConfig.ChainID.Int64(), tx) @@ -121,11 +127,11 @@ func (executor *batchExecutor) filterTransactionsWithSufficientFunds(stateDB *st return transactions, freeTransactions } -func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, failForEmptyBatch bool) (*ComputedBatch, error) { //nolint:gocognit +func (executor *batchExecutor) ComputeBatch(ctx context.Context, context *BatchExecutionContext, failForEmptyBatch bool) (*ComputedBatch, error) { //nolint:gocognit defer core.LogMethodDuration(executor.logger, measure.NewStopwatch(), "Batch context processed") // sanity check that the l1 block exists. We don't have to execute batches of forks. - block, err := executor.storage.FetchBlock(context.BlockPtr) + block, err := executor.storage.FetchBlock(ctx, context.BlockPtr) if errors.Is(err, errutil.ErrNotFound) { return nil, errutil.ErrBlockForBatchNotFound } else if err != nil { @@ -133,7 +139,7 @@ func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, fail } // These variables will be used to create the new batch - parent, err := executor.storage.FetchBatch(context.ParentPtr) + parent, err := executor.storage.FetchBatch(ctx, context.ParentPtr) if errors.Is(err, errutil.ErrNotFound) { executor.logger.Error(fmt.Sprintf("can't find parent batch %s. Seq %d", context.ParentPtr, context.SequencerNo)) return nil, errutil.ErrAncestorBatchNotFound @@ -145,7 +151,7 @@ func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, fail parentBlock := block if parent.Header.L1Proof != block.Hash() { var err error - parentBlock, err = executor.storage.FetchBlock(parent.Header.L1Proof) + parentBlock, err = executor.storage.FetchBlock(ctx, parent.Header.L1Proof) if err != nil { executor.logger.Error(fmt.Sprintf("Could not retrieve a proof for batch %s", parent.Hash()), log.ErrKey, err) return nil, err @@ -155,7 +161,7 @@ func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, fail // Create a new batch based on the fromBlock of inclusion of the previous, including all new transactions batch := core.DeterministicEmptyBatch(parent.Header, block, context.AtTime, context.SequencerNo, context.BaseFee, context.Creator) - stateDB, err := executor.storage.CreateStateDB(batch.Header.ParentHash) + stateDB, err := executor.storage.CreateStateDB(ctx, batch.Header.ParentHash) if err != nil { return nil, fmt.Errorf("could not create stateDB. Cause: %w", err) } @@ -164,13 +170,13 @@ func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, fail var messages common.CrossChainMessages var transfers common.ValueTransferEvents if context.SequencerNo.Int64() > int64(common.L2GenesisSeqNo+1) { - messages, transfers = executor.crossChainProcessors.Local.RetrieveInboundMessages(parentBlock, block, stateDB) + messages, transfers = executor.crossChainProcessors.Local.RetrieveInboundMessages(ctx, parentBlock, block, stateDB) } - crossChainTransactions := executor.crossChainProcessors.Local.CreateSyntheticTransactions(messages, stateDB) - executor.crossChainProcessors.Local.ExecuteValueTransfers(transfers, stateDB) + crossChainTransactions := executor.crossChainProcessors.Local.CreateSyntheticTransactions(ctx, messages, stateDB) + executor.crossChainProcessors.Local.ExecuteValueTransfers(ctx, transfers, stateDB) - transactionsToProcess, freeTransactions := executor.filterTransactionsWithSufficientFunds(stateDB, context) + transactionsToProcess, freeTransactions := executor.filterTransactionsWithSufficientFunds(ctx, stateDB, context) xchainTxs := make(common.L2PricedTransactions, 0) for _, xTx := range crossChainTransactions { @@ -183,14 +189,14 @@ func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, fail syntheticTransactions := append(xchainTxs, freeTransactions...) // fromTxIndex - Here we start from the 0 index. This will be the same for a validator. - successfulTxs, excludedTxs, txReceipts, err := executor.processTransactions(batch, 0, transactionsToProcess, stateDB, context.ChainConfig, false) + successfulTxs, excludedTxs, txReceipts, err := executor.processTransactions(ctx, batch, 0, transactionsToProcess, stateDB, context.ChainConfig, false) if err != nil { return nil, fmt.Errorf("could not process transactions. Cause: %w", err) } // fromTxIndex - Here we start from the len of the successful transactions; As long as we have the exact same successful transactions in a batch, // we will start from the same place. - ccSuccessfulTxs, _, ccReceipts, err := executor.processTransactions(batch, len(successfulTxs), syntheticTransactions, stateDB, context.ChainConfig, true) + ccSuccessfulTxs, _, ccReceipts, err := executor.processTransactions(ctx, batch, len(successfulTxs), syntheticTransactions, stateDB, context.ChainConfig, true) if err != nil { return nil, err } @@ -219,7 +225,7 @@ func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, fail copyBatch.Transactions = append(successfulTxs, freeTransactions.ToTransactions()...) copyBatch.ResetHash() - if err = executor.populateOutboundCrossChainData(©Batch, block, txReceipts); err != nil { + if err = executor.populateOutboundCrossChainData(ctx, ©Batch, block, txReceipts); err != nil { return nil, fmt.Errorf("failed adding cross chain data to batch. Cause: %w", err) } @@ -251,7 +257,7 @@ func (executor *batchExecutor) ComputeBatch(context *BatchExecutionContext, fail }, nil } -func (executor *batchExecutor) ExecuteBatch(batch *core.Batch) (types.Receipts, error) { +func (executor *batchExecutor) ExecuteBatch(ctx context.Context, batch *core.Batch) (types.Receipts, error) { defer core.LogMethodDuration(executor.logger, measure.NewStopwatch(), "Executed batch", log.BatchHashKey, batch.Hash()) // Validators recompute the entire batch using the same batch context @@ -259,7 +265,7 @@ func (executor *batchExecutor) ExecuteBatch(batch *core.Batch) (types.Receipts, // and the parent hash. This recomputed batch is then checked against the incoming batch. // If the sequencer has tampered with something the hash will not add up and validation will // produce an error. - cb, err := executor.ComputeBatch(&BatchExecutionContext{ + cb, err := executor.ComputeBatch(ctx, &BatchExecutionContext{ BlockPtr: batch.Header.L1Proof, ParentPtr: batch.Header.ParentHash, Transactions: batch.Transactions, @@ -300,6 +306,7 @@ func (vt ValueTransfers) EncodeIndex(index int, w *bytes.Buffer) { } func (executor *batchExecutor) CreateGenesisState( + ctx context.Context, blkHash common.L1BlockHash, timeNow uint64, coinbase gethcommon.Address, @@ -342,14 +349,14 @@ func (executor *batchExecutor) CreateGenesisState( return genesisBatch, deployTx, nil } -func (executor *batchExecutor) populateOutboundCrossChainData(batch *core.Batch, block *types.Block, receipts types.Receipts) error { - crossChainMessages, err := executor.crossChainProcessors.Local.ExtractOutboundMessages(receipts) +func (executor *batchExecutor) populateOutboundCrossChainData(ctx context.Context, batch *core.Batch, block *types.Block, receipts types.Receipts) error { + crossChainMessages, err := executor.crossChainProcessors.Local.ExtractOutboundMessages(ctx, receipts) if err != nil { executor.logger.Error("Failed extracting L2->L1 messages", log.ErrKey, err, log.CmpKey, log.CrossChainCmp) return fmt.Errorf("could not extract cross chain messages. Cause: %w", err) } - valueTransferMessages, err := executor.crossChainProcessors.Local.ExtractOutboundTransfers(receipts) + valueTransferMessages, err := executor.crossChainProcessors.Local.ExtractOutboundTransfers(ctx, receipts) if err != nil { executor.logger.Error("Failed extracting L2->L1 messages value transfers", log.ErrKey, err, log.CmpKey, log.CrossChainCmp) return fmt.Errorf("could not extract cross chain value transfers. Cause: %w", err) @@ -398,6 +405,7 @@ func (executor *batchExecutor) verifyInboundCrossChainTransactions(transactions } func (executor *batchExecutor) processTransactions( + ctx context.Context, batch *core.Batch, tCount int, txs common.L2PricedTransactions, @@ -409,12 +417,14 @@ func (executor *batchExecutor) processTransactions( var excludedTransactions []*common.L2Tx var txReceipts []*types.Receipt txResults := evm.ExecuteTransactions( + ctx, txs, stateDB, batch.Header, executor.storage, executor.gethEncodingService, cc, + executor.config, tCount, noBaseFee, executor.batchGasLimit, diff --git a/go/enclave/components/batch_registry.go b/go/enclave/components/batch_registry.go index 46f54247d4..a9bcd77659 100644 --- a/go/enclave/components/batch_registry.go +++ b/go/enclave/components/batch_registry.go @@ -1,6 +1,7 @@ package components import ( + "context" "errors" "fmt" "math/big" @@ -34,7 +35,7 @@ type batchRegistry struct { func NewBatchRegistry(storage storage.Storage, logger gethlog.Logger) BatchRegistry { var headBatchSeq *big.Int - headBatch, err := storage.FetchHeadBatch() + headBatch, err := storage.FetchHeadBatch(context.Background()) if err != nil { if errors.Is(err, errutil.ErrNotFound) { headBatchSeq = nil @@ -89,9 +90,9 @@ func (br *batchRegistry) HasGenesisBatch() (bool, error) { return br.headBatchSeq != nil, nil } -func (br *batchRegistry) BatchesAfter(batchSeqNo uint64, upToL1Height uint64, rollupLimiter limiters.RollupLimiter) ([]*core.Batch, []*types.Block, error) { +func (br *batchRegistry) BatchesAfter(ctx context.Context, batchSeqNo uint64, upToL1Height uint64, rollupLimiter limiters.RollupLimiter) ([]*core.Batch, []*types.Block, error) { // sanity check - headBatch, err := br.storage.FetchBatchBySeqNo(br.headBatchSeq.Uint64()) + headBatch, err := br.storage.FetchBatchBySeqNo(ctx, br.headBatchSeq.Uint64()) if err != nil { return nil, nil, err } @@ -106,7 +107,7 @@ func (br *batchRegistry) BatchesAfter(batchSeqNo uint64, upToL1Height uint64, ro currentBatchSeq := batchSeqNo var currentBlock *types.Block for currentBatchSeq <= headBatch.SeqNo().Uint64() { - batch, err := br.storage.FetchBatchBySeqNo(currentBatchSeq) + batch, err := br.storage.FetchBatchBySeqNo(ctx, currentBatchSeq) if err != nil { return nil, nil, fmt.Errorf("could not retrieve batch by sequence number %d. Cause: %w", currentBatchSeq, err) } @@ -114,7 +115,7 @@ func (br *batchRegistry) BatchesAfter(batchSeqNo uint64, upToL1Height uint64, ro // check the block height // if it's the same block as the previous batch there is no reason to check if currentBlock == nil || currentBlock.Hash() != batch.Header.L1Proof { - block, err := br.storage.FetchBlock(batch.Header.L1Proof) + block, err := br.storage.FetchBlock(ctx, batch.Header.L1Proof) if err != nil { return nil, nil, fmt.Errorf("could not retrieve block. Cause: %w", err) } @@ -153,15 +154,15 @@ func (br *batchRegistry) BatchesAfter(batchSeqNo uint64, upToL1Height uint64, ro return resultBatches, resultBlocks, nil } -func (br *batchRegistry) GetBatchStateAtHeight(blockNumber *gethrpc.BlockNumber) (*state.StateDB, error) { +func (br *batchRegistry) GetBatchStateAtHeight(ctx context.Context, blockNumber *gethrpc.BlockNumber) (*state.StateDB, error) { // We retrieve the batch of interest. - batch, err := br.GetBatchAtHeight(*blockNumber) + batch, err := br.GetBatchAtHeight(ctx, *blockNumber) if err != nil { return nil, err } // We get that of the chain at that height - blockchainState, err := br.storage.CreateStateDB(batch.Hash()) + blockchainState, err := br.storage.CreateStateDB(ctx, batch.Hash()) if err != nil { return nil, fmt.Errorf("could not create stateDB. Cause: %w", err) } @@ -173,27 +174,27 @@ func (br *batchRegistry) GetBatchStateAtHeight(blockNumber *gethrpc.BlockNumber) return blockchainState, err } -func (br *batchRegistry) GetBatchAtHeight(height gethrpc.BlockNumber) (*core.Batch, error) { +func (br *batchRegistry) GetBatchAtHeight(ctx context.Context, height gethrpc.BlockNumber) (*core.Batch, error) { if br.headBatchSeq == nil { return nil, fmt.Errorf("chain not initialised") } var batch *core.Batch switch height { case gethrpc.EarliestBlockNumber: - genesisBatch, err := br.storage.FetchBatchByHeight(0) + genesisBatch, err := br.storage.FetchBatchByHeight(ctx, 0) if err != nil { return nil, fmt.Errorf("could not retrieve genesis rollup. Cause: %w", err) } batch = genesisBatch // note: our API currently treats all these block statuses the same for obscuro batches case gethrpc.SafeBlockNumber, gethrpc.FinalizedBlockNumber, gethrpc.LatestBlockNumber, gethrpc.PendingBlockNumber: - headBatch, err := br.storage.FetchBatchBySeqNo(br.headBatchSeq.Uint64()) + headBatch, err := br.storage.FetchBatchBySeqNo(ctx, br.headBatchSeq.Uint64()) if err != nil { return nil, fmt.Errorf("batch with requested height %d was not found. Cause: %w", height, err) } batch = headBatch default: - maybeBatch, err := br.storage.FetchBatchByHeight(uint64(height)) + maybeBatch, err := br.storage.FetchBatchByHeight(ctx, uint64(height)) if err != nil { return nil, fmt.Errorf("batch with requested height %d could not be retrieved. Cause: %w", height, err) } diff --git a/go/enclave/components/block_processor.go b/go/enclave/components/block_processor.go index a80b893004..c80b518625 100644 --- a/go/enclave/components/block_processor.go +++ b/go/enclave/components/block_processor.go @@ -1,6 +1,7 @@ package components import ( + "context" "errors" "fmt" "time" @@ -35,7 +36,7 @@ type l1BlockProcessor struct { func NewBlockProcessor(storage storage.Storage, cc *crosschain.Processors, gasOracle gas.Oracle, logger gethlog.Logger) L1BlockProcessor { var l1BlockHash *common.L1BlockHash - head, err := storage.FetchHeadBlock() + head, err := storage.FetchHeadBlock(context.Background()) if err != nil { if !errors.Is(err, errutil.ErrNotFound) { logger.Crit("Cannot fetch head block", log.ErrKey, err) @@ -56,22 +57,22 @@ func NewBlockProcessor(storage storage.Storage, cc *crosschain.Processors, gasOr } } -func (bp *l1BlockProcessor) Process(br *common.BlockAndReceipts) (*BlockIngestionType, error) { +func (bp *l1BlockProcessor) Process(ctx context.Context, br *common.BlockAndReceipts) (*BlockIngestionType, error) { defer core.LogMethodDuration(bp.logger, measure.NewStopwatch(), "L1 block processed", log.BlockHashKey, br.Block.Hash()) - ingestion, err := bp.tryAndInsertBlock(br) + ingestion, err := bp.tryAndInsertBlock(ctx, br) if err != nil { return nil, err } if !ingestion.PreGenesis { // This requires block to be stored first ... but can permanently fail a block - err = bp.crossChainProcessors.Remote.StoreCrossChainMessages(br.Block, *br.Receipts) + err = bp.crossChainProcessors.Remote.StoreCrossChainMessages(ctx, br.Block, *br.Receipts) if err != nil { return nil, errors.New("failed to process cross chain messages") } - err = bp.crossChainProcessors.Remote.StoreCrossChainValueTransfers(br.Block, *br.Receipts) + err = bp.crossChainProcessors.Remote.StoreCrossChainValueTransfers(ctx, br.Block, *br.Receipts) if err != nil { return nil, fmt.Errorf("failed to process cross chain transfers. Cause: %w", err) } @@ -96,10 +97,10 @@ func (bp *l1BlockProcessor) HealthCheck() (bool, error) { return true, nil } -func (bp *l1BlockProcessor) tryAndInsertBlock(br *common.BlockAndReceipts) (*BlockIngestionType, error) { +func (bp *l1BlockProcessor) tryAndInsertBlock(ctx context.Context, br *common.BlockAndReceipts) (*BlockIngestionType, error) { block := br.Block - _, err := bp.storage.FetchBlock(block.Hash()) + _, err := bp.storage.FetchBlock(ctx, block.Hash()) if err == nil { return nil, errutil.ErrBlockAlreadyProcessed } @@ -109,7 +110,7 @@ func (bp *l1BlockProcessor) tryAndInsertBlock(br *common.BlockAndReceipts) (*Blo } // We insert the block into the L1 chain and store it. - ingestionType, err := bp.ingestBlock(block) + ingestionType, err := bp.ingestBlock(ctx, block) if err != nil { // Do not store the block if the L1 chain insertion failed return nil, err @@ -117,7 +118,7 @@ func (bp *l1BlockProcessor) tryAndInsertBlock(br *common.BlockAndReceipts) (*Blo bp.logger.Trace("Block inserted successfully", log.BlockHeightKey, block.NumberU64(), log.BlockHashKey, block.Hash(), "ingestionType", ingestionType) - err = bp.storage.StoreBlock(block, ingestionType.ChainFork) + err = bp.storage.StoreBlock(ctx, block, ingestionType.ChainFork) if err != nil { return nil, fmt.Errorf("1. could not store block. Cause: %w", err) } @@ -125,9 +126,9 @@ func (bp *l1BlockProcessor) tryAndInsertBlock(br *common.BlockAndReceipts) (*Blo return ingestionType, nil } -func (bp *l1BlockProcessor) ingestBlock(block *common.L1Block) (*BlockIngestionType, error) { +func (bp *l1BlockProcessor) ingestBlock(ctx context.Context, block *common.L1Block) (*BlockIngestionType, error) { // todo (#1056) - this is minimal L1 tracking/validation, and should be removed when we are using geth's blockchain or lightchain structures for validation - prevL1Head, err := bp.GetHead() + prevL1Head, err := bp.GetHead(ctx) if err != nil { if errors.Is(err, errutil.ErrNotFound) { // todo (@matt) - we should enforce that this block is a configured hash (e.g. the L1 management contract deployment block) @@ -137,7 +138,7 @@ func (bp *l1BlockProcessor) ingestBlock(block *common.L1Block) (*BlockIngestionT } // we do a basic sanity check, comparing the received block to the head block on the chain if block.ParentHash() != prevL1Head.Hash() { - chainFork, err := gethutil.LCA(block, prevL1Head, bp.storage) + chainFork, err := gethutil.LCA(ctx, block, prevL1Head, bp.storage) if err != nil { bp.logger.Trace("parent not found", "blkHeight", block.NumberU64(), log.BlockHashKey, block.Hash(), @@ -156,11 +157,11 @@ func (bp *l1BlockProcessor) ingestBlock(block *common.L1Block) (*BlockIngestionT return &BlockIngestionType{ChainFork: nil, PreGenesis: false}, nil } -func (bp *l1BlockProcessor) GetHead() (*common.L1Block, error) { +func (bp *l1BlockProcessor) GetHead(ctx context.Context) (*common.L1Block, error) { if bp.currentL1Head == nil { return nil, errutil.ErrNotFound } - return bp.storage.FetchBlock(*bp.currentL1Head) + return bp.storage.FetchBlock(ctx, *bp.currentL1Head) } func (bp *l1BlockProcessor) GetCrossChainContractAddress() *gethcommon.Address { diff --git a/go/enclave/components/consumer_test.go b/go/enclave/components/consumer_test.go index 799116a256..dc99fbc7f5 100644 --- a/go/enclave/components/consumer_test.go +++ b/go/enclave/components/consumer_test.go @@ -1,6 +1,7 @@ package components import ( + "context" "math/big" "testing" @@ -24,7 +25,7 @@ func TestInvalidBlocksAreRejected(t *testing.T) { for _, header := range invalidHeaders { loopHeader := header - _, err := blockConsumer.ingestBlock(types.NewBlock(&loopHeader, nil, nil, nil, &trie.StackTrie{})) + _, err := blockConsumer.ingestBlock(context.Background(), types.NewBlock(&loopHeader, nil, nil, nil, &trie.StackTrie{})) if err == nil { t.Errorf("expected block with invalid header to be rejected but was accepted") } diff --git a/go/enclave/components/interfaces.go b/go/enclave/components/interfaces.go index a877553422..1e507c7ad9 100644 --- a/go/enclave/components/interfaces.go +++ b/go/enclave/components/interfaces.go @@ -1,6 +1,7 @@ package components import ( + "context" "errors" "math/big" @@ -34,8 +35,8 @@ func (bit *BlockIngestionType) IsFork() bool { } type L1BlockProcessor interface { - Process(br *common.BlockAndReceipts) (*BlockIngestionType, error) - GetHead() (*common.L1Block, error) + Process(ctx context.Context, br *common.BlockAndReceipts) (*BlockIngestionType, error) + GetHead(context.Context) (*common.L1Block, error) GetCrossChainContractAddress() *gethcommon.Address HealthCheck() (bool, error) } @@ -69,28 +70,28 @@ type BatchExecutor interface { // Call with same BatchContext should always produce identical extBatch - idempotent // Should be safe to call in parallel // failForEmptyBatch bool is used to skip batch production - ComputeBatch(batchContext *BatchExecutionContext, failForEmptyBatch bool) (*ComputedBatch, error) + ComputeBatch(ctx context.Context, batchContext *BatchExecutionContext, failForEmptyBatch bool) (*ComputedBatch, error) // ExecuteBatch - executes the transactions and xchain messages, returns the receipts, and updates the stateDB - ExecuteBatch(*core.Batch) (types.Receipts, error) + ExecuteBatch(context.Context, *core.Batch) (types.Receipts, error) // CreateGenesisState - will create and commit the genesis state in the stateDB for the given block hash, // and uint64 timestamp representing the time now. In this genesis state is where one can // find preallocated funds for faucet. TODO - make this an option - CreateGenesisState(common.L1BlockHash, uint64, gethcommon.Address, *big.Int) (*core.Batch, *types.Transaction, error) + CreateGenesisState(context.Context, common.L1BlockHash, uint64, gethcommon.Address, *big.Int) (*core.Batch, *types.Transaction, error) } type BatchRegistry interface { // BatchesAfter - Given a hash, will return batches following it until the head batch and the l1 blocks referenced by those batches - BatchesAfter(batchSeqNo uint64, upToL1Height uint64, rollupLimiter limiters.RollupLimiter) ([]*core.Batch, []*types.Block, error) + BatchesAfter(ctx context.Context, batchSeqNo uint64, upToL1Height uint64, rollupLimiter limiters.RollupLimiter) ([]*core.Batch, []*types.Block, error) // GetBatchStateAtHeight - creates a stateDB that represents the state committed when // the batch with height matching the blockNumber was created and stored. - GetBatchStateAtHeight(blockNumber *gethrpc.BlockNumber) (*state.StateDB, error) + GetBatchStateAtHeight(ctx context.Context, blockNumber *gethrpc.BlockNumber) (*state.StateDB, error) // GetBatchAtHeight - same as `GetBatchStateAtHeight`, but instead returns the full batch // rather than its stateDB only. - GetBatchAtHeight(height gethrpc.BlockNumber) (*core.Batch, error) + GetBatchAtHeight(ctx context.Context, height gethrpc.BlockNumber) (*core.Batch, error) // SubscribeForExecutedBatches - register a callback for new batches SubscribeForExecutedBatches(func(*core.Batch, types.Receipts)) @@ -109,12 +110,12 @@ type BatchRegistry interface { type RollupProducer interface { // CreateInternalRollup - creates a rollup starting from the end of the last rollup that has been stored on the L1 - CreateInternalRollup(fromBatchNo uint64, upToL1Height uint64, limiter limiters.RollupLimiter) (*core.Rollup, error) + CreateInternalRollup(ctx context.Context, fromBatchNo uint64, upToL1Height uint64, limiter limiters.RollupLimiter) (*core.Rollup, error) } type RollupConsumer interface { // ProcessRollupsInBlock - extracts the rollup from the block's transactions // and verifies its integrity, saving and processing any batches that have // not been seen previously. - ProcessRollupsInBlock(b *common.BlockAndReceipts) error + ProcessRollupsInBlock(ctx context.Context, b *common.BlockAndReceipts) error } diff --git a/go/enclave/components/rollup_compression.go b/go/enclave/components/rollup_compression.go index 6769b728d8..ef2156693a 100644 --- a/go/enclave/components/rollup_compression.go +++ b/go/enclave/components/rollup_compression.go @@ -1,6 +1,7 @@ package components import ( + "context" "errors" "fmt" "math/big" @@ -95,8 +96,8 @@ type batchFromRollup struct { } // CreateExtRollup - creates a compressed and encrypted External rollup from the internal data structure -func (rc *RollupCompression) CreateExtRollup(r *core.Rollup) (*common.ExtRollup, error) { - header, err := rc.createRollupHeader(r) +func (rc *RollupCompression) CreateExtRollup(ctx context.Context, r *core.Rollup) (*common.ExtRollup, error) { + header, err := rc.createRollupHeader(ctx, r) if err != nil { return nil, err } @@ -122,7 +123,7 @@ func (rc *RollupCompression) CreateExtRollup(r *core.Rollup) (*common.ExtRollup, } // ProcessExtRollup - given an External rollup, responsible with checking and saving all batches found inside -func (rc *RollupCompression) ProcessExtRollup(rollup *common.ExtRollup) (*common.CalldataRollupHeader, error) { +func (rc *RollupCompression) ProcessExtRollup(ctx context.Context, rollup *common.ExtRollup) (*common.CalldataRollupHeader, error) { transactionsPerBatch := make([][]*common.L2Tx, 0) err := rc.decryptDecompressAndDeserialise(rollup.BatchPayloads, &transactionsPerBatch) if err != nil { @@ -138,13 +139,13 @@ func (rc *RollupCompression) ProcessExtRollup(rollup *common.ExtRollup) (*common // The recreation of batches is a 2-step process: // 1. calculate fields like: sequence, height, time, l1Proof, from the implicit and explicit information from the metadata - incompleteBatches, err := rc.createIncompleteBatches(calldataRollupHeader, transactionsPerBatch, rollup.Header.CompressionL1Head) + incompleteBatches, err := rc.createIncompleteBatches(ctx, calldataRollupHeader, transactionsPerBatch, rollup.Header.CompressionL1Head) if err != nil { return nil, err } // 2. execute each batch to be able to calculate the hash which is necessary for the next batch as it is the parent. - err = rc.executeAndSaveIncompleteBatches(calldataRollupHeader, incompleteBatches) + err = rc.executeAndSaveIncompleteBatches(ctx, calldataRollupHeader, incompleteBatches) if err != nil { return nil, err } @@ -153,7 +154,7 @@ func (rc *RollupCompression) ProcessExtRollup(rollup *common.ExtRollup) (*common } // the main logic that goes from a list of batches to the rollup header -func (rc *RollupCompression) createRollupHeader(rollup *core.Rollup) (*common.CalldataRollupHeader, error) { +func (rc *RollupCompression) createRollupHeader(ctx context.Context, rollup *core.Rollup) (*common.CalldataRollupHeader, error) { batches := rollup.Batches reorgs := make([]*common.BatchHeader, len(batches)) @@ -168,7 +169,7 @@ func (rc *RollupCompression) createRollupHeader(rollup *core.Rollup) (*common.Ca batchHeaders := make([]*common.BatchHeader, len(batches)) // create an efficient structure to determine whether a batch is canonical - reorgedBatches, err := rc.storage.FetchNonCanonicalBatchesBetween(batches[0].SeqNo().Uint64(), batches[len(batches)-1].SeqNo().Uint64()) + reorgedBatches, err := rc.storage.FetchNonCanonicalBatchesBetween(ctx, batches[0].SeqNo().Uint64(), batches[len(batches)-1].SeqNo().Uint64()) if err != nil { return nil, err } @@ -264,14 +265,14 @@ func (rc *RollupCompression) createRollupHeader(rollup *core.Rollup) (*common.Ca } // the main logic to recreate the batches from the header. The logical pair of: `createRollupHeader` -func (rc *RollupCompression) createIncompleteBatches(calldataRollupHeader *common.CalldataRollupHeader, transactionsPerBatch [][]*common.L2Tx, compressionL1Head common.L1BlockHash) ([]*batchFromRollup, error) { +func (rc *RollupCompression) createIncompleteBatches(ctx context.Context, calldataRollupHeader *common.CalldataRollupHeader, transactionsPerBatch [][]*common.L2Tx, compressionL1Head common.L1BlockHash) ([]*batchFromRollup, error) { incompleteBatches := make([]*batchFromRollup, len(transactionsPerBatch)) startAtSeq := calldataRollupHeader.FirstBatchSequence.Int64() currentHeight := calldataRollupHeader.FirstCanonBatchHeight.Int64() - 1 currentTime := int64(calldataRollupHeader.StartTime) - rollupL1Block, err := rc.storage.FetchBlock(compressionL1Head) + rollupL1Block, err := rc.storage.FetchBlock(ctx, compressionL1Head) if err != nil { return nil, fmt.Errorf("can't find the block used for compression. Cause: %w", err) } @@ -283,7 +284,7 @@ func (rc *RollupCompression) createIncompleteBatches(calldataRollupHeader *commo // a cache of the l1 blocks used by the current rollup, indexed by their height l1BlocksAtHeight := make(map[uint64]*types.Block) - err = rc.calcL1AncestorsOfHeight(big.NewInt(int64(slices.Min(l1Heights))), rollupL1Block, l1BlocksAtHeight) + err = rc.calcL1AncestorsOfHeight(ctx, big.NewInt(int64(slices.Min(l1Heights))), rollupL1Block, l1BlocksAtHeight) if err != nil { return nil, err } @@ -386,23 +387,23 @@ func (rc *RollupCompression) calculateL1HeightsFromDeltas(calldataRollupHeader * return l1Heights, nil } -func (rc *RollupCompression) calcL1AncestorsOfHeight(fromHeight *big.Int, toBlock *types.Block, path map[uint64]*types.Block) error { +func (rc *RollupCompression) calcL1AncestorsOfHeight(ctx context.Context, fromHeight *big.Int, toBlock *types.Block, path map[uint64]*types.Block) error { path[toBlock.NumberU64()] = toBlock if toBlock.NumberU64() == fromHeight.Uint64() { return nil } - p, err := rc.storage.FetchBlock(toBlock.ParentHash()) + p, err := rc.storage.FetchBlock(ctx, toBlock.ParentHash()) if err != nil { return err } - return rc.calcL1AncestorsOfHeight(fromHeight, p, path) + return rc.calcL1AncestorsOfHeight(ctx, fromHeight, p, path) } -func (rc *RollupCompression) executeAndSaveIncompleteBatches(calldataRollupHeader *common.CalldataRollupHeader, incompleteBatches []*batchFromRollup) error { //nolint:gocognit +func (rc *RollupCompression) executeAndSaveIncompleteBatches(ctx context.Context, calldataRollupHeader *common.CalldataRollupHeader, incompleteBatches []*batchFromRollup) error { //nolint:gocognit parentHash := calldataRollupHeader.FirstCanonParentHash if calldataRollupHeader.FirstBatchSequence.Uint64() != common.L2GenesisSeqNo { - _, err := rc.storage.FetchBatch(parentHash) + _, err := rc.storage.FetchBatch(ctx, parentHash) if err != nil { rc.logger.Error("Could not find batch mentioned in the rollup. This should not happen.", log.ErrKey, err) return err @@ -411,7 +412,7 @@ func (rc *RollupCompression) executeAndSaveIncompleteBatches(calldataRollupHeade for _, incompleteBatch := range incompleteBatches { // check whether the batch is already stored in the database - b, err := rc.storage.FetchBatchBySeqNo(incompleteBatch.seqNo.Uint64()) + b, err := rc.storage.FetchBatchBySeqNo(ctx, incompleteBatch.seqNo.Uint64()) if err == nil { // chain to a parent only if the batch is not a reorg if incompleteBatch.header == nil { @@ -427,15 +428,16 @@ func (rc *RollupCompression) executeAndSaveIncompleteBatches(calldataRollupHeade // this batch was re-orged case incompleteBatch.header != nil: - convertedHeader, err := rc.gethEncodingService.CreateEthHeaderForBatch(incompleteBatch.header) + convertedHeader, err := rc.gethEncodingService.CreateEthHeaderForBatch(ctx, incompleteBatch.header) if err != nil { return err } - err = rc.storage.StoreBatch(&core.Batch{ - Header: incompleteBatch.header, - Transactions: incompleteBatch.transactions, - }, convertedHeader.Hash()) + err = rc.storage.StoreBatch(ctx, + &core.Batch{ + Header: incompleteBatch.header, + Transactions: incompleteBatch.transactions, + }, convertedHeader.Hash()) if err != nil { return err } @@ -443,6 +445,7 @@ func (rc *RollupCompression) executeAndSaveIncompleteBatches(calldataRollupHeade // handle genesis case incompleteBatch.seqNo.Uint64() == common.L2GenesisSeqNo: genBatch, _, err := rc.batchExecutor.CreateGenesisState( + ctx, incompleteBatch.l1Proof, incompleteBatch.time, calldataRollupHeader.Coinbase, @@ -457,16 +460,16 @@ func (rc *RollupCompression) executeAndSaveIncompleteBatches(calldataRollupHeade rc.logger.Crit("Rollup decompression failure. The check hashes don't match") }*/ - convertedHeader, err := rc.gethEncodingService.CreateEthHeaderForBatch(genBatch.Header) + convertedHeader, err := rc.gethEncodingService.CreateEthHeaderForBatch(ctx, genBatch.Header) if err != nil { return err } - err = rc.storage.StoreBatch(genBatch, convertedHeader.Hash()) + err = rc.storage.StoreBatch(ctx, genBatch, convertedHeader.Hash()) if err != nil { return err } - err = rc.storage.StoreExecutedBatch(genBatch, nil) + err = rc.storage.StoreExecutedBatch(ctx, genBatch, nil) if err != nil { return err } @@ -478,7 +481,9 @@ func (rc *RollupCompression) executeAndSaveIncompleteBatches(calldataRollupHeade default: // transforms the incompleteBatch into a BatchHeader by executing the transactions // and then the info can be used to fill in the parent - computedBatch, err := rc.computeBatch(incompleteBatch.l1Proof, + computedBatch, err := rc.computeBatch( + ctx, + incompleteBatch.l1Proof, parentHash, incompleteBatch.transactions, incompleteBatch.time, @@ -499,16 +504,16 @@ func (rc *RollupCompression) executeAndSaveIncompleteBatches(calldataRollupHeade return fmt.Errorf("cannot commit stateDB for incoming valid batch seq=%d. Cause: %w", incompleteBatch.seqNo, err) } - convertedHeader, err := rc.gethEncodingService.CreateEthHeaderForBatch(computedBatch.Batch.Header) + convertedHeader, err := rc.gethEncodingService.CreateEthHeaderForBatch(ctx, computedBatch.Batch.Header) if err != nil { return err } - err = rc.storage.StoreBatch(computedBatch.Batch, convertedHeader.Hash()) + err = rc.storage.StoreBatch(ctx, computedBatch.Batch, convertedHeader.Hash()) if err != nil { return err } - err = rc.storage.StoreExecutedBatch(computedBatch.Batch, computedBatch.Receipts) + err = rc.storage.StoreExecutedBatch(ctx, computedBatch.Batch, computedBatch.Receipts) if err != nil { return err } @@ -553,6 +558,7 @@ func (rc *RollupCompression) decryptDecompressAndDeserialise(blob []byte, obj an } func (rc *RollupCompression) computeBatch( + ctx context.Context, BlockPtr common.L1BlockHash, ParentPtr common.L2BatchHash, Transactions common.L2Transactions, @@ -561,16 +567,18 @@ func (rc *RollupCompression) computeBatch( Coinbase gethcommon.Address, BaseFee *big.Int, ) (*ComputedBatch, error) { - return rc.batchExecutor.ComputeBatch(&BatchExecutionContext{ - BlockPtr: BlockPtr, - ParentPtr: ParentPtr, - Transactions: Transactions, - AtTime: AtTime, - Creator: Coinbase, - ChainConfig: rc.chainConfig, - SequencerNo: SequencerNo, - BaseFee: big.NewInt(0).Set(BaseFee), - }, false) + return rc.batchExecutor.ComputeBatch( + ctx, + &BatchExecutionContext{ + BlockPtr: BlockPtr, + ParentPtr: ParentPtr, + Transactions: Transactions, + AtTime: AtTime, + Creator: Coinbase, + ChainConfig: rc.chainConfig, + SequencerNo: SequencerNo, + BaseFee: big.NewInt(0).Set(BaseFee), + }, false) } func transformToByteArray(reorgs []*common.BatchHeader) ([][]byte, error) { diff --git a/go/enclave/components/rollup_consumer.go b/go/enclave/components/rollup_consumer.go index 3c57aaf11f..21c977e5d5 100644 --- a/go/enclave/components/rollup_consumer.go +++ b/go/enclave/components/rollup_consumer.go @@ -1,6 +1,7 @@ package components import ( + "context" "fmt" "github.com/ten-protocol/go-ten/go/enclave/core" @@ -46,7 +47,7 @@ func NewRollupConsumer( } } -func (rc *rollupConsumerImpl) ProcessRollupsInBlock(b *common.BlockAndReceipts) error { +func (rc *rollupConsumerImpl) ProcessRollupsInBlock(ctx context.Context, b *common.BlockAndReceipts) error { defer core.LogMethodDuration(rc.logger, measure.NewStopwatch(), "Rollup consumer processed block", log.BlockHashKey, b.Block.Hash()) rollups := rc.extractRollups(b) @@ -60,12 +61,12 @@ func (rc *rollupConsumerImpl) ProcessRollupsInBlock(b *common.BlockAndReceipts) } for _, rollup := range rollups { - l1CompressionBlock, err := rc.storage.FetchBlock(rollup.Header.CompressionL1Head) + l1CompressionBlock, err := rc.storage.FetchBlock(ctx, rollup.Header.CompressionL1Head) if err != nil { rc.logger.Warn("Can't process rollup because the l1 block used for compression is not available", "block_hash", rollup.Header.CompressionL1Head, log.RollupHashKey, rollup.Hash(), log.ErrKey, err) continue } - canonicalBlockByHeight, err := rc.storage.FetchCanonicaBlockByHeight(l1CompressionBlock.Number()) + canonicalBlockByHeight, err := rc.storage.FetchCanonicaBlockByHeight(ctx, l1CompressionBlock.Number()) if err != nil { return err } @@ -74,13 +75,13 @@ func (rc *rollupConsumerImpl) ProcessRollupsInBlock(b *common.BlockAndReceipts) continue } // read batch data from rollup, verify and store it - internalHeader, err := rc.rollupCompression.ProcessExtRollup(rollup) + internalHeader, err := rc.rollupCompression.ProcessExtRollup(ctx, rollup) if err != nil { rc.logger.Error("Failed processing rollup", log.RollupHashKey, rollup.Hash(), log.ErrKey, err) // todo - issue challenge as a validator return err } - if err := rc.storage.StoreRollup(rollup, internalHeader); err != nil { + if err := rc.storage.StoreRollup(ctx, rollup, internalHeader); err != nil { rc.logger.Error("Failed storing rollup", log.RollupHashKey, rollup.Hash(), log.ErrKey, err) return err } diff --git a/go/enclave/components/rollup_producer.go b/go/enclave/components/rollup_producer.go index df9981ace5..17bb301461 100644 --- a/go/enclave/components/rollup_producer.go +++ b/go/enclave/components/rollup_producer.go @@ -1,6 +1,7 @@ package components import ( + "context" "fmt" "math/big" @@ -35,8 +36,8 @@ func NewRollupProducer(enclaveID gethcommon.Address, storage storage.Storage, ba } } -func (re *rollupProducerImpl) CreateInternalRollup(fromBatchNo uint64, upToL1Height uint64, limiter limiters.RollupLimiter) (*core.Rollup, error) { - batches, blocks, err := re.batchRegistry.BatchesAfter(fromBatchNo, upToL1Height, limiter) +func (re *rollupProducerImpl) CreateInternalRollup(ctx context.Context, fromBatchNo uint64, upToL1Height uint64, limiter limiters.RollupLimiter) (*core.Rollup, error) { + batches, blocks, err := re.batchRegistry.BatchesAfter(ctx, fromBatchNo, upToL1Height, limiter) if err != nil { return nil, fmt.Errorf("could not fetch 'from' batch (seqNo=%d) for rollup: %w", fromBatchNo, err) } @@ -47,7 +48,7 @@ func (re *rollupProducerImpl) CreateInternalRollup(fromBatchNo uint64, upToL1Hei return nil, fmt.Errorf("no batches for rollup") } - block, err := re.storage.FetchCanonicaBlockByHeight(big.NewInt(int64(upToL1Height))) + block, err := re.storage.FetchCanonicaBlockByHeight(ctx, big.NewInt(int64(upToL1Height))) if err != nil { return nil, err } diff --git a/go/enclave/components/shared_secret_process.go b/go/enclave/components/shared_secret_process.go index a0a04fed69..82daff355c 100644 --- a/go/enclave/components/shared_secret_process.go +++ b/go/enclave/components/shared_secret_process.go @@ -1,6 +1,7 @@ package components import ( + "context" "fmt" gethcommon "github.com/ethereum/go-ethereum/common" @@ -33,7 +34,7 @@ func NewSharedSecretProcessor(mgmtcontractlib mgmtcontractlib.MgmtContractLib, a } // ProcessNetworkSecretMsgs we watch for all messages that are requesting or receiving the secret and we store the nodes attested keys -func (ssp *SharedSecretProcessor) ProcessNetworkSecretMsgs(br *common.BlockAndReceipts) []*common.ProducedSecretResponse { +func (ssp *SharedSecretProcessor) ProcessNetworkSecretMsgs(ctx context.Context, br *common.BlockAndReceipts) []*common.ProducedSecretResponse { var responses []*common.ProducedSecretResponse transactions := br.SuccessfulTransactions() block := br.Block @@ -43,7 +44,7 @@ func (ssp *SharedSecretProcessor) ProcessNetworkSecretMsgs(br *common.BlockAndRe // this transaction is for a node that has joined the network and needs to be sent the network secret if scrtReqTx, ok := t.(*ethadapter.L1RequestSecretTx); ok { ssp.logger.Info("Process shared secret request.", log.BlockHeightKey, block.Number(), log.BlockHashKey, block.Hash(), log.TxKey, tx.Hash()) - resp, err := ssp.processSecretRequest(scrtReqTx) + resp, err := ssp.processSecretRequest(ctx, scrtReqTx) if err != nil { ssp.logger.Error("Failed to process shared secret request.", log.ErrKey, err) continue @@ -61,7 +62,7 @@ func (ssp *SharedSecretProcessor) ProcessNetworkSecretMsgs(br *common.BlockAndRe ssp.logger.Error("Could not decode attestation report", log.ErrKey, err) } - err = ssp.storeAttestation(att) + err = ssp.storeAttestation(ctx, att) if err != nil { ssp.logger.Error("Could not store the attestation report.", log.ErrKey, err) } @@ -70,20 +71,20 @@ func (ssp *SharedSecretProcessor) ProcessNetworkSecretMsgs(br *common.BlockAndRe return responses } -func (ssp *SharedSecretProcessor) processSecretRequest(req *ethadapter.L1RequestSecretTx) (*common.ProducedSecretResponse, error) { +func (ssp *SharedSecretProcessor) processSecretRequest(ctx context.Context, req *ethadapter.L1RequestSecretTx) (*common.ProducedSecretResponse, error) { att, err := common.DecodeAttestation(req.Attestation) if err != nil { return nil, fmt.Errorf("failed to decode attestation - %w", err) } ssp.logger.Info("received attestation", "attestation", att) - secret, err := ssp.verifyAttestationAndEncryptSecret(att) + secret, err := ssp.verifyAttestationAndEncryptSecret(ctx, att) if err != nil { return nil, fmt.Errorf("secret request failed, no response will be published - %w", err) } // Store the attested key only if the attestation process succeeded. - err = ssp.storeAttestation(att) + err = ssp.storeAttestation(ctx, att) if err != nil { return nil, fmt.Errorf("could not store attestation, no response will be published. Cause: %w", err) } @@ -99,7 +100,7 @@ func (ssp *SharedSecretProcessor) processSecretRequest(req *ethadapter.L1Request } // ShareSecret verifies the request and if it trusts the report and the public key it will return the secret encrypted with that public key. -func (ssp *SharedSecretProcessor) verifyAttestationAndEncryptSecret(att *common.AttestationReport) (common.EncryptedSharedEnclaveSecret, error) { +func (ssp *SharedSecretProcessor) verifyAttestationAndEncryptSecret(ctx context.Context, att *common.AttestationReport) (common.EncryptedSharedEnclaveSecret, error) { // First we verify the attestation report has come from a valid obscuro enclave running in a verified TEE. data, err := ssp.attestationProvider.VerifyReport(att) if err != nil { @@ -111,7 +112,7 @@ func (ssp *SharedSecretProcessor) verifyAttestationAndEncryptSecret(att *common. } ssp.logger.Info(fmt.Sprintf("Successfully verified attestation and identity. Owner: %s", att.EnclaveID)) - secret, err := ssp.storage.FetchSecret() + secret, err := ssp.storage.FetchSecret(ctx) if err != nil { return nil, fmt.Errorf("could not retrieve secret; this should not happen. Cause: %w", err) } @@ -119,14 +120,14 @@ func (ssp *SharedSecretProcessor) verifyAttestationAndEncryptSecret(att *common. } // storeAttestation stores the attested keys of other nodes so we can decrypt their rollups -func (ssp *SharedSecretProcessor) storeAttestation(att *common.AttestationReport) error { +func (ssp *SharedSecretProcessor) storeAttestation(ctx context.Context, att *common.AttestationReport) error { ssp.logger.Info(fmt.Sprintf("Store attestation. Owner: %s", att.EnclaveID)) // Store the attestation key, err := gethcrypto.DecompressPubkey(att.PubKey) if err != nil { return fmt.Errorf("failed to parse public key %w", err) } - err = ssp.storage.StoreAttestedKey(att.EnclaveID, key) + err = ssp.storage.StoreAttestedKey(ctx, att.EnclaveID, key) if err != nil { return fmt.Errorf("could not store attested key. Cause: %w", err) } diff --git a/go/enclave/crosschain/block_message_extractor.go b/go/enclave/crosschain/block_message_extractor.go index 4009a27c1f..7937cb7927 100644 --- a/go/enclave/crosschain/block_message_extractor.go +++ b/go/enclave/crosschain/block_message_extractor.go @@ -1,6 +1,7 @@ package crosschain import ( + "context" "fmt" "github.com/ten-protocol/go-ten/go/enclave/core" @@ -36,7 +37,7 @@ func (m *blockMessageExtractor) Enabled() bool { return m.GetBusAddress().Big().Cmp(gethcommon.Big0) != 0 } -func (m *blockMessageExtractor) StoreCrossChainValueTransfers(block *common.L1Block, receipts common.L1Receipts) error { +func (m *blockMessageExtractor) StoreCrossChainValueTransfers(ctx context.Context, block *common.L1Block, receipts common.L1Receipts) error { defer core.LogMethodDuration(m.logger, measure.NewStopwatch(), "Block value transfer messages processed", log.BlockHashKey, block.Hash()) /*areReceiptsValid := common.VerifyReceiptHash(block, receipts) @@ -62,7 +63,7 @@ func (m *blockMessageExtractor) StoreCrossChainValueTransfers(block *common.L1Bl } m.logger.Trace("Storing value transfers for block", "nr", len(transfers), log.BlockHashKey, block.Hash()) - err = m.storage.StoreValueTransfers(block.Hash(), transfers) + err = m.storage.StoreValueTransfers(ctx, block.Hash(), transfers) if err != nil { m.logger.Crit("Unable to store the transfers", log.ErrKey, err) return err @@ -75,7 +76,7 @@ func (m *blockMessageExtractor) StoreCrossChainValueTransfers(block *common.L1Bl // The messages will be stored in DB storage for later usage. // block - the L1 block for which events are extracted. // receipts - all of the receipts for the corresponding block. This is validated. -func (m *blockMessageExtractor) StoreCrossChainMessages(block *common.L1Block, receipts common.L1Receipts) error { +func (m *blockMessageExtractor) StoreCrossChainMessages(ctx context.Context, block *common.L1Block, receipts common.L1Receipts) error { defer core.LogMethodDuration(m.logger, measure.NewStopwatch(), "Block cross chain messages processed", log.BlockHashKey, block.Hash()) if len(receipts) == 0 { @@ -91,7 +92,7 @@ func (m *blockMessageExtractor) StoreCrossChainMessages(block *common.L1Block, r if len(messages) > 0 { m.logger.Info(fmt.Sprintf("Storing %d messages for block", len(messages)), log.BlockHashKey, block.Hash()) - err = m.storage.StoreL1Messages(block.Hash(), messages) + err = m.storage.StoreL1Messages(ctx, block.Hash(), messages) if err != nil { m.logger.Crit("Unable to store the messages", log.ErrKey, err) return err diff --git a/go/enclave/crosschain/interfaces.go b/go/enclave/crosschain/interfaces.go index f382b35491..dfa0bc7788 100644 --- a/go/enclave/crosschain/interfaces.go +++ b/go/enclave/crosschain/interfaces.go @@ -1,6 +1,8 @@ package crosschain import ( + "context" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ten-protocol/go-ten/go/common" @@ -14,9 +16,9 @@ type ( type BlockMessageExtractor interface { // StoreCrossChainMessages - Verifies receipts belong to block and saves the relevant cross chain messages from the receipts - StoreCrossChainMessages(block *common.L1Block, receipts common.L1Receipts) error + StoreCrossChainMessages(ctx context.Context, block *common.L1Block, receipts common.L1Receipts) error - StoreCrossChainValueTransfers(block *common.L1Block, receipts common.L1Receipts) error + StoreCrossChainValueTransfers(ctx context.Context, block *common.L1Block, receipts common.L1Receipts) error // GetBusAddress - Returns the L1 message bus address. GetBusAddress() *common.L1Address @@ -43,13 +45,13 @@ type Manager interface { GenerateMessageBusDeployTx() (*common.L2Tx, error) // ExtractOutboundMessages - Finds relevant logs in the receipts and converts them to cross chain messages. - ExtractOutboundMessages(receipts common.L2Receipts) (common.CrossChainMessages, error) + ExtractOutboundMessages(ctx context.Context, receipts common.L2Receipts) (common.CrossChainMessages, error) - ExtractOutboundTransfers(receipts common.L2Receipts) (common.ValueTransferEvents, error) + ExtractOutboundTransfers(ctx context.Context, receipts common.L2Receipts) (common.ValueTransferEvents, error) - CreateSyntheticTransactions(messages common.CrossChainMessages, rollupState *state.StateDB) common.L2Transactions + CreateSyntheticTransactions(ctx context.Context, messages common.CrossChainMessages, rollupState *state.StateDB) common.L2Transactions - ExecuteValueTransfers(transfers common.ValueTransferEvents, rollupState *state.StateDB) + ExecuteValueTransfers(ctx context.Context, transfers common.ValueTransferEvents, rollupState *state.StateDB) - RetrieveInboundMessages(fromBlock *common.L1Block, toBlock *common.L1Block, rollupState *state.StateDB) (common.CrossChainMessages, common.ValueTransferEvents) + RetrieveInboundMessages(ctx context.Context, fromBlock *common.L1Block, toBlock *common.L1Block, rollupState *state.StateDB) (common.CrossChainMessages, common.ValueTransferEvents) } diff --git a/go/enclave/crosschain/message_bus_manager.go b/go/enclave/crosschain/message_bus_manager.go index 71056b3169..a0d999922f 100644 --- a/go/enclave/crosschain/message_bus_manager.go +++ b/go/enclave/crosschain/message_bus_manager.go @@ -2,6 +2,7 @@ package crosschain import ( "bytes" + "context" "fmt" "math/big" @@ -105,7 +106,7 @@ func (m *MessageBusManager) GenerateMessageBusDeployTx() (*common.L2Tx, error) { } // ExtractLocalMessages - Finds relevant logs in the receipts and converts them to cross chain messages. -func (m *MessageBusManager) ExtractOutboundMessages(receipts common.L2Receipts) (common.CrossChainMessages, error) { +func (m *MessageBusManager) ExtractOutboundMessages(ctx context.Context, receipts common.L2Receipts) (common.CrossChainMessages, error) { logs, err := filterLogsFromReceipts(receipts, m.messageBusAddress, &CrossChainEventID) if err != nil { m.logger.Error("Error extracting logs from L2 message bus!", log.ErrKey, err) @@ -121,8 +122,8 @@ func (m *MessageBusManager) ExtractOutboundMessages(receipts common.L2Receipts) return messages, nil } -// ExtractLocalMessages - Finds relevant logs in the receipts and converts them to cross chain messages. -func (m *MessageBusManager) ExtractOutboundTransfers(receipts common.L2Receipts) (common.ValueTransferEvents, error) { +// ExtractOutboundTransfers - Finds relevant logs in the receipts and converts them to cross chain messages. +func (m *MessageBusManager) ExtractOutboundTransfers(_ context.Context, receipts common.L2Receipts) (common.ValueTransferEvents, error) { logs, err := filterLogsFromReceipts(receipts, m.messageBusAddress, &ValueTransferEventID) if err != nil { m.logger.Error("Error extracting logs from L2 message bus!", log.ErrKey, err) @@ -142,13 +143,13 @@ func (m *MessageBusManager) ExtractOutboundTransfers(receipts common.L2Receipts) // todo (@stefan) - fix ordering of messages, currently it is irrelevant. // todo (@stefan) - do not extract messages below their consistency level. Irrelevant security wise. // todo (@stefan) - surface errors -func (m *MessageBusManager) RetrieveInboundMessages(fromBlock *common.L1Block, toBlock *common.L1Block, _ *state.StateDB) (common.CrossChainMessages, common.ValueTransferEvents) { +func (m *MessageBusManager) RetrieveInboundMessages(ctx context.Context, fromBlock *common.L1Block, toBlock *common.L1Block, _ *state.StateDB) (common.CrossChainMessages, common.ValueTransferEvents) { messages := make(common.CrossChainMessages, 0) transfers := make(common.ValueTransferEvents, 0) from := fromBlock.Hash() height := fromBlock.NumberU64() - if !m.storage.IsAncestor(toBlock, fromBlock) { + if !m.storage.IsAncestor(ctx, toBlock, fromBlock) { m.logger.Crit("Synthetic transactions can't be processed because the rollups are not on the same Ethereum fork. This should not happen.") } // Iterate through the blocks. @@ -160,12 +161,12 @@ func (m *MessageBusManager) RetrieveInboundMessages(fromBlock *common.L1Block, t m.logger.Trace(fmt.Sprintf("Looking for cross chain messages at block %s", b.Hash().Hex())) - messagesForBlock, err := m.storage.GetL1Messages(b.Hash()) + messagesForBlock, err := m.storage.GetL1Messages(ctx, b.Hash()) if err != nil { m.logger.Crit("Reading the key for the block failed with uncommon reason.", log.ErrKey, err) } - transfersForBlock, err := m.storage.GetL1Transfers(b.Hash()) + transfersForBlock, err := m.storage.GetL1Transfers(ctx, b.Hash()) if err != nil { m.logger.Crit("Unable to get L1 transfers for block that should be there.", log.ErrKey, err) } @@ -177,7 +178,7 @@ func (m *MessageBusManager) RetrieveInboundMessages(fromBlock *common.L1Block, t if b.NumberU64() < height { m.logger.Crit("block height is less than genesis height") } - p, err := m.storage.FetchBlock(b.ParentHash()) + p, err := m.storage.FetchBlock(ctx, b.ParentHash()) if err != nil { m.logger.Crit("Synthetic transactions can't be processed because the rollups are not on the same Ethereum fork") } @@ -193,14 +194,14 @@ func (m *MessageBusManager) RetrieveInboundMessages(fromBlock *common.L1Block, t return messages, transfers } -func (m *MessageBusManager) ExecuteValueTransfers(transfers common.ValueTransferEvents, rollupState *state.StateDB) { +func (m *MessageBusManager) ExecuteValueTransfers(ctx context.Context, transfers common.ValueTransferEvents, rollupState *state.StateDB) { for _, transfer := range transfers { rollupState.AddBalance(transfer.Receiver, transfer.Amount) } } // CreateSyntheticTransactions - generates transactions that the enclave should execute internally for the messages. -func (m *MessageBusManager) CreateSyntheticTransactions(messages common.CrossChainMessages, rollupState *state.StateDB) common.L2Transactions { +func (m *MessageBusManager) CreateSyntheticTransactions(ctx context.Context, messages common.CrossChainMessages, rollupState *state.StateDB) common.L2Transactions { // Get current nonce for this stateDB. // There can be forks thus we cannot trust the wallet. startingNonce := rollupState.GetNonce(m.GetOwner()) diff --git a/go/enclave/debugger/tracers.go b/go/enclave/debugger/tracers.go index 01acb11906..899d98ce5f 100644 --- a/go/enclave/debugger/tracers.go +++ b/go/enclave/debugger/tracers.go @@ -38,8 +38,8 @@ func New(chain l2chain.ObscuroChain, storage storage.Storage, config *params.Cha } } -func (d *Debugger) DebugEventLogRelevancy(txHash gethcommon.Hash) (json.RawMessage, error) { - logs, err := d.storage.DebugGetLogs(txHash) +func (d *Debugger) DebugEventLogRelevancy(ctx context.Context, txHash gethcommon.Hash) (json.RawMessage, error) { + logs, err := d.storage.DebugGetLogs(ctx, txHash) if err != nil { return nil, err } diff --git a/go/enclave/enclave.go b/go/enclave/enclave.go index ee73c7b773..42d62e44c9 100644 --- a/go/enclave/enclave.go +++ b/go/enclave/enclave.go @@ -139,7 +139,7 @@ func NewEnclave( } // attempt to fetch the enclave key from the database - enclaveKey, err := storage.GetEnclaveKey() + enclaveKey, err := storage.GetEnclaveKey(context.Background()) if err != nil { if !errors.Is(err, errutil.ErrNotFound) { logger.Crit("Failed to fetch enclave key", log.ErrKey, err) @@ -151,7 +151,7 @@ func NewEnclave( if err != nil { logger.Crit("Failed to generate enclave key.", log.ErrKey, err) } - err = storage.StoreEnclaveKey(enclaveKey) + err = storage.StoreEnclaveKey(context.Background(), enclaveKey) if err != nil { logger.Crit("Failed to store enclave key.", log.ErrKey, err) } @@ -168,7 +168,7 @@ func NewEnclave( gasOracle := gas.NewGasOracle() blockProcessor := components.NewBlockProcessor(storage, crossChainProcessors, gasOracle, logger) - batchExecutor := components.NewBatchExecutor(storage, gethEncodingService, crossChainProcessors, genesis, gasOracle, chainConfig, config.GasBatchExecutionLimit, logger) + batchExecutor := components.NewBatchExecutor(storage, *config, gethEncodingService, crossChainProcessors, genesis, gasOracle, chainConfig, config.GasBatchExecutionLimit, logger) sigVerifier, err := components.NewSignatureValidator(config.SequencerID, storage) registry := components.NewBatchRegistry(storage, logger) rProducer := components.NewRollupProducer(enclaveKey.EnclaveID(), storage, registry, logger) @@ -179,7 +179,7 @@ func NewEnclave( rConsumer := components.NewRollupConsumer(mgmtContractLib, registry, rollupCompression, storage, logger, sigVerifier) sharedSecretProcessor := components.NewSharedSecretProcessor(mgmtContractLib, attestationProvider, enclaveKey.EnclaveID(), storage, logger) - blockchain := ethchainadapter.NewEthChainAdapter(big.NewInt(config.ObscuroChainID), registry, storage, gethEncodingService, logger) + blockchain := ethchainadapter.NewEthChainAdapter(big.NewInt(config.ObscuroChainID), registry, storage, gethEncodingService, *config, logger) mempool, err := txpool.NewTxPool(blockchain, config.MinGasPrice, logger) if err != nil { logger.Crit("unable to init eth tx pool", log.ErrKey, err) @@ -217,6 +217,7 @@ func NewEnclave( chain := l2chain.NewChain( storage, + *config, gethEncodingService, chainConfig, genesis, @@ -228,7 +229,7 @@ func NewEnclave( subscriptionManager := events.NewSubscriptionManager(storage, config.ObscuroChainID, logger) // ensure cached chain state data is up-to-date using the persisted batch data - err = restoreStateDBCache(storage, registry, batchExecutor, genesis, logger) + err = restoreStateDBCache(context.Background(), storage, registry, batchExecutor, genesis, logger) if err != nil { logger.Crit("failed to resync L2 chain state DB after restart", log.ErrKey, err) } @@ -268,8 +269,8 @@ func NewEnclave( } } -func (e *enclaveImpl) GetBatch(hash common.L2BatchHash) (*common.ExtBatch, common.SystemError) { - batch, err := e.storage.FetchBatch(hash) +func (e *enclaveImpl) GetBatch(ctx context.Context, hash common.L2BatchHash) (*common.ExtBatch, common.SystemError) { + batch, err := e.storage.FetchBatch(ctx, hash) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("failed getting batch. Cause: %w", err)) } @@ -281,8 +282,8 @@ func (e *enclaveImpl) GetBatch(hash common.L2BatchHash) (*common.ExtBatch, commo return b, nil } -func (e *enclaveImpl) GetBatchBySeqNo(seqNo uint64) (*common.ExtBatch, common.SystemError) { - batch, err := e.storage.FetchBatchBySeqNo(seqNo) +func (e *enclaveImpl) GetBatchBySeqNo(ctx context.Context, seqNo uint64) (*common.ExtBatch, common.SystemError) { + batch, err := e.storage.FetchBatchBySeqNo(ctx, seqNo) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("failed getting batch. Cause: %w", err)) } @@ -294,8 +295,8 @@ func (e *enclaveImpl) GetBatchBySeqNo(seqNo uint64) (*common.ExtBatch, common.Sy return b, nil } -func (e *enclaveImpl) GetRollupData(hash common.L2RollupHash) (*common.PublicRollupMetadata, common.SystemError) { - rollupMetadata, err := e.storage.FetchRollupMetadata(hash) +func (e *enclaveImpl) GetRollupData(ctx context.Context, hash common.L2RollupHash) (*common.PublicRollupMetadata, common.SystemError) { + rollupMetadata, err := e.storage.FetchRollupMetadata(ctx, hash) if err != nil { return nil, err } @@ -307,12 +308,12 @@ func (e *enclaveImpl) GetRollupData(hash common.L2RollupHash) (*common.PublicRol } // Status is only implemented by the RPC wrapper -func (e *enclaveImpl) Status() (common.Status, common.SystemError) { +func (e *enclaveImpl) Status(ctx context.Context) (common.Status, common.SystemError) { if e.stopControl.IsStopping() { return common.Status{StatusCode: common.Unavailable}, responses.ToInternalError(fmt.Errorf("requested Status with the enclave stopping")) } - _, err := e.storage.FetchSecret() + _, err := e.storage.FetchSecret(ctx) if err != nil { if errors.Is(err, errutil.ErrNotFound) { return common.Status{StatusCode: common.AwaitingSecret, L2Head: _noHeadBatch}, nil @@ -320,7 +321,7 @@ func (e *enclaveImpl) Status() (common.Status, common.SystemError) { return common.Status{StatusCode: common.Unavailable}, responses.ToInternalError(err) } var l1HeadHash gethcommon.Hash - l1Head, err := e.l1BlockProcessor.GetHead() + l1Head, err := e.l1BlockProcessor.GetHead(ctx) if err != nil { // this might be normal while enclave is starting up, just send empty hash e.logger.Debug("failed to fetch L1 head block for status response", log.ErrKey, err) @@ -330,7 +331,7 @@ func (e *enclaveImpl) Status() (common.Status, common.SystemError) { // we use zero when there's no head batch yet, the first seq number is 1 l2HeadSeqNo := _noHeadBatch // this is the highest seq number that has been received and stored on the enclave (it may not have been executed) - currSeqNo, err := e.storage.FetchCurrentSequencerNo() + currSeqNo, err := e.storage.FetchCurrentSequencerNo(ctx) if err != nil { // this might be normal while enclave is starting up, just send empty hash e.logger.Debug("failed to fetch L2 head batch for status response", log.ErrKey, err) @@ -359,8 +360,8 @@ func (e *enclaveImpl) sendBatch(batch *core.Batch, outChannel chan common.Stream } // this function is only called when the executed batch is the new head -func (e *enclaveImpl) streamEventsForNewHeadBatch(batch *core.Batch, receipts types.Receipts, outChannel chan common.StreamL2UpdatesResponse) { - logs, err := e.subscriptionManager.GetSubscribedLogsForBatch(batch, receipts) +func (e *enclaveImpl) streamEventsForNewHeadBatch(ctx context.Context, batch *core.Batch, receipts types.Receipts, outChannel chan common.StreamL2UpdatesResponse) { + logs, err := e.subscriptionManager.GetSubscribedLogsForBatch(ctx, batch, receipts) e.logger.Debug("Stream Events for", log.BatchHashKey, batch.Hash(), "nr_events", len(logs)) if err != nil { e.logger.Error("Error while getting subscription logs", log.ErrKey, err) @@ -384,7 +385,7 @@ func (e *enclaveImpl) StreamL2Updates() (chan common.StreamL2UpdatesResponse, fu e.registry.SubscribeForExecutedBatches(func(batch *core.Batch, receipts types.Receipts) { e.sendBatch(batch, l2UpdatesChannel) if receipts != nil { - e.streamEventsForNewHeadBatch(batch, receipts, l2UpdatesChannel) + e.streamEventsForNewHeadBatch(context.Background(), batch, receipts, l2UpdatesChannel) } }) @@ -394,7 +395,7 @@ func (e *enclaveImpl) StreamL2Updates() (chan common.StreamL2UpdatesResponse, fu } // SubmitL1Block is used to update the enclave with an additional L1 block. -func (e *enclaveImpl) SubmitL1Block(block types.Block, receipts types.Receipts, _ bool) (*common.BlockSubmissionResponse, common.SystemError) { +func (e *enclaveImpl) SubmitL1Block(ctx context.Context, block common.L1Block, receipts common.L1Receipts, isLatest bool) (*common.BlockSubmissionResponse, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested SubmitL1Block with the enclave stopping")) } @@ -407,30 +408,30 @@ func (e *enclaveImpl) SubmitL1Block(block types.Block, receipts types.Receipts, // If the block and receipts do not match, reject the block. br, err := common.ParseBlockAndReceipts(&block, &receipts) if err != nil { - return nil, e.rejectBlockErr(fmt.Errorf("could not submit L1 block. Cause: %w", err)) + return nil, e.rejectBlockErr(ctx, fmt.Errorf("could not submit L1 block. Cause: %w", err)) } - result, err := e.ingestL1Block(br) + result, err := e.ingestL1Block(ctx, br) if err != nil { - return nil, e.rejectBlockErr(fmt.Errorf("could not submit L1 block. Cause: %w", err)) + return nil, e.rejectBlockErr(ctx, fmt.Errorf("could not submit L1 block. Cause: %w", err)) } if result.IsFork() { e.logger.Info(fmt.Sprintf("Detected fork at block %s with height %d", block.Hash(), block.Number())) } - err = e.service.OnL1Block(block, result) + err = e.service.OnL1Block(ctx, block, result) if err != nil { - return nil, e.rejectBlockErr(fmt.Errorf("could not submit L1 block. Cause: %w", err)) + return nil, e.rejectBlockErr(ctx, fmt.Errorf("could not submit L1 block. Cause: %w", err)) } - bsr := &common.BlockSubmissionResponse{ProducedSecretResponses: e.sharedSecretProcessor.ProcessNetworkSecretMsgs(br)} + bsr := &common.BlockSubmissionResponse{ProducedSecretResponses: e.sharedSecretProcessor.ProcessNetworkSecretMsgs(ctx, br)} return bsr, nil } -func (e *enclaveImpl) ingestL1Block(br *common.BlockAndReceipts) (*components.BlockIngestionType, error) { +func (e *enclaveImpl) ingestL1Block(ctx context.Context, br *common.BlockAndReceipts) (*components.BlockIngestionType, error) { e.logger.Info("Start ingesting block", log.BlockHashKey, br.Block.Hash()) - ingestion, err := e.l1BlockProcessor.Process(br) + ingestion, err := e.l1BlockProcessor.Process(ctx, br) if err != nil { // only warn for unexpected errors if errors.Is(err, errutil.ErrBlockAncestorNotFound) || errors.Is(err, errutil.ErrBlockAlreadyProcessed) { @@ -441,14 +442,14 @@ func (e *enclaveImpl) ingestL1Block(br *common.BlockAndReceipts) (*components.Bl return nil, err } - err = e.rollupConsumer.ProcessRollupsInBlock(br) + err = e.rollupConsumer.ProcessRollupsInBlock(ctx, br) if err != nil && !errors.Is(err, components.ErrDuplicateRollup) { e.logger.Error("Encountered error while processing l1 block", log.ErrKey, err) // Unsure what to do here; block has been stored } if ingestion.IsFork() { - err := e.service.OnL1Fork(ingestion.ChainFork) + err := e.service.OnL1Fork(ctx, ingestion.ChainFork) if err != nil { return nil, err } @@ -456,11 +457,11 @@ func (e *enclaveImpl) ingestL1Block(br *common.BlockAndReceipts) (*components.Bl return ingestion, nil } -func (e *enclaveImpl) SubmitTx(encryptedTxParams common.EncryptedTx) (*responses.RawTx, common.SystemError) { +func (e *enclaveImpl) SubmitTx(ctx context.Context, encryptedTxParams common.EncryptedTx) (*responses.RawTx, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested SubmitTx with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedTxParams, rpc.SubmitTxValidate, rpc.SubmitTxExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedTxParams, rpc.SubmitTxValidate, rpc.SubmitTxExecute) } func (e *enclaveImpl) Validator() nodetype.ObsValidator { @@ -479,7 +480,7 @@ func (e *enclaveImpl) Sequencer() nodetype.Sequencer { return sequencer } -func (e *enclaveImpl) SubmitBatch(extBatch *common.ExtBatch) common.SystemError { +func (e *enclaveImpl) SubmitBatch(ctx context.Context, extBatch *common.ExtBatch) common.SystemError { if e.stopControl.IsStopping() { return responses.ToInternalError(fmt.Errorf("requested SubmitBatch with the enclave stopping")) } @@ -489,7 +490,7 @@ func (e *enclaveImpl) SubmitBatch(extBatch *common.ExtBatch) common.SystemError e.logger.Info("Received new p2p batch", log.BatchHeightKey, extBatch.Header.Number, log.BatchHashKey, extBatch.Hash(), "l1", extBatch.Header.L1Proof) seqNo := extBatch.Header.SequencerOrderNo.Uint64() if seqNo > common.L2GenesisSeqNo+1 { - _, err := e.storage.FetchBatchBySeqNo(seqNo - 1) + _, err := e.storage.FetchBatchBySeqNo(ctx, seqNo-1) if err != nil { return responses.ToInternalError(fmt.Errorf("could not find previous batch with seq: %d", seqNo-1)) } @@ -506,7 +507,7 @@ func (e *enclaveImpl) SubmitBatch(extBatch *common.ExtBatch) common.SystemError } // calculate the converted hash, and store it in the db for chaining of the converted chain - convertedHeader, err := e.gethEncodingService.CreateEthHeaderForBatch(extBatch.Header) + convertedHeader, err := e.gethEncodingService.CreateEthHeaderForBatch(ctx, extBatch.Header) if err != nil { return err } @@ -515,12 +516,12 @@ func (e *enclaveImpl) SubmitBatch(extBatch *common.ExtBatch) common.SystemError defer e.mainMutex.Unlock() // if the signature is valid, then store the batch together with the converted hash - err = e.storage.StoreBatch(batch, convertedHeader.Hash()) + err = e.storage.StoreBatch(ctx, batch, convertedHeader.Hash()) if err != nil { return responses.ToInternalError(fmt.Errorf("could not store batch. Cause: %w", err)) } - err = e.Validator().ExecuteStoredBatches() + err = e.Validator().ExecuteStoredBatches(ctx) if err != nil { return responses.ToInternalError(fmt.Errorf("could not execute batches. Cause: %w", err)) } @@ -528,7 +529,7 @@ func (e *enclaveImpl) SubmitBatch(extBatch *common.ExtBatch) common.SystemError return nil } -func (e *enclaveImpl) CreateBatch(skipBatchIfEmpty bool) common.SystemError { +func (e *enclaveImpl) CreateBatch(ctx context.Context, skipBatchIfEmpty bool) common.SystemError { defer core.LogMethodDuration(e.logger, measure.NewStopwatch(), "CreateBatch call ended") if e.stopControl.IsStopping() { return responses.ToInternalError(fmt.Errorf("requested CreateBatch with the enclave stopping")) @@ -537,7 +538,7 @@ func (e *enclaveImpl) CreateBatch(skipBatchIfEmpty bool) common.SystemError { e.mainMutex.Lock() defer e.mainMutex.Unlock() - err := e.Sequencer().CreateBatch(skipBatchIfEmpty) + err := e.Sequencer().CreateBatch(ctx, skipBatchIfEmpty) if err != nil { return responses.ToInternalError(err) } @@ -545,7 +546,7 @@ func (e *enclaveImpl) CreateBatch(skipBatchIfEmpty bool) common.SystemError { return nil } -func (e *enclaveImpl) CreateRollup(fromSeqNo uint64) (*common.ExtRollup, common.SystemError) { +func (e *enclaveImpl) CreateRollup(ctx context.Context, fromSeqNo uint64) (*common.ExtRollup, common.SystemError) { defer core.LogMethodDuration(e.logger, measure.NewStopwatch(), "CreateRollup call ended") if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GenerateRollup with the enclave stopping")) @@ -559,7 +560,7 @@ func (e *enclaveImpl) CreateRollup(fromSeqNo uint64) (*common.ExtRollup, common. return nil, responses.ToInternalError(fmt.Errorf("not initialised yet")) } - rollup, err := e.Sequencer().CreateRollup(fromSeqNo) + rollup, err := e.Sequencer().CreateRollup(ctx, fromSeqNo) if err != nil { return nil, responses.ToInternalError(err) } @@ -568,39 +569,39 @@ func (e *enclaveImpl) CreateRollup(fromSeqNo uint64) (*common.ExtRollup, common. // ObsCall handles param decryption, validation and encryption // and requests the Rollup chain to execute the payload (eth_call) -func (e *enclaveImpl) ObsCall(encryptedParams common.EncryptedParamsCall) (*responses.Call, common.SystemError) { +func (e *enclaveImpl) ObsCall(ctx context.Context, encryptedParams common.EncryptedParamsCall) (*responses.Call, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested ObsCall with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.TenCallValidate, rpc.TenCallExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.TenCallValidate, rpc.TenCallExecute) } -func (e *enclaveImpl) GetTransactionCount(encryptedParams common.EncryptedParamsGetTxCount) (*responses.TxCount, common.SystemError) { +func (e *enclaveImpl) GetTransactionCount(ctx context.Context, encryptedParams common.EncryptedParamsGetTxCount) (*responses.TxCount, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetTransactionCount with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.GetTransactionCountValidate, rpc.GetTransactionCountExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.GetTransactionCountValidate, rpc.GetTransactionCountExecute) } -func (e *enclaveImpl) GetTransaction(encryptedParams common.EncryptedParamsGetTxByHash) (*responses.TxByHash, common.SystemError) { +func (e *enclaveImpl) GetTransaction(ctx context.Context, encryptedParams common.EncryptedParamsGetTxByHash) (*responses.TxByHash, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetTransaction with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.GetTransactionValidate, rpc.GetTransactionExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.GetTransactionValidate, rpc.GetTransactionExecute) } -func (e *enclaveImpl) GetTransactionReceipt(encryptedParams common.EncryptedParamsGetTxReceipt) (*responses.TxReceipt, common.SystemError) { +func (e *enclaveImpl) GetTransactionReceipt(ctx context.Context, encryptedParams common.EncryptedParamsGetTxReceipt) (*responses.TxReceipt, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetTransactionReceipt with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.GetTransactionReceiptValidate, rpc.GetTransactionReceiptExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.GetTransactionReceiptValidate, rpc.GetTransactionReceiptExecute) } -func (e *enclaveImpl) Attestation() (*common.AttestationReport, common.SystemError) { +func (e *enclaveImpl) Attestation(ctx context.Context) (*common.AttestationReport, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested ObsCall with the enclave stopping")) } @@ -608,7 +609,7 @@ func (e *enclaveImpl) Attestation() (*common.AttestationReport, common.SystemErr if e.enclaveKey == nil { return nil, responses.ToInternalError(fmt.Errorf("public key not initialized, we can't produce the attestation report")) } - report, err := e.attestationProvider.GetReport(e.enclaveKey.PublicKeyBytes(), e.enclaveKey.EnclaveID(), e.config.HostAddress) + report, err := e.attestationProvider.GetReport(ctx, e.enclaveKey.PublicKeyBytes(), e.enclaveKey.EnclaveID(), e.config.HostAddress) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("could not produce remote report. Cause %w", err)) } @@ -616,13 +617,13 @@ func (e *enclaveImpl) Attestation() (*common.AttestationReport, common.SystemErr } // GenerateSecret - the genesis enclave is responsible with generating the secret entropy -func (e *enclaveImpl) GenerateSecret() (common.EncryptedSharedEnclaveSecret, common.SystemError) { +func (e *enclaveImpl) GenerateSecret(ctx context.Context) (common.EncryptedSharedEnclaveSecret, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GenerateSecret with the enclave stopping")) } secret := crypto.GenerateEntropy(e.logger) - err := e.storage.StoreSecret(secret) + err := e.storage.StoreSecret(ctx, secret) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("could not store secret. Cause: %w", err)) } @@ -634,7 +635,7 @@ func (e *enclaveImpl) GenerateSecret() (common.EncryptedSharedEnclaveSecret, com } // InitEnclave - initialise an enclave with a seed received by another enclave -func (e *enclaveImpl) InitEnclave(s common.EncryptedSharedEnclaveSecret) common.SystemError { +func (e *enclaveImpl) InitEnclave(ctx context.Context, s common.EncryptedSharedEnclaveSecret) common.SystemError { if e.stopControl.IsStopping() { return responses.ToInternalError(fmt.Errorf("requested InitEnclave with the enclave stopping")) } @@ -643,7 +644,7 @@ func (e *enclaveImpl) InitEnclave(s common.EncryptedSharedEnclaveSecret) common. if err != nil { return responses.ToInternalError(err) } - err = e.storage.StoreSecret(*secret) + err = e.storage.StoreSecret(ctx, *secret) if err != nil { return responses.ToInternalError(fmt.Errorf("could not store secret. Cause: %w", err)) } @@ -651,34 +652,34 @@ func (e *enclaveImpl) InitEnclave(s common.EncryptedSharedEnclaveSecret) common. return nil } -func (e *enclaveImpl) EnclaveID() (common.EnclaveID, common.SystemError) { +func (e *enclaveImpl) EnclaveID(context.Context) (common.EnclaveID, common.SystemError) { return e.enclaveKey.EnclaveID(), nil } // GetBalance handles param decryption, validation and encryption // and requests the Rollup chain to execute the payload (eth_getBalance) -func (e *enclaveImpl) GetBalance(encryptedParams common.EncryptedParamsGetBalance) (*responses.Balance, common.SystemError) { +func (e *enclaveImpl) GetBalance(ctx context.Context, encryptedParams common.EncryptedParamsGetBalance) (*responses.Balance, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetBalance with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.GetBalanceValidate, rpc.GetBalanceExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.GetBalanceValidate, rpc.GetBalanceExecute) } // todo - needs to be encrypted -func (e *enclaveImpl) GetCode(address gethcommon.Address, batchHash *common.L2BatchHash) ([]byte, common.SystemError) { +func (e *enclaveImpl) GetCode(ctx context.Context, address gethcommon.Address, batchHash *gethcommon.Hash) ([]byte, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetCode with the enclave stopping")) } - stateDB, err := e.storage.CreateStateDB(*batchHash) + stateDB, err := e.storage.CreateStateDB(ctx, *batchHash) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("could not create stateDB. Cause: %w", err)) } return stateDB.GetCode(address), nil } -func (e *enclaveImpl) Subscribe(id gethrpc.ID, encryptedSubscription common.EncryptedParamsLogSubscription) common.SystemError { +func (e *enclaveImpl) Subscribe(ctx context.Context, id gethrpc.ID, encryptedSubscription common.EncryptedParamsLogSubscription) common.SystemError { if e.stopControl.IsStopping() { return responses.ToInternalError(fmt.Errorf("requested SubscribeForExecutedBatches with the enclave stopping")) } @@ -732,30 +733,30 @@ func (e *enclaveImpl) Stop() common.SystemError { // EstimateGas decrypts CallMsg data, runs the gas estimation for the data. // Using the callMsg.From Viewing Key, returns the encrypted gas estimation -func (e *enclaveImpl) EstimateGas(encryptedParams common.EncryptedParamsEstimateGas) (*responses.Gas, common.SystemError) { +func (e *enclaveImpl) EstimateGas(ctx context.Context, encryptedParams common.EncryptedParamsEstimateGas) (*responses.Gas, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested EstimateGas with the enclave stopping")) } defer core.LogMethodDuration(e.logger, measure.NewStopwatch(), "enclave.go:EstimateGas()") - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.EstimateGasValidate, rpc.EstimateGasExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.EstimateGasValidate, rpc.EstimateGasExecute) } -func (e *enclaveImpl) GetLogs(encryptedParams common.EncryptedParamsGetLogs) (*responses.Logs, common.SystemError) { +func (e *enclaveImpl) GetLogs(ctx context.Context, encryptedParams common.EncryptedParamsGetLogs) (*responses.Logs, common.SystemError) { if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetLogs with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.GetLogsValidate, rpc.GetLogsExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.GetLogsValidate, rpc.GetLogsExecute) } // HealthCheck returns whether the enclave is deemed healthy -func (e *enclaveImpl) HealthCheck() (bool, common.SystemError) { +func (e *enclaveImpl) HealthCheck(ctx context.Context) (bool, common.SystemError) { if e.stopControl.IsStopping() { return false, responses.ToInternalError(fmt.Errorf("requested HealthCheck with the enclave stopping")) } // check the storage health - storageHealthy, err := e.storage.HealthCheck() + storageHealthy, err := e.storage.HealthCheck(ctx) if err != nil { // simplest iteration, log the error and just return that it's not healthy e.logger.Info("HealthCheck failed for the enclave storage", log.ErrKey, err) @@ -780,7 +781,7 @@ func (e *enclaveImpl) HealthCheck() (bool, common.SystemError) { return storageHealthy && l1blockHealthy && l2batchHealthy, nil } -func (e *enclaveImpl) DebugTraceTransaction(txHash gethcommon.Hash, config *tracers.TraceConfig) (json.RawMessage, common.SystemError) { +func (e *enclaveImpl) DebugTraceTransaction(ctx context.Context, txHash gethcommon.Hash, config *tracers.TraceConfig) (json.RawMessage, common.SystemError) { // ensure the enclave is running if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested DebugTraceTransaction with the enclave stopping")) @@ -791,7 +792,7 @@ func (e *enclaveImpl) DebugTraceTransaction(txHash gethcommon.Hash, config *trac return nil, responses.ToInternalError(fmt.Errorf("debug namespace not enabled")) } - jsonMsg, err := e.debugger.DebugTraceTransaction(context.Background(), txHash, config) + jsonMsg, err := e.debugger.DebugTraceTransaction(ctx, txHash, config) if err != nil { if errors.Is(err, syserr.InternalError{}) { return nil, responses.ToInternalError(err) @@ -803,7 +804,7 @@ func (e *enclaveImpl) DebugTraceTransaction(txHash gethcommon.Hash, config *trac return jsonMsg, nil } -func (e *enclaveImpl) DebugEventLogRelevancy(txHash gethcommon.Hash) (json.RawMessage, common.SystemError) { +func (e *enclaveImpl) DebugEventLogRelevancy(ctx context.Context, txHash gethcommon.Hash) (json.RawMessage, common.SystemError) { // ensure the enclave is running if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested DebugEventLogRelevancy with the enclave stopping")) @@ -814,7 +815,7 @@ func (e *enclaveImpl) DebugEventLogRelevancy(txHash gethcommon.Hash) (json.RawMe return nil, responses.ToInternalError(fmt.Errorf("debug namespace not enabled")) } - jsonMsg, err := e.debugger.DebugEventLogRelevancy(txHash) + jsonMsg, err := e.debugger.DebugEventLogRelevancy(ctx, txHash) if err != nil { if errors.Is(err, syserr.InternalError{}) { return nil, responses.ToInternalError(err) @@ -826,37 +827,37 @@ func (e *enclaveImpl) DebugEventLogRelevancy(txHash gethcommon.Hash) (json.RawMe return jsonMsg, nil } -func (e *enclaveImpl) GetTotalContractCount() (*big.Int, common.SystemError) { +func (e *enclaveImpl) GetTotalContractCount(ctx context.Context) (*big.Int, common.SystemError) { // ensure the enclave is running if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetTotalContractCount with the enclave stopping")) } - return e.storage.GetContractCount() + return e.storage.GetContractCount(ctx) } -func (e *enclaveImpl) GetCustomQuery(encryptedParams common.EncryptedParamsGetStorageAt) (*responses.PrivateQueryResponse, common.SystemError) { +func (e *enclaveImpl) GetCustomQuery(ctx context.Context, encryptedParams common.EncryptedParamsGetStorageAt) (*responses.PrivateQueryResponse, common.SystemError) { // ensure the enclave is running if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetReceiptsByAddress with the enclave stopping")) } - return rpc.WithVKEncryption(e.rpcEncryptionManager, encryptedParams, rpc.GetCustomQueryValidate, rpc.GetCustomQueryExecute) + return rpc.WithVKEncryption(ctx, e.rpcEncryptionManager, encryptedParams, rpc.GetCustomQueryValidate, rpc.GetCustomQueryExecute) } -func (e *enclaveImpl) GetPublicTransactionData(pagination *common.QueryPagination) (*common.TransactionListingResponse, common.SystemError) { +func (e *enclaveImpl) GetPublicTransactionData(ctx context.Context, pagination *common.QueryPagination) (*common.TransactionListingResponse, common.SystemError) { // ensure the enclave is running if e.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested GetPublicTransactionData with the enclave stopping")) } - paginatedData, err := e.storage.GetPublicTransactionData(pagination) + paginatedData, err := e.storage.GetPublicTransactionData(ctx, pagination) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("unable to fetch data - %w", err)) } // Todo eventually make this a cacheable method - totalData, err := e.storage.GetPublicTransactionCount() + totalData, err := e.storage.GetPublicTransactionCount(ctx) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("unable to fetch data - %w", err)) } @@ -867,7 +868,7 @@ func (e *enclaveImpl) GetPublicTransactionData(pagination *common.QueryPaginatio }, nil } -func (e *enclaveImpl) EnclavePublicConfig() (*common.EnclavePublicConfig, common.SystemError) { +func (e *enclaveImpl) EnclavePublicConfig(context.Context) (*common.EnclavePublicConfig, common.SystemError) { address, systemError := e.crossChainProcessors.GetL2MessageBusAddress() if systemError != nil { return nil, systemError @@ -875,9 +876,9 @@ func (e *enclaveImpl) EnclavePublicConfig() (*common.EnclavePublicConfig, common return &common.EnclavePublicConfig{L2MessageBusAddress: address}, nil } -func (e *enclaveImpl) rejectBlockErr(cause error) *errutil.BlockRejectError { +func (e *enclaveImpl) rejectBlockErr(ctx context.Context, cause error) *errutil.BlockRejectError { var hash common.L1BlockHash - l1Head, err := e.l1BlockProcessor.GetHead() + l1Head, err := e.l1BlockProcessor.GetHead(ctx) // todo - handle error if err == nil { hash = l1Head.Hash() @@ -890,12 +891,12 @@ func (e *enclaveImpl) rejectBlockErr(cause error) *errutil.BlockRejectError { // this function looks at the batch chain and makes sure the resulting stateDB snapshots are available, replaying them if needed // (if there had been a clean shutdown and all stateDB data was persisted this should do nothing) -func restoreStateDBCache(storage storage.Storage, registry components.BatchRegistry, producer components.BatchExecutor, gen *genesis.Genesis, logger gethlog.Logger) error { +func restoreStateDBCache(ctx context.Context, storage storage.Storage, registry components.BatchRegistry, producer components.BatchExecutor, gen *genesis.Genesis, logger gethlog.Logger) error { if registry.HeadBatchSeq() == nil { // not initialised yet return nil } - batch, err := storage.FetchBatchBySeqNo(registry.HeadBatchSeq().Uint64()) + batch, err := storage.FetchBatchBySeqNo(ctx, registry.HeadBatchSeq().Uint64()) if err != nil { if errors.Is(err, errutil.ErrNotFound) { // there is no head batch, this is probably a new node - there is no state to rebuild @@ -904,9 +905,9 @@ func restoreStateDBCache(storage storage.Storage, registry components.BatchRegis } return fmt.Errorf("unexpected error fetching head batch to resync- %w", err) } - if !stateDBAvailableForBatch(storage, batch.Hash()) { + if !stateDBAvailableForBatch(ctx, storage, batch.Hash()) { logger.Info("state not available for latest batch after restart - rebuilding stateDB cache from batches") - err = replayBatchesToValidState(storage, registry, producer, gen, logger) + err = replayBatchesToValidState(ctx, storage, registry, producer, gen, logger) if err != nil { return fmt.Errorf("unable to replay batches to restore valid state - %w", err) } @@ -918,8 +919,8 @@ func restoreStateDBCache(storage storage.Storage, registry components.BatchRegis // batch in the chain and is used to query state at a certain height. // // This method checks if the stateDB data is available for a given batch hash (so it can be restored if not) -func stateDBAvailableForBatch(storage storage.Storage, hash common.L2BatchHash) bool { - _, err := storage.CreateStateDB(hash) +func stateDBAvailableForBatch(ctx context.Context, storage storage.Storage, hash common.L2BatchHash) bool { + _, err := storage.CreateStateDB(ctx, hash) return err == nil } @@ -927,24 +928,24 @@ func stateDBAvailableForBatch(storage storage.Storage, hash common.L2BatchHash) // 1. step backwards from head batch until we find a batch that is already in stateDB cache, builds list of batches to replay // 2. iterate that list of batches from the earliest, process the transactions to calculate and cache the stateDB // todo (#1416) - get unit test coverage around this (and L2 Chain code more widely, see ticket #1416 ) -func replayBatchesToValidState(storage storage.Storage, registry components.BatchRegistry, batchExecutor components.BatchExecutor, gen *genesis.Genesis, logger gethlog.Logger) error { +func replayBatchesToValidState(ctx context.Context, storage storage.Storage, registry components.BatchRegistry, batchExecutor components.BatchExecutor, gen *genesis.Genesis, logger gethlog.Logger) error { // this slice will be a stack of batches to replay as we walk backwards in search of latest valid state // todo - consider capping the size of this batch list using FIFO to avoid memory issues, and then repeating as necessary var batchesToReplay []*core.Batch // `batchToReplayFrom` variable will eventually be the latest batch for which we are able to produce a StateDB // - we will then set that as the head of the L2 so that this node can rebuild its missing state - batchToReplayFrom, err := storage.FetchBatchBySeqNo(registry.HeadBatchSeq().Uint64()) + batchToReplayFrom, err := storage.FetchBatchBySeqNo(ctx, registry.HeadBatchSeq().Uint64()) if err != nil { return fmt.Errorf("no head batch found in DB but expected to replay batches - %w", err) } // loop backwards building a slice of all batches that don't have cached stateDB data available - for !stateDBAvailableForBatch(storage, batchToReplayFrom.Hash()) { + for !stateDBAvailableForBatch(ctx, storage, batchToReplayFrom.Hash()) { batchesToReplay = append(batchesToReplay, batchToReplayFrom) if batchToReplayFrom.NumberU64() == 0 { // no more parents to check, replaying from genesis break } - batchToReplayFrom, err = storage.FetchBatch(batchToReplayFrom.Header.ParentHash) + batchToReplayFrom, err = storage.FetchBatch(ctx, batchToReplayFrom.Header.ParentHash) if err != nil { return fmt.Errorf("unable to fetch previous batch while rolling back to stable state - %w", err) } @@ -965,7 +966,7 @@ func replayBatchesToValidState(storage storage.Storage, registry components.Batc } // calculate the stateDB after this batch and store it in the cache - _, err := batchExecutor.ExecuteBatch(batch) + _, err := batchExecutor.ExecuteBatch(ctx, batch) if err != nil { return err } diff --git a/go/enclave/events/subscription_manager.go b/go/enclave/events/subscription_manager.go index 1d16f6f462..3af2250c56 100644 --- a/go/enclave/events/subscription_manager.go +++ b/go/enclave/events/subscription_manager.go @@ -1,6 +1,7 @@ package events import ( + "context" "encoding/json" "fmt" "sync" @@ -88,9 +89,9 @@ func (s *SubscriptionManager) RemoveSubscription(id gethrpc.ID) { } // FilterLogsForReceipt removes the logs that the sender of a transaction is not allowed to view -func FilterLogsForReceipt(receipt *types.Receipt, account *gethcommon.Address, storage storage.Storage) ([]*types.Log, error) { +func FilterLogsForReceipt(ctx context.Context, receipt *types.Receipt, account *gethcommon.Address, storage storage.Storage) ([]*types.Log, error) { filteredLogs := []*types.Log{} - stateDB, err := storage.CreateStateDB(receipt.BlockHash) + stateDB, err := storage.CreateStateDB(ctx, receipt.BlockHash) if err != nil { return nil, fmt.Errorf("could not create state DB to filter logs. Cause: %w", err) } @@ -107,7 +108,7 @@ func FilterLogsForReceipt(receipt *types.Receipt, account *gethcommon.Address, s // GetSubscribedLogsForBatch - Retrieves and encrypts the logs for the batch in live mode. // The assumption is that this function is called synchronously after the batch is produced -func (s *SubscriptionManager) GetSubscribedLogsForBatch(batch *core.Batch, receipts types.Receipts) (common.EncryptedSubscriptionLogs, error) { +func (s *SubscriptionManager) GetSubscribedLogsForBatch(ctx context.Context, batch *core.Batch, receipts types.Receipts) (common.EncryptedSubscriptionLogs, error) { s.subscriptionMutex.RLock() defer s.subscriptionMutex.RUnlock() @@ -129,7 +130,7 @@ func (s *SubscriptionManager) GetSubscribedLogsForBatch(batch *core.Batch, recei } // the stateDb is needed to extract the user addresses from the topics - stateDB, err := s.storage.CreateStateDB(batch.Hash()) + stateDB, err := s.storage.CreateStateDB(ctx, batch.Hash()) if err != nil { return nil, fmt.Errorf("could not create state DB to filter logs. Cause: %w", err) } diff --git a/go/enclave/evm/chain_context.go b/go/enclave/evm/chain_context.go index 31941479d7..75e4b942e6 100644 --- a/go/enclave/evm/chain_context.go +++ b/go/enclave/evm/chain_context.go @@ -1,8 +1,10 @@ package evm import ( + "context" "errors" + "github.com/ten-protocol/go-ten/go/config" "github.com/ten-protocol/go-ten/go/enclave/storage" "github.com/ethereum/go-ethereum/common" @@ -17,14 +19,16 @@ import ( // ObscuroChainContext - basic implementation of the ChainContext needed for the EVM integration type ObscuroChainContext struct { storage storage.Storage + config config.EnclaveConfig gethEncodingService gethencoding.EncodingService logger gethlog.Logger } // NewObscuroChainContext returns a new instance of the ObscuroChainContext given a storage ( and logger ) -func NewObscuroChainContext(storage storage.Storage, gethEncodingService gethencoding.EncodingService, logger gethlog.Logger) *ObscuroChainContext { +func NewObscuroChainContext(storage storage.Storage, gethEncodingService gethencoding.EncodingService, config config.EnclaveConfig, logger gethlog.Logger) *ObscuroChainContext { return &ObscuroChainContext{ storage: storage, + config: config, gethEncodingService: gethEncodingService, logger: logger, } @@ -35,7 +39,10 @@ func (occ *ObscuroChainContext) Engine() consensus.Engine { } func (occ *ObscuroChainContext) GetHeader(hash common.Hash, _ uint64) *types.Header { - batch, err := occ.storage.FetchBatch(hash) + ctx, cancelCtx := context.WithTimeout(context.Background(), occ.config.RPCTimeout) + defer cancelCtx() + + batch, err := occ.storage.FetchBatch(ctx, hash) if err != nil { if errors.Is(err, errutil.ErrNotFound) { return nil @@ -43,7 +50,7 @@ func (occ *ObscuroChainContext) GetHeader(hash common.Hash, _ uint64) *types.Hea occ.logger.Crit("Could not retrieve rollup", log.ErrKey, err) } - h, err := occ.gethEncodingService.CreateEthHeaderForBatch(batch.Header) + h, err := occ.gethEncodingService.CreateEthHeaderForBatch(ctx, batch.Header) if err != nil { occ.logger.Crit("Could not convert to eth header", log.ErrKey, err) return nil diff --git a/go/enclave/evm/ethchainadapter/eth_chainadapter.go b/go/enclave/evm/ethchainadapter/eth_chainadapter.go index 7ac12e930f..c64529be46 100644 --- a/go/enclave/evm/ethchainadapter/eth_chainadapter.go +++ b/go/enclave/evm/ethchainadapter/eth_chainadapter.go @@ -1,6 +1,7 @@ package ethchainadapter import ( + "context" "math/big" gethcommon "github.com/ethereum/go-ethereum/common" @@ -10,6 +11,7 @@ import ( "github.com/ethereum/go-ethereum/params" "github.com/ten-protocol/go-ten/go/common/gethencoding" "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/go/config" "github.com/ten-protocol/go-ten/go/enclave/components" "github.com/ten-protocol/go-ten/go/enclave/core" "github.com/ten-protocol/go-ten/go/enclave/storage" @@ -25,17 +27,19 @@ type EthChainAdapter struct { batchRegistry components.BatchRegistry gethEncoding gethencoding.EncodingService storage storage.Storage + config config.EnclaveConfig chainID *big.Int logger gethlog.Logger } // NewEthChainAdapter returns a new instance -func NewEthChainAdapter(chainID *big.Int, batchRegistry components.BatchRegistry, storage storage.Storage, gethEncoding gethencoding.EncodingService, logger gethlog.Logger) *EthChainAdapter { +func NewEthChainAdapter(chainID *big.Int, batchRegistry components.BatchRegistry, storage storage.Storage, gethEncoding gethencoding.EncodingService, config config.EnclaveConfig, logger gethlog.Logger) *EthChainAdapter { return &EthChainAdapter{ newHeadChan: make(chan gethcore.ChainHeadEvent), batchRegistry: batchRegistry, storage: storage, gethEncoding: gethEncoding, + config: config, chainID: chainID, logger: logger, } @@ -52,12 +56,15 @@ func (e *EthChainAdapter) CurrentBlock() *gethtypes.Header { if currentBatchSeqNo == nil { return nil } - currentBatch, err := e.storage.FetchBatchBySeqNo(currentBatchSeqNo.Uint64()) + ctx, cancelCtx := context.WithTimeout(context.Background(), e.config.RPCTimeout) + defer cancelCtx() + + currentBatch, err := e.storage.FetchBatchBySeqNo(ctx, currentBatchSeqNo.Uint64()) if err != nil { e.logger.Warn("unable to retrieve batch seq no", "currentBatchSeqNo", currentBatchSeqNo, log.ErrKey, err) return nil } - batch, err := e.gethEncoding.CreateEthHeaderForBatch(currentBatch.Header) + batch, err := e.gethEncoding.CreateEthHeaderForBatch(ctx, currentBatch.Header) if err != nil { e.logger.Warn("unable to convert batch to eth header ", "currentBatchSeqNo", currentBatchSeqNo, log.ErrKey, err) return nil @@ -85,9 +92,11 @@ func (e *EthChainAdapter) SubscribeChainHeadEvent(ch chan<- gethcore.ChainHeadEv // GetBlock retrieves a specific block, used during pool resets. func (e *EthChainAdapter) GetBlock(_ gethcommon.Hash, number uint64) *gethtypes.Block { var batch *core.Batch + ctx, cancelCtx := context.WithTimeout(context.Background(), e.config.RPCTimeout) + defer cancelCtx() // to avoid a costly select to the db, check whether the batches requested are the last ones which are cached - headBatch, err := e.storage.FetchBatchBySeqNo(e.batchRegistry.HeadBatchSeq().Uint64()) + headBatch, err := e.storage.FetchBatchBySeqNo(ctx, e.batchRegistry.HeadBatchSeq().Uint64()) if err != nil { e.logger.Error("unable to get head batch", log.ErrKey, err) return nil @@ -95,20 +104,20 @@ func (e *EthChainAdapter) GetBlock(_ gethcommon.Hash, number uint64) *gethtypes. if headBatch.Number().Uint64() == number { batch = headBatch } else if headBatch.Number().Uint64()-1 == number { - batch, err = e.storage.FetchBatch(headBatch.Header.ParentHash) + batch, err = e.storage.FetchBatch(ctx, headBatch.Header.ParentHash) if err != nil { e.logger.Error("unable to get parent of head batch", log.ErrKey, err, log.BatchHashKey, headBatch.Header.ParentHash) return nil } } else { - batch, err = e.storage.FetchBatchByHeight(number) + batch, err = e.storage.FetchBatchByHeight(ctx, number) if err != nil { e.logger.Error("unable to get batch by height", log.BatchHeightKey, number, log.ErrKey, err) return nil } } - nfromBatch, err := e.gethEncoding.CreateEthBlockFromBatch(batch) + nfromBatch, err := e.gethEncoding.CreateEthBlockFromBatch(ctx, batch) if err != nil { e.logger.Error("unable to convert batch to eth block", log.ErrKey, err) return nil @@ -127,7 +136,9 @@ func (e *EthChainAdapter) StateAt(root gethcommon.Hash) (*state.StateDB, error) } func (e *EthChainAdapter) IngestNewBlock(batch *core.Batch) error { - convertedBlock, err := e.gethEncoding.CreateEthBlockFromBatch(batch) + ctx, cancelCtx := context.WithTimeout(context.Background(), e.config.RPCTimeout) + defer cancelCtx() + convertedBlock, err := e.gethEncoding.CreateEthBlockFromBatch(ctx, batch) if err != nil { return err } diff --git a/go/enclave/evm/evm_facade.go b/go/enclave/evm/evm_facade.go index 74d3fa3096..a7702d66d2 100644 --- a/go/enclave/evm/evm_facade.go +++ b/go/enclave/evm/evm_facade.go @@ -1,13 +1,16 @@ package evm import ( + "context" "errors" "fmt" "math/big" + _ "unsafe" + + "github.com/ten-protocol/go-ten/go/config" // unsafe package imported in order to link to a private function in go-ethereum. // This allows us to customize the message generated from a signed transaction and inject custom gas logic. - _ "unsafe" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common/hexutil" @@ -34,24 +37,26 @@ import ( // header - the header of the rollup where this transaction will be included // fromTxIndex - for the receipts and events, the evm needs to know for each transaction the order in which it was executed in the block. func ExecuteTransactions( + ctx context.Context, txs common.L2PricedTransactions, s *state.StateDB, header *common.BatchHeader, storage storage.Storage, gethEncodingService gethencoding.EncodingService, chainConfig *params.ChainConfig, + config config.EnclaveConfig, fromTxIndex int, noBaseFee bool, batchGasLimit uint64, logger gethlog.Logger, ) map[common.TxHash]interface{} { // todo - return error - chain, vmCfg := initParams(storage, gethEncodingService, noBaseFee, logger) + chain, vmCfg := initParams(storage, gethEncodingService, config, noBaseFee, logger) gp := gethcore.GasPool(batchGasLimit) zero := uint64(0) usedGas := &zero result := map[common.TxHash]interface{}{} - ethHeader, err := gethEncodingService.CreateEthHeaderForBatch(header) + ethHeader, err := gethEncodingService.CreateEthHeaderForBatch(ctx, header) if err != nil { logger.Crit("Could not convert to eth header", log.ErrKey, err) return nil @@ -230,6 +235,7 @@ func logReceipt(r *types.Receipt, logger gethlog.Logger) { // ExecuteObsCall - executes the eth_call call func ExecuteObsCall( + ctx context.Context, msg *gethcore.Message, s *state.StateDB, header *common.BatchHeader, @@ -237,6 +243,7 @@ func ExecuteObsCall( gethEncodingService gethencoding.EncodingService, chainConfig *params.ChainConfig, gasEstimationCap uint64, + config config.EnclaveConfig, logger gethlog.Logger, ) (*gethcore.ExecutionResult, error) { noBaseFee := true @@ -248,9 +255,9 @@ func ExecuteObsCall( gp := gethcore.GasPool(gasEstimationCap) gp.SetGas(gasEstimationCap) - chain, vmCfg := initParams(storage, gethEncodingService, noBaseFee, nil) + chain, vmCfg := initParams(storage, gethEncodingService, config, noBaseFee, nil) - ethHeader, err := gethEncodingService.CreateEthHeaderForBatch(header) + ethHeader, err := gethEncodingService.CreateEthHeaderForBatch(ctx, header) if err != nil { return nil, err } @@ -285,11 +292,11 @@ func ExecuteObsCall( return result, nil } -func initParams(storage storage.Storage, gethEncodingService gethencoding.EncodingService, noBaseFee bool, l gethlog.Logger) (*ObscuroChainContext, vm.Config) { +func initParams(storage storage.Storage, gethEncodingService gethencoding.EncodingService, config config.EnclaveConfig, noBaseFee bool, l gethlog.Logger) (*ObscuroChainContext, vm.Config) { vmCfg := vm.Config{ NoBaseFee: noBaseFee, } - return NewObscuroChainContext(storage, gethEncodingService, l), vmCfg + return NewObscuroChainContext(storage, gethEncodingService, config, l), vmCfg } func newErrorWithReasonAndCode(err error) error { diff --git a/go/enclave/genesis/genesis_test.go b/go/enclave/genesis/genesis_test.go index 00ac8da2a8..825a2b5c23 100644 --- a/go/enclave/genesis/genesis_test.go +++ b/go/enclave/genesis/genesis_test.go @@ -4,6 +4,9 @@ import ( "fmt" "math/big" "testing" + "time" + + "github.com/ten-protocol/go-ten/go/config" "github.com/ten-protocol/go-ten/go/enclave/storage" "github.com/ten-protocol/go-ten/go/enclave/storage/init/sqlite" @@ -34,7 +37,7 @@ func TestDefaultGenesis(t *testing.T) { t.Fatal("unexpected number of accounts") } - backingDB, err := sqlite.CreateTemporarySQLiteDB("", "", testlog.Logger()) + backingDB, err := sqlite.CreateTemporarySQLiteDB("", "", config.EnclaveConfig{RPCTimeout: time.Second}, testlog.Logger()) if err != nil { t.Fatalf("unable to create temp db: %s", err) } @@ -77,7 +80,7 @@ func TestCustomGenesis(t *testing.T) { t.Fatal("unexpected number of accounts") } - backingDB, err := sqlite.CreateTemporarySQLiteDB("", "", testlog.Logger()) + backingDB, err := sqlite.CreateTemporarySQLiteDB("", "", config.EnclaveConfig{RPCTimeout: time.Second}, testlog.Logger()) if err != nil { t.Fatalf("unable to create temp db: %s", err) } diff --git a/go/enclave/l2chain/interfaces.go b/go/enclave/l2chain/interfaces.go index 37f06320cc..6c29fd25a3 100644 --- a/go/enclave/l2chain/interfaces.go +++ b/go/enclave/l2chain/interfaces.go @@ -1,6 +1,8 @@ package l2chain import ( + "context" + gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" gethcore "github.com/ethereum/go-ethereum/core" @@ -19,17 +21,17 @@ type ObscuroChain interface { // For Contracts - the address of the deployer. // Note - this might be subject to change if we implement a more flexible mechanism // todo - support BlockNumberOrHash - AccountOwner(address gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*gethcommon.Address, error) + AccountOwner(ctx context.Context, address gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*gethcommon.Address, error) // GetBalanceAtBlock - will return the balance of a specific address at the specific given block number (batch number). - GetBalanceAtBlock(accountAddr gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*hexutil.Big, error) + GetBalanceAtBlock(ctx context.Context, accountAddr gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*hexutil.Big, error) // ObsCall - The interface for executing eth_call RPC commands against obscuro. - ObsCall(apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) + ObsCall(ctx context.Context, apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) // ObsCallAtBlock - Execute eth_call RPC against obscuro for a specific block (batch) number. - ObsCallAtBlock(apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) + ObsCallAtBlock(ctx context.Context, apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) // GetChainStateAtTransaction - returns the stateDB after applying all the transactions in the batch leading to the desired transaction. - GetChainStateAtTransaction(batch *core.Batch, txIndex int, reexec uint64) (*gethcore.Message, vm.BlockContext, *state.StateDB, error) + GetChainStateAtTransaction(ctx context.Context, batch *core.Batch, txIndex int, reexec uint64) (*gethcore.Message, vm.BlockContext, *state.StateDB, error) } diff --git a/go/enclave/l2chain/l2_chain.go b/go/enclave/l2chain/l2_chain.go index 5765de720f..7fb9366de5 100644 --- a/go/enclave/l2chain/l2_chain.go +++ b/go/enclave/l2chain/l2_chain.go @@ -1,10 +1,13 @@ package l2chain import ( + "context" "errors" "fmt" "math/big" + "github.com/ten-protocol/go-ten/go/config" + "github.com/ten-protocol/go-ten/go/enclave/storage" gethcommon "github.com/ethereum/go-ethereum/common" @@ -27,8 +30,8 @@ import ( ) type obscuroChain struct { - chainConfig *params.ChainConfig - + chainConfig *params.ChainConfig + config config.EnclaveConfig storage storage.Storage gethEncodingService gethencoding.EncodingService genesis *genesis.Genesis @@ -41,6 +44,7 @@ type obscuroChain struct { func NewChain( storage storage.Storage, + config config.EnclaveConfig, gethEncodingService gethencoding.EncodingService, chainConfig *params.ChainConfig, genesis *genesis.Genesis, @@ -50,6 +54,7 @@ func NewChain( ) ObscuroChain { return &obscuroChain{ storage: storage, + config: config, gethEncodingService: gethEncodingService, chainConfig: chainConfig, logger: logger, @@ -59,9 +64,9 @@ func NewChain( } } -func (oc *obscuroChain) AccountOwner(address gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*gethcommon.Address, error) { +func (oc *obscuroChain) AccountOwner(ctx context.Context, address gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*gethcommon.Address, error) { // check if account is a contract - isContract, err := oc.isAccountContractAtBlock(address, blockNumber) + isContract, err := oc.isAccountContractAtBlock(ctx, address, blockNumber) if err != nil { return nil, err } @@ -70,11 +75,11 @@ func (oc *obscuroChain) AccountOwner(address gethcommon.Address, blockNumber *ge } // If the address is a contract, find the signer of the deploy transaction - txHash, err := oc.storage.GetContractCreationTx(address) + txHash, err := oc.storage.GetContractCreationTx(ctx, address) if err != nil { return nil, err } - transaction, _, _, _, err := oc.storage.GetTransaction(*txHash) //nolint:dogsled + transaction, _, _, _, err := oc.storage.GetTransaction(ctx, *txHash) //nolint:dogsled if err != nil { return nil, err } @@ -87,8 +92,8 @@ func (oc *obscuroChain) AccountOwner(address gethcommon.Address, blockNumber *ge return &sender, nil } -func (oc *obscuroChain) GetBalanceAtBlock(accountAddr gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*hexutil.Big, error) { - chainState, err := oc.Registry.GetBatchStateAtHeight(blockNumber) +func (oc *obscuroChain) GetBalanceAtBlock(ctx context.Context, accountAddr gethcommon.Address, blockNumber *gethrpc.BlockNumber) (*hexutil.Big, error) { + chainState, err := oc.Registry.GetBatchStateAtHeight(ctx, blockNumber) if err != nil { return nil, fmt.Errorf("unable to get blockchain state - %w", err) } @@ -96,8 +101,8 @@ func (oc *obscuroChain) GetBalanceAtBlock(accountAddr gethcommon.Address, blockN return (*hexutil.Big)(chainState.GetBalance(accountAddr)), nil } -func (oc *obscuroChain) ObsCall(apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) { - result, err := oc.ObsCallAtBlock(apiArgs, blockNumber) +func (oc *obscuroChain) ObsCall(ctx context.Context, apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) { + result, err := oc.ObsCallAtBlock(ctx, apiArgs, blockNumber) if err != nil { oc.logger.Info(fmt.Sprintf("Obs_Call: failed to execute contract %s.", apiArgs.To), log.CtrErrKey, err.Error()) return nil, err @@ -115,14 +120,14 @@ func (oc *obscuroChain) ObsCall(apiArgs *gethapi.TransactionArgs, blockNumber *g return result, nil } -func (oc *obscuroChain) ObsCallAtBlock(apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) { +func (oc *obscuroChain) ObsCallAtBlock(ctx context.Context, apiArgs *gethapi.TransactionArgs, blockNumber *gethrpc.BlockNumber) (*gethcore.ExecutionResult, error) { // fetch the chain state at given batch - blockState, err := oc.Registry.GetBatchStateAtHeight(blockNumber) + blockState, err := oc.Registry.GetBatchStateAtHeight(ctx, blockNumber) if err != nil { return nil, err } - batch, err := oc.Registry.GetBatchAtHeight(*blockNumber) + batch, err := oc.Registry.GetBatchAtHeight(ctx, *blockNumber) if err != nil { return nil, fmt.Errorf("unable to fetch head state batch. Cause: %w", err) } @@ -141,7 +146,7 @@ func (oc *obscuroChain) ObsCallAtBlock(apiArgs *gethapi.TransactionArgs, blockNu batch.Header.Root.Hex()) }}) - result, err := evm.ExecuteObsCall(callMsg, blockState, batch.Header, oc.storage, oc.gethEncodingService, oc.chainConfig, oc.gasEstimationCap, oc.logger) + result, err := evm.ExecuteObsCall(ctx, callMsg, blockState, batch.Header, oc.storage, oc.gethEncodingService, oc.chainConfig, oc.gasEstimationCap, oc.config, oc.logger) if err != nil { // also return the result as the result can be evaluated on some errors like ErrIntrinsicGas return result, err @@ -152,13 +157,13 @@ func (oc *obscuroChain) ObsCallAtBlock(apiArgs *gethapi.TransactionArgs, blockNu // GetChainStateAtTransaction Returns the state of the chain at certain block height after executing transactions up to the selected transaction // TODO make this cacheable -func (oc *obscuroChain) GetChainStateAtTransaction(batch *core.Batch, txIndex int, _ uint64) (*gethcore.Message, vm.BlockContext, *state.StateDB, error) { +func (oc *obscuroChain) GetChainStateAtTransaction(ctx context.Context, batch *core.Batch, txIndex int, _ uint64) (*gethcore.Message, vm.BlockContext, *state.StateDB, error) { // Short circuit if it's genesis batch. if batch.NumberU64() == 0 { return nil, vm.BlockContext{}, nil, errors.New("no transaction in genesis") } // Create the parent state database - parent, err := oc.Registry.GetBatchAtHeight(gethrpc.BlockNumber(batch.NumberU64() - 1)) + parent, err := oc.Registry.GetBatchAtHeight(ctx, gethrpc.BlockNumber(batch.NumberU64()-1)) if err != nil { return nil, vm.BlockContext{}, nil, fmt.Errorf("unable to fetch parent batch - %w", err) } @@ -166,7 +171,7 @@ func (oc *obscuroChain) GetChainStateAtTransaction(batch *core.Batch, txIndex in // Lookup the statedb of parent batch from the live database, // otherwise regenerate it on the flight. - statedb, err := oc.Registry.GetBatchStateAtHeight(&parentBlockNumber) + statedb, err := oc.Registry.GetBatchStateAtHeight(ctx, &parentBlockNumber) if err != nil { return nil, vm.BlockContext{}, nil, err } @@ -185,9 +190,9 @@ func (oc *obscuroChain) GetChainStateAtTransaction(batch *core.Batch, txIndex in } txContext := gethcore.NewEVMTxContext(msg) - chain := evm.NewObscuroChainContext(oc.storage, oc.gethEncodingService, oc.logger) + chain := evm.NewObscuroChainContext(oc.storage, oc.gethEncodingService, oc.config, oc.logger) - blockHeader, err := oc.gethEncodingService.CreateEthHeaderForBatch(batch.Header) + blockHeader, err := oc.gethEncodingService.CreateEthHeaderForBatch(ctx, batch.Header) if err != nil { return nil, vm.BlockContext{}, nil, fmt.Errorf("unable to convert batch header to eth header - %w", err) } @@ -209,8 +214,8 @@ func (oc *obscuroChain) GetChainStateAtTransaction(batch *core.Batch, txIndex in } // Returns whether the account is a contract -func (oc *obscuroChain) isAccountContractAtBlock(accountAddr gethcommon.Address, blockNumber *gethrpc.BlockNumber) (bool, error) { - chainState, err := oc.Registry.GetBatchStateAtHeight(blockNumber) +func (oc *obscuroChain) isAccountContractAtBlock(ctx context.Context, accountAddr gethcommon.Address, blockNumber *gethrpc.BlockNumber) (bool, error) { + chainState, err := oc.Registry.GetBatchStateAtHeight(ctx, blockNumber) if err != nil { return false, fmt.Errorf("unable to get blockchain state - %w", err) } diff --git a/go/enclave/nodetype/interfaces.go b/go/enclave/nodetype/interfaces.go index f6d0ec051c..eb7844d229 100644 --- a/go/enclave/nodetype/interfaces.go +++ b/go/enclave/nodetype/interfaces.go @@ -1,6 +1,8 @@ package nodetype import ( + "context" + "github.com/ethereum/go-ethereum/core/types" "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/enclave/components" @@ -16,28 +18,28 @@ type NodeType interface { SubmitTransaction(*common.L2Tx) error // OnL1Fork - logic to be performed when there is an L1 Fork - OnL1Fork(fork *common.ChainFork) error + OnL1Fork(ctx context.Context, fork *common.ChainFork) error // OnL1Block - performed after the block was processed - OnL1Block(block types.Block, result *components.BlockIngestionType) error + OnL1Block(ctx context.Context, block types.Block, result *components.BlockIngestionType) error Close() error } type Sequencer interface { // CreateBatch - creates a new head batch for the latest known L1 head block. - CreateBatch(skipBatchIfEmpty bool) error + CreateBatch(ctx context.Context, skipBatchIfEmpty bool) error // CreateRollup - creates a new rollup from the latest recorded rollup in the head l1 chain // and adds as many batches to it as possible. - CreateRollup(lastBatchNo uint64) (*common.ExtRollup, error) + CreateRollup(ctx context.Context, lastBatchNo uint64) (*common.ExtRollup, error) NodeType } type ObsValidator interface { // ExecuteStoredBatches - try to execute all stored by unexecuted batches - ExecuteStoredBatches() error + ExecuteStoredBatches(context.Context) error VerifySequencerSignature(*core.Batch) error diff --git a/go/enclave/nodetype/sequencer.go b/go/enclave/nodetype/sequencer.go index a5060f828c..780318c107 100644 --- a/go/enclave/nodetype/sequencer.go +++ b/go/enclave/nodetype/sequencer.go @@ -1,6 +1,7 @@ package nodetype import ( + "context" "errors" "fmt" "math/big" @@ -82,21 +83,21 @@ func NewSequencer(blockProcessor components.L1BlockProcessor, batchExecutor comp } } -func (s *sequencer) CreateBatch(skipBatchIfEmpty bool) error { +func (s *sequencer) CreateBatch(ctx context.Context, skipBatchIfEmpty bool) error { hasGenesis, err := s.batchRegistry.HasGenesisBatch() if err != nil { return fmt.Errorf("unknown genesis batch state. Cause: %w", err) } // L1 Head is only updated when isLatest: true - l1HeadBlock, err := s.blockProcessor.GetHead() + l1HeadBlock, err := s.blockProcessor.GetHead(ctx) if err != nil { return fmt.Errorf("failed retrieving l1 head. Cause: %w", err) } // the sequencer creates the initial genesis batch if one does not exist yet if !hasGenesis { - return s.createGenesisBatch(l1HeadBlock) + return s.createGenesisBatch(ctx, l1HeadBlock) } if running := s.mempool.Running(); !running { @@ -108,16 +109,17 @@ func (s *sequencer) CreateBatch(skipBatchIfEmpty bool) error { } } - return s.createNewHeadBatch(l1HeadBlock, skipBatchIfEmpty) + return s.createNewHeadBatch(ctx, l1HeadBlock, skipBatchIfEmpty) } // TODO - This is iffy, the producer commits the stateDB. The producer // should only create batches and stateDBs but not commit them to the database, // this is the responsibility of the sequencer. Refactor the code so genesis state // won't be committed by the producer. -func (s *sequencer) createGenesisBatch(block *common.L1Block) error { +func (s *sequencer) createGenesisBatch(ctx context.Context, block *common.L1Block) error { s.logger.Info("Initializing genesis state", log.BlockHashKey, block.Hash()) batch, msgBusTx, err := s.batchProducer.CreateGenesisState( + ctx, block.Hash(), uint64(time.Now().Unix()), s.settings.GasPaymentAddress, @@ -131,7 +133,7 @@ func (s *sequencer) createGenesisBatch(block *common.L1Block) error { return fmt.Errorf("failed signing created batch. Cause: %w", err) } - if err := s.StoreExecutedBatch(batch, nil); err != nil { + if err := s.StoreExecutedBatch(ctx, batch, nil); err != nil { return fmt.Errorf("1. failed storing batch. Cause: %w", err) } @@ -152,6 +154,7 @@ func (s *sequencer) createGenesisBatch(block *common.L1Block) error { time.Sleep(time.Second) // produce batch #2 which has the message bus and any other system contracts cb, err := s.produceBatch( + ctx, big.NewInt(0).Add(batch.Header.SequencerOrderNo, big.NewInt(1)), block.Hash(), batch.Hash(), @@ -180,22 +183,22 @@ func (s *sequencer) createGenesisBatch(block *common.L1Block) error { return nil } -func (s *sequencer) createNewHeadBatch(l1HeadBlock *common.L1Block, skipBatchIfEmpty bool) error { +func (s *sequencer) createNewHeadBatch(ctx context.Context, l1HeadBlock *common.L1Block, skipBatchIfEmpty bool) error { headBatchSeq := s.batchRegistry.HeadBatchSeq() if headBatchSeq == nil { headBatchSeq = big.NewInt(int64(common.L2GenesisSeqNo)) } - headBatch, err := s.storage.FetchBatchBySeqNo(headBatchSeq.Uint64()) + headBatch, err := s.storage.FetchBatchBySeqNo(ctx, headBatchSeq.Uint64()) if err != nil { return err } // todo - sanity check that the headBatch.Header.L1Proof is an ancestor of the l1HeadBlock - b, err := s.storage.FetchBlock(headBatch.Header.L1Proof) + b, err := s.storage.FetchBlock(ctx, headBatch.Header.L1Proof) if err != nil { return err } - if !s.storage.IsAncestor(l1HeadBlock, b) { + if !s.storage.IsAncestor(ctx, l1HeadBlock, b) { return fmt.Errorf("attempted to create batch on top of batch=%s. With l1 head=%s", headBatch.Hash(), l1HeadBlock.Hash()) } @@ -221,13 +224,13 @@ func (s *sequencer) createNewHeadBatch(l1HeadBlock *common.L1Block, skipBatchIfE } } - sequencerNo, err := s.storage.FetchCurrentSequencerNo() + sequencerNo, err := s.storage.FetchCurrentSequencerNo(ctx) if err != nil { return err } // todo - time is set only here; take from l1 block? - if _, err := s.produceBatch(sequencerNo.Add(sequencerNo, big.NewInt(1)), l1HeadBlock.Hash(), headBatch.Hash(), transactions, uint64(time.Now().Unix()), skipBatchIfEmpty); err != nil { + if _, err := s.produceBatch(ctx, sequencerNo.Add(sequencerNo, big.NewInt(1)), l1HeadBlock.Hash(), headBatch.Hash(), transactions, uint64(time.Now().Unix()), skipBatchIfEmpty); err != nil { if errors.Is(err, components.ErrNoTransactionsToProcess) { // skip batch production when there are no transactions to process // todo: this might be a useful event to track for metrics (skipping batch production because empty batch) @@ -241,6 +244,7 @@ func (s *sequencer) createNewHeadBatch(l1HeadBlock *common.L1Block, skipBatchIfE } func (s *sequencer) produceBatch( + ctx context.Context, sequencerNo *big.Int, l1Hash common.L1BlockHash, headBatch common.L2BatchHash, @@ -248,16 +252,17 @@ func (s *sequencer) produceBatch( batchTime uint64, failForEmptyBatch bool, ) (*components.ComputedBatch, error) { - cb, err := s.batchProducer.ComputeBatch(&components.BatchExecutionContext{ - BlockPtr: l1Hash, - ParentPtr: headBatch, - Transactions: transactions, - AtTime: batchTime, - Creator: s.settings.GasPaymentAddress, - BaseFee: s.settings.BaseFee, - ChainConfig: s.chainConfig, - SequencerNo: sequencerNo, - }, failForEmptyBatch) + cb, err := s.batchProducer.ComputeBatch(ctx, + &components.BatchExecutionContext{ + BlockPtr: l1Hash, + ParentPtr: headBatch, + Transactions: transactions, + AtTime: batchTime, + Creator: s.settings.GasPaymentAddress, + BaseFee: s.settings.BaseFee, + ChainConfig: s.chainConfig, + SequencerNo: sequencerNo, + }, failForEmptyBatch) if err != nil { return nil, fmt.Errorf("failed computing batch. Cause: %w", err) } @@ -270,7 +275,7 @@ func (s *sequencer) produceBatch( return nil, fmt.Errorf("failed signing created batch. Cause: %w", err) } - if err := s.StoreExecutedBatch(cb.Batch, cb.Receipts); err != nil { + if err := s.StoreExecutedBatch(ctx, cb.Batch, cb.Receipts); err != nil { return nil, fmt.Errorf("2. failed storing batch. Cause: %w", err) } @@ -288,25 +293,25 @@ func (s *sequencer) produceBatch( // StoreExecutedBatch - stores an executed batch in one go. This can be done for the sequencer because it is guaranteed // that all dependencies are in place for the execution to be successful. -func (s *sequencer) StoreExecutedBatch(batch *core.Batch, receipts types.Receipts) error { +func (s *sequencer) StoreExecutedBatch(ctx context.Context, batch *core.Batch, receipts types.Receipts) error { defer core.LogMethodDuration(s.logger, measure.NewStopwatch(), "Registry StoreBatch() exit", log.BatchHashKey, batch.Hash()) // Check if this batch is already stored. - if _, err := s.storage.FetchBatchHeader(batch.Hash()); err == nil { + if _, err := s.storage.FetchBatchHeader(ctx, batch.Hash()); err == nil { s.logger.Warn("Attempted to store batch twice! This indicates issues with the batch processing loop") return nil } - convertedHeader, err := s.gethEncoding.CreateEthHeaderForBatch(batch.Header) + convertedHeader, err := s.gethEncoding.CreateEthHeaderForBatch(ctx, batch.Header) if err != nil { return err } - if err := s.storage.StoreBatch(batch, convertedHeader.Hash()); err != nil { + if err := s.storage.StoreBatch(ctx, batch, convertedHeader.Hash()); err != nil { return fmt.Errorf("failed to store batch. Cause: %w", err) } - if err := s.storage.StoreExecutedBatch(batch, receipts); err != nil { + if err := s.storage.StoreExecutedBatch(ctx, batch, receipts); err != nil { return fmt.Errorf("failed to store batch. Cause: %w", err) } @@ -315,20 +320,20 @@ func (s *sequencer) StoreExecutedBatch(batch *core.Batch, receipts types.Receipt return nil } -func (s *sequencer) CreateRollup(lastBatchNo uint64) (*common.ExtRollup, error) { +func (s *sequencer) CreateRollup(ctx context.Context, lastBatchNo uint64) (*common.ExtRollup, error) { rollupLimiter := limiters.NewRollupLimiter(s.settings.MaxRollupSize) - currentL1Head, err := s.blockProcessor.GetHead() + currentL1Head, err := s.blockProcessor.GetHead(ctx) if err != nil { return nil, err } upToL1Height := currentL1Head.NumberU64() - RollupDelay - rollup, err := s.rollupProducer.CreateInternalRollup(lastBatchNo, upToL1Height, rollupLimiter) + rollup, err := s.rollupProducer.CreateInternalRollup(ctx, lastBatchNo, upToL1Height, rollupLimiter) if err != nil { return nil, err } - extRollup, err := s.rollupCompression.CreateExtRollup(rollup) + extRollup, err := s.rollupCompression.CreateExtRollup(ctx, rollup) if err != nil { return nil, fmt.Errorf("failed to compress rollup: %w", err) } @@ -341,12 +346,12 @@ func (s *sequencer) CreateRollup(lastBatchNo uint64) (*common.ExtRollup, error) return extRollup, nil } -func (s *sequencer) duplicateBatches(l1Head *types.Block, nonCanonicalL1Path []common.L1BlockHash) error { +func (s *sequencer) duplicateBatches(ctx context.Context, l1Head *types.Block, nonCanonicalL1Path []common.L1BlockHash) error { batchesToDuplicate := make([]*core.Batch, 0) // read the batches attached to these blocks for _, l1BlockHash := range nonCanonicalL1Path { - batches, err := s.storage.FetchBatchesByBlock(l1BlockHash) + batches, err := s.storage.FetchBatchesByBlock(ctx, l1BlockHash) if err != nil { if errors.Is(err, errutil.ErrNotFound) { continue @@ -373,13 +378,13 @@ func (s *sequencer) duplicateBatches(l1Head *types.Block, nonCanonicalL1Path []c if i > 0 && batchesToDuplicate[i].Header.ParentHash != batchesToDuplicate[i-1].Hash() { s.logger.Crit("the batches that must be duplicated are invalid") } - sequencerNo, err := s.storage.FetchCurrentSequencerNo() + sequencerNo, err := s.storage.FetchCurrentSequencerNo(ctx) if err != nil { return fmt.Errorf("could not fetch sequencer no. Cause %w", err) } sequencerNo = sequencerNo.Add(sequencerNo, big.NewInt(1)) // create the duplicate and store/broadcast it, recreate batch even if it was empty - cb, err := s.produceBatch(sequencerNo, l1Head.ParentHash(), currentHead, orphanBatch.Transactions, orphanBatch.Header.Time, false) + cb, err := s.produceBatch(ctx, sequencerNo, l1Head.ParentHash(), currentHead, orphanBatch.Transactions, orphanBatch.Header.Time, false) if err != nil { return fmt.Errorf("could not produce batch. Cause %w", err) } @@ -394,17 +399,17 @@ func (s *sequencer) SubmitTransaction(transaction *common.L2Tx) error { return s.mempool.Add(transaction) } -func (s *sequencer) OnL1Fork(fork *common.ChainFork) error { +func (s *sequencer) OnL1Fork(ctx context.Context, fork *common.ChainFork) error { if !fork.IsFork() { return nil } - err := s.duplicateBatches(fork.NewCanonical, fork.NonCanonicalPath) + err := s.duplicateBatches(ctx, fork.NewCanonical, fork.NonCanonicalPath) if err != nil { return fmt.Errorf("could not duplicate batches. Cause %w", err) } - rollup, err := s.storage.FetchReorgedRollup(fork.NonCanonicalPath) + rollup, err := s.storage.FetchReorgedRollup(ctx, fork.NonCanonicalPath) if err == nil { s.logger.Error("Reissue rollup", log.RollupHashKey, rollup) // todo - tudor - finalise the logic to reissue a rollup when the block used for compression was reorged @@ -437,7 +442,7 @@ func (s *sequencer) signRollup(rollup *common.ExtRollup) error { return nil } -func (s *sequencer) OnL1Block(_ types.Block, _ *components.BlockIngestionType) error { +func (s *sequencer) OnL1Block(ctx context.Context, block types.Block, result *components.BlockIngestionType) error { // nothing to do return nil } diff --git a/go/enclave/nodetype/validator.go b/go/enclave/nodetype/validator.go index b3ef5a93de..aed2be6cb7 100644 --- a/go/enclave/nodetype/validator.go +++ b/go/enclave/nodetype/validator.go @@ -1,6 +1,7 @@ package nodetype import ( + "context" "errors" "fmt" "math/big" @@ -66,7 +67,7 @@ func (val *obsValidator) SubmitTransaction(tx *common.L2Tx) error { return err } -func (val *obsValidator) OnL1Fork(_ *common.ChainFork) error { +func (val *obsValidator) OnL1Fork(ctx context.Context, fork *common.ChainFork) error { // nothing to do return nil } @@ -75,12 +76,12 @@ func (val *obsValidator) VerifySequencerSignature(b *core.Batch) error { return val.sigValidator.CheckSequencerSignature(b.Hash(), b.Header.Signature) } -func (val *obsValidator) ExecuteStoredBatches() error { +func (val *obsValidator) ExecuteStoredBatches(ctx context.Context) error { headBatchSeq := val.batchRegistry.HeadBatchSeq() if headBatchSeq == nil { headBatchSeq = big.NewInt(int64(common.L2GenesisSeqNo)) } - batches, err := val.storage.FetchCanonicalUnexecutedBatches(headBatchSeq) + batches, err := val.storage.FetchCanonicalUnexecutedBatches(ctx, headBatchSeq) if err != nil { if errors.Is(err, errutil.ErrNotFound) { return nil @@ -92,23 +93,23 @@ func (val *obsValidator) ExecuteStoredBatches() error { for _, batch := range batches { if batch.IsGenesis() { - if err = val.handleGenesis(batch); err != nil { + if err = val.handleGenesis(ctx, batch); err != nil { return err } } // check batch execution prerequisites - canExecute, err := val.executionPrerequisites(batch) + canExecute, err := val.executionPrerequisites(ctx, batch) if err != nil { return fmt.Errorf("could not determine the execution prerequisites for batch %s. Cause: %w", batch.Hash(), err) } if canExecute { - receipts, err := val.batchExecutor.ExecuteBatch(batch) + receipts, err := val.batchExecutor.ExecuteBatch(ctx, batch) if err != nil { return fmt.Errorf("could not execute batch %s. Cause: %w", batch.Hash(), err) } - err = val.storage.StoreExecutedBatch(batch, receipts) + err = val.storage.StoreExecutedBatch(ctx, batch, receipts) if err != nil { return fmt.Errorf("could not store executed batch %s. Cause: %w", batch.Hash(), err) } @@ -122,16 +123,16 @@ func (val *obsValidator) ExecuteStoredBatches() error { return nil } -func (val *obsValidator) executionPrerequisites(batch *core.Batch) (bool, error) { +func (val *obsValidator) executionPrerequisites(ctx context.Context, batch *core.Batch) (bool, error) { // 1.l1 block exists - block, err := val.storage.FetchBlock(batch.Header.L1Proof) + block, err := val.storage.FetchBlock(ctx, batch.Header.L1Proof) if err != nil && errors.Is(err, errutil.ErrNotFound) { val.logger.Info("Error fetching block", log.BlockHashKey, batch.Header.L1Proof, log.ErrKey, err) return false, err } // 2. parent was executed - parentExecuted, err := val.storage.BatchWasExecuted(batch.Header.ParentHash) + parentExecuted, err := val.storage.BatchWasExecuted(ctx, batch.Header.ParentHash) if err != nil { val.logger.Info("Error reading execution status of batch", log.BatchHashKey, batch.Header.ParentHash, log.ErrKey, err) return false, err @@ -140,8 +141,8 @@ func (val *obsValidator) executionPrerequisites(batch *core.Batch) (bool, error) return block != nil && parentExecuted, nil } -func (val *obsValidator) handleGenesis(batch *core.Batch) error { - genBatch, _, err := val.batchExecutor.CreateGenesisState(batch.Header.L1Proof, batch.Header.Time, batch.Header.Coinbase, batch.Header.BaseFee) +func (val *obsValidator) handleGenesis(ctx context.Context, batch *core.Batch) error { + genBatch, _, err := val.batchExecutor.CreateGenesisState(ctx, batch.Header.L1Proof, batch.Header.Time, batch.Header.Coinbase, batch.Header.BaseFee) if err != nil { return err } @@ -150,7 +151,7 @@ func (val *obsValidator) handleGenesis(batch *core.Batch) error { return fmt.Errorf("received invalid genesis batch") } - err = val.storage.StoreExecutedBatch(genBatch, nil) + err = val.storage.StoreExecutedBatch(ctx, genBatch, nil) if err != nil { return err } @@ -158,8 +159,8 @@ func (val *obsValidator) handleGenesis(batch *core.Batch) error { return nil } -func (val *obsValidator) OnL1Block(_ types.Block, _ *components.BlockIngestionType) error { - return val.ExecuteStoredBatches() +func (val *obsValidator) OnL1Block(ctx context.Context, block types.Block, result *components.BlockIngestionType) error { + return val.ExecuteStoredBatches(ctx) } func (val *obsValidator) Close() error { diff --git a/go/enclave/rpc/EstimateGas.go b/go/enclave/rpc/EstimateGas.go index f80c47aaf9..0c583da43e 100644 --- a/go/enclave/rpc/EstimateGas.go +++ b/go/enclave/rpc/EstimateGas.go @@ -1,6 +1,7 @@ package rpc import ( + "context" "errors" "fmt" "math/big" @@ -61,7 +62,7 @@ func EstimateGasExecute(builder *CallBuilder[CallParamsWithBlock, hexutil.Uint64 txArgs := builder.Param.callParams blockNumber := builder.Param.block - block, err := rpc.l1BlockProcessor.GetHead() + block, err := rpc.l1BlockProcessor.GetHead(builder.ctx) if err != nil { return err } @@ -74,7 +75,7 @@ func EstimateGasExecute(builder *CallBuilder[CallParamsWithBlock, hexutil.Uint64 } headBatchSeq := rpc.registry.HeadBatchSeq() - batch, err := rpc.storage.FetchBatchBySeqNo(headBatchSeq.Uint64()) + batch, err := rpc.storage.FetchBatchBySeqNo(builder.ctx, headBatchSeq.Uint64()) if err != nil { return err } @@ -93,7 +94,7 @@ func EstimateGasExecute(builder *CallBuilder[CallParamsWithBlock, hexutil.Uint64 // TODO: Change to fixed time period quotes, rather than this. publishingGas = publishingGas.Mul(publishingGas, gethcommon.Big2) - executionGasEstimate, err := rpc.doEstimateGas(txArgs, blockNumber, rpc.config.GasLocalExecutionCapFlag) + executionGasEstimate, err := rpc.doEstimateGas(builder.ctx, txArgs, blockNumber, rpc.config.GasLocalExecutionCapFlag) if err != nil { err = fmt.Errorf("unable to estimate transaction - %w", err) @@ -115,7 +116,7 @@ func EstimateGasExecute(builder *CallBuilder[CallParamsWithBlock, hexutil.Uint64 // This is a copy of https://github.com/ethereum/go-ethereum/blob/master/internal/ethapi/api.go#L1055 // there's a high complexity to the method due to geth business rules (which is mimic'd here) // once the work of obscuro gas mechanics is established this method should be simplified -func (rpc *EncryptionManager) doEstimateGas(args *gethapi.TransactionArgs, blkNumber *gethrpc.BlockNumber, gasCap uint64) (hexutil.Uint64, common.SystemError) { //nolint: gocognit +func (rpc *EncryptionManager) doEstimateGas(ctx context.Context, args *gethapi.TransactionArgs, blkNumber *gethrpc.BlockNumber, gasCap uint64) (hexutil.Uint64, common.SystemError) { //nolint: gocognit // Binary search the gas requirement, as it may be higher than the amount used var ( //nolint: revive lo = params.TxGas - 1 @@ -157,7 +158,7 @@ func (rpc *EncryptionManager) doEstimateGas(args *gethapi.TransactionArgs, blkNu } // Recap the highest gas limit with account's available balance. if feeCap.BitLen() != 0 { //nolint:nestif - balance, err := rpc.chain.GetBalanceAtBlock(*args.From, blkNumber) + balance, err := rpc.chain.GetBalanceAtBlock(ctx, *args.From, blkNumber) if err != nil { return 0, fmt.Errorf("unable to fetch account balance - %w", err) } @@ -198,7 +199,7 @@ func (rpc *EncryptionManager) doEstimateGas(args *gethapi.TransactionArgs, blkNu // range here is skewed to favor the low side. mid = lo * 2 } - failed, _, err := rpc.isGasEnough(args, mid, blkNumber) + failed, _, err := rpc.isGasEnough(ctx, args, mid, blkNumber) // If the error is not nil(consensus error), it means the provided message // call or transaction will never be accepted no matter how much gas it is // assigned. Return the error directly, don't struggle any more. @@ -213,7 +214,7 @@ func (rpc *EncryptionManager) doEstimateGas(args *gethapi.TransactionArgs, blkNu } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { //nolint:nestif - failed, result, err := rpc.isGasEnough(args, hi, blkNumber) + failed, result, err := rpc.isGasEnough(ctx, args, hi, blkNumber) if err != nil { return 0, err } @@ -233,10 +234,10 @@ func (rpc *EncryptionManager) doEstimateGas(args *gethapi.TransactionArgs, blkNu // Create a helper to check if a gas allowance results in an executable transaction // isGasEnough returns whether the gaslimit should be raised, lowered, or if it was impossible to execute the message -func (rpc *EncryptionManager) isGasEnough(args *gethapi.TransactionArgs, gas uint64, blkNumber *gethrpc.BlockNumber) (bool, *gethcore.ExecutionResult, error) { +func (rpc *EncryptionManager) isGasEnough(ctx context.Context, args *gethapi.TransactionArgs, gas uint64, blkNumber *gethrpc.BlockNumber) (bool, *gethcore.ExecutionResult, error) { defer core.LogMethodDuration(rpc.logger, measure.NewStopwatch(), "enclave.go:IsGasEnough") args.Gas = (*hexutil.Uint64)(&gas) - result, err := rpc.chain.ObsCallAtBlock(args, blkNumber) + result, err := rpc.chain.ObsCallAtBlock(ctx, args, blkNumber) if err != nil { if errors.Is(err, gethcore.ErrIntrinsicGas) { return true, nil, nil // Special case, raise gas limit diff --git a/go/enclave/rpc/GetBalance.go b/go/enclave/rpc/GetBalance.go index 8bdc26fc14..4fae78d2ab 100644 --- a/go/enclave/rpc/GetBalance.go +++ b/go/enclave/rpc/GetBalance.go @@ -43,7 +43,7 @@ func GetBalanceValidate(reqParams []any, builder *CallBuilder[BalanceReq, hexuti } func GetBalanceExecute(builder *CallBuilder[BalanceReq, hexutil.Big], rpc *EncryptionManager) error { - acctOwner, err := rpc.chain.AccountOwner(*builder.Param.Addr, builder.Param.Block.BlockNumber) + acctOwner, err := rpc.chain.AccountOwner(builder.ctx, *builder.Param.Addr, builder.Param.Block.BlockNumber) if err != nil { return err } @@ -55,7 +55,7 @@ func GetBalanceExecute(builder *CallBuilder[BalanceReq, hexutil.Big], rpc *Encry return nil } - balance, err := rpc.chain.GetBalanceAtBlock(*builder.Param.Addr, builder.Param.Block.BlockNumber) + balance, err := rpc.chain.GetBalanceAtBlock(builder.ctx, *builder.Param.Addr, builder.Param.Block.BlockNumber) if err != nil { return fmt.Errorf("unable to get balance - %w", err) } diff --git a/go/enclave/rpc/GetCustomQuery.go b/go/enclave/rpc/GetCustomQuery.go index 77b21a44fd..a280b4c33d 100644 --- a/go/enclave/rpc/GetCustomQuery.go +++ b/go/enclave/rpc/GetCustomQuery.go @@ -31,12 +31,12 @@ func GetCustomQueryExecute(builder *CallBuilder[common.PrivateCustomQueryListTra return nil //nolint:nilerr } - encryptReceipts, err := rpc.storage.GetReceiptsPerAddress(&builder.Param.Address, &builder.Param.Pagination) + encryptReceipts, err := rpc.storage.GetReceiptsPerAddress(builder.ctx, &builder.Param.Address, &builder.Param.Pagination) if err != nil { return fmt.Errorf("GetReceiptsPerAddress - %w", err) } - receiptsCount, err := rpc.storage.GetReceiptsPerAddressCount(&builder.Param.Address) + receiptsCount, err := rpc.storage.GetReceiptsPerAddressCount(builder.ctx, &builder.Param.Address) if err != nil { return fmt.Errorf("GetReceiptsPerAddressCount - %w", err) } diff --git a/go/enclave/rpc/GetLogs.go b/go/enclave/rpc/GetLogs.go index b4798600e2..3d938883d0 100644 --- a/go/enclave/rpc/GetLogs.go +++ b/go/enclave/rpc/GetLogs.go @@ -51,7 +51,7 @@ func GetLogsExecute(builder *CallBuilder[filters.FilterCriteria, []*types.Log], from := filter.FromBlock if from != nil && from.Int64() < 0 { - batch, err := rpc.storage.FetchBatchBySeqNo(rpc.registry.HeadBatchSeq().Uint64()) + batch, err := rpc.storage.FetchBatchBySeqNo(builder.ctx, rpc.registry.HeadBatchSeq().Uint64()) if err != nil { // system error return fmt.Errorf("could not retrieve head batch. Cause: %w", err) @@ -61,7 +61,7 @@ func GetLogsExecute(builder *CallBuilder[filters.FilterCriteria, []*types.Log], // Set from to the height of the block hash if from == nil && filter.BlockHash != nil { - batch, err := rpc.storage.FetchBatchHeader(*filter.BlockHash) + batch, err := rpc.storage.FetchBatchHeader(builder.ctx, *filter.BlockHash) if err != nil { if errors.Is(err, errutil.ErrNotFound) { builder.Status = NotFound @@ -84,7 +84,7 @@ func GetLogsExecute(builder *CallBuilder[filters.FilterCriteria, []*types.Log], } // We retrieve the relevant logs that match the filter. - filteredLogs, err := rpc.storage.FilterLogs(builder.VK.AccountAddress, from, to, nil, filter.Addresses, filter.Topics) + filteredLogs, err := rpc.storage.FilterLogs(builder.ctx, builder.VK.AccountAddress, from, to, nil, filter.Addresses, filter.Topics) if err != nil { if errors.Is(err, syserr.InternalError{}) { return err diff --git a/go/enclave/rpc/GetTransaction.go b/go/enclave/rpc/GetTransaction.go index a06080985f..6690b99efd 100644 --- a/go/enclave/rpc/GetTransaction.go +++ b/go/enclave/rpc/GetTransaction.go @@ -32,7 +32,7 @@ func GetTransactionValidate(reqParams []any, builder *CallBuilder[gethcommon.Has func GetTransactionExecute(builder *CallBuilder[gethcommon.Hash, RpcTransaction], rpc *EncryptionManager) error { // Unlike in the Geth impl, we do not try and retrieve unconfirmed transactions from the mempool. - tx, blockHash, blockNumber, index, err := rpc.storage.GetTransaction(*builder.Param) + tx, blockHash, blockNumber, index, err := rpc.storage.GetTransaction(builder.ctx, *builder.Param) if err != nil { if errors.Is(err, errutil.ErrNotFound) { builder.Status = NotFound diff --git a/go/enclave/rpc/GetTransactionCount.go b/go/enclave/rpc/GetTransactionCount.go index 80c57e4824..10018d2cd9 100644 --- a/go/enclave/rpc/GetTransactionCount.go +++ b/go/enclave/rpc/GetTransactionCount.go @@ -31,7 +31,7 @@ func GetTransactionCountValidate(reqParams []any, builder *CallBuilder[uint64, s } // todo - support BlockNumberOrHash - b, err := rpc.registry.GetBatchAtHeight(*tag.BlockNumber) + b, err := rpc.registry.GetBatchAtHeight(builder.ctx, *tag.BlockNumber) if err != nil { builder.Err = fmt.Errorf("cant retrieve batch for tag. Cause: %w", err) return nil @@ -52,11 +52,11 @@ func GetTransactionCountExecute(builder *CallBuilder[uint64, string], rpc *Encry } var nonce uint64 - l2Head, err := rpc.storage.FetchBatchBySeqNo(*builder.Param) + l2Head, err := rpc.storage.FetchBatchBySeqNo(builder.ctx, *builder.Param) if err == nil { // todo - we should return an error when head state is not available, but for current test situations with race // conditions we allow it to return zero while head state is uninitialized - s, err := rpc.storage.CreateStateDB(l2Head.Hash()) + s, err := rpc.storage.CreateStateDB(builder.ctx, l2Head.Hash()) if err != nil { return err } diff --git a/go/enclave/rpc/GetTransactionReceipt.go b/go/enclave/rpc/GetTransactionReceipt.go index 8a2342fcfc..6dabe91f3c 100644 --- a/go/enclave/rpc/GetTransactionReceipt.go +++ b/go/enclave/rpc/GetTransactionReceipt.go @@ -39,7 +39,7 @@ func GetTransactionReceiptExecute(builder *CallBuilder[gethcommon.Hash, map[stri // todo - optimise these calls. This can be done with a single sql rpc.logger.Trace("Get receipt for ", log.TxKey, txHash) // We retrieve the transaction. - tx, blockHash, number, txIndex, err := rpc.storage.GetTransaction(txHash) //nolint:dogsled + tx, blockHash, number, txIndex, err := rpc.storage.GetTransaction(builder.ctx, txHash) //nolint:dogsled if err != nil { rpc.logger.Trace("error getting tx ", log.TxKey, txHash, log.ErrKey, err) if errors.Is(err, errutil.ErrNotFound) { @@ -62,7 +62,7 @@ func GetTransactionReceiptExecute(builder *CallBuilder[gethcommon.Hash, map[stri } // We retrieve the transaction receipt. - txReceipt, err := rpc.storage.GetTransactionReceipt(txHash) + txReceipt, err := rpc.storage.GetTransactionReceipt(builder.ctx, txHash) if err != nil { rpc.logger.Trace("error getting tx receipt", log.TxKey, txHash, log.ErrKey, err) if errors.Is(err, errutil.ErrNotFound) { @@ -74,7 +74,7 @@ func GetTransactionReceiptExecute(builder *CallBuilder[gethcommon.Hash, map[stri } // We filter out irrelevant logs. - txReceipt.Logs, err = events.FilterLogsForReceipt(txReceipt, &txSigner, rpc.storage) + txReceipt.Logs, err = events.FilterLogsForReceipt(builder.ctx, txReceipt, &txSigner, rpc.storage) if err != nil { rpc.logger.Error("error filter logs ", log.TxKey, txHash, log.ErrKey, err) // this is a system error diff --git a/go/enclave/rpc/TenEthCall.go b/go/enclave/rpc/TenEthCall.go index a530b8f194..cacaa518de 100644 --- a/go/enclave/rpc/TenEthCall.go +++ b/go/enclave/rpc/TenEthCall.go @@ -49,7 +49,7 @@ func TenCallExecute(builder *CallBuilder[CallParamsWithBlock, string], rpc *Encr apiArgs := builder.Param.callParams blkNumber := builder.Param.block - execResult, err := rpc.chain.ObsCall(apiArgs, blkNumber) + execResult, err := rpc.chain.ObsCall(builder.ctx, apiArgs, blkNumber) if err != nil { rpc.logger.Debug("Failed eth_call.", log.ErrKey, err) diff --git a/go/enclave/rpc/vk_utils.go b/go/enclave/rpc/vk_utils.go index 9d46af4733..ea1229ef28 100644 --- a/go/enclave/rpc/vk_utils.go +++ b/go/enclave/rpc/vk_utils.go @@ -1,6 +1,7 @@ package rpc import ( + "context" "encoding/json" "errors" "fmt" @@ -27,6 +28,7 @@ const ( // CallBuilder - builder used during processing of an RPC request, which is a multi-step process type CallBuilder[P any, R any] struct { + ctx context.Context Param *P // value calculated during phase 1 to be used during the execution phase VK *vkhandler.AuthenticatedViewingKey // the vk accompanying the request From *gethcommon.Address // extracted from the request @@ -44,6 +46,7 @@ type CallBuilder[P any, R any] struct { // e.g. - "getTransaction" or "getBalance" have to perform authorisation // "Ten_call" , "Estimate_Gas" - have to authenticate the "From" - which will be used by the EVM func WithVKEncryption[P any, R any]( + ctx context.Context, encManager *EncryptionManager, encReq []byte, // encrypted request that contains a signed viewing key validate func([]any, *CallBuilder[P, R], *EncryptionManager) error, @@ -71,7 +74,7 @@ func WithVKEncryption[P any, R any]( } // 4. Call the function that knows how to validate the request - builder := &CallBuilder[P, R]{Status: NotSet, VK: vk} + builder := &CallBuilder[P, R]{Status: NotSet, VK: vk, ctx: ctx} err = validate(decodedRequest.Params, builder, encManager) if err != nil { diff --git a/go/enclave/rpc_server.go b/go/enclave/rpc_server.go index a504481a8e..8947add528 100644 --- a/go/enclave/rpc_server.go +++ b/go/enclave/rpc_server.go @@ -63,8 +63,8 @@ func (s *RPCServer) StartServer() error { } // Status returns the current status of the RPCServer as an enum value (see common.Status for details) -func (s *RPCServer) Status(context.Context, *generated.StatusRequest) (*generated.StatusResponse, error) { - status, sysError := s.enclave.Status() +func (s *RPCServer) Status(ctx context.Context, _ *generated.StatusRequest) (*generated.StatusResponse, error) { + status, sysError := s.enclave.Status(ctx) if sysError != nil { s.logger.Error("Enclave error on Status", log.ErrKey, sysError) } @@ -80,8 +80,8 @@ func (s *RPCServer) Status(context.Context, *generated.StatusRequest) (*generate }, nil } -func (s *RPCServer) Attestation(context.Context, *generated.AttestationRequest) (*generated.AttestationResponse, error) { - attestation, sysError := s.enclave.Attestation() +func (s *RPCServer) Attestation(ctx context.Context, _ *generated.AttestationRequest) (*generated.AttestationResponse, error) { + attestation, sysError := s.enclave.Attestation(ctx) if sysError != nil { s.logger.Error("Error getting attestation", log.ErrKey, sysError) return &generated.AttestationResponse{SystemError: toRPCError(sysError)}, nil @@ -90,8 +90,8 @@ func (s *RPCServer) Attestation(context.Context, *generated.AttestationRequest) return &generated.AttestationResponse{AttestationReportMsg: &msg}, nil } -func (s *RPCServer) GenerateSecret(context.Context, *generated.GenerateSecretRequest) (*generated.GenerateSecretResponse, error) { - secret, sysError := s.enclave.GenerateSecret() +func (s *RPCServer) GenerateSecret(ctx context.Context, _ *generated.GenerateSecretRequest) (*generated.GenerateSecretResponse, error) { + secret, sysError := s.enclave.GenerateSecret(ctx) if sysError != nil { s.logger.Error("Error generating secret", log.ErrKey, sysError) return &generated.GenerateSecretResponse{SystemError: toRPCError(sysError)}, nil @@ -99,16 +99,16 @@ func (s *RPCServer) GenerateSecret(context.Context, *generated.GenerateSecretReq return &generated.GenerateSecretResponse{EncryptedSharedEnclaveSecret: secret}, nil } -func (s *RPCServer) InitEnclave(_ context.Context, request *generated.InitEnclaveRequest) (*generated.InitEnclaveResponse, error) { - sysError := s.enclave.InitEnclave(request.EncryptedSharedEnclaveSecret) +func (s *RPCServer) InitEnclave(ctx context.Context, request *generated.InitEnclaveRequest) (*generated.InitEnclaveResponse, error) { + sysError := s.enclave.InitEnclave(ctx, request.EncryptedSharedEnclaveSecret) if sysError != nil { s.logger.Error("Error initialising the enclave", log.ErrKey, sysError) } return &generated.InitEnclaveResponse{SystemError: toRPCError(sysError)}, nil } -func (s *RPCServer) EnclaveID(_ context.Context, _ *generated.EnclaveIDRequest) (*generated.EnclaveIDResponse, error) { - id, sysError := s.enclave.EnclaveID() +func (s *RPCServer) EnclaveID(ctx context.Context, _ *generated.EnclaveIDRequest) (*generated.EnclaveIDResponse, error) { + id, sysError := s.enclave.EnclaveID(ctx) if sysError != nil { s.logger.Error("Error getting enclave ID", log.ErrKey, sysError) return &generated.EnclaveIDResponse{SystemError: toRPCError(sysError)}, nil @@ -116,7 +116,7 @@ func (s *RPCServer) EnclaveID(_ context.Context, _ *generated.EnclaveIDRequest) return &generated.EnclaveIDResponse{EnclaveID: id.Bytes()}, nil } -func (s *RPCServer) SubmitL1Block(_ context.Context, request *generated.SubmitBlockRequest) (*generated.SubmitBlockResponse, error) { +func (s *RPCServer) SubmitL1Block(ctx context.Context, request *generated.SubmitBlockRequest) (*generated.SubmitBlockResponse, error) { bl, err := s.decodeBlock(request.EncodedBlock) if err != nil { s.logger.Error("Error decoding block", log.ErrKey, err) @@ -127,7 +127,7 @@ func (s *RPCServer) SubmitL1Block(_ context.Context, request *generated.SubmitBl s.logger.Error("Error decoding receipts", log.ErrKey, err) return nil, err } - blockSubmissionResponse, err := s.enclave.SubmitL1Block(bl, receipts, request.IsLatest) + blockSubmissionResponse, err := s.enclave.SubmitL1Block(ctx, bl, receipts, request.IsLatest) if err != nil { var rejErr *errutil.BlockRejectError isReject := errors.As(err, &rejErr) @@ -152,8 +152,8 @@ func (s *RPCServer) SubmitL1Block(_ context.Context, request *generated.SubmitBl return &generated.SubmitBlockResponse{BlockSubmissionResponse: msg}, nil } -func (s *RPCServer) SubmitTx(_ context.Context, request *generated.SubmitTxRequest) (*generated.SubmitTxResponse, error) { - enclaveResponse, sysError := s.enclave.SubmitTx(request.EncryptedTx) +func (s *RPCServer) SubmitTx(ctx context.Context, request *generated.SubmitTxRequest) (*generated.SubmitTxResponse, error) { + enclaveResponse, sysError := s.enclave.SubmitTx(ctx, request.EncryptedTx) if sysError != nil { s.logger.Error("Error submitting tx", log.ErrKey, sysError) return &generated.SubmitTxResponse{SystemError: toRPCError(sysError)}, nil @@ -161,17 +161,17 @@ func (s *RPCServer) SubmitTx(_ context.Context, request *generated.SubmitTxReque return &generated.SubmitTxResponse{EncodedEnclaveResponse: enclaveResponse.Encode()}, nil } -func (s *RPCServer) SubmitBatch(_ context.Context, request *generated.SubmitBatchRequest) (*generated.SubmitBatchResponse, error) { +func (s *RPCServer) SubmitBatch(ctx context.Context, request *generated.SubmitBatchRequest) (*generated.SubmitBatchResponse, error) { batch := rpc.FromExtBatchMsg(request.Batch) - sysError := s.enclave.SubmitBatch(batch) + sysError := s.enclave.SubmitBatch(ctx, batch) if sysError != nil { s.logger.Error("Error submitting batch", log.ErrKey, sysError) } return &generated.SubmitBatchResponse{SystemError: toRPCError(sysError)}, nil } -func (s *RPCServer) ObsCall(_ context.Context, request *generated.ObsCallRequest) (*generated.ObsCallResponse, error) { - enclaveResp, sysError := s.enclave.ObsCall(request.EncryptedParams) +func (s *RPCServer) ObsCall(ctx context.Context, request *generated.ObsCallRequest) (*generated.ObsCallResponse, error) { + enclaveResp, sysError := s.enclave.ObsCall(ctx, request.EncryptedParams) if sysError != nil { s.logger.Error("Error calling ObsCall", log.ErrKey, sysError) return &generated.ObsCallResponse{SystemError: toRPCError(sysError)}, nil @@ -179,8 +179,8 @@ func (s *RPCServer) ObsCall(_ context.Context, request *generated.ObsCallRequest return &generated.ObsCallResponse{EncodedEnclaveResponse: enclaveResp.Encode()}, nil } -func (s *RPCServer) GetTransactionCount(_ context.Context, request *generated.GetTransactionCountRequest) (*generated.GetTransactionCountResponse, error) { - enclaveResp, sysError := s.enclave.GetTransactionCount(request.EncryptedParams) +func (s *RPCServer) GetTransactionCount(ctx context.Context, request *generated.GetTransactionCountRequest) (*generated.GetTransactionCountResponse, error) { + enclaveResp, sysError := s.enclave.GetTransactionCount(ctx, request.EncryptedParams) if sysError != nil { s.logger.Error("Error tx count", log.ErrKey, sysError) return &generated.GetTransactionCountResponse{SystemError: toRPCError(sysError)}, nil @@ -194,8 +194,8 @@ func (s *RPCServer) Stop(context.Context, *generated.StopRequest) (*generated.St return &generated.StopResponse{SystemError: toRPCError(s.enclave.Stop())}, nil } -func (s *RPCServer) GetTransaction(_ context.Context, request *generated.GetTransactionRequest) (*generated.GetTransactionResponse, error) { - enclaveResp, sysError := s.enclave.GetTransaction(request.EncryptedParams) +func (s *RPCServer) GetTransaction(ctx context.Context, request *generated.GetTransactionRequest) (*generated.GetTransactionResponse, error) { + enclaveResp, sysError := s.enclave.GetTransaction(ctx, request.EncryptedParams) if sysError != nil { s.logger.Error("Error get tx", log.ErrKey, sysError) return &generated.GetTransactionResponse{SystemError: toRPCError(sysError)}, nil @@ -203,8 +203,8 @@ func (s *RPCServer) GetTransaction(_ context.Context, request *generated.GetTran return &generated.GetTransactionResponse{EncodedEnclaveResponse: enclaveResp.Encode()}, nil } -func (s *RPCServer) GetTransactionReceipt(_ context.Context, request *generated.GetTransactionReceiptRequest) (*generated.GetTransactionReceiptResponse, error) { - enclaveResponse, sysError := s.enclave.GetTransactionReceipt(request.EncryptedParams) +func (s *RPCServer) GetTransactionReceipt(ctx context.Context, request *generated.GetTransactionReceiptRequest) (*generated.GetTransactionReceiptResponse, error) { + enclaveResponse, sysError := s.enclave.GetTransactionReceipt(ctx, request.EncryptedParams) if sysError != nil { s.logger.Error("Error getting tx receipt", log.ErrKey, sysError) return &generated.GetTransactionReceiptResponse{SystemError: toRPCError(sysError)}, nil @@ -212,8 +212,8 @@ func (s *RPCServer) GetTransactionReceipt(_ context.Context, request *generated. return &generated.GetTransactionReceiptResponse{EncodedEnclaveResponse: enclaveResponse.Encode()}, nil } -func (s *RPCServer) GetBalance(_ context.Context, request *generated.GetBalanceRequest) (*generated.GetBalanceResponse, error) { - enclaveResp, sysError := s.enclave.GetBalance(request.EncryptedParams) +func (s *RPCServer) GetBalance(ctx context.Context, request *generated.GetBalanceRequest) (*generated.GetBalanceResponse, error) { + enclaveResp, sysError := s.enclave.GetBalance(ctx, request.EncryptedParams) if sysError != nil { s.logger.Error("Error getting balance", log.ErrKey, sysError) return &generated.GetBalanceResponse{SystemError: toRPCError(sysError)}, nil @@ -221,11 +221,11 @@ func (s *RPCServer) GetBalance(_ context.Context, request *generated.GetBalanceR return &generated.GetBalanceResponse{EncodedEnclaveResponse: enclaveResp.Encode()}, nil } -func (s *RPCServer) GetCode(_ context.Context, request *generated.GetCodeRequest) (*generated.GetCodeResponse, error) { +func (s *RPCServer) GetCode(ctx context.Context, request *generated.GetCodeRequest) (*generated.GetCodeResponse, error) { address := gethcommon.BytesToAddress(request.Address) rollupHash := gethcommon.BytesToHash(request.RollupHash) - code, sysError := s.enclave.GetCode(address, &rollupHash) + code, sysError := s.enclave.GetCode(ctx, address, &rollupHash) if sysError != nil { s.logger.Error("Error getting code", log.ErrKey, sysError) return &generated.GetCodeResponse{SystemError: toRPCError(sysError)}, nil @@ -233,8 +233,8 @@ func (s *RPCServer) GetCode(_ context.Context, request *generated.GetCodeRequest return &generated.GetCodeResponse{Code: code}, nil } -func (s *RPCServer) Subscribe(_ context.Context, req *generated.SubscribeRequest) (*generated.SubscribeResponse, error) { - sysError := s.enclave.Subscribe(gethrpc.ID(req.Id), req.EncryptedSubscription) +func (s *RPCServer) Subscribe(ctx context.Context, req *generated.SubscribeRequest) (*generated.SubscribeResponse, error) { + sysError := s.enclave.Subscribe(ctx, gethrpc.ID(req.Id), req.EncryptedSubscription) if sysError != nil { s.logger.Error("Error subscribing", log.ErrKey, sysError) } @@ -249,8 +249,8 @@ func (s *RPCServer) Unsubscribe(_ context.Context, req *generated.UnsubscribeReq return &generated.UnsubscribeResponse{SystemError: toRPCError(sysError)}, nil } -func (s *RPCServer) EstimateGas(_ context.Context, req *generated.EstimateGasRequest) (*generated.EstimateGasResponse, error) { - enclaveResp, sysError := s.enclave.EstimateGas(req.EncryptedParams) +func (s *RPCServer) EstimateGas(ctx context.Context, req *generated.EstimateGasRequest) (*generated.EstimateGasResponse, error) { + enclaveResp, sysError := s.enclave.EstimateGas(ctx, req.EncryptedParams) if sysError != nil { s.logger.Error("Error estimating gas", log.ErrKey, sysError) return &generated.EstimateGasResponse{SystemError: toRPCError(sysError)}, nil @@ -258,8 +258,8 @@ func (s *RPCServer) EstimateGas(_ context.Context, req *generated.EstimateGasReq return &generated.EstimateGasResponse{EncodedEnclaveResponse: enclaveResp.Encode()}, nil } -func (s *RPCServer) GetLogs(_ context.Context, req *generated.GetLogsRequest) (*generated.GetLogsResponse, error) { - enclaveResp, sysError := s.enclave.GetLogs(req.EncryptedParams) +func (s *RPCServer) GetLogs(ctx context.Context, req *generated.GetLogsRequest) (*generated.GetLogsResponse, error) { + enclaveResp, sysError := s.enclave.GetLogs(ctx, req.EncryptedParams) if sysError != nil { s.logger.Error("Error getting logs", log.ErrKey, sysError) return &generated.GetLogsResponse{SystemError: toRPCError(sysError)}, nil @@ -267,21 +267,21 @@ func (s *RPCServer) GetLogs(_ context.Context, req *generated.GetLogsRequest) (* return &generated.GetLogsResponse{EncodedEnclaveResponse: enclaveResp.Encode()}, nil } -func (s *RPCServer) HealthCheck(_ context.Context, _ *generated.EmptyArgs) (*generated.HealthCheckResponse, error) { - healthy, sysError := s.enclave.HealthCheck() +func (s *RPCServer) HealthCheck(ctx context.Context, _ *generated.EmptyArgs) (*generated.HealthCheckResponse, error) { + healthy, sysError := s.enclave.HealthCheck(ctx) if sysError != nil { return &generated.HealthCheckResponse{SystemError: toRPCError(sysError)}, nil } return &generated.HealthCheckResponse{Status: healthy}, nil } -func (s *RPCServer) CreateRollup(_ context.Context, req *generated.CreateRollupRequest) (*generated.CreateRollupResponse, error) { +func (s *RPCServer) CreateRollup(ctx context.Context, req *generated.CreateRollupRequest) (*generated.CreateRollupResponse, error) { var fromSeqNo uint64 = 1 if req.FromSequenceNumber != nil && *req.FromSequenceNumber > common.L2GenesisSeqNo { fromSeqNo = *req.FromSequenceNumber } - rollup, sysError := s.enclave.CreateRollup(fromSeqNo) + rollup, sysError := s.enclave.CreateRollup(ctx, fromSeqNo) if sysError != nil { s.logger.Error("Error creating rollup", log.ErrKey, sysError) } @@ -294,15 +294,15 @@ func (s *RPCServer) CreateRollup(_ context.Context, req *generated.CreateRollupR }, nil } -func (s *RPCServer) CreateBatch(_ context.Context, r *generated.CreateBatchRequest) (*generated.CreateBatchResponse, error) { - sysError := s.enclave.CreateBatch(r.SkipIfEmpty) +func (s *RPCServer) CreateBatch(ctx context.Context, r *generated.CreateBatchRequest) (*generated.CreateBatchResponse, error) { + sysError := s.enclave.CreateBatch(ctx, r.SkipIfEmpty) if sysError != nil { s.logger.Error("Error creating batch", log.ErrKey, sysError) } return &generated.CreateBatchResponse{}, sysError } -func (s *RPCServer) DebugTraceTransaction(_ context.Context, req *generated.DebugTraceTransactionRequest) (*generated.DebugTraceTransactionResponse, error) { +func (s *RPCServer) DebugTraceTransaction(ctx context.Context, req *generated.DebugTraceTransactionRequest) (*generated.DebugTraceTransactionResponse, error) { txHash := gethcommon.BytesToHash(req.TxHash) var config tracers.TraceConfig @@ -315,12 +315,12 @@ func (s *RPCServer) DebugTraceTransaction(_ context.Context, req *generated.Debu }, nil } - traceTx, sysError := s.enclave.DebugTraceTransaction(txHash, &config) + traceTx, sysError := s.enclave.DebugTraceTransaction(ctx, txHash, &config) return &generated.DebugTraceTransactionResponse{Msg: string(traceTx), SystemError: toRPCError(sysError)}, nil } -func (s *RPCServer) GetBatch(_ context.Context, request *generated.GetBatchRequest) (*generated.GetBatchResponse, error) { - batch, err := s.enclave.GetBatch(gethcommon.BytesToHash(request.KnownHead)) +func (s *RPCServer) GetBatch(ctx context.Context, request *generated.GetBatchRequest) (*generated.GetBatchResponse, error) { + batch, err := s.enclave.GetBatch(ctx, gethcommon.BytesToHash(request.KnownHead)) if err != nil { s.logger.Error("Error getting batch", log.ErrKey, err) // todo do we want to exit here or return the usual response @@ -341,8 +341,8 @@ func (s *RPCServer) GetBatch(_ context.Context, request *generated.GetBatchReque }, err } -func (s *RPCServer) GetBatchBySeqNo(_ context.Context, request *generated.GetBatchBySeqNoRequest) (*generated.GetBatchResponse, error) { - batch, err := s.enclave.GetBatchBySeqNo(request.SeqNo) +func (s *RPCServer) GetBatchBySeqNo(ctx context.Context, request *generated.GetBatchBySeqNoRequest) (*generated.GetBatchResponse, error) { + batch, err := s.enclave.GetBatchBySeqNo(ctx, request.SeqNo) if err != nil { s.logger.Error("Error getting batch by seq", log.ErrKey, err) // todo do we want to exit here or return the usual response @@ -363,8 +363,8 @@ func (s *RPCServer) GetBatchBySeqNo(_ context.Context, request *generated.GetBat }, err } -func (s *RPCServer) GetRollupData(_ context.Context, request *generated.GetRollupDataRequest) (*generated.GetRollupDataResponse, error) { - rollupMetadata, sysError := s.enclave.GetRollupData(gethcommon.BytesToHash(request.Hash)) +func (s *RPCServer) GetRollupData(ctx context.Context, request *generated.GetRollupDataRequest) (*generated.GetRollupDataResponse, error) { + rollupMetadata, sysError := s.enclave.GetRollupData(ctx, gethcommon.BytesToHash(request.Hash)) if sysError != nil { s.logger.Error("Error fetching rollup metadata", log.ErrKey, sysError) return nil, sysError @@ -406,10 +406,10 @@ func (s *RPCServer) StreamL2Updates(_ *generated.StreamL2UpdatesRequest, stream return nil } -func (s *RPCServer) DebugEventLogRelevancy(_ context.Context, req *generated.DebugEventLogRelevancyRequest) (*generated.DebugEventLogRelevancyResponse, error) { +func (s *RPCServer) DebugEventLogRelevancy(ctx context.Context, req *generated.DebugEventLogRelevancyRequest) (*generated.DebugEventLogRelevancyResponse, error) { txHash := gethcommon.BytesToHash(req.TxHash) - logs, sysError := s.enclave.DebugEventLogRelevancy(txHash) + logs, sysError := s.enclave.DebugEventLogRelevancy(ctx, txHash) if sysError != nil { s.logger.Error("Error debugging event relevancy", log.ErrKey, sysError) } @@ -417,8 +417,8 @@ func (s *RPCServer) DebugEventLogRelevancy(_ context.Context, req *generated.Deb return &generated.DebugEventLogRelevancyResponse{Msg: string(logs), SystemError: toRPCError(sysError)}, nil } -func (s *RPCServer) GetTotalContractCount(_ context.Context, _ *generated.GetTotalContractCountRequest) (*generated.GetTotalContractCountResponse, error) { - count, sysError := s.enclave.GetTotalContractCount() +func (s *RPCServer) GetTotalContractCount(ctx context.Context, _ *generated.GetTotalContractCountRequest) (*generated.GetTotalContractCountResponse, error) { + count, sysError := s.enclave.GetTotalContractCount(ctx) if sysError != nil { s.logger.Error("Error GetTotalContractCount", log.ErrKey, sysError) } @@ -433,8 +433,8 @@ func (s *RPCServer) GetTotalContractCount(_ context.Context, _ *generated.GetTot }, nil } -func (s *RPCServer) GetReceiptsByAddress(_ context.Context, req *generated.GetReceiptsByAddressRequest) (*generated.GetReceiptsByAddressResponse, error) { - enclaveResp, sysError := s.enclave.GetCustomQuery(req.EncryptedParams) +func (s *RPCServer) GetReceiptsByAddress(ctx context.Context, req *generated.GetReceiptsByAddressRequest) (*generated.GetReceiptsByAddressResponse, error) { + enclaveResp, sysError := s.enclave.GetCustomQuery(ctx, req.EncryptedParams) if sysError != nil { s.logger.Error("Error getting receipt", log.ErrKey, sysError) return &generated.GetReceiptsByAddressResponse{SystemError: toRPCError(sysError)}, nil @@ -442,8 +442,8 @@ func (s *RPCServer) GetReceiptsByAddress(_ context.Context, req *generated.GetRe return &generated.GetReceiptsByAddressResponse{EncodedEnclaveResponse: enclaveResp.Encode()}, nil } -func (s *RPCServer) GetPublicTransactionData(_ context.Context, req *generated.GetPublicTransactionDataRequest) (*generated.GetPublicTransactionDataResponse, error) { - publicTxData, sysError := s.enclave.GetPublicTransactionData(&common.QueryPagination{ +func (s *RPCServer) GetPublicTransactionData(ctx context.Context, req *generated.GetPublicTransactionDataRequest) (*generated.GetPublicTransactionDataResponse, error) { + publicTxData, sysError := s.enclave.GetPublicTransactionData(ctx, &common.QueryPagination{ Offset: uint64(req.Pagination.GetOffset()), Size: uint(req.Pagination.GetSize()), }) @@ -462,8 +462,8 @@ func (s *RPCServer) GetPublicTransactionData(_ context.Context, req *generated.G return &generated.GetPublicTransactionDataResponse{PublicTransactionData: marshal}, nil } -func (s *RPCServer) EnclavePublicConfig(_ context.Context, _ *generated.EnclavePublicConfigRequest) (*generated.EnclavePublicConfigResponse, error) { - enclaveCfg, sysError := s.enclave.EnclavePublicConfig() +func (s *RPCServer) EnclavePublicConfig(ctx context.Context, _ *generated.EnclavePublicConfigRequest) (*generated.EnclavePublicConfigResponse, error) { + enclaveCfg, sysError := s.enclave.EnclavePublicConfig(ctx) if sysError != nil { s.logger.Error("Error getting message bus address", log.ErrKey, sysError) return &generated.EnclavePublicConfigResponse{SystemError: toRPCError(sysError)}, nil diff --git a/go/enclave/storage/db_init.go b/go/enclave/storage/db_init.go index bde5f2e063..853143272d 100644 --- a/go/enclave/storage/db_init.go +++ b/go/enclave/storage/db_init.go @@ -21,7 +21,7 @@ func CreateDBFromConfig(cfg *config.EnclaveConfig, logger gethlog.Logger) (encla if cfg.UseInMemoryDB { logger.Info("UseInMemoryDB flag is true, data will not be persisted. Creating in-memory database...") // this creates a temporary sqlite sqldb - return sqlite.CreateTemporarySQLiteDB(cfg.HostID.String(), "mode=memory&cache=shared&_foreign_keys=on", logger) + return sqlite.CreateTemporarySQLiteDB(cfg.HostID.String(), "mode=memory&cache=shared&_foreign_keys=on", *cfg, logger) } if !cfg.WillAttest { @@ -29,7 +29,7 @@ func CreateDBFromConfig(cfg *config.EnclaveConfig, logger gethlog.Logger) (encla logger.Warn("Attestation is disabled, using a basic sqlite DB for persistence") // when we want to test persistence after node restart the SqliteDBPath should be set // (if empty string then a temp sqldb file will be created for the lifetime of the enclave) - return sqlite.CreateTemporarySQLiteDB(cfg.SqliteDBPath, "_foreign_keys=on", logger) + return sqlite.CreateTemporarySQLiteDB(cfg.SqliteDBPath, "_foreign_keys=on", *cfg, logger) } // persistent and with attestation means connecting to edgeless DB in a trusted enclave from a secure enclave @@ -62,5 +62,5 @@ func getEdgelessDB(cfg *config.EnclaveConfig, logger gethlog.Logger) (enclavedb. return nil, fmt.Errorf("failed to prepare EdgelessDB connection - EdgelessDBHost was not set on enclave config") } dbConfig := edgelessdb.Config{Host: cfg.EdgelessDBHost} - return edgelessdb.Connector(&dbConfig, logger) + return edgelessdb.Connector(&dbConfig, *cfg, logger) } diff --git a/go/enclave/storage/enclavedb/batch.go b/go/enclave/storage/enclavedb/batch.go index e0ce4cac1b..a93938db47 100644 --- a/go/enclave/storage/enclavedb/batch.go +++ b/go/enclave/storage/enclavedb/batch.go @@ -2,6 +2,7 @@ package enclavedb import ( "bytes" + "context" "crypto/sha256" "database/sql" "errors" @@ -49,7 +50,7 @@ const ( ) // WriteBatchAndTransactions - persists the batch and the transactions -func WriteBatchAndTransactions(dbtx DBTransaction, batch *core.Batch, convertedHash gethcommon.Hash) error { +func WriteBatchAndTransactions(ctx context.Context, dbtx DBTransaction, batch *core.Batch, convertedHash gethcommon.Hash) error { // todo - optimize for reorgs batchBodyID := batch.SeqNo().Uint64() @@ -70,7 +71,7 @@ func WriteBatchAndTransactions(dbtx DBTransaction, batch *core.Batch, convertedH } var isCanon bool - err = dbtx.GetDB().QueryRow(isCanonQuery, truncTo16(batch.Header.L1Proof)).Scan(&isCanon) + err = dbtx.GetDB().QueryRowContext(ctx, isCanonQuery, truncTo16(batch.Header.L1Proof)).Scan(&isCanon) if err != nil { // if the block is not found, we assume it is non-canonical // fmt.Printf("IsCanon %s err: %s\n", batch.Header.L1Proof, err) @@ -123,7 +124,7 @@ func WriteBatchAndTransactions(dbtx DBTransaction, batch *core.Batch, convertedH } // WriteBatchExecution - insert all receipts to the db -func WriteBatchExecution(dbtx DBTransaction, seqNo *big.Int, receipts []*types.Receipt) error { +func WriteBatchExecution(ctx context.Context, dbtx DBTransaction, seqNo *big.Int, receipts []*types.Receipt) error { dbtx.ExecuteSQL(updateBatchExecuted, seqNo.Uint64()) args := make([]any, 0) @@ -157,39 +158,39 @@ func executedTransactionID(batchHash *common.L2BatchHash, txHash *common.L2TxHas return truncTo16(sha256.Sum256(execTxID)) } -func ReadBatchBySeqNo(db *sql.DB, seqNo uint64) (*core.Batch, error) { - return fetchBatch(db, " where sequence=?", seqNo) +func ReadBatchBySeqNo(ctx context.Context, db *sql.DB, seqNo uint64) (*core.Batch, error) { + return fetchBatch(ctx, db, " where sequence=?", seqNo) } -func ReadBatchByHash(db *sql.DB, hash common.L2BatchHash) (*core.Batch, error) { - return fetchBatch(db, " where b.hash=?", truncTo16(hash)) +func ReadBatchByHash(ctx context.Context, db *sql.DB, hash common.L2BatchHash) (*core.Batch, error) { + return fetchBatch(ctx, db, " where b.hash=?", truncTo16(hash)) } -func ReadCanonicalBatchByHeight(db *sql.DB, height uint64) (*core.Batch, error) { - return fetchBatch(db, " where b.height=? and is_canonical=true", height) +func ReadCanonicalBatchByHeight(ctx context.Context, db *sql.DB, height uint64) (*core.Batch, error) { + return fetchBatch(ctx, db, " where b.height=? and is_canonical=true", height) } -func ReadNonCanonicalBatches(db *sql.DB, startAtSeq uint64, endSeq uint64) ([]*core.Batch, error) { - return fetchBatches(db, " where b.sequence>=? and b.sequence <=? and b.is_canonical=false order by b.sequence", startAtSeq, endSeq) +func ReadNonCanonicalBatches(ctx context.Context, db *sql.DB, startAtSeq uint64, endSeq uint64) ([]*core.Batch, error) { + return fetchBatches(ctx, db, " where b.sequence>=? and b.sequence <=? and b.is_canonical=false order by b.sequence", startAtSeq, endSeq) } -func ReadBatchHeader(db *sql.DB, hash gethcommon.Hash) (*common.BatchHeader, error) { - return fetchBatchHeader(db, " where hash=?", truncTo16(hash)) +func ReadBatchHeader(ctx context.Context, db *sql.DB, hash gethcommon.Hash) (*common.BatchHeader, error) { + return fetchBatchHeader(ctx, db, " where hash=?", truncTo16(hash)) } // todo - is there a better way to write this query? -func ReadCurrentHeadBatch(db *sql.DB) (*core.Batch, error) { - return fetchBatch(db, " where b.is_canonical=true and b.is_executed=true and b.height=(select max(b1.height) from batch b1 where b1.is_canonical=true and b1.is_executed=true)") +func ReadCurrentHeadBatch(ctx context.Context, db *sql.DB) (*core.Batch, error) { + return fetchBatch(ctx, db, " where b.is_canonical=true and b.is_executed=true and b.height=(select max(b1.height) from batch b1 where b1.is_canonical=true and b1.is_executed=true)") } -func ReadBatchesByBlock(db *sql.DB, hash common.L1BlockHash) ([]*core.Batch, error) { - return fetchBatches(db, " where b.l1_proof=? order by b.sequence", truncTo16(hash)) +func ReadBatchesByBlock(ctx context.Context, db *sql.DB, hash common.L1BlockHash) ([]*core.Batch, error) { + return fetchBatches(ctx, db, " where b.l1_proof=? order by b.sequence", truncTo16(hash)) } -func ReadCurrentSequencerNo(db *sql.DB) (*big.Int, error) { +func ReadCurrentSequencerNo(ctx context.Context, db *sql.DB) (*big.Int, error) { var seq sql.NullInt64 query := "select max(sequence) from batch" - err := db.QueryRow(query).Scan(&seq) + err := db.QueryRowContext(ctx, query).Scan(&seq) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error @@ -203,20 +204,20 @@ func ReadCurrentSequencerNo(db *sql.DB) (*big.Int, error) { return big.NewInt(seq.Int64), nil } -func ReadHeadBatchForBlock(db *sql.DB, l1Hash common.L1BlockHash) (*core.Batch, error) { +func ReadHeadBatchForBlock(ctx context.Context, db *sql.DB, l1Hash common.L1BlockHash) (*core.Batch, error) { query := " where b.is_canonical=true and b.is_executed=true and b.height=(select max(b1.height) from batch b1 where b1.is_canonical=true and b1.is_executed=true and b1.l1_proof=?)" - return fetchBatch(db, query, truncTo16(l1Hash)) + return fetchBatch(ctx, db, query, truncTo16(l1Hash)) } -func fetchBatch(db *sql.DB, whereQuery string, args ...any) (*core.Batch, error) { +func fetchBatch(ctx context.Context, db *sql.DB, whereQuery string, args ...any) (*core.Batch, error) { var header string var body []byte query := selectBatch + " " + whereQuery var err error if len(args) > 0 { - err = db.QueryRow(query, args...).Scan(&header, &body) + err = db.QueryRowContext(ctx, query, args...).Scan(&header, &body) } else { - err = db.QueryRow(query).Scan(&header, &body) + err = db.QueryRowContext(ctx, query).Scan(&header, &body) } if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -242,10 +243,10 @@ func fetchBatch(db *sql.DB, whereQuery string, args ...any) (*core.Batch, error) return &b, nil } -func fetchBatches(db *sql.DB, whereQuery string, args ...any) ([]*core.Batch, error) { +func fetchBatches(ctx context.Context, db *sql.DB, whereQuery string, args ...any) ([]*core.Batch, error) { result := make([]*core.Batch, 0) - rows, err := db.Query(selectBatch+" "+whereQuery, args...) + rows, err := db.QueryContext(ctx, selectBatch+" "+whereQuery, args...) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error @@ -282,14 +283,14 @@ func fetchBatches(db *sql.DB, whereQuery string, args ...any) ([]*core.Batch, er return result, nil } -func fetchBatchHeader(db *sql.DB, whereQuery string, args ...any) (*common.BatchHeader, error) { +func fetchBatchHeader(ctx context.Context, db *sql.DB, whereQuery string, args ...any) (*common.BatchHeader, error) { var header string query := selectHeader + " " + whereQuery var err error if len(args) > 0 { - err = db.QueryRow(query, args...).Scan(&header) + err = db.QueryRowContext(ctx, query, args...).Scan(&header) } else { - err = db.QueryRow(query).Scan(&header) + err = db.QueryRowContext(ctx, query).Scan(&header) } if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -306,11 +307,11 @@ func fetchBatchHeader(db *sql.DB, whereQuery string, args ...any) (*common.Batch return h, nil } -func selectReceipts(db *sql.DB, config *params.ChainConfig, query string, args ...any) (types.Receipts, error) { +func selectReceipts(ctx context.Context, db *sql.DB, config *params.ChainConfig, query string, args ...any) (types.Receipts, error) { var allReceipts types.Receipts // where batch=? - rows, err := db.Query(queryReceipts+" "+query, args...) + rows, err := db.QueryContext(ctx, queryReceipts+" "+query, args...) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error @@ -362,12 +363,12 @@ func selectReceipts(db *sql.DB, config *params.ChainConfig, query string, args . // The current implementation populates these metadata fields by reading the receipts' // corresponding block body, so if the block body is not found it will return nil even // if the receipt itself is stored. -func ReadReceiptsByBatchHash(db *sql.DB, hash common.L2BatchHash, config *params.ChainConfig) (types.Receipts, error) { - return selectReceipts(db, config, "where batch.hash = ?", truncTo16(hash)) +func ReadReceiptsByBatchHash(ctx context.Context, db *sql.DB, hash common.L2BatchHash, config *params.ChainConfig) (types.Receipts, error) { + return selectReceipts(ctx, db, config, "where batch.hash = ?", truncTo16(hash)) } -func ReadReceipt(db *sql.DB, hash common.L2TxHash, config *params.ChainConfig) (*types.Receipt, error) { - row := db.QueryRow(queryReceipts+" where tx=?", truncTo16(hash)) +func ReadReceipt(ctx context.Context, db *sql.DB, hash common.L2TxHash, config *params.ChainConfig) (*types.Receipt, error) { + row := db.QueryRowContext(ctx, queryReceipts+" where tx=?", truncTo16(hash)) // receipt, tx, batch, height var receiptData []byte var txData []byte @@ -402,8 +403,8 @@ func ReadReceipt(db *sql.DB, hash common.L2TxHash, config *params.ChainConfig) ( return receipts[0], nil } -func ReadTransaction(db *sql.DB, txHash gethcommon.Hash) (*types.Transaction, common.L2BatchHash, uint64, uint64, error) { - row := db.QueryRow(selectTxQuery, truncTo16(txHash)) +func ReadTransaction(ctx context.Context, db *sql.DB, txHash gethcommon.Hash) (*types.Transaction, common.L2BatchHash, uint64, uint64, error) { + row := db.QueryRowContext(ctx, selectTxQuery, truncTo16(txHash)) // tx, batch, height, idx var txData []byte @@ -427,8 +428,8 @@ func ReadTransaction(db *sql.DB, txHash gethcommon.Hash) (*types.Transaction, co return tx, batch, height, idx, nil } -func GetContractCreationTx(db *sql.DB, address gethcommon.Address) (*gethcommon.Hash, error) { - row := db.QueryRow(selectContractCreationTx, address.Bytes()) +func GetContractCreationTx(ctx context.Context, db *sql.DB, address gethcommon.Address) (*gethcommon.Hash, error) { + row := db.QueryRowContext(ctx, selectContractCreationTx, address.Bytes()) var txHashBytes []byte err := row.Scan(&txHashBytes) @@ -444,8 +445,8 @@ func GetContractCreationTx(db *sql.DB, address gethcommon.Address) (*gethcommon. return &txHash, nil } -func ReadContractCreationCount(db *sql.DB) (*big.Int, error) { - row := db.QueryRow(selectTotalCreatedContracts) +func ReadContractCreationCount(ctx context.Context, db *sql.DB) (*big.Int, error) { + row := db.QueryRowContext(ctx, selectTotalCreatedContracts) var count int64 err := row.Scan(&count) @@ -456,12 +457,12 @@ func ReadContractCreationCount(db *sql.DB) (*big.Int, error) { return big.NewInt(count), nil } -func ReadUnexecutedBatches(db *sql.DB, from *big.Int) ([]*core.Batch, error) { - return fetchBatches(db, "where is_executed=false and is_canonical=true and sequence >= ? order by b.sequence", from.Uint64()) +func ReadUnexecutedBatches(ctx context.Context, db *sql.DB, from *big.Int) ([]*core.Batch, error) { + return fetchBatches(ctx, db, "where is_executed=false and is_canonical=true and sequence >= ? order by b.sequence", from.Uint64()) } -func BatchWasExecuted(db *sql.DB, hash common.L2BatchHash) (bool, error) { - row := db.QueryRow(queryBatchWasExecuted, truncTo16(hash)) +func BatchWasExecuted(ctx context.Context, db *sql.DB, hash common.L2BatchHash) (bool, error) { + row := db.QueryRowContext(ctx, queryBatchWasExecuted, truncTo16(hash)) var result bool err := row.Scan(&result) @@ -476,12 +477,12 @@ func BatchWasExecuted(db *sql.DB, hash common.L2BatchHash) (bool, error) { return result, nil } -func GetReceiptsPerAddress(db *sql.DB, config *params.ChainConfig, address *gethcommon.Address, pagination *common.QueryPagination) (types.Receipts, error) { - return selectReceipts(db, config, "where tx.sender_address = ? ORDER BY height DESC LIMIT ? OFFSET ? ", address.Bytes(), pagination.Size, pagination.Offset) +func GetReceiptsPerAddress(ctx context.Context, db *sql.DB, config *params.ChainConfig, address *gethcommon.Address, pagination *common.QueryPagination) (types.Receipts, error) { + return selectReceipts(ctx, db, config, "where tx.sender_address = ? ORDER BY height DESC LIMIT ? OFFSET ? ", address.Bytes(), pagination.Size, pagination.Offset) } -func GetReceiptsPerAddressCount(db *sql.DB, address *gethcommon.Address) (uint64, error) { - row := db.QueryRow(queryReceiptsCount+" where tx.sender_address = ?", address.Bytes()) +func GetReceiptsPerAddressCount(ctx context.Context, db *sql.DB, address *gethcommon.Address) (uint64, error) { + row := db.QueryRowContext(ctx, queryReceiptsCount+" where tx.sender_address = ?", address.Bytes()) var count uint64 err := row.Scan(&count) @@ -492,14 +493,14 @@ func GetReceiptsPerAddressCount(db *sql.DB, address *gethcommon.Address) (uint64 return count, nil } -func GetPublicTransactionData(db *sql.DB, pagination *common.QueryPagination) ([]common.PublicTransaction, error) { - return selectPublicTxsBySender(db, " ORDER BY height DESC LIMIT ? OFFSET ? ", pagination.Size, pagination.Offset) +func GetPublicTransactionData(ctx context.Context, db *sql.DB, pagination *common.QueryPagination) ([]common.PublicTransaction, error) { + return selectPublicTxsBySender(ctx, db, " ORDER BY height DESC LIMIT ? OFFSET ? ", pagination.Size, pagination.Offset) } -func selectPublicTxsBySender(db *sql.DB, query string, args ...any) ([]common.PublicTransaction, error) { +func selectPublicTxsBySender(ctx context.Context, db *sql.DB, query string, args ...any) ([]common.PublicTransaction, error) { var publicTxs []common.PublicTransaction - rows, err := db.Query(queryTxList+" "+query, args...) + rows, err := db.QueryContext(ctx, queryTxList+" "+query, args...) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error @@ -536,8 +537,8 @@ func selectPublicTxsBySender(db *sql.DB, query string, args ...any) ([]common.Pu return publicTxs, nil } -func GetPublicTransactionCount(db *sql.DB) (uint64, error) { - row := db.QueryRow(queryTxCountList) +func GetPublicTransactionCount(ctx context.Context, db *sql.DB) (uint64, error) { + row := db.QueryRowContext(ctx, queryTxCountList) var count uint64 err := row.Scan(&count) @@ -548,11 +549,11 @@ func GetPublicTransactionCount(db *sql.DB) (uint64, error) { return count, nil } -func FetchConvertedBatchHash(db *sql.DB, seqNo uint64) (gethcommon.Hash, error) { +func FetchConvertedBatchHash(ctx context.Context, db *sql.DB, seqNo uint64) (gethcommon.Hash, error) { var hash []byte query := "select converted_hash from batch where sequence=?" - err := db.QueryRow(query, seqNo).Scan(&hash) + err := db.QueryRowContext(ctx, query, seqNo).Scan(&hash) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error diff --git a/go/enclave/storage/enclavedb/block.go b/go/enclave/storage/enclavedb/block.go index 2832ddf985..e4b605ab65 100644 --- a/go/enclave/storage/enclavedb/block.go +++ b/go/enclave/storage/enclavedb/block.go @@ -2,6 +2,7 @@ package enclavedb import ( "bytes" + "context" "database/sql" "errors" "fmt" @@ -31,7 +32,7 @@ const ( updateCanonicalBatches = "update batch set is_canonical=? where l1_proof in " ) -func WriteBlock(dbtx DBTransaction, b *types.Header) error { +func WriteBlock(ctx context.Context, dbtx DBTransaction, b *types.Header) error { header, err := rlp.EncodeToBytes(b) if err != nil { return fmt.Errorf("could not encode block header. Cause: %w", err) @@ -51,16 +52,16 @@ func WriteBlock(dbtx DBTransaction, b *types.Header) error { return nil } -func UpdateCanonicalBlocks(dbtx DBTransaction, canonical []common.L1BlockHash, nonCanonical []common.L1BlockHash) { +func UpdateCanonicalBlocks(ctx context.Context, dbtx DBTransaction, canonical []common.L1BlockHash, nonCanonical []common.L1BlockHash) { if len(nonCanonical) > 0 { - updateCanonicalValue(dbtx, false, nonCanonical) + updateCanonicalValue(ctx, dbtx, false, nonCanonical) } if len(canonical) > 0 { - updateCanonicalValue(dbtx, true, canonical) + updateCanonicalValue(ctx, dbtx, true, canonical) } } -func updateCanonicalValue(dbtx DBTransaction, isCanonical bool, values []common.L1BlockHash) { +func updateCanonicalValue(ctx context.Context, dbtx DBTransaction, isCanonical bool, values []common.L1BlockHash) { argPlaceholders := strings.Repeat("?,", len(values)) argPlaceholders = argPlaceholders[0 : len(argPlaceholders)-1] // remove trailing comma @@ -77,19 +78,19 @@ func updateCanonicalValue(dbtx DBTransaction, isCanonical bool, values []common. } // todo - remove this. For now creates a "block" but without a body. -func FetchBlock(db *sql.DB, hash common.L1BlockHash) (*types.Block, error) { - return fetchBlock(db, " where hash=?", truncTo16(hash)) +func FetchBlock(ctx context.Context, db *sql.DB, hash common.L1BlockHash) (*types.Block, error) { + return fetchBlock(ctx, db, " where hash=?", truncTo16(hash)) } -func FetchHeadBlock(db *sql.DB) (*types.Block, error) { - return fetchBlock(db, "where is_canonical=true and height=(select max(b.height) from block b where is_canonical=true)") +func FetchHeadBlock(ctx context.Context, db *sql.DB) (*types.Block, error) { + return fetchBlock(ctx, db, "where is_canonical=true and height=(select max(b.height) from block b where is_canonical=true)") } -func FetchBlockHeaderByHeight(db *sql.DB, height *big.Int) (*types.Header, error) { - return fetchBlockHeader(db, "where is_canonical=true and height=?", height.Int64()) +func FetchBlockHeaderByHeight(ctx context.Context, db *sql.DB, height *big.Int) (*types.Header, error) { + return fetchBlockHeader(ctx, db, "where is_canonical=true and height=?", height.Int64()) } -func WriteL1Messages[T any](db *sql.DB, blockHash common.L1BlockHash, messages []T, isValueTransfer bool) error { +func WriteL1Messages[T any](ctx context.Context, db *sql.DB, blockHash common.L1BlockHash, messages []T, isValueTransfer bool) error { insert := l1msgInsert + strings.Repeat(l1msgValue+",", len(messages)) insert = insert[0 : len(insert)-1] // remove trailing comma @@ -105,16 +106,16 @@ func WriteL1Messages[T any](db *sql.DB, blockHash common.L1BlockHash, messages [ args = append(args, isValueTransfer) } if len(messages) > 0 { - _, err := db.Exec(insert, args...) + _, err := db.ExecContext(ctx, insert, args...) return err } return nil } -func FetchL1Messages[T any](db *sql.DB, blockHash common.L1BlockHash, isTransfer bool) ([]T, error) { +func FetchL1Messages[T any](ctx context.Context, db *sql.DB, blockHash common.L1BlockHash, isTransfer bool) ([]T, error) { var result []T query := selectL1Msg + " where block = ? and is_transfer = ?" - rows, err := db.Query(query, truncTo16(blockHash), isTransfer) + rows, err := db.QueryContext(ctx, query, truncTo16(blockHash), isTransfer) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error @@ -142,7 +143,7 @@ func FetchL1Messages[T any](db *sql.DB, blockHash common.L1BlockHash, isTransfer return result, nil } -func WriteRollup(dbtx DBTransaction, rollup *common.RollupHeader, internalHeader *common.CalldataRollupHeader) error { +func WriteRollup(ctx context.Context, dbtx DBTransaction, rollup *common.RollupHeader, internalHeader *common.CalldataRollupHeader) error { // Write the encoded header data, err := rlp.EncodeToBytes(rollup) if err != nil { @@ -159,7 +160,7 @@ func WriteRollup(dbtx DBTransaction, rollup *common.RollupHeader, internalHeader return nil } -func FetchReorgedRollup(db *sql.DB, reorgedBlocks []common.L1BlockHash) (*common.L2BatchHash, error) { +func FetchReorgedRollup(ctx context.Context, db *sql.DB, reorgedBlocks []common.L1BlockHash) (*common.L2BatchHash, error) { argPlaceholders := strings.Repeat("?,", len(reorgedBlocks)) argPlaceholders = argPlaceholders[0 : len(argPlaceholders)-1] // remove trailing comma @@ -170,7 +171,7 @@ func FetchReorgedRollup(db *sql.DB, reorgedBlocks []common.L1BlockHash) (*common args = append(args, truncTo16(value)) } rollup := new(common.L2BatchHash) - err := db.QueryRow(query, args...).Scan(&rollup) + err := db.QueryRowContext(ctx, query, args...).Scan(&rollup) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error @@ -181,12 +182,12 @@ func FetchReorgedRollup(db *sql.DB, reorgedBlocks []common.L1BlockHash) (*common return rollup, nil } -func FetchRollupMetadata(db *sql.DB, hash common.L2RollupHash) (*common.PublicRollupMetadata, error) { +func FetchRollupMetadata(ctx context.Context, db *sql.DB, hash common.L2RollupHash) (*common.PublicRollupMetadata, error) { var startSeq int64 var startTime uint64 rollup := new(common.PublicRollupMetadata) - err := db.QueryRow(rollupSelectMetadata, truncTo16(hash)).Scan(&startSeq, &startTime) + err := db.QueryRowContext(ctx, rollupSelectMetadata, truncTo16(hash)).Scan(&startSeq, &startTime) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errutil.ErrNotFound @@ -198,14 +199,14 @@ func FetchRollupMetadata(db *sql.DB, hash common.L2RollupHash) (*common.PublicRo return rollup, nil } -func fetchBlockHeader(db *sql.DB, whereQuery string, args ...any) (*types.Header, error) { +func fetchBlockHeader(ctx context.Context, db *sql.DB, whereQuery string, args ...any) (*types.Header, error) { var header string query := selectBlockHeader + " " + whereQuery var err error if len(args) > 0 { - err = db.QueryRow(query, args...).Scan(&header) + err = db.QueryRowContext(ctx, query, args...).Scan(&header) } else { - err = db.QueryRow(query).Scan(&header) + err = db.QueryRowContext(ctx, query).Scan(&header) } if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -222,8 +223,8 @@ func fetchBlockHeader(db *sql.DB, whereQuery string, args ...any) (*types.Header return h, nil } -func fetchBlock(db *sql.DB, whereQuery string, args ...any) (*types.Block, error) { - h, err := fetchBlockHeader(db, whereQuery, args...) +func fetchBlock(ctx context.Context, db *sql.DB, whereQuery string, args ...any) (*types.Block, error) { + h, err := fetchBlockHeader(ctx, db, whereQuery, args...) if err != nil { return nil, err } diff --git a/go/enclave/storage/enclavedb/config.go b/go/enclave/storage/enclavedb/config.go index 615cdaf0fa..d261ccd2b9 100644 --- a/go/enclave/storage/enclavedb/config.go +++ b/go/enclave/storage/enclavedb/config.go @@ -1,6 +1,7 @@ package enclavedb import ( + "context" "database/sql" "errors" @@ -19,42 +20,42 @@ const ( attSelect = "select ky from attestation_key where party=?" ) -func WriteConfigToBatch(dbtx DBTransaction, key string, value any) { +func WriteConfigToBatch(ctx context.Context, dbtx DBTransaction, key string, value any) { dbtx.ExecuteSQL(cfgInsert, key, value) } -func WriteConfigToTx(dbtx *sql.Tx, key string, value any) (sql.Result, error) { +func WriteConfigToTx(ctx context.Context, dbtx *sql.Tx, key string, value any) (sql.Result, error) { return dbtx.Exec(cfgInsert, key, value) } -func WriteConfig(db *sql.DB, key string, value []byte) (sql.Result, error) { - return db.Exec(cfgInsert, key, value) +func WriteConfig(ctx context.Context, db *sql.DB, key string, value []byte) (sql.Result, error) { + return db.ExecContext(ctx, cfgInsert, key, value) } -func UpdateConfigToBatch(dbtx DBTransaction, key string, value []byte) { +func UpdateConfigToBatch(ctx context.Context, dbtx DBTransaction, key string, value []byte) { dbtx.ExecuteSQL(cfgUpdate, key, value) } -func UpdateConfig(db *sql.DB, key string, value []byte) (sql.Result, error) { - return db.Exec(cfgUpdate, key, value) +func UpdateConfig(ctx context.Context, db *sql.DB, key string, value []byte) (sql.Result, error) { + return db.ExecContext(ctx, cfgUpdate, key, value) } -func FetchConfig(db *sql.DB, key string) ([]byte, error) { - return readSingleRow(db, cfgSelect, key) +func FetchConfig(ctx context.Context, db *sql.DB, key string) ([]byte, error) { + return readSingleRow(ctx, db, cfgSelect, key) } -func WriteAttKey(db *sql.DB, party common.Address, key []byte) (sql.Result, error) { - return db.Exec(attInsert, party.Bytes(), key) +func WriteAttKey(ctx context.Context, db *sql.DB, party common.Address, key []byte) (sql.Result, error) { + return db.ExecContext(ctx, attInsert, party.Bytes(), key) } -func FetchAttKey(db *sql.DB, party common.Address) ([]byte, error) { - return readSingleRow(db, attSelect, party.Bytes()) +func FetchAttKey(ctx context.Context, db *sql.DB, party common.Address) ([]byte, error) { + return readSingleRow(ctx, db, attSelect, party.Bytes()) } -func readSingleRow(db *sql.DB, query string, v any) ([]byte, error) { +func readSingleRow(ctx context.Context, db *sql.DB, query string, v any) ([]byte, error) { var res []byte - err := db.QueryRow(query, v).Scan(&res) + err := db.QueryRowContext(ctx, query, v).Scan(&res) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error diff --git a/go/enclave/storage/enclavedb/db_transaction.go b/go/enclave/storage/enclavedb/db_transaction.go index 0b9bdbecfc..372988ba86 100644 --- a/go/enclave/storage/enclavedb/db_transaction.go +++ b/go/enclave/storage/enclavedb/db_transaction.go @@ -1,8 +1,10 @@ package enclavedb import ( + "context" "database/sql" "fmt" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" @@ -23,6 +25,7 @@ type statement struct { } type dbTransaction struct { + timeout time.Duration db EnclaveDB writes []keyvalue statements []statement @@ -62,7 +65,13 @@ func (b *dbTransaction) ValueSize() int { // Write executes a batch statement with all the updates func (b *dbTransaction) Write() error { - tx, err := b.db.BeginTx() + ctx, cancelCtx := context.WithTimeout(context.Background(), b.timeout) + defer cancelCtx() + return b.WriteCtx(ctx) +} + +func (b *dbTransaction) WriteCtx(ctx context.Context) error { + tx, err := b.db.BeginTx(ctx) if err != nil { return fmt.Errorf("failed to create batch transaction - %w", err) } @@ -80,12 +89,12 @@ func (b *dbTransaction) Write() error { } } - err = PutKeyValues(tx, updateKeys, updateValues) + err = PutKeyValues(ctx, tx, updateKeys, updateValues) if err != nil { return fmt.Errorf("failed to put key/value. Cause %w", err) } - err = DeleteKeys(tx, deletes) + err = DeleteKeys(ctx, tx, deletes) if err != nil { return fmt.Errorf("failed to delete keys. Cause %w", err) } diff --git a/go/enclave/storage/enclavedb/enclave_sql_db.go b/go/enclave/storage/enclavedb/enclave_sql_db.go index fcfa2852c4..602ed0b18f 100644 --- a/go/enclave/storage/enclavedb/enclave_sql_db.go +++ b/go/enclave/storage/enclavedb/enclave_sql_db.go @@ -1,10 +1,13 @@ package enclavedb import ( + "context" "database/sql" "errors" "fmt" + "github.com/ten-protocol/go-ten/go/config" + "github.com/ethereum/go-ethereum/ethdb" gethlog "github.com/ethereum/go-ethereum/log" ) @@ -13,6 +16,7 @@ import ( // should not be used directly outside the db package type enclaveDB struct { sqldb *sql.DB + config config.EnclaveConfig logger gethlog.Logger } @@ -51,32 +55,40 @@ func (sqlDB *enclaveDB) NewSnapshot() (ethdb.Snapshot, error) { panic("implement me") } -func NewEnclaveDB(db *sql.DB, logger gethlog.Logger) (EnclaveDB, error) { - return &enclaveDB{sqldb: db, logger: logger}, nil +func NewEnclaveDB(db *sql.DB, config config.EnclaveConfig, logger gethlog.Logger) (EnclaveDB, error) { + return &enclaveDB{sqldb: db, config: config, logger: logger}, nil } func (sqlDB *enclaveDB) GetSQLDB() *sql.DB { return sqlDB.sqldb } -func (sqlDB *enclaveDB) BeginTx() (*sql.Tx, error) { - return sqlDB.sqldb.Begin() +func (sqlDB *enclaveDB) BeginTx(ctx context.Context) (*sql.Tx, error) { + return sqlDB.sqldb.BeginTx(ctx, nil) } func (sqlDB *enclaveDB) Has(key []byte) (bool, error) { - return Has(sqlDB.sqldb, key) + ctx, cancelCtx := context.WithTimeout(context.Background(), sqlDB.config.RPCTimeout) + defer cancelCtx() + return Has(ctx, sqlDB.sqldb, key) } func (sqlDB *enclaveDB) Get(key []byte) ([]byte, error) { - return Get(sqlDB.sqldb, key) + ctx, cancelCtx := context.WithTimeout(context.Background(), sqlDB.config.RPCTimeout) + defer cancelCtx() + return Get(ctx, sqlDB.sqldb, key) } func (sqlDB *enclaveDB) Put(key []byte, value []byte) error { - return Put(sqlDB.sqldb, key, value) + ctx, cancelCtx := context.WithTimeout(context.Background(), sqlDB.config.RPCTimeout) + defer cancelCtx() + return Put(ctx, sqlDB.sqldb, key, value) } func (sqlDB *enclaveDB) Delete(key []byte) error { - return Delete(sqlDB.sqldb, key) + ctx, cancelCtx := context.WithTimeout(context.Background(), sqlDB.config.RPCTimeout) + defer cancelCtx() + return Delete(ctx, sqlDB.sqldb, key) } func (sqlDB *enclaveDB) Close() error { @@ -88,18 +100,21 @@ func (sqlDB *enclaveDB) Close() error { func (sqlDB *enclaveDB) NewDBTransaction() *dbTransaction { return &dbTransaction{ - db: sqlDB, + timeout: sqlDB.config.RPCTimeout, + db: sqlDB, } } func (sqlDB *enclaveDB) NewBatch() ethdb.Batch { return &dbTransaction{ - db: sqlDB, + timeout: sqlDB.config.RPCTimeout, + db: sqlDB, } } func (sqlDB *enclaveDB) NewIterator(prefix []byte, start []byte) ethdb.Iterator { - return NewIterator(sqlDB.sqldb, prefix, start) + // we can't use a timeout context here, because the cleanup function must be called + return NewIterator(context.Background(), sqlDB.sqldb, prefix, start) } func (sqlDB *enclaveDB) Stat(_ string) (string, error) { diff --git a/go/enclave/storage/enclavedb/enclave_sql_db_test.go b/go/enclave/storage/enclavedb/enclave_sql_db_test.go index 750386b9e5..957cf8c7b8 100644 --- a/go/enclave/storage/enclavedb/enclave_sql_db_test.go +++ b/go/enclave/storage/enclavedb/enclave_sql_db_test.go @@ -4,6 +4,9 @@ import ( "database/sql" "path/filepath" "testing" + "time" + + "github.com/ten-protocol/go-ten/go/config" "github.com/ten-protocol/go-ten/integration/common/testlog" @@ -118,7 +121,7 @@ func createDB(t *testing.T) ethdb.Database { lite := setupSQLite(t) _, err := lite.Exec(createKVTable) failIfError(t, err, "Failed to create key-value table in test db") - s, err := NewEnclaveDB(lite, testlog.Logger()) + s, err := NewEnclaveDB(lite, config.EnclaveConfig{RPCTimeout: time.Second}, testlog.Logger()) failIfError(t, err, "Failed to create SQLEthDatabase for test") return s } diff --git a/go/enclave/storage/enclavedb/events.go b/go/enclave/storage/enclavedb/events.go index aa0ccefaf2..910078fdd8 100644 --- a/go/enclave/storage/enclavedb/events.go +++ b/go/enclave/storage/enclavedb/events.go @@ -1,6 +1,7 @@ package enclavedb import ( + "context" "database/sql" "fmt" "math/big" @@ -22,12 +23,12 @@ const ( orderBy = " order by b.height, tx.idx asc" ) -func StoreEventLogs(dbtx DBTransaction, receipts []*types.Receipt, stateDB *state.StateDB) error { +func StoreEventLogs(ctx context.Context, dbtx DBTransaction, receipts []*types.Receipt, stateDB *state.StateDB) error { var args []any totalLogs := 0 for _, receipt := range receipts { for _, l := range receipt.Logs { - logArgs, err := logDBValues(dbtx.GetDB(), l, receipt, stateDB) + logArgs, err := logDBValues(ctx, dbtx.GetDB(), l, receipt, stateDB) if err != nil { return err } @@ -49,7 +50,7 @@ func StoreEventLogs(dbtx DBTransaction, receipts []*types.Receipt, stateDB *stat // The other 4 topics are set by the programmer // According to the data relevancy rules, an event is relevant to accounts referenced directly in topics // If the event is not referring any user address, it is considered a "lifecycle event", and is relevant to everyone -func logDBValues(db *sql.DB, l *types.Log, receipt *types.Receipt, stateDB *state.StateDB) ([]any, error) { +func logDBValues(ctx context.Context, db *sql.DB, l *types.Log, receipt *types.Receipt, stateDB *state.StateDB) ([]any, error) { // The topics are stored in an array with a maximum of 5 entries, but usually less var t0, t1, t2, t3, t4 []byte @@ -72,7 +73,7 @@ func logDBValues(db *sql.DB, l *types.Log, receipt *types.Receipt, stateDB *stat // if yes, then mark it as relevant for that account if n > 1 { t1 = l.Topics[1].Bytes() - isUserAccount, addr1, err = isEndUserAccount(db, l.Topics[1], stateDB) + isUserAccount, addr1, err = isEndUserAccount(ctx, db, l.Topics[1], stateDB) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func logDBValues(db *sql.DB, l *types.Log, receipt *types.Receipt, stateDB *stat } if n > 2 { t2 = l.Topics[2].Bytes() - isUserAccount, addr2, err = isEndUserAccount(db, l.Topics[2], stateDB) + isUserAccount, addr2, err = isEndUserAccount(ctx, db, l.Topics[2], stateDB) if err != nil { return nil, err } @@ -94,7 +95,7 @@ func logDBValues(db *sql.DB, l *types.Log, receipt *types.Receipt, stateDB *stat } if n > 3 { t3 = l.Topics[3].Bytes() - isUserAccount, addr3, err = isEndUserAccount(db, l.Topics[3], stateDB) + isUserAccount, addr3, err = isEndUserAccount(ctx, db, l.Topics[3], stateDB) if err != nil { return nil, err } @@ -105,7 +106,7 @@ func logDBValues(db *sql.DB, l *types.Log, receipt *types.Receipt, stateDB *stat } if n > 4 { t4 = l.Topics[4].Bytes() - isUserAccount, addr4, err = isEndUserAccount(db, l.Topics[4], stateDB) + isUserAccount, addr4, err = isEndUserAccount(ctx, db, l.Topics[4], stateDB) if err != nil { return nil, err } @@ -130,6 +131,7 @@ func logDBValues(db *sql.DB, l *types.Log, receipt *types.Receipt, stateDB *stat } func FilterLogs( + ctx context.Context, db *sql.DB, requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, @@ -176,10 +178,10 @@ func FilterLogs( } } - return loadLogs(db, requestingAccount, query, queryParams) + return loadLogs(ctx, db, requestingAccount, query, queryParams) } -func DebugGetLogs(db *sql.DB, txHash common.TxHash) ([]*tracers.DebugLogs, error) { +func DebugGetLogs(ctx context.Context, db *sql.DB, txHash common.TxHash) ([]*tracers.DebugLogs, error) { var queryParams []any query := baseDebugEventsQuerySelect + " " + baseEventsJoin + "AND tx.hash = ?" @@ -188,7 +190,7 @@ func DebugGetLogs(db *sql.DB, txHash common.TxHash) ([]*tracers.DebugLogs, error result := make([]*tracers.DebugLogs, 0) - rows, err := db.Query(query, queryParams...) + rows, err := db.QueryContext(ctx, query, queryParams...) if err != nil { return nil, err } @@ -259,7 +261,7 @@ func bytesToAddress(b []byte) *gethcommon.Address { // forcing its events to become permanently private (this is not implemented for now) // // todo - find a more efficient way -func isEndUserAccount(db *sql.DB, topic gethcommon.Hash, stateDB *state.StateDB) (bool, *gethcommon.Address, error) { +func isEndUserAccount(ctx context.Context, db *sql.DB, topic gethcommon.Hash, stateDB *state.StateDB) (bool, *gethcommon.Address, error) { potentialAddr := common.ExtractPotentialAddress(topic) if potentialAddr == nil { return false, nil, nil @@ -268,7 +270,7 @@ func isEndUserAccount(db *sql.DB, topic gethcommon.Hash, stateDB *state.StateDB) // Check the database if there are already entries for this address var count int query := "select count(*) from events where rel_address1=? OR rel_address2=? OR rel_address3=? OR rel_address4=?" - err := db.QueryRow(query, addrBytes, addrBytes, addrBytes, addrBytes).Scan(&count) + err := db.QueryRowContext(ctx, query, addrBytes, addrBytes, addrBytes, addrBytes).Scan(&count) if err != nil { // exit here return false, nil, err @@ -292,7 +294,7 @@ func isEndUserAccount(db *sql.DB, topic gethcommon.Hash, stateDB *state.StateDB) // utility function that knows how to load relevant logs from the database // todo always pass in the actual batch hashes because of reorgs, or make sure to clean up log entries from discarded batches -func loadLogs(db *sql.DB, requestingAccount *gethcommon.Address, whereCondition string, whereParams []any) ([]*types.Log, error) { +func loadLogs(ctx context.Context, db *sql.DB, requestingAccount *gethcommon.Address, whereCondition string, whereParams []any) ([]*types.Log, error) { if requestingAccount == nil { return nil, fmt.Errorf("logs can only be requested for an account") } @@ -315,7 +317,7 @@ func loadLogs(db *sql.DB, requestingAccount *gethcommon.Address, whereCondition query += orderBy - rows, err := db.Query(query, queryParams...) + rows, err := db.QueryContext(ctx, query, queryParams...) if err != nil { return nil, err } diff --git a/go/enclave/storage/enclavedb/interfaces.go b/go/enclave/storage/enclavedb/interfaces.go index ab6cafe54d..a8cc0a34d0 100644 --- a/go/enclave/storage/enclavedb/interfaces.go +++ b/go/enclave/storage/enclavedb/interfaces.go @@ -1,6 +1,7 @@ package enclavedb import ( + "context" "database/sql" "github.com/ethereum/go-ethereum/ethdb" @@ -14,7 +15,7 @@ type EnclaveDB interface { ethdb.Database GetSQLDB() *sql.DB NewDBTransaction() *dbTransaction - BeginTx() (*sql.Tx, error) + BeginTx(context.Context) (*sql.Tx, error) } // DBTransaction - represents a database transaction implemented unusually. diff --git a/go/enclave/storage/enclavedb/keyvalue.go b/go/enclave/storage/enclavedb/keyvalue.go index 568f16cb03..cb78491501 100644 --- a/go/enclave/storage/enclavedb/keyvalue.go +++ b/go/enclave/storage/enclavedb/keyvalue.go @@ -1,6 +1,7 @@ package enclavedb import ( + "context" "database/sql" "errors" "fmt" @@ -20,8 +21,8 @@ const ( searchQry = `select * from keyvalue where substring(keyvalue.ky, 1, ?) = ? and keyvalue.ky >= ? order by keyvalue.ky asc` ) -func Has(db *sql.DB, key []byte) (bool, error) { - err := db.QueryRow(getQry, key).Scan() +func Has(ctx context.Context, db *sql.DB, key []byte) (bool, error) { + err := db.QueryRowContext(ctx, getQry, key).Scan() if err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil @@ -31,10 +32,10 @@ func Has(db *sql.DB, key []byte) (bool, error) { return true, nil } -func Get(db *sql.DB, key []byte) ([]byte, error) { +func Get(ctx context.Context, db *sql.DB, key []byte) ([]byte, error) { var res []byte - err := db.QueryRow(getQry, key).Scan(&res) + err := db.QueryRowContext(ctx, getQry, key).Scan(&res) if err != nil { if errors.Is(err, sql.ErrNoRows) { // make sure the error is converted to obscuro-wide not found error @@ -45,12 +46,12 @@ func Get(db *sql.DB, key []byte) ([]byte, error) { return res, nil } -func Put(db *sql.DB, key []byte, value []byte) error { - _, err := db.Exec(putQry, key, value) +func Put(ctx context.Context, db *sql.DB, key []byte, value []byte) error { + _, err := db.ExecContext(ctx, putQry, key, value) return err } -func PutKeyValues(tx *sql.Tx, keys [][]byte, vals [][]byte) error { +func PutKeyValues(ctx context.Context, tx *sql.Tx, keys [][]byte, vals [][]byte) error { if len(keys) != len(vals) { return fmt.Errorf("invalid command. should not happen") } @@ -73,14 +74,14 @@ func PutKeyValues(tx *sql.Tx, keys [][]byte, vals [][]byte) error { return nil } -func Delete(db *sql.DB, key []byte) error { - _, err := db.Exec(delQry, key) +func Delete(ctx context.Context, db *sql.DB, key []byte) error { + _, err := db.ExecContext(ctx, delQry, key) return err } -func DeleteKeys(db *sql.Tx, keys [][]byte) error { +func DeleteKeys(ctx context.Context, db *sql.Tx, keys [][]byte) error { for _, del := range keys { - _, err := db.Exec(delQry, del) + _, err := db.ExecContext(ctx, delQry, del) if err != nil { return err } @@ -88,11 +89,11 @@ func DeleteKeys(db *sql.Tx, keys [][]byte) error { return nil } -func NewIterator(db *sql.DB, prefix []byte, start []byte) ethdb.Iterator { +func NewIterator(ctx context.Context, db *sql.DB, prefix []byte, start []byte) ethdb.Iterator { pr := prefix st := append(prefix, start...) // iterator clean-up handles closing this rows iterator - rows, err := db.Query(searchQry, len(pr), pr, st) + rows, err := db.QueryContext(ctx, searchQry, len(pr), pr, st) if err != nil { return &iterator{ err: fmt.Errorf("failed to get rows, iter will be empty, %w", err), diff --git a/go/enclave/storage/init/edgelessdb/edgelessdb.go b/go/enclave/storage/init/edgelessdb/edgelessdb.go index 9c0a2fb785..09aaa69ba9 100644 --- a/go/enclave/storage/init/edgelessdb/edgelessdb.go +++ b/go/enclave/storage/init/edgelessdb/edgelessdb.go @@ -24,6 +24,8 @@ import ( "strings" "time" + "github.com/ten-protocol/go-ten/go/config" + "github.com/ten-protocol/go-ten/go/common/log" "github.com/ten-protocol/go-ten/go/enclave/storage/init/migration" @@ -128,7 +130,7 @@ type Credentials struct { UserKeyPEM string // db user private key, generated in our enclave } -func Connector(edbCfg *Config, logger gethlog.Logger) (enclavedb.EnclaveDB, error) { +func Connector(edbCfg *Config, config config.EnclaveConfig, logger gethlog.Logger) (enclavedb.EnclaveDB, error) { // rather than fail immediately if EdgelessDB is not available yet we wait up for `edgelessDBStartTimeout` for it to be available err := waitForEdgelessDBToStart(edbCfg.Host, logger) if err != nil { @@ -158,7 +160,7 @@ func Connector(edbCfg *Config, logger gethlog.Logger) (enclavedb.EnclaveDB, erro } // wrap it in our eth-compatible key-value store layer - return enclavedb.NewEnclaveDB(sqlDB, logger) + return enclavedb.NewEnclaveDB(sqlDB, config, logger) } func waitForEdgelessDBToStart(edbHost string, logger gethlog.Logger) error { diff --git a/go/enclave/storage/init/migration/db_migration.go b/go/enclave/storage/init/migration/db_migration.go index be56fe9433..742cc6001d 100644 --- a/go/enclave/storage/init/migration/db_migration.go +++ b/go/enclave/storage/init/migration/db_migration.go @@ -1,6 +1,7 @@ package migration import ( + "context" "database/sql" "embed" "errors" @@ -27,7 +28,7 @@ func DBMigration(db *sql.DB, sqlFiles embed.FS, logger gethlog.Logger) error { maxMigration := int64(len(migrationFiles)) var maxDB int64 - config, err := enclavedb.FetchConfig(db, currentMigrationVersionKey) + config, err := enclavedb.FetchConfig(context.Background(), db, currentMigrationVersionKey) if err != nil { // first time there is no entry, so 001 was executed already ( triggered at launch/manifest time ) if errors.Is(err, errutil.ErrNotFound) { @@ -66,7 +67,7 @@ func executeMigration(db *sql.DB, content string, migrationOrder int64) error { return err } - _, err = enclavedb.WriteConfigToTx(tx, currentMigrationVersionKey, big.NewInt(migrationOrder).Bytes()) + _, err = enclavedb.WriteConfigToTx(context.Background(), tx, currentMigrationVersionKey, big.NewInt(migrationOrder).Bytes()) if err != nil { return err } diff --git a/go/enclave/storage/init/sqlite/sqlite.go b/go/enclave/storage/init/sqlite/sqlite.go index f566de1766..75fa453250 100644 --- a/go/enclave/storage/init/sqlite/sqlite.go +++ b/go/enclave/storage/init/sqlite/sqlite.go @@ -8,6 +8,8 @@ import ( "path/filepath" "strings" + "github.com/ten-protocol/go-ten/go/config" + "github.com/ten-protocol/go-ten/go/common/log" "github.com/ten-protocol/go-ten/go/enclave/storage/init/migration" @@ -30,7 +32,7 @@ var sqlFiles embed.FS // CreateTemporarySQLiteDB if dbPath is empty will use a random throwaway temp file, // otherwise dbPath is a filepath for the sqldb file, allows for tests that care about persistence between restarts -func CreateTemporarySQLiteDB(dbPath string, dbOptions string, logger gethlog.Logger) (enclavedb.EnclaveDB, error) { +func CreateTemporarySQLiteDB(dbPath string, dbOptions string, config config.EnclaveConfig, logger gethlog.Logger) (enclavedb.EnclaveDB, error) { initialsed := false if dbPath == "" { @@ -76,7 +78,7 @@ func CreateTemporarySQLiteDB(dbPath string, dbOptions string, logger gethlog.Log logger.Info(fmt.Sprintf("Opened %s sqlite db file at %s", description, dbPath)) - return enclavedb.NewEnclaveDB(db, logger) + return enclavedb.NewEnclaveDB(db, config, logger) } func initialiseDB(db *sql.DB) error { diff --git a/go/enclave/storage/interfaces.go b/go/enclave/storage/interfaces.go index 032a367866..3bbb00560f 100644 --- a/go/enclave/storage/interfaces.go +++ b/go/enclave/storage/interfaces.go @@ -1,6 +1,7 @@ package storage import ( + "context" "crypto/ecdsa" "io" "math/big" @@ -20,103 +21,103 @@ import ( // BlockResolver stores new blocks and returns information on existing blocks type BlockResolver interface { // FetchBlock returns the L1 Block with the given hash. - FetchBlock(blockHash common.L1BlockHash) (*types.Block, error) + FetchBlock(ctx context.Context, blockHash common.L1BlockHash) (*types.Block, error) // FetchCanonicaBlockByHeight - self explanatory - FetchCanonicaBlockByHeight(height *big.Int) (*types.Block, error) + FetchCanonicaBlockByHeight(ctx context.Context, height *big.Int) (*types.Block, error) // FetchHeadBlock - returns the head of the current chain. - FetchHeadBlock() (*types.Block, error) + FetchHeadBlock(ctx context.Context) (*types.Block, error) // StoreBlock persists the L1 Block and updates the canonical ancestors if there was a fork - StoreBlock(block *types.Block, fork *common.ChainFork) error + StoreBlock(ctx context.Context, block *types.Block, fork *common.ChainFork) error // IsAncestor returns true if maybeAncestor is an ancestor of the L1 Block, and false otherwise - IsAncestor(block *types.Block, maybeAncestor *types.Block) bool + IsAncestor(ctx context.Context, block *types.Block, maybeAncestor *types.Block) bool // IsBlockAncestor returns true if maybeAncestor is an ancestor of the L1 Block, and false otherwise // Takes into consideration that the Block to verify might be on a branch we haven't received yet // todo (low priority) - this is super confusing, analyze the usage - IsBlockAncestor(block *types.Block, maybeAncestor common.L1BlockHash) bool + IsBlockAncestor(ctx context.Context, block *types.Block, maybeAncestor common.L1BlockHash) bool } type BatchResolver interface { // FetchBatch returns the batch with the given hash. - FetchBatch(hash common.L2BatchHash) (*core.Batch, error) + FetchBatch(ctx context.Context, hash common.L2BatchHash) (*core.Batch, error) // FetchBatchHeader returns the batch header with the given hash. - FetchBatchHeader(hash common.L2BatchHash) (*common.BatchHeader, error) + FetchBatchHeader(ctx context.Context, hash common.L2BatchHash) (*common.BatchHeader, error) // FetchBatchByHeight returns the batch on the canonical chain with the given height. - FetchBatchByHeight(height uint64) (*core.Batch, error) + FetchBatchByHeight(ctx context.Context, height uint64) (*core.Batch, error) // FetchBatchBySeqNo returns the batch with the given seq number. - FetchBatchBySeqNo(seqNum uint64) (*core.Batch, error) + FetchBatchBySeqNo(ctx context.Context, seqNum uint64) (*core.Batch, error) // FetchHeadBatch returns the current head batch of the canonical chain. - FetchHeadBatch() (*core.Batch, error) + FetchHeadBatch(ctx context.Context) (*core.Batch, error) // FetchCurrentSequencerNo returns the sequencer number - FetchCurrentSequencerNo() (*big.Int, error) + FetchCurrentSequencerNo(ctx context.Context) (*big.Int, error) // FetchBatchesByBlock returns all batches with the block hash as the L1 proof - FetchBatchesByBlock(common.L1BlockHash) ([]*core.Batch, error) + FetchBatchesByBlock(ctx context.Context, hash common.L1BlockHash) ([]*core.Batch, error) // FetchNonCanonicalBatchesBetween - returns all reorged batches between the sequences - FetchNonCanonicalBatchesBetween(startSeq uint64, endSeq uint64) ([]*core.Batch, error) + FetchNonCanonicalBatchesBetween(ctx context.Context, startSeq uint64, endSeq uint64) ([]*core.Batch, error) // FetchCanonicalUnexecutedBatches - return the list of the unexecuted batches that are canonical - FetchCanonicalUnexecutedBatches(*big.Int) ([]*core.Batch, error) + FetchCanonicalUnexecutedBatches(context.Context, *big.Int) ([]*core.Batch, error) - FetchConvertedHash(hash common.L2BatchHash) (gethcommon.Hash, error) + FetchConvertedHash(ctx context.Context, hash common.L2BatchHash) (gethcommon.Hash, error) // BatchWasExecuted - return true if the batch was executed - BatchWasExecuted(hash common.L2BatchHash) (bool, error) + BatchWasExecuted(ctx context.Context, hash common.L2BatchHash) (bool, error) // FetchHeadBatchForBlock returns the hash of the head batch at a given L1 block. - FetchHeadBatchForBlock(blockHash common.L1BlockHash) (*core.Batch, error) + FetchHeadBatchForBlock(ctx context.Context, blockHash common.L1BlockHash) (*core.Batch, error) // StoreBatch stores an un-executed batch. - StoreBatch(batch *core.Batch, convertedHash gethcommon.Hash) error + StoreBatch(ctx context.Context, batch *core.Batch, convertedHash gethcommon.Hash) error // StoreExecutedBatch - store the batch after it was executed - StoreExecutedBatch(batch *core.Batch, receipts []*types.Receipt) error + StoreExecutedBatch(ctx context.Context, batch *core.Batch, receipts []*types.Receipt) error // StoreRollup - StoreRollup(rollup *common.ExtRollup, header *common.CalldataRollupHeader) error - FetchRollupMetadata(hash common.L2RollupHash) (*common.PublicRollupMetadata, error) - FetchReorgedRollup(reorgedBlocks []common.L1BlockHash) (*common.L2BatchHash, error) + StoreRollup(ctx context.Context, rollup *common.ExtRollup, header *common.CalldataRollupHeader) error + FetchRollupMetadata(ctx context.Context, hash common.L2RollupHash) (*common.PublicRollupMetadata, error) + FetchReorgedRollup(ctx context.Context, reorgedBlocks []common.L1BlockHash) (*common.L2BatchHash, error) } type GethStateDB interface { // CreateStateDB creates a database that can be used to execute transactions - CreateStateDB(hash common.L2BatchHash) (*state.StateDB, error) + CreateStateDB(ctx context.Context, hash common.L2BatchHash) (*state.StateDB, error) // EmptyStateDB creates the original empty StateDB EmptyStateDB() (*state.StateDB, error) } type SharedSecretStorage interface { // FetchSecret returns the enclave's secret. - FetchSecret() (*crypto.SharedEnclaveSecret, error) + FetchSecret(ctx context.Context) (*crypto.SharedEnclaveSecret, error) // StoreSecret stores a secret in the enclave - StoreSecret(secret crypto.SharedEnclaveSecret) error + StoreSecret(ctx context.Context, secret crypto.SharedEnclaveSecret) error } type TransactionStorage interface { // GetTransaction - returns the positional metadata of the tx by hash - GetTransaction(txHash common.L2TxHash) (*types.Transaction, common.L2BatchHash, uint64, uint64, error) + GetTransaction(ctx context.Context, txHash common.L2TxHash) (*types.Transaction, common.L2BatchHash, uint64, uint64, error) // GetTransactionReceipt - returns the receipt of a tx by tx hash - GetTransactionReceipt(txHash common.L2TxHash) (*types.Receipt, error) + GetTransactionReceipt(ctx context.Context, txHash common.L2TxHash) (*types.Receipt, error) // GetReceiptsByBatchHash retrieves the receipts for all transactions in a given rollup. - GetReceiptsByBatchHash(hash common.L2BatchHash) (types.Receipts, error) + GetReceiptsByBatchHash(ctx context.Context, hash common.L2BatchHash) (types.Receipts, error) // GetContractCreationTx returns the hash of the tx that created a contract - GetContractCreationTx(address gethcommon.Address) (*gethcommon.Hash, error) + GetContractCreationTx(ctx context.Context, address gethcommon.Address) (*gethcommon.Hash, error) } type AttestationStorage interface { // FetchAttestedKey returns the public key of an attested aggregator - FetchAttestedKey(aggregator gethcommon.Address) (*ecdsa.PublicKey, error) + FetchAttestedKey(ctx context.Context, aggregator gethcommon.Address) (*ecdsa.PublicKey, error) // StoreAttestedKey - store the public key of an attested aggregator - StoreAttestedKey(aggregator gethcommon.Address, key *ecdsa.PublicKey) error + StoreAttestedKey(ctx context.Context, aggregator gethcommon.Address, key *ecdsa.PublicKey) error } type CrossChainMessagesStorage interface { - StoreL1Messages(blockHash common.L1BlockHash, messages common.CrossChainMessages) error - GetL1Messages(blockHash common.L1BlockHash) (common.CrossChainMessages, error) + StoreL1Messages(ctx context.Context, blockHash common.L1BlockHash, messages common.CrossChainMessages) error + GetL1Messages(ctx context.Context, blockHash common.L1BlockHash) (common.CrossChainMessages, error) - StoreValueTransfers(blockHash common.L1BlockHash, transfers common.ValueTransferEvents) error - GetL1Transfers(blockHash common.L1BlockHash) (common.ValueTransferEvents, error) + StoreValueTransfers(ctx context.Context, blockHash common.L1BlockHash, transfers common.ValueTransferEvents) error + GetL1Transfers(ctx context.Context, blockHash common.L1BlockHash) (common.ValueTransferEvents, error) } type EnclaveKeyStorage interface { - StoreEnclaveKey(enclaveKey *crypto.EnclaveKey) error - GetEnclaveKey() (*crypto.EnclaveKey, error) + StoreEnclaveKey(ctx context.Context, enclaveKey *crypto.EnclaveKey) error + GetEnclaveKey(ctx context.Context) (*crypto.EnclaveKey, error) } // Storage is the enclave's interface for interacting with the enclave's datastore @@ -133,15 +134,15 @@ type Storage interface { io.Closer // HealthCheck returns whether the storage is deemed healthy or not - HealthCheck() (bool, error) + HealthCheck(ctx context.Context) (bool, error) // FilterLogs - applies the properties the relevancy checks for the requestingAccount to all the stored log events // nil values will be ignored. Make sure to set all fields to the right values before calling this function // the blockHash should always be nil. - FilterLogs(requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, blockHash *common.L2BatchHash, addresses []gethcommon.Address, topics [][]gethcommon.Hash) ([]*types.Log, error) + FilterLogs(ctx context.Context, requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, blockHash *common.L2BatchHash, addresses []gethcommon.Address, topics [][]gethcommon.Hash) ([]*types.Log, error) // DebugGetLogs returns logs for a given tx hash without any constraints - should only be used for debug purposes - DebugGetLogs(txHash common.TxHash) ([]*tracers.DebugLogs, error) + DebugGetLogs(ctx context.Context, txHash common.TxHash) ([]*tracers.DebugLogs, error) // TrieDB - return the underlying trie database TrieDB() *trie.Database @@ -151,10 +152,10 @@ type Storage interface { } type ScanStorage interface { - GetContractCount() (*big.Int, error) - GetReceiptsPerAddress(address *gethcommon.Address, pagination *common.QueryPagination) (types.Receipts, error) - GetPublicTransactionData(pagination *common.QueryPagination) ([]common.PublicTransaction, error) - GetPublicTransactionCount() (uint64, error) + GetContractCount(ctx context.Context) (*big.Int, error) + GetReceiptsPerAddress(ctx context.Context, address *gethcommon.Address, pagination *common.QueryPagination) (types.Receipts, error) + GetPublicTransactionData(ctx context.Context, pagination *common.QueryPagination) ([]common.PublicTransaction, error) + GetPublicTransactionCount(ctx context.Context) (uint64, error) - GetReceiptsPerAddressCount(addr *gethcommon.Address) (uint64, error) + GetReceiptsPerAddressCount(ctx context.Context, addr *gethcommon.Address) (uint64, error) } diff --git a/go/enclave/storage/storage.go b/go/enclave/storage/storage.go index cfe48bad3c..f1c02e1b48 100644 --- a/go/enclave/storage/storage.go +++ b/go/enclave/storage/storage.go @@ -2,6 +2,7 @@ package storage import ( "bytes" + "context" "crypto/ecdsa" "errors" "fmt" @@ -127,20 +128,20 @@ func (s *storageImpl) Close() error { return s.db.GetSQLDB().Close() } -func (s *storageImpl) FetchHeadBatch() (*core.Batch, error) { +func (s *storageImpl) FetchHeadBatch(ctx context.Context) (*core.Batch, error) { defer s.logDuration("FetchHeadBatch", measure.NewStopwatch()) - return enclavedb.ReadCurrentHeadBatch(s.db.GetSQLDB()) + return enclavedb.ReadCurrentHeadBatch(ctx, s.db.GetSQLDB()) } -func (s *storageImpl) FetchCurrentSequencerNo() (*big.Int, error) { +func (s *storageImpl) FetchCurrentSequencerNo(ctx context.Context) (*big.Int, error) { defer s.logDuration("FetchCurrentSequencerNo", measure.NewStopwatch()) - return enclavedb.ReadCurrentSequencerNo(s.db.GetSQLDB()) + return enclavedb.ReadCurrentSequencerNo(ctx, s.db.GetSQLDB()) } -func (s *storageImpl) FetchBatch(hash common.L2BatchHash) (*core.Batch, error) { +func (s *storageImpl) FetchBatch(ctx context.Context, hash common.L2BatchHash) (*core.Batch, error) { defer s.logDuration("FetchBatch", measure.NewStopwatch()) - seqNo, err := common.GetCachedValue(s.seqCacheByHash, s.logger, hash, func(v any) (*big.Int, error) { - batch, err := enclavedb.ReadBatchByHash(s.db.GetSQLDB(), v.(common.L2BatchHash)) + seqNo, err := common.GetCachedValue(ctx, s.seqCacheByHash, s.logger, hash, func(v any) (*big.Int, error) { + batch, err := enclavedb.ReadBatchByHash(ctx, s.db.GetSQLDB(), v.(common.L2BatchHash)) if err != nil { return nil, err } @@ -149,32 +150,32 @@ func (s *storageImpl) FetchBatch(hash common.L2BatchHash) (*core.Batch, error) { if err != nil { return nil, err } - return s.FetchBatchBySeqNo(seqNo.Uint64()) + return s.FetchBatchBySeqNo(ctx, seqNo.Uint64()) } -func (s *storageImpl) FetchConvertedHash(hash common.L2BatchHash) (gethcommon.Hash, error) { +func (s *storageImpl) FetchConvertedHash(ctx context.Context, hash common.L2BatchHash) (gethcommon.Hash, error) { defer s.logDuration("FetchConvertedHash", measure.NewStopwatch()) - batch, err := s.FetchBatch(hash) + batch, err := s.FetchBatch(ctx, hash) if err != nil { return gethcommon.Hash{}, err } - return enclavedb.FetchConvertedBatchHash(s.db.GetSQLDB(), batch.Header.SequencerOrderNo.Uint64()) + return enclavedb.FetchConvertedBatchHash(ctx, s.db.GetSQLDB(), batch.Header.SequencerOrderNo.Uint64()) } -func (s *storageImpl) FetchBatchHeader(hash common.L2BatchHash) (*common.BatchHeader, error) { +func (s *storageImpl) FetchBatchHeader(ctx context.Context, hash common.L2BatchHash) (*common.BatchHeader, error) { defer s.logDuration("FetchBatchHeader", measure.NewStopwatch()) - b, err := s.FetchBatch(hash) + b, err := s.FetchBatch(ctx, hash) if err != nil { return nil, err } return b.Header, nil } -func (s *storageImpl) FetchBatchByHeight(height uint64) (*core.Batch, error) { +func (s *storageImpl) FetchBatchByHeight(ctx context.Context, height uint64) (*core.Batch, error) { defer s.logDuration("FetchBatchByHeight", measure.NewStopwatch()) // the key is (height+1), because for some reason it doesn't like a key of 0 - seqNo, err := common.GetCachedValue(s.seqCacheByHeight, s.logger, height+1, func(h any) (*big.Int, error) { - batch, err := enclavedb.ReadCanonicalBatchByHeight(s.db.GetSQLDB(), height) + seqNo, err := common.GetCachedValue(ctx, s.seqCacheByHeight, s.logger, height+1, func(h any) (*big.Int, error) { + batch, err := enclavedb.ReadCanonicalBatchByHeight(ctx, s.db.GetSQLDB(), height) if err != nil { return nil, err } @@ -183,73 +184,73 @@ func (s *storageImpl) FetchBatchByHeight(height uint64) (*core.Batch, error) { if err != nil { return nil, err } - return s.FetchBatchBySeqNo(seqNo.Uint64()) + return s.FetchBatchBySeqNo(ctx, seqNo.Uint64()) } -func (s *storageImpl) FetchNonCanonicalBatchesBetween(startSeq uint64, endSeq uint64) ([]*core.Batch, error) { +func (s *storageImpl) FetchNonCanonicalBatchesBetween(ctx context.Context, startSeq uint64, endSeq uint64) ([]*core.Batch, error) { defer s.logDuration("FetchNonCanonicalBatchesBetween", measure.NewStopwatch()) - return enclavedb.ReadNonCanonicalBatches(s.db.GetSQLDB(), startSeq, endSeq) + return enclavedb.ReadNonCanonicalBatches(ctx, s.db.GetSQLDB(), startSeq, endSeq) } -func (s *storageImpl) StoreBlock(b *types.Block, chainFork *common.ChainFork) error { +func (s *storageImpl) StoreBlock(ctx context.Context, b *types.Block, chainFork *common.ChainFork) error { defer s.logDuration("StoreBlock", measure.NewStopwatch()) dbTransaction := s.db.NewDBTransaction() if chainFork != nil && chainFork.IsFork() { s.logger.Info(fmt.Sprintf("Fork. %s", chainFork)) - enclavedb.UpdateCanonicalBlocks(dbTransaction, chainFork.CanonicalPath, chainFork.NonCanonicalPath) + enclavedb.UpdateCanonicalBlocks(ctx, dbTransaction, chainFork.CanonicalPath, chainFork.NonCanonicalPath) } // In case there were any batches inserted before this block was received - enclavedb.UpdateCanonicalBlocks(dbTransaction, []common.L1BlockHash{b.Hash()}, nil) + enclavedb.UpdateCanonicalBlocks(ctx, dbTransaction, []common.L1BlockHash{b.Hash()}, nil) - if err := enclavedb.WriteBlock(dbTransaction, b.Header()); err != nil { + if err := enclavedb.WriteBlock(ctx, dbTransaction, b.Header()); err != nil { return fmt.Errorf("2. could not store block %s. Cause: %w", b.Hash(), err) } - if err := dbTransaction.Write(); err != nil { + if err := dbTransaction.WriteCtx(ctx); err != nil { return fmt.Errorf("3. could not store block %s. Cause: %w", b.Hash(), err) } - common.CacheValue(s.blockCache, s.logger, b.Hash(), b) + common.CacheValue(ctx, s.blockCache, s.logger, b.Hash(), b) return nil } -func (s *storageImpl) FetchBlock(blockHash common.L1BlockHash) (*types.Block, error) { +func (s *storageImpl) FetchBlock(ctx context.Context, blockHash common.L1BlockHash) (*types.Block, error) { defer s.logDuration("FetchBlock", measure.NewStopwatch()) - return common.GetCachedValue(s.blockCache, s.logger, blockHash, func(hash any) (*types.Block, error) { - return enclavedb.FetchBlock(s.db.GetSQLDB(), hash.(common.L1BlockHash)) + return common.GetCachedValue(ctx, s.blockCache, s.logger, blockHash, func(hash any) (*types.Block, error) { + return enclavedb.FetchBlock(ctx, s.db.GetSQLDB(), hash.(common.L1BlockHash)) }) } -func (s *storageImpl) FetchCanonicaBlockByHeight(height *big.Int) (*types.Block, error) { +func (s *storageImpl) FetchCanonicaBlockByHeight(ctx context.Context, height *big.Int) (*types.Block, error) { defer s.logDuration("FetchCanonicaBlockByHeight", measure.NewStopwatch()) - header, err := enclavedb.FetchBlockHeaderByHeight(s.db.GetSQLDB(), height) + header, err := enclavedb.FetchBlockHeaderByHeight(ctx, s.db.GetSQLDB(), height) if err != nil { return nil, err } - return s.FetchBlock(header.Hash()) + return s.FetchBlock(ctx, header.Hash()) } -func (s *storageImpl) FetchHeadBlock() (*types.Block, error) { +func (s *storageImpl) FetchHeadBlock(ctx context.Context) (*types.Block, error) { defer s.logDuration("FetchHeadBlock", measure.NewStopwatch()) - return enclavedb.FetchHeadBlock(s.db.GetSQLDB()) + return enclavedb.FetchHeadBlock(ctx, s.db.GetSQLDB()) } -func (s *storageImpl) StoreSecret(secret crypto.SharedEnclaveSecret) error { +func (s *storageImpl) StoreSecret(ctx context.Context, secret crypto.SharedEnclaveSecret) error { defer s.logDuration("StoreSecret", measure.NewStopwatch()) enc, err := rlp.EncodeToBytes(secret) if err != nil { return fmt.Errorf("could not encode shared secret. Cause: %w", err) } - _, err = enclavedb.WriteConfig(s.db.GetSQLDB(), masterSeedCfg, enc) + _, err = enclavedb.WriteConfig(ctx, s.db.GetSQLDB(), masterSeedCfg, enc) if err != nil { return fmt.Errorf("could not shared secret in DB. Cause: %w", err) } return nil } -func (s *storageImpl) FetchSecret() (*crypto.SharedEnclaveSecret, error) { +func (s *storageImpl) FetchSecret(ctx context.Context) (*crypto.SharedEnclaveSecret, error) { defer s.logDuration("FetchSecret", measure.NewStopwatch()) if s.cachedSharedSecret != nil { @@ -258,7 +259,7 @@ func (s *storageImpl) FetchSecret() (*crypto.SharedEnclaveSecret, error) { var ss crypto.SharedEnclaveSecret - cfg, err := enclavedb.FetchConfig(s.db.GetSQLDB(), masterSeedCfg) + cfg, err := enclavedb.FetchConfig(ctx, s.db.GetSQLDB(), masterSeedCfg) if err != nil { return nil, err } @@ -270,7 +271,7 @@ func (s *storageImpl) FetchSecret() (*crypto.SharedEnclaveSecret, error) { return s.cachedSharedSecret, nil } -func (s *storageImpl) IsAncestor(block *types.Block, maybeAncestor *types.Block) bool { +func (s *storageImpl) IsAncestor(ctx context.Context, block *types.Block, maybeAncestor *types.Block) bool { defer s.logDuration("IsAncestor", measure.NewStopwatch()) if bytes.Equal(maybeAncestor.Hash().Bytes(), block.Hash().Bytes()) { return true @@ -280,27 +281,27 @@ func (s *storageImpl) IsAncestor(block *types.Block, maybeAncestor *types.Block) return false } - p, err := s.FetchBlock(block.ParentHash()) + p, err := s.FetchBlock(ctx, block.ParentHash()) if err != nil { s.logger.Debug("Could not find block with hash", log.BlockHashKey, block.ParentHash(), log.ErrKey, err) return false } - return s.IsAncestor(p, maybeAncestor) + return s.IsAncestor(ctx, p, maybeAncestor) } -func (s *storageImpl) IsBlockAncestor(block *types.Block, maybeAncestor common.L1BlockHash) bool { +func (s *storageImpl) IsBlockAncestor(ctx context.Context, block *types.Block, maybeAncestor common.L1BlockHash) bool { defer s.logDuration("IsBlockAncestor", measure.NewStopwatch()) - resolvedBlock, err := s.FetchBlock(maybeAncestor) + resolvedBlock, err := s.FetchBlock(ctx, maybeAncestor) if err != nil { return false } - return s.IsAncestor(block, resolvedBlock) + return s.IsAncestor(ctx, block, resolvedBlock) } -func (s *storageImpl) HealthCheck() (bool, error) { +func (s *storageImpl) HealthCheck(ctx context.Context) (bool, error) { defer s.logDuration("HealthCheck", measure.NewStopwatch()) - headBatch, err := s.FetchHeadBatch() + headBatch, err := s.FetchHeadBatch(ctx) if err != nil { return false, err } @@ -312,14 +313,14 @@ func (s *storageImpl) HealthCheck() (bool, error) { return true, nil } -func (s *storageImpl) FetchHeadBatchForBlock(blockHash common.L1BlockHash) (*core.Batch, error) { +func (s *storageImpl) FetchHeadBatchForBlock(ctx context.Context, blockHash common.L1BlockHash) (*core.Batch, error) { defer s.logDuration("FetchHeadBatchForBlock", measure.NewStopwatch()) - return enclavedb.ReadHeadBatchForBlock(s.db.GetSQLDB(), blockHash) + return enclavedb.ReadHeadBatchForBlock(ctx, s.db.GetSQLDB(), blockHash) } -func (s *storageImpl) CreateStateDB(batchHash common.L2BatchHash) (*state.StateDB, error) { +func (s *storageImpl) CreateStateDB(ctx context.Context, batchHash common.L2BatchHash) (*state.StateDB, error) { defer s.logDuration("CreateStateDB", measure.NewStopwatch()) - batch, err := s.FetchBatch(batchHash) + batch, err := s.FetchBatch(ctx, batchHash) if err != nil { return nil, err } @@ -341,29 +342,29 @@ func (s *storageImpl) EmptyStateDB() (*state.StateDB, error) { } // GetReceiptsByBatchHash retrieves the receipts for all transactions in a given batch. -func (s *storageImpl) GetReceiptsByBatchHash(hash gethcommon.Hash) (types.Receipts, error) { +func (s *storageImpl) GetReceiptsByBatchHash(ctx context.Context, hash gethcommon.Hash) (types.Receipts, error) { defer s.logDuration("GetReceiptsByBatchHash", measure.NewStopwatch()) - return enclavedb.ReadReceiptsByBatchHash(s.db.GetSQLDB(), hash, s.chainConfig) + return enclavedb.ReadReceiptsByBatchHash(ctx, s.db.GetSQLDB(), hash, s.chainConfig) } -func (s *storageImpl) GetTransaction(txHash gethcommon.Hash) (*types.Transaction, common.L2BatchHash, uint64, uint64, error) { +func (s *storageImpl) GetTransaction(ctx context.Context, txHash gethcommon.Hash) (*types.Transaction, common.L2BatchHash, uint64, uint64, error) { defer s.logDuration("GetTransaction", measure.NewStopwatch()) - return enclavedb.ReadTransaction(s.db.GetSQLDB(), txHash) + return enclavedb.ReadTransaction(ctx, s.db.GetSQLDB(), txHash) } -func (s *storageImpl) GetContractCreationTx(address gethcommon.Address) (*gethcommon.Hash, error) { +func (s *storageImpl) GetContractCreationTx(ctx context.Context, address gethcommon.Address) (*gethcommon.Hash, error) { defer s.logDuration("GetContractCreationTx", measure.NewStopwatch()) - return enclavedb.GetContractCreationTx(s.db.GetSQLDB(), address) + return enclavedb.GetContractCreationTx(ctx, s.db.GetSQLDB(), address) } -func (s *storageImpl) GetTransactionReceipt(txHash gethcommon.Hash) (*types.Receipt, error) { +func (s *storageImpl) GetTransactionReceipt(ctx context.Context, txHash gethcommon.Hash) (*types.Receipt, error) { defer s.logDuration("GetTransactionReceipt", measure.NewStopwatch()) - return enclavedb.ReadReceipt(s.db.GetSQLDB(), txHash, s.chainConfig) + return enclavedb.ReadReceipt(ctx, s.db.GetSQLDB(), txHash, s.chainConfig) } -func (s *storageImpl) FetchAttestedKey(address gethcommon.Address) (*ecdsa.PublicKey, error) { +func (s *storageImpl) FetchAttestedKey(ctx context.Context, address gethcommon.Address) (*ecdsa.PublicKey, error) { defer s.logDuration("FetchAttestedKey", measure.NewStopwatch()) - key, err := enclavedb.FetchAttKey(s.db.GetSQLDB(), address) + key, err := enclavedb.FetchAttKey(ctx, s.db.GetSQLDB(), address) if err != nil { return nil, fmt.Errorf("could not retrieve attestation key for address %s. Cause: %w", address, err) } @@ -376,16 +377,16 @@ func (s *storageImpl) FetchAttestedKey(address gethcommon.Address) (*ecdsa.Publi return publicKey, nil } -func (s *storageImpl) StoreAttestedKey(aggregator gethcommon.Address, key *ecdsa.PublicKey) error { +func (s *storageImpl) StoreAttestedKey(ctx context.Context, aggregator gethcommon.Address, key *ecdsa.PublicKey) error { defer s.logDuration("StoreAttestedKey", measure.NewStopwatch()) - _, err := enclavedb.WriteAttKey(s.db.GetSQLDB(), aggregator, gethcrypto.CompressPubkey(key)) + _, err := enclavedb.WriteAttKey(ctx, s.db.GetSQLDB(), aggregator, gethcrypto.CompressPubkey(key)) return err } -func (s *storageImpl) FetchBatchBySeqNo(seqNum uint64) (*core.Batch, error) { +func (s *storageImpl) FetchBatchBySeqNo(ctx context.Context, seqNum uint64) (*core.Batch, error) { defer s.logDuration("FetchBatchBySeqNo", measure.NewStopwatch()) - b, err := common.GetCachedValue(s.batchCacheBySeqNo, s.logger, seqNum, func(seq any) (*core.Batch, error) { - return enclavedb.ReadBatchBySeqNo(s.db.GetSQLDB(), seqNum) + b, err := common.GetCachedValue(ctx, s.batchCacheBySeqNo, s.logger, seqNum, func(seq any) (*core.Batch, error) { + return enclavedb.ReadBatchBySeqNo(ctx, s.db.GetSQLDB(), seqNum) }) if err == nil && b == nil { return nil, fmt.Errorf("not found") @@ -393,15 +394,15 @@ func (s *storageImpl) FetchBatchBySeqNo(seqNum uint64) (*core.Batch, error) { return b, err } -func (s *storageImpl) FetchBatchesByBlock(block common.L1BlockHash) ([]*core.Batch, error) { +func (s *storageImpl) FetchBatchesByBlock(ctx context.Context, block common.L1BlockHash) ([]*core.Batch, error) { defer s.logDuration("FetchBatchesByBlock", measure.NewStopwatch()) - return enclavedb.ReadBatchesByBlock(s.db.GetSQLDB(), block) + return enclavedb.ReadBatchesByBlock(ctx, s.db.GetSQLDB(), block) } -func (s *storageImpl) StoreBatch(batch *core.Batch, convertedHash gethcommon.Hash) error { +func (s *storageImpl) StoreBatch(ctx context.Context, batch *core.Batch, convertedHash gethcommon.Hash) error { defer s.logDuration("StoreBatch", measure.NewStopwatch()) // sanity check that this is not overlapping - existingBatchWithSameSequence, _ := s.FetchBatchBySeqNo(batch.SeqNo().Uint64()) + existingBatchWithSameSequence, _ := s.FetchBatchBySeqNo(ctx, batch.SeqNo().Uint64()) if existingBatchWithSameSequence != nil && existingBatchWithSameSequence.Hash() != batch.Hash() { // todo - tudor - remove the Critical before production, and return a challenge s.logger.Crit(fmt.Sprintf("Conflicting batches for the same sequence %d: (previous) %+v != (incoming) %+v", batch.SeqNo(), existingBatchWithSameSequence.Header, batch.Header)) @@ -416,25 +417,25 @@ func (s *storageImpl) StoreBatch(batch *core.Batch, convertedHash gethcommon.Has dbTx := s.db.NewDBTransaction() s.logger.Trace("write batch", log.BatchHashKey, batch.Hash(), "l1Proof", batch.Header.L1Proof, log.BatchSeqNoKey, batch.SeqNo()) - if err := enclavedb.WriteBatchAndTransactions(dbTx, batch, convertedHash); err != nil { + if err := enclavedb.WriteBatchAndTransactions(ctx, dbTx, batch, convertedHash); err != nil { return fmt.Errorf("could not write batch. Cause: %w", err) } - if err := dbTx.Write(); err != nil { + if err := dbTx.WriteCtx(ctx); err != nil { return fmt.Errorf("could not commit batch %w", err) } - common.CacheValue(s.batchCacheBySeqNo, s.logger, batch.SeqNo().Uint64(), batch) - common.CacheValue(s.seqCacheByHash, s.logger, batch.Hash(), batch.SeqNo()) + common.CacheValue(ctx, s.batchCacheBySeqNo, s.logger, batch.SeqNo().Uint64(), batch) + common.CacheValue(ctx, s.seqCacheByHash, s.logger, batch.Hash(), batch.SeqNo()) // note: the key is (height+1), because for some reason it doesn't like a key of 0 // should always contain the canonical batch because the cache is overwritten by each new batch after a reorg - common.CacheValue(s.seqCacheByHeight, s.logger, batch.NumberU64()+1, batch.SeqNo()) + common.CacheValue(ctx, s.seqCacheByHeight, s.logger, batch.NumberU64()+1, batch.SeqNo()) return nil } -func (s *storageImpl) StoreExecutedBatch(batch *core.Batch, receipts []*types.Receipt) error { +func (s *storageImpl) StoreExecutedBatch(ctx context.Context, batch *core.Batch, receipts []*types.Receipt) error { defer s.logDuration("StoreExecutedBatch", measure.NewStopwatch()) - executed, err := enclavedb.BatchWasExecuted(s.db.GetSQLDB(), batch.Hash()) + executed, err := enclavedb.BatchWasExecuted(ctx, s.db.GetSQLDB(), batch.Hash()) if err != nil { return err } @@ -444,63 +445,63 @@ func (s *storageImpl) StoreExecutedBatch(batch *core.Batch, receipts []*types.Re } dbTx := s.db.NewDBTransaction() - if err := enclavedb.WriteBatchExecution(dbTx, batch.SeqNo(), receipts); err != nil { + if err := enclavedb.WriteBatchExecution(ctx, dbTx, batch.SeqNo(), receipts); err != nil { return fmt.Errorf("could not write transaction receipts. Cause: %w", err) } if batch.Number().Int64() > 1 { - stateDB, err := s.CreateStateDB(batch.Header.ParentHash) + stateDB, err := s.CreateStateDB(ctx, batch.Header.ParentHash) if err != nil { return fmt.Errorf("could not create state DB to filter logs. Cause: %w", err) } - err = enclavedb.StoreEventLogs(dbTx, receipts, stateDB) + err = enclavedb.StoreEventLogs(ctx, dbTx, receipts, stateDB) if err != nil { return fmt.Errorf("could not save logs %w", err) } } - if err = dbTx.Write(); err != nil { + if err = dbTx.WriteCtx(ctx); err != nil { return fmt.Errorf("could not commit batch %w", err) } return nil } -func (s *storageImpl) StoreValueTransfers(blockHash common.L1BlockHash, transfers common.ValueTransferEvents) error { - return enclavedb.WriteL1Messages(s.db.GetSQLDB(), blockHash, transfers, true) +func (s *storageImpl) StoreValueTransfers(ctx context.Context, blockHash common.L1BlockHash, transfers common.ValueTransferEvents) error { + return enclavedb.WriteL1Messages(ctx, s.db.GetSQLDB(), blockHash, transfers, true) } -func (s *storageImpl) StoreL1Messages(blockHash common.L1BlockHash, messages common.CrossChainMessages) error { +func (s *storageImpl) StoreL1Messages(ctx context.Context, blockHash common.L1BlockHash, messages common.CrossChainMessages) error { defer s.logDuration("StoreL1Messages", measure.NewStopwatch()) - return enclavedb.WriteL1Messages(s.db.GetSQLDB(), blockHash, messages, false) + return enclavedb.WriteL1Messages(ctx, s.db.GetSQLDB(), blockHash, messages, false) } -func (s *storageImpl) GetL1Messages(blockHash common.L1BlockHash) (common.CrossChainMessages, error) { +func (s *storageImpl) GetL1Messages(ctx context.Context, blockHash common.L1BlockHash) (common.CrossChainMessages, error) { defer s.logDuration("GetL1Messages", measure.NewStopwatch()) - return enclavedb.FetchL1Messages[common.CrossChainMessage](s.db.GetSQLDB(), blockHash, false) + return enclavedb.FetchL1Messages[common.CrossChainMessage](ctx, s.db.GetSQLDB(), blockHash, false) } -func (s *storageImpl) GetL1Transfers(blockHash common.L1BlockHash) (common.ValueTransferEvents, error) { - return enclavedb.FetchL1Messages[common.ValueTransferEvent](s.db.GetSQLDB(), blockHash, true) +func (s *storageImpl) GetL1Transfers(ctx context.Context, blockHash common.L1BlockHash) (common.ValueTransferEvents, error) { + return enclavedb.FetchL1Messages[common.ValueTransferEvent](ctx, s.db.GetSQLDB(), blockHash, true) } const enclaveKeyKey = "ek" -func (s *storageImpl) StoreEnclaveKey(enclaveKey *crypto.EnclaveKey) error { +func (s *storageImpl) StoreEnclaveKey(ctx context.Context, enclaveKey *crypto.EnclaveKey) error { defer s.logDuration("StoreEnclaveKey", measure.NewStopwatch()) if enclaveKey == nil { return errors.New("enclaveKey cannot be nil") } keyBytes := gethcrypto.FromECDSA(enclaveKey.PrivateKey()) - _, err := enclavedb.WriteConfig(s.db.GetSQLDB(), enclaveKeyKey, keyBytes) + _, err := enclavedb.WriteConfig(ctx, s.db.GetSQLDB(), enclaveKeyKey, keyBytes) return err } -func (s *storageImpl) GetEnclaveKey() (*crypto.EnclaveKey, error) { +func (s *storageImpl) GetEnclaveKey(ctx context.Context) (*crypto.EnclaveKey, error) { defer s.logDuration("GetEnclaveKey", measure.NewStopwatch()) - keyBytes, err := enclavedb.FetchConfig(s.db.GetSQLDB(), enclaveKeyKey) + keyBytes, err := enclavedb.FetchConfig(ctx, s.db.GetSQLDB(), enclaveKeyKey) if err != nil { return nil, err } @@ -511,34 +512,35 @@ func (s *storageImpl) GetEnclaveKey() (*crypto.EnclaveKey, error) { return crypto.NewEnclaveKey(ecdsaKey), nil } -func (s *storageImpl) StoreRollup(rollup *common.ExtRollup, internalHeader *common.CalldataRollupHeader) error { +func (s *storageImpl) StoreRollup(ctx context.Context, rollup *common.ExtRollup, internalHeader *common.CalldataRollupHeader) error { defer s.logDuration("StoreRollup", measure.NewStopwatch()) dbBatch := s.db.NewDBTransaction() - if err := enclavedb.WriteRollup(dbBatch, rollup.Header, internalHeader); err != nil { + if err := enclavedb.WriteRollup(ctx, dbBatch, rollup.Header, internalHeader); err != nil { return fmt.Errorf("could not write rollup. Cause: %w", err) } - if err := dbBatch.Write(); err != nil { + if err := dbBatch.WriteCtx(ctx); err != nil { return fmt.Errorf("could not write rollup to storage. Cause: %w", err) } return nil } -func (s *storageImpl) FetchReorgedRollup(reorgedBlocks []common.L1BlockHash) (*common.L2BatchHash, error) { - return enclavedb.FetchReorgedRollup(s.db.GetSQLDB(), reorgedBlocks) +func (s *storageImpl) FetchReorgedRollup(ctx context.Context, reorgedBlocks []common.L1BlockHash) (*common.L2BatchHash, error) { + return enclavedb.FetchReorgedRollup(ctx, s.db.GetSQLDB(), reorgedBlocks) } -func (s *storageImpl) FetchRollupMetadata(hash common.L2RollupHash) (*common.PublicRollupMetadata, error) { - return enclavedb.FetchRollupMetadata(s.db.GetSQLDB(), hash) +func (s *storageImpl) FetchRollupMetadata(ctx context.Context, hash common.L2RollupHash) (*common.PublicRollupMetadata, error) { + return enclavedb.FetchRollupMetadata(ctx, s.db.GetSQLDB(), hash) } -func (s *storageImpl) DebugGetLogs(txHash common.TxHash) ([]*tracers.DebugLogs, error) { +func (s *storageImpl) DebugGetLogs(ctx context.Context, txHash common.TxHash) ([]*tracers.DebugLogs, error) { defer s.logDuration("DebugGetLogs", measure.NewStopwatch()) - return enclavedb.DebugGetLogs(s.db.GetSQLDB(), txHash) + return enclavedb.DebugGetLogs(ctx, s.db.GetSQLDB(), txHash) } func (s *storageImpl) FilterLogs( + ctx context.Context, requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, blockHash *common.L2BatchHash, @@ -546,42 +548,42 @@ func (s *storageImpl) FilterLogs( topics [][]gethcommon.Hash, ) ([]*types.Log, error) { defer s.logDuration("FilterLogs", measure.NewStopwatch()) - return enclavedb.FilterLogs(s.db.GetSQLDB(), requestingAccount, fromBlock, toBlock, blockHash, addresses, topics) + return enclavedb.FilterLogs(ctx, s.db.GetSQLDB(), requestingAccount, fromBlock, toBlock, blockHash, addresses, topics) } -func (s *storageImpl) GetContractCount() (*big.Int, error) { +func (s *storageImpl) GetContractCount(ctx context.Context) (*big.Int, error) { defer s.logDuration("GetContractCount", measure.NewStopwatch()) - return enclavedb.ReadContractCreationCount(s.db.GetSQLDB()) + return enclavedb.ReadContractCreationCount(ctx, s.db.GetSQLDB()) } -func (s *storageImpl) FetchCanonicalUnexecutedBatches(from *big.Int) ([]*core.Batch, error) { +func (s *storageImpl) FetchCanonicalUnexecutedBatches(ctx context.Context, from *big.Int) ([]*core.Batch, error) { defer s.logDuration("FetchCanonicalUnexecutedBatches", measure.NewStopwatch()) - return enclavedb.ReadUnexecutedBatches(s.db.GetSQLDB(), from) + return enclavedb.ReadUnexecutedBatches(ctx, s.db.GetSQLDB(), from) } -func (s *storageImpl) BatchWasExecuted(hash common.L2BatchHash) (bool, error) { +func (s *storageImpl) BatchWasExecuted(ctx context.Context, hash common.L2BatchHash) (bool, error) { defer s.logDuration("BatchWasExecuted", measure.NewStopwatch()) - return enclavedb.BatchWasExecuted(s.db.GetSQLDB(), hash) + return enclavedb.BatchWasExecuted(ctx, s.db.GetSQLDB(), hash) } -func (s *storageImpl) GetReceiptsPerAddress(address *gethcommon.Address, pagination *common.QueryPagination) (types.Receipts, error) { +func (s *storageImpl) GetReceiptsPerAddress(ctx context.Context, address *gethcommon.Address, pagination *common.QueryPagination) (types.Receipts, error) { defer s.logDuration("GetReceiptsPerAddress", measure.NewStopwatch()) - return enclavedb.GetReceiptsPerAddress(s.db.GetSQLDB(), s.chainConfig, address, pagination) + return enclavedb.GetReceiptsPerAddress(ctx, s.db.GetSQLDB(), s.chainConfig, address, pagination) } -func (s *storageImpl) GetReceiptsPerAddressCount(address *gethcommon.Address) (uint64, error) { +func (s *storageImpl) GetReceiptsPerAddressCount(ctx context.Context, address *gethcommon.Address) (uint64, error) { defer s.logDuration("GetReceiptsPerAddressCount", measure.NewStopwatch()) - return enclavedb.GetReceiptsPerAddressCount(s.db.GetSQLDB(), address) + return enclavedb.GetReceiptsPerAddressCount(ctx, s.db.GetSQLDB(), address) } -func (s *storageImpl) GetPublicTransactionData(pagination *common.QueryPagination) ([]common.PublicTransaction, error) { +func (s *storageImpl) GetPublicTransactionData(ctx context.Context, pagination *common.QueryPagination) ([]common.PublicTransaction, error) { defer s.logDuration("GetPublicTransactionData", measure.NewStopwatch()) - return enclavedb.GetPublicTransactionData(s.db.GetSQLDB(), pagination) + return enclavedb.GetPublicTransactionData(ctx, s.db.GetSQLDB(), pagination) } -func (s *storageImpl) GetPublicTransactionCount() (uint64, error) { +func (s *storageImpl) GetPublicTransactionCount(ctx context.Context) (uint64, error) { defer s.logDuration("GetPublicTransactionCount", measure.NewStopwatch()) - return enclavedb.GetPublicTransactionCount(s.db.GetSQLDB()) + return enclavedb.GetPublicTransactionCount(ctx, s.db.GetSQLDB()) } func (s *storageImpl) logDuration(method string, stopWatch *measure.Stopwatch) { diff --git a/go/ethadapter/interface.go b/go/ethadapter/interface.go index 45f3d0e56f..281f1e6abf 100644 --- a/go/ethadapter/interface.go +++ b/go/ethadapter/interface.go @@ -2,7 +2,6 @@ package ethadapter import ( "context" - "errors" "math/big" "github.com/ten-protocol/go-ten/go/common" @@ -14,9 +13,6 @@ import ( "github.com/ethereum/go-ethereum/ethclient" ) -// ErrSubscriptionNotSupported return from BlockListener subscription if client doesn't support streaming (in-mem simulation) -var ErrSubscriptionNotSupported = errors.New("block subscription not supported") - // EthClient defines the interface for RPC communications with the ethereum nodes // todo (#1617) - some of these methods are composed calls that should be decoupled in the future (ie: BlocksBetween or IsBlockAncestor) type EthClient interface { diff --git a/go/host/enclave/guardian.go b/go/host/enclave/guardian.go index 8584a1dc43..5f1970a337 100644 --- a/go/host/enclave/guardian.go +++ b/go/host/enclave/guardian.go @@ -1,6 +1,7 @@ package enclave import ( + "context" "fmt" "math/big" "strings" @@ -102,7 +103,7 @@ func (g *Guardian) Start() error { // Identify the enclave before starting (the enclave generates its ID immediately at startup) // (retry until we get the enclave ID or the host is stopping) for g.enclaveID == nil && !g.hostInterrupter.IsStopping() { - enclID, err := g.enclaveClient.EnclaveID() + enclID, err := g.enclaveClient.EnclaveID(context.Background()) if err != nil { g.logger.Warn("could not get enclave ID", log.ErrKey, err) time.Sleep(_retryInterval) @@ -149,7 +150,7 @@ func (g *Guardian) Stop() error { return nil } -func (g *Guardian) HealthStatus() host.HealthStatus { +func (g *Guardian) HealthStatus(context.Context) host.HealthStatus { // todo (@matt) do proper health status based on enclave state errMsg := "" if !g.hostInterrupter.IsStopping() { @@ -196,7 +197,8 @@ func (g *Guardian) HandleBatch(batch *common.ExtBatch) { if g.hostData.IsSequencer || !g.state.IsUpToDate() { return // ignore batches until we're up-to-date } - err := g.submitL2Batch(batch) + // todo - @matt - does it make sense to use a timeout context? + err := g.submitL2Batch(context.Background(), batch) if err != nil { g.logger.Error("Error submitting batch to enclave", log.ErrKey, err) } @@ -209,7 +211,7 @@ func (g *Guardian) HandleTransaction(tx common.EncryptedTx) { g.logger.Info("Enclave is not ready yet, dropping transaction.") return // ignore transactions when enclave unavailable } - resp, sysError := g.enclaveClient.SubmitTx(tx) + resp, sysError := g.enclaveClient.SubmitTx(context.Background(), tx) if sysError != nil { g.logger.Warn("could not submit transaction due to sysError", log.ErrKey, sysError) return @@ -266,7 +268,7 @@ func (g *Guardian) mainLoop() { } func (g *Guardian) checkEnclaveStatus() { - s, err := g.enclaveClient.Status() + s, err := g.enclaveClient.Status(context.Background()) if err != nil { g.logger.Error("Could not get enclave status", log.ErrKey, err) // we record this as a disconnection, we can't get any more info from the enclave about status currently @@ -282,7 +284,7 @@ func (g *Guardian) provideSecret() error { // instead of requesting a secret, we generate one and broadcast it return g.generateAndBroadcastSecret() } - att, err := g.enclaveClient.Attestation() + att, err := g.enclaveClient.Attestation(context.Background()) if err != nil { return fmt.Errorf("could not retrieve attestation from enclave. Cause: %w", err) } @@ -306,7 +308,7 @@ func (g *Guardian) provideSecret() error { secretRespTxs, _, _ := g.sl.L1Publisher().ExtractObscuroRelevantTransactions(nextBlock) for _, scrt := range secretRespTxs { if scrt.RequesterID.Hex() == g.enclaveID.Hex() { - err = g.enclaveClient.InitEnclave(scrt.Secret) + err = g.enclaveClient.InitEnclave(context.Background(), scrt.Secret) if err != nil { g.logger.Error("Could not initialize enclave with received secret response", log.ErrKey, err) continue // try the next secret response in the block if there are more @@ -333,7 +335,7 @@ func (g *Guardian) provideSecret() error { func (g *Guardian) generateAndBroadcastSecret() error { g.logger.Info("Node is genesis node. Publishing secret to L1 management contract.") // Create the shared secret and submit it to the management contract for storage - attestation, err := g.enclaveClient.Attestation() + attestation, err := g.enclaveClient.Attestation(context.Background()) if err != nil { return fmt.Errorf("could not retrieve attestation from enclave. Cause: %w", err) } @@ -341,7 +343,7 @@ func (g *Guardian) generateAndBroadcastSecret() error { return fmt.Errorf("genesis enclave has ID %s, but its enclave produced an attestation using ID %s", g.enclaveID.Hex(), attestation.EnclaveID.Hex()) } - secret, err := g.enclaveClient.GenerateSecret() + secret, err := g.enclaveClient.GenerateSecret(context.Background()) if err != nil { return fmt.Errorf("could not generate secret. Cause: %w", err) } @@ -394,12 +396,12 @@ func (g *Guardian) catchupWithL2() error { nextHead := prevHead.Add(prevHead, big.NewInt(1)) g.logger.Trace("fetching next batch", log.BatchSeqNoKey, nextHead) - batch, err := g.sl.L2Repo().FetchBatchBySeqNo(nextHead) + batch, err := g.sl.L2Repo().FetchBatchBySeqNo(context.Background(), nextHead) if err != nil { return errors.Wrap(err, "could not fetch next L2 batch") } - err = g.submitL2Batch(batch) + err = g.submitL2Batch(context.Background(), batch) if err != nil { return err } @@ -421,7 +423,7 @@ func (g *Guardian) submitL1Block(block *common.L1Block, isLatest bool) (bool, er g.submitDataLock.Unlock() // lock must be released before returning return false, fmt.Errorf("could not fetch obscuro receipts for block=%s - %w", block.Hash(), err) } - resp, err := g.enclaveClient.SubmitL1Block(*block, receipts, isLatest) + resp, err := g.enclaveClient.SubmitL1Block(context.Background(), *block, receipts, isLatest) g.submitDataLock.Unlock() // lock is only guarding the enclave call, so we can release it now if err != nil { if strings.Contains(err.Error(), errutil.ErrBlockAlreadyProcessed.Error()) { @@ -469,7 +471,7 @@ func (g *Guardian) processL1BlockTransactions(block *common.L1Block) { g.logger.Error("Could not decode rollup.", log.ErrKey, err) } - metaData, err := g.enclaveClient.GetRollupData(r.Header.Hash()) + metaData, err := g.enclaveClient.GetRollupData(context.Background(), r.Header.Hash()) if err != nil { g.logger.Error("Could not fetch rollup metadata from enclave.", log.ErrKey, err) } else { @@ -517,9 +519,9 @@ func (g *Guardian) publishSharedSecretResponses(scrtResponses []*common.Produced return nil } -func (g *Guardian) submitL2Batch(batch *common.ExtBatch) error { +func (g *Guardian) submitL2Batch(ctx context.Context, batch *common.ExtBatch) error { g.submitDataLock.Lock() - err := g.enclaveClient.SubmitBatch(batch) + err := g.enclaveClient.SubmitBatch(ctx, batch) g.submitDataLock.Unlock() if err != nil { // something went wrong, return error and let the main loop check status and try again when appropriate @@ -555,7 +557,7 @@ func (g *Guardian) periodicBatchProduction() { // if maxBatchInterval is set higher than batchInterval then we are happy to skip creating batches when there is no data // (up to a maximum time of maxBatchInterval) skipBatchIfEmpty := g.maxBatchInterval > g.batchInterval && time.Since(g.lastBatchCreated) < g.maxBatchInterval - err := g.enclaveClient.CreateBatch(skipBatchIfEmpty) + err := g.enclaveClient.CreateBatch(context.Background(), skipBatchIfEmpty) if err != nil { g.logger.Error("Unable to produce batch", log.ErrKey, err) } @@ -609,7 +611,7 @@ func (g *Guardian) periodicRollupProduction() { sizeExceeded := estimatedRunningRollupSize >= g.maxRollupSize if timeExpired || sizeExceeded { g.logger.Info("Trigger rollup production.", "timeExpired", timeExpired, "sizeExceeded", sizeExceeded) - producedRollup, err := g.enclaveClient.CreateRollup(fromBatch) + producedRollup, err := g.enclaveClient.CreateRollup(context.Background(), fromBatch) if err != nil { g.logger.Error("Unable to create rollup", log.BatchSeqNoKey, fromBatch, log.ErrKey, err) continue @@ -688,7 +690,7 @@ func (g *Guardian) calculateNonRolledupBatchesSize(seqNo uint64) (uint64, error) currentNo := seqNo for { - batch, err := g.sl.L2Repo().FetchBatchBySeqNo(big.NewInt(int64(currentNo))) + batch, err := g.sl.L2Repo().FetchBatchBySeqNo(context.TODO(), big.NewInt(int64(currentNo))) if err != nil { if errors.Is(err, errutil.ErrNotFound) { break // no more batches diff --git a/go/host/enclave/service.go b/go/host/enclave/service.go index 1c836a2ad8..cc4edc82f4 100644 --- a/go/host/enclave/service.go +++ b/go/host/enclave/service.go @@ -1,6 +1,7 @@ package enclave import ( + "context" "fmt" "math/big" "sync/atomic" @@ -65,7 +66,7 @@ func (e *Service) Stop() error { return nil } -func (e *Service) HealthStatus() host.HealthStatus { +func (e *Service) HealthStatus(ctx context.Context) host.HealthStatus { if !e.running.Load() { return &host.BasicErrHealthStatus{ErrMsg: "not running"} } @@ -74,7 +75,7 @@ func (e *Service) HealthStatus() host.HealthStatus { for i, guardian := range e.enclaveGuardians { // check the enclave health, which in turn checks the DB health - enclaveHealthy, err := guardian.enclaveClient.HealthCheck() + enclaveHealthy, err := guardian.enclaveClient.HealthCheck(ctx) if err != nil { errors = append(errors, fmt.Errorf("unable to HealthCheck enclave[%d] - %w", i, err)) } else if !enclaveHealthy { @@ -90,9 +91,9 @@ func (e *Service) HealthStatus() host.HealthStatus { return &host.GroupErrsHealthStatus{Errors: errors} } -func (e *Service) HealthyGuardian() *Guardian { +func (e *Service) HealthyGuardian(ctx context.Context) *Guardian { for _, guardian := range e.enclaveGuardians { - if guardian.HealthStatus().OK() { + if guardian.HealthStatus(ctx).OK() { return guardian } } @@ -101,14 +102,14 @@ func (e *Service) HealthyGuardian() *Guardian { // LookupBatchBySeqNo is used to fetch batch data from the enclave - it is only used as a fallback for the sequencer // host if it's missing a batch (other host services should use L2Repo to fetch batch data) -func (e *Service) LookupBatchBySeqNo(seqNo *big.Int) (*common.ExtBatch, error) { - hg := e.HealthyGuardian() +func (e *Service) LookupBatchBySeqNo(ctx context.Context, seqNo *big.Int) (*common.ExtBatch, error) { + hg := e.HealthyGuardian(ctx) state := hg.GetEnclaveState() if state.GetEnclaveL2Head().Cmp(seqNo) < 0 { return nil, errutil.ErrNotFound } client := hg.GetEnclaveClient() - return client.GetBatchBySeqNo(seqNo.Uint64()) + return client.GetBatchBySeqNo(ctx, seqNo.Uint64()) } func (e *Service) GetEnclaveClient() common.Enclave { @@ -117,10 +118,10 @@ func (e *Service) GetEnclaveClient() common.Enclave { return e.enclaveGuardians[0].GetEnclaveClient() } -func (e *Service) SubmitAndBroadcastTx(encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) { +func (e *Service) SubmitAndBroadcastTx(ctx context.Context, encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) { encryptedTx := common.EncryptedTx(encryptedParams) - enclaveResponse, sysError := e.GetEnclaveClient().SubmitTx(encryptedTx) + enclaveResponse, sysError := e.GetEnclaveClient().SubmitTx(ctx, encryptedTx) if sysError != nil { e.logger.Warn("Could not submit transaction due to sysError.", log.ErrKey, sysError) return nil, sysError @@ -141,7 +142,7 @@ func (e *Service) SubmitAndBroadcastTx(encryptedParams common.EncryptedParamsSen } func (e *Service) Subscribe(id rpc.ID, encryptedParams common.EncryptedParamsLogSubscription) error { - return e.GetEnclaveClient().Subscribe(id, encryptedParams) + return e.GetEnclaveClient().Subscribe(context.Background(), id, encryptedParams) } func (e *Service) Unsubscribe(id rpc.ID) error { diff --git a/go/host/events/logs.go b/go/host/events/logs.go index 9099811654..b0f25a246d 100644 --- a/go/host/events/logs.go +++ b/go/host/events/logs.go @@ -1,6 +1,7 @@ package events import ( + "context" "sync" "github.com/pkg/errors" @@ -42,7 +43,7 @@ func (l *LogEventManager) Stop() error { return nil } -func (l *LogEventManager) HealthStatus() host.HealthStatus { +func (l *LogEventManager) HealthStatus(context.Context) host.HealthStatus { // always healthy for now return &host.BasicErrHealthStatus{ErrMsg: ""} } diff --git a/go/host/host.go b/go/host/host.go index e846e09420..867140058b 100644 --- a/go/host/host.go +++ b/go/host/host.go @@ -1,6 +1,7 @@ package host import ( + "context" "encoding/json" "fmt" @@ -155,11 +156,11 @@ func (h *host) EnclaveClient() common.Enclave { return h.services.Enclaves().GetEnclaveClient() } -func (h *host) SubmitAndBroadcastTx(encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) { +func (h *host) SubmitAndBroadcastTx(ctx context.Context, encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) { if h.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested SubmitAndBroadcastTx with the host stopping")) } - return h.services.Enclaves().SubmitAndBroadcastTx(encryptedParams) + return h.services.Enclaves().SubmitAndBroadcastTx(ctx, encryptedParams) } func (h *host) SubscribeLogs(id rpc.ID, encryptedLogSubscription common.EncryptedParamsLogSubscription, matchedLogsCh chan []byte) error { @@ -198,7 +199,7 @@ func (h *host) Stop() error { } // HealthCheck returns whether the host, enclave and DB are healthy -func (h *host) HealthCheck() (*hostcommon.HealthCheck, error) { +func (h *host) HealthCheck(ctx context.Context) (*hostcommon.HealthCheck, error) { if h.stopControl.IsStopping() { return nil, responses.ToInternalError(fmt.Errorf("requested HealthCheck with the host stopping")) } @@ -207,7 +208,7 @@ func (h *host) HealthCheck() (*hostcommon.HealthCheck, error) { // loop through all registered services and collect their health statuses for name, service := range h.services.All() { - status := service.HealthStatus() + status := service.HealthStatus(ctx) if !status.OK() { healthErrors = append(healthErrors, fmt.Sprintf("[%s] not healthy - %s", name, status.Message())) } @@ -222,7 +223,7 @@ func (h *host) HealthCheck() (*hostcommon.HealthCheck, error) { // ObscuroConfig returns info on the Obscuro network func (h *host) ObscuroConfig() (*common.ObscuroNetworkInfo, error) { if h.l2MessageBusAddress == nil { - publicCfg, err := h.EnclaveClient().EnclavePublicConfig() + publicCfg, err := h.EnclaveClient().EnclavePublicConfig(context.Background()) if err != nil { return nil, responses.ToInternalError(fmt.Errorf("unable to get L2 message bus address - %w", err)) } diff --git a/go/host/l1/blockrepository.go b/go/host/l1/blockrepository.go index 5ccc5885df..bd0a2ad4fc 100644 --- a/go/host/l1/blockrepository.go +++ b/go/host/l1/blockrepository.go @@ -1,6 +1,7 @@ package l1 import ( + "context" "errors" "fmt" "math/big" @@ -63,7 +64,7 @@ func (r *Repository) Stop() error { return nil } -func (r *Repository) HealthStatus() host.HealthStatus { +func (r *Repository) HealthStatus(context.Context) host.HealthStatus { // todo (@matt) do proper health status based on last received block or something errMsg := "" if !r.running.Load() { diff --git a/go/host/l1/publisher.go b/go/host/l1/publisher.go index de821798fb..69f6610dd6 100644 --- a/go/host/l1/publisher.go +++ b/go/host/l1/publisher.go @@ -97,7 +97,7 @@ func (p *Publisher) Stop() error { return nil } -func (p *Publisher) HealthStatus() host.HealthStatus { +func (p *Publisher) HealthStatus(context.Context) host.HealthStatus { // todo (@matt) do proper health status based on failed transactions or something errMsg := "" if p.hostStopper.IsStopping() { diff --git a/go/host/l2/batchrepository.go b/go/host/l2/batchrepository.go index ed7cfee642..27ed9fff7a 100644 --- a/go/host/l2/batchrepository.go +++ b/go/host/l2/batchrepository.go @@ -1,6 +1,7 @@ package l2 import ( + "context" "errors" "fmt" "math/big" @@ -83,7 +84,7 @@ func (r *Repository) Stop() error { return nil } -func (r *Repository) HealthStatus() host.HealthStatus { +func (r *Repository) HealthStatus(context.Context) host.HealthStatus { // todo (@matt) do proper health status based on last received batch or something errMsg := "" if !r.running.Load() { @@ -148,13 +149,13 @@ func (r *Repository) Subscribe(handler host.L2BatchHandler) func() { return r.batchSubscribers.Subscribe(handler) } -func (r *Repository) FetchBatchBySeqNo(seqNo *big.Int) (*common.ExtBatch, error) { +func (r *Repository) FetchBatchBySeqNo(ctx context.Context, seqNo *big.Int) (*common.ExtBatch, error) { b, err := r.storage.FetchBatchBySeqNo(seqNo.Uint64()) if err != nil { if errors.Is(err, errutil.ErrNotFound) && seqNo.Cmp(r.latestBatchSeqNo) < 0 { if r.isSequencer { // sequencer does not request batches from peers, it checks if its enclave has the batch - return r.fetchBatchFallbackToEnclave(seqNo) + return r.fetchBatchFallbackToEnclave(ctx, seqNo) } // we haven't seen this batch before, but it is older than the latest batch we have seen so far // Request missing batches from peers (the batches from any response will be added asynchronously, so @@ -189,8 +190,8 @@ func (r *Repository) AddBatch(batch *common.ExtBatch) error { return nil } -func (r *Repository) fetchBatchFallbackToEnclave(seqNo *big.Int) (*common.ExtBatch, error) { - b, err := r.sl.Enclaves().LookupBatchBySeqNo(seqNo) +func (r *Repository) fetchBatchFallbackToEnclave(ctx context.Context, seqNo *big.Int) (*common.ExtBatch, error) { + b, err := r.sl.Enclaves().LookupBatchBySeqNo(ctx, seqNo) if err != nil { return nil, err } diff --git a/go/host/p2p/p2p.go b/go/host/p2p/p2p.go index 876efc0ec3..e18ce9069f 100644 --- a/go/host/p2p/p2p.go +++ b/go/host/p2p/p2p.go @@ -1,6 +1,7 @@ package p2p import ( + "context" "fmt" "io" "math/big" @@ -140,7 +141,7 @@ func (p *Service) Stop() error { return nil } -func (p *Service) HealthStatus() host.HealthStatus { +func (p *Service) HealthStatus(context.Context) host.HealthStatus { msg := "" if err := p.verifyHealth(); err != nil { msg = err.Error() diff --git a/go/host/rpc/clientapi/client_api_debug.go b/go/host/rpc/clientapi/client_api_debug.go index 5af064d386..8299d44924 100644 --- a/go/host/rpc/clientapi/client_api_debug.go +++ b/go/host/rpc/clientapi/client_api_debug.go @@ -22,8 +22,8 @@ func NewNetworkDebug(host host.Host) *NetworkDebug { // TraceTransaction returns the structured logs created during the execution of EVM // and returns them as a JSON object. -func (api *NetworkDebug) TraceTransaction(_ context.Context, hash gethcommon.Hash, config *tracers.TraceConfig) (interface{}, error) { - response, err := api.host.EnclaveClient().DebugTraceTransaction(hash, config) +func (api *NetworkDebug) TraceTransaction(ctx context.Context, hash gethcommon.Hash, config *tracers.TraceConfig) (interface{}, error) { + response, err := api.host.EnclaveClient().DebugTraceTransaction(ctx, hash, config) if err != nil { return "", err } @@ -31,8 +31,8 @@ func (api *NetworkDebug) TraceTransaction(_ context.Context, hash gethcommon.Has } // EventLogRelevancy returns the events for a given transactions and the revelancy params -func (api *NetworkDebug) EventLogRelevancy(_ context.Context, hash gethcommon.Hash) (interface{}, error) { - response, err := api.host.EnclaveClient().DebugEventLogRelevancy(hash) +func (api *NetworkDebug) EventLogRelevancy(ctx context.Context, hash gethcommon.Hash) (interface{}, error) { + response, err := api.host.EnclaveClient().DebugEventLogRelevancy(ctx, hash) if err != nil { return "", err } diff --git a/go/host/rpc/clientapi/client_api_eth.go b/go/host/rpc/clientapi/client_api_eth.go index 9da97805af..703da1c40b 100644 --- a/go/host/rpc/clientapi/client_api_eth.go +++ b/go/host/rpc/clientapi/client_api_eth.go @@ -82,8 +82,8 @@ func (api *EthereumAPI) GasPrice(context.Context) (*hexutil.Big, error) { // GetBalance returns the address's balance on the Obscuro network, encrypted with the viewing key corresponding to the // `address` field and encoded as hex. -func (api *EthereumAPI) GetBalance(_ context.Context, encryptedParams common.EncryptedParamsGetBalance) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.EnclaveClient().GetBalance(encryptedParams) +func (api *EthereumAPI) GetBalance(ctx context.Context, encryptedParams common.EncryptedParamsGetBalance) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.EnclaveClient().GetBalance(ctx, encryptedParams) if sysError != nil { return api.handleSysError("GetBalance", sysError) } @@ -92,8 +92,8 @@ func (api *EthereumAPI) GetBalance(_ context.Context, encryptedParams common.Enc // Call returns the result of executing the smart contract as a user, encrypted with the viewing key corresponding to // the `from` field and encoded as hex. -func (api *EthereumAPI) Call(_ context.Context, encryptedParams common.EncryptedParamsCall) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.EnclaveClient().ObsCall(encryptedParams) +func (api *EthereumAPI) Call(ctx context.Context, encryptedParams common.EncryptedParamsCall) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.EnclaveClient().ObsCall(ctx, encryptedParams) if sysError != nil { return api.handleSysError("Call", sysError) } @@ -102,8 +102,8 @@ func (api *EthereumAPI) Call(_ context.Context, encryptedParams common.Encrypted // GetTransactionReceipt returns the transaction receipt for the given transaction hash, encrypted with the viewing key // corresponding to the original transaction submitter and encoded as hex, or nil if no matching transaction exists. -func (api *EthereumAPI) GetTransactionReceipt(_ context.Context, encryptedParams common.EncryptedParamsGetTxReceipt) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.EnclaveClient().GetTransactionReceipt(encryptedParams) +func (api *EthereumAPI) GetTransactionReceipt(ctx context.Context, encryptedParams common.EncryptedParamsGetTxReceipt) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.EnclaveClient().GetTransactionReceipt(ctx, encryptedParams) if sysError != nil { return api.handleSysError("GetTransactionReceipt", sysError) } @@ -111,8 +111,8 @@ func (api *EthereumAPI) GetTransactionReceipt(_ context.Context, encryptedParams } // EstimateGas requests the enclave the gas estimation based on the callMsg supplied params (encrypted) -func (api *EthereumAPI) EstimateGas(_ context.Context, encryptedParams common.EncryptedParamsEstimateGas) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.EnclaveClient().EstimateGas(encryptedParams) +func (api *EthereumAPI) EstimateGas(ctx context.Context, encryptedParams common.EncryptedParamsEstimateGas) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.EnclaveClient().EstimateGas(ctx, encryptedParams) if sysError != nil { return api.handleSysError("EstimateGas", sysError) } @@ -120,8 +120,8 @@ func (api *EthereumAPI) EstimateGas(_ context.Context, encryptedParams common.En } // SendRawTransaction sends the encrypted transaction. -func (api *EthereumAPI) SendRawTransaction(_ context.Context, encryptedParams common.EncryptedParamsSendRawTx) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.SubmitAndBroadcastTx(encryptedParams) +func (api *EthereumAPI) SendRawTransaction(ctx context.Context, encryptedParams common.EncryptedParamsSendRawTx) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.SubmitAndBroadcastTx(ctx, encryptedParams) if sysError != nil { return api.handleSysError("SendRawTransaction", sysError) } @@ -130,7 +130,7 @@ func (api *EthereumAPI) SendRawTransaction(_ context.Context, encryptedParams co // GetCode returns the code stored at the given address in the state for the given batch height or batch hash. // todo (#1620) - instead of converting the block number of hash client-side, do it on the enclave -func (api *EthereumAPI) GetCode(_ context.Context, address gethcommon.Address, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { +func (api *EthereumAPI) GetCode(ctx context.Context, address gethcommon.Address, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { var batchHash *gethcommon.Hash // requested a number @@ -151,7 +151,7 @@ func (api *EthereumAPI) GetCode(_ context.Context, address gethcommon.Address, b return nil, errors.New("invalid arguments; neither batch height nor batch hash specified") } - code, sysError := api.host.EnclaveClient().GetCode(address, batchHash) + code, sysError := api.host.EnclaveClient().GetCode(ctx, address, batchHash) if sysError != nil { api.logger.Error(fmt.Sprintf("Enclave System Error. Function %s", "GetCode"), log.ErrKey, sysError) return nil, fmt.Errorf(responses.InternalErrMsg) @@ -160,8 +160,8 @@ func (api *EthereumAPI) GetCode(_ context.Context, address gethcommon.Address, b return code, nil } -func (api *EthereumAPI) GetTransactionCount(_ context.Context, encryptedParams common.EncryptedParamsGetTxCount) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.EnclaveClient().GetTransactionCount(encryptedParams) +func (api *EthereumAPI) GetTransactionCount(ctx context.Context, encryptedParams common.EncryptedParamsGetTxCount) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.EnclaveClient().GetTransactionCount(ctx, encryptedParams) if sysError != nil { return api.handleSysError("GetTransactionCount", sysError) } @@ -170,8 +170,8 @@ func (api *EthereumAPI) GetTransactionCount(_ context.Context, encryptedParams c // GetTransactionByHash returns the transaction with the given hash, encrypted with the viewing key corresponding to the // `from` field and encoded as hex, or nil if no matching transaction exists. -func (api *EthereumAPI) GetTransactionByHash(_ context.Context, encryptedParams common.EncryptedParamsGetTxByHash) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.EnclaveClient().GetTransaction(encryptedParams) +func (api *EthereumAPI) GetTransactionByHash(ctx context.Context, encryptedParams common.EncryptedParamsGetTxByHash) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.EnclaveClient().GetTransaction(ctx, encryptedParams) if sysError != nil { return api.handleSysError("GetTransactionByHash", sysError) } @@ -179,8 +179,8 @@ func (api *EthereumAPI) GetTransactionByHash(_ context.Context, encryptedParams } // GetStorageAt is a reused method for listing the users transactions -func (api *EthereumAPI) GetStorageAt(_ context.Context, encryptedParams common.EncryptedParamsGetStorageAt) (*responses.Receipts, error) { - return api.host.EnclaveClient().GetCustomQuery(encryptedParams) +func (api *EthereumAPI) GetStorageAt(ctx context.Context, encryptedParams common.EncryptedParamsGetStorageAt) (*responses.Receipts, error) { + return api.host.EnclaveClient().GetCustomQuery(ctx, encryptedParams) } // FeeHistory is a placeholder for an RPC method required by MetaMask/Remix. diff --git a/go/host/rpc/clientapi/client_api_filter.go b/go/host/rpc/clientapi/client_api_filter.go index c0677a5ca1..7af780cda1 100644 --- a/go/host/rpc/clientapi/client_api_filter.go +++ b/go/host/rpc/clientapi/client_api_filter.go @@ -67,8 +67,8 @@ func (api *FilterAPI) Logs(ctx context.Context, encryptedParams common.Encrypted } // GetLogs returns the logs matching the filter. -func (api *FilterAPI) GetLogs(_ context.Context, encryptedParams common.EncryptedParamsGetLogs) (responses.EnclaveResponse, error) { - enclaveResponse, sysError := api.host.EnclaveClient().GetLogs(encryptedParams) +func (api *FilterAPI) GetLogs(ctx context.Context, encryptedParams common.EncryptedParamsGetLogs) (responses.EnclaveResponse, error) { + enclaveResponse, sysError := api.host.EnclaveClient().GetLogs(ctx, encryptedParams) if sysError != nil { return api.handleSysError("GetLogs", sysError) } diff --git a/go/host/rpc/clientapi/client_api_obscuro.go b/go/host/rpc/clientapi/client_api_obscuro.go index 2b0db66784..3762826fbe 100644 --- a/go/host/rpc/clientapi/client_api_obscuro.go +++ b/go/host/rpc/clientapi/client_api_obscuro.go @@ -1,6 +1,8 @@ package clientapi import ( + "context" + gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/host" @@ -18,8 +20,8 @@ func NewObscuroAPI(host host.Host) *ObscuroAPI { } // Health returns the health status of obscuro host + enclave + db -func (api *ObscuroAPI) Health() (*host.HealthCheck, error) { - return api.host.HealthCheck() +func (api *ObscuroAPI) Health(ctx context.Context) (*host.HealthCheck, error) { + return api.host.HealthCheck(ctx) } // Config returns the config status of obscuro host + enclave + db diff --git a/go/host/rpc/clientapi/client_api_scan.go b/go/host/rpc/clientapi/client_api_scan.go index c1124b5156..7bbde8b2cb 100644 --- a/go/host/rpc/clientapi/client_api_scan.go +++ b/go/host/rpc/clientapi/client_api_scan.go @@ -1,6 +1,7 @@ package clientapi import ( + "context" "math/big" gethcommon "github.com/ethereum/go-ethereum/common" @@ -24,8 +25,8 @@ func NewScanAPI(host host.Host, logger log.Logger) *ScanAPI { } // GetTotalContractCount returns the number of recorded contracts on the network. -func (s *ScanAPI) GetTotalContractCount() (*big.Int, error) { - return s.host.EnclaveClient().GetTotalContractCount() +func (s *ScanAPI) GetTotalContractCount(ctx context.Context) (*big.Int, error) { + return s.host.EnclaveClient().GetTotalContractCount(ctx) } // GetTotalTxCount returns the number of recorded transactions on the network. @@ -79,8 +80,8 @@ func (s *ScanAPI) GetLatestRollupHeader() (*common.RollupHeader, error) { } // GetPublicTransactionData returns a paginated list of transaction data -func (s *ScanAPI) GetPublicTransactionData(pagination *common.QueryPagination) (*common.TransactionListingResponse, error) { - return s.host.EnclaveClient().GetPublicTransactionData(pagination) +func (s *ScanAPI) GetPublicTransactionData(ctx context.Context, pagination *common.QueryPagination) (*common.TransactionListingResponse, error) { + return s.host.EnclaveClient().GetPublicTransactionData(ctx, pagination) } // GetBlockListing returns a paginated list of blocks that include rollups diff --git a/go/host/rpc/enclaverpc/enclave_client.go b/go/host/rpc/enclaverpc/enclave_client.go index 7ff31de0bb..4fa0bc32e9 100644 --- a/go/host/rpc/enclaverpc/enclave_client.go +++ b/go/host/rpc/enclaverpc/enclave_client.go @@ -10,7 +10,6 @@ import ( "github.com/ten-protocol/go-ten/go/enclave/core" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/log" @@ -80,12 +79,12 @@ func (c *Client) StopClient() common.SystemError { return c.connection.Close() } -func (c *Client) Status() (common.Status, common.SystemError) { +func (c *Client) Status(ctx context.Context) (common.Status, common.SystemError) { if c.connection.GetState() != connectivity.Ready { return common.Status{StatusCode: common.Unavailable}, syserr.NewInternalError(fmt.Errorf("RPC connection is not ready")) } - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.Status(timeoutCtx, &generated.StatusRequest{}) @@ -103,8 +102,8 @@ func (c *Client) Status() (common.Status, common.SystemError) { }, nil } -func (c *Client) Attestation() (*common.AttestationReport, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) Attestation(ctx context.Context) (*common.AttestationReport, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.Attestation(timeoutCtx, &generated.AttestationRequest{}) @@ -117,8 +116,8 @@ func (c *Client) Attestation() (*common.AttestationReport, common.SystemError) { return rpc.FromAttestationReportMsg(response.AttestationReportMsg), nil } -func (c *Client) GenerateSecret() (common.EncryptedSharedEnclaveSecret, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GenerateSecret(ctx context.Context) (common.EncryptedSharedEnclaveSecret, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GenerateSecret(timeoutCtx, &generated.GenerateSecretRequest{}) @@ -132,8 +131,8 @@ func (c *Client) GenerateSecret() (common.EncryptedSharedEnclaveSecret, common.S return response.EncryptedSharedEnclaveSecret, nil } -func (c *Client) InitEnclave(secret common.EncryptedSharedEnclaveSecret) common.SystemError { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) InitEnclave(ctx context.Context, secret common.EncryptedSharedEnclaveSecret) common.SystemError { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.InitEnclave(timeoutCtx, &generated.InitEnclaveRequest{EncryptedSharedEnclaveSecret: secret}) @@ -146,8 +145,8 @@ func (c *Client) InitEnclave(secret common.EncryptedSharedEnclaveSecret) common. return nil } -func (c *Client) EnclaveID() (common.EnclaveID, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) EnclaveID(ctx context.Context) (common.EnclaveID, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.EnclaveID(timeoutCtx, &generated.EnclaveIDRequest{}) @@ -160,8 +159,8 @@ func (c *Client) EnclaveID() (common.EnclaveID, common.SystemError) { return common.EnclaveID(response.EnclaveID), nil } -func (c *Client) SubmitL1Block(block types.Block, receipts types.Receipts, isLatest bool) (*common.BlockSubmissionResponse, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) SubmitL1Block(ctx context.Context, block common.L1Block, receipts common.L1Receipts, isLatest bool) (*common.BlockSubmissionResponse, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() var buffer bytes.Buffer @@ -186,8 +185,8 @@ func (c *Client) SubmitL1Block(block types.Block, receipts types.Receipts, isLat return blockSubmissionResponse, nil } -func (c *Client) SubmitTx(tx common.EncryptedTx) (*responses.RawTx, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) SubmitTx(ctx context.Context, tx common.EncryptedTx) (*responses.RawTx, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.SubmitTx(timeoutCtx, &generated.SubmitTxRequest{EncryptedTx: tx}) @@ -201,10 +200,10 @@ func (c *Client) SubmitTx(tx common.EncryptedTx) (*responses.RawTx, common.Syste return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) SubmitBatch(batch *common.ExtBatch) common.SystemError { +func (c *Client) SubmitBatch(ctx context.Context, batch *common.ExtBatch) common.SystemError { defer core.LogMethodDuration(c.logger, measure.NewStopwatch(), "SubmitBatch rpc call") - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() batchMsg := rpc.ToExtBatchMsg(batch) @@ -219,8 +218,8 @@ func (c *Client) SubmitBatch(batch *common.ExtBatch) common.SystemError { return nil } -func (c *Client) ObsCall(encryptedParams common.EncryptedParamsCall) (*responses.Call, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) ObsCall(ctx context.Context, encryptedParams common.EncryptedParamsCall) (*responses.Call, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.ObsCall(timeoutCtx, &generated.ObsCallRequest{ @@ -236,8 +235,8 @@ func (c *Client) ObsCall(encryptedParams common.EncryptedParamsCall) (*responses return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) GetTransactionCount(encryptedParams common.EncryptedParamsGetTxCount) (*responses.TxCount, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetTransactionCount(ctx context.Context, encryptedParams common.EncryptedParamsGetTxCount) (*responses.TxCount, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetTransactionCount(timeoutCtx, &generated.GetTransactionCountRequest{EncryptedParams: encryptedParams}) @@ -267,8 +266,8 @@ func (c *Client) Stop() common.SystemError { return nil } -func (c *Client) GetTransaction(encryptedParams common.EncryptedParamsGetTxByHash) (*responses.TxByHash, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetTransaction(ctx context.Context, encryptedParams common.EncryptedParamsGetTxByHash) (*responses.TxByHash, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetTransaction(timeoutCtx, &generated.GetTransactionRequest{EncryptedParams: encryptedParams}) @@ -282,8 +281,8 @@ func (c *Client) GetTransaction(encryptedParams common.EncryptedParamsGetTxByHas return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) GetTransactionReceipt(encryptedParams common.EncryptedParamsGetTxReceipt) (*responses.TxReceipt, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetTransactionReceipt(ctx context.Context, encryptedParams common.EncryptedParamsGetTxReceipt) (*responses.TxReceipt, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetTransactionReceipt(timeoutCtx, &generated.GetTransactionReceiptRequest{EncryptedParams: encryptedParams}) @@ -297,8 +296,8 @@ func (c *Client) GetTransactionReceipt(encryptedParams common.EncryptedParamsGet return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) GetBalance(encryptedParams common.EncryptedParamsGetBalance) (*responses.Balance, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetBalance(ctx context.Context, encryptedParams common.EncryptedParamsGetBalance) (*responses.Balance, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetBalance(timeoutCtx, &generated.GetBalanceRequest{ @@ -314,8 +313,8 @@ func (c *Client) GetBalance(encryptedParams common.EncryptedParamsGetBalance) (* return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) GetCode(address gethcommon.Address, batchHash *gethcommon.Hash) ([]byte, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetCode(ctx context.Context, address gethcommon.Address, batchHash *gethcommon.Hash) ([]byte, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetCode(timeoutCtx, &generated.GetCodeRequest{ @@ -332,8 +331,8 @@ func (c *Client) GetCode(address gethcommon.Address, batchHash *gethcommon.Hash) return response.Code, nil } -func (c *Client) Subscribe(id gethrpc.ID, encryptedParams common.EncryptedParamsLogSubscription) common.SystemError { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) Subscribe(ctx context.Context, id gethrpc.ID, encryptedParams common.EncryptedParamsLogSubscription) common.SystemError { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.Subscribe(timeoutCtx, &generated.SubscribeRequest{ @@ -365,8 +364,8 @@ func (c *Client) Unsubscribe(id gethrpc.ID) common.SystemError { return nil } -func (c *Client) EstimateGas(encryptedParams common.EncryptedParamsEstimateGas) (*responses.Gas, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) EstimateGas(ctx context.Context, encryptedParams common.EncryptedParamsEstimateGas) (*responses.Gas, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.EstimateGas(timeoutCtx, &generated.EstimateGasRequest{ @@ -382,8 +381,8 @@ func (c *Client) EstimateGas(encryptedParams common.EncryptedParamsEstimateGas) return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) GetLogs(encryptedParams common.EncryptedParamsGetLogs) (*responses.Logs, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetLogs(ctx context.Context, encryptedParams common.EncryptedParamsGetLogs) (*responses.Logs, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetLogs(timeoutCtx, &generated.GetLogsRequest{ @@ -399,8 +398,8 @@ func (c *Client) GetLogs(encryptedParams common.EncryptedParamsGetLogs) (*respon return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) HealthCheck() (bool, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) HealthCheck(ctx context.Context) (bool, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.HealthCheck(timeoutCtx, &generated.EmptyArgs{}) @@ -413,10 +412,10 @@ func (c *Client) HealthCheck() (bool, common.SystemError) { return response.Status, nil } -func (c *Client) CreateBatch(skipIfEmpty bool) common.SystemError { +func (c *Client) CreateBatch(ctx context.Context, skipIfEmpty bool) common.SystemError { defer core.LogMethodDuration(c.logger, measure.NewStopwatch(), "CreateBatch rpc call") - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.CreateBatch(timeoutCtx, &generated.CreateBatchRequest{SkipIfEmpty: skipIfEmpty}) @@ -429,10 +428,10 @@ func (c *Client) CreateBatch(skipIfEmpty bool) common.SystemError { return err } -func (c *Client) CreateRollup(fromSeqNo uint64) (*common.ExtRollup, common.SystemError) { +func (c *Client) CreateRollup(ctx context.Context, fromSeqNo uint64) (*common.ExtRollup, common.SystemError) { defer core.LogMethodDuration(c.logger, measure.NewStopwatch(), "CreateRollup rpc call") - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.CreateRollup(timeoutCtx, &generated.CreateRollupRequest{ @@ -448,8 +447,8 @@ func (c *Client) CreateRollup(fromSeqNo uint64) (*common.ExtRollup, common.Syste return rpc.FromExtRollupMsg(response.Msg), nil } -func (c *Client) DebugTraceTransaction(hash gethcommon.Hash, config *tracers.TraceConfig) (json.RawMessage, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) DebugTraceTransaction(ctx context.Context, hash gethcommon.Hash, config *tracers.TraceConfig) (json.RawMessage, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() confBytes, err := json.Marshal(config) @@ -471,8 +470,8 @@ func (c *Client) DebugTraceTransaction(hash gethcommon.Hash, config *tracers.Tra return json.RawMessage(response.Msg), nil } -func (c *Client) GetBatch(hash common.L2BatchHash) (*common.ExtBatch, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetBatch(ctx context.Context, hash common.L2BatchHash) (*common.ExtBatch, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() batchMsg, err := c.protoClient.GetBatch(timeoutCtx, &generated.GetBatchRequest{KnownHead: hash.Bytes()}) @@ -483,8 +482,8 @@ func (c *Client) GetBatch(hash common.L2BatchHash) (*common.ExtBatch, common.Sys return common.DecodeExtBatch(batchMsg.Batch) } -func (c *Client) GetBatchBySeqNo(seqNo uint64) (*common.ExtBatch, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetBatchBySeqNo(ctx context.Context, seqNo uint64) (*common.ExtBatch, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() batchMsg, err := c.protoClient.GetBatchBySeqNo(timeoutCtx, &generated.GetBatchBySeqNoRequest{SeqNo: seqNo}) @@ -495,8 +494,8 @@ func (c *Client) GetBatchBySeqNo(seqNo uint64) (*common.ExtBatch, common.SystemE return common.DecodeExtBatch(batchMsg.Batch) } -func (c *Client) GetRollupData(hash common.L2RollupHash) (*common.PublicRollupMetadata, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetRollupData(ctx context.Context, hash common.L2RollupHash) (*common.PublicRollupMetadata, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetRollupData(timeoutCtx, &generated.GetRollupDataRequest{Hash: hash.Bytes()}) @@ -552,8 +551,8 @@ func (c *Client) StreamL2Updates() (chan common.StreamL2UpdatesResponse, func()) return batchChan, cancel } -func (c *Client) DebugEventLogRelevancy(hash gethcommon.Hash) (json.RawMessage, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) DebugEventLogRelevancy(ctx context.Context, hash gethcommon.Hash) (json.RawMessage, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.DebugEventLogRelevancy(timeoutCtx, &generated.DebugEventLogRelevancyRequest{ @@ -568,8 +567,8 @@ func (c *Client) DebugEventLogRelevancy(hash gethcommon.Hash) (json.RawMessage, return json.RawMessage(response.Msg), nil } -func (c *Client) GetTotalContractCount() (*big.Int, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetTotalContractCount(ctx context.Context) (*big.Int, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetTotalContractCount(timeoutCtx, &generated.GetTotalContractCountRequest{}) @@ -582,8 +581,8 @@ func (c *Client) GetTotalContractCount() (*big.Int, common.SystemError) { return big.NewInt(response.Count), nil } -func (c *Client) GetCustomQuery(encryptedParams common.EncryptedParamsGetStorageAt) (*responses.Receipts, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetCustomQuery(ctx context.Context, encryptedParams common.EncryptedParamsGetStorageAt) (*responses.PrivateQueryResponse, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetReceiptsByAddress(timeoutCtx, &generated.GetReceiptsByAddressRequest{ @@ -599,8 +598,8 @@ func (c *Client) GetCustomQuery(encryptedParams common.EncryptedParamsGetStorage return responses.ToEnclaveResponse(response.EncodedEnclaveResponse), nil } -func (c *Client) GetPublicTransactionData(pagination *common.QueryPagination) (*common.TransactionListingResponse, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) GetPublicTransactionData(ctx context.Context, pagination *common.QueryPagination) (*common.TransactionListingResponse, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.GetPublicTransactionData(timeoutCtx, &generated.GetPublicTransactionDataRequest{ @@ -625,8 +624,8 @@ func (c *Client) GetPublicTransactionData(pagination *common.QueryPagination) (* return &result, nil } -func (c *Client) EnclavePublicConfig() (*common.EnclavePublicConfig, common.SystemError) { - timeoutCtx, cancel := context.WithTimeout(context.Background(), c.enclaveRPCTimeout) +func (c *Client) EnclavePublicConfig(ctx context.Context) (*common.EnclavePublicConfig, common.SystemError) { + timeoutCtx, cancel := context.WithTimeout(ctx, c.enclaveRPCTimeout) defer cancel() response, err := c.protoClient.EnclavePublicConfig(timeoutCtx, &generated.EnclavePublicConfigRequest{}) diff --git a/integration/common/constants.go b/integration/common/constants.go index 9c802797c0..e50ec3daf8 100644 --- a/integration/common/constants.go +++ b/integration/common/constants.go @@ -2,6 +2,7 @@ package common import ( "math/big" + "time" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/params" @@ -90,5 +91,6 @@ func DefaultEnclaveConfig() *config.EnclaveConfig { // whilst the usage is small. Should be ok since execution is paid for anyway. GasLocalExecutionCapFlag: 300_000_000_000, GasBatchExecutionLimit: 300_000_000_000, + RPCTimeout: 5 * time.Second, } } diff --git a/integration/ethereummock/db.go b/integration/ethereummock/db.go index 6f806ab2e4..517a4efb8b 100644 --- a/integration/ethereummock/db.go +++ b/integration/ethereummock/db.go @@ -2,6 +2,7 @@ package ethereummock import ( "bytes" + "context" "math/big" "sync" @@ -23,11 +24,11 @@ type blockResolverInMem struct { m sync.RWMutex } -func (n *blockResolverInMem) FetchCanonicaBlockByHeight(_ *big.Int) (*types.Block, error) { +func (n *blockResolverInMem) FetchCanonicaBlockByHeight(_ context.Context, _ *big.Int) (*types.Block, error) { panic("implement me") } -func (n *blockResolverInMem) Proof(_ *core.Rollup) (*types.Block, error) { +func (n *blockResolverInMem) Proof(_ context.Context, _ *core.Rollup) (*types.Block, error) { panic("implement me") } @@ -38,14 +39,14 @@ func NewResolver() storage.BlockResolver { } } -func (n *blockResolverInMem) StoreBlock(block *types.Block, _ *common.ChainFork) error { +func (n *blockResolverInMem) StoreBlock(_ context.Context, block *types.Block, _ *common.ChainFork) error { n.m.Lock() defer n.m.Unlock() n.blockCache[block.Hash()] = block return nil } -func (n *blockResolverInMem) FetchBlock(hash common.L1BlockHash) (*types.Block, error) { +func (n *blockResolverInMem) FetchBlock(_ context.Context, hash common.L1BlockHash) (*types.Block, error) { n.m.RLock() defer n.m.RUnlock() block, f := n.blockCache[hash] @@ -56,7 +57,7 @@ func (n *blockResolverInMem) FetchBlock(hash common.L1BlockHash) (*types.Block, return block, nil } -func (n *blockResolverInMem) FetchHeadBlock() (*types.Block, error) { +func (n *blockResolverInMem) FetchHeadBlock(_ context.Context) (*types.Block, error) { n.m.RLock() defer n.m.RUnlock() var max *types.Block @@ -72,11 +73,11 @@ func (n *blockResolverInMem) FetchHeadBlock() (*types.Block, error) { return max, nil } -func (n *blockResolverInMem) ParentBlock(b *types.Block) (*types.Block, error) { - return n.FetchBlock(b.Header().ParentHash) +func (n *blockResolverInMem) ParentBlock(ctx context.Context, b *types.Block) (*types.Block, error) { + return n.FetchBlock(ctx, b.Header().ParentHash) } -func (n *blockResolverInMem) IsAncestor(block *types.Block, maybeAncestor *types.Block) bool { +func (n *blockResolverInMem) IsAncestor(ctx context.Context, block *types.Block, maybeAncestor *types.Block) bool { if bytes.Equal(maybeAncestor.Hash().Bytes(), block.Hash().Bytes()) { return true } @@ -85,15 +86,15 @@ func (n *blockResolverInMem) IsAncestor(block *types.Block, maybeAncestor *types return false } - p, err := n.ParentBlock(block) + p, err := n.ParentBlock(ctx, block) if err != nil { return false } - return n.IsAncestor(p, maybeAncestor) + return n.IsAncestor(ctx, p, maybeAncestor) } -func (n *blockResolverInMem) IsBlockAncestor(block *types.Block, maybeAncestor common.L1BlockHash) bool { +func (n *blockResolverInMem) IsBlockAncestor(ctx context.Context, block *types.Block, maybeAncestor common.L1BlockHash) bool { if bytes.Equal(maybeAncestor.Bytes(), block.Hash().Bytes()) { return true } @@ -106,20 +107,20 @@ func (n *blockResolverInMem) IsBlockAncestor(block *types.Block, maybeAncestor c return false } - resolvedBlock, err := n.FetchBlock(maybeAncestor) + resolvedBlock, err := n.FetchBlock(ctx, maybeAncestor) if err == nil { if resolvedBlock.NumberU64() >= block.NumberU64() { return false } } - p, err := n.ParentBlock(block) + p, err := n.ParentBlock(ctx, block) if err != nil { // todo (@tudor) - if error is not `errutil.ErrNotFound`, throw return false } - return n.IsBlockAncestor(p, maybeAncestor) + return n.IsBlockAncestor(ctx, p, maybeAncestor) } // The cache of included transactions @@ -152,6 +153,7 @@ func (n *txDBInMem) AddTxs(b *types.Block, newMap map[common.TxHash]*types.Trans // removeCommittedTransactions returns a copy of `mempool` where all transactions that are exactly `committedBlocks` // deep have been removed. func (m *Node) removeCommittedTransactions( + ctx context.Context, cb *types.Block, mempool []*types.Transaction, resolver storage.BlockResolver, @@ -169,7 +171,7 @@ func (m *Node) removeCommittedTransactions( break } - p, err := resolver.FetchBlock(b.ParentHash()) + p, err := resolver.FetchBlock(ctx, b.ParentHash()) if err != nil { m.logger.Crit("Could not retrieve parent block.", log.ErrKey, err) } diff --git a/integration/ethereummock/mock_l1_network.go b/integration/ethereummock/mock_l1_network.go index f6dcd7613e..eac6ade118 100644 --- a/integration/ethereummock/mock_l1_network.go +++ b/integration/ethereummock/mock_l1_network.go @@ -1,6 +1,7 @@ package ethereummock import ( + "context" "fmt" "time" @@ -106,7 +107,7 @@ func printBlock(b *types.Block, m *Node) string { txs = append(txs, fmt.Sprintf("deposit(%d=%d)", to, l1Tx.Amount)) } } - p, err := m.Resolver.FetchBlock(b.ParentHash()) + p, err := m.Resolver.FetchBlock(context.Background(), b.ParentHash()) if err != nil { testlog.Logger().Crit("Should not happen. Could not retrieve parent", log.ErrKey, err) } diff --git a/integration/ethereummock/node.go b/integration/ethereummock/node.go index 1531adffc3..cb41bac416 100644 --- a/integration/ethereummock/node.go +++ b/integration/ethereummock/node.go @@ -167,7 +167,7 @@ func (m *Node) BlockListener() (chan *types.Header, ethereum.Subscription) { } func (m *Node) BlockNumber() (uint64, error) { - blk, err := m.Resolver.FetchHeadBlock() + blk, err := m.Resolver.FetchHeadBlock(context.Background()) if err != nil { if errors.Is(err, errutil.ErrNotFound) { return 0, ethereum.NotFound @@ -182,7 +182,7 @@ func (m *Node) BlockByNumber(n *big.Int) (*types.Block, error) { return MockGenesisBlock, nil } // TODO this should be a method in the resolver - blk, err := m.Resolver.FetchHeadBlock() + blk, err := m.Resolver.FetchHeadBlock(context.Background()) if err != nil { if errors.Is(err, errutil.ErrNotFound) { return nil, ethereum.NotFound @@ -194,7 +194,7 @@ func (m *Node) BlockByNumber(n *big.Int) (*types.Block, error) { return blk, nil } - blk, err = m.Resolver.FetchBlock(blk.ParentHash()) + blk, err = m.Resolver.FetchBlock(context.Background(), blk.ParentHash()) if err != nil { return nil, fmt.Errorf("could not retrieve parent for block in chain. Cause: %w", err) } @@ -203,7 +203,7 @@ func (m *Node) BlockByNumber(n *big.Int) (*types.Block, error) { } func (m *Node) BlockByHash(id gethcommon.Hash) (*types.Block, error) { - blk, err := m.Resolver.FetchBlock(id) + blk, err := m.Resolver.FetchBlock(context.Background(), id) if err != nil { return nil, fmt.Errorf("block could not be retrieved. Cause: %w", err) } @@ -211,7 +211,7 @@ func (m *Node) BlockByHash(id gethcommon.Hash) (*types.Block, error) { } func (m *Node) FetchHeadBlock() (*types.Block, error) { - block, err := m.Resolver.FetchHeadBlock() + block, err := m.Resolver.FetchHeadBlock(context.Background()) if err != nil { return nil, fmt.Errorf("could not retrieve head block. Cause: %w", err) } @@ -225,7 +225,7 @@ func (m *Node) Info() ethadapter.Info { } func (m *Node) IsBlockAncestor(block *types.Block, proof common.L1BlockHash) bool { - return m.Resolver.IsBlockAncestor(block, proof) + return m.Resolver.IsBlockAncestor(context.Background(), block, proof) } func (m *Node) BalanceAt(gethcommon.Address, *big.Int) (*big.Int, error) { @@ -260,7 +260,7 @@ func (m *Node) Start() { go m.startMining() } - err := m.Resolver.StoreBlock(MockGenesisBlock, nil) + err := m.Resolver.StoreBlock(context.Background(), MockGenesisBlock, nil) if err != nil { m.logger.Crit("Failed to store block") } @@ -269,7 +269,7 @@ func (m *Node) Start() { for { select { case p2pb := <-m.p2pCh: // Received from peers - _, err := m.Resolver.FetchBlock(p2pb.Hash()) + _, err := m.Resolver.FetchBlock(context.Background(), p2pb.Hash()) // only process blocks if they haven't been processed before if err != nil { if errors.Is(err, errutil.ErrNotFound) { @@ -282,7 +282,7 @@ func (m *Node) Start() { case mb := <-m.miningCh: // Received from the local mining head = m.processBlock(mb, head) if bytes.Equal(head.Hash().Bytes(), mb.Hash().Bytes()) { // Ignore the locally produced block if someone else found one already - p, err := m.Resolver.FetchBlock(mb.ParentHash()) + p, err := m.Resolver.FetchBlock(context.Background(), mb.ParentHash()) if err != nil { panic(fmt.Errorf("could not retrieve parent. Cause: %w", err)) } @@ -305,11 +305,11 @@ func (m *Node) Start() { } func (m *Node) processBlock(b *types.Block, head *types.Block) *types.Block { - err := m.Resolver.StoreBlock(b, nil) + err := m.Resolver.StoreBlock(context.Background(), b, nil) if err != nil { m.logger.Crit("Failed to store block. Cause: %w", err) } - _, err = m.Resolver.FetchBlock(b.Header().ParentHash) + _, err = m.Resolver.FetchBlock(context.Background(), b.Header().ParentHash) // only proceed if the parent is available if err != nil { if errors.Is(err, errutil.ErrNotFound) { @@ -325,9 +325,9 @@ func (m *Node) processBlock(b *types.Block, head *types.Block) *types.Block { } // Check for Reorgs - if !m.Resolver.IsAncestor(b, head) { + if !m.Resolver.IsAncestor(context.Background(), b, head) { m.stats.L1Reorg(m.l2ID) - fork, err := gethutil.LCA(head, b, m.Resolver) + fork, err := gethutil.LCA(context.Background(), head, b, m.Resolver) if err != nil { panic(err) } @@ -418,7 +418,7 @@ func (m *Node) startMining() { // A new canonical block was found. Start a new round based on that block. // remove transactions that are already considered committed - mempool = m.removeCommittedTransactions(canonicalBlock, mempool, m.Resolver, m.db) + mempool = m.removeCommittedTransactions(context.Background(), canonicalBlock, mempool, m.Resolver, m.db) // notify the existing mining go routine to stop mining atomic.StoreInt32(interrupt, 1) @@ -474,7 +474,7 @@ func (m *Node) BlocksBetween(blockA *types.Block, blockB *types.Block) []*types. if bytes.Equal(tempBlock.Hash().Bytes(), blockA.Hash().Bytes()) { break } - tempBlock, err = m.Resolver.FetchBlock(tempBlock.ParentHash()) + tempBlock, err = m.Resolver.FetchBlock(context.Background(), tempBlock.ParentHash()) if err != nil { panic(fmt.Errorf("could not retrieve parent block. Cause: %w", err)) } diff --git a/integration/ethereummock/utils.go b/integration/ethereummock/utils.go index 7e2ae1e02a..326f478155 100644 --- a/integration/ethereummock/utils.go +++ b/integration/ethereummock/utils.go @@ -1,6 +1,7 @@ package ethereummock import ( + "context" "fmt" "github.com/ten-protocol/go-ten/go/enclave/storage" @@ -26,7 +27,7 @@ func allIncludedTransactions(b *types.Block, r storage.BlockResolver, db TxDB) m return makeMap(b.Transactions()) } newMap := make(map[common.TxHash]*types.Transaction) - p, err := r.FetchBlock(b.ParentHash()) + p, err := r.FetchBlock(context.Background(), b.ParentHash()) if err != nil { panic(fmt.Errorf("should not happen. Could not retrieve parent. Cause: %w", err)) } diff --git a/integration/simulation/devnetwork/node.go b/integration/simulation/devnetwork/node.go index 3cffcfb53a..cb938e95c0 100644 --- a/integration/simulation/devnetwork/node.go +++ b/integration/simulation/devnetwork/node.go @@ -2,6 +2,7 @@ package devnetwork import ( "fmt" + "time" "github.com/ten-protocol/go-ten/lib/gethfork/node" @@ -215,6 +216,7 @@ func (n *InMemNodeOperator) createEnclaveContainer(idx int) *enclavecontainer.En GasBatchExecutionLimit: defaultCfg.GasBatchExecutionLimit, GasLocalExecutionCapFlag: defaultCfg.GasLocalExecutionCapFlag, GasPaymentAddress: defaultCfg.GasPaymentAddress, + RPCTimeout: 5 * time.Second, } return enclavecontainer.NewEnclaveContainerWithLogger(enclaveConfig, enclaveLogger) } diff --git a/integration/simulation/network/network_utils.go b/integration/simulation/network/network_utils.go index 9b6c3f1aa3..3564cff4f0 100644 --- a/integration/simulation/network/network_utils.go +++ b/integration/simulation/network/network_utils.go @@ -97,6 +97,7 @@ func createInMemObscuroNode( BaseFee: big.NewInt(1), // todo @siliev:: fix test transaction builders so this can be different GasLocalExecutionCapFlag: params.MaxGasLimit / 2, GasBatchExecutionLimit: params.MaxGasLimit / 2, + RPCTimeout: 5 * time.Second, } enclaveLogger := testlog.Logger().New(log.NodeIDKey, id, log.CmpKey, log.EnclaveCmp) diff --git a/integration/simulation/p2p/in_mem_obscuro_client.go b/integration/simulation/p2p/in_mem_obscuro_client.go index 6db8332e1c..bc94156a5b 100644 --- a/integration/simulation/p2p/in_mem_obscuro_client.go +++ b/integration/simulation/p2p/in_mem_obscuro_client.go @@ -381,7 +381,7 @@ func (c *inMemObscuroClient) getPublicTransactionData(result interface{}, args [ return fmt.Errorf("first arg to %s is of type %T, expected type int", rpc.GetPublicTransactionData, args[0]) } - txs, err := c.tenScanAPI.GetPublicTransactionData(pagination) + txs, err := c.tenScanAPI.GetPublicTransactionData(context.Background(), pagination) if err != nil { return fmt.Errorf("`%s` call failed. Cause: %w", rpc.GetPublicTransactionData, err) } diff --git a/integration/simulation/p2p/mock_l2_network.go b/integration/simulation/p2p/mock_l2_network.go index cc562e3c1b..7dbba0670c 100644 --- a/integration/simulation/p2p/mock_l2_network.go +++ b/integration/simulation/p2p/mock_l2_network.go @@ -1,6 +1,7 @@ package p2p import ( + "context" "math/big" "strconv" "sync/atomic" @@ -119,7 +120,7 @@ func (n *MockP2P) Stop() error { return nil } -func (n *MockP2P) HealthStatus() host.HealthStatus { +func (n *MockP2P) HealthStatus(context.Context) host.HealthStatus { return &host.BasicErrHealthStatus{ErrMsg: ""} } diff --git a/tools/walletextension/rpcapi/utils.go b/tools/walletextension/rpcapi/utils.go index eeb2fb9408..357e38ca18 100644 --- a/tools/walletextension/rpcapi/utils.go +++ b/tools/walletextension/rpcapi/utils.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "fmt" "time" @@ -31,6 +32,10 @@ const ( longCacheTTL = 5 * time.Hour shortCacheTTL = 1 * time.Minute + + // hardcoding the maximum time for an RPC request + // this value will be propagated to the node and enclave and all the operations + maximumRPCCallDuration = 5 * time.Second ) var rpcNotImplemented = fmt.Errorf("rpc endpoint not implemented") @@ -58,6 +63,9 @@ type CacheCfg struct { } func UnauthenticatedTenRPCCall[R any](ctx context.Context, w *Services, cfg *CacheCfg, method string, args ...any) (*R, error) { + if ctx == nil { + return nil, errors.New("invalid call. nil Context") + } audit(w, "RPC start method=%s args=%v", method, args) requestStartTime := time.Now() cacheArgs := []any{method} @@ -67,11 +75,12 @@ func UnauthenticatedTenRPCCall[R any](ctx context.Context, w *Services, cfg *Cac return withPlainRPCConnection(w, func(client *rpc.Client) (*R, error) { var resp *R var err error - if ctx == nil { - err = client.Call(&resp, method, args...) - } else { - err = client.CallContext(ctx, &resp, method, args...) - } + + // wrap the context with a timeout to prevent long executions + timeoutContext, cancelCtx := context.WithTimeout(ctx, maximumRPCCallDuration) + defer cancelCtx() + + err = client.CallContext(timeoutContext, &resp, method, args...) return resp, err }) }) @@ -114,7 +123,12 @@ func ExecAuthRPC[R any](ctx context.Context, w *Services, cfg *ExecCfg, method s if cfg.adjustArgs != nil { adjustedArgs = cfg.adjustArgs(acct) } - err := rpcClient.CallContext(ctx, &result, method, adjustedArgs...) + + // wrap the context with a timeout to prevent long executions + timeoutContext, cancelCtx := context.WithTimeout(ctx, maximumRPCCallDuration) + defer cancelCtx() + + err := rpcClient.CallContext(timeoutContext, &result, method, adjustedArgs...) return result, err }) if err != nil {