From 8ce194423351fd2b978eb36d7fc27543b396f509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=BDiga=20Kokelj?= Date: Mon, 4 Nov 2024 12:21:44 +0100 Subject: [PATCH] Use cosmosDB in the gateway and encrypt data with encryption key generated inside the enclave (#2104) --- .../manual-deploy-obscuro-gateway.yml | 2 +- go.mod | 5 + go.sum | 21 ++ tools/walletextension/common/common.go | 12 + tools/walletextension/common/config.go | 1 + tools/walletextension/common/db_types.go | 15 +- tools/walletextension/enclave.Dockerfile | 6 + .../walletextension/encryption/encryption.go | 65 +++++ .../encryption/encryption_test.go | 115 +++++++++ tools/walletextension/main/cli.go | 6 + tools/walletextension/main/enclave.json | 12 +- tools/walletextension/rpcapi/gw_user.go | 39 ++- .../walletextension/rpcapi/transaction_api.go | 22 -- .../rpcapi/wallet_extension.go | 7 +- .../storage/database/cosmosdb/cosmosdb.go | 229 ++++++++++++++++++ .../storage/database/mariadb/001_init.sql | 15 -- .../mariadb/002_store_incoming_txs.sql | 11 - .../mariadb/003_add_signature_type.sql | 1 - .../storage/database/mariadb/mariadb.go | 179 -------------- .../storage/database/sqlite/sqlite.go | 175 +++++-------- tools/walletextension/storage/storage.go | 17 +- tools/walletextension/storage/storage_test.go | 144 +++++++---- .../storage/storage_with_cache.go | 90 +++++++ .../walletextension_container.go | 16 +- 24 files changed, 766 insertions(+), 439 deletions(-) create mode 100644 tools/walletextension/encryption/encryption.go create mode 100644 tools/walletextension/encryption/encryption_test.go create mode 100644 tools/walletextension/storage/database/cosmosdb/cosmosdb.go delete mode 100644 tools/walletextension/storage/database/mariadb/001_init.sql delete mode 100644 tools/walletextension/storage/database/mariadb/002_store_incoming_txs.sql delete mode 100644 tools/walletextension/storage/database/mariadb/003_add_signature_type.sql delete mode 100644 tools/walletextension/storage/database/mariadb/mariadb.go create mode 100644 tools/walletextension/storage/storage_with_cache.go diff --git a/.github/workflows/manual-deploy-obscuro-gateway.yml b/.github/workflows/manual-deploy-obscuro-gateway.yml index 4fee72c3c5..b04124670a 100644 --- a/.github/workflows/manual-deploy-obscuro-gateway.yml +++ b/.github/workflows/manual-deploy-obscuro-gateway.yml @@ -296,5 +296,5 @@ jobs: "${{ env.DOCKER_BUILD_TAG_GATEWAY }}" \ ego run /home/ten/go-ten/tools/walletextension/main/main \ -host=0.0.0.0 -port=80 -portWS=81 -nodeHost="${{ env.L2_RPC_URL_VALIDATOR }}" -verbose=true \ - -logPath=sys_out -dbType=mariaDB -dbConnectionURL="obscurouser:${{ secrets.OBSCURO_GATEWAY_MARIADB_USER_PWD }}@tcp(obscurogateway-mariadb-${{ github.event.inputs.testnet_type }}.uksouth.cloudapp.azure.com:3306)/ogdb" \ + -logPath=sys_out -dbType=cosmosDB -dbConnectionURL="${{ secrets.COSMOS_DB_CONNECTION_STRING }}" \ -rateLimitUserComputeTime="${{ env.GATEWAY_RATE_LIMIT_USER_COMPUTE_TIME }}" -rateLimitWindow="${{ env.GATEWAY_RATE_LIMIT_WINDOW }}" -maxConcurrentRequestsPerUser="${{ env.GATEWAY_MAX_CONCURRENT_REQUESTS_PER_USER }}" ' diff --git a/go.mod b/go.mod index fd84e92bd0..e51164fef3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21.11 replace github.com/docker/docker => github.com/docker/docker v20.10.3-0.20220224222438-c78f6963a1c0+incompatible require ( + github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.1.0 github.com/FantasyJony/openzeppelin-merkle-tree-go v1.1.3 github.com/Microsoft/go-winio v0.6.2 github.com/andybalholm/brotli v1.1.1 @@ -55,6 +56,9 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/DataDog/zstd v1.5.6 // indirect github.com/VictoriaMetrics/fastcache v1.12.2 // indirect @@ -102,6 +106,7 @@ require ( github.com/go-playground/validator/v10 v10.22.1 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/glog v1.2.2 // indirect github.com/golang/mock v1.6.0 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect github.com/gorilla/mux v1.8.1 // indirect diff --git a/go.sum b/go.sum index 3f0eed559d..8599f67082 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,25 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= +github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.1.0 h1:c726lgbwpwFBuj+Fyrwuh/vUilqFo+hUAOUNjsKj5DI= +github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.1.0/go.mod h1:WzFGxuepAtZIZtQbz8/WviJycLMKJHpaEAqcXONxlag= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/DataDog/zstd v1.5.6 h1:LbEglqepa/ipmmQJUDnSsfvA8e8IStVcGaFWDuxvGOY= github.com/DataDog/zstd v1.5.6/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/FantasyJony/openzeppelin-merkle-tree-go v1.1.3 h1:KzMvCFet0baw6uJnxTE/His8YeRgaxlASd4/ISuTvzI= github.com/FantasyJony/openzeppelin-merkle-tree-go v1.1.3/go.mod h1:OiwyYqbtMkQH+VzA4b8lI+qHnExJy0fIdz+59/8nFes= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/DataDog/zstd v1.5.5 h1:oWf5W7GtOLgp6bciQYDmhHHjdhYkALu6S/5Ni9ZgSvQ= +github.com/DataDog/zstd v1.5.5/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/VictoriaMetrics/fastcache v1.12.2 h1:N0y9ASrJ0F6h0QaC3o6uJb3NIZ9VKLjCM7NQbSmF7WI= @@ -159,6 +173,11 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= +github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -279,6 +298,8 @@ github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNH github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/tools/walletextension/common/common.go b/tools/walletextension/common/common.go index 94ac8be729..838d2a37d5 100644 --- a/tools/walletextension/common/common.go +++ b/tools/walletextension/common/common.go @@ -1,6 +1,7 @@ package common import ( + "crypto/rand" "encoding/json" "fmt" @@ -15,6 +16,8 @@ import ( gethlog "github.com/ethereum/go-ethereum/log" ) +const EncryptionKeySize = 32 + // PrivateKeyToCompressedPubKey converts *ecies.PrivateKey to compressed PubKey ([]byte with length 33) func PrivateKeyToCompressedPubKey(prvKey *ecies.PrivateKey) []byte { ecdsaPublicKey := prvKey.PublicKey.ExportECDSA() @@ -76,3 +79,12 @@ func (r *RPCRequest) Clone() *RPCRequest { Params: r.Params, } } + +func GenerateRandomKey() ([]byte, error) { + key := make([]byte, EncryptionKeySize) + _, err := rand.Read(key) + if err != nil { + return nil, err + } + return key, nil +} diff --git a/tools/walletextension/common/config.go b/tools/walletextension/common/config.go index 26f43cf459..799d92f36b 100644 --- a/tools/walletextension/common/config.go +++ b/tools/walletextension/common/config.go @@ -19,4 +19,5 @@ type Config struct { RateLimitUserComputeTime time.Duration RateLimitWindow time.Duration RateLimitMaxConcurrentRequests int + Debug bool } diff --git a/tools/walletextension/common/db_types.go b/tools/walletextension/common/db_types.go index cde690813f..04993bbd75 100644 --- a/tools/walletextension/common/db_types.go +++ b/tools/walletextension/common/db_types.go @@ -1,12 +1,13 @@ package common -type AccountDB struct { - AccountAddress []byte - Signature []byte - SignatureType int +type GWUserDB struct { + UserId []byte `json:"userId"` + PrivateKey []byte `json:"privateKey"` + Accounts []GWAccountDB `json:"accounts"` } -type UserDB struct { - UserID []byte - PrivateKey []byte +type GWAccountDB struct { + AccountAddress []byte `json:"accountAddress"` + Signature []byte `json:"signature"` + SignatureType int `json:"signatureType"` } diff --git a/tools/walletextension/enclave.Dockerfile b/tools/walletextension/enclave.Dockerfile index e0a8b77504..33673df659 100644 --- a/tools/walletextension/enclave.Dockerfile +++ b/tools/walletextension/enclave.Dockerfile @@ -10,6 +10,12 @@ FROM ghcr.io/edgelesssys/ego-dev:v1.5.3 AS build-base +# Install ca-certificates package and update it +RUN apt-get update && apt-get install -y \ + ca-certificates \ + && update-ca-certificates + + # setup container data structure RUN mkdir -p /home/ten/go-ten diff --git a/tools/walletextension/encryption/encryption.go b/tools/walletextension/encryption/encryption.go new file mode 100644 index 0000000000..c4e8de7c22 --- /dev/null +++ b/tools/walletextension/encryption/encryption.go @@ -0,0 +1,65 @@ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "io" +) + +// Encryptor provides AES-GCM encryption/decryption with the following characteristics: +// - Uses AES-256-GCM (Galois/Counter Mode) with a 32-byte key +// - Generates a random 12-byte nonce for each encryption operation using crypto/rand +// - The nonce is prepended to the ciphertext output from Encrypt() and is generated +// using crypto/rand.Reader for cryptographically secure random values +// +// Additionally provides HMAC-SHA256 hashing functionality: +// - Uses the same 32-byte key as the encryption operations +// - Generates a 32-byte (256-bit) message authentication code +// - Suitable for creating secure message digests and verifying data integrity +type Encryptor struct { + gcm cipher.AEAD + key []byte +} + +func NewEncryptor(key []byte) (*Encryptor, error) { + if len(key) != 32 { + return nil, fmt.Errorf("key must be 32 bytes long") + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + return &Encryptor{gcm: gcm, key: key}, nil +} + +func (e *Encryptor) Encrypt(plaintext []byte) ([]byte, error) { + nonce := make([]byte, e.gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + return e.gcm.Seal(nonce, nonce, plaintext, nil), nil +} + +func (e *Encryptor) Decrypt(ciphertext []byte) ([]byte, error) { + if len(ciphertext) < e.gcm.NonceSize() { + return nil, errors.New("ciphertext too short") + } + nonce, ciphertext := ciphertext[:e.gcm.NonceSize()], ciphertext[e.gcm.NonceSize():] + return e.gcm.Open(nil, nonce, ciphertext, nil) +} + +func (e *Encryptor) HashWithHMAC(data []byte) []byte { + h := hmac.New(sha256.New, e.key) + h.Write(data) + return h.Sum(nil) +} diff --git a/tools/walletextension/encryption/encryption_test.go b/tools/walletextension/encryption/encryption_test.go new file mode 100644 index 0000000000..8395fd6da9 --- /dev/null +++ b/tools/walletextension/encryption/encryption_test.go @@ -0,0 +1,115 @@ +package encryption + +import ( + "bytes" + "crypto/rand" + "testing" +) + +func TestNewEncryptor(t *testing.T) { + key := make([]byte, 32) // 256-bit key + _, err := rand.Read(key) + if err != nil { + t.Fatalf("Failed to generate random key: %v", err) + } + + encryptor, err := NewEncryptor(key) + if err != nil { + t.Fatalf("NewEncryptor failed: %v", err) + } + + if encryptor == nil { + t.Fatal("NewEncryptor returned nil") + } +} + +func TestEncryptDecrypt(t *testing.T) { + key := make([]byte, 32) // 256-bit key + _, err := rand.Read(key) + if err != nil { + t.Fatalf("Failed to generate random key: %v", err) + } + + encryptor, err := NewEncryptor(key) + if err != nil { + t.Fatalf("NewEncryptor failed: %v", err) + } + + plaintext := []byte("Hello, World!") + + ciphertext, err := encryptor.Encrypt(plaintext) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + decrypted, err := encryptor.Decrypt(ciphertext) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatalf("Decrypted text does not match original plaintext") + } +} + +func TestEncryptDecryptEmptyString(t *testing.T) { + key := make([]byte, 32) // 256-bit key + _, err := rand.Read(key) + if err != nil { + t.Fatalf("Failed to generate random key: %v", err) + } + + encryptor, err := NewEncryptor(key) + if err != nil { + t.Fatalf("NewEncryptor failed: %v", err) + } + + plaintext := []byte("") + + ciphertext, err := encryptor.Encrypt(plaintext) + if err != nil { + t.Fatalf("Encryption of empty string failed: %v", err) + } + + decrypted, err := encryptor.Decrypt(ciphertext) + if err != nil { + t.Fatalf("Decryption of empty string failed: %v", err) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatalf("Decrypted empty string does not match original") + } +} + +func TestDecryptInvalidCiphertext(t *testing.T) { + key := make([]byte, 32) // 256-bit key + _, err := rand.Read(key) + if err != nil { + t.Fatalf("Failed to generate random key: %v", err) + } + + encryptor, err := NewEncryptor(key) + if err != nil { + t.Fatalf("NewEncryptor failed: %v", err) + } + + invalidCiphertext := []byte("This is not a valid ciphertext") + + _, err = encryptor.Decrypt(invalidCiphertext) + if err == nil { + t.Fatal("Decryption of invalid ciphertext should have failed, but didn't") + } +} + +func TestNewEncryptorInvalidKeySize(t *testing.T) { + invalidKey := make([]byte, 31) // Invalid key size (not 16, 24, or 32 bytes) + _, err := rand.Read(invalidKey) + if err != nil { + t.Fatalf("Failed to generate random key: %v", err) + } + + _, err = NewEncryptor(invalidKey) + if err == nil { + t.Fatal("NewEncryptor should have failed with invalid key size, but didn't") + } +} diff --git a/tools/walletextension/main/cli.go b/tools/walletextension/main/cli.go index e1c4befd4d..1d604af99f 100644 --- a/tools/walletextension/main/cli.go +++ b/tools/walletextension/main/cli.go @@ -72,6 +72,10 @@ const ( rateLimitMaxConcurrentRequestsName = "maxConcurrentRequestsPerUser" rateLimitMaxConcurrentRequestsDefault = 3 rateLimitMaxConcurrentRequestsUsage = "Number of concurrent requests allowed per user. Default: 3" + + debugFlagName = "debug" + debugFlagDefault = false + debugFlagUsage = "Flag to enable debug mode" ) func parseCLIArgs() wecommon.Config { @@ -91,6 +95,7 @@ func parseCLIArgs() wecommon.Config { rateLimitUserComputeTime := flag.Duration(rateLimitUserComputeTimeName, rateLimitUserComputeTimeDefault, rateLimitUserComputeTimeUsage) rateLimitWindow := flag.Duration(rateLimitWindowName, rateLimitWindowDefault, rateLimitWindowUsage) rateLimitMaxConcurrentRequests := flag.Int(rateLimitMaxConcurrentRequestsName, rateLimitMaxConcurrentRequestsDefault, rateLimitMaxConcurrentRequestsUsage) + debugFlag := flag.Bool(debugFlagName, debugFlagDefault, debugFlagUsage) flag.Parse() return wecommon.Config{ @@ -109,5 +114,6 @@ func parseCLIArgs() wecommon.Config { RateLimitUserComputeTime: *rateLimitUserComputeTime, RateLimitWindow: *rateLimitWindow, RateLimitMaxConcurrentRequests: *rateLimitMaxConcurrentRequests, + Debug: *debugFlag, } } diff --git a/tools/walletextension/main/enclave.json b/tools/walletextension/main/enclave.json index 24304b82c5..c4bf5eb33b 100644 --- a/tools/walletextension/main/enclave.json +++ b/tools/walletextension/main/enclave.json @@ -14,16 +14,8 @@ ], "files": [ { - "source": "../storage/database/mariadb/001_init.sql", - "target": "/home/ten/go-ten/tools/walletextension/storage/database/mariadb/001_init.sql" - }, - { - "source": "../storage/database/mariadb/002_store_incoming_txs.sql", - "target": "/home/ten/go-ten/tools/walletextension/storage/database/mariadb/002_store_incoming_txs.sql" - }, - { - "source": "../storage/database/mariadb/003_add_signature_type.sql", - "target": "/home/ten/go-ten/tools/walletextension/storage/database/mariadb/003_add_signature_type.sql" + "source": "/etc/ssl/certs/ca-certificates.crt", + "target": "/etc/ssl/certs/ca-certificates.crt" } ] } \ No newline at end of file diff --git a/tools/walletextension/rpcapi/gw_user.go b/tools/walletextension/rpcapi/gw_user.go index 302a488557..73117eb1f4 100644 --- a/tools/walletextension/rpcapi/gw_user.go +++ b/tools/walletextension/rpcapi/gw_user.go @@ -7,6 +7,7 @@ import ( "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ethereum/go-ethereum/common" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) var userCacheKeyPrefix = []byte{0x0, 0x1, 0x2, 0x3} @@ -33,6 +34,28 @@ func (u GWUser) GetAllAddresses() []*common.Address { return accts } +func gwUserFromDB(userDB wecommon.GWUserDB, s *Services) (*GWUser, error) { + result := &GWUser{ + userID: userDB.UserId, + services: s, + accounts: make(map[common.Address]*GWAccount), + userKey: userDB.PrivateKey, + } + + for _, accountDB := range userDB.Accounts { + address := common.BytesToAddress(accountDB.AccountAddress) + gwAccount := &GWAccount{ + user: result, + address: &address, + signature: accountDB.Signature, + signatureType: viewingkey.SignatureType(accountDB.SignatureType), + } + result.accounts[address] = gwAccount + } + + return result, nil +} + func userCacheKey(userID []byte) []byte { var key []byte key = append(key, userCacheKeyPrefix...) @@ -42,21 +65,11 @@ func userCacheKey(userID []byte) []byte { func getUser(userID []byte, s *Services) (*GWUser, error) { return withCache(s.Cache, &CacheCfg{CacheType: LongLiving}, userCacheKey(userID), func() (*GWUser, error) { - result := GWUser{userID: userID, services: s, accounts: map[common.Address]*GWAccount{}} - userPrivateKey, err := s.Storage.GetUserPrivateKey(userID) + user, err := s.Storage.GetUser(userID) if err != nil { return nil, fmt.Errorf("user %s not found. %w", hexutils.BytesToHex(userID), err) } - result.userKey = userPrivateKey - allAccounts, err := s.Storage.GetAccounts(userID) - if err != nil { - return nil, err - } - - for _, account := range allAccounts { - address := common.BytesToAddress(account.AccountAddress) - result.accounts[address] = &GWAccount{user: &result, address: &address, signature: account.Signature, signatureType: viewingkey.SignatureType(uint8(account.SignatureType))} - } - return &result, nil + result, err := gwUserFromDB(user, s) + return result, err }) } diff --git a/tools/walletextension/rpcapi/transaction_api.go b/tools/walletextension/rpcapi/transaction_api.go index a0a8baedd4..ffa978145e 100644 --- a/tools/walletextension/rpcapi/transaction_api.go +++ b/tools/walletextension/rpcapi/transaction_api.go @@ -2,8 +2,6 @@ package rpcapi import ( "context" - "encoding/json" - "fmt" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -105,19 +103,6 @@ func (s *TransactionAPI) SendTransaction(ctx context.Context, args gethapi.Trans if err != nil { return common.Hash{}, err } - userIDBytes, _ := extractUserID(ctx, s.we) - if s.we.Config.StoreIncomingTxs && len(userIDBytes) > 10 { - tx, err := json.Marshal(args) - if err != nil { - s.we.Logger().Error("error marshalling transaction: %s", err) - return *txRec, nil - } - err = s.we.Storage.StoreTransaction(string(tx), userIDBytes) - if err != nil { - s.we.Logger().Error("error storing transaction in the database: %s", err) - return *txRec, nil - } - } return *txRec, err } @@ -135,13 +120,6 @@ func (s *TransactionAPI) SendRawTransaction(ctx context.Context, input hexutil.B if err != nil { return common.Hash{}, err } - userIDBytes, err := extractUserID(ctx, s.we) - if s.we.Config.StoreIncomingTxs && len(userIDBytes) > 10 { - err = s.we.Storage.StoreTransaction(input.String(), userIDBytes) - if err != nil { - s.we.Logger().Error(fmt.Errorf("error storing transaction in the database: %w", err).Error()) - } - } return *txRec, err } diff --git a/tools/walletextension/rpcapi/wallet_extension.go b/tools/walletextension/rpcapi/wallet_extension.go index 060ea5370f..9e93885369 100644 --- a/tools/walletextension/rpcapi/wallet_extension.go +++ b/tools/walletextension/rpcapi/wallet_extension.go @@ -228,11 +228,12 @@ func (w *Services) UserHasAccount(userID []byte, address string) (bool, error) { // todo - this can be optimised and done in the database if we will have users with large number of accounts // get all the accounts for the selected user - accounts, err := w.Storage.GetAccounts(userID) + user, err := w.Storage.GetUser(userID) if err != nil { w.Logger().Error(fmt.Errorf("error getting accounts for user (%s), %w", userID, err).Error()) return false, err } + accounts := user.Accounts // check if any of the account matches given account found := false @@ -262,12 +263,12 @@ func (w *Services) UserExists(userID []byte) bool { // Check if user exists and don't log error if user doesn't exist, because we expect this to happen in case of // user revoking encryption token or using different testnet. // todo add a counter here in the future - key, err := w.Storage.GetUserPrivateKey(userID) + users, err := w.Storage.GetUser(userID) if err != nil { return false } - return len(key) > 0 + return len(users.PrivateKey) > 0 } func (w *Services) Version() string { diff --git a/tools/walletextension/storage/database/cosmosdb/cosmosdb.go b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go new file mode 100644 index 0000000000..e8296f2523 --- /dev/null +++ b/tools/walletextension/storage/database/cosmosdb/cosmosdb.go @@ -0,0 +1,229 @@ +package cosmosdb + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "strings" + + "github.com/ten-protocol/go-ten/go/common/viewingkey" + + "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" + "github.com/ten-protocol/go-ten/go/common/errutil" + "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/encryption" +) + +/* +This is a CosmosDB implementation of the Storage interface. + +We need to make sure we have a CosmosDB account and a database created before using this. + +Quick summary of the CosmosDB setup: +- Create a CosmosDB account (Azure Cosmos DB for NoSQL) +- account name should follow the format: -gateway-cosmosdb +- use Serverless capacity mode for testnets +- go to "Data Explorer" in the CosmosDB account and create new database named "gatewayDB" +- inside the database create a container named "users" with partition key of "/id" +- to get your connection string go to settings -> keys -> primary connection string + +*/ + +// CosmosDB struct represents the CosmosDB storage implementation +type CosmosDB struct { + client *azcosmos.Client + usersContainer *azcosmos.ContainerClient + encryptor encryption.Encryptor +} + +// EncryptedDocument struct is used to store encrypted user data in CosmosDB +// We use this structure to add an extra layer of security by encrypting the actual user data +// The 'ID' field is used as the document ID and partition key in CosmosDB +// The 'Data' field contains the base64-encoded encrypted user data +type EncryptedDocument struct { + ID string `json:"id"` + Data []byte `json:"data"` +} + +// Constants for the CosmosDB database and container names +const ( + DATABASE_NAME = "gatewayDB" + USERS_CONTAINER_NAME = "users" +) + +func NewCosmosDB(connectionString string, encryptionKey []byte) (*CosmosDB, error) { + // Create encryptor + encryptor, err := encryption.NewEncryptor(encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to create encryptor: %w", err) + } + + client, err := azcosmos.NewClientFromConnectionString(connectionString, nil) + if err != nil { + return nil, fmt.Errorf("failed to create CosmosDB client: %w", err) + } + + // Create database if it doesn't exist + ctx := context.Background() + _, err = client.CreateDatabase(ctx, azcosmos.DatabaseProperties{ID: DATABASE_NAME}, nil) + if err != nil && !strings.Contains(err.Error(), "Conflict") { + return nil, fmt.Errorf("failed to create database: %w", err) + } + + // Create container client for users container + usersContainer, err := client.NewContainer(DATABASE_NAME, USERS_CONTAINER_NAME) + if err != nil { + return nil, fmt.Errorf("failed to create users container: %w", err) + } + + return &CosmosDB{ + client: client, + usersContainer: usersContainer, + encryptor: *encryptor, + }, nil +} + +func (c *CosmosDB) AddUser(userID []byte, privateKey []byte) error { + user := common.GWUserDB{ + UserId: userID, + PrivateKey: privateKey, + Accounts: []common.GWAccountDB{}, + } + userJSON, err := json.Marshal(user) + if err != nil { + return fmt.Errorf("failed to marshal user: %w", err) + } + + ciphertext, err := c.encryptor.Encrypt(userJSON) + if err != nil { + return fmt.Errorf("failed to encrypt user data: %w", err) + } + + key := c.encryptor.HashWithHMAC(userID) + keyString := hex.EncodeToString(key) + + // Create an EncryptedDocument struct to store in CosmosDB + doc := EncryptedDocument{ + ID: keyString, + Data: ciphertext, + } + + docJSON, err := json.Marshal(doc) + if err != nil { + return fmt.Errorf("failed to marshal document: %w", err) + } + + partitionKey := azcosmos.NewPartitionKeyString(keyString) + ctx := context.Background() + _, err = c.usersContainer.CreateItem(ctx, partitionKey, docJSON, nil) + if err != nil { + return fmt.Errorf("failed to create item: %w", err) + } + return nil +} + +func (c *CosmosDB) DeleteUser(userID []byte) error { + key := c.encryptor.HashWithHMAC(userID) + keyString := hex.EncodeToString(key) + partitionKey := azcosmos.NewPartitionKeyString(keyString) + ctx := context.Background() + + _, err := c.usersContainer.DeleteItem(ctx, partitionKey, keyString, nil) + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + return nil +} + +func (c *CosmosDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { + key := c.encryptor.HashWithHMAC(userID) + keyString := hex.EncodeToString(key) + partitionKey := azcosmos.NewPartitionKeyString(keyString) + ctx := context.Background() + + itemResponse, err := c.usersContainer.ReadItem(ctx, partitionKey, keyString, nil) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + + var doc EncryptedDocument + err = json.Unmarshal(itemResponse.Value, &doc) + if err != nil { + return fmt.Errorf("failed to unmarshal document: %w", err) + } + + data, err := c.encryptor.Decrypt(doc.Data) + if err != nil { + return fmt.Errorf("failed to decrypt data: %w", err) + } + + var user common.GWUserDB + err = json.Unmarshal(data, &user) + if err != nil { + return fmt.Errorf("failed to unmarshal user data: %w", err) + } + + // Add the new account + newAccount := common.GWAccountDB{ + AccountAddress: accountAddress, + Signature: signature, + SignatureType: int(signatureType), + } + user.Accounts = append(user.Accounts, newAccount) + + userJSON, err := json.Marshal(user) + if err != nil { + return fmt.Errorf("error marshaling updated user: %w", err) + } + + ciphertext, err := c.encryptor.Encrypt(userJSON) + if err != nil { + return fmt.Errorf("failed to encrypt updated user data: %w", err) + } + + // Update the document + doc.Data = ciphertext + + docJSON, err := json.Marshal(doc) + if err != nil { + return fmt.Errorf("failed to marshal updated document: %w", err) + } + + // Replace the item in the container + _, err = c.usersContainer.ReplaceItem(ctx, partitionKey, keyString, docJSON, nil) + if err != nil { + return fmt.Errorf("failed to update user with new account: %w", err) + } + return nil +} + +func (c *CosmosDB) GetUser(userID []byte) (common.GWUserDB, error) { + key := c.encryptor.HashWithHMAC(userID) + keyString := hex.EncodeToString(key) + partitionKey := azcosmos.NewPartitionKeyString(keyString) + ctx := context.Background() + + itemResponse, err := c.usersContainer.ReadItem(ctx, partitionKey, keyString, nil) + if err != nil { + return common.GWUserDB{}, errutil.ErrNotFound + } + + var doc EncryptedDocument + err = json.Unmarshal(itemResponse.Value, &doc) + if err != nil { + return common.GWUserDB{}, fmt.Errorf("failed to unmarshal document: %w", err) + } + + data, err := c.encryptor.Decrypt(doc.Data) + if err != nil { + return common.GWUserDB{}, fmt.Errorf("failed to decrypt data: %w", err) + } + + var user common.GWUserDB + err = json.Unmarshal(data, &user) + if err != nil { + return common.GWUserDB{}, fmt.Errorf("failed to unmarshal user data: %w", err) + } + return user, nil +} diff --git a/tools/walletextension/storage/database/mariadb/001_init.sql b/tools/walletextension/storage/database/mariadb/001_init.sql deleted file mode 100644 index 95a4e61513..0000000000 --- a/tools/walletextension/storage/database/mariadb/001_init.sql +++ /dev/null @@ -1,15 +0,0 @@ -/* - This is a migration file for MariaDB and is executed when the Gateway is started to make sure the database schema is up to date. -*/ - -CREATE TABLE IF NOT EXISTS ogdb.users ( - user_id varbinary(20) PRIMARY KEY, - private_key varbinary(32) -); - -CREATE TABLE IF NOT EXISTS ogdb.accounts ( - user_id varbinary(20), - account_address varbinary(20), - signature varbinary(65), - FOREIGN KEY(user_id) REFERENCES ogdb.users(user_id) ON DELETE CASCADE -); diff --git a/tools/walletextension/storage/database/mariadb/002_store_incoming_txs.sql b/tools/walletextension/storage/database/mariadb/002_store_incoming_txs.sql deleted file mode 100644 index c5dbc2af0e..0000000000 --- a/tools/walletextension/storage/database/mariadb/002_store_incoming_txs.sql +++ /dev/null @@ -1,11 +0,0 @@ -/* - This is a migration file for MariaDB that creates transactions table for storing incoming transactions -*/ - -CREATE TABLE IF NOT EXISTS ogdb.transactions ( - id INT AUTO_INCREMENT PRIMARY KEY, - user_id varbinary(20), - tx_hash TEXT, - tx TEXT, - tx_time DATETIME DEFAULT CURRENT_TIMESTAMP -); \ No newline at end of file diff --git a/tools/walletextension/storage/database/mariadb/003_add_signature_type.sql b/tools/walletextension/storage/database/mariadb/003_add_signature_type.sql deleted file mode 100644 index f04ffb6f23..0000000000 --- a/tools/walletextension/storage/database/mariadb/003_add_signature_type.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE ogdb.accounts ADD COLUMN IF NOT EXISTS signature_type INT DEFAULT 0; \ No newline at end of file diff --git a/tools/walletextension/storage/database/mariadb/mariadb.go b/tools/walletextension/storage/database/mariadb/mariadb.go deleted file mode 100644 index a377970ba2..0000000000 --- a/tools/walletextension/storage/database/mariadb/mariadb.go +++ /dev/null @@ -1,179 +0,0 @@ -package mariadb - -import ( - "database/sql" - "encoding/hex" - "fmt" - "path/filepath" - "runtime" - - "github.com/ten-protocol/go-ten/go/common/storage" - - "github.com/ten-protocol/go-ten/go/common/viewingkey" - - "github.com/ethereum/go-ethereum/crypto" - - _ "github.com/go-sql-driver/mysql" // Importing MariaDB driver - "github.com/ten-protocol/go-ten/go/common/errutil" - "github.com/ten-protocol/go-ten/tools/walletextension/common" -) - -type MariaDB struct { - db *sql.DB -} - -// NewMariaDB creates a new MariaDB connection instance -func NewMariaDB(dbURL string) (*MariaDB, error) { - db, err := sql.Open("mysql", dbURL+"?multiStatements=true") - if err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) - } - - // get the path to the migrations (they are always in the same directory as file containing connection function) - _, filename, _, ok := runtime.Caller(0) - if !ok { - return nil, fmt.Errorf("failed to get current directory") - } - migrationsDir := filepath.Dir(filename) - - if err = storage.ApplyMigrations(db, migrationsDir); err != nil { - return nil, err - } - - return &MariaDB{db: db}, nil -} - -func (m *MariaDB) AddUser(userID []byte, privateKey []byte) error { - stmt, err := m.db.Prepare("REPLACE INTO users(user_id, private_key) VALUES (?, ?)") - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(userID, privateKey) - if err != nil { - return err - } - - return nil -} - -func (m *MariaDB) DeleteUser(userID []byte) error { - stmt, err := m.db.Prepare("DELETE FROM users WHERE user_id = ?") - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(userID) - if err != nil { - return err - } - - return nil -} - -func (m *MariaDB) GetUserPrivateKey(userID []byte) ([]byte, error) { - var privateKey []byte - err := m.db.QueryRow("SELECT private_key FROM users WHERE user_id = ?", userID).Scan(&privateKey) - if err != nil { - if err == sql.ErrNoRows { - // No rows found for the given userID - return nil, errutil.ErrNotFound - } - return nil, err - } - - return privateKey, nil -} - -func (m *MariaDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { - stmt, err := m.db.Prepare("INSERT INTO accounts(user_id, account_address, signature, signature_type) VALUES (?, ?, ?, ?)") - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(userID, accountAddress, signature, int(signatureType)) - if err != nil { - return err - } - - return nil -} - -func (m *MariaDB) GetAccounts(userID []byte) ([]common.AccountDB, error) { - rows, err := m.db.Query("SELECT account_address, signature, signature_type FROM accounts WHERE user_id = ?", userID) - if err != nil { - return nil, err - } - defer rows.Close() - - var accounts []common.AccountDB - for rows.Next() { - var account common.AccountDB - if err := rows.Scan(&account.AccountAddress, &account.Signature, &account.SignatureType); err != nil { - return nil, err - } - accounts = append(accounts, account) - } - if err := rows.Err(); err != nil { - return nil, err - } - - return accounts, nil -} - -func (m *MariaDB) GetAllUsers() ([]common.UserDB, error) { - rows, err := m.db.Query("SELECT user_id, private_key FROM users") - if err != nil { - return nil, err - } - defer rows.Close() - - var users []common.UserDB - for rows.Next() { - var user common.UserDB - err = rows.Scan(&user.UserID, &user.PrivateKey) - if err != nil { - return nil, err - } - users = append(users, user) - } - - if err = rows.Err(); err != nil { - return nil, err - } - - return users, nil -} - -func (m *MariaDB) StoreTransaction(rawTx string, userID []byte) error { - stmt, err := m.db.Prepare("INSERT INTO transactions(user_id, tx_hash, tx) VALUES (?, ?, ?)") - if err != nil { - return err - } - defer stmt.Close() - - // Validate rawTx length and get the txHash - txHash := "" - if len(rawTx) < 3 { - fmt.Println("Invalid rawTx: ", rawTx) - } else { - // Decode the hex string to bytes, excluding the '0x' prefix - rawTxBytes, err := hex.DecodeString(rawTx[2:]) - if err != nil { - fmt.Println("Error decoding rawTx: ", err) - } else { - // Compute Keccak-256 hash - txHash = crypto.Keccak256Hash(rawTxBytes).Hex() - } - } - - _, err = stmt.Exec(userID, txHash, rawTx) - if err != nil { - return err - } - - return nil -} diff --git a/tools/walletextension/storage/database/sqlite/sqlite.go b/tools/walletextension/storage/database/sqlite/sqlite.go index a39abfd65c..f08086bf5b 100644 --- a/tools/walletextension/storage/database/sqlite/sqlite.go +++ b/tools/walletextension/storage/database/sqlite/sqlite.go @@ -1,20 +1,24 @@ package sqlite +/* + SQLite database implementation of the Storage interface + + SQLite is used for local deployments and testing without the need for a cloud database. + To make sure to see similar behaviour as in production using CosmosDB we use SQLite database in a similar way as comosDB (as key-value database). +*/ import ( "database/sql" - "encoding/hex" + "encoding/json" "fmt" "os" "path/filepath" "github.com/ten-protocol/go-ten/go/common/viewingkey" - - "github.com/ethereum/go-ethereum/crypto" + "github.com/ten-protocol/go-ten/tools/walletextension/common" _ "github.com/mattn/go-sqlite3" // sqlite driver for sql.Open() obscurocommon "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/errutil" - common "github.com/ten-protocol/go-ten/tools/walletextension/common" ) type Database struct { @@ -41,50 +45,38 @@ func NewSqliteDatabase(dbPath string) (*Database, error) { return nil, err } - // create users table + // Modify the users table to store the entire GWUserDB as JSON _, err = db.Exec(`CREATE TABLE IF NOT EXISTS users ( - user_id binary(20) PRIMARY KEY, - private_key binary(32) - );`) - if err != nil { - return nil, err - } - - // create accounts table - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS accounts ( - user_id binary(20), - account_address binary(20), - signature binary(65), - signature_type int, - FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE + id TEXT PRIMARY KEY, + user_data TEXT );`) if err != nil { return nil, err } - // create transactions table - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS transactions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id binary(20), - tx_hash TEXT, - tx TEXT, - tx_time TEXT DEFAULT (datetime('now')) -) ;`) - if err != nil { - return nil, err - } + // Remove the accounts table as it will be stored within the user_data JSON return &Database{db: db}, nil } func (s *Database) AddUser(userID []byte, privateKey []byte) error { - stmt, err := s.db.Prepare("INSERT OR REPLACE INTO users(user_id, private_key) VALUES (?, ?)") + user := common.GWUserDB{ + UserId: userID, + PrivateKey: privateKey, + Accounts: []common.GWAccountDB{}, + } + userJSON, err := json.Marshal(user) + if err != nil { + return err + } + + stmt, err := s.db.Prepare("INSERT OR REPLACE INTO users(id, user_data) VALUES (?, ?)") if err != nil { return err } defer stmt.Close() - _, err = stmt.Exec(userID, privateKey) + _, err = stmt.Exec(string(user.UserId), string(userJSON)) if err != nil { return err } @@ -93,93 +85,77 @@ func (s *Database) AddUser(userID []byte, privateKey []byte) error { } func (s *Database) DeleteUser(userID []byte) error { - stmt, err := s.db.Prepare("DELETE FROM users WHERE user_id = ?") + stmt, err := s.db.Prepare("DELETE FROM users WHERE id = ?") if err != nil { return err } defer stmt.Close() - _, err = stmt.Exec(userID) + _, err = stmt.Exec(string(userID)) if err != nil { - return err + return fmt.Errorf("failed to delete user: %w", err) } return nil } -func (s *Database) GetUserPrivateKey(userID []byte) ([]byte, error) { - var privateKey []byte - err := s.db.QueryRow("SELECT private_key FROM users WHERE user_id = ?", userID).Scan(&privateKey) +func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { + var userDataJSON string + err := s.db.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON) if err != nil { - if err == sql.ErrNoRows { - // No rows found for the given userID - return nil, errutil.ErrNotFound - } - return nil, err + return fmt.Errorf("failed to get user: %w", err) } - return privateKey, nil -} - -func (s *Database) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { - stmt, err := s.db.Prepare("INSERT INTO accounts(user_id, account_address, signature, signature_type) VALUES (?, ?, ?, ?)") + var user common.GWUserDB + err = json.Unmarshal([]byte(userDataJSON), &user) if err != nil { - return err + return fmt.Errorf("failed to unmarshal user data: %w", err) } - defer stmt.Close() - _, err = stmt.Exec(userID, accountAddress, signature, int(signatureType)) - if err != nil { - return err + newAccount := common.GWAccountDB{ + AccountAddress: accountAddress, + Signature: signature, + SignatureType: int(signatureType), } - return nil -} + user.Accounts = append(user.Accounts, newAccount) -func (s *Database) GetAccounts(userID []byte) ([]common.AccountDB, error) { - rows, err := s.db.Query("SELECT account_address, signature, signature_type FROM accounts WHERE user_id = ?", userID) + updatedUserJSON, err := json.Marshal(user) if err != nil { - return nil, err + return fmt.Errorf("error marshaling updated user: %w", err) } - defer rows.Close() - var accounts []common.AccountDB - for rows.Next() { - var account common.AccountDB - if err := rows.Scan(&account.AccountAddress, &account.Signature, &account.SignatureType); err != nil { - return nil, err - } - accounts = append(accounts, account) + stmt, err := s.db.Prepare("UPDATE users SET user_data = ? WHERE id = ?") + if err != nil { + return err } - if err := rows.Err(); err != nil { - return nil, err + defer stmt.Close() + + _, err = stmt.Exec(string(updatedUserJSON), string(userID)) + if err != nil { + return fmt.Errorf("failed to update user with new account: %w", err) } - return accounts, nil + return nil } -func (s *Database) GetAllUsers() ([]common.UserDB, error) { - rows, err := s.db.Query("SELECT user_id, private_key FROM users") +func (s *Database) GetUser(userID []byte) (common.GWUserDB, error) { + var userDataJSON string + err := s.db.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON) if err != nil { - return nil, err - } - defer rows.Close() - - var users []common.UserDB - for rows.Next() { - var user common.UserDB - err = rows.Scan(&user.UserID, &user.PrivateKey) - if err != nil { - return nil, err + if err == sql.ErrNoRows { + return common.GWUserDB{}, fmt.Errorf("failed to get user: %w", errutil.ErrNotFound) } - users = append(users, user) + return common.GWUserDB{}, fmt.Errorf("failed to get user: %w", err) } - if err = rows.Err(); err != nil { - return nil, err + var user common.GWUserDB + err = json.Unmarshal([]byte(userDataJSON), &user) + if err != nil { + return common.GWUserDB{}, fmt.Errorf("failed to unmarshal user data: %w", err) } - return users, nil + return user, nil } func createOrLoad(dbPath string) (string, error) { @@ -203,32 +179,3 @@ func createOrLoad(dbPath string) (string, error) { return dbPath, nil } - -func (s *Database) StoreTransaction(rawTx string, userID []byte) error { - stmt, err := s.db.Prepare("INSERT INTO transactions(user_id, tx_hash, tx) VALUES (?, ?, ?)") - if err != nil { - return err - } - defer stmt.Close() - - txHash := "" - if len(rawTx) < 3 { - fmt.Println("Invalid rawTx: ", rawTx) - } else { - // Decode the hex string to bytes, excluding the '0x' prefix - rawTxBytes, err := hex.DecodeString(rawTx[2:]) - if err != nil { - fmt.Println("Error decoding rawTx: ", err) - } else { - // Compute Keccak-256 hash - txHash = crypto.Keccak256Hash(rawTxBytes).Hex() - } - } - - _, err = stmt.Exec(userID, txHash, rawTx) - if err != nil { - return err - } - - return nil -} diff --git a/tools/walletextension/storage/storage.go b/tools/walletextension/storage/storage.go index 78efcd1a39..cb1fb6304b 100644 --- a/tools/walletextension/storage/storage.go +++ b/tools/walletextension/storage/storage.go @@ -5,28 +5,25 @@ import ( "github.com/ten-protocol/go-ten/go/common/viewingkey" - "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/mariadb" - "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite" - "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/cosmosdb" + "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/sqlite" ) type Storage interface { AddUser(userID []byte, privateKey []byte) error DeleteUser(userID []byte) error - GetUserPrivateKey(userID []byte) ([]byte, error) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error - GetAccounts(userID []byte) ([]common.AccountDB, error) - GetAllUsers() ([]common.UserDB, error) - StoreTransaction(rawTx string, userID []byte) error + GetUser(userID []byte) (common.GWUserDB, error) } -func New(dbType string, dbConnectionURL, dbPath string) (Storage, error) { +func New(dbType string, dbConnectionURL, dbPath string, randomKey []byte) (Storage, error) { switch dbType { - case "mariaDB": - return mariadb.NewMariaDB(dbConnectionURL) case "sqlite": return sqlite.NewSqliteDatabase(dbPath) + case "cosmosDB": + return cosmosdb.NewCosmosDB(dbConnectionURL, randomKey) } + return nil, fmt.Errorf("unknown db %s", dbType) } diff --git a/tools/walletextension/storage/storage_test.go b/tools/walletextension/storage/storage_test.go index 4bdc6e1cee..2053fc37c5 100644 --- a/tools/walletextension/storage/storage_test.go +++ b/tools/walletextension/storage/storage_test.go @@ -2,6 +2,7 @@ package storage import ( "bytes" + "crypto/rand" "errors" "testing" @@ -9,21 +10,24 @@ import ( "github.com/stretchr/testify/require" "github.com/ten-protocol/go-ten/go/common/errutil" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) var tests = map[string]func(storage Storage, t *testing.T){ - "testAddAndGetUser": testAddAndGetUser, - "testAddAndGetAccounts": testAddAndGetAccounts, - "testDeleteUser": testDeleteUser, - "testGetAllUsers": testGetAllUsers, - "testStoringNewTx": testStoringNewTx, + "testAddAndGetUser": testAddAndGetUser, + "testAddAccounts": testAddAccounts, + "testDeleteUser": testDeleteUser, + "testGetUser": testGetUser, } -func TestSQLiteGatewayDB(t *testing.T) { +func TestGatewayStorage(t *testing.T) { + randomKey, err := wecommon.GenerateRandomKey() + require.NoError(t, err) + for name, test := range tests { t.Run(name, func(t *testing.T) { - // storage, err := New("mariaDB", "obscurouser:password@tcp(127.0.0.1:3306)/ogdb", "") allows to run tests against a local instance of MariaDB - storage, err := New("sqlite", "", "") + storage, err := New("sqlite", "", "", randomKey) + // storage, err := New("cosmosDB", "", "", randomKey) require.NoError(t, err) test(storage, t) @@ -32,61 +36,88 @@ func TestSQLiteGatewayDB(t *testing.T) { } func testAddAndGetUser(storage Storage, t *testing.T) { - userID := []byte("userID") - privateKey := []byte("privateKey") + // Generate random user ID and private key + userID := make([]byte, 20) + _, err := rand.Read(userID) + if err != nil { + t.Fatal(err) + } + privateKey := make([]byte, 32) + _, err = rand.Read(privateKey) + if err != nil { + t.Fatal(err) + } - err := storage.AddUser(userID, privateKey) + // Add user to storage + err = storage.AddUser(userID, privateKey) if err != nil { t.Fatal(err) } - returnedPrivateKey, err := storage.GetUserPrivateKey(userID) + // Retrieve user's private key from storage + user, err := storage.GetUser(userID) if err != nil { t.Fatal(err) } - if !bytes.Equal(returnedPrivateKey, privateKey) { - t.Errorf("privateKey mismatch: got %v, want %v", returnedPrivateKey, privateKey) + // Check if retrieved private key matches the original + if !bytes.Equal(user.PrivateKey, privateKey) { + t.Errorf("privateKey mismatch: got %v, want %v", user.PrivateKey, privateKey) } } -func testAddAndGetAccounts(storage Storage, t *testing.T) { - userID := []byte("userID") - privateKey := []byte("privateKey") - accountAddress1 := []byte("accountAddress1") - signature1 := []byte("signature1") - +func testAddAccounts(storage Storage, t *testing.T) { + // Generate random user ID, private key, and account details + userID := make([]byte, 20) + rand.Read(userID) + privateKey := make([]byte, 32) + rand.Read(privateKey) + accountAddress1 := make([]byte, 20) + rand.Read(accountAddress1) + signature1 := make([]byte, 65) + rand.Read(signature1) + + // Add a new user to the storage err := storage.AddUser(userID, privateKey) if err != nil { t.Fatal(err) } + // Add the first account for the user err = storage.AddAccount(userID, accountAddress1, signature1, viewingkey.EIP712Signature) if err != nil { t.Fatal(err) } - accountAddress2 := []byte("accountAddress2") - signature2 := []byte("signature2") + // Generate details for a second account + accountAddress2 := make([]byte, 20) + rand.Read(accountAddress2) + signature2 := make([]byte, 65) + rand.Read(signature2) + // Add the second account for the user err = storage.AddAccount(userID, accountAddress2, signature2, viewingkey.EIP712Signature) if err != nil { t.Fatal(err) } - accounts, err := storage.GetAccounts(userID) + // Retrieve all accounts for the user + user, err := storage.GetUser(userID) if err != nil { t.Fatal(err) } - if len(accounts) != 2 { - t.Errorf("Expected 2 accounts, got %d", len(accounts)) + // Check if the correct number of accounts were retrieved + if len(user.Accounts) != 2 { + t.Errorf("Expected 2 accounts, got %d", len(user.Accounts)) } + // Flags to check if both accounts are found foundAccount1 := false foundAccount2 := false - for _, account := range accounts { + // Iterate through retrieved accounts and check if they match the added accounts + for _, account := range user.Accounts { if bytes.Equal(account.AccountAddress, accountAddress1) && bytes.Equal(account.Signature, signature1) { foundAccount1 = true } @@ -95,6 +126,7 @@ func testAddAndGetAccounts(storage Storage, t *testing.T) { } } + // Verify that both accounts were found if !foundAccount1 { t.Errorf("Account 1 was not found in the result") } @@ -105,55 +137,65 @@ func testAddAndGetAccounts(storage Storage, t *testing.T) { } func testDeleteUser(storage Storage, t *testing.T) { - userID := []byte("testDeleteUserID") - privateKey := []byte("testDeleteUserPrivateKey") + // Generate random user ID and private key + userID := make([]byte, 20) + rand.Read(userID) + privateKey := make([]byte, 32) + rand.Read(privateKey) + // Add user to storage err := storage.AddUser(userID, privateKey) if err != nil { t.Fatal(err) } + // Delete the user err = storage.DeleteUser(userID) if err != nil { t.Fatal(err) } - _, err = storage.GetUserPrivateKey(userID) + // Attempt to retrieve the deleted user's private key + // This should fail with a "not found" error + _, err = storage.GetUser(userID) if err == nil || !errors.Is(err, errutil.ErrNotFound) { - t.Fatal("Expected error when getting deleted user, but got none") + t.Fatal("Expected 'not found' error when getting deleted user, but got none or different error") } } -func testGetAllUsers(storage Storage, t *testing.T) { - initialUsers, err := storage.GetAllUsers() - if err != nil { - t.Fatal(err) - } - - userID := []byte("getAllUsersTestID") - privateKey := []byte("getAllUsersTestPrivateKey") +func testGetUser(storage Storage, t *testing.T) { + // Generate random user ID and private key + userID := make([]byte, 20) + rand.Read(userID) + privateKey := make([]byte, 32) + rand.Read(privateKey) - err = storage.AddUser(userID, privateKey) + // Add user to storage + err := storage.AddUser(userID, privateKey) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to add user: %v", err) } - afterInsertUsers, err := storage.GetAllUsers() + // Get user from storage + user, err := storage.GetUser(userID) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to get user: %v", err) } - if len(afterInsertUsers) != len(initialUsers)+1 { - t.Errorf("Expected user count to increase by 1. Got %d initially and %d after insert", len(initialUsers), len(afterInsertUsers)) + // Check if retrieved user matches the added user + if !bytes.Equal(user.UserId, userID) { + t.Errorf("Retrieved user ID does not match. Expected %x, got %x", userID, user.UserId) } -} -func testStoringNewTx(storage Storage, t *testing.T) { - userID := []byte("userID") - rawTransaction := "0x0123456789" + if !bytes.Equal(user.PrivateKey, privateKey) { + t.Errorf("Retrieved private key does not match. Expected %x, got %x", privateKey, user.PrivateKey) + } - err := storage.StoreTransaction(rawTransaction, userID) - if err != nil { - t.Fatal(err) + // Try to get a non-existent user + nonExistentUserID := make([]byte, 20) + rand.Read(nonExistentUserID) + _, err = storage.GetUser(nonExistentUserID) + if err == nil { + t.Error("Expected error when getting non-existent user, but got none") } } diff --git a/tools/walletextension/storage/storage_with_cache.go b/tools/walletextension/storage/storage_with_cache.go new file mode 100644 index 0000000000..7c83d75423 --- /dev/null +++ b/tools/walletextension/storage/storage_with_cache.go @@ -0,0 +1,90 @@ +package storage + +import ( + "sync" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ten-protocol/go-ten/go/common/viewingkey" + "github.com/ten-protocol/go-ten/tools/walletextension/cache" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" +) + +// StorageWithCache implements the Storage interface with caching +type StorageWithCache struct { + storage Storage + cache cache.Cache + mu sync.RWMutex +} + +// NewStorageWithCache creates a new StorageWithCache instance +func NewStorageWithCache(storage Storage, logger log.Logger) (*StorageWithCache, error) { + c, err := cache.NewCache(logger) + if err != nil { + return nil, err + } + return &StorageWithCache{ + storage: storage, + cache: c, + }, nil +} + +// AddUser adds a new user and invalidates the cache for the userID +func (s *StorageWithCache) AddUser(userID []byte, privateKey []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + err := s.storage.AddUser(userID, privateKey) + if err != nil { + return err + } + s.cache.Remove(userID) + return nil +} + +// DeleteUser deletes a user and invalidates the cache for the userID +func (s *StorageWithCache) DeleteUser(userID []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + err := s.storage.DeleteUser(userID) + if err != nil { + return err + } + s.cache.Remove(userID) + return nil +} + +// AddAccount adds an account to a user and invalidates the cache for the userID +func (s *StorageWithCache) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error { + s.mu.Lock() + defer s.mu.Unlock() + err := s.storage.AddAccount(userID, accountAddress, signature, signatureType) + if err != nil { + return err + } + s.cache.Remove(userID) + return nil +} + +// GetUser retrieves a user from the cache or underlying storage +func (s *StorageWithCache) GetUser(userID []byte) (wecommon.GWUserDB, error) { + s.mu.RLock() + // Check if the user is in the cache + if cachedUser, found := s.cache.Get(userID); found { + s.mu.RUnlock() + return cachedUser.(wecommon.GWUserDB), nil + } + s.mu.RUnlock() + + // If not in cache, retrieve from storage + user, err := s.storage.GetUser(userID) + if err != nil { + return wecommon.GWUserDB{}, err + } + + // Store the retrieved user in the cache + s.mu.Lock() + s.cache.Set(userID, user, 5*time.Minute) + s.mu.Unlock() + + return user, nil +} diff --git a/tools/walletextension/walletextension_container.go b/tools/walletextension/walletextension_container.go index f6499ec6c7..a306e3a99a 100644 --- a/tools/walletextension/walletextension_container.go +++ b/tools/walletextension/walletextension_container.go @@ -32,8 +32,20 @@ func NewContainerFromConfig(config wecommon.Config, logger gethlog.Logger) *Cont // create the account manager with a single unauthenticated connection hostRPCBindAddrWS := wecommon.WSProtocol + config.NodeRPCWebsocketAddress hostRPCBindAddrHTTP := wecommon.HTTPProtocol + config.NodeRPCHTTPAddress - // start the database - databaseStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride) + + // Database encryption key handling + // TODO: Check if encryption key is already sealed and unseal it and generate new one if not (part of the next PR) + // TODO: We should have a mechanism to get the key from an enclave that already runs (part of the next PR) + // TODO: Move this to a separate file along with key exchange logic (part of the next PR) + + encryptionKey, err := wecommon.GenerateRandomKey() + if err != nil { + logger.Crit("unable to generate random encryption key", log.ErrKey, err) + os.Exit(1) + } + + // start the database with the encryption key + databaseStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride, encryptionKey) if err != nil { logger.Crit("unable to create database to store viewing keys ", log.ErrKey, err) os.Exit(1)