diff --git a/go/enclave/events/subscription_manager.go b/go/enclave/events/subscription_manager.go index 27e6515d41..5641d09b62 100644 --- a/go/enclave/events/subscription_manager.go +++ b/go/enclave/events/subscription_manager.go @@ -11,24 +11,15 @@ import ( "github.com/ten-protocol/go-ten/go/enclave/vkhandler" gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" - "github.com/ten-protocol/go-ten/go/common/log" - "github.com/ten-protocol/go-ten/go/enclave/core" "github.com/ten-protocol/go-ten/go/enclave/storage" - gethcommon "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" gethlog "github.com/ethereum/go-ethereum/log" "github.com/ten-protocol/go-ten/go/common" ) -const ( - // The leading zero bytes in a hash indicating that it is possibly an address, since it only has 20 bytes of data. - zeroBytesHex = "000000000000000000000000" -) - type logSubscription struct { Subscription *common.LogSubscription // Handles the viewing key encryption @@ -92,24 +83,6 @@ func (s *SubscriptionManager) RemoveSubscription(id gethrpc.ID) { delete(s.subscriptions, id) } -// FilterLogsForReceipt removes the logs that the sender of a transaction is not allowed to view -func FilterLogsForReceipt(ctx context.Context, receipt *types.Receipt, account *gethcommon.Address, registry components.BatchRegistry) ([]*types.Log, error) { - var filteredLogs []*types.Log - stateDB, err := registry.GetBatchState(ctx, &receipt.BlockHash) - if err != nil { - return nil, fmt.Errorf("could not create state DB to filter logs. Cause: %w", err) - } - - for _, logItem := range receipt.Logs { - userAddrs := getUserAddrsFromLogTopics(logItem, stateDB) - if isRelevant(account, userAddrs) { - filteredLogs = append(filteredLogs, logItem) - } - } - - return filteredLogs, nil -} - // GetSubscribedLogsForBatch - Retrieves and encrypts the logs for the batch in live mode. // The assumption is that this function is called synchronously after the batch is produced func (s *SubscriptionManager) GetSubscribedLogsForBatch(ctx context.Context, batch *core.Batch, receipts types.Receipts) (common.EncryptedSubscriptionLogs, error) { @@ -121,47 +94,17 @@ func (s *SubscriptionManager) GetSubscribedLogsForBatch(ctx context.Context, bat return nil, nil } + h := batch.Hash() relevantLogsPerSubscription := map[gethrpc.ID][]*types.Log{} - // extract the logs from all receipts - var allLogs []*types.Log - for _, receipt := range receipts { - allLogs = append(allLogs, receipt.Logs...) - } - - if len(allLogs) == 0 { + if len(receipts) == 0 { return nil, nil } - // the stateDb is needed to extract the user addresses from the topics - h := batch.Hash() - stateDB, err := s.registry.GetBatchState(ctx, &h) - if err != nil { - return nil, fmt.Errorf("could not create state DB to filter logs. Cause: %w", err) - } - - // cache for the user addresses extracted from the individual logs - // this is an expensive operation so we are doing it lazy, and caching the result - userAddrsForLog := map[*types.Log][]*gethcommon.Address{} - for id, sub := range s.subscriptions { - // first filter the logs - filteredLogs := filterLogs(allLogs, sub.Subscription.Filter.FromBlock, sub.Subscription.Filter.ToBlock, sub.Subscription.Filter.Addresses, sub.Subscription.Filter.Topics, s.logger) - - // the account requesting the logs is retrieved from the Viewing Key - requestingAccount := sub.ViewingKeyEncryptor.AccountAddress - relevantLogsForSub := []*types.Log{} - for _, logItem := range filteredLogs { - userAddrs, f := userAddrsForLog[logItem] - if !f { - userAddrs = getUserAddrsFromLogTopics(logItem, stateDB) - userAddrsForLog[logItem] = userAddrs - } - relevant := isRelevant(requestingAccount, userAddrs) - if relevant { - relevantLogsForSub = append(relevantLogsForSub, logItem) - } - s.logger.Debug("Subscription", log.SubIDKey, id, "acc", requestingAccount, "log", logItem, "extr_addr", userAddrs, "relev", relevant) + relevantLogsForSub, err := s.storage.FilterLogs(ctx, sub.ViewingKeyEncryptor.AccountAddress, nil, nil, &h, sub.Subscription.Filter.Addresses, sub.Subscription.Filter.Topics) + if err != nil { + return nil, err } if len(relevantLogsForSub) > 0 { relevantLogsPerSubscription[id] = relevantLogsForSub @@ -172,19 +115,6 @@ func (s *SubscriptionManager) GetSubscribedLogsForBatch(ctx context.Context, bat return s.encryptLogs(relevantLogsPerSubscription) } -func isRelevant(sub *gethcommon.Address, userAddrs []*gethcommon.Address) bool { - // If there are no user addresses, this is a lifecycle event, and is therefore relevant to everyone. - if len(userAddrs) == 0 { - return true - } - for _, addr := range userAddrs { - if *addr == *sub { - return true - } - } - return false -} - // Encrypts each log with the appropriate viewing key. func (s *SubscriptionManager) encryptLogs(logsByID map[gethrpc.ID][]*types.Log) (map[gethrpc.ID][]byte, error) { encryptedLogsByID := map[gethrpc.ID][]byte{} @@ -210,85 +140,3 @@ func (s *SubscriptionManager) encryptLogs(logsByID map[gethrpc.ID][]*types.Log) return encryptedLogsByID, nil } - -// Of the log's topics, returns those that are (potentially) user addresses. A topic is considered a user address if: -// - It has 12 leading zero bytes (since addresses are 20 bytes long, while hashes are 32) -// - It has a non-zero nonce (to prevent accidental or malicious creation of the address matching a given topic, -// forcing its events to become permanently private -// - It does not have associated code (meaning it's a smart-contract address) -func getUserAddrsFromLogTopics(log *types.Log, db *state.StateDB) []*gethcommon.Address { - var userAddrs []*gethcommon.Address - - // We skip over the first topic, which is always the hash of the event. - for _, topic := range log.Topics[1:len(log.Topics)] { - if topic.Hex()[2:len(zeroBytesHex)+2] != zeroBytesHex { - continue - } - - potentialAddr := gethcommon.BytesToAddress(topic.Bytes()) - - // A user address must have a non-zero nonce. This prevents accidental or malicious sending of funds to an - // address matching a topic, forcing its events to become permanently private. - if db.GetNonce(potentialAddr) != 0 { - // If the address has code, it's a smart contract address instead. - if db.GetCode(potentialAddr) == nil { - userAddrs = append(userAddrs, &potentialAddr) - } - } - } - - return userAddrs -} - -// Lifted from eth/filters/filter.go in the go-ethereum repository. -// filterLogs creates a slice of logs matching the given criteria. -func filterLogs(logs []*types.Log, fromBlock, toBlock *gethrpc.BlockNumber, addresses []gethcommon.Address, topics [][]gethcommon.Hash, logger gethlog.Logger) []*types.Log { //nolint:gocognit - var ret []*types.Log -Logs: - for _, logItem := range logs { - if fromBlock != nil && fromBlock.Int64() >= 0 && fromBlock.Int64() > int64(logItem.BlockNumber) { - logger.Debug("Skipping log ", "log", logItem, "reason", "In the past. The starting block num for filter is bigger than log") - continue - } - if toBlock != nil && toBlock.Int64() > 0 && toBlock.Int64() < int64(logItem.BlockNumber) { - logger.Debug("Skipping log ", "log", logItem, "reason", "In the future. The ending block num for filter is smaller than log") - continue - } - - if len(addresses) > 0 && !includes(addresses, logItem.Address) { - logger.Debug("Skipping log ", "log", logItem, "reason", "The contract address of the log is not an address of interest") - continue - } - // If the to filtered topics is greater than the amount of topics in logs, skip. - if len(topics) > len(logItem.Topics) { - logger.Debug("Skipping log ", "log", logItem, "reason", "Insufficient topics. The log has less topics than the required one to satisfy the query") - continue - } - for i, sub := range topics { - match := len(sub) == 0 // empty rule set == wildcard - for _, topic := range sub { - if logItem.Topics[i] == topic { - match = true - break - } - } - if !match { - logger.Debug("Skipping log ", "log", logItem, "reason", "Topics do not match.") - continue Logs - } - } - ret = append(ret, logItem) - } - return ret -} - -// Lifted from eth/filters/filter.go in the go-ethereum repository. -func includes(addresses []gethcommon.Address, a gethcommon.Address) bool { - for _, addr := range addresses { - if addr == a { - return true - } - } - - return false -} diff --git a/go/enclave/rpc/GetTransactionReceipt.go b/go/enclave/rpc/GetTransactionReceipt.go index 88c55f967d..6eb4d2ae79 100644 --- a/go/enclave/rpc/GetTransactionReceipt.go +++ b/go/enclave/rpc/GetTransactionReceipt.go @@ -14,7 +14,6 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ten-protocol/go-ten/go/common/log" - "github.com/ten-protocol/go-ten/go/enclave/events" ) func GetTransactionReceiptValidate(reqParams []any, builder *CallBuilder[gethcommon.Hash, map[string]interface{}], _ *EncryptionManager) error { @@ -73,8 +72,8 @@ func GetTransactionReceiptExecute(builder *CallBuilder[gethcommon.Hash, map[stri return fmt.Errorf("could not retrieve transaction receipt in eth_getTransactionReceipt request. Cause: %w", err) } - // We filter out irrelevant logs. - txReceipt.Logs, err = events.FilterLogsForReceipt(builder.ctx, txReceipt, &txSigner, rpc.registry) + // We only keep the logs that the requester is allowed to see + txReceipt.Logs, err = rpc.storage.FilterLogsForReceipt(builder.ctx, &txSigner, txReceipt.TxHash) if err != nil { rpc.logger.Error("error filter logs ", log.TxKey, txHash, log.ErrKey, err) // this is a system error diff --git a/go/enclave/storage/enclavedb/events.go b/go/enclave/storage/enclavedb/events.go index 65b36d002e..518b921ba9 100644 --- a/go/enclave/storage/enclavedb/events.go +++ b/go/enclave/storage/enclavedb/events.go @@ -106,17 +106,15 @@ func WriteEventLog(ctx context.Context, dbTX *sql.Tx, eventTypeId uint64, userTo return err } -func FilterLogs( - ctx context.Context, - db *sql.DB, - requestingAccount *gethcommon.Address, - fromBlock, toBlock *big.Int, - batchHash *common.L2BatchHash, - addresses []gethcommon.Address, - topics [][]gethcommon.Hash, -) ([]*types.Log, error) { +func FilterLogs(ctx context.Context, db *sql.DB, requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, batchHash *common.L2BatchHash, addresses []gethcommon.Address, topics [][]gethcommon.Hash, txHash *gethcommon.Hash) ([]*types.Log, error) { queryParams := []any{} query := "" + + if txHash != nil { + query += " AND tx.hash = ? " + queryParams = append(queryParams, txHash.Bytes()) + } + if batchHash != nil { query += " AND b.hash = ? " queryParams = append(queryParams, batchHash.Bytes()) diff --git a/go/enclave/storage/interfaces.go b/go/enclave/storage/interfaces.go index 5bf5107be8..c029eef4a4 100644 --- a/go/enclave/storage/interfaces.go +++ b/go/enclave/storage/interfaces.go @@ -143,6 +143,8 @@ type Storage interface { // the blockHash should always be nil. FilterLogs(ctx context.Context, requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, blockHash *common.L2BatchHash, addresses []gethcommon.Address, topics [][]gethcommon.Hash) ([]*types.Log, error) + FilterLogsForReceipt(ctx context.Context, requestingAccount *gethcommon.Address, TxHash gethcommon.Hash) ([]*types.Log, error) + // DebugGetLogs returns logs for a given tx hash without any constraints - should only be used for debug purposes DebugGetLogs(ctx context.Context, txHash common.TxHash) ([]*tracers.DebugLogs, error) diff --git a/go/enclave/storage/storage.go b/go/enclave/storage/storage.go index 6d372f3f0f..74b264ef05 100644 --- a/go/enclave/storage/storage.go +++ b/go/enclave/storage/storage.go @@ -736,6 +736,23 @@ func (s *storageImpl) DebugGetLogs(ctx context.Context, txHash common.TxHash) ([ return enclavedb.DebugGetLogs(ctx, s.db.GetSQLDB(), txHash) } +func (s *storageImpl) FilterLogsForReceipt(ctx context.Context, requestingAccount *gethcommon.Address, txHash gethcommon.Hash) ([]*types.Log, error) { + defer s.logDuration("FilterLogs", measure.NewStopwatch()) + logs, err := enclavedb.FilterLogs(ctx, s.db.GetSQLDB(), requestingAccount, nil, nil, nil, nil, nil, &txHash) + if err != nil { + return nil, err + } + // the database returns an unsorted list of event logs. + // we have to perform the sorting programatically + sort.Slice(logs, func(i, j int) bool { + if logs[i].BlockNumber == logs[j].BlockNumber { + return logs[i].Index < logs[j].Index + } + return logs[i].BlockNumber < logs[j].BlockNumber + }) + return logs, nil +} + func (s *storageImpl) FilterLogs( ctx context.Context, requestingAccount *gethcommon.Address, @@ -745,7 +762,7 @@ func (s *storageImpl) FilterLogs( topics [][]gethcommon.Hash, ) ([]*types.Log, error) { defer s.logDuration("FilterLogs", measure.NewStopwatch()) - logs, err := enclavedb.FilterLogs(ctx, s.db.GetSQLDB(), requestingAccount, fromBlock, toBlock, blockHash, addresses, topics) + logs, err := enclavedb.FilterLogs(ctx, s.db.GetSQLDB(), requestingAccount, fromBlock, toBlock, blockHash, addresses, topics, nil) if err != nil { return nil, err }