diff --git a/go.mod b/go.mod index e1be7176b9..06516ba0ce 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.21 replace github.com/docker/docker => github.com/docker/docker v20.10.3-0.20220224222438-c78f6963a1c0+incompatible require ( + github.com/Microsoft/go-winio v0.6.1 github.com/andybalholm/brotli v1.1.0 + github.com/deckarep/golang-set/v2 v2.6.0 github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v25.0.4+incompatible github.com/docker/go-connections v0.5.0 @@ -17,14 +19,18 @@ require ( github.com/gin-gonic/gin v1.9.1 github.com/go-kit/kit v0.13.0 github.com/go-sql-driver/mysql v1.8.0 + github.com/gofrs/flock v0.8.1 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/holiman/uint256 v1.2.4 + github.com/jolestar/go-commons-pool/v2 v2.1.2 github.com/mattn/go-sqlite3 v1.14.22 github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 github.com/pkg/errors v0.9.1 + github.com/rs/cors v1.10.1 github.com/sanity-io/litter v1.5.5 github.com/status-im/keycard-go v0.3.2 github.com/stretchr/testify v1.9.0 @@ -43,7 +49,6 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/DataDog/zstd v1.5.5 // indirect - github.com/Microsoft/go-winio v0.6.1 // indirect github.com/VictoriaMetrics/fastcache v1.12.2 // indirect github.com/allegro/bigcache v1.2.1 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -65,7 +70,6 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/crate-crypto/go-kzg-4844 v0.7.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/deckarep/golang-set/v2 v2.6.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/distribution/reference v0.5.0 // indirect github.com/docker/distribution v2.8.3+incompatible // indirect @@ -86,9 +90,7 @@ require ( github.com/go-playground/validator/v10 v10.19.0 // indirect github.com/go-stack/stack v1.8.1 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/gofrs/flock v0.8.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/glog v1.2.0 // indirect github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect @@ -105,11 +107,8 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/mitchellh/pointerstructure v1.2.1 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect @@ -127,8 +126,6 @@ require ( github.com/prometheus/procfs v0.13.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - github.com/rs/cors v1.10.1 // indirect - github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shirou/gopsutil v3.21.11+incompatible // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stretchr/objx v0.5.2 // indirect diff --git a/go.sum b/go.sum index 4f122dede7..f1251d5bde 100644 --- a/go.sum +++ b/go.sum @@ -128,6 +128,8 @@ github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/fjl/memsize v0.0.2 h1:27txuSD9or+NZlnOWdKUxeBzTAUkWCVh+4Gf2dWFOzA= github.com/fjl/memsize v0.0.2/go.mod h1:VvhXpOYNQvB+uIk2RvXzuaQtkQJzzIx6lSBe1xv7hi0= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -260,6 +262,8 @@ github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0Gqw github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= +github.com/jolestar/go-commons-pool/v2 v2.1.2 h1:E+XGo58F23t7HtZiC/W6jzO2Ux2IccSH/yx4nD+J1CM= +github.com/jolestar/go-commons-pool/v2 v2.1.2/go.mod h1:r4NYccrkS5UqP1YQI1COyTZ9UjPJAAGTUxzcsK1kqhY= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -309,7 +313,6 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= @@ -322,7 +325,6 @@ github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/pointerstructure v1.2.1 h1:ZhBBeX8tSlRpu/FFhXH4RC4OJzFlqsQhoHZAz4x7TIw= @@ -398,6 +400,7 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99 github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= @@ -577,7 +580,6 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/go/common/log_events.go b/go/common/log_events.go index fd01b0357c..962762dbcf 100644 --- a/go/common/log_events.go +++ b/go/common/log_events.go @@ -1,7 +1,15 @@ package common import ( + "encoding/json" + "errors" + "fmt" + "math/big" + "strings" + + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ten-protocol/go-ten/go/common/viewingkey" @@ -14,7 +22,7 @@ type LogSubscription struct { ViewingKey *viewingkey.RPCSignedViewingKey // A subscriber-defined filter to apply to the stream of logs. - Filter *filters.FilterCriteria + Filter *FilterCriteriaJSON } // IDAndEncLog pairs an encrypted log with the ID of the subscription that generated it. @@ -36,6 +44,177 @@ type FilterCriteriaJSON struct { BlockHash *common.Hash `json:"blockHash"` FromBlock *rpc.BlockNumber `json:"fromBlock"` ToBlock *rpc.BlockNumber `json:"toBlock"` - Addresses interface{} `json:"address"` - Topics []interface{} `json:"topics"` + Addresses []common.Address `json:"addresses"` + Topics [][]common.Hash `json:"topics"` +} + +func FromCriteria(crit FilterCriteria) FilterCriteriaJSON { + var from *rpc.BlockNumber + if crit.FromBlock != nil { + f := (rpc.BlockNumber)(crit.FromBlock.Int64()) + from = &f + } + + var to *rpc.BlockNumber + if crit.ToBlock != nil { + t := (rpc.BlockNumber)(crit.ToBlock.Int64()) + to = &t + } + + return FilterCriteriaJSON{ + BlockHash: crit.BlockHash, + FromBlock: from, + ToBlock: to, + Addresses: crit.Addresses, + Topics: crit.Topics, + } +} + +func ToCriteria(jsonCriteria FilterCriteriaJSON) filters.FilterCriteria { + var from *big.Int + if jsonCriteria.FromBlock != nil { + from = big.NewInt(jsonCriteria.FromBlock.Int64()) + } + var to *big.Int + if jsonCriteria.ToBlock != nil { + to = big.NewInt(jsonCriteria.ToBlock.Int64()) + } + + return filters.FilterCriteria{ + BlockHash: jsonCriteria.BlockHash, + FromBlock: from, + ToBlock: to, + Addresses: jsonCriteria.Addresses, + Topics: jsonCriteria.Topics, + } +} + +var errInvalidTopic = errors.New("invalid topic(s)") + +// FilterCriteria represents a request to create a new filter. +// Same as ethereum.FilterQuery but with UnmarshalJSON() method. +// duplicated from geth to tweak the unmarshalling +type FilterCriteria ethereum.FilterQuery + +// UnmarshalJSON sets *args fields with given data. +func (args *FilterCriteria) UnmarshalJSON(data []byte) error { + type input struct { + BlockHash *common.Hash `json:"blockHash"` + FromBlock *rpc.BlockNumber `json:"fromBlock"` + ToBlock *rpc.BlockNumber `json:"toBlock"` + Addresses interface{} `json:"address"` + Topics []interface{} `json:"topics"` + } + + var raw input + if err := json.Unmarshal(data, &raw); err != nil { + // tweak to handle the case when an empty array is passed in by javascript libraries + if strings.Contains(err.Error(), "cannot unmarshal array") { + return nil + } + return err + } + + if raw.BlockHash != nil { + if raw.FromBlock != nil || raw.ToBlock != nil { + // BlockHash is mutually exclusive with FromBlock/ToBlock criteria + return errors.New("cannot specify both BlockHash and FromBlock/ToBlock, choose one or the other") + } + args.BlockHash = raw.BlockHash + } else { + if raw.FromBlock != nil { + args.FromBlock = big.NewInt(raw.FromBlock.Int64()) + } + + if raw.ToBlock != nil { + args.ToBlock = big.NewInt(raw.ToBlock.Int64()) + } + } + + args.Addresses = []common.Address{} + + if raw.Addresses != nil { + // raw.Address can contain a single address or an array of addresses + switch rawAddr := raw.Addresses.(type) { + case []interface{}: + for i, addr := range rawAddr { + if strAddr, ok := addr.(string); ok { + addr, err := decodeAddress(strAddr) + if err != nil { + return fmt.Errorf("invalid address at index %d: %v", i, err) + } + args.Addresses = append(args.Addresses, addr) + } else { + return fmt.Errorf("non-string address at index %d", i) + } + } + case string: + addr, err := decodeAddress(rawAddr) + if err != nil { + return fmt.Errorf("invalid address: %v", err) + } + args.Addresses = []common.Address{addr} + default: + return errors.New("invalid addresses in query") + } + } + + // topics is an array consisting of strings and/or arrays of strings. + // JSON null values are converted to common.Hash{} and ignored by the filter manager. + if len(raw.Topics) > 0 { + args.Topics = make([][]common.Hash, len(raw.Topics)) + for i, t := range raw.Topics { + switch topic := t.(type) { + case nil: + // ignore topic when matching logs + + case string: + // match specific topic + top, err := decodeTopic(topic) + if err != nil { + return err + } + args.Topics[i] = []common.Hash{top} + + case []interface{}: + // or case e.g. [null, "topic0", "topic1"] + for _, rawTopic := range topic { + if rawTopic == nil { + // null component, match all + args.Topics[i] = nil + break + } + if topic, ok := rawTopic.(string); ok { + parsed, err := decodeTopic(topic) + if err != nil { + return err + } + args.Topics[i] = append(args.Topics[i], parsed) + } else { + return errInvalidTopic + } + } + default: + return errInvalidTopic + } + } + } + + return nil +} + +func decodeAddress(s string) (common.Address, error) { + b, err := hexutil.Decode(s) + if err == nil && len(b) != common.AddressLength { + err = fmt.Errorf("hex has invalid length %d after decoding; expected %d for address", len(b), common.AddressLength) + } + return common.BytesToAddress(b), err +} + +func decodeTopic(s string) (common.Hash, error) { + b, err := hexutil.Decode(s) + if err == nil && len(b) != common.HashLength { + err = fmt.Errorf("hex has invalid length %d after decoding; expected %d for topic", len(b), common.HashLength) + } + return common.BytesToHash(b), err } diff --git a/go/common/viewingkey/viewing_key.go b/go/common/viewingkey/viewing_key.go index 6d70a59977..93ad049999 100644 --- a/go/common/viewingkey/viewing_key.go +++ b/go/common/viewingkey/viewing_key.go @@ -42,7 +42,7 @@ func GenerateViewingKeyForWallet(wal wallet.Wallet) (*ViewingKey, error) { if err != nil { return nil, err } - encryptionToken := CalculateUserIDHex(crypto.CompressPubkey(viewingPrivateKeyECIES.PublicKey.ExportECDSA())) + encryptionToken := CalculateUserID(crypto.CompressPubkey(viewingPrivateKeyECIES.PublicKey.ExportECDSA())) messageToSign, err := GenerateMessage(encryptionToken, chainID, PersonalSignVersion, messageType) if err != nil { return nil, fmt.Errorf("failed to generate message for viewing key: %w", err) diff --git a/go/common/viewingkey/viewing_key_messages.go b/go/common/viewingkey/viewing_key_messages.go index 01ff4829e5..a0efcbc362 100644 --- a/go/common/viewingkey/viewing_key_messages.go +++ b/go/common/viewingkey/viewing_key_messages.go @@ -6,6 +6,8 @@ import ( "fmt" "math/big" + "github.com/status-im/keycard-go/hexutils" + "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/crypto" @@ -29,18 +31,13 @@ const ( EIP712EncryptionToken = "Encryption Token" EIP712DomainNameValue = "Ten" EIP712DomainVersionValue = "1.0" - UserIDHexLength = 40 + UserIDLength = 20 PersonalSignMessageFormat = "Token: %s on chain: %d version: %d" PersonalSignVersion = 1 ) -// EIP712EncryptionTokens is a list of all possible options for Encryption token name -var EIP712EncryptionTokens = [...]string{ - EIP712EncryptionToken, -} - type MessageGenerator interface { - generateMessage(encryptionToken string, chainID int64, version int) ([]byte, error) + generateMessage(encryptionToken []byte, chainID int64, version int) ([]byte, error) } type ( @@ -54,13 +51,13 @@ var messageGenerators = map[SignatureType]MessageGenerator{ } // GenerateMessage generates a message for the given encryptionToken, chainID, version and signatureType -func (p PersonalMessageGenerator) generateMessage(encryptionToken string, chainID int64, version int) ([]byte, error) { - return []byte(fmt.Sprintf(PersonalSignMessageFormat, encryptionToken, chainID, version)), nil +func (p PersonalMessageGenerator) generateMessage(encryptionToken []byte, chainID int64, version int) ([]byte, error) { + return []byte(fmt.Sprintf(PersonalSignMessageFormat, hexutils.BytesToHex(encryptionToken), chainID, version)), nil } -func (e EIP712MessageGenerator) generateMessage(encryptionToken string, chainID int64, _ int) ([]byte, error) { - if len(encryptionToken) != UserIDHexLength { - return nil, fmt.Errorf("userID hex length must be %d, received %d", UserIDHexLength, len(encryptionToken)) +func (e EIP712MessageGenerator) generateMessage(encryptionToken []byte, chainID int64, _ int) ([]byte, error) { + if len(encryptionToken) != UserIDLength { + return nil, fmt.Errorf("userID must be %d bytes, received %d", UserIDLength, len(encryptionToken)) } EIP712TypedData := createTypedDataForEIP712Message(encryptionToken, chainID) @@ -73,7 +70,7 @@ func (e EIP712MessageGenerator) generateMessage(encryptionToken string, chainID } // GenerateMessage generates a message for the given encryptionToken, chainID, version and signatureType -func GenerateMessage(encryptionToken string, chainID int64, version int, signatureType SignatureType) ([]byte, error) { +func GenerateMessage(encryptionToken []byte, chainID int64, version int, signatureType SignatureType) ([]byte, error) { generator, exists := messageGenerators[signatureType] if !exists { return nil, fmt.Errorf("unsupported signature type") @@ -144,8 +141,8 @@ func getBytesFromTypedData(typedData apitypes.TypedData) ([]byte, error) { } // createTypedDataForEIP712Message creates typed data for EIP712 message -func createTypedDataForEIP712Message(encryptionToken string, chainID int64) apitypes.TypedData { - encryptionToken = "0x" + encryptionToken +func createTypedDataForEIP712Message(encryptionToken []byte, chainID int64) apitypes.TypedData { + hexToken := hexutils.BytesToHex(encryptionToken) domain := apitypes.TypedDataDomain{ Name: EIP712DomainNameValue, @@ -154,7 +151,7 @@ func createTypedDataForEIP712Message(encryptionToken string, chainID int64) apit } message := map[string]interface{}{ - EIP712EncryptionToken: encryptionToken, + EIP712EncryptionToken: hexToken, } types := apitypes.Types{ diff --git a/go/common/viewingkey/viewing_key_signature.go b/go/common/viewingkey/viewing_key_signature.go index 6737b5c574..38eb822f95 100644 --- a/go/common/viewingkey/viewing_key_signature.go +++ b/go/common/viewingkey/viewing_key_signature.go @@ -13,7 +13,7 @@ import ( // SignatureChecker is an interface for checking // if signature is valid for provided encryptionToken and chainID and return singing address or nil if not valid type SignatureChecker interface { - CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) + CheckSignature(encryptionToken []byte, signature []byte, chainID int64) (*gethcommon.Address, error) } type ( @@ -22,7 +22,7 @@ type ( ) // CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid -func (psc PersonalSignChecker) CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) { +func (psc PersonalSignChecker) CheckSignature(encryptionToken []byte, signature []byte, chainID int64) (*gethcommon.Address, error) { if len(signature) != 65 { return nil, fmt.Errorf("invalid signaure length: %d", len(signature)) } @@ -51,7 +51,7 @@ func (psc PersonalSignChecker) CheckSignature(encryptionToken string, signature return nil, fmt.Errorf("signature verification failed") } -func (e EIP712Checker) CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) { +func (e EIP712Checker) CheckSignature(encryptionToken []byte, signature []byte, chainID int64) (*gethcommon.Address, error) { if len(signature) != 65 { return nil, fmt.Errorf("invalid signaure length: %d", len(signature)) } @@ -88,7 +88,7 @@ var signatureCheckers = map[SignatureType]SignatureChecker{ } // CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid -func CheckSignature(encryptionToken string, signature []byte, chainID int64, signatureType SignatureType) (*gethcommon.Address, error) { +func CheckSignature(encryptionToken []byte, signature []byte, chainID int64, signatureType SignatureType) (*gethcommon.Address, error) { checker, exists := signatureCheckers[signatureType] if !exists { return nil, fmt.Errorf("unsupported signature type") diff --git a/go/enclave/events/subscription_manager.go b/go/enclave/events/subscription_manager.go index fcaeaaeb2e..b49ff1f84e 100644 --- a/go/enclave/events/subscription_manager.go +++ b/go/enclave/events/subscription_manager.go @@ -3,7 +3,6 @@ package events import ( "encoding/json" "fmt" - "math/big" "sync" "github.com/ten-protocol/go-ten/go/enclave/vkhandler" @@ -19,7 +18,6 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" gethlog "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/rlp" "github.com/ten-protocol/go-ten/go/common" ) @@ -61,7 +59,7 @@ func NewSubscriptionManager(storage storage.Storage, chainID int64, logger gethl // correctly. If there is an existing subscription with the given ID, it is overwritten. func (s *SubscriptionManager) AddSubscription(id gethrpc.ID, encodedSubscription []byte) error { subscription := &common.LogSubscription{} - if err := rlp.DecodeBytes(encodedSubscription, subscription); err != nil { + if err := json.Unmarshal(encodedSubscription, subscription); err != nil { return fmt.Errorf("could not decocde log subscription from RLP. Cause: %w", err) } @@ -238,15 +236,15 @@ func getUserAddrsFromLogTopics(log *types.Log, db *state.StateDB) []*gethcommon. // Lifted from eth/filters/filter.go in the go-ethereum repository. // filterLogs creates a slice of logs matching the given criteria. -func filterLogs(logs []*types.Log, fromBlock, toBlock *big.Int, addresses []gethcommon.Address, topics [][]gethcommon.Hash, logger gethlog.Logger) []*types.Log { //nolint:gocognit +func filterLogs(logs []*types.Log, fromBlock, toBlock *gethrpc.BlockNumber, addresses []gethcommon.Address, topics [][]gethcommon.Hash, logger gethlog.Logger) []*types.Log { //nolint:gocognit var ret []*types.Log Logs: for _, logItem := range logs { - if fromBlock != nil && fromBlock.Int64() >= 0 && fromBlock.Uint64() > logItem.BlockNumber { + if fromBlock != nil && fromBlock.Int64() >= 0 && fromBlock.Int64() > int64(logItem.BlockNumber) { logger.Debug("Skipping log ", "log", logItem, "reason", "In the past. The starting block num for filter is bigger than log") continue } - if toBlock != nil && toBlock.Int64() > 0 && toBlock.Uint64() < logItem.BlockNumber { + if toBlock != nil && toBlock.Int64() > 0 && toBlock.Int64() < int64(logItem.BlockNumber) { logger.Debug("Skipping log ", "log", logItem, "reason", "In the future. The ending block num for filter is smaller than log") continue } diff --git a/go/enclave/rpc/GetBalance.go b/go/enclave/rpc/GetBalance.go index de9b6a4f42..8bdc26fc14 100644 --- a/go/enclave/rpc/GetBalance.go +++ b/go/enclave/rpc/GetBalance.go @@ -3,6 +3,8 @@ package rpc import ( "fmt" + "github.com/status-im/keycard-go/hexutils" + "github.com/ethereum/go-ethereum/common" "github.com/ten-protocol/go-ten/lib/gethfork/rpc" @@ -48,7 +50,7 @@ func GetBalanceExecute(builder *CallBuilder[BalanceReq, hexutil.Big], rpc *Encry // authorise the call if acctOwner.Hex() != builder.VK.AccountAddress.Hex() { - rpc.logger.Debug("Unauthorised call", "address", acctOwner, "vk", builder.VK.AccountAddress, "userId", builder.VK.UserID) + rpc.logger.Debug("Unauthorised call", "address", acctOwner, "vk", builder.VK.AccountAddress, "userId", hexutils.BytesToHex(builder.VK.UserID)) builder.Status = NotAuthorised return nil } diff --git a/go/enclave/rpc/GetLogs.go b/go/enclave/rpc/GetLogs.go index 9ac1d2acbc..b4798600e2 100644 --- a/go/enclave/rpc/GetLogs.go +++ b/go/enclave/rpc/GetLogs.go @@ -5,39 +5,40 @@ import ( "errors" "fmt" + "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ethereum/go-ethereum/core/types" - gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ten-protocol/go-ten/go/common/syserr" ) func GetLogsValidate(reqParams []any, builder *CallBuilder[filters.FilterCriteria, []*types.Log], _ *EncryptionManager) error { - // Parameters are [Filter, Address] - if len(reqParams) != 2 { + // Parameters are [Filter] + if len(reqParams) != 1 { builder.Err = fmt.Errorf("unexpected number of parameters") return nil } - // We extract the arguments from the param bytes. - filter, forAddress, err := extractGetLogsParams(reqParams) + + serialised, err := json.Marshal(reqParams[0]) if err != nil { - builder.Err = err - return nil //nolint:nilerr + builder.Err = fmt.Errorf("invalid parameter %w", err) + return nil + } + var crit common.FilterCriteriaJSON + err = json.Unmarshal(serialised, &crit) + if err != nil { + builder.Err = fmt.Errorf("invalid parameter %w", err) + return nil } - builder.From = forAddress - builder.Param = filter + filter := common.ToCriteria(crit) + + builder.Param = &filter return nil } func GetLogsExecute(builder *CallBuilder[filters.FilterCriteria, []*types.Log], rpc *EncryptionManager) error { //nolint:gocognit - err := authenticateFrom(builder.VK, builder.From) - if err != nil { - builder.Err = err - return nil //nolint:nilerr - } - filter := builder.Param // todo logic to check that the filter is valid // can't have both from and blockhash @@ -83,7 +84,7 @@ func GetLogsExecute(builder *CallBuilder[filters.FilterCriteria, []*types.Log], } // We retrieve the relevant logs that match the filter. - filteredLogs, err := rpc.storage.FilterLogs(builder.From, from, to, nil, filter.Addresses, filter.Topics) + filteredLogs, err := rpc.storage.FilterLogs(builder.VK.AccountAddress, from, to, nil, filter.Addresses, filter.Topics) if err != nil { if errors.Is(err, syserr.InternalError{}) { return err @@ -95,28 +96,3 @@ func GetLogsExecute(builder *CallBuilder[filters.FilterCriteria, []*types.Log], builder.ReturnValue = &filteredLogs return nil } - -// Returns the params extracted from an eth_getLogs request. -func extractGetLogsParams(paramList []interface{}) (*filters.FilterCriteria, *gethcommon.Address, error) { - // We extract the first param, the filter for the logs. - // We marshal the filter criteria from a map to JSON, then back from JSON into a FilterCriteria. This is - // because the filter criteria arrives as a map, and there is no way to convert it to a map directly into a - // FilterCriteria. - filterJSON, err := json.Marshal(paramList[0]) - if err != nil { - return nil, nil, fmt.Errorf("could not marshal filter criteria to JSON. Cause: %w", err) - } - filter := filters.FilterCriteria{} - err = filter.UnmarshalJSON(filterJSON) - if err != nil { - return nil, nil, fmt.Errorf("could not unmarshal filter criteria from JSON. Cause: %w", err) - } - - // We extract the second param, the address the logs are for. - forAddressHex, ok := paramList[1].(string) - if !ok { - return nil, nil, fmt.Errorf("expected second argument in GetLogs request to be of type string, but got %T", paramList[0]) - } - forAddress := gethcommon.HexToAddress(forAddressHex) - return &filter, &forAddress, nil -} diff --git a/go/enclave/vkhandler/vk_handler.go b/go/enclave/vkhandler/vk_handler.go index 06d6d27335..e1ec0bf4f8 100644 --- a/go/enclave/vkhandler/vk_handler.go +++ b/go/enclave/vkhandler/vk_handler.go @@ -1,4 +1,4 @@ -package vkhandler +package vkhandler //nolint:typecheck import ( "crypto/rand" @@ -21,7 +21,7 @@ type AuthenticatedViewingKey struct { rpcVK *viewingkey.RPCSignedViewingKey AccountAddress *gethcommon.Address ecdsaKey *ecies.PublicKey - UserID string + UserID []byte } func VerifyViewingKey(rpcVK *viewingkey.RPCSignedViewingKey, chainID int64) (*AuthenticatedViewingKey, error) { @@ -48,7 +48,7 @@ func VerifyViewingKey(rpcVK *viewingkey.RPCSignedViewingKey, chainID int64) (*Au // checkViewingKeyAndRecoverAddress checks the signature and recovers the address from the viewing key func checkViewingKeyAndRecoverAddress(vk *AuthenticatedViewingKey, chainID int64) (*gethcommon.Address, error) { // get userID from viewingKey public key - userID := viewingkey.CalculateUserIDHex(vk.rpcVK.PublicKey) + userID := viewingkey.CalculateUserID(vk.rpcVK.PublicKey) vk.UserID = userID // check the signature and recover the address assuming the message was signed with EIP712 diff --git a/go/enclave/vkhandler/vk_handler_test.go b/go/enclave/vkhandler/vk_handler_test.go index e8065ef2e8..1e4704a0f9 100644 --- a/go/enclave/vkhandler/vk_handler_test.go +++ b/go/enclave/vkhandler/vk_handler_test.go @@ -18,14 +18,14 @@ const chainID = 443 // generateRandomUserKeys - // generates a random user private key and a random viewing key private key and returns the user private key, // the viewing key private key, the userID and the user address -func generateRandomUserKeys() (*ecdsa.PrivateKey, *ecdsa.PrivateKey, string, gethcommon.Address) { +func generateRandomUserKeys() (*ecdsa.PrivateKey, *ecdsa.PrivateKey, []byte, gethcommon.Address) { userPrivKey, err := crypto.GenerateKey() // user private key if err != nil { - return nil, nil, "", gethcommon.Address{} + return nil, nil, nil, gethcommon.Address{} } vkPrivKey, _ := crypto.GenerateKey() // viewingkey generated in the gateway if err != nil { - return nil, nil, "", gethcommon.Address{} + return nil, nil, nil, gethcommon.Address{} } // get the address from userPrivKey @@ -33,7 +33,7 @@ func generateRandomUserKeys() (*ecdsa.PrivateKey, *ecdsa.PrivateKey, string, get // get userID from viewingKey public key vkPubKeyBytes := crypto.CompressPubkey(ecies.ImportECDSAPublic(&vkPrivKey.PublicKey).ExportECDSA()) - userID := viewingkey.CalculateUserIDHex(vkPubKeyBytes) + userID := viewingkey.CalculateUserID(vkPubKeyBytes) return userPrivKey, vkPrivKey, userID, userAddress } diff --git a/go/obsclient/authclient.go b/go/obsclient/authclient.go index f417e3443f..6f4eec0495 100644 --- a/go/obsclient/authclient.go +++ b/go/obsclient/authclient.go @@ -9,7 +9,6 @@ import ( "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/params" "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/viewingkey" @@ -192,13 +191,13 @@ func (ac *AuthObsClient) BalanceAt(ctx context.Context, blockNumber *big.Int) (* return (*big.Int)(&result), err } -func (ac *AuthObsClient) SubscribeFilterLogs(ctx context.Context, filterCriteria filters.FilterCriteria, ch chan common.IDAndLog) (ethereum.Subscription, error) { +func (ac *AuthObsClient) SubscribeFilterLogs(ctx context.Context, filterCriteria common.FilterCriteria, ch chan common.IDAndLog) (ethereum.Subscription, error) { return ac.rpcClient.Subscribe(ctx, nil, rpc.SubscribeNamespace, ch, rpc.SubscriptionTypeLogs, filterCriteria) } -func (ac *AuthObsClient) GetLogs(ctx context.Context, filterCriteria common.FilterCriteriaJSON) ([]*types.Log, error) { +func (ac *AuthObsClient) GetLogs(ctx context.Context, filterCriteria common.FilterCriteria) ([]*types.Log, error) { var result responses.LogsType - err := ac.rpcClient.CallContext(ctx, &result, rpc.GetLogs, filterCriteria, ac.account) + err := ac.rpcClient.CallContext(ctx, &result, rpc.GetLogs, filterCriteria) if err != nil { return nil, err } diff --git a/go/rpc/client.go b/go/rpc/client.go index cb2297ac20..78e52fb28a 100644 --- a/go/rpc/client.go +++ b/go/rpc/client.go @@ -33,8 +33,6 @@ const ( GetTotalTxs = "tenscan_getTotalTransactions" Attestation = "tenscan_attestation" StopHost = "test_stopHost" - Subscribe = "eth_subscribe" - Unsubscribe = "eth_unsubscribe" SubscribeNamespace = "eth" SubscriptionTypeLogs = "logs" diff --git a/go/rpc/encrypted_client.go b/go/rpc/encrypted_client.go index 24050814f2..8b0d8635e8 100644 --- a/go/rpc/encrypted_client.go +++ b/go/rpc/encrypted_client.go @@ -12,8 +12,6 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/eth/filters" - "github.com/ethereum/go-ethereum/rlp" "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ten-protocol/go-ten/go/common/log" @@ -70,6 +68,10 @@ func NewEncRPCClient(client Client, viewingKey *viewingkey.ViewingKey, logger ge return encClient, nil } +func (c *EncRPCClient) Client() Client { + return c.obscuroClient +} + // Call handles JSON rpc requests without a context - see CallContext for details func (c *EncRPCClient) Call(result interface{}, method string, args ...interface{}) error { return c.CallContext(nil, result, method, args...) //nolint:staticcheck @@ -108,8 +110,7 @@ func (c *EncRPCClient) Subscribe(ctx context.Context, _ interface{}, namespace s return nil, err } - // We use RLP instead of JSON marshaling here, as for some reason the filter criteria doesn't unmarshal correctly from JSON. - encodedLogSubscription, err := rlp.EncodeToBytes(logSubscription) + encodedLogSubscription, err := json.Marshal(logSubscription) if err != nil { return nil, err } @@ -181,20 +182,16 @@ func (c *EncRPCClient) createAuthenticatedLogSubscription(args []interface{}) (* // If there are less than two arguments, it means no filter criteria was passed. if len(args) < 2 { - logSubscription.Filter = &filters.FilterCriteria{} + logSubscription.Filter = &common.FilterCriteriaJSON{} return logSubscription, nil } - filterCriteria, ok := args[1].(filters.FilterCriteria) + filterCriteria, ok := args[1].(common.FilterCriteria) if !ok { return nil, fmt.Errorf("invalid subscription") } - // If we do not override a nil block hash to an empty one, RLP decoding will fail on the enclave side. - if filterCriteria.BlockHash == nil { - filterCriteria.BlockHash = &gethcommon.Hash{} - } - - logSubscription.Filter = &filterCriteria + fc := common.FromCriteria(filterCriteria) + logSubscription.Filter = &fc return logSubscription, nil } @@ -243,13 +240,13 @@ func (c *EncRPCClient) executeSensitiveCall(ctx context.Context, result interfac // EstimateGas and Call methods return EVM Errors that are json objects // and contain multiple keys that normally do not get serialized if method == EstimateGas || method == Call { - var result errutil.EVMSerialisableError - err = json.Unmarshal([]byte(decodedError.Error()), &result) + var evmErr errutil.EVMSerialisableError + err = json.Unmarshal([]byte(decodedError.Error()), &evmErr) if err != nil { - return err + return decodedError } // Return the evm user error. - return result + return evmErr } // Return the user error. diff --git a/go/rpc/network_client.go b/go/rpc/network_client.go index 22c9c4e64d..aa08c4274c 100644 --- a/go/rpc/network_client.go +++ b/go/rpc/network_client.go @@ -7,6 +7,7 @@ import ( "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" gethlog "github.com/ethereum/go-ethereum/log" ) @@ -16,9 +17,9 @@ const ( http = "http://" ) -// networkClient is a Client implementation that wraps Geth's rpc.Client to make calls to the obscuro node -type networkClient struct { - rpcClient *rpc.Client +// NetworkClient is a Client implementation that wraps Geth's rpc.Client to make calls to the obscuro node +type NetworkClient struct { + RpcClient *rpc.Client } // NewEncNetworkClient returns a network RPC client with Viewing Key encryption/decryption @@ -34,6 +35,14 @@ func NewEncNetworkClient(rpcAddress string, viewingKey *viewingkey.ViewingKey, l return encClient, nil } +func NewEncNetworkClientFromConn(connection *gethrpc.Client, viewingKey *viewingkey.ViewingKey, logger gethlog.Logger) (*EncRPCClient, error) { + encClient, err := NewEncRPCClient(&NetworkClient{RpcClient: connection}, viewingKey, logger) + if err != nil { + return nil, err + } + return encClient, nil +} + // NewNetworkClient returns a client that can make RPC calls to an Obscuro node func NewNetworkClient(address string) (Client, error) { if !strings.HasPrefix(address, http) && !strings.HasPrefix(address, ws) { @@ -45,25 +54,25 @@ func NewNetworkClient(address string) (Client, error) { return nil, fmt.Errorf("could not create RPC client on %s. Cause: %w", address, err) } - return &networkClient{ - rpcClient: rpcClient, + return &NetworkClient{ + RpcClient: rpcClient, }, nil } // Call handles JSON rpc requests, delegating to the geth RPC client // The result must be a pointer so that package json can unmarshal into it. You can also pass nil, in which case the result is ignored. -func (c *networkClient) Call(result interface{}, method string, args ...interface{}) error { - return c.rpcClient.Call(result, method, args...) +func (c *NetworkClient) Call(result interface{}, method string, args ...interface{}) error { + return c.RpcClient.Call(result, method, args...) } -func (c *networkClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { - return c.rpcClient.CallContext(ctx, result, method, args...) +func (c *NetworkClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + return c.RpcClient.CallContext(ctx, result, method, args...) } -func (c *networkClient) Subscribe(ctx context.Context, _ interface{}, namespace string, channel interface{}, args ...interface{}) (*rpc.ClientSubscription, error) { - return c.rpcClient.Subscribe(ctx, namespace, channel, args...) +func (c *NetworkClient) Subscribe(ctx context.Context, _ interface{}, namespace string, channel interface{}, args ...interface{}) (*gethrpc.ClientSubscription, error) { + return c.RpcClient.Subscribe(ctx, namespace, channel, args...) } -func (c *networkClient) Stop() { - c.rpcClient.Close() +func (c *NetworkClient) Stop() { + c.RpcClient.Close() } diff --git a/integration/constants.go b/integration/constants.go index 48b3bcfa29..e6c73d8224 100644 --- a/integration/constants.go +++ b/integration/constants.go @@ -12,11 +12,11 @@ const ( StartPortNetworkTests = 17000 StartPortSmartContractTests = 18000 StartPortContractDeployerTest1 = 19000 - StartPortContractDeployerTest2 = 20000 - StartPortWalletExtensionUnitTest = 21000 + StartPortContractDeployerTest2 = 21000 StartPortFaucetUnitTest = 22000 StartPortFaucetHTTPUnitTest = 23000 StartPortTenGatewayUnitTest = 24000 + StartPortWalletExtensionUnitTest = 25000 DefaultGethWSPortOffset = 100 DefaultGethAUTHPortOffset = 200 diff --git a/integration/networktest/env/network_setup.go b/integration/networktest/env/network_setup.go index 8f0b4df2ca..86d45bb8e7 100644 --- a/integration/networktest/env/network_setup.go +++ b/integration/networktest/env/network_setup.go @@ -8,8 +8,8 @@ import ( "github.com/ten-protocol/go-ten/integration" "github.com/ten-protocol/go-ten/integration/common/testlog" "github.com/ten-protocol/go-ten/integration/networktest" - gatewaycfg "github.com/ten-protocol/go-ten/tools/walletextension/config" - "github.com/ten-protocol/go-ten/tools/walletextension/container" + "github.com/ten-protocol/go-ten/tools/walletextension" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) const ( @@ -68,7 +68,7 @@ type TestnetEnvOption func(env *testnetEnv) type testnetEnv struct { testnetConnector *testnetConnector localTenGateway bool - tenGatewayContainer *container.WalletExtensionContainer + tenGatewayContainer *walletextension.Container logger gethlog.Logger } @@ -99,7 +99,7 @@ func (t *testnetEnv) startTenGateway() { validatorHTTP := validator[len("http://"):] // replace the last character with a 1 (expect it to be zero), this is good enough for these tests validatorWS := validatorHTTP[:len(validatorHTTP)-1] + "1" - cfg := gatewaycfg.Config{ + cfg := wecommon.Config{ WalletExtensionHost: "127.0.0.1", WalletExtensionPortHTTP: _gwHTTPPort, WalletExtensionPortWS: _gwWSPort, @@ -110,7 +110,7 @@ func (t *testnetEnv) startTenGateway() { DBType: "sqlite", TenChainID: integration.TenChainID, } - tenGWContainer := container.NewWalletExtensionContainerFromConfig(cfg, t.logger) + tenGWContainer := walletextension.NewContainerFromConfig(cfg, t.logger) go func() { fmt.Println("Starting Ten Gateway, HTTP Port:", _gwHTTPPort, "WS Port:", _gwWSPort) err := tenGWContainer.Start() diff --git a/integration/obscurogateway/errors_contract.go b/integration/obscurogateway/errors_contract.go index 105640d9ae..67e2c076ad 100644 --- a/integration/obscurogateway/errors_contract.go +++ b/integration/obscurogateway/errors_contract.go @@ -1,4 +1,4 @@ -package faucet +package obscurogateway import ( "strings" diff --git a/integration/obscurogateway/events_contract.go b/integration/obscurogateway/events_contract.go index 859472ca6d..2815f56f2b 100644 --- a/integration/obscurogateway/events_contract.go +++ b/integration/obscurogateway/events_contract.go @@ -1,4 +1,4 @@ -package faucet +package obscurogateway import ( "strings" diff --git a/integration/obscurogateway/gateway_user.go b/integration/obscurogateway/gateway_user.go index 3f425d351c..b37445c504 100644 --- a/integration/obscurogateway/gateway_user.go +++ b/integration/obscurogateway/gateway_user.go @@ -1,4 +1,4 @@ -package faucet +package obscurogateway import ( "context" @@ -22,7 +22,7 @@ type GatewayUser struct { tgClient *lib.TGLib } -func NewUser(wallets []wallet.Wallet, serverAddressHTTP string, serverAddressWS string) (*GatewayUser, error) { +func NewGatewayUser(wallets []wallet.Wallet, serverAddressHTTP string, serverAddressWS string) (*GatewayUser, error) { ogClient := lib.NewTenGatewayLibrary(serverAddressHTTP, serverAddressWS) // automatically join diff --git a/tools/walletextension/common/json.go b/integration/obscurogateway/json.go similarity index 96% rename from tools/walletextension/common/json.go rename to integration/obscurogateway/json.go index 2079aef3bc..41b8629f0c 100644 --- a/tools/walletextension/common/json.go +++ b/integration/obscurogateway/json.go @@ -1,4 +1,4 @@ -package common +package obscurogateway import ( "encoding/json" diff --git a/integration/obscurogateway/tengateway_test.go b/integration/obscurogateway/tengateway_test.go index 1dd7fc7d63..c8846dedfc 100644 --- a/integration/obscurogateway/tengateway_test.go +++ b/integration/obscurogateway/tengateway_test.go @@ -1,4 +1,4 @@ -package faucet +package obscurogateway import ( "bytes" @@ -12,8 +12,12 @@ import ( "testing" "time" + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + + "github.com/ten-protocol/go-ten/tools/walletextension" + "github.com/go-kit/kit/transport/http/jsonrpc" - "github.com/ten-protocol/go-ten/go/rpc" + tenrpc "github.com/ten-protocol/go-ten/go/rpc" log2 "github.com/ten-protocol/go-ten/go/common/log" @@ -38,8 +42,6 @@ import ( "github.com/ten-protocol/go-ten/integration/ethereummock" "github.com/ten-protocol/go-ten/integration/simulation/network" "github.com/ten-protocol/go-ten/integration/simulation/params" - "github.com/ten-protocol/go-ten/tools/walletextension/config" - "github.com/ten-protocol/go-ten/tools/walletextension/container" "github.com/ten-protocol/go-ten/tools/walletextension/lib" "github.com/valyala/fasthttp" ) @@ -61,7 +63,7 @@ func TestTenGateway(t *testing.T) { startPort := integration.StartPortTenGatewayUnitTest createTenNetwork(t, startPort) - tenGatewayConf := config.Config{ + tenGatewayConf := wecommon.Config{ WalletExtensionHost: "127.0.0.1", WalletExtensionPortHTTP: startPort + integration.DefaultTenGatewayHTTPPortOffset, WalletExtensionPortWS: startPort + integration.DefaultTenGatewayWSPortOffset, @@ -74,7 +76,7 @@ func TestTenGateway(t *testing.T) { StoreIncomingTxs: true, } - tenGwContainer := container.NewWalletExtensionContainerFromConfig(tenGatewayConf, testlog.Logger()) + tenGwContainer := walletextension.NewContainerFromConfig(tenGatewayConf, testlog.Logger()) go func() { err := tenGwContainer.Start() if err != nil { @@ -122,15 +124,18 @@ func TestTenGateway(t *testing.T) { } func testMultipleAccountsSubscription(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { - user0, err := NewUser([]wallet.Wallet{w}, httpURL, wsURL) + user0, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) require.NoError(t, err) testlog.Logger().Info("Created user with encryption token", "t", user0.tgClient.UserID()) - user1, err := NewUser([]wallet.Wallet{datagenerator.RandomWallet(integration.TenChainID), datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + _, err = user0.HTTPClient.ChainID(context.Background()) + require.NoError(t, err) + + user1, err := NewGatewayUser([]wallet.Wallet{datagenerator.RandomWallet(integration.TenChainID), datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) require.NoError(t, err) testlog.Logger().Info("Created user with encryption token", "t", user1.tgClient.UserID()) - user2, err := NewUser([]wallet.Wallet{datagenerator.RandomWallet(integration.TenChainID), datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + user2, err := NewGatewayUser([]wallet.Wallet{datagenerator.RandomWallet(integration.TenChainID), datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) require.NoError(t, err) testlog.Logger().Info("Created user with encryption token", "t", user2.tgClient.UserID()) @@ -154,12 +159,7 @@ func testMultipleAccountsSubscription(t *testing.T, httpURL, wsURL string, w wal require.NoError(t, err) // Print balances of all registered accounts to check if all accounts have funds - balances, err := user0.GetUserAccountsBalances() - require.NoError(t, err) - for _, balance := range balances { - require.NotZero(t, balance.Uint64()) - } - balances, err = user1.GetUserAccountsBalances() + balances, err := user1.GetUserAccountsBalances() require.NoError(t, err) for _, balance := range balances { require.NotZero(t, balance.Uint64()) @@ -190,6 +190,9 @@ func testMultipleAccountsSubscription(t *testing.T, httpURL, wsURL string, w wal contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user0.HTTPClient, signedTx.Hash(), time.Minute) require.NoError(t, err) + _, err = user0.HTTPClient.CodeAt(context.Background(), contractReceipt.ContractAddress, big.NewInt(int64(rpc.LatestBlockNumber))) + require.NoError(t, err) + // check if value was changed in the smart contract with the interactions above pack, _ := eventsContractABI.Pack("message2") result, err := user1.HTTPClient.CallContract(context.Background(), ethereum.CallMsg{ @@ -210,9 +213,12 @@ func testMultipleAccountsSubscription(t *testing.T, httpURL, wsURL string, w wal var user0logs []types.Log var user1logs []types.Log var user2logs []types.Log - subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user0.WSClient, &user0logs) - subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user1.WSClient, &user1logs) - subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user2.WSClient, &user2logs) + _, err = subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user0.WSClient, &user0logs) + require.NoError(t, err) + _, err = subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user1.WSClient, &user1logs) + require.NoError(t, err) + _, err = subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user2.WSClient, &user2logs) + require.NoError(t, err) // user1 calls setMessage and setMessage2 on deployed smart contract with the account // that was registered as the first in TG @@ -265,13 +271,21 @@ func testMultipleAccountsSubscription(t *testing.T, httpURL, wsURL string, w wal assert.Equal(t, 3, len(user1logs)) // user2 should see three events (two lifecycle events - same as user0) and event with his interaction with setMessage assert.Equal(t, 3, len(user2logs)) + + _, err = user0.HTTPClient.FilterLogs(context.TODO(), ethereum.FilterQuery{ + Addresses: []gethcommon.Address{contractReceipt.ContractAddress}, + FromBlock: big.NewInt(0), + ToBlock: big.NewInt(10000), + Topics: nil, + }) + require.NoError(t, err) } func testSubscriptionTopics(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { - user0, err := NewUser([]wallet.Wallet{w}, httpURL, wsURL) + user0, err := NewGatewayUser([]wallet.Wallet{w}, httpURL, wsURL) require.NoError(t, err) - user1, err := NewUser([]wallet.Wallet{datagenerator.RandomWallet(integration.TenChainID), datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + user1, err := NewGatewayUser([]wallet.Wallet{datagenerator.RandomWallet(integration.TenChainID), datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) require.NoError(t, err) // register all the accounts for that user @@ -319,6 +333,12 @@ func testSubscriptionTopics(t *testing.T, httpURL, wsURL string, w wallet.Wallet contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user0.HTTPClient, signedTx.Hash(), time.Minute) require.NoError(t, err) + tx, _, err := user0.HTTPClient.TransactionByHash(context.Background(), signedTx.Hash()) + if err != nil { + return + } + require.Equal(t, signedTx.Hash(), tx.Hash()) + // user0 subscribes to all events from that smart contract, user1 only an event with a topic of his first account var user0logs []types.Log var user1logs []types.Log @@ -326,8 +346,10 @@ func testSubscriptionTopics(t *testing.T, httpURL, wsURL string, w wallet.Wallet t1 := gethcommon.BytesToHash(user1.Wallets[1].Address().Bytes()) topics = append(topics, nil) topics = append(topics, []gethcommon.Hash{t1}) - subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user0.WSClient, &user0logs) - subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, topics, user1.WSClient, &user1logs) + _, err = subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user0.WSClient, &user0logs) + require.NoError(t, err) + _, err = subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, topics, user1.WSClient, &user1logs) + require.NoError(t, err) // user0 calls setMessage on deployed smart contract with the account twice and expects two events _, err = integrationCommon.InteractWithSmartContract(user0.HTTPClient, user0.Wallets[0], eventsContractABI, "setMessage", "user0Event1", contractReceipt.ContractAddress) @@ -394,28 +416,41 @@ func testErrorHandling(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { // make requests to geth for comparison for _, req := range []string{ + `{"jsonrpc":"2.0","method":"eth_getLogs","params":[[]],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_getLogs","params":[{"topics":[]}],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_getLogs","params":[{"fromBlock":"0x387","topics":["0xc6d8c0af6d21f291e7c359603aa97e0ed500f04db6e983b9fce75a91c6b8da6b"]}],"id":1}`, + //`{"jsonrpc":"2.0","method":"eth_subscribe","params":["logs"],"id":1}`, + //`{"jsonrpc":"2.0","method":"eth_subscribe","params":["logs",{"topics":[]}],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_blockNumber","params":[],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_blockNumber","params": [],"id":1}`, // test caching + `{"jsonrpc":"2.0","method":"eth_gasPrice","params": [],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_gasPrice","params": [],"id":1}`, // test caching + `{"jsonrpc":"2.0","method":"eth_getBlockByNumber","params": ["latest", false],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_feeHistory","params":[1, "latest", [50]],"id":1}`, `{"jsonrpc":"2.0","method":"eth_getBalance","params":["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "latest"],"id":1}`, `{"jsonrpc":"2.0","method":"eth_getBalance","params":[],"id":1}`, - `{"jsonrpc":"2.0","method":"eth_getgetget","params":["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "latest"],"id":1}`, + //`{"jsonrpc":"2.0","method":"eth_getgetget","params":["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "latest"],"id":1}`, `{"method":"eth_getBalance","params":["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "latest"],"id":1}`, `{"jsonrpc":"2.0","method":"eth_getBalance","params":["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "latest"],"id":1,"extra":"extra_field"}`, `{"jsonrpc":"2.0","method":"eth_sendTransaction","params":[["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "0x1234"]],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_getTransactionByHash","params":["0x0000000000000000000000000000000000000000000000000000000000000000"],"id":1}`, } { // ensure the geth request is issued correctly (should return 200 ok with jsonRPCError) _, response, err := httputil.PostDataJSON(ogClient.HTTP(), []byte(req)) require.NoError(t, err) + fmt.Printf("Resp: %s", response) // unmarshall the response to JSONRPCMessage - jsonRPCError := wecommon.JSONRPCMessage{} + jsonRPCError := JSONRPCMessage{} err = json.Unmarshal(response, &jsonRPCError) - require.NoError(t, err) + require.NoError(t, err, req, response) // repeat the process for the gateway _, response, err = httputil.PostDataJSON(fmt.Sprintf("http://localhost:%d", integration.StartPortTenGatewayUnitTest), []byte(req)) require.NoError(t, err) // we only care about format - jsonRPCError = wecommon.JSONRPCMessage{} + jsonRPCError = JSONRPCMessage{} err = json.Unmarshal(response, &jsonRPCError) require.NoError(t, err) } @@ -473,7 +508,7 @@ func testErrorsRevertedArePassed(t *testing.T, httpURL, wsURL string, w wallet.W // convert error to WE error errBytes, err := json.Marshal(err) require.NoError(t, err) - weError := wecommon.JSONError{} + weError := JSONError{} err = json.Unmarshal(errBytes, &weError) require.NoError(t, err) require.Equal(t, "execution reverted: Forced require", weError.Message) @@ -502,9 +537,12 @@ func testErrorsRevertedArePassed(t *testing.T, httpURL, wsURL string, w wallet.W func testUnsubscribe(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { // create a user with multiple accounts - user, err := NewUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + user, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + require.NoError(t, err) + testlog.Logger().Info("Created user with encryption token", "t", user.tgClient.UserID()) + + _, err = user.HTTPClient.ChainID(context.Background()) require.NoError(t, err) - testlog.Logger().Info("Created user with encryption token: %s\n", user.tgClient.UserID()) // register all the accounts for the user err = user.RegisterAccounts() @@ -529,11 +567,12 @@ func testUnsubscribe(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user.HTTPClient, signedTx.Hash(), time.Minute) require.NoError(t, err) - testlog.Logger().Info("Deployed contract address: ", contractReceipt.ContractAddress) + testlog.Logger().Info("Deployed contract address: ", "addr", contractReceipt.ContractAddress) // subscribe to an event var userLogs []types.Log - subscription := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs) + subscription, err := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs) + require.NoError(t, err) // make an action that will trigger events _, err = integrationCommon.InteractWithSmartContract(user.HTTPClient, user.Wallets[0], eventsContractABI, "setMessage", "foo", contractReceipt.ContractAddress) @@ -554,9 +593,12 @@ func testUnsubscribe(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { func testClosingConnectionWhileSubscribed(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { // create a user with multiple accounts - user, err := NewUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + user, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + require.NoError(t, err) + testlog.Logger().Info("Created user with encryption token", "t", user.tgClient.UserID()) + + _, err = user.HTTPClient.ChainID(context.Background()) require.NoError(t, err) - testlog.Logger().Info("Created user with encryption token: %s\n", user.tgClient.UserID()) // register all the accounts for the user err = user.RegisterAccounts() @@ -581,11 +623,12 @@ func testClosingConnectionWhileSubscribed(t *testing.T, httpURL, wsURL string, w contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user.HTTPClient, signedTx.Hash(), time.Minute) require.NoError(t, err) - testlog.Logger().Info("Deployed contract address: ", contractReceipt.ContractAddress) + testlog.Logger().Info("Deployed contract address: ", "addr", contractReceipt.ContractAddress) // subscribe to an event var userLogs []types.Log - subscription := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs) + subscription, err := subscribeToEvents([]gethcommon.Address{contractReceipt.ContractAddress}, nil, user.WSClient, &userLogs) + require.NoError(t, err) // Close the websocket connection and make sure nothing breaks, but user does not receive events user.WSClient.Close() @@ -613,7 +656,7 @@ func testClosingConnectionWhileSubscribed(t *testing.T, httpURL, wsURL string, w } func testDifferentMessagesOnRegister(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { - user, err := NewUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + user, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) require.NoError(t, err) testlog.Logger().Info("Created user with encryption token: %s\n", user.tgClient.UserID()) @@ -627,19 +670,19 @@ func testDifferentMessagesOnRegister(t *testing.T, httpURL, wsURL string, w wall } func testInvokeNonSensitiveMethod(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { - user, err := NewUser([]wallet.Wallet{w}, httpURL, wsURL) + user, err := NewGatewayUser([]wallet.Wallet{w}, httpURL, wsURL) require.NoError(t, err) // call one of the non-sensitive methods with unauthenticated user // and make sure gateway is not complaining about not having viewing keys - respBody := makeHTTPEthJSONReq(httpURL, rpc.ChainID, user.tgClient.UserID(), nil) - if strings.Contains(string(respBody), fmt.Sprintf("method %s cannot be called with an unauthorised client - no signed viewing keys found", rpc.ChainID)) { - t.Errorf("sensitive method called without authenticating viewingkeys and did fail because of it: %s", rpc.ChainID) + respBody := makeHTTPEthJSONReq(httpURL, tenrpc.ChainID, user.tgClient.UserID(), nil) + if strings.Contains(string(respBody), fmt.Sprintf("method %s cannot be called with an unauthorised client - no signed viewing keys found", tenrpc.ChainID)) { + t.Errorf("sensitive method called without authenticating viewingkeys and did fail because of it: %s", tenrpc.ChainID) } } func testGetStorageAtForReturningUserID(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { - user, err := NewUser([]wallet.Wallet{w}, httpURL, wsURL) + user, err := NewGatewayUser([]wallet.Wallet{w}, httpURL, wsURL) require.NoError(t, err) type JSONResponse struct { @@ -648,24 +691,24 @@ func testGetStorageAtForReturningUserID(t *testing.T, httpURL, wsURL string, w w var response JSONResponse // make a request to GetStorageAt with correct parameters to get userID that exists in the database - respBody := makeHTTPEthJSONReq(httpURL, rpc.GetStorageAt, user.tgClient.UserID(), []interface{}{"getUserID", "0", nil}) + respBody := makeHTTPEthJSONReq(httpURL, tenrpc.GetStorageAt, user.tgClient.UserID(), []interface{}{wecommon.GetStorageAtUserIDRequestMethodName, "0", nil}) if err = json.Unmarshal(respBody, &response); err != nil { t.Error("Unable to unmarshal response") } - if response.Result != user.tgClient.UserID() { + if !bytes.Equal(gethcommon.FromHex(response.Result), user.tgClient.UserIDBytes()) { t.Errorf("Wrong UserID returned. Expected: %s, received: %s", user.tgClient.UserID(), response.Result) } // make a request to GetStorageAt with correct parameters to get userID, but with wrong userID - respBody2 := makeHTTPEthJSONReq(httpURL, rpc.GetStorageAt, "invalid_user_id", []interface{}{"getUserID", "0", nil}) - if !strings.Contains(string(respBody2), "method eth_getStorageAt cannot be called with an unauthorised client - no signed viewing keys found") { - t.Error("eth_getStorageAt did not respond with error: method eth_getStorageAt cannot be called with an unauthorised client - no signed viewing keys found") + respBody2 := makeHTTPEthJSONReq(httpURL, tenrpc.GetStorageAt, "0x0000000000000000000000000000000000000001", []interface{}{wecommon.GetStorageAtUserIDRequestMethodName, "0", nil}) + if !strings.Contains(string(respBody2), "not found") { + t.Error("eth_getStorageAt did not respond with not found error") } // make a request to GetStorageAt with wrong parameters to get userID, but correct userID - respBody3 := makeHTTPEthJSONReq(httpURL, rpc.GetStorageAt, user.tgClient.UserID(), []interface{}{"abc", "0", nil}) - if !strings.Contains(string(respBody3), "method eth_getStorageAt cannot be called with an unauthorised client - no signed viewing keys found") { - t.Error("eth_getStorageAt did not respond with error: no signed viewing keys found") + respBody3 := makeHTTPEthJSONReq(httpURL, tenrpc.GetStorageAt, user.tgClient.UserID(), []interface{}{"0x0000000000000000000000000000000000000001", "0", nil}) + if !strings.Contains(string(respBody3), "illegal access") { + t.Error("eth_getStorageAt did not respond with error: illegal access") } } @@ -759,7 +802,7 @@ func createTenNetwork(t *testing.T, startPort int) { func waitServerIsReady(serverAddr string) error { for now := time.Now(); time.Since(now) < 30*time.Second; time.Sleep(500 * time.Millisecond) { - statusCode, _, err := fasthttp.Get(nil, fmt.Sprintf("%s/health/", serverAddr)) + statusCode, _, err := fasthttp.Get(nil, fmt.Sprintf("%s/v1/health/", serverAddr)) if err != nil { // give it time to boot up if strings.Contains(err.Error(), "connection") { @@ -778,13 +821,13 @@ func waitServerIsReady(serverAddr string) error { func getFeeAndGas(client *ethclient.Client, wallet wallet.Wallet, legacyTx *types.LegacyTx) error { tx := types.NewTx(legacyTx) - history, err := client.FeeHistory(context.Background(), 1, nil, []float64{}) + history, err := client.FeeHistory(context.Background(), 1, nil, nil) if err != nil || len(history.BaseFee) == 0 { return err } estimate, err := client.EstimateGas(context.Background(), ethereum.CallMsg{ - From: wallet.Address(), + // From: wallet.Address(), To: tx.To(), Value: tx.Value(), Data: tx.Data(), @@ -825,19 +868,20 @@ func transferETHToAddress(client *ethclient.Client, wallet wallet.Wallet, toAddr return integrationCommon.AwaitReceiptEth(context.Background(), client, signedTx.Hash(), 30*time.Second) } -func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Hash, client *ethclient.Client, logs *[]types.Log) ethereum.Subscription { +func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Hash, client *ethclient.Client, logs *[]types.Log) (ethereum.Subscription, error) { // Make a subscription filterQuery := ethereum.FilterQuery{ Addresses: addresses, - FromBlock: big.NewInt(0), // todo (@ziga) - without those we get errors - fix that and make them configurable - ToBlock: big.NewInt(10000), - Topics: topics, + FromBlock: big.NewInt(2), + // ToBlock: big.NewInt(10000), + Topics: topics, } logsCh := make(chan types.Log) subscription, err := client.SubscribeFilterLogs(context.Background(), filterQuery, logsCh) if err != nil { - testlog.Logger().Info("Failed to subscribe to filter logs: %v", log2.ErrKey, err) + testlog.Logger().Info("Failed to subscribe to filter logs", log2.ErrKey, err) + return nil, err } // Listen for logs in a goroutine @@ -845,7 +889,7 @@ func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Has for { select { case err := <-subscription.Err(): - testlog.Logger().Info("Error from logs subscription: %v", log2.ErrKey, err) + testlog.Logger().Info("Error from logs subscription", log2.ErrKey, err) return case log := <-logsCh: // append logs to be visible from the main thread @@ -854,5 +898,5 @@ func subscribeToEvents(addresses []gethcommon.Address, topics [][]gethcommon.Has } }() - return subscription + return subscription, nil } diff --git a/integration/simulation/devnetwork/dev_network.go b/integration/simulation/devnetwork/dev_network.go index 3ac1be6060..25cd70f86b 100644 --- a/integration/simulation/devnetwork/dev_network.go +++ b/integration/simulation/devnetwork/dev_network.go @@ -7,10 +7,11 @@ import ( "sync" "time" + "github.com/ten-protocol/go-ten/tools/walletextension" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/integration/common/testlog" "github.com/ten-protocol/go-ten/integration/simulation/network" - gatewaycfg "github.com/ten-protocol/go-ten/tools/walletextension/config" - "github.com/ten-protocol/go-ten/tools/walletextension/container" "github.com/ten-protocol/go-ten/go/ethadapter" @@ -57,7 +58,7 @@ type InMemDevNetwork struct { tenConfig *TenConfig tenSequencer *InMemNodeOperator tenValidators []*InMemNodeOperator - tenGatewayContainer *container.WalletExtensionContainer + tenGatewayContainer *walletextension.Container faucet userwallet.User faucetLock sync.Mutex @@ -192,7 +193,7 @@ func (s *InMemDevNetwork) startTenGateway() { validatorWS := validator.HostRPCWSAddress() // remove ws:// prefix for the gateway config validatorWS = validatorWS[len("ws://"):] - cfg := gatewaycfg.Config{ + cfg := wecommon.Config{ WalletExtensionHost: "127.0.0.1", WalletExtensionPortHTTP: _gwHTTPPort, WalletExtensionPortWS: _gwWSPort, @@ -203,7 +204,7 @@ func (s *InMemDevNetwork) startTenGateway() { DBType: "sqlite", TenChainID: integration.TenChainID, } - tenGWContainer := container.NewWalletExtensionContainerFromConfig(cfg, s.logger) + tenGWContainer := walletextension.NewContainerFromConfig(cfg, s.logger) go func() { fmt.Println("Starting Ten Gateway, HTTP Port:", _gwHTTPPort, "WS Port:", _gwWSPort) err := tenGWContainer.Start() diff --git a/integration/simulation/simulation.go b/integration/simulation/simulation.go index d2ab17b953..68dc37e895 100644 --- a/integration/simulation/simulation.go +++ b/integration/simulation/simulation.go @@ -11,7 +11,6 @@ import ( "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/filters" gethparams "github.com/ethereum/go-ethereum/params" "github.com/ten-protocol/go-ten/contracts/generated/MessageBus" "github.com/ten-protocol/go-ten/go/common" @@ -192,7 +191,7 @@ func (s *Simulation) trackLogs() { channel := make(chan common.IDAndLog, 1000) // To exercise the filtering mechanism, we subscribe for HOC events only, ignoring POC events. - hocFilter := filters.FilterCriteria{ + hocFilter := common.FilterCriteria{ Addresses: []gethcommon.Address{gethcommon.HexToAddress("0x" + testcommon.HOCAddr)}, } sub, err := client.SubscribeFilterLogs(context.Background(), hocFilter, channel) diff --git a/integration/simulation/validate_chain.go b/integration/simulation/validate_chain.go index c7010f1747..5fa4809065 100644 --- a/integration/simulation/validate_chain.go +++ b/integration/simulation/validate_chain.go @@ -627,7 +627,7 @@ func checkSubscribedLogs(t *testing.T, owner string, channel chan common.IDAndLo func checkSnapshotLogs(t *testing.T, client *obsclient.AuthObsClient) int { // To exercise the filtering mechanism, we get a snapshot for HOC events only, ignoring POC events. - hocFilter := common.FilterCriteriaJSON{ + hocFilter := common.FilterCriteria{ Addresses: []gethcommon.Address{gethcommon.HexToAddress("0x" + testcommon.HOCAddr)}, } logs, err := client.GetLogs(context.Background(), hocFilter) diff --git a/lib/gethfork/rpc/client.go b/lib/gethfork/rpc/client.go index de7fd5396a..805f375441 100644 --- a/lib/gethfork/rpc/client.go +++ b/lib/gethfork/rpc/client.go @@ -76,7 +76,7 @@ type BatchElem struct { // Client represents a connection to an RPC server. type Client struct { - UserID string + UserID []byte idgen func() ID // for subscriptions isHTTP bool // connection type: http, ws or ipc services *serviceRegistry diff --git a/lib/gethfork/rpc/client_opt.go b/lib/gethfork/rpc/client_opt.go index 0eae2e6134..32435e8efb 100644 --- a/lib/gethfork/rpc/client_opt.go +++ b/lib/gethfork/rpc/client_opt.go @@ -28,7 +28,7 @@ type ClientOption interface { } type clientConfig struct { - UserID string + UserID []byte // HTTP settings httpClient *http.Client httpHeaders http.Header diff --git a/lib/gethfork/rpc/handler.go b/lib/gethfork/rpc/handler.go index 856ef8b2c4..8836f3c6b5 100644 --- a/lib/gethfork/rpc/handler.go +++ b/lib/gethfork/rpc/handler.go @@ -25,6 +25,8 @@ import ( "sync" "time" + "github.com/status-im/keycard-go/hexutils" + "github.com/ethereum/go-ethereum/log" ) @@ -65,7 +67,7 @@ type handler struct { subLock sync.Mutex serverSubs map[ID]*Subscription - UserID string + UserID []byte } type callProc struct { @@ -73,7 +75,7 @@ type callProc struct { notifiers []*Notifier } -func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int, userID string) *handler { +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int, userID []byte) *handler { rootCtx, cancelRoot := context.WithCancel(connCtx) h := &handler{ reg: reg, @@ -386,6 +388,10 @@ func (h *handler) startCallProc(fn func(*callProc)) { ctx, cancel := context.WithCancel(h.rootCtx) defer h.callWG.Done() defer cancel() + // handle the case when normal rpc calls are made over a ws connection + if ctx.Value(GWTokenKey{}) == nil { + ctx = context.WithValue(ctx, GWTokenKey{}, hexutils.BytesToHex(h.UserID)) + } fn(&callProc{ctx: ctx}) }() } diff --git a/lib/gethfork/rpc/inproc.go b/lib/gethfork/rpc/inproc.go index 2a5d400b19..835825334b 100644 --- a/lib/gethfork/rpc/inproc.go +++ b/lib/gethfork/rpc/inproc.go @@ -27,7 +27,7 @@ func DialInProc(handler *Server) *Client { cfg := new(clientConfig) c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) { p1, p2 := net.Pipe() - go handler.ServeCodec(NewCodec(p1), 0, "") + go handler.ServeCodec(NewCodec(p1), 0, nil) return NewCodec(p2), nil }) return c diff --git a/lib/gethfork/rpc/ipc.go b/lib/gethfork/rpc/ipc.go index 9db95dc467..5f45a4cb07 100644 --- a/lib/gethfork/rpc/ipc.go +++ b/lib/gethfork/rpc/ipc.go @@ -35,7 +35,7 @@ func (s *Server) ServeListener(l net.Listener) error { return err } log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr()) - go s.ServeCodec(NewCodec(conn), 0, "") + go s.ServeCodec(NewCodec(conn), 0, nil) } } diff --git a/lib/gethfork/rpc/server.go b/lib/gethfork/rpc/server.go index e0b96ad53f..686afe5a02 100644 --- a/lib/gethfork/rpc/server.go +++ b/lib/gethfork/rpc/server.go @@ -18,7 +18,6 @@ package rpc import ( "context" - "fmt" "io" "sync" "sync/atomic" @@ -103,7 +102,7 @@ func (s *Server) RegisterName(name string, receiver interface{}) error { // server is stopped. In either case the codec is closed. // // Note that codec options are no longer supported. -func (s *Server) ServeCodec(codec ServerCodec, _ CodecOption, userID string) { +func (s *Server) ServeCodec(codec ServerCodec, _ CodecOption, userID []byte) { defer codec.close() if !s.trackCodec(codec) { @@ -149,7 +148,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { return } - h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit, "") + h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit, nil) h.allowSubscribe = false defer h.close(io.EOF, nil) @@ -157,7 +156,6 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { if err != nil { if err != io.EOF { resp := errorMessage(&invalidMessageError{"parse error"}) - fmt.Printf(">> Parse error %s. requests: %v\n", err, reqs) codec.writeJSON(ctx, resp, true) } return diff --git a/lib/gethfork/rpc/subscription.go b/lib/gethfork/rpc/subscription.go index c7fbc6c8c2..ca432190fb 100644 --- a/lib/gethfork/rpc/subscription.go +++ b/lib/gethfork/rpc/subscription.go @@ -101,7 +101,7 @@ func NotifierFromContext(ctx context.Context) (*Notifier, bool) { // Server callbacks use the notifier to send notifications. type Notifier struct { h *handler - UserID string // added by TEN + UserID []byte // added by TEN namespace string mu sync.Mutex diff --git a/lib/gethfork/rpc/websocket.go b/lib/gethfork/rpc/websocket.go index 605931c47d..9db7e19e40 100644 --- a/lib/gethfork/rpc/websocket.go +++ b/lib/gethfork/rpc/websocket.go @@ -27,6 +27,10 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/common" + + "github.com/ten-protocol/go-ten/go/common/viewingkey" + mapset "github.com/deckarep/golang-set/v2" "github.com/ethereum/go-ethereum/log" "github.com/gorilla/websocket" @@ -65,12 +69,16 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { }) } -func extractUserID(ctx context.Context) string { +func extractUserID(ctx context.Context) []byte { token, ok := ctx.Value(GWTokenKey{}).(string) if !ok { - return "" + return nil + } + userID := common.FromHex(token) + if len(userID) != viewingkey.UserIDLength { + return nil } - return token + return userID } // wsHandshakeValidator returns a handler that verifies the origin during the diff --git a/testnet/launcher/eth2network/docker.go b/testnet/launcher/eth2network/docker.go index 65685833f5..144e13dbf4 100644 --- a/testnet/launcher/eth2network/docker.go +++ b/testnet/launcher/eth2network/docker.go @@ -50,7 +50,7 @@ func (n *Eth2Network) Start() error { } func (n *Eth2Network) IsReady() error { - timeout := 10 * time.Minute + timeout := 20 * time.Minute // this can be reduced when we no longer download the ethereum binaries interval := 2 * time.Second var dial *ethclient.Client var err error diff --git a/testnet/launcher/gateway/docker.go b/testnet/launcher/gateway/docker.go index 1827516cf9..3bde82b592 100644 --- a/testnet/launcher/gateway/docker.go +++ b/testnet/launcher/gateway/docker.go @@ -43,7 +43,7 @@ func (n *DockerGateway) IsReady() error { interval := time.Second return retry.Do(func() error { - statusCode, _, err := fasthttp.Get(nil, fmt.Sprintf("http://127.0.0.1:%d/health/", n.cfg.gatewayHTTPPort)) + statusCode, _, err := fasthttp.Get(nil, fmt.Sprintf("http://127.0.0.1:%d/v1/health/", n.cfg.gatewayHTTPPort)) if err != nil { return err } diff --git a/tools/tenscan/frontend/pages/_app.tsx b/tools/tenscan/frontend/pages/_app.tsx index 287ecf806d..e8c8d12077 100644 --- a/tools/tenscan/frontend/pages/_app.tsx +++ b/tools/tenscan/frontend/pages/_app.tsx @@ -69,7 +69,7 @@ export default function App({ Component, pageProps }: AppProps) { ogTwitterImage={siteMetadata.siteLogo} ogType={"website"} > - + diff --git a/tools/tenscan/frontend/src/components/layouts/header.tsx b/tools/tenscan/frontend/src/components/layouts/header.tsx index a04de0fea8..3a2e59b40a 100644 --- a/tools/tenscan/frontend/src/components/layouts/header.tsx +++ b/tools/tenscan/frontend/src/components/layouts/header.tsx @@ -14,7 +14,7 @@ export default function Header() {
Logo 1 { - filteredAccounts, err := m.filterAccounts(rpcReq, accounts) - if err != nil { - return nil, err - } - // return filtered clients if we found any - if len(filteredAccounts) > 0 { - accounts = filteredAccounts - } - } - // create clients for all accounts if we didn't find any clients that match the filter or if no topics were provided - return m.createClientsForAccounts(accounts, userPrivateKey) -} - -// filterClients checks if any of the accounts match the filter criteria and returns those accounts -func (m *AccountManager) filterAccounts(rpcReq *wecommon.RPCRequest, accounts []wecommon.AccountDB) ([]wecommon.AccountDB, error) { - var filteredAccounts []wecommon.AccountDB - filterCriteriaJSON, err := json.Marshal(rpcReq.Params[1]) - if err != nil { - return nil, fmt.Errorf("could not marshal filter criteria to JSON. Cause: %w", err) - } - filterCriteria := filters.FilterCriteria{} - if string(filterCriteriaJSON) != emptyFilterCriteria { - err = filterCriteria.UnmarshalJSON(filterCriteriaJSON) - if err != nil { - return nil, fmt.Errorf("could not unmarshal filter criteria from the following JSON: `%s`. Cause: %w", string(filterCriteriaJSON), err) - } - } - - for _, topicCondition := range filterCriteria.Topics { - for _, topic := range topicCondition { - potentialAddr := common.ExtractPotentialAddress(topic) - m.logger.Info(fmt.Sprintf("Potential address (%s) found for the request %s", potentialAddr, rpcReq)) - if potentialAddr != nil { - for _, account := range accounts { - // if we find a match, we append the account to the list of filtered accounts - if bytes.Equal(account.AccountAddress, potentialAddr.Bytes()) { - filteredAccounts = append(filteredAccounts, account) - } - } - } - } - } - - return filteredAccounts, nil -} - -// createClientsForAllAccounts creates ws clients for all accounts for given user and returns them -func (m *AccountManager) createClientsForAccounts(accounts []wecommon.AccountDB, userPrivateKey []byte) ([]rpc.Client, error) { - clients := make([]rpc.Client, 0, len(accounts)) - for _, account := range accounts { - encClient, err := wecommon.CreateEncClient(m.hostRPCBindAddrWS, account.AccountAddress, userPrivateKey, account.Signature, viewingkey.SignatureType(account.SignatureType), m.logger) - if err != nil { - m.logger.Error(fmt.Errorf("error creating new client, %w", err).Error()) - continue - } - clients = append(clients, encClient) - } - return clients, nil -} - -// todo - better way -const notAuthorised = "not authorised" - -var platformAuthorisedCalls = map[string]bool{ - rpc.GetBalance: true, - // rpc.GetCode, //todo - rpc.GetTransactionCount: true, - rpc.GetTransactionReceipt: true, - rpc.GetLogs: true, -} - -func (m *AccountManager) executeCall(rpcReq *wecommon.RPCRequest, rpcResp *interface{}) error { - m.accountsMutex.RLock() - defer m.accountsMutex.RUnlock() - // for Ten RPC requests, it is important we know the sender account for the viewing key encryption/decryption - suggestedClient := m.suggestAccountClient(rpcReq, m.accountClientsHTTP) - - switch { - case suggestedClient != nil: // use the suggested client if there is one - // todo (@ziga) - if we have a suggested client, should we still loop through the other clients if it fails? - // The call data guessing won't often be wrong but there could be edge-cases there - return submitCall(suggestedClient, rpcReq, rpcResp) - - case len(m.accountClientsHTTP) > 0: // try registered clients until there's a successful execution - m.logger.Info(fmt.Sprintf("appropriate client not found, attempting request with up to %d clients", len(m.accountClientsHTTP))) - var err error - for _, client := range m.accountClientsHTTP { - err = submitCall(client, rpcReq, rpcResp) - if err == nil { - // request didn't fail, we don't need to continue trying the other clients - return nil - } - // platform calls return a standard error for calls that are not authorised. - // any other error can be returned early - if platformAuthorisedCalls[rpcReq.Method] && err.Error() != notAuthorised { - return err - } - } - // every attempt errored - return err - - default: // no clients registered, use the unauthenticated one - if rpc.IsSensitiveMethod(rpcReq.Method) { - return fmt.Errorf(ErrNoViewingKey, rpcReq.Method) - } - return m.unauthedClient.Call(rpcResp, rpcReq.Method, rpcReq.Params...) - } -} - -// suggestAccountClient works through various methods to try and guess which available client to use for a request, returns nil if none found -func (m *AccountManager) suggestAccountClient(req *wecommon.RPCRequest, accClients map[gethcommon.Address]*rpc.EncRPCClient) *rpc.EncRPCClient { - if len(accClients) == 1 { - for _, client := range accClients { - // return the first (and only) client - return client - } - } - switch req.Method { - case rpc.Call, rpc.EstimateGas: - return m.handleEthCall(req, accClients) - case rpc.GetBalance: - return extractAddress(0, req.Params, accClients) - case rpc.GetLogs: - return extractAddress(1, req.Params, accClients) - case rpc.GetTransactionCount: - return extractAddress(0, req.Params, accClients) - default: - return nil - } -} - -func extractAddress(pos int, params []interface{}, accClients map[gethcommon.Address]*rpc.EncRPCClient) *rpc.EncRPCClient { - if len(params) < pos+1 { - return nil - } - requestedAddress, err := gethencoding.ExtractAddress(params[pos]) - if err == nil { - return accClients[*requestedAddress] - } - return nil -} - -func (m *AccountManager) handleEthCall(req *wecommon.RPCRequest, accClients map[gethcommon.Address]*rpc.EncRPCClient) *rpc.EncRPCClient { - paramsMap, err := parseParams(req.Params) - if err != nil { - // no further info to deduce calling client - return nil - } - // check if request params had a "from" address and if we had a client for that address - fromClient, found := checkForFromField(paramsMap, accClients) - if found { - return fromClient - } - - // Otherwise, we search the `data` field for an address matching a registered viewing key. - addr, err := searchDataFieldForAccount(paramsMap, accClients) - if err == nil { - return accClients[*addr] - } - return nil -} - -// Many eth RPC requests provide params as first argument in a json map with similar fields (e.g. a `from` field) -func parseParams(args []interface{}) (map[string]interface{}, error) { - if len(args) == 0 { - return nil, fmt.Errorf("no params found to unmarshal") - } - - // only interested in trying first arg - params, ok := args[0].(map[string]interface{}) - if !ok { - callParamsJSON, ok := args[0].([]byte) - if !ok { - return nil, fmt.Errorf("first arg was not a byte array") - } - - err := json.Unmarshal(callParamsJSON, ¶ms) - if err != nil { - return nil, fmt.Errorf("first arg couldn't be unmarshaled into a params map") - } - } - - return params, nil -} - -func checkForFromField(paramsMap map[string]interface{}, accClients map[gethcommon.Address]*rpc.EncRPCClient) (*rpc.EncRPCClient, bool) { - fromVal, found := paramsMap[wecommon.JSONKeyFrom] - if !found { - return nil, false - } - - fromStr, ok := fromVal.(string) - if !ok { - return nil, false - } - - fromAddr := gethcommon.HexToAddress(fromStr) - client, found := accClients[fromAddr] - return client, found -} - -// Extracts the arguments from the request's `data` field. If any of them, after removing padding, match the viewing -// key address, we return that address. Otherwise, we return nil. -func searchDataFieldForAccount(callParams map[string]interface{}, accClients map[gethcommon.Address]*rpc.EncRPCClient) (*gethcommon.Address, error) { - // We ensure that the `data` field is present. - data := callParams[wecommon.JSONKeyData] - if data == nil { - return nil, fmt.Errorf("eth_call request did not have its `data` field set") - } - dataString, ok := data.(string) - if !ok { - return nil, fmt.Errorf("eth_call request's `data` field was not of the expected type `string`") - } - - // We check that the data field is long enough before removing the leading "0x" (1 bytes/2 chars) and the method ID - // (4 bytes/8 chars). - if len(dataString) < 10 { - return nil, fmt.Errorf("data field is not long enough - no known account found in data bytes") - } - dataString = dataString[10:] - - // We split up the arguments in the `data` field. - var dataArgs []string - for i := 0; i < len(dataString); i += ethCallPaddedArgLen { - if i+ethCallPaddedArgLen > len(dataString) { - break - } - dataArgs = append(dataArgs, dataString[i:i+ethCallPaddedArgLen]) - } - - // We iterate over the arguments, looking for an argument that matches a viewing key address - for _, dataArg := range dataArgs { - // If the argument doesn't have the correct padding, it's not an address. - if !strings.HasPrefix(dataArg, ethCallAddrPadding) { - continue - } - - maybeAddress := gethcommon.HexToAddress(dataArg[len(ethCallAddrPadding):]) - if _, ok := accClients[maybeAddress]; ok { - return &maybeAddress, nil - } - } - - return nil, fmt.Errorf("no known account found in data bytes") -} - -func submitCall(client *rpc.EncRPCClient, req *wecommon.RPCRequest, resp *interface{}) error { - if req.Method == rpc.Call || req.Method == rpc.EstimateGas { - // Never modify the original request, as it might be reused. - req = req.Clone() - - // Any method using an ethereum.CallMsg is a sensitive method that requires a viewing key lookup but the 'from' field is not mandatory - // and is often not included from metamask etc. So we ensure it is populated here. - account := client.Account() - var err error - req.Params, err = setFromFieldIfMissing(req.Params, *account) - if err != nil { - return err - } - } - - if req.Method == rpc.GetLogs { - // Never modify the original request, as it might be reused. - req = req.Clone() - - // We add the account to the list of arguments, so we know which account to use to filter the logs and encrypt - // the result. - req.Params = append(req.Params, client.Account().Hex()) - } - - return client.Call(resp, req.Method, req.Params...) -} - -// The enclave requires the `from` field to be set so that it can encrypt the response, but sources like MetaMask often -// don't set it. So we check whether it's present; if absent, we walk through the arguments in the request's `data` -// field, and if any of the arguments match our viewing key address, we set the `from` field to that address. -func setFromFieldIfMissing(args []interface{}, account gethcommon.Address) ([]interface{}, error) { - if len(args) == 0 { - return nil, fmt.Errorf("no params found to unmarshal") - } - - callMsg, err := gethencoding.ExtractEthCallMapString(args[0]) - if err != nil { - return nil, fmt.Errorf("unable to marshal callMsg - %w", err) - } - - // We only modify `eth_call` requests where the `from` field is not set. - if callMsg[gethencoding.CallFieldFrom] != gethcommon.HexToAddress("0x0").Hex() { - return args, nil - } - - // override the existing args - callMsg[gethencoding.CallFieldFrom] = account.Hex() - - // do not modify other existing arguments - request := []interface{}{callMsg} - for i := 1; i < len(args); i++ { - request = append(request, args[i]) - } - - return request, nil -} diff --git a/tools/walletextension/accountmanager/account_manager_test.go b/tools/walletextension/accountmanager/account_manager_test.go deleted file mode 100644 index 23cd3737cd..0000000000 --- a/tools/walletextension/accountmanager/account_manager_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package accountmanager - -import ( - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/ten-protocol/go-ten/go/rpc" -) - -const ( - dataFieldPrefix = "0xmethodID" - padding = "000000000000000000000000" - viewingKeyAddressHex = "71C7656EC7ab88b098defB751B7401B5f6d8976F" - viewingKeyAddressHexPadded = padding + viewingKeyAddressHex - otherAddressHexPadded = padding + "71C7656EC7ab88b098defB751B7401B5f6d8976E" // Differs only in the final byte. -) - -var ( - viewingKeyAddressOne = common.HexToAddress("0x" + viewingKeyAddressHex) - viewingKeyAddressTwo = common.HexToAddress("0x71C7656EC7ab88b098defB751B7401B5f6d8976D") // Not in the data field. - accClients = map[common.Address]*rpc.EncRPCClient{ - viewingKeyAddressOne: nil, - viewingKeyAddressTwo: nil, - } -) - -func TestCanSearchDataFieldForFrom(t *testing.T) { - callParams := map[string]interface{}{"data": dataFieldPrefix + otherAddressHexPadded + viewingKeyAddressHexPadded} - address, err := searchDataFieldForAccount(callParams, accClients) - if err != nil { - t.Fatalf("did not expect an error but got %s", err) - } - if *address != viewingKeyAddressOne { - t.Fatal("did not find correct viewing key address in `data` field") - } -} - -func TestCanSearchDataFieldWhenHasUnexpectedLength(t *testing.T) { - incorrectLengthArg := "arg2xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" // Only 31 bytes. - callParams := map[string]interface{}{"data": dataFieldPrefix + otherAddressHexPadded + viewingKeyAddressHexPadded + incorrectLengthArg} - address, err := searchDataFieldForAccount(callParams, accClients) - if err != nil { - t.Fatalf("did not expect an error but got %s", err) - } - if *address != viewingKeyAddressOne { - t.Fatal("did not find correct viewing key address in `data` field") - } -} - -func TestErrorsWhenDataFieldIsMissing(t *testing.T) { - _, err := searchDataFieldForAccount(make(map[string]interface{}), accClients) - - if err == nil { - t.Fatal("`data` field was missing but not error was thrown") - } -} - -func TestDataFieldTooShort(t *testing.T) { - callParams := map[string]interface{}{"data": "tooshort"} - address, err := searchDataFieldForAccount(callParams, accClients) - if err == nil { - t.Fatal("expected an error but got none") - } - if address != nil { - t.Fatal("`data` field was too short but address was found anyway") - } -} diff --git a/tools/walletextension/api/server.go b/tools/walletextension/api/server.go index ae1fac8252..3552888ae0 100644 --- a/tools/walletextension/api/server.go +++ b/tools/walletextension/api/server.go @@ -1,16 +1,10 @@ package api import ( - "context" "embed" "fmt" "io/fs" "net/http" - "time" - - "github.com/ten-protocol/go-ten/lib/gethfork/node" - - "github.com/ten-protocol/go-ten/tools/walletextension/common" ) //go:embed all:static @@ -20,67 +14,11 @@ const ( staticDir = "static" ) -// Server is a wrapper for the http server -type Server struct { - server *http.Server -} - -// Start starts the server in its own goroutine and returns an error chan where errors can be monitored -func (s *Server) Start() chan error { - errChan := make(chan error) - go func() { - // start the server and serve any errors over the channel - errChan <- s.server.ListenAndServe() - }() - return errChan -} - -// Stop synchronously stops the server -func (s *Server) Stop() error { - return s.server.Shutdown(context.Background()) -} - -// NewHTTPServer returns the HTTP server for the WE -func NewHTTPServer(address string, routes []node.Route) *Server { - return &Server{ - server: createHTTPServer(address, routes), - } -} - -// NewWSServer returns the WS server for the WE -func NewWSServer(address string, routes []node.Route) *Server { - return &Server{ - server: createWSServer(address, routes), - } -} - -func createHTTPServer(address string, routes []node.Route) *http.Server { - serveMux := http.NewServeMux() - - // Handles Ethereum JSON-RPC requests received over HTTP. - for _, route := range routes { - serveMux.HandleFunc(route.Name, route.Func) - } - +func StaticFilesHandler(prefix string) http.Handler { // Serves the web assets for the management of viewing keys. - noPrefixStaticFiles, err := fs.Sub(staticFiles, staticDir) + fileSystem, err := fs.Sub(staticFiles, staticDir) if err != nil { panic(fmt.Sprintf("could not serve static files. Cause: %s", err)) } - serveMux.Handle(common.PathObscuroGateway, http.StripPrefix(common.PathObscuroGateway, http.FileServer(http.FS(noPrefixStaticFiles)))) - - // Creates the actual http server with a ReadHeaderTimeout to avoid Potential Slowloris Attack - server := &http.Server{Addr: address, Handler: serveMux, ReadHeaderTimeout: common.ReaderHeadTimeout} - return server -} - -func createWSServer(address string, routes []node.Route) *http.Server { - serveMux := http.NewServeMux() - - // Handles Ethereum JSON-RPC requests received over HTTP. - for _, route := range routes { - serveMux.HandleFunc(route.Name, route.Func) - } - - return &http.Server{Addr: address, Handler: serveMux, ReadHeaderTimeout: 10 * time.Second} + return http.StripPrefix(prefix, http.FileServer(http.FS(fileSystem))) } diff --git a/tools/walletextension/api/utils.go b/tools/walletextension/api/utils.go deleted file mode 100644 index 4ca99ac44d..0000000000 --- a/tools/walletextension/api/utils.go +++ /dev/null @@ -1,152 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "strings" - - gethlog "github.com/ethereum/go-ethereum/log" - "github.com/ten-protocol/go-ten/go/common/errutil" - "github.com/ten-protocol/go-ten/go/common/log" - "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/userconn" -) - -func parseRequest(body []byte) (*common.RPCRequest, error) { - // We unmarshal the JSON request - var reqJSONMap map[string]json.RawMessage - err := json.Unmarshal(body, &reqJSONMap) - if err != nil { - return nil, fmt.Errorf("could not unmarshal JSON-RPC request body to JSON: %s. "+ - "If you're trying to generate a viewing key, visit %s", err, common.PathViewingKeys) - } - - reqID := reqJSONMap[common.JSONKeyID] - var method string - err = json.Unmarshal(reqJSONMap[common.JSONKeyMethod], &method) - if err != nil { - return nil, fmt.Errorf("could not unmarshal method string from JSON-RPC request body: %s ; %w", string(body), err) - } - - // we extract the params into a JSON list - var params []interface{} - // params key is optional in JSON-RPC request - _, exists := reqJSONMap[common.JSONKeyParams] - if exists { - err = json.Unmarshal(reqJSONMap[common.JSONKeyParams], ¶ms) - if err != nil { - return nil, fmt.Errorf("could not unmarshal params list from JSON-RPC request body: %s ; %w", string(body), err) - } - } else { - params = []interface{}{} - } - - return &common.RPCRequest{ - ID: reqID, - Method: method, - Params: params, - }, nil -} - -func getQueryParameter(params map[string]string, selectedParameter string) (string, error) { - value, exists := params[selectedParameter] - if !exists { - return "", fmt.Errorf("parameter '%s' is not in the query params", selectedParameter) - } - - return value, nil -} - -// getUserID returns userID from query params / url of the URL -// it always first tries to get userID from a query parameter `u` or `token` (`u` parameter will become deprecated) -// if it fails to get userID from a query parameter it tries to get it from the URL and it needs position as the second parameter -func getUserID(conn userconn.UserConn, userIDPosition int) (string, error) { - // try getting userID (`token`) from query parameters and return it if successful - userID, err := getQueryParameter(conn.ReadRequestParams(), common.EncryptedTokenQueryParameter) - if err == nil { - if len(userID) != common.MessageUserIDLen { - return "", fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d", len(userID), common.MessageUserIDLen)) - } - return userID, err - } - - // try getting userID(`u`) from query parameters and return it if successful - userID, err = getQueryParameter(conn.ReadRequestParams(), common.UserQueryParameter) - if err == nil { - if len(userID) != common.MessageUserIDLen { - return "", fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d", len(userID), common.MessageUserIDLen)) - } - return userID, err - } - - // Alternatively, try to get it from URL path - // This is a temporary hack to work around hardhat bug which causes hardhat to ignore query parameters. - // It is unsafe because https encrypts query parameters, - // but not URL itself and will be removed once hardhat bug is resolved. - path := conn.GetHTTPRequest().URL.Path - path = strings.Trim(path, "/") - parts := strings.Split(path, "/") - - // our URLs, which require userID, have following pattern: // - // userID can be only on second or third position - if len(parts) != userIDPosition+1 { - return "", fmt.Errorf("URL structure of the request looks wrong") - } - userID = parts[userIDPosition] - - // Check if userID has the correct length - if len(userID) != common.MessageUserIDLen { - return "", fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d", len(userID), common.MessageUserIDLen)) - } - - return userID, nil -} - -func handleEthError(req *common.RPCRequest, conn userconn.UserConn, logger gethlog.Logger, err error) { - var method string - id := json.RawMessage("1") - if req != nil { - method = req.Method - id = req.ID - } - - errjson := &common.JSONError{ - Code: 0, - Message: err.Error(), - Data: nil, - } - - jsonRPRCError := common.JSONRPCMessage{ - Version: "2.0", - ID: id, - Method: method, - Params: nil, - Error: errjson, - Result: nil, - } - - if evmError, ok := err.(errutil.EVMSerialisableError); ok { //nolint: errorlint - jsonRPRCError.Error.Data = evmError.Reason - jsonRPRCError.Error.Code = evmError.ErrorCode() - } - - errBytes, err := json.Marshal(jsonRPRCError) - if err != nil { - logger.Error("unable to marshal error - %w", log.ErrKey, err) - return - } - - logger.Info(fmt.Sprintf("Forwarding %s error response from Obscuro node: %s", method, errBytes)) - - if err = conn.WriteResponse(errBytes); err != nil { - logger.Error("unable to write response back", log.ErrKey, err) - } -} - -func handleError(conn userconn.UserConn, logger gethlog.Logger, err error) { - logger.Warn("error processing request - Forwarding response to user", log.ErrKey, err) - - if err = conn.WriteResponse([]byte(err.Error())); err != nil { - logger.Error("unable to write response back", log.ErrKey, err) - } -} diff --git a/tools/walletextension/cache/RistrettoCache.go b/tools/walletextension/cache/RistrettoCache.go index af417115b7..a72ca5f606 100644 --- a/tools/walletextension/cache/RistrettoCache.go +++ b/tools/walletextension/cache/RistrettoCache.go @@ -9,19 +9,19 @@ import ( ) const ( - numCounters = 1e7 // number of keys to track frequency of (10M). - maxCost = 1 << 30 // maximum cost of cache (1GB). - bufferItems = 64 // number of keys per Get buffer. - defaultConst = 1 // default cost of cache. + numCounters = 1e7 // number of keys to track frequency of (10M). + maxCost = 1_000_000 // 1 million entries + bufferItems = 64 // number of keys per Get buffer. + defaultCost = 1 // default cost of cache. ) -type RistrettoCache struct { +type ristrettoCache struct { cache *ristretto.Cache quit chan struct{} } -// NewRistrettoCache returns a new RistrettoCache. -func NewRistrettoCache(logger log.Logger) (*RistrettoCache, error) { +// NewRistrettoCache returns a new ristrettoCache. +func NewRistrettoCache(logger log.Logger) (Cache, error) { cache, err := ristretto.NewCache(&ristretto.Config{ NumCounters: numCounters, MaxCost: maxCost, @@ -32,7 +32,7 @@ func NewRistrettoCache(logger log.Logger) (*RistrettoCache, error) { return nil, err } - c := &RistrettoCache{ + c := &ristrettoCache{ cache: cache, quit: make(chan struct{}), } @@ -44,29 +44,21 @@ func NewRistrettoCache(logger log.Logger) (*RistrettoCache, error) { } // Set adds the key and value to the cache. -func (c *RistrettoCache) Set(key string, value map[string]interface{}, ttl time.Duration) bool { - return c.cache.SetWithTTL(key, value, defaultConst, ttl) +func (c *ristrettoCache) Set(key []byte, value any, ttl time.Duration) bool { + return c.cache.SetWithTTL(key, value, defaultCost, ttl) } // Get returns the value for the given key if it exists. -func (c *RistrettoCache) Get(key string) (value map[string]interface{}, ok bool) { - item, found := c.cache.Get(key) - if !found { - return nil, false - } - - // Assuming the item is stored as a map[string]interface{}, otherwise you need to type assert to the correct type. - value, ok = item.(map[string]interface{}) - if !ok { - // The item isn't of type map[string]interface{} - return nil, false - } +func (c *ristrettoCache) Get(key []byte) (value any, ok bool) { + return c.cache.Get(key) +} - return value, true +func (c *ristrettoCache) Remove(key []byte) { + c.cache.Del(key) } // startMetricsLogging starts logging cache metrics every hour. -func (c *RistrettoCache) startMetricsLogging(logger log.Logger) { +func (c *ristrettoCache) startMetricsLogging(logger log.Logger) { ticker := time.NewTicker(1 * time.Hour) for { select { diff --git a/tools/walletextension/cache/cache.go b/tools/walletextension/cache/cache.go index e0886abf9e..c080779305 100644 --- a/tools/walletextension/cache/cache.go +++ b/tools/walletextension/cache/cache.go @@ -1,111 +1,17 @@ package cache import ( - "crypto/sha256" - "encoding/json" - "fmt" "time" "github.com/ethereum/go-ethereum/log" - - "github.com/ten-protocol/go-ten/tools/walletextension/common" -) - -const ( - longCacheTTL = 5 * time.Hour - shortCacheTTL = 1 * time.Second ) -// Define a struct to hold the cache TTL and auth requirement -type RPCMethodCacheConfig struct { - CacheTTL time.Duration - RequiresAuth bool -} - -// CacheableRPCMethods is a map of Ethereum JSON-RPC methods that can be cached and their TTL -var cacheableRPCMethods = map[string]RPCMethodCacheConfig{ - // Ethereum JSON-RPC methods that can be cached long time - "eth_getBlockByNumber": {longCacheTTL, false}, - "eth_getBlockByHash": {longCacheTTL, false}, - //"eth_getTransactionByHash": {longCacheTTL, true}, - "eth_chainId": {longCacheTTL, false}, - - // Ethereum JSON-RPC methods that can be cached short time - "eth_blockNumber": {shortCacheTTL, false}, - "eth_getCode": {shortCacheTTL, true}, - // "eth_getBalance": {longCacheTTL, true},// excluded for test: gen_cor_059 - //"eth_getTransactionReceipt": {shortCacheTTL, true}, - "eth_call": {shortCacheTTL, true}, - "eth_gasPrice": {shortCacheTTL, false}, - // "eth_getTransactionCount": {longCacheTTL, true}, // excluded for test: gen_cor_009 - "eth_estimateGas": {shortCacheTTL, true}, - "eth_feeHistory": {shortCacheTTL, false}, -} - type Cache interface { - Set(key string, value map[string]interface{}, ttl time.Duration) bool - Get(key string) (value map[string]interface{}, ok bool) + Set(key []byte, value any, ttl time.Duration) bool + Get(key []byte) (value any, ok bool) + Remove(key []byte) } func NewCache(logger log.Logger) (Cache, error) { return NewRistrettoCache(logger) } - -// IsCacheable checks if the given RPC request is cacheable and returns the cache key and TTL -func IsCacheable(key *common.RPCRequest, encryptionToken string) (bool, string, time.Duration) { - if key == nil || key.Method == "" { - return false, "", 0 - } - - // Check if the method is cacheable - methodCacheConfig, isCacheable := cacheableRPCMethods[key.Method] - - // If method does not need to be authenticated, we can don't need to cache it per user - if !methodCacheConfig.RequiresAuth { - encryptionToken = "" - } - - if isCacheable { - // method is cacheable - select cache key and ttl - switch key.Method { - case "eth_getCode", "eth_getBalance", "eth_estimateGas", "eth_call": - if len(key.Params) == 1 || len(key.Params) == 2 && (key.Params[1] == "latest" || key.Params[1] == "pending") { - return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), methodCacheConfig.CacheTTL - } - // in this case, we have a fixed block number, and we can cache the result for a long time - return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), longCacheTTL - case "eth_feeHistory": - if len(key.Params) == 2 || len(key.Params) == 3 && (key.Params[2] == "latest" || key.Params[2] == "pending") { - return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), methodCacheConfig.CacheTTL - } - // in this case, we have a fixed block number, and we can cache the result for a long time - return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), longCacheTTL - - default: - return true, GenerateCacheKey(key.Method, encryptionToken, key.Params...), methodCacheConfig.CacheTTL - } - } - - // method is not cacheable - return false, "", 0 -} - -// GenerateCacheKey generates a cache key for the given method, encryptionToken and parameters -// encryptionToken is used to generate a unique cache key for each user and empty string should be used for public data -func GenerateCacheKey(method string, encryptionToken string, params ...interface{}) string { - // Serialize parameters - paramBytes, err := json.Marshal(params) - if err != nil { - return "" - } - - // Concatenate method name and parameters - rawKey := method + encryptionToken + string(paramBytes) - - // Optional: Apply hashing - hasher := sha256.New() - hasher.Write([]byte(rawKey)) - hashedKey := fmt.Sprintf("%x", hasher.Sum(nil)) - - return hashedKey -} diff --git a/tools/walletextension/cache/cache_test.go b/tools/walletextension/cache/cache_test.go deleted file mode 100644 index 8c59d533fd..0000000000 --- a/tools/walletextension/cache/cache_test.go +++ /dev/null @@ -1,249 +0,0 @@ -package cache - -import ( - "reflect" - "testing" - "time" - - "github.com/ethereum/go-ethereum/log" - - "github.com/ten-protocol/go-ten/tools/walletextension/common" -) - -var tests = map[string]func(t *testing.T){ - "testCacheableMethods": testCacheableMethods, - "testNonCacheableMethods": testNonCacheableMethods, - "testMethodsWithLatestOrPendingParameter": testMethodsWithLatestOrPendingParameter, -} - -var cacheTests = map[string]func(cache Cache, t *testing.T){ - "testResultsAreCached": testResultsAreCached, - "testCacheTTL": testCacheTTL, - "testCachingAuthenticatedMethods": testCachingAuthenticatedMethods, - "testCachingNonAuthenticatedMethods": testCachingNonAuthenticatedMethods, -} - -var ( - nonCacheableMethods = []string{"eth_sendrawtransaction", "eth_sendtransaction", "join", "authenticate"} - encryptionToken = "test" - encryptionToken2 = "not-test" -) - -func TestGatewayCaching(t *testing.T) { - for name, test := range tests { - t.Run(name, func(t *testing.T) { - test(t) - }) - } - - // cache tests - for name, test := range cacheTests { - t.Run(name, func(t *testing.T) { - logger := log.New() - cache, err := NewCache(logger) - if err != nil { - t.Errorf("failed to create cache: %v", err) - } - test(cache, t) - }) - } -} - -// testCacheableMethods tests if the cacheable methods are cacheable -func testCacheableMethods(t *testing.T) { - for method := range cacheableRPCMethods { - key := &common.RPCRequest{Method: method} - isCacheable, _, _ := IsCacheable(key, encryptionToken) - if isCacheable != true { - t.Errorf("method %s should be cacheable", method) - } - } -} - -// testNonCacheableMethods tests if the non-cacheable methods are not cacheable -func testNonCacheableMethods(t *testing.T) { - for _, method := range nonCacheableMethods { - key := &common.RPCRequest{Method: method} - isCacheable, _, _ := IsCacheable(key, encryptionToken) - if isCacheable == true { - t.Errorf("method %s should not be cacheable", method) - } - } -} - -// testMethodsWithLatestOrPendingParameter tests if the methods with latest or pending parameter are cacheable -func testMethodsWithLatestOrPendingParameter(t *testing.T) { - methods := []string{"eth_getCode", "eth_estimateGas", "eth_call"} - for _, method := range methods { - key := &common.RPCRequest{Method: method, Params: []interface{}{"0x123", "latest"}} - _, _, ttl := IsCacheable(key, encryptionToken) - if ttl != shortCacheTTL { - t.Errorf("method %s with latest parameter should have TTL of %s, but %s received", method, shortCacheTTL, ttl) - } - - key = &common.RPCRequest{Method: method, Params: []interface{}{"0x123", "pending"}} - _, _, ttl = IsCacheable(key, encryptionToken) - if ttl != shortCacheTTL { - t.Errorf("method %s with pending parameter should have TTL of %s, but %s received", method, shortCacheTTL, ttl) - } - } -} - -// testResultsAreCached tests if the results are cached as expected -func testResultsAreCached(cache Cache, t *testing.T) { - // prepare a cacheable request and imaginary response - req := &common.RPCRequest{Method: "eth_getBlockByNumber", Params: []interface{}{"0x123"}} - res := map[string]interface{}{"result": "block"} - isCacheable, key, ttl := IsCacheable(req, encryptionToken) - if !isCacheable { - t.Errorf("method %s should be cacheable", req.Method) - } - // set the response in the cache with a TTL - if !cache.Set(key, res, ttl) { - t.Errorf("failed to set value in cache for %s", req) - } - - time.Sleep(50 * time.Millisecond) // wait for the cache to be set - value, ok := cache.Get(key) - if !ok { - t.Errorf("failed to get cached value for %s", req) - } - - if !reflect.DeepEqual(value, res) { - t.Errorf("expected %v, got %v", res, value) - } -} - -// testCacheTTL tests if the cache TTL is working as expected -func testCacheTTL(cache Cache, t *testing.T) { - req := &common.RPCRequest{Method: "eth_blockNumber", Params: []interface{}{"0x123"}} - res := map[string]interface{}{"result": "100"} - isCacheable, key, ttl := IsCacheable(req, encryptionToken) - - if !isCacheable { - t.Errorf("method %s should be cacheable", req.Method) - } - - if ttl != shortCacheTTL { - t.Errorf("method %s should have TTL of %s, but %s received", req.Method, shortCacheTTL, ttl) - } - - // set the response in the cache with a TTL - if !cache.Set(key, res, ttl) { - t.Errorf("failed to set value in cache for %s", req) - } - time.Sleep(50 * time.Millisecond) // wait for the cache to be set - - // check if the value is in the cache - value, ok := cache.Get(key) - if !ok { - t.Errorf("failed to get cached value for %s", req) - } - - if !reflect.DeepEqual(value, res) { - t.Errorf("expected %v, got %v", res, value) - } - - // sleep for the TTL to expire - time.Sleep(shortCacheTTL + 100*time.Millisecond) - _, ok = cache.Get(key) - if ok { - t.Errorf("value should not be in the cache after TTL") - } -} - -func testCachingAuthenticatedMethods(cache Cache, t *testing.T) { - // eth_getTransactionByHash - authMethods := []string{ - //"eth_getTransactionByHash", - "eth_getCode", - //"eth_getTransactionReceipt", - "eth_call", - "eth_estimateGas", - } - for _, method := range authMethods { - req := &common.RPCRequest{Method: method, Params: []interface{}{"0x123"}} - res := map[string]interface{}{"result": "transaction"} - - // store the response in cache for the first user using encryptionToken - isCacheable, key, ttl := IsCacheable(req, encryptionToken) - - if !isCacheable { - t.Errorf("method %s should be cacheable", req.Method) - } - - // set the response in the cache with a TTL - if !cache.Set(key, res, ttl) { - t.Errorf("failed to set value in cache for %s", req) - } - time.Sleep(50 * time.Millisecond) // wait for the cache to be set - - // check if the value is in the cache - value, ok := cache.Get(key) - if !ok { - t.Errorf("failed to get cached value for %s", req) - } - - // for the first error we should have the value in cache - if !reflect.DeepEqual(value, res) { - t.Errorf("expected %v, got %v", res, value) - } - - // now check with the second user asking for the same request, but with a different encryptionToken - _, key2, _ := IsCacheable(req, encryptionToken2) - - _, okSecondUser := cache.Get(key2) - if okSecondUser { - t.Errorf("another user should not see a value the first user cached %s", req) - } - } -} - -func testCachingNonAuthenticatedMethods(cache Cache, t *testing.T) { - // eth_getTransactionByHash - nonAuthMethods := []string{ - "eth_getBlockByNumber", - "eth_getBlockByHash", - "eth_chainId", - "eth_blockNumber", - "eth_gasPrice", - "eth_feeHistory", - } - - for _, method := range nonAuthMethods { - req := &common.RPCRequest{Method: method, Params: []interface{}{"0x123"}} - res := map[string]interface{}{"result": "transaction"} - - // store the response in cache for the first user using encryptionToken - isCacheable, key, ttl := IsCacheable(req, encryptionToken) - - if !isCacheable { - t.Errorf("method %s should be cacheable", req.Method) - } - - // set the response in the cache with a TTL - if !cache.Set(key, res, ttl) { - t.Errorf("failed to set value in cache for %s", req) - } - time.Sleep(50 * time.Millisecond) // wait for the cache to be set - - // check if the value is in the cache - value, ok := cache.Get(key) - if !ok { - t.Errorf("failed to get cached value for %s", req) - } - - // for the first error we should have the value in cache - if !reflect.DeepEqual(value, res) { - t.Errorf("expected %v, got %v", res, value) - } - - // now check with the second user asking for the same request, but with a different encryptionToken - _, key2, _ := IsCacheable(req, encryptionToken2) - - _, okSecondUser := cache.Get(key2) - if !okSecondUser { - t.Errorf("another user should see a value the first user cached %s", req) - } - } -} diff --git a/tools/walletextension/common/common.go b/tools/walletextension/common/common.go index 8a13de5a3d..59cb4b8a4b 100644 --- a/tools/walletextension/common/common.go +++ b/tools/walletextension/common/common.go @@ -1,11 +1,12 @@ package common import ( - "encoding/hex" "encoding/json" "fmt" "os" + gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ten-protocol/go-ten/go/common/log" @@ -34,13 +35,8 @@ func BytesToPrivateKey(keyBytes []byte) (*ecies.PrivateKey, error) { return eciesPrivateKey, nil } -// GetUserIDbyte converts userID from string to correct byte format -func GetUserIDbyte(userID string) ([]byte, error) { - return hex.DecodeString(userID) -} - func CreateEncClient( - hostRPCBindAddr string, + conn *gethrpc.Client, addressBytes []byte, privateKeyBytes []byte, signature []byte, @@ -61,7 +57,7 @@ func CreateEncClient( SignatureWithAccountKey: signature, SignatureType: signatureType, } - encClient, err := rpc.NewEncNetworkClient(hostRPCBindAddr, vk, logger) + encClient, err := rpc.NewEncNetworkClientFromConn(conn, vk, logger) if err != nil { return nil, fmt.Errorf("unable to create EncRPCClient: %w", err) } diff --git a/tools/walletextension/config/config.go b/tools/walletextension/common/config.go similarity index 97% rename from tools/walletextension/config/config.go rename to tools/walletextension/common/config.go index 9e423ca476..98092f8c7b 100644 --- a/tools/walletextension/config/config.go +++ b/tools/walletextension/common/config.go @@ -1,4 +1,4 @@ -package config +package common // Config contains the configuration required by the WalletExtension. type Config struct { diff --git a/tools/walletextension/common/constants.go b/tools/walletextension/common/constants.go index 55d1da0ac4..aa81df8b64 100644 --- a/tools/walletextension/common/constants.go +++ b/tools/walletextension/common/constants.go @@ -1,58 +1,39 @@ package common -import ( - "time" -) - const ( Localhost = "127.0.0.1" JSONKeyAddress = "address" - JSONKeyData = "data" - JSONKeyErr = "error" - JSONKeyFrom = "from" JSONKeyID = "id" JSONKeyMethod = "method" JSONKeyParams = "params" - JSONKeyResult = "result" - JSONKeyRoot = "root" JSONKeyRPCVersion = "jsonrpc" JSONKeySignature = "signature" - JSONKeySubscription = "subscription" - JSONKeyCode = "code" - JSONKeyMessage = "message" JSONKeyType = "type" JSONKeyEncryptionToken = "encryptionToken" JSONKeyFormats = "formats" ) const ( - PathRoot = "/" + PathStatic = "/static/" PathReady = "/ready/" - PathViewingKeys = "/viewingkeys/" PathJoin = "/join/" PathGetMessage = "/getmessage/" PathAuthenticate = "/authenticate/" PathQuery = "/query/" PathRevoke = "/revoke/" - PathObscuroGateway = "/" PathHealth = "/health/" PathNetworkHealth = "/network-health/" WSProtocol = "ws://" HTTPProtocol = "http://" - DefaultUser = "defaultUser" - UserQueryParameter = "u" EncryptedTokenQueryParameter = "token" AddressQueryParameter = "a" MessageUserIDLen = 40 EthereumAddressLen = 42 - GetStorageAtUserIDRequestMethodName = "getUserID" + GetStorageAtUserIDRequestMethodName = "0x0000000000000000000000000000000000000000" SuccessMsg = "success" APIVersion1 = "/v1" - MethodEthSubscription = "eth_subscription" PathVersion = "/version/" DeduplicationBufferSize = 20 DefaultGatewayAuthMessageType = "EIP712" ) - -var ReaderHeadTimeout = 10 * time.Second diff --git a/tools/walletextension/common/types.go b/tools/walletextension/common/db_types.go similarity index 100% rename from tools/walletextension/common/types.go rename to tools/walletextension/common/db_types.go diff --git a/tools/walletextension/common/responses.go b/tools/walletextension/common/responses.go deleted file mode 100644 index 65b6d77c2f..0000000000 --- a/tools/walletextension/common/responses.go +++ /dev/null @@ -1,29 +0,0 @@ -package common - -import ( - "errors" - - gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" -) - -func CraftErrorResponse(err error) map[string]interface{} { - errMap := make(map[string]interface{}) - respMap := make(map[string]interface{}) - - respMap[JSONKeyErr] = errMap - errMap[JSONKeyMessage] = err.Error() - - var e gethrpc.Error - ok := errors.As(err, &e) - if ok { - errMap[JSONKeyCode] = e.ErrorCode() - } - - var de gethrpc.DataError - ok = errors.As(err, &de) - if ok { - errMap[JSONKeyData] = de.ErrorData() - } - - return respMap -} diff --git a/tools/walletextension/container/walletextension_container.go b/tools/walletextension/container/walletextension_container.go deleted file mode 100644 index 1dccaab400..0000000000 --- a/tools/walletextension/container/walletextension_container.go +++ /dev/null @@ -1,209 +0,0 @@ -package container - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - "net/http" - "os" - - "github.com/ten-protocol/go-ten/go/common/viewingkey" - - "github.com/ethereum/go-ethereum/common" - - "github.com/ethereum/go-ethereum/crypto" - - "github.com/ten-protocol/go-ten/go/common/log" - "github.com/ten-protocol/go-ten/go/common/stopcontrol" - "github.com/ten-protocol/go-ten/go/rpc" - "github.com/ten-protocol/go-ten/tools/walletextension" - "github.com/ten-protocol/go-ten/tools/walletextension/api" - "github.com/ten-protocol/go-ten/tools/walletextension/config" - "github.com/ten-protocol/go-ten/tools/walletextension/storage" - "github.com/ten-protocol/go-ten/tools/walletextension/useraccountmanager" - - gethlog "github.com/ethereum/go-ethereum/log" - wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" -) - -type WalletExtensionContainer struct { - hostAddr string - userAccountManager *useraccountmanager.UserAccountManager - storage storage.Storage - stopControl *stopcontrol.StopControl - logger gethlog.Logger - walletExt *walletextension.WalletExtension - httpServer *api.Server - wsServer *api.Server -} - -func NewWalletExtensionContainerFromConfig(config config.Config, logger gethlog.Logger) *WalletExtensionContainer { - // create the account manager with a single unauthenticated connection - hostRPCBindAddrWS := wecommon.WSProtocol + config.NodeRPCWebsocketAddress - hostRPCBindAddrHTTP := wecommon.HTTPProtocol + config.NodeRPCHTTPAddress - unAuthedClient, err := rpc.NewNetworkClient(hostRPCBindAddrHTTP) - if err != nil { - logger.Crit("unable to create temporary client for request ", log.ErrKey, err) - os.Exit(1) - } - - // start the database - databaseStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride) - if err != nil { - logger.Crit("unable to create database to store viewing keys ", log.ErrKey, err) - os.Exit(1) - } - userAccountManager := useraccountmanager.NewUserAccountManager(unAuthedClient, logger, databaseStorage, hostRPCBindAddrHTTP, hostRPCBindAddrWS) - - // add default user (when no UserID is provided in the query parameter - for WE endpoints) - defaultUserAccountManager := userAccountManager.AddAndReturnAccountManager(hex.EncodeToString([]byte(wecommon.DefaultUser))) - - // add default user to the database (temporary fix before removing wallet extension endpoints) - accountPrivateKey, err := crypto.GenerateKey() - if err != nil { - logger.Error("Unable to generate key pair for default user", log.ErrKey, err) - os.Exit(1) - } - - // get all users and their private keys from the database - allUsers, err := databaseStorage.GetAllUsers() - if err != nil { - logger.Error(fmt.Errorf("error getting all users from database, %w", err).Error()) - os.Exit(1) - } - - // iterate over users create accountManagers and add all defaultUserAccounts to them per user - for _, user := range allUsers { - userAccountManager.AddAndReturnAccountManager(hex.EncodeToString(user.UserID)) - logger.Info(fmt.Sprintf("account manager added for user: %s", hex.EncodeToString(user.UserID))) - - // to ensure backwards compatibility we want to load clients for the default user - // TODO @ziga - this code needs to be removed when removing old wallet extension endpoints - if bytes.Equal(user.UserID, []byte(wecommon.DefaultUser)) { - accounts, err := databaseStorage.GetAccounts(user.UserID) - if err != nil { - logger.Error(fmt.Errorf("error getting accounts for user: %s, %w", hex.EncodeToString(user.UserID), err).Error()) - os.Exit(1) - } - for _, account := range accounts { - encClient, err := wecommon.CreateEncClient(hostRPCBindAddrWS, account.AccountAddress, user.PrivateKey, account.Signature, viewingkey.SignatureType(account.SignatureType), logger) - if err != nil { - logger.Error(fmt.Errorf("error creating new client, %w", err).Error()) - os.Exit(1) - } - - // add a client to default user - defaultUserAccountManager.AddClient(common.BytesToAddress(account.AccountAddress), encClient) - } - } - } - // TODO @ziga - remove this when removing wallet extension endpoints - err = databaseStorage.AddUser([]byte(wecommon.DefaultUser), crypto.FromECDSA(accountPrivateKey)) - if err != nil { - logger.Error("Unable to save default user to the database", log.ErrKey, err) - os.Exit(1) - } - - // captures version in the env vars - version := os.Getenv("OBSCURO_GATEWAY_VERSION") - if version == "" { - version = "dev" - } - - stopControl := stopcontrol.New() - walletExt := walletextension.New(hostRPCBindAddrHTTP, hostRPCBindAddrWS, &userAccountManager, databaseStorage, stopControl, version, logger, &config) - httpRoutes := api.NewHTTPRoutes(walletExt) - httpServer := api.NewHTTPServer(fmt.Sprintf("%s:%d", config.WalletExtensionHost, config.WalletExtensionPortHTTP), httpRoutes) - - wsRoutes := api.NewWSRoutes(walletExt) - wsServer := api.NewWSServer(fmt.Sprintf("%s:%d", config.WalletExtensionHost, config.WalletExtensionPortWS), wsRoutes) - return NewWalletExtensionContainer( - hostRPCBindAddrWS, - walletExt, - &userAccountManager, - databaseStorage, - stopControl, - httpServer, - wsServer, - logger, - ) -} - -func NewWalletExtensionContainer( - hostAddr string, - walletExt *walletextension.WalletExtension, - userAccountManager *useraccountmanager.UserAccountManager, - storage storage.Storage, - stopControl *stopcontrol.StopControl, - httpServer *api.Server, - wsServer *api.Server, - logger gethlog.Logger, -) *WalletExtensionContainer { - return &WalletExtensionContainer{ - hostAddr: hostAddr, - walletExt: walletExt, - userAccountManager: userAccountManager, - storage: storage, - stopControl: stopControl, - httpServer: httpServer, - wsServer: wsServer, - logger: logger, - } -} - -// Start starts the wallet extension container -func (w *WalletExtensionContainer) Start() error { - httpErrChan := w.httpServer.Start() - wsErrChan := w.wsServer.Start() - - // Start a goroutine for handling HTTP and WS server errors - go func() { - for { - select { - case err := <-httpErrChan: - if errors.Is(err, http.ErrServerClosed) { - err = w.Stop() // Stop the container when the HTTP server is closed - if err != nil { - fmt.Printf("failed to stop gracefully - %s\n", err) - os.Exit(1) - } - } else { - // for other errors, we just log them - w.logger.Error("HTTP server error: %v", err) - } - case err := <-wsErrChan: - if errors.Is(err, http.ErrServerClosed) { - err = w.Stop() // Stop the container when the WS server is closed - if err != nil { - fmt.Printf("failed to stop gracefully - %s\n", err) - os.Exit(1) - } - } else { - // for other errors, we just log them - w.logger.Error("HTTP server error: %v", err) - } - case <-w.stopControl.Done(): - return // Exit the goroutine when stop signal is received - } - } - }() - return nil -} - -func (w *WalletExtensionContainer) Stop() error { - w.stopControl.Stop() - - err := w.httpServer.Stop() - if err != nil { - w.logger.Warn("could not shut down wallet extension", log.ErrKey, err) - } - - err = w.wsServer.Stop() - if err != nil { - w.logger.Warn("could not shut down wallet extension", log.ErrKey, err) - } - - // todo (@pedro) correctly surface shutdown errors - return nil -} diff --git a/tools/walletextension/frontend/next.config.js b/tools/walletextension/frontend/next.config.js index 5d5086f7cb..04472f7ff7 100644 --- a/tools/walletextension/frontend/next.config.js +++ b/tools/walletextension/frontend/next.config.js @@ -7,6 +7,8 @@ const nextConfig = { images: { unoptimized: true, }, + // base path for static files should be "" in development but "/static" in production + basePath: process.env.NODE_ENV === "development" ? "" : "/static", }; module.exports = nextConfig; diff --git a/tools/walletextension/frontend/src/components/layouts/header.tsx b/tools/walletextension/frontend/src/components/layouts/header.tsx index a04de0fea8..3a2e59b40a 100644 --- a/tools/walletextension/frontend/src/components/layouts/header.tsx +++ b/tools/walletextension/frontend/src/components/layouts/header.tsx @@ -14,7 +14,7 @@ export default function Header() {
Logo - + diff --git a/tools/walletextension/frontend/src/routes/index.ts b/tools/walletextension/frontend/src/routes/index.ts index f3ac4438f7..c53c3bc501 100644 --- a/tools/walletextension/frontend/src/routes/index.ts +++ b/tools/walletextension/frontend/src/routes/index.ts @@ -8,7 +8,7 @@ export const apiRoutes = { authenticate: `/${tenGatewayVersion}/authenticate/`, queryAccountToken: `/${tenGatewayVersion}/query/`, revoke: `/${tenGatewayVersion}/revoke/`, - version: `/version/`, + version: `/${tenGatewayVersion}/version/`, // **** INFO **** getHealthStatus: `/${tenGatewayVersion}/network-health/`, diff --git a/tools/walletextension/httpapi/README.MD b/tools/walletextension/httpapi/README.MD new file mode 100644 index 0000000000..90a8e7e823 --- /dev/null +++ b/tools/walletextension/httpapi/README.MD @@ -0,0 +1 @@ +todo - the content of this package should be moved to rpcapi to avoid implementing the low-level http logic \ No newline at end of file diff --git a/tools/walletextension/api/routes.go b/tools/walletextension/httpapi/routes.go similarity index 66% rename from tools/walletextension/api/routes.go rename to tools/walletextension/httpapi/routes.go index dad70da1a5..4f05e5b62f 100644 --- a/tools/walletextension/api/routes.go +++ b/tools/walletextension/httpapi/routes.go @@ -1,4 +1,4 @@ -package api +package httpapi import ( "encoding/hex" @@ -6,27 +6,24 @@ import ( "fmt" "net/http" + "github.com/status-im/keycard-go/hexutils" + "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ten-protocol/go-ten/lib/gethfork/node" + "github.com/ten-protocol/go-ten/tools/walletextension/rpcapi" "github.com/ten-protocol/go-ten/go/common/log" "github.com/ten-protocol/go-ten/go/common/httputil" - "github.com/ten-protocol/go-ten/go/rpc" - "github.com/ten-protocol/go-ten/tools/walletextension" "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/userconn" ) // NewHTTPRoutes returns the http specific routes -func NewHTTPRoutes(walletExt *walletextension.WalletExtension) []node.Route { +// todo - move these to the rpc framework. +func NewHTTPRoutes(walletExt *rpcapi.Services) []node.Route { return []node.Route{ { - Name: common.APIVersion1 + common.PathRoot, - Func: httpHandler(walletExt, ethRequestHandler), - }, - { - Name: common.PathReady, + Name: common.APIVersion1 + common.PathReady, Func: httpHandler(walletExt, readyRequestHandler), }, { @@ -50,23 +47,23 @@ func NewHTTPRoutes(walletExt *walletextension.WalletExtension) []node.Route { Func: httpHandler(walletExt, revokeRequestHandler), }, { - Name: common.PathHealth, + Name: common.APIVersion1 + common.PathHealth, Func: httpHandler(walletExt, healthRequestHandler), }, { - Name: common.PathNetworkHealth, + Name: common.APIVersion1 + common.PathNetworkHealth, Func: httpHandler(walletExt, networkHealthRequestHandler), }, { - Name: common.PathVersion, + Name: common.APIVersion1 + common.PathVersion, Func: httpHandler(walletExt, versionRequestHandler), }, } } func httpHandler( - walletExt *walletextension.WalletExtension, - fun func(walletExt *walletextension.WalletExtension, conn userconn.UserConn), + walletExt *rpcapi.Services, + fun func(walletExt *rpcapi.Services, conn UserConn), ) func(resp http.ResponseWriter, req *http.Request) { return func(resp http.ResponseWriter, req *http.Request) { httpRequestHandler(walletExt, resp, req, fun) @@ -74,113 +71,23 @@ func httpHandler( } // Overall request handler for http requests -func httpRequestHandler(walletExt *walletextension.WalletExtension, resp http.ResponseWriter, req *http.Request, fun func(walletExt *walletextension.WalletExtension, conn userconn.UserConn)) { +func httpRequestHandler(walletExt *rpcapi.Services, resp http.ResponseWriter, req *http.Request, fun func(walletExt *rpcapi.Services, conn UserConn)) { if walletExt.IsStopping() { return } if httputil.EnableCORS(resp, req) { return } - userConn := userconn.NewUserConnHTTP(resp, req, walletExt.Logger()) + userConn := NewUserConnHTTP(resp, req, walletExt.Logger()) fun(walletExt, userConn) } -// NewWSRoutes returns the WS specific routes -func NewWSRoutes(walletExt *walletextension.WalletExtension) []node.Route { - return []node.Route{ - { - Name: common.PathRoot, - Func: wsHandler(walletExt, ethRequestHandler), - }, - { - Name: common.PathReady, - Func: wsHandler(walletExt, readyRequestHandler), - }, - } -} - -func wsHandler( - walletExt *walletextension.WalletExtension, - fun func(walletExt *walletextension.WalletExtension, conn userconn.UserConn), -) func(resp http.ResponseWriter, req *http.Request) { - return func(resp http.ResponseWriter, req *http.Request) { - wsRequestHandler(walletExt, resp, req, fun) - } -} - -// Overall request handler for WS requests -func wsRequestHandler(walletExt *walletextension.WalletExtension, resp http.ResponseWriter, req *http.Request, fun func(walletExt *walletextension.WalletExtension, conn userconn.UserConn)) { - if walletExt.IsStopping() { - return - } - - userConn, err := userconn.NewUserConnWS(resp, req, walletExt.Logger()) - if err != nil { - return - } - // We handle requests in a loop until the connection is closed on the client side. - for !userConn.IsClosed() { - fun(walletExt, userConn) - } -} - -// ethRequestHandler parses the user eth request, passes it on to the WE to proxy it and processes the response -func ethRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { - body, err := conn.ReadRequest() - if err != nil { - handleEthError(nil, conn, walletExt.Logger(), fmt.Errorf("error reading request - %w", err)) - return - } - - request, err := parseRequest(body) - if err != nil { - handleError(conn, walletExt.Logger(), err) - return - } - walletExt.Logger().Debug("REQUEST", "method", request.Method, "body", string(body)) - - if request.Method == rpc.Subscribe && !conn.SupportsSubscriptions() { - handleError(conn, walletExt.Logger(), fmt.Errorf("received an %s request but the connection does not support subscriptions", rpc.Subscribe)) - return - } - - // Get userID - // TODO: @ziga - after removing old wallet extension endpoints we should prevent users doing anything without valid encryption token - hexUserID, err := getUserID(conn, 1) - if err != nil || !walletExt.UserExists(hexUserID) { - walletExt.Logger().Info("user not found in the query params: %w. Using the default user", log.ErrKey, err) - hexUserID = hex.EncodeToString([]byte(common.DefaultUser)) // todo (@ziga) - this can be removed once old WE endpoints are removed - } - - if len(hexUserID) < 3 { - handleError(conn, walletExt.Logger(), fmt.Errorf("encryption token length is incorrect")) - return - } - - // todo (@pedro) remove this conn dependency - response, err := walletExt.ProxyEthRequest(request, conn, hexUserID) - if err != nil { - handleEthError(request, conn, walletExt.Logger(), err) - return - } - - rpcResponse, err := json.Marshal(response) - if err != nil { - handleEthError(request, conn, walletExt.Logger(), err) - return - } - - walletExt.Logger().Info(fmt.Sprintf("Forwarding %s response from Obscuro node: %s", request.Method, rpcResponse)) - if err = conn.WriteResponse(rpcResponse); err != nil { - walletExt.Logger().Error("error writing success response", log.ErrKey, err) - } -} - // readyRequestHandler is used to check whether the server is ready -func readyRequestHandler(_ *walletextension.WalletExtension, _ userconn.UserConn) {} +func readyRequestHandler(_ *rpcapi.Services, _ UserConn) {} // This function handles request to /join endpoint. It is responsible to create new user (new key-pair) and store it to the db -func joinRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { +func joinRequestHandler(walletExt *rpcapi.Services, conn UserConn) { + // audit() // todo (@ziga) add protection against DDOS attacks _, err := conn.ReadRequest() if err != nil { @@ -189,14 +96,14 @@ func joinRequestHandler(walletExt *walletextension.WalletExtension, conn usercon } // generate new key-pair and store it in the database - hexUserID, err := walletExt.GenerateAndStoreNewUser() + userID, err := walletExt.GenerateAndStoreNewUser() if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("internal Error")) walletExt.Logger().Error("error creating new user", log.ErrKey, err) } // write hex encoded userID in the response - err = conn.WriteResponse([]byte(hexUserID)) + err = conn.WriteResponse([]byte(hexutils.BytesToHex(userID))) if err != nil { walletExt.Logger().Error("error writing success response", log.ErrKey, err) } @@ -205,7 +112,7 @@ func joinRequestHandler(walletExt *walletextension.WalletExtension, conn usercon // This function handles request to /authenticate endpoint. // In the request we receive message, signature and address in JSON as request body and userID and address as query parameters // We then check if message is in correct format and if signature is valid. If all checks pass we save address and signature against userID -func authenticateRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { +func authenticateRequestHandler(walletExt *rpcapi.Services, conn UserConn) { // read the request body, err := conn.ReadRequest() if err != nil { @@ -248,17 +155,17 @@ func authenticateRequestHandler(walletExt *walletextension.WalletExtension, conn } // read userID from query params - hexUserID, err := getUserID(conn, 2) + userID, err := getUserID(conn) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("malformed query: 'u' required - representing encryption token - %w", err)) return } // check signature and add address and signature for that user - err = walletExt.AddAddressToUser(hexUserID, address, signature, messageType) + err = walletExt.AddAddressToUser(userID, address, signature, messageType) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("internal error")) - walletExt.Logger().Error(fmt.Sprintf("error adding address: %s to user: %s with signature: %s", address, hexUserID, signature)) + walletExt.Logger().Error(fmt.Sprintf("error adding address: %s to user: %s with signature: %s", address, userID, signature)) return } err = conn.WriteResponse([]byte(common.SuccessMsg)) @@ -270,7 +177,7 @@ func authenticateRequestHandler(walletExt *walletextension.WalletExtension, conn // This function handles request to /query endpoint. // In the query parameters address and userID are required. We check if provided address is registered for given userID // and return true/false in json response -func queryRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { +func queryRequestHandler(walletExt *rpcapi.Services, conn UserConn) { // read the request _, err := conn.ReadRequest() if err != nil { @@ -278,7 +185,7 @@ func queryRequestHandler(walletExt *walletextension.WalletExtension, conn userco return } - hexUserID, err := getUserID(conn, 2) + userID, err := getUserID(conn) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("user ('u') not found in query parameters")) walletExt.Logger().Info("user not found in the query params", log.ErrKey, err) @@ -297,10 +204,10 @@ func queryRequestHandler(walletExt *walletextension.WalletExtension, conn userco } // check if this account is registered with given user - found, err := walletExt.UserHasAccount(hexUserID, address) + found, err := walletExt.UserHasAccount(userID, address) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("internal error")) - walletExt.Logger().Error("error during checking if account exists for user", "hexUserID", hexUserID, log.ErrKey, err) + walletExt.Logger().Error("error during checking if account exists for user", "userID", userID, log.ErrKey, err) } // create and write the response @@ -322,7 +229,7 @@ func queryRequestHandler(walletExt *walletextension.WalletExtension, conn userco // This function handles request to /revoke endpoint. // It requires userID as query parameter and deletes given user and all associated viewing keys -func revokeRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { +func revokeRequestHandler(walletExt *rpcapi.Services, conn UserConn) { // read the request _, err := conn.ReadRequest() if err != nil { @@ -330,7 +237,7 @@ func revokeRequestHandler(walletExt *walletextension.WalletExtension, conn userc return } - hexUserID, err := getUserID(conn, 2) + userID, err := getUserID(conn) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("user ('u') not found in query parameters")) walletExt.Logger().Info("user not found in the query params", log.ErrKey, err) @@ -338,10 +245,10 @@ func revokeRequestHandler(walletExt *walletextension.WalletExtension, conn userc } // delete user and accounts associated with it from the database - err = walletExt.DeleteUser(hexUserID) + err = walletExt.DeleteUser(userID) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("internal error")) - walletExt.Logger().Error("unable to delete user", "hexUserID", hexUserID, log.ErrKey, err) + walletExt.Logger().Error("unable to delete user", "userID", userID, log.ErrKey, err) return } @@ -352,7 +259,7 @@ func revokeRequestHandler(walletExt *walletextension.WalletExtension, conn userc } // Handles request to /health endpoint. -func healthRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { +func healthRequestHandler(walletExt *rpcapi.Services, conn UserConn) { // read the request _, err := conn.ReadRequest() if err != nil { @@ -368,7 +275,7 @@ func healthRequestHandler(walletExt *walletextension.WalletExtension, conn userc } // Handles request to /network-health endpoint. -func networkHealthRequestHandler(walletExt *walletextension.WalletExtension, userConn userconn.UserConn) { +func networkHealthRequestHandler(walletExt *rpcapi.Services, userConn UserConn) { // read the request _, err := userConn.ReadRequest() if err != nil { @@ -394,7 +301,7 @@ func networkHealthRequestHandler(walletExt *walletextension.WalletExtension, use } // Handles request to /version endpoint. -func versionRequestHandler(walletExt *walletextension.WalletExtension, userConn userconn.UserConn) { +func versionRequestHandler(walletExt *rpcapi.Services, userConn UserConn) { // read the request _, err := userConn.ReadRequest() if err != nil { @@ -409,7 +316,7 @@ func versionRequestHandler(walletExt *walletextension.WalletExtension, userConn } // getMessageRequestHandler handles request to /get-message endpoint. -func getMessageRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { +func getMessageRequestHandler(walletExt *rpcapi.Services, conn UserConn) { // read the request body, err := conn.ReadRequest() if err != nil { @@ -451,7 +358,12 @@ func getMessageRequestHandler(walletExt *walletextension.WalletExtension, conn u } } - message, err := walletExt.GenerateUserMessageToSign(encryptionToken.(string), formatsSlice) + userID := hexutils.HexToBytes(encryptionToken.(string)) + if len(userID) != viewingkey.UserIDLength { + return + } + + message, err := walletExt.GenerateUserMessageToSign(userID, formatsSlice) if err != nil { handleError(conn, walletExt.Logger(), fmt.Errorf("internal error")) walletExt.Logger().Error("error getting message", log.ErrKey, err) diff --git a/tools/walletextension/httpapi/user_conn.go b/tools/walletextension/httpapi/user_conn.go new file mode 100644 index 0000000000..2244adcfbd --- /dev/null +++ b/tools/walletextension/httpapi/user_conn.go @@ -0,0 +1,72 @@ +package httpapi + +import ( + "fmt" + "io" + "net/http" + "net/url" + + gethlog "github.com/ethereum/go-ethereum/log" +) + +// UserConn represents a connection to a user. +type UserConn interface { + ReadRequest() ([]byte, error) + ReadRequestParams() map[string]string + WriteResponse(msg []byte) error + SupportsSubscriptions() bool + IsClosed() bool + GetHTTPRequest() *http.Request +} + +// Represents a user's connection over HTTP. +type userConnHTTP struct { + resp http.ResponseWriter + req *http.Request + logger gethlog.Logger +} + +func NewUserConnHTTP(resp http.ResponseWriter, req *http.Request, logger gethlog.Logger) UserConn { + return &userConnHTTP{resp: resp, req: req, logger: logger} +} + +func (h *userConnHTTP) ReadRequest() ([]byte, error) { + body, err := io.ReadAll(h.req.Body) + if err != nil { + return nil, fmt.Errorf("could not read request body: %w", err) + } + return body, nil +} + +func (h *userConnHTTP) WriteResponse(msg []byte) error { + _, err := h.resp.Write(msg) + if err != nil { + return fmt.Errorf("could not write response: %w", err) + } + return nil +} + +func (h *userConnHTTP) SupportsSubscriptions() bool { + return false +} + +func (h *userConnHTTP) IsClosed() bool { + return false +} + +func (h *userConnHTTP) ReadRequestParams() map[string]string { + return getQueryParams(h.req.URL.Query()) +} + +func (h *userConnHTTP) GetHTTPRequest() *http.Request { + return h.req +} + +func getQueryParams(query url.Values) map[string]string { + params := make(map[string]string) + queryParams := query + for key, value := range queryParams { + params[key] = value[0] + } + return params +} diff --git a/tools/walletextension/httpapi/utils.go b/tools/walletextension/httpapi/utils.go new file mode 100644 index 0000000000..cc7c50fe1a --- /dev/null +++ b/tools/walletextension/httpapi/utils.go @@ -0,0 +1,43 @@ +package httpapi + +import ( + "fmt" + + gethlog "github.com/ethereum/go-ethereum/log" + "github.com/status-im/keycard-go/hexutils" + "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/tools/walletextension/common" +) + +func getQueryParameter(params map[string]string, selectedParameter string) (string, error) { + value, exists := params[selectedParameter] + if !exists { + return "", fmt.Errorf("parameter '%s' is not in the query params", selectedParameter) + } + + return value, nil +} + +// getUserID returns userID from query params / url of the URL +// it always first tries to get userID from a query parameter `u` or `token` (`u` parameter will become deprecated) +// if it fails to get userID from a query parameter it tries to get it from the URL and it needs position as the second parameter +func getUserID(conn UserConn) ([]byte, error) { + // try getting userID (`token`) from query parameters and return it if successful + userID, err := getQueryParameter(conn.ReadRequestParams(), common.EncryptedTokenQueryParameter) + if err == nil { + if len(userID) != common.MessageUserIDLen { + return nil, fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d", len(userID), common.MessageUserIDLen)) + } + return hexutils.HexToBytes(userID), err + } + + return nil, fmt.Errorf("missing token field") +} + +func handleError(conn UserConn, logger gethlog.Logger, err error) { + logger.Warn("error processing request - Forwarding response to user", log.ErrKey, err) + + if err = conn.WriteResponse([]byte(err.Error())); err != nil { + logger.Error("unable to write response back", log.ErrKey, err) + } +} diff --git a/tools/walletextension/lib/client_lib.go b/tools/walletextension/lib/client_lib.go index 42be5c4eb0..3e63a1a22e 100644 --- a/tools/walletextension/lib/client_lib.go +++ b/tools/walletextension/lib/client_lib.go @@ -9,6 +9,8 @@ import ( "net/http" "strings" + "github.com/status-im/keycard-go/hexutils" + "github.com/ten-protocol/go-ten/integration" gethcommon "github.com/ethereum/go-ethereum/common" @@ -31,7 +33,11 @@ func NewTenGatewayLibrary(httpURL, wsURL string) *TGLib { } func (o *TGLib) UserID() string { - return string(o.userID) + return hexutils.BytesToHex(o.userID) +} + +func (o *TGLib) UserIDBytes() []byte { + return o.userID } func (o *TGLib) Join() error { @@ -40,13 +46,13 @@ func (o *TGLib) Join() error { if err != nil || statusCode != 200 { return fmt.Errorf(fmt.Sprintf("Failed to get userID. Status code: %d, err: %s", statusCode, err)) } - o.userID = userID + o.userID = hexutils.HexToBytes(string(userID)) return nil } func (o *TGLib) RegisterAccount(pk *ecdsa.PrivateKey, addr gethcommon.Address) error { // create the registration message - message, err := viewingkey.GenerateMessage(string(o.userID), integration.TenChainID, 1, viewingkey.EIP712Signature) + message, err := viewingkey.GenerateMessage(o.userID, integration.TenChainID, 1, viewingkey.EIP712Signature) if err != nil { return err } @@ -68,7 +74,7 @@ func (o *TGLib) RegisterAccount(pk *ecdsa.PrivateKey, addr gethcommon.Address) e req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - o.httpURL+"/v1/authenticate/?token="+string(o.userID), + o.httpURL+"/v1/authenticate/?token="+hexutils.BytesToHex(o.userID), strings.NewReader(payload), ) if err != nil { @@ -96,7 +102,7 @@ func (o *TGLib) RegisterAccount(pk *ecdsa.PrivateKey, addr gethcommon.Address) e func (o *TGLib) RegisterAccountPersonalSign(pk *ecdsa.PrivateKey, addr gethcommon.Address) error { // create the registration message - message, err := viewingkey.GenerateMessage(string(o.userID), integration.TenChainID, viewingkey.PersonalSignVersion, viewingkey.PersonalSign) + message, err := viewingkey.GenerateMessage(o.userID, integration.TenChainID, viewingkey.PersonalSignVersion, viewingkey.PersonalSign) if err != nil { return err } @@ -118,7 +124,7 @@ func (o *TGLib) RegisterAccountPersonalSign(pk *ecdsa.PrivateKey, addr gethcommo req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - o.httpURL+"/v1/authenticate/?token="+string(o.userID), + o.httpURL+"/v1/authenticate/?token="+hexutils.BytesToHex(o.userID), strings.NewReader(payload), ) if err != nil { @@ -145,9 +151,9 @@ func (o *TGLib) RegisterAccountPersonalSign(pk *ecdsa.PrivateKey, addr gethcommo } func (o *TGLib) HTTP() string { - return fmt.Sprintf("%s/v1/?token=%s", o.httpURL, o.userID) + return fmt.Sprintf("%s/v1/?token=%s", o.httpURL, hexutils.BytesToHex(o.userID)) } func (o *TGLib) WS() string { - return fmt.Sprintf("%s/v1/?token=%s", o.wsURL, o.userID) + return fmt.Sprintf("%s/v1/?token=%s", o.wsURL, hexutils.BytesToHex(o.userID)) } diff --git a/tools/walletextension/main/cli.go b/tools/walletextension/main/cli.go index b84b854d90..dbbe72a9a0 100644 --- a/tools/walletextension/main/cli.go +++ b/tools/walletextension/main/cli.go @@ -4,7 +4,7 @@ import ( "flag" "fmt" - "github.com/ten-protocol/go-ten/tools/walletextension/config" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) const ( @@ -61,7 +61,7 @@ const ( storeIncomingTxsUsage = "Flag to enable storing incoming transactions in the database for debugging purposes. Default: true" ) -func parseCLIArgs() config.Config { +func parseCLIArgs() wecommon.Config { walletExtensionHost := flag.String(walletExtensionHostName, walletExtensionHostDefault, walletExtensionHostUsage) walletExtensionPort := flag.Int(walletExtensionPortName, walletExtensionPortDefault, walletExtensionPortUsage) walletExtensionPortWS := flag.Int(walletExtensionPortWSName, walletExtensionPortWSDefault, walletExtensionPortWSUsage) @@ -77,7 +77,7 @@ func parseCLIArgs() config.Config { storeIncomingTransactions := flag.Bool(storeIncomingTxs, storeIncomingTxsDefault, storeIncomingTxsUsage) flag.Parse() - return config.Config{ + return wecommon.Config{ WalletExtensionHost: *walletExtensionHost, WalletExtensionPortHTTP: *walletExtensionPort, WalletExtensionPortWS: *walletExtensionPortWS, diff --git a/tools/walletextension/main/main.go b/tools/walletextension/main/main.go index b909f43368..4b200968c0 100644 --- a/tools/walletextension/main/main.go +++ b/tools/walletextension/main/main.go @@ -7,9 +7,10 @@ import ( "os" "time" + "github.com/ten-protocol/go-ten/tools/walletextension" + "github.com/ten-protocol/go-ten/go/common/log" "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/container" gethlog "github.com/ethereum/go-ethereum/log" ) @@ -58,7 +59,7 @@ func main() { } logger := log.New(log.WalletExtCmp, int(logLvl), config.LogPath) - walletExtContainer := container.NewWalletExtensionContainerFromConfig(config, logger) + walletExtContainer := walletextension.NewContainerFromConfig(config, logger) // Start the wallet extension. err := walletExtContainer.Start() diff --git a/tools/walletextension/rpcapi/blockchain_api.go b/tools/walletextension/rpcapi/blockchain_api.go new file mode 100644 index 0000000000..e9e1d66e3f --- /dev/null +++ b/tools/walletextension/rpcapi/blockchain_api.go @@ -0,0 +1,261 @@ +package rpcapi + +import ( + "context" + "encoding/json" + "time" + + "github.com/ethereum/go-ethereum/core/types" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ten-protocol/go-ten/go/common/gethapi" + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" +) + +type BlockChainAPI struct { + we *Services +} + +func NewBlockChainAPI(we *Services) *BlockChainAPI { + return &BlockChainAPI{we} +} + +func (api *BlockChainAPI) ChainId() *hexutil.Big { //nolint:stylecheck + chainID, _ := UnauthenticatedTenRPCCall[hexutil.Big](context.Background(), api.we, &CacheCfg{TTL: longCacheTTL}, "eth_chainId") + return chainID +} + +func (api *BlockChainAPI) BlockNumber() hexutil.Uint64 { + nr, err := UnauthenticatedTenRPCCall[hexutil.Uint64](context.Background(), api.we, &CacheCfg{TTL: shortCacheTTL}, "eth_blockNumber") + if err != nil { + return hexutil.Uint64(0) + } + return *nr +} + +func (api *BlockChainAPI) GetBalance(ctx context.Context, address common.Address, blockNrOrHash rpc.BlockNumberOrHash) (*hexutil.Big, error) { + return ExecAuthRPC[hexutil.Big]( + ctx, + api.we, + &ExecCfg{ + cacheCfg: &CacheCfg{ + TTLCallback: func() time.Duration { + return cacheTTLBlockNumberOrHash(blockNrOrHash) + }, + }, + account: &address, + tryUntilAuthorised: true, // the user can request the balance of a contract account + }, + "eth_getBalance", + address, + blockNrOrHash, + ) +} + +// Result structs for GetProof +type AccountResult struct { + Address common.Address `json:"address"` + AccountProof []string `json:"accountProof"` + Balance *hexutil.Big `json:"balance"` + CodeHash common.Hash `json:"codeHash"` + Nonce hexutil.Uint64 `json:"nonce"` + StorageHash common.Hash `json:"storageHash"` + StorageProof []StorageResult `json:"storageProof"` +} + +type StorageResult struct { + Key string `json:"key"` + Value *hexutil.Big `json:"value"` + Proof []string `json:"proof"` +} + +func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, storageKeys []string, blockNrOrHash rpc.BlockNumberOrHash) (*AccountResult, error) { + return nil, rpcNotImplemented +} + +func (api *BlockChainAPI) GetHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (map[string]interface{}, error) { + resp, err := UnauthenticatedTenRPCCall[map[string]interface{}](ctx, api.we, &CacheCfg{TTLCallback: func() time.Duration { + return cacheTTLBlockNumber(number) + }}, "eth_getHeaderByNumber", number) + if resp == nil { + return nil, err + } + return *resp, err +} + +func (api *BlockChainAPI) GetHeaderByHash(ctx context.Context, hash common.Hash) map[string]interface{} { + resp, _ := UnauthenticatedTenRPCCall[map[string]interface{}](ctx, api.we, &CacheCfg{TTL: longCacheTTL}, "eth_getHeaderByHash", hash) + if resp == nil { + return nil + } + return *resp +} + +func (api *BlockChainAPI) GetBlockByNumber(ctx context.Context, number rpc.BlockNumber, fullTx bool) (map[string]interface{}, error) { + resp, err := UnauthenticatedTenRPCCall[map[string]interface{}]( + ctx, + api.we, + &CacheCfg{ + TTLCallback: func() time.Duration { + return cacheTTLBlockNumber(number) + }, + }, "eth_getBlockByNumber", number, fullTx) + if resp == nil { + return nil, err + } + return *resp, err +} + +func (api *BlockChainAPI) GetBlockByHash(ctx context.Context, hash common.Hash, fullTx bool) (map[string]interface{}, error) { + resp, err := UnauthenticatedTenRPCCall[map[string]interface{}](ctx, api.we, &CacheCfg{TTL: longCacheTTL}, "eth_getBlockByHash", hash, fullTx) + if resp == nil { + return nil, err + } + return *resp, err +} + +func (api *BlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { + // todo - must be authenticated + resp, err := UnauthenticatedTenRPCCall[hexutil.Bytes]( + ctx, + api.we, + &CacheCfg{ + TTLCallback: func() time.Duration { + return cacheTTLBlockNumberOrHash(blockNrOrHash) + }, + }, + "eth_getCode", + address, + blockNrOrHash, + ) + if resp == nil { + return nil, err + } + return *resp, err +} + +func (api *BlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, hexKey string, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { + // GetStorageAt is repurposed to return the userID + if address.Hex() == wecommon.GetStorageAtUserIDRequestMethodName { + userID, err := extractUserID(ctx, api.we) + if err != nil { + return nil, err + } + + _, err = getUser(userID, api.we) + if err != nil { + return nil, err + } + return userID, nil + } + + resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &ExecCfg{account: &address}, "eth_getStorageAt", address, hexKey, blockNrOrHash) + if resp == nil { + return nil, err + } + return *resp, err +} + +func (s *BlockChainAPI) GetBlockReceipts(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) ([]map[string]interface{}, error) { + return nil, rpcNotImplemented +} + +type OverrideAccount struct { + Nonce *hexutil.Uint64 `json:"nonce"` + Code *hexutil.Bytes `json:"code"` + Balance **hexutil.Big `json:"balance"` + State *map[common.Hash]common.Hash `json:"state"` + StateDiff *map[common.Hash]common.Hash `json:"stateDiff"` +} +type ( + StateOverride map[common.Address]OverrideAccount + BlockOverrides struct { + Number *hexutil.Big + Difficulty *hexutil.Big + Time *hexutil.Uint64 + GasLimit *hexutil.Uint64 + Coinbase *common.Address + Random *common.Hash + BaseFee *hexutil.Big + } +) + +func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, blockOverrides *BlockOverrides) (hexutil.Bytes, error) { + resp, err := ExecAuthRPC[hexutil.Bytes](ctx, api.we, &ExecCfg{ + cacheCfg: &CacheCfg{ + TTLCallback: func() time.Duration { + return cacheTTLBlockNumberOrHash(blockNrOrHash) + }, + }, + computeFromCallback: func(user *GWUser) *common.Address { + return searchFromAndData(user.GetAllAddresses(), args) + }, + adjustArgs: func(acct *GWAccount) []any { + argsClone := populateFrom(acct, args) + return []any{argsClone, blockNrOrHash, overrides, blockOverrides} + }, + tryAll: true, + }, "eth_call", args, blockNrOrHash, overrides, blockOverrides) + if resp == nil { + return nil, err + } + return *resp, err +} + +func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) { + resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &ExecCfg{ + cacheCfg: &CacheCfg{ + TTLCallback: func() time.Duration { + if blockNrOrHash != nil { + return cacheTTLBlockNumberOrHash(*blockNrOrHash) + } + return shortCacheTTL + }, + }, + computeFromCallback: func(user *GWUser) *common.Address { + return searchFromAndData(user.GetAllAddresses(), args) + }, + adjustArgs: func(acct *GWAccount) []any { + argsClone := populateFrom(acct, args) + return []any{argsClone, blockNrOrHash, overrides} + }, + // is this a security risk? + tryAll: true, + }, "eth_estimateGas", args, blockNrOrHash, overrides) + if resp == nil { + return 0, err + } + return *resp, err +} + +func populateFrom(acct *GWAccount, args gethapi.TransactionArgs) gethapi.TransactionArgs { + // clone the args + argsClone := cloneArgs(args) + // set the from + if args.From == nil || args.From.Hex() == (common.Address{}).Hex() { + argsClone.From = acct.address + } + return argsClone +} + +func cloneArgs(args gethapi.TransactionArgs) gethapi.TransactionArgs { + serialised, _ := json.Marshal(args) + var argsClone gethapi.TransactionArgs + err := json.Unmarshal(serialised, &argsClone) + if err != nil { + return gethapi.TransactionArgs{} + } + return argsClone +} + +type accessListResult struct { + Accesslist *types.AccessList `json:"accessList"` + Error string `json:"error,omitempty"` + GasUsed hexutil.Uint64 `json:"gasUsed"` +} + +func (s *BlockChainAPI) CreateAccessList(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash) (*accessListResult, error) { + return nil, rpcNotImplemented +} diff --git a/tools/walletextension/rpcapi/debug_api.go b/tools/walletextension/rpcapi/debug_api.go new file mode 100644 index 0000000000..fc0452611d --- /dev/null +++ b/tools/walletextension/rpcapi/debug_api.go @@ -0,0 +1,55 @@ +package rpcapi + +import ( + "context" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" +) + +type DebugAPI struct { + we *Services +} + +func NewDebugAPI(we *Services) *DebugAPI { + return &DebugAPI{we} +} + +func (api *DebugAPI) GetRawHeader(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { + return nil, rpcNotImplemented +} + +func (api *DebugAPI) GetRawBlock(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { + return nil, rpcNotImplemented +} + +func (api *DebugAPI) GetRawReceipts(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) ([]hexutil.Bytes, error) { + return nil, rpcNotImplemented +} + +func (s *DebugAPI) GetRawTransaction(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) { + return nil, rpcNotImplemented +} + +func (api *DebugAPI) PrintBlock(ctx context.Context, number uint64) (string, error) { + return "", rpcNotImplemented +} + +func (api *DebugAPI) ChaindbProperty(property string) (string, error) { + return "", rpcNotImplemented +} + +func (api *DebugAPI) ChaindbCompact() error { + return rpcNotImplemented +} + +func (api *DebugAPI) SetHead(number hexutil.Uint64) { + // not implemented +} + +// EventLogRelevancy - specific to Ten - todo +func (api *DebugAPI) EventLogRelevancy(_ context.Context, _ common.Hash) (interface{}, error) { + // todo + return nil, rpcNotImplemented +} diff --git a/tools/walletextension/subscriptions/deduplication_circular_buffer.go b/tools/walletextension/rpcapi/deduplication_circular_buffer.go similarity index 98% rename from tools/walletextension/subscriptions/deduplication_circular_buffer.go rename to tools/walletextension/rpcapi/deduplication_circular_buffer.go index 52a2e23cfe..89b2f2614c 100644 --- a/tools/walletextension/subscriptions/deduplication_circular_buffer.go +++ b/tools/walletextension/rpcapi/deduplication_circular_buffer.go @@ -1,4 +1,4 @@ -package subscriptions +package rpcapi import "github.com/ethereum/go-ethereum/common" diff --git a/tools/walletextension/rpcapi/ethereum_api.go b/tools/walletextension/rpcapi/ethereum_api.go new file mode 100644 index 0000000000..742c5c71c9 --- /dev/null +++ b/tools/walletextension/rpcapi/ethereum_api.go @@ -0,0 +1,52 @@ +package rpcapi + +import ( + "context" + "time" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" +) + +type EthereumAPI struct { + we *Services +} + +func NewEthereumAPI(we *Services, +) *EthereumAPI { + return &EthereumAPI{we} +} + +func (api *EthereumAPI) GasPrice(ctx context.Context) (*hexutil.Big, error) { + return UnauthenticatedTenRPCCall[hexutil.Big](ctx, api.we, &CacheCfg{TTL: shortCacheTTL}, "eth_gasPrice") +} + +func (api *EthereumAPI) MaxPriorityFeePerGas(ctx context.Context) (*hexutil.Big, error) { + return UnauthenticatedTenRPCCall[hexutil.Big](ctx, api.we, &CacheCfg{TTL: shortCacheTTL}, "eth_maxPriorityFeePerGas") +} + +type FeeHistoryResult struct { + OldestBlock *hexutil.Big `json:"oldestBlock"` + Reward [][]*hexutil.Big `json:"reward,omitempty"` + BaseFee []*hexutil.Big `json:"baseFeePerGas,omitempty"` + GasUsedRatio []float64 `json:"gasUsedRatio"` +} + +func (api *EthereumAPI) FeeHistory(ctx context.Context, blockCount math.HexOrDecimal64, lastBlock rpc.BlockNumber, rewardPercentiles []float64) (*FeeHistoryResult, error) { + return UnauthenticatedTenRPCCall[FeeHistoryResult]( + ctx, + api.we, + &CacheCfg{TTLCallback: func() time.Duration { + return cacheTTLBlockNumber(lastBlock) + }}, + "eth_feeHistory", + blockCount, + lastBlock, + rewardPercentiles, + ) +} + +func (api *EthereumAPI) Syncing() (interface{}, error) { + return nil, rpcNotImplemented +} diff --git a/tools/walletextension/rpcapi/filter_api.go b/tools/walletextension/rpcapi/filter_api.go new file mode 100644 index 0000000000..247d235b05 --- /dev/null +++ b/tools/walletextension/rpcapi/filter_api.go @@ -0,0 +1,248 @@ +package rpcapi + +import ( + "context" + "fmt" + "reflect" + "sync/atomic" + "time" + + pool "github.com/jolestar/go-commons-pool/v2" + tenrpc "github.com/ten-protocol/go-ten/go/rpc" + + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ten-protocol/go-ten/go/common" + + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" +) + +type FilterAPI struct { + we *Services +} + +func NewFilterAPI(we *Services) *FilterAPI { + return &FilterAPI{we: we} +} + +func (api *FilterAPI) NewPendingTransactionFilter(_ *bool) rpc.ID { + return "not supported" +} + +func (api *FilterAPI) NewPendingTransactions(ctx context.Context, fullTx *bool) (*rpc.Subscription, error) { + return nil, fmt.Errorf("not supported") +} + +func (api *FilterAPI) NewBlockFilter() rpc.ID { + // not implemented + return "" +} + +func (api *FilterAPI) NewHeads(ctx context.Context) (*rpc.Subscription, error) { + return nil, rpcNotImplemented +} + +func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rpc.Subscription, error) { + audit(api.we, "start Logs subscription %v", crit) + subNotifier, user, err := getUserAndNotifier(ctx, api) + if err != nil { + return nil, err + } + + // determine the accounts to use for the backend subscriptions + candidateAddresses := user.GetAllAddresses() + if len(candidateAddresses) > 1 { + candidateAddresses = searchForAddressInFilterCriteria(crit, user.GetAllAddresses()) + // when we can't determine which addresses to use based on the criteria, use all of them + if len(candidateAddresses) == 0 { + candidateAddresses = user.GetAllAddresses() + } + } + + inputChannels := make([]chan common.IDAndLog, 0) + backendSubscriptions := make([]*rpc.ClientSubscription, 0) + connections := make([]*tenrpc.EncRPCClient, 0) + for _, address := range candidateAddresses { + rpcWSClient, err := connectWS(user.accounts[*address], api.we.Logger()) + if err != nil { + return nil, err + } + connections = append(connections, rpcWSClient) + + inCh := make(chan common.IDAndLog) + backendSubscription, err := rpcWSClient.Subscribe(ctx, nil, "eth", inCh, "logs", crit) + if err != nil { + fmt.Printf("could not connect to backend %s", err) + return nil, err + } + + inputChannels = append(inputChannels, inCh) + backendSubscriptions = append(backendSubscriptions, backendSubscription) + } + + dedupeBuffer := NewCircularBuffer(wecommon.DeduplicationBufferSize) + subscription := subNotifier.CreateSubscription() + + unsubscribed := atomic.Bool{} + go forwardAndDedupe(inputChannels, backendSubscriptions, subscription, subNotifier, &unsubscribed, func(data common.IDAndLog) *types.Log { + uniqueLogKey := LogKey{ + BlockHash: data.Log.BlockHash, + TxHash: data.Log.TxHash, + Index: data.Log.Index, + } + + if !dedupeBuffer.Contains(uniqueLogKey) { + dedupeBuffer.Push(uniqueLogKey) + return data.Log + } + return nil + }) + + go handleUnsubscribe(subscription, backendSubscriptions, connections, api.we.rpcWSConnPool, &unsubscribed) + + return subscription, err +} + +func getUserAndNotifier(ctx context.Context, api *FilterAPI) (*rpc.Notifier, *GWUser, error) { + subNotifier, supported := rpc.NotifierFromContext(ctx) + if !supported { + return nil, nil, fmt.Errorf("creation of subscriptions is not supported") + } + + // todo - we might want to allow access to public logs + if len(subNotifier.UserID) == 0 { + return nil, nil, fmt.Errorf("illegal access") + } + + user, err := getUser(subNotifier.UserID, api.we) + if err != nil { + return nil, nil, fmt.Errorf("illegal access: %s, %w", subNotifier.UserID, err) + } + return subNotifier, user, nil +} + +func searchForAddressInFilterCriteria(filterCriteria common.FilterCriteria, possibleAddresses []*gethcommon.Address) []*gethcommon.Address { + result := make([]*gethcommon.Address, 0) + addrMap := toMap(possibleAddresses) + for _, topicCondition := range filterCriteria.Topics { + for _, topic := range topicCondition { + potentialAddr := common.ExtractPotentialAddress(topic) + if potentialAddr != nil && addrMap[*potentialAddr] != nil { + result = append(result, potentialAddr) + } + } + } + return result +} + +// forwardAndDedupe - reads messages from the input channels, and forwards them to the notifier only if they are new +func forwardAndDedupe[R any, T any](inputChannels []chan R, _ []*rpc.ClientSubscription, outSub *rpc.Subscription, notifier *rpc.Notifier, unsubscribed *atomic.Bool, toForward func(elem R) *T) { + inputCases := make([]reflect.SelectCase, len(inputChannels)+1) + + // create a ticker to handle cleanup + inputCases[0] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(time.NewTicker(10 * time.Second).C), + } + + // create a select "case" for each input channel + for i, ch := range inputChannels { + inputCases[i+1] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} + } + + unclosedInputChannels := len(inputCases) + for unclosedInputChannels > 0 { + chosen, value, ok := reflect.Select(inputCases) + if !ok { + // The chosen channel has been closed, so zero out the channel to disable the case + inputCases[chosen].Chan = reflect.ValueOf(nil) + unclosedInputChannels-- + continue + } + + switch v := value.Interface().(type) { + case time.Time: + // exit the loop to avoid a goroutine loop + if unsubscribed.Load() { + return + } + case R: + valueToSubmit := toForward(v) + if valueToSubmit != nil { + err := notifier.Notify(outSub.ID, *valueToSubmit) + if err != nil { + return + } + } + default: + // unexpected element received + continue + } + } +} + +func handleUnsubscribe(connectionSub *rpc.Subscription, backendSubscriptions []*rpc.ClientSubscription, connections []*tenrpc.EncRPCClient, p *pool.ObjectPool, unsubscribed *atomic.Bool) { + <-connectionSub.Err() + unsubscribed.Store(true) + for _, backendSub := range backendSubscriptions { + backendSub.Unsubscribe() + } + for _, connection := range connections { + _ = returnConn(p, connection) + } +} + +func (api *FilterAPI) NewFilter(crit common.FilterCriteria) (rpc.ID, error) { + return rpc.NewID(), rpcNotImplemented +} + +func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ([]*types.Log, error) { + logs, err := ExecAuthRPC[[]*types.Log]( + ctx, + api.we, + &ExecCfg{ + cacheCfg: &CacheCfg{ + TTLCallback: func() time.Duration { + // when the toBlock is not specified, the request is open-ended + if crit.ToBlock != nil && crit.ToBlock.Int64() > 0 { + return longCacheTTL + } + return shortCacheTTL + }, + }, + tryUntilAuthorised: true, + adjustArgs: func(acct *GWAccount) []any { + // convert to something serializable + return []any{common.FromCriteria(crit)} + }, + }, + "eth_getLogs", + crit, + ) + if logs != nil { + return *logs, err + } + return nil, err +} + +func (api *FilterAPI) UninstallFilter(id rpc.ID) bool { + // not implemented + return false +} + +func (api *FilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*types.Log, error) { + //txRec, err := ExecAuthRPC[[]*types.Log](ctx, api.we, "GetFilterLogs", ExecCfg{account: args.From}, id) + //if txRec != nil { + // return *txRec, err + //} + //return common.Hash{}, err + + // not implemented + return nil, nil +} + +func (api *FilterAPI) GetFilterChanges(id rpc.ID) (interface{}, error) { + return nil, rpcNotImplemented +} diff --git a/tools/walletextension/rpcapi/from_tx_args.go b/tools/walletextension/rpcapi/from_tx_args.go new file mode 100644 index 0000000000..996e0b667d --- /dev/null +++ b/tools/walletextension/rpcapi/from_tx_args.go @@ -0,0 +1,66 @@ +package rpcapi + +import ( + "strings" + + "github.com/ethereum/go-ethereum/common" + "github.com/status-im/keycard-go/hexutils" + "github.com/ten-protocol/go-ten/go/common/gethapi" +) + +func searchFromAndData(possibleAddresses []*common.Address, args gethapi.TransactionArgs) *common.Address { + if args.From != nil { + return args.From + } + + if args.Data == nil { + return nil + } + + // since the "from" field is not mandatory, we try to find a matching address in the data field + addressesMap := toMap(possibleAddresses) + return searchDataFieldForAccount(addressesMap, *args.Data) +} + +func searchDataFieldForAccount(addressesMap map[common.Address]*common.Address, data []byte) *common.Address { + hexEncodedData := hexutils.BytesToHex(data) + + // We check that the data field is long enough before removing the leading "0x" (1 bytes/2 chars) and the method ID + // (4 bytes/8 chars). + if len(hexEncodedData) < 10 { + return nil + } + hexEncodedData = hexEncodedData[10:] + + // We split up the arguments in the `data` field. + var dataArgs []string + for i := 0; i < len(hexEncodedData); i += ethCallPaddedArgLen { + if i+ethCallPaddedArgLen > len(hexEncodedData) { + break + } + dataArgs = append(dataArgs, hexEncodedData[i:i+ethCallPaddedArgLen]) + } + + // We iterate over the arguments, looking for an argument that matches a viewing key address + for _, dataArg := range dataArgs { + // If the argument doesn't have the correct padding, it's not an address. + if !strings.HasPrefix(dataArg, ethCallAddrPadding) { + continue + } + + maybeAddress := common.HexToAddress(dataArg[len(ethCallAddrPadding):]) + if _, ok := addressesMap[maybeAddress]; ok { + return &maybeAddress + } + } + + return nil +} + +func toMap(possibleAddresses []*common.Address) map[common.Address]*common.Address { + addresses := map[common.Address]*common.Address{} + for i := range possibleAddresses { + addresses[*possibleAddresses[i]] = possibleAddresses[i] + } + return addresses +} diff --git a/tools/walletextension/rpcapi/gw_user.go b/tools/walletextension/rpcapi/gw_user.go new file mode 100644 index 0000000000..abd65df0fa --- /dev/null +++ b/tools/walletextension/rpcapi/gw_user.go @@ -0,0 +1,62 @@ +package rpcapi + +import ( + "fmt" + + "github.com/status-im/keycard-go/hexutils" + "github.com/ten-protocol/go-ten/go/common/viewingkey" + + "github.com/ethereum/go-ethereum/common" +) + +var userCacheKeyPrefix = []byte{0x0, 0x1, 0x2, 0x3} + +type GWAccount struct { + user *GWUser + address *common.Address + signature []byte + signatureType viewingkey.SignatureType +} + +type GWUser struct { + userID []byte + services *Services + accounts map[common.Address]*GWAccount + userKey []byte +} + +func (u GWUser) GetAllAddresses() []*common.Address { + accts := make([]*common.Address, 0) + for _, acc := range u.accounts { + accts = append(accts, acc.address) + } + return accts +} + +func userCacheKey(userID []byte) []byte { + var key []byte + key = append(key, userCacheKeyPrefix...) + key = append(key, userID...) + return key +} + +func getUser(userID []byte, s *Services) (*GWUser, error) { + return withCache(s.Cache, &CacheCfg{TTL: longCacheTTL}, userCacheKey(userID), func() (*GWUser, error) { + result := GWUser{userID: userID, services: s, accounts: map[common.Address]*GWAccount{}} + userPrivateKey, err := s.Storage.GetUserPrivateKey(userID) + if err != nil { + return nil, fmt.Errorf("user %s not found. %w", hexutils.BytesToHex(userID), err) + } + result.userKey = userPrivateKey + allAccounts, err := s.Storage.GetAccounts(userID) + if err != nil { + return nil, err + } + + for _, account := range allAccounts { + address := common.BytesToAddress(account.AccountAddress) + result.accounts[address] = &GWAccount{user: &result, address: &address, signature: account.Signature, signatureType: viewingkey.SignatureType(uint8(account.SignatureType))} + } + return &result, nil + }) +} diff --git a/tools/walletextension/rpcapi/transaction_api.go b/tools/walletextension/rpcapi/transaction_api.go new file mode 100644 index 0000000000..537bdbc8f4 --- /dev/null +++ b/tools/walletextension/rpcapi/transaction_api.go @@ -0,0 +1,145 @@ +package rpcapi + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ten-protocol/go-ten/go/common/gethapi" + "github.com/ten-protocol/go-ten/go/enclave/rpc" + gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" +) + +type TransactionAPI struct { + we *Services +} + +func NewTransactionAPI(we *Services) *TransactionAPI { + return &TransactionAPI{we} +} + +func (s *TransactionAPI) GetBlockTransactionCountByNumber(ctx context.Context, blockNr gethrpc.BlockNumber) *hexutil.Uint { + count, err := UnauthenticatedTenRPCCall[hexutil.Uint](ctx, s.we, &CacheCfg{TTLCallback: func() time.Duration { + return cacheTTLBlockNumber(blockNr) + }}, "eth_getBlockTransactionCountByNumber", blockNr) + if err != nil { + return nil + } + return count +} + +func (s *TransactionAPI) GetBlockTransactionCountByHash(ctx context.Context, blockHash common.Hash) *hexutil.Uint { + count, err := UnauthenticatedTenRPCCall[hexutil.Uint](ctx, s.we, &CacheCfg{TTL: longCacheTTL}, "eth_getBlockTransactionCountByHash", blockHash) + if err != nil { + return nil + } + return count +} + +func (s *TransactionAPI) GetTransactionByBlockNumberAndIndex(ctx context.Context, blockNr gethrpc.BlockNumber, index hexutil.Uint) *rpc.RpcTransaction { + // not implemented + return nil +} + +func (s *TransactionAPI) GetTransactionByBlockHashAndIndex(ctx context.Context, blockHash common.Hash, index hexutil.Uint) *rpc.RpcTransaction { + // not implemented + return nil +} + +func (s *TransactionAPI) GetRawTransactionByBlockNumberAndIndex(ctx context.Context, blockNr gethrpc.BlockNumber, index hexutil.Uint) hexutil.Bytes { + // not implemented + return nil +} + +func (s *TransactionAPI) GetRawTransactionByBlockHashAndIndex(ctx context.Context, blockHash common.Hash, index hexutil.Uint) hexutil.Bytes { + // not implemented + return nil +} + +func (s *TransactionAPI) GetTransactionCount(ctx context.Context, address common.Address, blockNrOrHash gethrpc.BlockNumberOrHash) (*hexutil.Uint64, error) { + return ExecAuthRPC[hexutil.Uint64](ctx, s.we, &ExecCfg{account: &address}, "eth_getTransactionCount", address, blockNrOrHash) +} + +func (s *TransactionAPI) GetTransactionByHash(ctx context.Context, hash common.Hash) (*rpc.RpcTransaction, error) { + return ExecAuthRPC[rpc.RpcTransaction](ctx, s.we, &ExecCfg{tryAll: true}, "eth_getTransactionByHash", hash) +} + +func (s *TransactionAPI) GetRawTransactionByHash(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) { + tx, err := ExecAuthRPC[hexutil.Bytes](ctx, s.we, &ExecCfg{tryAll: true}, "eth_getRawTransactionByHash", hash) + if tx != nil { + return *tx, err + } + return nil, err +} + +func (s *TransactionAPI) GetTransactionReceipt(ctx context.Context, hash common.Hash) (map[string]interface{}, error) { + txRec, err := ExecAuthRPC[map[string]interface{}](ctx, s.we, &ExecCfg{tryUntilAuthorised: true}, "eth_getTransactionReceipt", hash) + if err != nil { + return nil, err + } + if txRec == nil { + return nil, err + } + return *txRec, err +} + +func (s *TransactionAPI) SendTransaction(ctx context.Context, args gethapi.TransactionArgs) (common.Hash, error) { + txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &ExecCfg{account: args.From}, "eth_sendTransaction", args) + if err != nil { + return common.Hash{}, err + } + userIDBytes, _ := extractUserID(ctx, s.we) + if s.we.Config.StoreIncomingTxs && len(userIDBytes) > 10 { + tx, err := json.Marshal(args) + if err != nil { + s.we.Logger().Error("error marshalling transaction: %s", err) + return *txRec, nil + } + err = s.we.Storage.StoreTransaction(string(tx), userIDBytes) + if err != nil { + s.we.Logger().Error("error storing transaction in the database: %s", err) + return *txRec, nil + } + } + return *txRec, err +} + +type SignTransactionResult struct { + Raw hexutil.Bytes `json:"raw"` + Tx *types.Transaction `json:"tx"` +} + +func (s *TransactionAPI) FillTransaction(ctx context.Context, args gethapi.TransactionArgs) (*SignTransactionResult, error) { + return nil, rpcNotImplemented +} + +func (s *TransactionAPI) SendRawTransaction(ctx context.Context, input hexutil.Bytes) (common.Hash, error) { + txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &ExecCfg{tryAll: true}, "eth_sendRawTransaction", input) + if err != nil { + return common.Hash{}, err + } + userIDBytes, err := extractUserID(ctx, s.we) + if s.we.Config.StoreIncomingTxs && len(userIDBytes) > 10 { + err = s.we.Storage.StoreTransaction(input.String(), userIDBytes) + if err != nil { + s.we.Logger().Error(fmt.Errorf("error storing transaction in the database: %w", err).Error()) + } + } + return *txRec, err +} + +func (s *TransactionAPI) PendingTransactions() ([]*rpc.RpcTransaction, error) { + return nil, rpcNotImplemented +} + +func (s *TransactionAPI) Resend(ctx context.Context, sendArgs gethapi.TransactionArgs, gasPrice *hexutil.Big, gasLimit *hexutil.Uint64) (common.Hash, error) { + txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &ExecCfg{account: sendArgs.From}, "eth_resend", sendArgs, gasPrice, gasLimit) + if txRec != nil { + return *txRec, err + } + return common.Hash{}, err +} diff --git a/tools/walletextension/rpcapi/txpool_api.go b/tools/walletextension/rpcapi/txpool_api.go new file mode 100644 index 0000000000..16ec3f3b76 --- /dev/null +++ b/tools/walletextension/rpcapi/txpool_api.go @@ -0,0 +1,35 @@ +package rpcapi + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + rpc2 "github.com/ten-protocol/go-ten/go/enclave/rpc" +) + +type TxPoolAPI struct { + we *Services +} + +func NewTxPoolAPI(we *Services) *TxPoolAPI { + return &TxPoolAPI{we} +} + +func (s *TxPoolAPI) Content() map[string]map[string]map[string]*rpc2.RpcTransaction { + // not implemented + return nil +} + +func (s *TxPoolAPI) ContentFrom(_ common.Address) map[string]map[string]*rpc2.RpcTransaction { + // not implemented + return nil +} + +func (s *TxPoolAPI) Status() map[string]hexutil.Uint { + // not implemented + return nil +} + +func (s *TxPoolAPI) Inspect() map[string]map[string]map[string]string { + // not implemented + return nil +} diff --git a/tools/walletextension/rpcapi/utils.go b/tools/walletextension/rpcapi/utils.go new file mode 100644 index 0000000000..a0f24c1c74 --- /dev/null +++ b/tools/walletextension/rpcapi/utils.go @@ -0,0 +1,273 @@ +package rpcapi + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "time" + + gethlog "github.com/ethereum/go-ethereum/log" + pool "github.com/jolestar/go-commons-pool/v2" + tenrpc "github.com/ten-protocol/go-ten/go/rpc" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" + + "github.com/ten-protocol/go-ten/go/common/viewingkey" + + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + + "github.com/status-im/keycard-go/hexutils" + + "github.com/ten-protocol/go-ten/tools/walletextension/cache" + + "github.com/ethereum/go-ethereum/common" +) + +const ( + ethCallPaddedArgLen = 64 + ethCallAddrPadding = "000000000000000000000000" + + notAuthorised = "not authorised" + + longCacheTTL = 5 * time.Hour + shortCacheTTL = 100 * time.Millisecond +) + +var rpcNotImplemented = fmt.Errorf("rpc endpoint not implemented") + +type ExecCfg struct { + account *common.Address + computeFromCallback func(user *GWUser) *common.Address + tryAll bool + tryUntilAuthorised bool + adjustArgs func(acct *GWAccount) []any + cacheCfg *CacheCfg +} + +type CacheCfg struct { + // ResetWhenNewBlock bool todo + TTL time.Duration + // logic based on block + // todo - handle block in the future + TTLCallback func() time.Duration +} + +func UnauthenticatedTenRPCCall[R any](ctx context.Context, w *Services, cfg *CacheCfg, method string, args ...any) (*R, error) { + audit(w, "RPC start method=%s args=%v", method, args) + requestStartTime := time.Now() + cacheArgs := []any{method} + cacheArgs = append(cacheArgs, args...) + + res, err := withCache(w.Cache, cfg, generateCacheKey(cacheArgs), func() (*R, error) { + var resp *R + unauthedRPC, err := w.UnauthenticatedClient() + if err != nil { + return nil, err + } + if ctx == nil { + err = unauthedRPC.Call(&resp, method, args...) + } else { + err = unauthedRPC.CallContext(ctx, &resp, method, args...) + } + return resp, err + }) + audit(w, "RPC call. method=%s args=%v result=%s error=%s time=%d", method, args, res, err, time.Since(requestStartTime).Milliseconds()) + return res, err +} + +func ExecAuthRPC[R any](ctx context.Context, w *Services, cfg *ExecCfg, method string, args ...any) (*R, error) { + audit(w, "RPC start method=%s args=%v", method, args) + requestStartTime := time.Now() + userID, err := extractUserID(ctx, w) + if err != nil { + return nil, err + } + + user, err := getUser(userID, w) + if err != nil { + return nil, err + } + + cacheArgs := []any{userID, method} + cacheArgs = append(cacheArgs, args...) + + res, err := withCache(w.Cache, cfg.cacheCfg, generateCacheKey(cacheArgs), func() (*R, error) { + // determine candidate "from" + candidateAccts, err := getCandidateAccounts(user, w, cfg) + if err != nil { + return nil, err + } + if len(candidateAccts) == 0 { + return nil, fmt.Errorf("illegal access") + } + + var rpcErr error + for i := range candidateAccts { + acct := candidateAccts[i] + result, err := withHTTPRPCConnection(w, acct, func(rpcClient *tenrpc.EncRPCClient) (*R, error) { + var result *R + adjustedArgs := args + if cfg.adjustArgs != nil { + adjustedArgs = cfg.adjustArgs(acct) + } + err := rpcClient.CallContext(ctx, &result, method, adjustedArgs...) + return result, err + }) + if err != nil { + // for calls where we know the expected error we can return early + if cfg.tryUntilAuthorised && err.Error() != notAuthorised { + return nil, err + } + rpcErr = err + continue + } + return result, nil + } + return nil, rpcErr + }) + + audit(w, "RPC call. uid=%s, method=%s args=%v result=%s error=%s time=%d", hexutils.BytesToHex(userID), method, args, res, err, time.Since(requestStartTime).Milliseconds()) + return res, err +} + +func getCandidateAccounts(user *GWUser, _ *Services, cfg *ExecCfg) ([]*GWAccount, error) { + candidateAccts := make([]*GWAccount, 0) + // for users with multiple accounts try to determine a candidate account based on the available information + switch { + case cfg.account != nil: + acc := user.accounts[*cfg.account] + if acc != nil { + candidateAccts = append(candidateAccts, acc) + return candidateAccts, nil + } + + case cfg.computeFromCallback != nil: + suggestedAddress := cfg.computeFromCallback(user) + if suggestedAddress != nil { + acc := user.accounts[*suggestedAddress] + if acc != nil { + candidateAccts = append(candidateAccts, acc) + return candidateAccts, nil + } + } + } + + if cfg.tryAll || cfg.tryUntilAuthorised { + for _, acc := range user.accounts { + candidateAccts = append(candidateAccts, acc) + } + } + + return candidateAccts, nil +} + +func extractUserID(ctx context.Context, _ *Services) ([]byte, error) { + token, ok := ctx.Value(rpc.GWTokenKey{}).(string) + if !ok { + return nil, fmt.Errorf("invalid userid: %s", ctx.Value(rpc.GWTokenKey{})) + } + userID := common.FromHex(token) + if len(userID) != viewingkey.UserIDLength { + return nil, fmt.Errorf("invalid userid: %s", token) + } + return userID, nil +} + +// generateCacheKey generates a cache key for the given method, encryptionToken and parameters +// encryptionToken is used to generate a unique cache key for each user and empty string should be used for public data +func generateCacheKey(params []any) []byte { + // Serialize parameters + rawKey, err := json.Marshal(params) + if err != nil { + return nil + } + + // Optional: Apply hashing + hasher := sha256.New() + hasher.Write(rawKey) + + return hasher.Sum(nil) +} + +func withCache[R any](cache cache.Cache, cfg *CacheCfg, cacheKey []byte, onCacheMiss func() (*R, error)) (*R, error) { + if cfg == nil { + return onCacheMiss() + } + + cacheTTL := cfg.TTL + if cfg.TTLCallback != nil { + cacheTTL = cfg.TTLCallback() + } + isCacheable := cacheTTL > 0 + + if isCacheable { + if cachedValue, ok := cache.Get(cacheKey); ok { + // cloning? + returnValue, ok := cachedValue.(*R) + if !ok { + return nil, fmt.Errorf("unexpected error. Invalid format cached. %v", cachedValue) + } + return returnValue, nil + } + } + + result, err := onCacheMiss() + + // cache only non-nil values + if isCacheable && err == nil && result != nil { + cache.Set(cacheKey, result, cacheTTL) + } + + return result, err +} + +func audit(services *Services, msg string, params ...any) { + if services.Config.VerboseFlag { + services.FileLogger.Info(fmt.Sprintf(msg, params...)) + } +} + +func cacheTTLBlockNumberOrHash(blockNrOrHash rpc.BlockNumberOrHash) time.Duration { + if blockNrOrHash.BlockNumber != nil && blockNrOrHash.BlockNumber.Int64() <= 0 { + return shortCacheTTL + } + return longCacheTTL +} + +func cacheTTLBlockNumber(lastBlock rpc.BlockNumber) time.Duration { + if lastBlock > 0 { + return longCacheTTL + } + return shortCacheTTL +} + +func connectWS(account *GWAccount, logger gethlog.Logger) (*tenrpc.EncRPCClient, error) { + return conn(account.user.services.rpcWSConnPool, account, logger) +} + +func conn(p *pool.ObjectPool, account *GWAccount, logger gethlog.Logger) (*tenrpc.EncRPCClient, error) { + connectionObj, err := p.BorrowObject(context.Background()) + if err != nil { + return nil, fmt.Errorf("cannot fetch rpc connection to backend node %w", err) + } + conn := connectionObj.(*rpc.Client) + encClient, err := wecommon.CreateEncClient(conn, account.address.Bytes(), account.user.userKey, account.signature, account.signatureType, logger) + if err != nil { + return nil, fmt.Errorf("error creating new client, %w", err) + } + return encClient, nil +} + +func returnConn(p *pool.ObjectPool, conn *tenrpc.EncRPCClient) error { + c := conn.Client().(*tenrpc.NetworkClient).RpcClient + return p.ReturnObject(context.Background(), c) +} + +func withHTTPRPCConnection[R any](w *Services, acct *GWAccount, execute func(*tenrpc.EncRPCClient) (*R, error)) (*R, error) { + rpcClient, err := conn(acct.user.services.rpcHTTPConnPool, acct, w.logger) + if err != nil { + return nil, fmt.Errorf("could not connect to backed. Cause: %w", err) + } + defer returnConn(w.rpcHTTPConnPool, rpcClient) + return execute(rpcClient) +} diff --git a/tools/walletextension/rpcapi/wallet_extension.go b/tools/walletextension/rpcapi/wallet_extension.go new file mode 100644 index 0000000000..3b27bbe750 --- /dev/null +++ b/tools/walletextension/rpcapi/wallet_extension.go @@ -0,0 +1,247 @@ +package rpcapi + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "time" + + pool "github.com/jolestar/go-commons-pool/v2" + gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + + "github.com/status-im/keycard-go/hexutils" + + "github.com/ten-protocol/go-ten/tools/walletextension/cache" + + "github.com/ten-protocol/go-ten/go/obsclient" + + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + gethlog "github.com/ethereum/go-ethereum/log" + "github.com/ten-protocol/go-ten/go/common/stopcontrol" + "github.com/ten-protocol/go-ten/go/common/viewingkey" + "github.com/ten-protocol/go-ten/go/rpc" + "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage" +) + +// Services handles the various business logic for the api endpoints +type Services struct { + HostAddrHTTP string // The HTTP address on which the Ten host can be reached + HostAddrWS string // The WS address on which the Ten host can be reached + Storage storage.Storage + logger gethlog.Logger + FileLogger gethlog.Logger + stopControl *stopcontrol.StopControl + version string + tenClient *obsclient.ObsClient + Cache cache.Cache + // the OG maintains a connection pool of rpc connections to underlying nodes + rpcHTTPConnPool *pool.ObjectPool + rpcWSConnPool *pool.ObjectPool + Config *common.Config +} + +func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage, stopControl *stopcontrol.StopControl, version string, logger gethlog.Logger, config *common.Config) *Services { + rpcClient, err := rpc.NewNetworkClient(hostAddrHTTP) + if err != nil { + logger.Error(fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrHTTP, err).Error()) + panic(err) + } + newTenClient := obsclient.NewObsClient(rpcClient) + newFileLogger := common.NewFileLogger() + newGatewayCache, err := cache.NewCache(logger) + if err != nil { + logger.Error(fmt.Errorf("could not create cache. Cause: %w", err).Error()) + panic(err) + } + + factoryHTTP := pool.NewPooledObjectFactory( + func(context.Context) (interface{}, error) { + rpcClient, err := gethrpc.Dial(hostAddrHTTP) + if err != nil { + return nil, fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrHTTP, err) + } + return rpcClient, nil + }, func(ctx context.Context, object *pool.PooledObject) error { + client := object.Object.(*gethrpc.Client) + client.Close() + return nil + }, nil, nil, nil) + + factoryWS := pool.NewPooledObjectFactory( + func(context.Context) (interface{}, error) { + rpcClient, err := gethrpc.Dial(hostAddrWS) + if err != nil { + return nil, fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrWS, err) + } + return rpcClient, nil + }, func(ctx context.Context, object *pool.PooledObject) error { + client := object.Object.(*gethrpc.Client) + client.Close() + return nil + }, nil, nil, nil) + + cfg := pool.NewDefaultPoolConfig() + cfg.MaxTotal = 100 + cfg.MaxTotal = 50 + + return &Services{ + HostAddrHTTP: hostAddrHTTP, + HostAddrWS: hostAddrWS, + Storage: storage, + logger: logger, + FileLogger: newFileLogger, + stopControl: stopControl, + version: version, + tenClient: newTenClient, + Cache: newGatewayCache, + rpcHTTPConnPool: pool.NewObjectPool(context.Background(), factoryHTTP, cfg), + rpcWSConnPool: pool.NewObjectPool(context.Background(), factoryWS, cfg), + Config: config, + } +} + +// IsStopping returns whether the WE is stopping +func (w *Services) IsStopping() bool { + return w.stopControl.IsStopping() +} + +// Logger returns the WE set logger +func (w *Services) Logger() gethlog.Logger { + return w.logger +} + +// GenerateAndStoreNewUser generates new key-pair and userID, stores it in the database and returns hex encoded userID and error +func (w *Services) GenerateAndStoreNewUser() ([]byte, error) { + requestStartTime := time.Now() + // generate new key-pair + viewingKeyPrivate, err := crypto.GenerateKey() + viewingPrivateKeyEcies := ecies.ImportECDSA(viewingKeyPrivate) + if err != nil { + w.Logger().Error(fmt.Sprintf("could not generate new keypair: %s", err)) + return nil, err + } + + // create UserID and store it in the database with the private key + userID := viewingkey.CalculateUserID(common.PrivateKeyToCompressedPubKey(viewingPrivateKeyEcies)) + err = w.Storage.AddUser(userID, crypto.FromECDSA(viewingPrivateKeyEcies.ExportECDSA())) + if err != nil { + w.Logger().Error(fmt.Sprintf("failed to save user to the database: %s", err)) + return nil, err + } + + requestEndTime := time.Now() + duration := requestEndTime.Sub(requestStartTime) + audit(w, "Storing new userID: %s, duration: %d ", hexutils.BytesToHex(userID), duration.Milliseconds()) + return userID, nil +} + +// AddAddressToUser checks if a message is in correct format and if signature is valid. If all checks pass we save address and signature against userID +func (w *Services) AddAddressToUser(userID []byte, address string, signature []byte, signatureType viewingkey.SignatureType) error { + requestStartTime := time.Now() + addressFromMessage := gethcommon.HexToAddress(address) + // check if a message was signed by the correct address and if the signature is valid + recoveredAddress, err := viewingkey.CheckSignature(userID, signature, int64(w.Config.TenChainID), signatureType) + if err != nil { + return fmt.Errorf("signature is not valid: %w", err) + } + + if recoveredAddress.Hex() != addressFromMessage.Hex() { + return fmt.Errorf("invalid request. Signature doesn't match address") + } + + // register the account for that viewing key + err = w.Storage.AddAccount(userID, addressFromMessage.Bytes(), signature, signatureType) + if err != nil { + w.Logger().Error(fmt.Errorf("error while storing account (%s) for user (%s): %w", addressFromMessage.Hex(), userID, err).Error()) + return err + } + + w.Cache.Remove(userCacheKey(userID)) + audit(w, "Storing new address for user: %s, address: %s, duration: %d ", hexutils.BytesToHex(userID), address, time.Since(requestStartTime).Milliseconds()) + return nil +} + +// UserHasAccount checks if provided account exist in the database for given userID +func (w *Services) UserHasAccount(userID []byte, address string) (bool, error) { + audit(w, "Checking if user has account: %s, address: %s", hexutils.BytesToHex(userID), address) + addressBytes, err := hex.DecodeString(address[2:]) // remove 0x prefix from address + if err != nil { + w.Logger().Error(fmt.Errorf("error decoding string (%s), %w", address[2:], err).Error()) + return false, err + } + + // todo - this can be optimised and done in the database if we will have users with large number of accounts + // get all the accounts for the selected user + accounts, err := w.Storage.GetAccounts(userID) + if err != nil { + w.Logger().Error(fmt.Errorf("error getting accounts for user (%s), %w", userID, err).Error()) + return false, err + } + + // check if any of the account matches given account + found := false + for _, account := range accounts { + if bytes.Equal(account.AccountAddress, addressBytes) { + found = true + } + } + return found, nil +} + +// DeleteUser deletes user and accounts associated with user from the database for given userID +func (w *Services) DeleteUser(userID []byte) error { + audit(w, "Deleting user: %s", hexutils.BytesToHex(userID)) + + err := w.Storage.DeleteUser(userID) + if err != nil { + w.Logger().Error(fmt.Errorf("error deleting user (%s), %w", userID, err).Error()) + return err + } + w.Cache.Remove(userCacheKey(userID)) + return nil +} + +func (w *Services) UserExists(userID []byte) bool { + audit(w, "Checking if user exists: %s", userID) + // Check if user exists and don't log error if user doesn't exist, because we expect this to happen in case of + // user revoking encryption token or using different testnet. + // todo add a counter here in the future + key, err := w.Storage.GetUserPrivateKey(userID) + if err != nil { + return false + } + + return len(key) > 0 +} + +func (w *Services) Version() string { + return w.version +} + +func (w *Services) GetTenNodeHealthStatus() (bool, error) { + return w.tenClient.Health() +} + +func (w *Services) UnauthenticatedClient() (rpc.Client, error) { + return rpc.NewNetworkClient(w.HostAddrHTTP) +} + +func (w *Services) GenerateUserMessageToSign(encryptionToken []byte, formatsSlice []string) (string, error) { + // Check if the formats are valid + for _, format := range formatsSlice { + if _, exists := viewingkey.SignatureTypeMap[format]; !exists { + return "", fmt.Errorf("invalid format: %s", format) + } + } + + messageFormat := viewingkey.GetBestFormat(formatsSlice) + message, err := viewingkey.GenerateMessage(encryptionToken, int64(w.Config.TenChainID), viewingkey.PersonalSignVersion, messageFormat) + if err != nil { + return "", fmt.Errorf("error generating message: %w", err) + } + return string(message), nil +} diff --git a/tools/walletextension/subscriptions/subscriptions.go b/tools/walletextension/subscriptions/subscriptions.go deleted file mode 100644 index 9405a524ce..0000000000 --- a/tools/walletextension/subscriptions/subscriptions.go +++ /dev/null @@ -1,200 +0,0 @@ -package subscriptions - -import ( - "context" - "encoding/json" - "fmt" - "sync" - "time" - - "github.com/ethereum/go-ethereum/eth/filters" - - "github.com/go-kit/kit/transport/http/jsonrpc" - - gethlog "github.com/ethereum/go-ethereum/log" - "github.com/ten-protocol/go-ten/go/common" - "github.com/ten-protocol/go-ten/go/common/log" - "github.com/ten-protocol/go-ten/go/rpc" - gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" - wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/userconn" -) - -type SubscriptionManager struct { - subscriptionMappings map[string][]*gethrpc.ClientSubscription - logger gethlog.Logger - mu sync.Mutex -} - -func New(logger gethlog.Logger) *SubscriptionManager { - return &SubscriptionManager{ - subscriptionMappings: make(map[string][]*gethrpc.ClientSubscription), - logger: logger, - } -} - -// HandleNewSubscriptions subscribes to an event with all the clients provided. -// Doing this is necessary because we have relevancy rule, and we want to subscribe sometimes with all clients to get all the events -func (sm *SubscriptionManager) HandleNewSubscriptions(clients []rpc.Client, criteria filters.FilterCriteria, resp *interface{}, userConn userconn.UserConn) error { - sm.logger.Info(fmt.Sprintf("Subscribing to event %s with %d clients", criteria, len(clients))) - - // create subscriptionID which will enable user to unsubscribe from all subscriptions - userSubscriptionID := gethrpc.NewID() - - // create a common channel for subscriptions from all accounts - funnelMultipleAccountsChan := make(chan common.IDAndLog) - - // read from a multiple accounts channel and write results to userConn - go readFromChannelAndWriteToUserConn(funnelMultipleAccountsChan, userConn, userSubscriptionID, sm.logger) - - // iterate over all clients and subscribe for each of them - for _, client := range clients { - subscription, err := client.Subscribe(context.Background(), resp, rpc.SubscribeNamespace, funnelMultipleAccountsChan, rpc.SubscriptionTypeLogs, criteria) - if err != nil { - return fmt.Errorf("could not subscrbie for logs with params %v. Cause: %w", criteria, err) - } - sm.UpdateSubscriptionMapping(string(userSubscriptionID), subscription) - - // We periodically check if the websocket is closed, and terminate the subscription. - // TODO: Check if it will be much more efficient to create just one go routine for all clients together - go sm.checkIfUserConnIsClosedAndUnsubscribe(userConn, subscription, string(userSubscriptionID)) - } - - // We return subscriptionID with resp interface. We want to use userSubscriptionID to allow unsubscribing - *resp = userSubscriptionID - return nil -} - -func readFromChannelAndWriteToUserConn(channel chan common.IDAndLog, userConn userconn.UserConn, userSubscriptionID gethrpc.ID, logger gethlog.Logger) { - buffer := NewCircularBuffer(wecommon.DeduplicationBufferSize) - for data := range channel { - // create unique identifier for current log - uniqueLogKey := LogKey{ - BlockHash: data.Log.BlockHash, - TxHash: data.Log.TxHash, - Index: data.Log.Index, - } - - // check if the current event is a duplicate (and skip it if it is) - if buffer.Contains(uniqueLogKey) { - continue - } - - jsonResponse, err := prepareLogResponse(data, userSubscriptionID) - if err != nil { - logger.Error("could not marshal log response to JSON on subscription.", log.SubIDKey, data.SubID, log.ErrKey, err) - continue - } - - // the current log is unique, and we want to add it to our buffer and proceed with forwarding to the user - buffer.Push(uniqueLogKey) - - logger.Trace(fmt.Sprintf("Forwarding log from Obscuro node: %s", jsonResponse), log.SubIDKey, data.SubID) - err = userConn.WriteResponse(jsonResponse) - if err != nil { - logger.Error("could not write the JSON log to the websocket on subscription %", log.SubIDKey, data.SubID, log.ErrKey, err) - continue - } - } -} - -func (sm *SubscriptionManager) unsubscribeAndRemove(userSubscriptionID string, subscription *gethrpc.ClientSubscription) { - sm.mu.Lock() - defer sm.mu.Unlock() - - subscription.Unsubscribe() - - subscriptions, exists := sm.subscriptionMappings[userSubscriptionID] - if !exists { - sm.logger.Error("subscription that needs to be removed is not present in subscriptionMappings for userSubscriptionID: %s", userSubscriptionID) - return - } - - for i, s := range subscriptions { - if s != subscription { - continue - } - - // Remove the subscription from the slice - lastIndex := len(subscriptions) - 1 - subscriptions[i] = subscriptions[lastIndex] - subscriptions = subscriptions[:lastIndex] - - // If the slice is empty, delete the key from the map - if len(subscriptions) == 0 { - delete(sm.subscriptionMappings, userSubscriptionID) - } else { - sm.subscriptionMappings[userSubscriptionID] = subscriptions - } - break - } -} - -func (sm *SubscriptionManager) checkIfUserConnIsClosedAndUnsubscribe(userConn userconn.UserConn, subscription *gethrpc.ClientSubscription, userSubscriptionID string) { - for !userConn.IsClosed() { - time.Sleep(100 * time.Millisecond) - } - - sm.unsubscribeAndRemove(userSubscriptionID, subscription) -} - -func (sm *SubscriptionManager) UpdateSubscriptionMapping(userSubscriptionID string, subscription *gethrpc.ClientSubscription) { - // Ensure there is no concurrent map writes - sm.mu.Lock() - defer sm.mu.Unlock() - - // Check if the userSubscriptionID already exists in the map - subscriptions, exists := sm.subscriptionMappings[userSubscriptionID] - - // If it doesn't exist, create a new slice for it - if !exists { - subscriptions = []*gethrpc.ClientSubscription{} - } - - // Check if the subscription is already in the slice, if not, add it - subscriptionExists := false - for _, sub := range subscriptions { - if sub == subscription { - subscriptionExists = true - break - } - } - - if !subscriptionExists { - sm.subscriptionMappings[userSubscriptionID] = append(subscriptions, subscription) - } -} - -// Formats the log to be sent as an Eth JSON-RPC response. -func prepareLogResponse(idAndLog common.IDAndLog, userSubscriptionID gethrpc.ID) ([]byte, error) { - paramsMap := make(map[string]interface{}) - paramsMap[wecommon.JSONKeySubscription] = userSubscriptionID - paramsMap[wecommon.JSONKeyResult] = idAndLog.Log - - respMap := make(map[string]interface{}) - respMap[wecommon.JSONKeyRPCVersion] = jsonrpc.Version - respMap[wecommon.JSONKeyMethod] = wecommon.MethodEthSubscription - respMap[wecommon.JSONKeyParams] = paramsMap - - jsonResponse, err := json.Marshal(respMap) - if err != nil { - return nil, fmt.Errorf("could not marshal log response to JSON. Cause: %w", err) - } - return jsonResponse, nil -} - -func (sm *SubscriptionManager) HandleUnsubscribe(userSubscriptionID string, rpcResp *interface{}) { - subscriptions, exists := sm.subscriptionMappings[userSubscriptionID] - if !exists { - *rpcResp = false - return - } - - sm.mu.Lock() - defer sm.mu.Unlock() - for _, sub := range subscriptions { - sub.Unsubscribe() - } - delete(sm.subscriptionMappings, userSubscriptionID) - *rpcResp = true -} diff --git a/tools/walletextension/useraccountmanager/user_account_manager.go b/tools/walletextension/useraccountmanager/user_account_manager.go deleted file mode 100644 index 664c2cd994..0000000000 --- a/tools/walletextension/useraccountmanager/user_account_manager.go +++ /dev/null @@ -1,135 +0,0 @@ -package useraccountmanager - -import ( - "encoding/hex" - "fmt" - "sync" - - "github.com/ten-protocol/go-ten/go/common/viewingkey" - - "github.com/ethereum/go-ethereum/common" - gethlog "github.com/ethereum/go-ethereum/log" - "github.com/ten-protocol/go-ten/go/rpc" - "github.com/ten-protocol/go-ten/tools/walletextension/accountmanager" - wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/storage" -) - -// UserAccountManager is a struct that stores one account manager per user and other required data -type UserAccountManager struct { - userAccountManager map[string]*accountmanager.AccountManager - unauthenticatedClient rpc.Client - storage storage.Storage - hostRPCBinAddrHTTP string - hostRPCBinAddrWS string - logger gethlog.Logger - mu sync.Mutex -} - -func NewUserAccountManager(unauthenticatedClient rpc.Client, logger gethlog.Logger, storage storage.Storage, hostRPCBindAddrHTTP string, hostRPCBindAddrWS string) UserAccountManager { - return UserAccountManager{ - userAccountManager: make(map[string]*accountmanager.AccountManager), - unauthenticatedClient: unauthenticatedClient, - storage: storage, - hostRPCBinAddrHTTP: hostRPCBindAddrHTTP, - hostRPCBinAddrWS: hostRPCBindAddrWS, - logger: logger, - } -} - -// AddAndReturnAccountManager adds new UserAccountManager if it doesn't exist and returns it, if UserAccountManager already exists for that user just return it -func (m *UserAccountManager) AddAndReturnAccountManager(userID string) *accountmanager.AccountManager { - m.mu.Lock() - defer m.mu.Unlock() - existingUserAccountManager, exists := m.userAccountManager[userID] - if exists { - return existingUserAccountManager - } - newAccountManager := accountmanager.NewAccountManager(userID, m.unauthenticatedClient, m.hostRPCBinAddrWS, m.storage, m.logger) - m.userAccountManager[userID] = newAccountManager - return newAccountManager -} - -// GetUserAccountManager retrieves the UserAccountManager associated with the given userID. -// it returns the UserAccountManager and nil error if one exists. -// before returning it checks the database and creates all missing clients for that userID -// (we are not loading all of them at startup to limit the number of established connections) -// If a UserAccountManager does not exist for the userID, it returns nil and an error. -func (m *UserAccountManager) GetUserAccountManager(userID string) (*accountmanager.AccountManager, error) { - m.mu.Lock() - defer m.mu.Unlock() - userAccManager, exists := m.userAccountManager[userID] - if !exists { - return nil, fmt.Errorf("UserAccountManager doesn't exist for user: %s", userID) - } - - // we have userAccountManager as expected. - // now we need to create all clients that don't exist there yet - addressesWithClients := userAccManager.GetAllAddressesWithClients() - - // get all addresses for current userID - userIDbytes, err := hex.DecodeString(userID) - if err != nil { - return nil, err - } - - // log that we don't have a storage, but still return existing userAccountManager - // this should never happen, but is useful for tests - if m.storage == nil { - m.logger.Error("storage is nil in UserAccountManager") - return userAccManager, nil - } - - databaseAccounts, err := m.storage.GetAccounts(userIDbytes) - if err != nil { - return nil, err - } - - userPrivateKey, err := m.storage.GetUserPrivateKey(userIDbytes) - if err != nil { - return nil, err - } - - for _, account := range databaseAccounts { - addressHexString := common.BytesToAddress(account.AccountAddress).Hex() - // check if a client for the current address already exists (and skip it if it does) - if addressAlreadyExists(addressHexString, addressesWithClients) { - continue - } - - // create a new client - encClient, err := wecommon.CreateEncClient(m.hostRPCBinAddrWS, account.AccountAddress, userPrivateKey, account.Signature, viewingkey.SignatureType(account.SignatureType), m.logger) - if err != nil { - m.logger.Error(fmt.Errorf("error creating new client, %w", err).Error()) - } - - // add a client to requested userAccountManager - userAccManager.AddClient(common.BytesToAddress(account.AccountAddress), encClient) - addressesWithClients = append(addressesWithClients, addressHexString) - } - - return userAccManager, nil -} - -// DeleteUserAccountManager removes the UserAccountManager associated with the given userID. -// It returns an error if no UserAccountManager exists for that userID. -func (m *UserAccountManager) DeleteUserAccountManager(userID string) error { - m.mu.Lock() - defer m.mu.Unlock() - _, exists := m.userAccountManager[userID] - if !exists { - return fmt.Errorf("no UserAccountManager exists for userID %s", userID) - } - delete(m.userAccountManager, userID) - return nil -} - -// addressAlreadyExists is a helper function to check if an address is already present in a list of existing addresses -func addressAlreadyExists(str string, list []string) bool { - for _, v := range list { - if v == str { - return true - } - } - return false -} diff --git a/tools/walletextension/useraccountmanager/user_account_manager_test.go b/tools/walletextension/useraccountmanager/user_account_manager_test.go deleted file mode 100644 index 38c94ed85d..0000000000 --- a/tools/walletextension/useraccountmanager/user_account_manager_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package useraccountmanager - -import ( - "testing" - - "github.com/ethereum/go-ethereum/log" - "github.com/ten-protocol/go-ten/go/rpc" -) - -func TestAddingAndGettingUserAccountManagers(t *testing.T) { - unauthedClient, _ := rpc.NewNetworkClient("ws://test") - userAccountManager := NewUserAccountManager(unauthedClient, log.New(), nil, "http://test", "ws://test") - userID1 := "4A6F686E20446F65" - userID2 := "7A65746F65A2676F" - - // Test adding and getting account manager for userID1 - userAccountManager.AddAndReturnAccountManager(userID1) - accManager1, err := userAccountManager.GetUserAccountManager(userID1) - if err != nil { - t.Fatal(err) - } - // We should get error if we try to get Account manager for User2 - _, err = userAccountManager.GetUserAccountManager(userID2) - if err == nil { - t.Fatal("expecting error when trying to get AccountManager for user that doesn't exist.") - } - - // After trying to add new AccountManager for the same user we should get the same instance (not overriding old one) - userAccountManager.AddAndReturnAccountManager(userID1) - accManager1New, err := userAccountManager.GetUserAccountManager(userID1) - if err != nil { - t.Fatal(err) - } - - if accManager1 != accManager1New { - t.Fatal("AccountManagers are not the same after adding new account manager for the same userID") - } - - // We get a new instance of AccountManager when we add it for a new user - userAccountManager.AddAndReturnAccountManager(userID2) - accManager2, err := userAccountManager.GetUserAccountManager(userID2) - if err != nil { - t.Fatal(err) - } - - if accManager1 == accManager2 { - t.Fatal("AccountManagers are the same for two different users") - } -} - -func TestDeletingUserAccountManagers(t *testing.T) { - unauthedClient, _ := rpc.NewNetworkClient("ws://test") - userAccountManager := NewUserAccountManager(unauthedClient, log.New(), nil, "", "") - userID := "user1" - - // Add an account manager for the user - userAccountManager.AddAndReturnAccountManager(userID) - - // Test deleting user account manager - err := userAccountManager.DeleteUserAccountManager(userID) - if err != nil { - t.Fatal(err) - } - - // After deleting, we should get an error if we try to get the user's account manager - _, err = userAccountManager.GetUserAccountManager(userID) - if err == nil { - t.Fatal("expected an error after trying to get a deleted account manager") - } - - // Trying to delete an account manager that doesn't exist should return an error - err = userAccountManager.DeleteUserAccountManager("nonexistentUser") - if err == nil { - t.Fatal("expected an error after trying to delete an account manager that doesn't exist") - } -} diff --git a/tools/walletextension/userconn/user_conn.go b/tools/walletextension/userconn/user_conn.go deleted file mode 100644 index 4acb9ebf2f..0000000000 --- a/tools/walletextension/userconn/user_conn.go +++ /dev/null @@ -1,147 +0,0 @@ -package userconn - -import ( - "fmt" - "io" - "net/http" - "net/url" - "strings" - "sync" - - "github.com/ten-protocol/go-ten/go/common/log" - - gethlog "github.com/ethereum/go-ethereum/log" - - "github.com/gorilla/websocket" -) - -var upgrader = websocket.Upgrader{} // Used to upgrade connections to websocket connections. - -// UserConn represents a connection to a user. -type UserConn interface { - ReadRequest() ([]byte, error) - ReadRequestParams() map[string]string - WriteResponse(msg []byte) error - SupportsSubscriptions() bool - IsClosed() bool - GetHTTPRequest() *http.Request -} - -// Represents a user's connection over HTTP. -type userConnHTTP struct { - resp http.ResponseWriter - req *http.Request - logger gethlog.Logger -} - -// Represents a user's connection websockets. -type userConnWS struct { - conn *websocket.Conn - isClosed bool - logger gethlog.Logger - req *http.Request - mu sync.Mutex -} - -func NewUserConnHTTP(resp http.ResponseWriter, req *http.Request, logger gethlog.Logger) UserConn { - return &userConnHTTP{resp: resp, req: req, logger: logger} -} - -func NewUserConnWS(resp http.ResponseWriter, req *http.Request, logger gethlog.Logger) (UserConn, error) { - // We search all the request's headers. If there's a websocket upgrade header, we upgrade to a websocket connection. - conn, err := upgrader.Upgrade(resp, req, nil) - if err != nil { - err = fmt.Errorf("unable to upgrade to websocket connection - %w", err) - _, _ = resp.Write([]byte(err.Error())) - logger.Error("unable to upgrade to websocket connection", log.ErrKey, err) - return nil, err - } - - return &userConnWS{ - conn: conn, - logger: logger, - req: req, - }, nil -} - -func (h *userConnHTTP) ReadRequest() ([]byte, error) { - body, err := io.ReadAll(h.req.Body) - if err != nil { - return nil, fmt.Errorf("could not read request body: %w", err) - } - return body, nil -} - -func (h *userConnHTTP) WriteResponse(msg []byte) error { - _, err := h.resp.Write(msg) - if err != nil { - return fmt.Errorf("could not write response: %w", err) - } - return nil -} - -func (h *userConnHTTP) SupportsSubscriptions() bool { - return false -} - -func (h *userConnHTTP) IsClosed() bool { - return false -} - -func (h *userConnHTTP) ReadRequestParams() map[string]string { - return getQueryParams(h.req.URL.Query()) -} - -func (h *userConnHTTP) GetHTTPRequest() *http.Request { - return h.req -} - -func (w *userConnWS) ReadRequest() ([]byte, error) { - _, msg, err := w.conn.ReadMessage() - if err != nil { - if websocket.IsCloseError(err) { - w.isClosed = true - } - return nil, fmt.Errorf("could not read request: %w", err) - } - return msg, nil -} - -func (w *userConnWS) WriteResponse(msg []byte) error { - w.mu.Lock() - defer w.mu.Unlock() - - err := w.conn.WriteMessage(websocket.TextMessage, msg) - if err != nil { - if websocket.IsCloseError(err) || strings.Contains(string(msg), "EOF") { - w.isClosed = true - } - return fmt.Errorf("could not write response: %w", err) - } - return nil -} - -func (w *userConnWS) SupportsSubscriptions() bool { - return true -} - -func (w *userConnWS) IsClosed() bool { - return w.isClosed -} - -func (w *userConnWS) ReadRequestParams() map[string]string { - return getQueryParams(w.req.URL.Query()) -} - -func (w *userConnWS) GetHTTPRequest() *http.Request { - return w.req -} - -func getQueryParams(query url.Values) map[string]string { - params := make(map[string]string) - queryParams := query - for key, value := range queryParams { - params[key] = value[0] - } - return params -} diff --git a/tools/walletextension/wallet_extension.go b/tools/walletextension/wallet_extension.go deleted file mode 100644 index 64ceffe2d7..0000000000 --- a/tools/walletextension/wallet_extension.go +++ /dev/null @@ -1,419 +0,0 @@ -package walletextension - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - "time" - - "github.com/ten-protocol/go-ten/tools/walletextension/cache" - - "github.com/ten-protocol/go-ten/tools/walletextension/accountmanager" - - "github.com/ten-protocol/go-ten/tools/walletextension/config" - - "github.com/ten-protocol/go-ten/go/common/log" - "github.com/ten-protocol/go-ten/go/obsclient" - - "github.com/ten-protocol/go-ten/tools/walletextension/useraccountmanager" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/go-kit/kit/transport/http/jsonrpc" - "github.com/ten-protocol/go-ten/go/common/stopcontrol" - "github.com/ten-protocol/go-ten/go/common/viewingkey" - "github.com/ten-protocol/go-ten/go/rpc" - "github.com/ten-protocol/go-ten/tools/walletextension/common" - "github.com/ten-protocol/go-ten/tools/walletextension/storage" - "github.com/ten-protocol/go-ten/tools/walletextension/userconn" - - gethcommon "github.com/ethereum/go-ethereum/common" - gethlog "github.com/ethereum/go-ethereum/log" -) - -// WalletExtension handles the management of viewing keys and the forwarding of Ethereum JSON-RPC requests. -type WalletExtension struct { - hostAddrHTTP string // The HTTP address on which the Ten host can be reached - hostAddrWS string // The WS address on which the Ten host can be reached - userAccountManager *useraccountmanager.UserAccountManager - storage storage.Storage - logger gethlog.Logger - fileLogger gethlog.Logger - stopControl *stopcontrol.StopControl - version string - config *config.Config - tenClient *obsclient.ObsClient - cache cache.Cache -} - -func New( - hostAddrHTTP string, - hostAddrWS string, - userAccountManager *useraccountmanager.UserAccountManager, - storage storage.Storage, - stopControl *stopcontrol.StopControl, - version string, - logger gethlog.Logger, - config *config.Config, -) *WalletExtension { - rpcClient, err := rpc.NewNetworkClient(hostAddrHTTP) - if err != nil { - logger.Error(fmt.Errorf("could not create RPC client on %s. Cause: %w", hostAddrHTTP, err).Error()) - panic(err) - } - newTenClient := obsclient.NewObsClient(rpcClient) - newFileLogger := common.NewFileLogger() - newGatewayCache, err := cache.NewCache(logger) - if err != nil { - logger.Error(fmt.Errorf("could not create cache. Cause: %w", err).Error()) - panic(err) - } - - return &WalletExtension{ - hostAddrHTTP: hostAddrHTTP, - hostAddrWS: hostAddrWS, - userAccountManager: userAccountManager, - storage: storage, - logger: logger, - fileLogger: newFileLogger, - stopControl: stopControl, - version: version, - config: config, - tenClient: newTenClient, - cache: newGatewayCache, - } -} - -// IsStopping returns whether the WE is stopping -func (w *WalletExtension) IsStopping() bool { - return w.stopControl.IsStopping() -} - -// Logger returns the WE set logger -func (w *WalletExtension) Logger() gethlog.Logger { - return w.logger -} - -// ProxyEthRequest proxys an incoming user request to the enclave -func (w *WalletExtension) ProxyEthRequest(request *common.RPCRequest, conn userconn.UserConn, hexUserID string) (map[string]interface{}, error) { - response := map[string]interface{}{} - // all responses must contain the request id. Both successful and unsuccessful. - response[common.JSONKeyRPCVersion] = jsonrpc.Version - response[common.JSONKeyID] = request.ID - - // start measuring time for request - requestStartTime := time.Now() - - // Check if the request is in the cache - isCacheable, key, ttl := cache.IsCacheable(request, hexUserID) - - // in case of cache hit return the response from the cache - if isCacheable { - if value, ok := w.cache.Get(key); ok { - // do a shallow copy of the map to avoid concurrent map iteration and map write - returnValue := make(map[string]interface{}) - for k, v := range value { - returnValue[k] = v - } - - requestEndTime := time.Now() - duration := requestEndTime.Sub(requestStartTime) - // adjust requestID - returnValue[common.JSONKeyID] = request.ID - w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, returnValue, duration.Milliseconds())) - return returnValue, nil - } - } - - // proxyRequest will find the correct client to proxy the request (or try them all if appropriate) - var rpcResp interface{} - - // wallet extension can override the GetStorageAt to retrieve the current userID - if request.Method == rpc.GetStorageAt { - if interceptedResponse := w.getStorageAtInterceptor(request, hexUserID); interceptedResponse != nil { - w.logger.Info("interception successful for getStorageAt, returning userID response") - requestEndTime := time.Now() - duration := requestEndTime.Sub(requestStartTime) - w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, interceptedResponse, duration.Milliseconds())) - return interceptedResponse, nil - } - } - - // check if user is sending a new transaction and if we should store it in the database for debugging purposes - if request.Method == rpc.SendRawTransaction && w.config.StoreIncomingTxs { - userIDBytes, err := common.GetUserIDbyte(hexUserID) - if err != nil { - w.Logger().Error(fmt.Errorf("error decoding string (%s), %w", hexUserID[2:], err).Error()) - return nil, errors.New("error decoding userID. It should be in hex format") - } - err = w.storage.StoreTransaction(request.Params[0].(string), userIDBytes) - if err != nil { - w.Logger().Error(fmt.Errorf("error storing transaction in the database: %w", err).Error()) - return nil, err - } - } - - // get account manager for current user (if there is no users in the query parameters - use defaultUser for WE endpoints) - selectedAccountManager, err := w.userAccountManager.GetUserAccountManager(hexUserID) - if err != nil { - w.Logger().Error(fmt.Errorf("error getting accountManager for user (%s), %w", hexUserID, err).Error()) - return nil, err - } - - err = selectedAccountManager.ProxyRequest(request, &rpcResp, conn) - if err != nil { - return nil, err - } - - response[common.JSONKeyResult] = rpcResp - - if rpcResp != nil { - // todo (@ziga) - fix this upstream on the decode - // https://github.com/ethereum/EIPs/blob/master/EIPS/eip-658.md - adjustStateRoot(rpcResp, response) - } - - requestEndTime := time.Now() - duration := requestEndTime.Sub(requestStartTime) - w.fileLogger.Info(fmt.Sprintf("Request method: %s, request params: %s, encryptionToken of sender: %s, response: %s, duration: %d ", request.Method, request.Params, hexUserID, response, duration.Milliseconds())) - - // if the request is cacheable, store the response in the cache - if isCacheable { - w.cache.Set(key, response, ttl) - } - - return response, nil -} - -// GenerateAndStoreNewUser generates new key-pair and userID, stores it in the database and returns hex encoded userID and error -func (w *WalletExtension) GenerateAndStoreNewUser() (string, error) { - requestStartTime := time.Now() - // generate new key-pair - viewingKeyPrivate, err := crypto.GenerateKey() - viewingPrivateKeyEcies := ecies.ImportECDSA(viewingKeyPrivate) - if err != nil { - w.Logger().Error(fmt.Sprintf("could not generate new keypair: %s", err)) - return "", err - } - - // create UserID and store it in the database with the private key - userID := viewingkey.CalculateUserID(common.PrivateKeyToCompressedPubKey(viewingPrivateKeyEcies)) - err = w.storage.AddUser(userID, crypto.FromECDSA(viewingPrivateKeyEcies.ExportECDSA())) - if err != nil { - w.Logger().Error(fmt.Sprintf("failed to save user to the database: %s", err)) - return "", err - } - - hexUserID := hex.EncodeToString(userID) - - w.userAccountManager.AddAndReturnAccountManager(hexUserID) - requestEndTime := time.Now() - duration := requestEndTime.Sub(requestStartTime) - w.fileLogger.Info(fmt.Sprintf("Storing new userID: %s, duration: %d ", hexUserID, duration.Milliseconds())) - return hexUserID, nil -} - -// AddAddressToUser checks if a message is in correct format and if signature is valid. If all checks pass we save address and signature against userID -func (w *WalletExtension) AddAddressToUser(hexUserID string, address string, signature []byte, signatureType viewingkey.SignatureType) error { - requestStartTime := time.Now() - addressFromMessage := gethcommon.HexToAddress(address) - // check if a message was signed by the correct address and if the signature is valid - _, err := viewingkey.CheckSignature(hexUserID, signature, int64(w.config.TenChainID), signatureType) - if err != nil { - return fmt.Errorf("signature is not valid: %w", err) - } - - // register the account for that viewing key - userIDBytes, err := common.GetUserIDbyte(hexUserID) - if err != nil { - w.Logger().Error(fmt.Errorf("error decoding string (%s), %w", hexUserID[2:], err).Error()) - return errors.New("error decoding userID. It should be in hex format") - } - err = w.storage.AddAccount(userIDBytes, addressFromMessage.Bytes(), signature, signatureType) - if err != nil { - w.Logger().Error(fmt.Errorf("error while storing account (%s) for user (%s): %w", addressFromMessage.Hex(), hexUserID, err).Error()) - return err - } - - // Get account manager for current userID (and create it if it doesn't exist) - privateKeyBytes, err := w.storage.GetUserPrivateKey(userIDBytes) - if err != nil { - w.Logger().Error(fmt.Errorf("error getting private key for user: (%s), %w", hexUserID, err).Error()) - } - - accManager := w.userAccountManager.AddAndReturnAccountManager(hexUserID) - - encClient, err := common.CreateEncClient(w.hostAddrHTTP, addressFromMessage.Bytes(), privateKeyBytes, signature, signatureType, w.Logger()) - if err != nil { - w.Logger().Error(fmt.Errorf("error creating encrypted client for user: (%s), %w", hexUserID, err).Error()) - return fmt.Errorf("error creating encrypted client for user: (%s), %w", hexUserID, err) - } - - accManager.AddClient(addressFromMessage, encClient) - requestEndTime := time.Now() - duration := requestEndTime.Sub(requestStartTime) - w.fileLogger.Info(fmt.Sprintf("Storing new address for user: %s, address: %s, duration: %d ", hexUserID, address, duration.Milliseconds())) - return nil -} - -// UserHasAccount checks if provided account exist in the database for given userID -func (w *WalletExtension) UserHasAccount(hexUserID string, address string) (bool, error) { - w.fileLogger.Info(fmt.Sprintf("Checkinf if user has account: %s, address: %s", hexUserID, address)) - userIDBytes, err := common.GetUserIDbyte(hexUserID) - if err != nil { - w.Logger().Error(fmt.Errorf("error decoding string (%s), %w", hexUserID[2:], err).Error()) - return false, err - } - - addressBytes, err := hex.DecodeString(address[2:]) // remove 0x prefix from address - if err != nil { - w.Logger().Error(fmt.Errorf("error decoding string (%s), %w", address[2:], err).Error()) - return false, err - } - - // todo - this can be optimised and done in the database if we will have users with large number of accounts - // get all the accounts for the selected user - accounts, err := w.storage.GetAccounts(userIDBytes) - if err != nil { - w.Logger().Error(fmt.Errorf("error getting accounts for user (%s), %w", hexUserID, err).Error()) - return false, err - } - - // check if any of the account matches given account - found := false - for _, account := range accounts { - if bytes.Equal(account.AccountAddress, addressBytes) { - found = true - } - } - return found, nil -} - -// DeleteUser deletes user and accounts associated with user from the database for given userID -func (w *WalletExtension) DeleteUser(hexUserID string) error { - w.fileLogger.Info(fmt.Sprintf("Deleting user: %s", hexUserID)) - userIDBytes, err := common.GetUserIDbyte(hexUserID) - if err != nil { - w.Logger().Error(fmt.Errorf("error decoding string (%s), %w", hexUserID, err).Error()) - return err - } - - err = w.storage.DeleteUser(userIDBytes) - if err != nil { - w.Logger().Error(fmt.Errorf("error deleting user (%s), %w", hexUserID, err).Error()) - return err - } - - // Delete UserAccountManager for user that revoked userID - err = w.userAccountManager.DeleteUserAccountManager(hexUserID) - if err != nil { - w.Logger().Error(fmt.Errorf("error deleting UserAccointManager for user (%s), %w", hexUserID, err).Error()) - } - - return nil -} - -func (w *WalletExtension) UserExists(hexUserID string) bool { - w.fileLogger.Info(fmt.Sprintf("Checking if user exists: %s", hexUserID)) - userIDBytes, err := common.GetUserIDbyte(hexUserID) - if err != nil { - w.Logger().Error(fmt.Errorf("error decoding string (%s), %w", hexUserID, err).Error()) - return false - } - - // Check if user exists and don't log error if user doesn't exist, because we expect this to happen in case of - // user revoking encryption token or using different testnet. - // todo add a counter here in the future - key, err := w.storage.GetUserPrivateKey(userIDBytes) - if err != nil { - return false - } - - return len(key) > 0 -} - -func adjustStateRoot(rpcResp interface{}, respMap map[string]interface{}) { - if resultMap, ok := rpcResp.(map[string]interface{}); ok { - if val, foundRoot := resultMap[common.JSONKeyRoot]; foundRoot { - if val == "0x" { - respMap[common.JSONKeyResult].(map[string]interface{})[common.JSONKeyRoot] = nil - } - } - } -} - -// getStorageAtInterceptor checks if the parameters for getStorageAt are set to values that require interception -// and return response or nil if the gateway should forward the request to the node. -func (w *WalletExtension) getStorageAtInterceptor(request *common.RPCRequest, hexUserID string) map[string]interface{} { - // check if parameters are correct, and we can intercept a request, otherwise return nil - if w.checkParametersForInterceptedGetStorageAt(request.Params) { - // check if userID in the parameters is also in our database - userID, err := common.GetUserIDbyte(hexUserID) - if err != nil { - w.logger.Warn("GetStorageAt called with appropriate parameters to return userID, but not found in the database: ", "userId", hexUserID) - return nil - } - - // check if we have default user (we don't want to send userID of it out) - if hexUserID == hex.EncodeToString([]byte(common.DefaultUser)) { - response := map[string]interface{}{} - response[common.JSONKeyRPCVersion] = jsonrpc.Version - response[common.JSONKeyID] = request.ID - response[common.JSONKeyResult] = fmt.Sprintf(accountmanager.ErrNoViewingKey, "eth_getStorageAt") - return response - } - - _, err = w.storage.GetUserPrivateKey(userID) - if err != nil { - w.logger.Info("Trying to get userID, but it is not present in our database: ", log.ErrKey, err) - return nil - } - response := map[string]interface{}{} - response[common.JSONKeyRPCVersion] = jsonrpc.Version - response[common.JSONKeyID] = request.ID - response[common.JSONKeyResult] = hexUserID - return response - } - w.logger.Info(fmt.Sprintf("parameters used in the request do not match requited parameters for interception: %s", request.Params)) - - return nil -} - -// checkParametersForInterceptedGetStorageAt checks -// if parameters for getStorageAt are in the correct format to intercept the function -func (w *WalletExtension) checkParametersForInterceptedGetStorageAt(params []interface{}) bool { - if len(params) != 3 { - w.logger.Info(fmt.Sprintf("getStorageAt expects 3 parameters, but %d received", len(params))) - return false - } - - if methodName, ok := params[0].(string); ok { - return methodName == common.GetStorageAtUserIDRequestMethodName - } - return false -} - -func (w *WalletExtension) Version() string { - return w.version -} - -func (w *WalletExtension) GetTenNodeHealthStatus() (bool, error) { - return w.tenClient.Health() -} - -func (w *WalletExtension) GenerateUserMessageToSign(encryptionToken string, formatsSlice []string) (string, error) { - // Check if the formats are valid - for _, format := range formatsSlice { - if _, exists := viewingkey.SignatureTypeMap[format]; !exists { - return "", fmt.Errorf("invalid format: %s", format) - } - } - - messageFormat := viewingkey.GetBestFormat(formatsSlice) - message, err := viewingkey.GenerateMessage(encryptionToken, int64(w.config.TenChainID), viewingkey.PersonalSignVersion, messageFormat) - if err != nil { - return "", fmt.Errorf("error generating message: %w", err) - } - return string(message), nil -} diff --git a/tools/walletextension/walletextension_container.go b/tools/walletextension/walletextension_container.go new file mode 100644 index 0000000000..505832a98e --- /dev/null +++ b/tools/walletextension/walletextension_container.go @@ -0,0 +1,136 @@ +package walletextension + +import ( + "net/http" + "os" + "time" + + "github.com/ten-protocol/go-ten/tools/walletextension/api" + + "github.com/ten-protocol/go-ten/tools/walletextension/httpapi" + + "github.com/ten-protocol/go-ten/tools/walletextension/rpcapi" + + "github.com/ten-protocol/go-ten/lib/gethfork/node" + + gethlog "github.com/ethereum/go-ethereum/log" + "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/go/common/stopcontrol" + gethrpc "github.com/ten-protocol/go-ten/lib/gethfork/rpc" + wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage" +) + +type Container struct { + stopControl *stopcontrol.StopControl + logger gethlog.Logger + rpcServer node.Server +} + +func NewContainerFromConfig(config wecommon.Config, logger gethlog.Logger) *Container { + // create the account manager with a single unauthenticated connection + hostRPCBindAddrWS := wecommon.WSProtocol + config.NodeRPCWebsocketAddress + hostRPCBindAddrHTTP := wecommon.HTTPProtocol + config.NodeRPCHTTPAddress + // start the database + databaseStorage, err := storage.New(config.DBType, config.DBConnectionURL, config.DBPathOverride) + if err != nil { + logger.Crit("unable to create database to store viewing keys ", log.ErrKey, err) + os.Exit(1) + } + + // captures version in the env vars + version := os.Getenv("OBSCURO_GATEWAY_VERSION") + if version == "" { + version = "dev" + } + + stopControl := stopcontrol.New() + walletExt := rpcapi.NewServices(hostRPCBindAddrHTTP, hostRPCBindAddrWS, databaseStorage, stopControl, version, logger, &config) + cfg := &node.RPCConfig{ + EnableHTTP: true, + HTTPPort: config.WalletExtensionPortHTTP, + EnableWs: true, + WsPort: config.WalletExtensionPortWS, + WsPath: wecommon.APIVersion1 + "/", + HTTPPath: wecommon.APIVersion1 + "/", + Host: config.WalletExtensionHost, + } + rpcServer := node.NewServer(cfg, logger) + + rpcServer.RegisterRoutes(httpapi.NewHTTPRoutes(walletExt)) + + // register all RPC endpoints exposed by a typical Geth node + rpcServer.RegisterAPIs([]gethrpc.API{ + { + Namespace: "eth", + Service: rpcapi.NewEthereumAPI(walletExt), + }, { + Namespace: "eth", + Service: rpcapi.NewBlockChainAPI(walletExt), + }, { + Namespace: "eth", + Service: rpcapi.NewTransactionAPI(walletExt), + }, { + Namespace: "txpool", + Service: rpcapi.NewTxPoolAPI(walletExt), + }, { + Namespace: "debug", + Service: rpcapi.NewDebugAPI(walletExt), + }, { + Namespace: "eth", + Service: rpcapi.NewFilterAPI(walletExt), + }, + }) + + // register the static files + // todo - remove this when the frontend is no longer served from the enclave + staticHandler := api.StaticFilesHandler(wecommon.PathStatic) + rpcServer.RegisterRoutes([]node.Route{{ + Name: wecommon.PathStatic, + Func: func(resp http.ResponseWriter, req *http.Request) { + staticHandler.ServeHTTP(resp, req) + }, + }}) + + return NewWalletExtensionContainer( + stopControl, + rpcServer, + logger, + ) +} + +func NewWalletExtensionContainer( + stopControl *stopcontrol.StopControl, + rpcServer node.Server, + logger gethlog.Logger, +) *Container { + return &Container{ + stopControl: stopControl, + rpcServer: rpcServer, + logger: logger, + } +} + +// Start starts the wallet extension container +func (w *Container) Start() error { + err := w.rpcServer.Start() + if err != nil { + return err + } + return nil +} + +func (w *Container) Stop() error { + w.stopControl.Stop() + + if w.rpcServer != nil { + // rpc server cannot be stopped synchronously as it will kill current request + go func() { + // make sure it's not killing the connection before returning the response + time.Sleep(time.Second) // todo review this sleep + w.rpcServer.Stop() + }() + } + + return nil +}