diff --git a/tools/walletextension/cache/RistrettoCache.go b/tools/walletextension/cache/RistrettoCache.go new file mode 100644 index 0000000000..af417115b7 --- /dev/null +++ b/tools/walletextension/cache/RistrettoCache.go @@ -0,0 +1,82 @@ +package cache + +import ( + "time" + + "github.com/ethereum/go-ethereum/log" + + "github.com/dgraph-io/ristretto" +) + +const ( + numCounters = 1e7 // number of keys to track frequency of (10M). + maxCost = 1 << 30 // maximum cost of cache (1GB). + bufferItems = 64 // number of keys per Get buffer. + defaultConst = 1 // default cost of cache. +) + +type RistrettoCache struct { + cache *ristretto.Cache + quit chan struct{} +} + +// NewRistrettoCache returns a new RistrettoCache. +func NewRistrettoCache(logger log.Logger) (*RistrettoCache, error) { + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: numCounters, + MaxCost: maxCost, + BufferItems: bufferItems, + Metrics: true, + }) + if err != nil { + return nil, err + } + + c := &RistrettoCache{ + cache: cache, + quit: make(chan struct{}), + } + + // Start the metrics logging + go c.startMetricsLogging(logger) + + return c, nil +} + +// Set adds the key and value to the cache. +func (c *RistrettoCache) Set(key string, value map[string]interface{}, ttl time.Duration) bool { + return c.cache.SetWithTTL(key, value, defaultConst, ttl) +} + +// Get returns the value for the given key if it exists. +func (c *RistrettoCache) Get(key string) (value map[string]interface{}, ok bool) { + item, found := c.cache.Get(key) + if !found { + return nil, false + } + + // Assuming the item is stored as a map[string]interface{}, otherwise you need to type assert to the correct type. + value, ok = item.(map[string]interface{}) + if !ok { + // The item isn't of type map[string]interface{} + return nil, false + } + + return value, true +} + +// startMetricsLogging starts logging cache metrics every hour. +func (c *RistrettoCache) startMetricsLogging(logger log.Logger) { + ticker := time.NewTicker(1 * time.Hour) + for { + select { + case <-ticker.C: + metrics := c.cache.Metrics + logger.Info("Cache metrics: Hits: %d, Misses: %d, Cost Added: %d\n", + metrics.Hits(), metrics.Misses(), metrics.CostAdded()) + case <-c.quit: + ticker.Stop() + return + } + } +} diff --git a/tools/walletextension/cache/cache.go b/tools/walletextension/cache/cache.go new file mode 100644 index 0000000000..66e8b35f63 --- /dev/null +++ b/tools/walletextension/cache/cache.go @@ -0,0 +1,98 @@ +package cache + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/log" + + "github.com/ten-protocol/go-ten/tools/walletextension/common" +) + +const ( + longCacheTTL = 5 * time.Hour + shortCacheTTL = 1 * time.Second +) + +// CacheableRPCMethods is a map of Ethereum JSON-RPC methods that can be cached and their TTL +var cacheableRPCMethods = map[string]time.Duration{ + // Ethereum JSON-RPC methods that can be cached long time + "eth_getBlockByNumber": longCacheTTL, + "eth_getBlockByHash": longCacheTTL, + "eth_getTransactionByHash": longCacheTTL, + "eth_chainId": longCacheTTL, + + // Ethereum JSON-RPC methods that can be cached short time + "eth_blockNumber": shortCacheTTL, + "eth_getCode": shortCacheTTL, + "eth_getBalance": shortCacheTTL, + "eth_getTransactionReceipt": shortCacheTTL, + "eth_call": shortCacheTTL, + "eth_gasPrice": shortCacheTTL, + "eth_getTransactionCount": shortCacheTTL, + "eth_estimateGas": shortCacheTTL, + "eth_feeHistory": shortCacheTTL, +} + +type Cache interface { + Set(key string, value map[string]interface{}, ttl time.Duration) bool + Get(key string) (value map[string]interface{}, ok bool) +} + +func NewCache(logger log.Logger) (Cache, error) { + return NewRistrettoCache(logger) +} + +// IsCacheable checks if the given RPC request is cacheable and returns the cache key and TTL +func IsCacheable(key *common.RPCRequest) (bool, string, time.Duration) { + if key == nil || key.Method == "" { + return false, "", 0 + } + + // Check if the method is cacheable + ttl, isCacheable := cacheableRPCMethods[key.Method] + + if isCacheable { + // method is cacheable - select cache key + switch key.Method { + case "eth_getCode", "eth_getBalance", "eth_getTransactionCount", "eth_estimateGas", "eth_call": + if len(key.Params) == 1 || len(key.Params) == 2 && (key.Params[1] == "latest" || key.Params[1] == "pending") { + return true, GenerateCacheKey(key.Method, key.Params...), ttl + } + // in this case, we have a fixed block number, and we can cache the result for a long time + return true, GenerateCacheKey(key.Method, key.Params...), longCacheTTL + case "eth_feeHistory": + if len(key.Params) == 2 || len(key.Params) == 3 && (key.Params[2] == "latest" || key.Params[2] == "pending") { + return true, GenerateCacheKey(key.Method, key.Params...), ttl + } + // in this case, we have a fixed block number, and we can cache the result for a long time + return true, GenerateCacheKey(key.Method, key.Params...), longCacheTTL + default: + return true, GenerateCacheKey(key.Method, key.Params...), ttl + } + } + + // method is not cacheable + return false, "", 0 +} + +// GenerateCacheKey generates a cache key for the given method and parameters +func GenerateCacheKey(method string, params ...interface{}) string { + // Serialize parameters + paramBytes, err := json.Marshal(params) + if err != nil { + return "" + } + + // Concatenate method name and parameters + rawKey := method + string(paramBytes) + + // Optional: Apply hashing + hasher := sha256.New() + hasher.Write([]byte(rawKey)) + hashedKey := fmt.Sprintf("%x", hasher.Sum(nil)) + + return hashedKey +} diff --git a/tools/walletextension/cache/cache_test.go b/tools/walletextension/cache/cache_test.go new file mode 100644 index 0000000000..f4bb05d941 --- /dev/null +++ b/tools/walletextension/cache/cache_test.go @@ -0,0 +1,147 @@ +package cache + +import ( + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/log" + + "github.com/ten-protocol/go-ten/tools/walletextension/common" +) + +var tests = map[string]func(t *testing.T){ + "testCacheableMethods": testCacheableMethods, + "testNonCacheableMethods": testNonCacheableMethods, + "testMethodsWithLatestOrPendingParameter": testMethodsWithLatestOrPendingParameter, +} + +var cacheTests = map[string]func(cache Cache, t *testing.T){ + "testResultsAreCached": testResultsAreCached, + "testCacheTTL": testCacheTTL, +} + +var nonCacheableMethods = []string{"eth_sendrawtransaction", "eth_sendtransaction", "join", "authenticate"} + +func TestGatewayCaching(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + test(t) + }) + } + + // cache tests + for name, test := range cacheTests { + t.Run(name, func(t *testing.T) { + logger := log.New() + cache, err := NewCache(logger) + if err != nil { + t.Errorf("failed to create cache: %v", err) + } + test(cache, t) + }) + } +} + +// testCacheableMethods tests if the cacheable methods are cacheable +func testCacheableMethods(t *testing.T) { + for method := range cacheableRPCMethods { + key := &common.RPCRequest{Method: method} + isCacheable, _, _ := IsCacheable(key) + if isCacheable != true { + t.Errorf("method %s should be cacheable", method) + } + } +} + +// testNonCacheableMethods tests if the non-cacheable methods are not cacheable +func testNonCacheableMethods(t *testing.T) { + for _, method := range nonCacheableMethods { + key := &common.RPCRequest{Method: method} + isCacheable, _, _ := IsCacheable(key) + if isCacheable == true { + t.Errorf("method %s should not be cacheable", method) + } + } +} + +// testMethodsWithLatestOrPendingParameter tests if the methods with latest or pending parameter are cacheable +func testMethodsWithLatestOrPendingParameter(t *testing.T) { + methods := []string{"eth_getCode", "eth_getBalance", "eth_getTransactionCount", "eth_estimateGas", "eth_call"} + for _, method := range methods { + key := &common.RPCRequest{Method: method, Params: []interface{}{"0x123", "latest"}} + _, _, ttl := IsCacheable(key) + if ttl != shortCacheTTL { + t.Errorf("method %s with latest parameter should have TTL of %s, but %s received", method, shortCacheTTL, ttl) + } + + key = &common.RPCRequest{Method: method, Params: []interface{}{"0x123", "pending"}} + _, _, ttl = IsCacheable(key) + if ttl != shortCacheTTL { + t.Errorf("method %s with pending parameter should have TTL of %s, but %s received", method, shortCacheTTL, ttl) + } + } +} + +// testResultsAreCached tests if the results are cached as expected +func testResultsAreCached(cache Cache, t *testing.T) { + // prepare a cacheable request and imaginary response + req := &common.RPCRequest{Method: "eth_getBlockByNumber", Params: []interface{}{"0x123"}} + res := map[string]interface{}{"result": "block"} + isCacheable, key, ttl := IsCacheable(req) + if !isCacheable { + t.Errorf("method %s should be cacheable", req.Method) + } + // set the response in the cache with a TTL + if !cache.Set(key, res, ttl) { + t.Errorf("failed to set value in cache for %s", req) + } + + time.Sleep(50 * time.Millisecond) // wait for the cache to be set + value, ok := cache.Get(key) + if !ok { + t.Errorf("failed to get cached value for %s", req) + } + + if !reflect.DeepEqual(value, res) { + t.Errorf("expected %v, got %v", res, value) + } +} + +// testCacheTTL tests if the cache TTL is working as expected +func testCacheTTL(cache Cache, t *testing.T) { + req := &common.RPCRequest{Method: "eth_getBalance", Params: []interface{}{"0x123"}} + res := map[string]interface{}{"result": "100"} + isCacheable, key, ttl := IsCacheable(req) + + if !isCacheable { + t.Errorf("method %s should be cacheable", req.Method) + } + + if ttl != shortCacheTTL { + t.Errorf("method %s should have TTL of %s, but %s received", req.Method, shortCacheTTL, ttl) + } + + // set the response in the cache with a TTL + if !cache.Set(key, res, ttl) { + t.Errorf("failed to set value in cache for %s", req) + } + time.Sleep(50 * time.Millisecond) // wait for the cache to be set + + // check if the value is in the cache + value, ok := cache.Get(key) + if !ok { + t.Errorf("failed to get cached value for %s", req) + } + + if !reflect.DeepEqual(value, res) { + t.Errorf("expected %v, got %v", res, value) + } + + // sleep for the TTL to expire + time.Sleep(shortCacheTTL + 100*time.Millisecond) + _, ok = cache.Get(key) + if ok { + t.Errorf("value should not be in the cache after TTL") + } +} diff --git a/tools/walletextension/wallet_extension.go b/tools/walletextension/wallet_extension.go index 237e79c980..68f2917ab0 100644 --- a/tools/walletextension/wallet_extension.go +++ b/tools/walletextension/wallet_extension.go @@ -7,6 +7,8 @@ import ( "fmt" "time" + "github.com/ten-protocol/go-ten/tools/walletextension/cache" + "github.com/ten-protocol/go-ten/tools/walletextension/accountmanager" "github.com/ten-protocol/go-ten/tools/walletextension/config" @@ -43,6 +45,7 @@ type WalletExtension struct { version string config *config.Config tenClient *obsclient.ObsClient + cache cache.Cache } func New( @@ -62,6 +65,12 @@ func New( } newTenClient := obsclient.NewObsClient(rpcClient) newFileLogger := common.NewFileLogger() + newGatewayCache, err := cache.NewCache(logger) + if err != nil { + logger.Error(fmt.Errorf("could not create cache. Cause: %w", err).Error()) + panic(err) + } + return &WalletExtension{ hostAddrHTTP: hostAddrHTTP, hostAddrWS: hostAddrWS, @@ -74,6 +83,7 @@ func New( version: version, config: config, tenClient: newTenClient, + cache: newGatewayCache, } } @@ -92,6 +102,19 @@ func (w *WalletExtension) ProxyEthRequest(request *common.RPCRequest, conn userc // start measuring time for request requestStartTime := time.Now() + // Check if the request is in the cache + isCacheable, key, ttl := cache.IsCacheable(request) + + // in case of cache hit return the response from the cache + if isCacheable { + if value, ok := w.cache.Get(key); ok { + requestEndTime := time.Now() + duration := requestEndTime.Sub(requestStartTime) + w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, value, duration.Milliseconds())) + return value, nil + } + } + response := map[string]interface{}{} // all responses must contain the request id. Both successful and unsuccessful. response[common.JSONKeyRPCVersion] = jsonrpc.Version @@ -140,6 +163,11 @@ func (w *WalletExtension) ProxyEthRequest(request *common.RPCRequest, conn userc duration := requestEndTime.Sub(requestStartTime) w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, response, duration.Milliseconds())) + // if the request is cacheable, store the response in the cache + if isCacheable { + w.cache.Set(key, response, ttl) + } + return response, nil }