diff --git a/go/enclave/events/subscription_manager.go b/go/enclave/events/subscription_manager.go index 5641d09b62..27e6515d41 100644 --- a/go/enclave/events/subscription_manager.go +++ b/go/enclave/events/subscription_manager.go @@ -11,15 +11,24 @@ import ( "github.com/ten-protocol/go-ten/go/enclave/vkhandler" gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/go/enclave/core" "github.com/ten-protocol/go-ten/go/enclave/storage" + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" gethlog "github.com/ethereum/go-ethereum/log" "github.com/ten-protocol/go-ten/go/common" ) +const ( + // The leading zero bytes in a hash indicating that it is possibly an address, since it only has 20 bytes of data. + zeroBytesHex = "000000000000000000000000" +) + type logSubscription struct { Subscription *common.LogSubscription // Handles the viewing key encryption @@ -83,6 +92,24 @@ func (s *SubscriptionManager) RemoveSubscription(id gethrpc.ID) { delete(s.subscriptions, id) } +// FilterLogsForReceipt removes the logs that the sender of a transaction is not allowed to view +func FilterLogsForReceipt(ctx context.Context, receipt *types.Receipt, account *gethcommon.Address, registry components.BatchRegistry) ([]*types.Log, error) { + var filteredLogs []*types.Log + stateDB, err := registry.GetBatchState(ctx, &receipt.BlockHash) + if err != nil { + return nil, fmt.Errorf("could not create state DB to filter logs. Cause: %w", err) + } + + for _, logItem := range receipt.Logs { + userAddrs := getUserAddrsFromLogTopics(logItem, stateDB) + if isRelevant(account, userAddrs) { + filteredLogs = append(filteredLogs, logItem) + } + } + + return filteredLogs, nil +} + // GetSubscribedLogsForBatch - Retrieves and encrypts the logs for the batch in live mode. // The assumption is that this function is called synchronously after the batch is produced func (s *SubscriptionManager) GetSubscribedLogsForBatch(ctx context.Context, batch *core.Batch, receipts types.Receipts) (common.EncryptedSubscriptionLogs, error) { @@ -94,17 +121,47 @@ func (s *SubscriptionManager) GetSubscribedLogsForBatch(ctx context.Context, bat return nil, nil } - h := batch.Hash() relevantLogsPerSubscription := map[gethrpc.ID][]*types.Log{} - if len(receipts) == 0 { + // extract the logs from all receipts + var allLogs []*types.Log + for _, receipt := range receipts { + allLogs = append(allLogs, receipt.Logs...) + } + + if len(allLogs) == 0 { return nil, nil } + // the stateDb is needed to extract the user addresses from the topics + h := batch.Hash() + stateDB, err := s.registry.GetBatchState(ctx, &h) + if err != nil { + return nil, fmt.Errorf("could not create state DB to filter logs. Cause: %w", err) + } + + // cache for the user addresses extracted from the individual logs + // this is an expensive operation so we are doing it lazy, and caching the result + userAddrsForLog := map[*types.Log][]*gethcommon.Address{} + for id, sub := range s.subscriptions { - relevantLogsForSub, err := s.storage.FilterLogs(ctx, sub.ViewingKeyEncryptor.AccountAddress, nil, nil, &h, sub.Subscription.Filter.Addresses, sub.Subscription.Filter.Topics) - if err != nil { - return nil, err + // first filter the logs + filteredLogs := filterLogs(allLogs, sub.Subscription.Filter.FromBlock, sub.Subscription.Filter.ToBlock, sub.Subscription.Filter.Addresses, sub.Subscription.Filter.Topics, s.logger) + + // the account requesting the logs is retrieved from the Viewing Key + requestingAccount := sub.ViewingKeyEncryptor.AccountAddress + relevantLogsForSub := []*types.Log{} + for _, logItem := range filteredLogs { + userAddrs, f := userAddrsForLog[logItem] + if !f { + userAddrs = getUserAddrsFromLogTopics(logItem, stateDB) + userAddrsForLog[logItem] = userAddrs + } + relevant := isRelevant(requestingAccount, userAddrs) + if relevant { + relevantLogsForSub = append(relevantLogsForSub, logItem) + } + s.logger.Debug("Subscription", log.SubIDKey, id, "acc", requestingAccount, "log", logItem, "extr_addr", userAddrs, "relev", relevant) } if len(relevantLogsForSub) > 0 { relevantLogsPerSubscription[id] = relevantLogsForSub @@ -115,6 +172,19 @@ func (s *SubscriptionManager) GetSubscribedLogsForBatch(ctx context.Context, bat return s.encryptLogs(relevantLogsPerSubscription) } +func isRelevant(sub *gethcommon.Address, userAddrs []*gethcommon.Address) bool { + // If there are no user addresses, this is a lifecycle event, and is therefore relevant to everyone. + if len(userAddrs) == 0 { + return true + } + for _, addr := range userAddrs { + if *addr == *sub { + return true + } + } + return false +} + // Encrypts each log with the appropriate viewing key. func (s *SubscriptionManager) encryptLogs(logsByID map[gethrpc.ID][]*types.Log) (map[gethrpc.ID][]byte, error) { encryptedLogsByID := map[gethrpc.ID][]byte{} @@ -140,3 +210,85 @@ func (s *SubscriptionManager) encryptLogs(logsByID map[gethrpc.ID][]*types.Log) return encryptedLogsByID, nil } + +// Of the log's topics, returns those that are (potentially) user addresses. A topic is considered a user address if: +// - It has 12 leading zero bytes (since addresses are 20 bytes long, while hashes are 32) +// - It has a non-zero nonce (to prevent accidental or malicious creation of the address matching a given topic, +// forcing its events to become permanently private +// - It does not have associated code (meaning it's a smart-contract address) +func getUserAddrsFromLogTopics(log *types.Log, db *state.StateDB) []*gethcommon.Address { + var userAddrs []*gethcommon.Address + + // We skip over the first topic, which is always the hash of the event. + for _, topic := range log.Topics[1:len(log.Topics)] { + if topic.Hex()[2:len(zeroBytesHex)+2] != zeroBytesHex { + continue + } + + potentialAddr := gethcommon.BytesToAddress(topic.Bytes()) + + // A user address must have a non-zero nonce. This prevents accidental or malicious sending of funds to an + // address matching a topic, forcing its events to become permanently private. + if db.GetNonce(potentialAddr) != 0 { + // If the address has code, it's a smart contract address instead. + if db.GetCode(potentialAddr) == nil { + userAddrs = append(userAddrs, &potentialAddr) + } + } + } + + return userAddrs +} + +// Lifted from eth/filters/filter.go in the go-ethereum repository. +// filterLogs creates a slice of logs matching the given criteria. +func filterLogs(logs []*types.Log, fromBlock, toBlock *gethrpc.BlockNumber, addresses []gethcommon.Address, topics [][]gethcommon.Hash, logger gethlog.Logger) []*types.Log { //nolint:gocognit + var ret []*types.Log +Logs: + for _, logItem := range logs { + if fromBlock != nil && fromBlock.Int64() >= 0 && fromBlock.Int64() > int64(logItem.BlockNumber) { + logger.Debug("Skipping log ", "log", logItem, "reason", "In the past. The starting block num for filter is bigger than log") + continue + } + if toBlock != nil && toBlock.Int64() > 0 && toBlock.Int64() < int64(logItem.BlockNumber) { + logger.Debug("Skipping log ", "log", logItem, "reason", "In the future. The ending block num for filter is smaller than log") + continue + } + + if len(addresses) > 0 && !includes(addresses, logItem.Address) { + logger.Debug("Skipping log ", "log", logItem, "reason", "The contract address of the log is not an address of interest") + continue + } + // If the to filtered topics is greater than the amount of topics in logs, skip. + if len(topics) > len(logItem.Topics) { + logger.Debug("Skipping log ", "log", logItem, "reason", "Insufficient topics. The log has less topics than the required one to satisfy the query") + continue + } + for i, sub := range topics { + match := len(sub) == 0 // empty rule set == wildcard + for _, topic := range sub { + if logItem.Topics[i] == topic { + match = true + break + } + } + if !match { + logger.Debug("Skipping log ", "log", logItem, "reason", "Topics do not match.") + continue Logs + } + } + ret = append(ret, logItem) + } + return ret +} + +// Lifted from eth/filters/filter.go in the go-ethereum repository. +func includes(addresses []gethcommon.Address, a gethcommon.Address) bool { + for _, addr := range addresses { + if addr == a { + return true + } + } + + return false +} diff --git a/go/enclave/nodetype/common.go b/go/enclave/nodetype/common.go index c414f63c81..bff4a94447 100644 --- a/go/enclave/nodetype/common.go +++ b/go/enclave/nodetype/common.go @@ -20,7 +20,7 @@ func ExportCrossChainData(ctx context.Context, storage storage.Storage, fromSeqN return nil, errutil.ErrCrossChainBundleNoBatches } - // todo - siliev - all those fetches need to be atomic + //todo - siliev - all those fetches need to be atomic header, err := storage.FetchHeadBatchHeader(ctx) if err != nil { return nil, err diff --git a/go/enclave/rpc/GetTransactionReceipt.go b/go/enclave/rpc/GetTransactionReceipt.go index 6eb4d2ae79..88c55f967d 100644 --- a/go/enclave/rpc/GetTransactionReceipt.go +++ b/go/enclave/rpc/GetTransactionReceipt.go @@ -14,6 +14,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/go/enclave/events" ) func GetTransactionReceiptValidate(reqParams []any, builder *CallBuilder[gethcommon.Hash, map[string]interface{}], _ *EncryptionManager) error { @@ -72,8 +73,8 @@ func GetTransactionReceiptExecute(builder *CallBuilder[gethcommon.Hash, map[stri return fmt.Errorf("could not retrieve transaction receipt in eth_getTransactionReceipt request. Cause: %w", err) } - // We only keep the logs that the requester is allowed to see - txReceipt.Logs, err = rpc.storage.FilterLogsForReceipt(builder.ctx, &txSigner, txReceipt.TxHash) + // We filter out irrelevant logs. + txReceipt.Logs, err = events.FilterLogsForReceipt(builder.ctx, txReceipt, &txSigner, rpc.registry) if err != nil { rpc.logger.Error("error filter logs ", log.TxKey, txHash, log.ErrKey, err) // this is a system error diff --git a/go/enclave/storage/enclavedb/events.go b/go/enclave/storage/enclavedb/events.go index 3bb0a8519f..f5a505bbc8 100644 --- a/go/enclave/storage/enclavedb/events.go +++ b/go/enclave/storage/enclavedb/events.go @@ -89,15 +89,17 @@ func WriteEventLog(ctx context.Context, dbTX *sql.Tx, eventTypeId uint64, userTo return err } -func FilterLogs(ctx context.Context, db *sql.DB, requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, batchHash *common.L2BatchHash, addresses []gethcommon.Address, topics [][]gethcommon.Hash, txHash *gethcommon.Hash) ([]*types.Log, error) { +func FilterLogs( + ctx context.Context, + db *sql.DB, + requestingAccount *gethcommon.Address, + fromBlock, toBlock *big.Int, + batchHash *common.L2BatchHash, + addresses []gethcommon.Address, + topics [][]gethcommon.Hash, +) ([]*types.Log, error) { queryParams := []any{} query := "" - - if txHash != nil { - query += " AND tx.hash = ? " - queryParams = append(queryParams, txHash.Bytes()) - } - if batchHash != nil { query += " AND b.hash = ? " queryParams = append(queryParams, batchHash.Bytes()) diff --git a/go/enclave/storage/interfaces.go b/go/enclave/storage/interfaces.go index ac650b3107..fb7ac7b73a 100644 --- a/go/enclave/storage/interfaces.go +++ b/go/enclave/storage/interfaces.go @@ -141,8 +141,6 @@ type Storage interface { // the blockHash should always be nil. FilterLogs(ctx context.Context, requestingAccount *gethcommon.Address, fromBlock, toBlock *big.Int, blockHash *common.L2BatchHash, addresses []gethcommon.Address, topics [][]gethcommon.Hash) ([]*types.Log, error) - FilterLogsForReceipt(ctx context.Context, requestingAccount *gethcommon.Address, TxHash gethcommon.Hash) ([]*types.Log, error) - // DebugGetLogs returns logs for a given tx hash without any constraints - should only be used for debug purposes DebugGetLogs(ctx context.Context, txHash common.TxHash) ([]*tracers.DebugLogs, error) diff --git a/go/enclave/storage/storage.go b/go/enclave/storage/storage.go index 947d982fb0..244ffd7c5d 100644 --- a/go/enclave/storage/storage.go +++ b/go/enclave/storage/storage.go @@ -741,23 +741,6 @@ func (s *storageImpl) DebugGetLogs(ctx context.Context, txHash common.TxHash) ([ return enclavedb.DebugGetLogs(ctx, s.db.GetSQLDB(), txHash) } -func (s *storageImpl) FilterLogsForReceipt(ctx context.Context, requestingAccount *gethcommon.Address, txHash gethcommon.Hash) ([]*types.Log, error) { - defer s.logDuration("FilterLogs", measure.NewStopwatch()) - logs, err := enclavedb.FilterLogs(ctx, s.db.GetSQLDB(), requestingAccount, nil, nil, nil, nil, nil, &txHash) - if err != nil { - return nil, err - } - // the database returns an unsorted list of event logs. - // we have to perform the sorting programatically - sort.Slice(logs, func(i, j int) bool { - if logs[i].BlockNumber == logs[j].BlockNumber { - return logs[i].Index < logs[j].Index - } - return logs[i].BlockNumber < logs[j].BlockNumber - }) - return logs, nil -} - func (s *storageImpl) FilterLogs( ctx context.Context, requestingAccount *gethcommon.Address, @@ -767,7 +750,7 @@ func (s *storageImpl) FilterLogs( topics [][]gethcommon.Hash, ) ([]*types.Log, error) { defer s.logDuration("FilterLogs", measure.NewStopwatch()) - logs, err := enclavedb.FilterLogs(ctx, s.db.GetSQLDB(), requestingAccount, fromBlock, toBlock, blockHash, addresses, topics, nil) + logs, err := enclavedb.FilterLogs(ctx, s.db.GetSQLDB(), requestingAccount, fromBlock, toBlock, blockHash, addresses, topics) if err != nil { return nil, err }