diff --git a/go/common/viewingkey/viewing_key.go b/go/common/viewingkey/viewing_key.go index c46c285585..464b29ca81 100644 --- a/go/common/viewingkey/viewing_key.go +++ b/go/common/viewingkey/viewing_key.go @@ -3,58 +3,15 @@ package viewingkey import ( "crypto/ecdsa" "encoding/hex" - "errors" "fmt" - "math/big" "github.com/ethereum/go-ethereum/accounts" - "github.com/ethereum/go-ethereum/common/math" + gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/signer/core/apitypes" "github.com/ten-protocol/go-ten/go/wallet" - - gethcommon "github.com/ethereum/go-ethereum/common" -) - -// SignedMsgPrefix is the prefix added when signing the viewing key in MetaMask using the personal_sign -// API. Why is this needed? MetaMask has a security feature whereby if you ask it to sign something that looks like -// a transaction using the personal_sign API, it modifies the data being signed. The goal is to prevent hackers -// from asking a visitor to their website to personal_sign something that is actually a malicious transaction (e.g. -// theft of funds). By adding a prefix, the viewing key bytes no longer looks like a transaction hash, and thus get -// signed as-is. -const SignedMsgPrefix = "vk" - -const ( - EIP712Domain = "EIP712Domain" - EIP712Type = "Authentication" - EIP712DomainName = "name" - EIP712DomainVersion = "version" - EIP712DomainChainID = "chainId" - EIP712EncryptionToken = "Encryption Token" - EIP712DomainNameValue = "Ten" - EIP712DomainVersionValue = "1.0" - UserIDHexLength = 40 - PersonalSignMessageFormat = "Token: %s on chain: %d version:%d" ) -const ( - EIP712Signature SignatureType = 0 - PersonalSign SignatureType = 1 - Legacy SignatureType = 2 -) - -// EIP712EncryptionTokens is a list of all possible options for Encryption token name -var EIP712EncryptionTokens = [...]string{ - EIP712EncryptionToken, -} - -// PersonalSignMessageSupportedVersions is a list of supported versions for the personal sign message -var PersonalSignMessageSupportedVersions = []int{1} - -// SignatureType is used to differentiate between different signature types (string is used, because int is not RLP-serializable) -type SignatureType uint8 - // ViewingKey encapsulates the signed viewing key for an account for use in encrypted communication with an enclave. // It is the client-side perspective of the viewing key used for decrypting incoming traffic. type ViewingKey struct { @@ -75,105 +32,6 @@ type RPCSignedViewingKey struct { SignatureType SignatureType } -// SignatureChecker is an interface for checking -// if signature is valid for provided encryptionToken and chainID and return singing address or nil if not valid -type SignatureChecker interface { - CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) -} - -type ( - PersonalSignChecker struct{} - EIP712Checker struct{} - LegacyChecker struct{} -) - -// CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid -func (psc PersonalSignChecker) CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) { - if len(signature) != 65 { - return nil, fmt.Errorf("invalid signaure length: %d", len(signature)) - } - // We transform the V from 27/28 to 0/1. This same change is made in Geth internals, for legacy reasons to be able - // to recover the address: https://github.com/ethereum/go-ethereum/blob/55599ee95d4151a2502465e0afc7c47bd1acba77/internal/ethapi/api.go#L452-L459 - if signature[64] == 27 || signature[64] == 28 { - signature[64] -= 27 - } - - // create all possible hashes (for all the supported versions) of the message (needed for signature verification) - for _, version := range PersonalSignMessageSupportedVersions { - message := GeneratePersonalSignMessage(encryptionToken, chainID, version) - messageHash := accounts.TextHash([]byte(message)) - - // current signature is valid - return account address - address, err := CheckSignatureAndReturnAccountAddress(messageHash, signature) - if err == nil { - return address, nil - } - } - - return nil, fmt.Errorf("signature verification failed") -} - -func (e EIP712Checker) CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) { - if len(signature) != 65 { - return nil, fmt.Errorf("invalid signaure length: %d", len(signature)) - } - - rawDataOptions, err := GenerateAuthenticationEIP712RawDataOptions(encryptionToken, chainID) - if err != nil { - return nil, fmt.Errorf("cannot generate eip712 message. Cause %w", err) - } - - // We transform the V from 27/28 to 0/1. This same change is made in Geth internals, for legacy reasons to be able - // to recover the address: https://github.com/ethereum/go-ethereum/blob/55599ee95d4151a2502465e0afc7c47bd1acba77/internal/ethapi/api.go#L452-L459 - if signature[64] == 27 || signature[64] == 28 { - signature[64] -= 27 - } - - for _, rawData := range rawDataOptions { - // create a hash of structured message (needed for signature verification) - hashBytes := crypto.Keccak256(rawData) - - // current signature is valid - return account address - address, err := CheckSignatureAndReturnAccountAddress(hashBytes, signature) - if err == nil { - return address, nil - } - } - return nil, errors.New("EIP 712 signature verification failed") -} - -// CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid -// todo (@ziga) Remove this method once old WE endpoints are removed -// encryptionToken is expected to be a public key and not encrypted token as with other signature types -// (since this is only temporary fix and legacy format will be removed soon) -func (lsc LegacyChecker) CheckSignature(encryptionToken string, signature []byte, _ int64) (*gethcommon.Address, error) { - publicKey := []byte(encryptionToken) - msgToSignLegacy := GenerateSignMessage(publicKey) - - recoveredAccountPublicKeyLegacy, err := crypto.SigToPub(accounts.TextHash([]byte(msgToSignLegacy)), signature) - if err != nil { - return nil, fmt.Errorf("failed to recover account public key from legacy signature: %w", err) - } - recoveredAccountAddressLegacy := crypto.PubkeyToAddress(*recoveredAccountPublicKeyLegacy) - return &recoveredAccountAddressLegacy, nil -} - -// SignatureChecker is a map of SignatureType to SignatureChecker -var signatureCheckers = map[SignatureType]SignatureChecker{ - PersonalSign: PersonalSignChecker{}, - EIP712Signature: EIP712Checker{}, - Legacy: LegacyChecker{}, -} - -// CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid -func CheckSignature(encryptionToken string, signature []byte, chainID int64, signatureType SignatureType) (*gethcommon.Address, error) { - checker, exists := signatureCheckers[signatureType] - if !exists { - return nil, fmt.Errorf("unsupported signature type") - } - return checker.CheckSignature(encryptionToken, signature, chainID) -} - // GenerateViewingKeyForWallet takes an account wallet, generates a viewing key and signs the key with the acc's private key // uses the same method of signature handling as Metamask/geth // TODO @Ziga - update this method to use the new EIP-712 signature format / personal sign after the removal of the legacy format @@ -241,116 +99,3 @@ func Sign(userPrivKey *ecdsa.PrivateKey, vkPubKey []byte) ([]byte, error) { } return signature, nil } - -// GenerateSignMessage creates the message to be signed -// vkPubKey is expected to be a []byte("0x....") to create the signing message -// todo (@ziga) Remove this method once old WE endpoints are removed -func GenerateSignMessage(vkPubKey []byte) string { - return SignedMsgPrefix + hex.EncodeToString(vkPubKey) -} - -func GeneratePersonalSignMessage(encryptionToken string, chainID int64, version int) string { - return fmt.Sprintf(PersonalSignMessageFormat, encryptionToken, chainID, version) -} - -// getBytesFromTypedData creates EIP-712 compliant hash from typedData. -// It involves hashing the message with its structure, hashing domain separator, -// and then encoding both hashes with specific EIP-712 bytes to construct the final message format. -func getBytesFromTypedData(typedData apitypes.TypedData) ([]byte, error) { - typedDataHash, err := typedData.HashStruct(typedData.PrimaryType, typedData.Message) - if err != nil { - return nil, err - } - // Create the domain separator hash for EIP-712 message context - domainSeparator, err := typedData.HashStruct(EIP712Domain, typedData.Domain.Map()) - if err != nil { - return nil, err - } - // Prefix domain and message hashes with EIP-712 version and encoding bytes - rawData := append([]byte("\x19\x01"), append(domainSeparator, typedDataHash...)...) - return rawData, nil -} - -// GenerateAuthenticationEIP712RawDataOptions generates all the options or raw data messages (bytes) -// for an EIP-712 message used to authenticate an address with user -// (currently only one option is supported, but function leaves room for future expansion of options) -func GenerateAuthenticationEIP712RawDataOptions(userID string, chainID int64) ([][]byte, error) { - if len(userID) != UserIDHexLength { - return nil, fmt.Errorf("userID hex length must be %d, received %d", UserIDHexLength, len(userID)) - } - encryptionToken := "0x" + userID - - domain := apitypes.TypedDataDomain{ - Name: EIP712DomainNameValue, - Version: EIP712DomainVersionValue, - ChainId: (*math.HexOrDecimal256)(big.NewInt(chainID)), - } - - message := map[string]interface{}{ - EIP712EncryptionToken: encryptionToken, - } - - types := apitypes.Types{ - EIP712Domain: { - {Name: EIP712DomainName, Type: "string"}, - {Name: EIP712DomainVersion, Type: "string"}, - {Name: EIP712DomainChainID, Type: "uint256"}, - }, - EIP712Type: { - {Name: EIP712EncryptionToken, Type: "address"}, - }, - } - - newTypeElement := apitypes.TypedData{ - Types: types, - PrimaryType: EIP712Type, - Domain: domain, - Message: message, - } - - rawDataOptions := make([][]byte, 0) - rawData, err := getBytesFromTypedData(newTypeElement) - if err != nil { - return nil, err - } - rawDataOptions = append(rawDataOptions, rawData) - - return rawDataOptions, nil -} - -// CalculateUserIDHex CalculateUserID calculates userID from a public key -// (we truncate it, because we want it to have length 20) and encode to hex strings -func CalculateUserIDHex(publicKeyBytes []byte) string { - return hex.EncodeToString(CalculateUserID(publicKeyBytes)) -} - -// CalculateUserID calculates userID from a public key (we truncate it, because we want it to have length 20) -func CalculateUserID(publicKeyBytes []byte) []byte { - return crypto.Keccak256Hash(publicKeyBytes).Bytes()[:20] -} - -// CheckSignatureAndReturnAccountAddress checks if the signature is valid for hash of the message and checks if -// signer is an address provided to the function. -// It returns an address if the signature is valid and nil otherwise -func CheckSignatureAndReturnAccountAddress(hashBytes []byte, signature []byte) (*gethcommon.Address, error) { - pubKeyBytes, err := crypto.Ecrecover(hashBytes, signature) - if err != nil { - return nil, err - } - - pubKey, err := crypto.UnmarshalPubkey(pubKeyBytes) - if err != nil { - return nil, err - } - - r := new(big.Int).SetBytes(signature[:32]) - s := new(big.Int).SetBytes(signature[32:64]) - - // Verify the signature and return the result (all the checks above passed) - isSigValid := ecdsa.Verify(pubKey, hashBytes, r, s) - if isSigValid { - address := crypto.PubkeyToAddress(*pubKey) - return &address, nil - } - return nil, fmt.Errorf("invalid signature") -} diff --git a/go/common/viewingkey/viewing_key_messages.go b/go/common/viewingkey/viewing_key_messages.go new file mode 100644 index 0000000000..036a82a99b --- /dev/null +++ b/go/common/viewingkey/viewing_key_messages.go @@ -0,0 +1,223 @@ +package viewingkey + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/signer/core/apitypes" +) + +const ( + EIP712Signature SignatureType = 0 + PersonalSign SignatureType = 1 + Legacy SignatureType = 2 +) + +// SignatureType is used to differentiate between different signature types (string is used, because int is not RLP-serializable) +type SignatureType uint8 + +const ( + EIP712Domain = "EIP712Domain" + EIP712Type = "Authentication" + EIP712DomainName = "name" + EIP712DomainVersion = "version" + EIP712DomainChainID = "chainId" + EIP712EncryptionToken = "Encryption Token" + EIP712DomainNameValue = "Ten" + EIP712DomainVersionValue = "1.0" + UserIDHexLength = 40 + PersonalSignMessageFormat = "Token: %s on chain: %d version: %d" + SignedMsgPrefix = "vk" // prefix for legacy signed messages (remove when legacy signature type is removed) + PersonalSignVersion = 1 +) + +// EIP712EncryptionTokens is a list of all possible options for Encryption token name +var EIP712EncryptionTokens = [...]string{ + EIP712EncryptionToken, +} + +type MessageGenerator interface { + generateMessage(encryptionToken string, chainID int64, version int) ([]byte, error) +} + +type ( + PersonalMessageGenerator struct{} + EIP712MessageGenerator struct{} +) + +var messageGenerators = map[SignatureType]MessageGenerator{ + PersonalSign: PersonalMessageGenerator{}, + EIP712Signature: EIP712MessageGenerator{}, +} + +// GenerateMessage generates a message for the given encryptionToken, chainID, version and signatureType +func (p PersonalMessageGenerator) generateMessage(encryptionToken string, chainID int64, version int) ([]byte, error) { + return []byte(fmt.Sprintf(PersonalSignMessageFormat, encryptionToken, chainID, version)), nil +} + +func (e EIP712MessageGenerator) generateMessage(encryptionToken string, chainID int64, _ int) ([]byte, error) { + if len(encryptionToken) != UserIDHexLength { + return nil, fmt.Errorf("userID hex length must be %d, received %d", UserIDHexLength, len(encryptionToken)) + } + EIP712TypedData := createTypedDataForEIP712Message(encryptionToken, chainID) + + // add the JSON message to the list of messages + jsonData, err := json.Marshal(EIP712TypedData) + if err != nil { + return nil, err + } + return jsonData, nil +} + +// GenerateMessage generates a message for the given encryptionToken, chainID, version and signatureType +func GenerateMessage(encryptionToken string, chainID int64, version int, signatureType SignatureType) ([]byte, error) { + generator, exists := messageGenerators[signatureType] + if !exists { + return nil, fmt.Errorf("unsupported signature type") + } + return generator.generateMessage(encryptionToken, chainID, version) +} + +// MessageHash is an interface for getting the hash of the message +type MessageHash interface { + getMessageHash(message []byte) []byte +} + +type ( + PersonalMessageHash struct{} + EIP712MessageHash struct{} +) + +var messageHash = map[SignatureType]MessageHash{ + PersonalSign: PersonalMessageHash{}, + EIP712Signature: EIP712MessageHash{}, +} + +// getMessageHash returns the hash for the personal message +func (p PersonalMessageHash) getMessageHash(message []byte) []byte { + return accounts.TextHash(message) +} + +// getMessageHash returns the hash for the EIP712 message +func (E EIP712MessageHash) getMessageHash(message []byte) []byte { + var EIP712TypedData apitypes.TypedData + err := json.Unmarshal(message, &EIP712TypedData) + if err != nil { + return nil + } + + rawData, err := getBytesFromTypedData(EIP712TypedData) + if err != nil { + return nil + } + return crypto.Keccak256(rawData) +} + +// GetMessageHash returns the hash of the message based on the signature type +func GetMessageHash(message []byte, signatureType SignatureType) ([]byte, error) { + hashFunction, exists := messageHash[signatureType] + if !exists { + return nil, fmt.Errorf("unsupported signature type") + } + return hashFunction.getMessageHash(message), nil +} + +// GenerateSignMessage creates the message to be signed +// vkPubKey is expected to be a []byte("0x....") to create the signing message +// todo (@ziga) Remove this method once old WE endpoints are removed +func GenerateSignMessage(vkPubKey []byte) string { + return SignedMsgPrefix + hex.EncodeToString(vkPubKey) +} + +// getBytesFromTypedData creates EIP-712 compliant hash from typedData. +// It involves hashing the message with its structure, hashing domain separator, +// and then encoding both hashes with specific EIP-712 bytes to construct the final message format. +func getBytesFromTypedData(typedData apitypes.TypedData) ([]byte, error) { + typedDataHash, err := typedData.HashStruct(typedData.PrimaryType, typedData.Message) + if err != nil { + return nil, err + } + // Create the domain separator hash for EIP-712 message context + domainSeparator, err := typedData.HashStruct(EIP712Domain, typedData.Domain.Map()) + if err != nil { + return nil, err + } + // Prefix domain and message hashes with EIP-712 version and encoding bytes + rawData := append([]byte("\x19\x01"), append(domainSeparator, typedDataHash...)...) + return rawData, nil +} + +// createTypedDataForEIP712Message creates typed data for EIP712 message +func createTypedDataForEIP712Message(encryptionToken string, chainID int64) apitypes.TypedData { + encryptionToken = "0x" + encryptionToken + + domain := apitypes.TypedDataDomain{ + Name: EIP712DomainNameValue, + Version: EIP712DomainVersionValue, + ChainId: (*math.HexOrDecimal256)(big.NewInt(chainID)), + } + + message := map[string]interface{}{ + EIP712EncryptionToken: encryptionToken, + } + + types := apitypes.Types{ + EIP712Domain: { + {Name: EIP712DomainName, Type: "string"}, + {Name: EIP712DomainVersion, Type: "string"}, + {Name: EIP712DomainChainID, Type: "uint256"}, + }, + EIP712Type: { + {Name: EIP712EncryptionToken, Type: "address"}, + }, + } + + typedData := apitypes.TypedData{ + Types: types, + PrimaryType: EIP712Type, + Domain: domain, + Message: message, + } + return typedData +} + +// CalculateUserIDHex CalculateUserID calculates userID from a public key +// (we truncate it, because we want it to have length 20) and encode to hex strings +func CalculateUserIDHex(publicKeyBytes []byte) string { + return hex.EncodeToString(CalculateUserID(publicKeyBytes)) +} + +// CalculateUserID calculates userID from a public key (we truncate it, because we want it to have length 20) +func CalculateUserID(publicKeyBytes []byte) []byte { + return crypto.Keccak256Hash(publicKeyBytes).Bytes()[:20] +} + +// GetBestFormat returns the best format for a message based on available formats that are supported by the user +func GetBestFormat(formatsSlice []string) SignatureType { + // If "Personal" is the only format available, choose it + if len(formatsSlice) == 1 && formatsSlice[0] == "Personal" { + return PersonalSign + } + + // otherwise, choose EIP712 + return EIP712Signature +} + +func GetSignatureTypeString(expectedSignatureType SignatureType) string { + for key, value := range SignatureTypeMap { + if value == expectedSignatureType { + return key + } + } + return "" +} + +var SignatureTypeMap = map[string]SignatureType{ + "EIP712": EIP712Signature, + "Personal": PersonalSign, +} diff --git a/go/common/viewingkey/viewing_key_signature.go b/go/common/viewingkey/viewing_key_signature.go new file mode 100644 index 0000000000..9c48f0b3cc --- /dev/null +++ b/go/common/viewingkey/viewing_key_signature.go @@ -0,0 +1,142 @@ +package viewingkey + +import ( + "crypto/ecdsa" + "errors" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/accounts" + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" +) + +// SignatureChecker is an interface for checking +// if signature is valid for provided encryptionToken and chainID and return singing address or nil if not valid +type SignatureChecker interface { + CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) +} + +type ( + PersonalSignChecker struct{} + EIP712Checker struct{} + LegacyChecker struct{} +) + +// CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid +func (psc PersonalSignChecker) CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) { + if len(signature) != 65 { + return nil, fmt.Errorf("invalid signaure length: %d", len(signature)) + } + // We transform the V from 27/28 to 0/1. This same change is made in Geth internals, for legacy reasons to be able + // to recover the address: https://github.com/ethereum/go-ethereum/blob/55599ee95d4151a2502465e0afc7c47bd1acba77/internal/ethapi/api.go#L452-L459 + if signature[64] == 27 || signature[64] == 28 { + signature[64] -= 27 + } + + msg, err := GenerateMessage(encryptionToken, chainID, PersonalSignVersion, PersonalSign) + if err != nil { + return nil, fmt.Errorf("cannot generate message. Cause %w", err) + } + + msgHash, err := GetMessageHash(msg, PersonalSign) + if err != nil { + return nil, fmt.Errorf("cannot generate message hash. Cause %w", err) + } + + // signature is valid - return account address + address, err := CheckSignatureAndReturnAccountAddress(msgHash, signature) + if err == nil { + return address, nil + } + + return nil, fmt.Errorf("signature verification failed") +} + +func (e EIP712Checker) CheckSignature(encryptionToken string, signature []byte, chainID int64) (*gethcommon.Address, error) { + if len(signature) != 65 { + return nil, fmt.Errorf("invalid signaure length: %d", len(signature)) + } + + msg, err := GenerateMessage(encryptionToken, chainID, 1, EIP712Signature) + if err != nil { + return nil, fmt.Errorf("cannot generate message. Cause %w", err) + } + + msgHash, err := GetMessageHash(msg, EIP712Signature) + if err != nil { + return nil, fmt.Errorf("cannot generate message hash. Cause %w", err) + } + + // We transform the V from 27/28 to 0/1. This same change is made in Geth internals, for legacy reasons to be able + // to recover the address: https://github.com/ethereum/go-ethereum/blob/55599ee95d4151a2502465e0afc7c47bd1acba77/internal/ethapi/api.go#L452-L459 + if signature[64] == 27 || signature[64] == 28 { + signature[64] -= 27 + } + + // current signature is valid - return account address + address, err := CheckSignatureAndReturnAccountAddress(msgHash, signature) + if err == nil { + return address, nil + } + + return nil, errors.New("EIP 712 signature verification failed") +} + +// CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid +// todo (@ziga) Remove this method once old WE endpoints are removed +// encryptionToken is expected to be a public key and not encrypted token as with other signature types +// (since this is only temporary fix and legacy format will be removed soon) +func (lsc LegacyChecker) CheckSignature(encryptionToken string, signature []byte, _ int64) (*gethcommon.Address, error) { + publicKey := []byte(encryptionToken) + msgToSignLegacy := GenerateSignMessage(publicKey) + + recoveredAccountPublicKeyLegacy, err := crypto.SigToPub(accounts.TextHash([]byte(msgToSignLegacy)), signature) + if err != nil { + return nil, fmt.Errorf("failed to recover account public key from legacy signature: %w", err) + } + recoveredAccountAddressLegacy := crypto.PubkeyToAddress(*recoveredAccountPublicKeyLegacy) + return &recoveredAccountAddressLegacy, nil +} + +// SignatureChecker is a map of SignatureType to SignatureChecker +var signatureCheckers = map[SignatureType]SignatureChecker{ + PersonalSign: PersonalSignChecker{}, + EIP712Signature: EIP712Checker{}, + Legacy: LegacyChecker{}, +} + +// CheckSignature checks if signature is valid for provided encryptionToken and chainID and return address or nil if not valid +func CheckSignature(encryptionToken string, signature []byte, chainID int64, signatureType SignatureType) (*gethcommon.Address, error) { + checker, exists := signatureCheckers[signatureType] + if !exists { + return nil, fmt.Errorf("unsupported signature type") + } + return checker.CheckSignature(encryptionToken, signature, chainID) +} + +// CheckSignatureAndReturnAccountAddress checks if the signature is valid for hash of the message and checks if +// signer is an address provided to the function. +// It returns an address if the signature is valid and nil otherwise +func CheckSignatureAndReturnAccountAddress(hashBytes []byte, signature []byte) (*gethcommon.Address, error) { + pubKeyBytes, err := crypto.Ecrecover(hashBytes, signature) + if err != nil { + return nil, err + } + + pubKey, err := crypto.UnmarshalPubkey(pubKeyBytes) + if err != nil { + return nil, err + } + + r := new(big.Int).SetBytes(signature[:32]) + s := new(big.Int).SetBytes(signature[32:64]) + + // Verify the signature and return the result (all the checks above passed) + isSigValid := ecdsa.Verify(pubKey, hashBytes, r, s) + if isSigValid { + address := crypto.PubkeyToAddress(*pubKey) + return &address, nil + } + return nil, fmt.Errorf("invalid signature") +} diff --git a/go/enclave/vkhandler/vk_handler_test.go b/go/enclave/vkhandler/vk_handler_test.go index 7d66185c86..e8065ef2e8 100644 --- a/go/enclave/vkhandler/vk_handler_test.go +++ b/go/enclave/vkhandler/vk_handler_test.go @@ -7,7 +7,6 @@ import ( gethcommon "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/stretchr/testify/assert" @@ -49,13 +48,24 @@ func TestCheckSignature(t *testing.T) { userPrivKey, _, userID, userAddress := generateRandomUserKeys() // Generate all message types and create map with the corresponding signature type - // Test EIP712 message format - EIP712MessageDataOptions, err := viewingkey.GenerateAuthenticationEIP712RawDataOptions(userID, chainID) + EIP712Message, err := viewingkey.GenerateMessage(userID, chainID, 0, viewingkey.EIP712Signature) + if err != nil { + t.Fatalf(err.Error()) + } + EIP712MessageHash, err := viewingkey.GetMessageHash(EIP712Message, viewingkey.EIP712Signature) + if err != nil { + t.Fatalf(err.Error()) + } + + PersonalSignMessage, err := viewingkey.GenerateMessage(userID, chainID, viewingkey.PersonalSignVersion, viewingkey.PersonalSign) + if err != nil { + t.Fatalf(err.Error()) + } + + PersonalSignMessageHash, err := viewingkey.GetMessageHash(PersonalSignMessage, viewingkey.PersonalSign) if err != nil { t.Fatalf(err.Error()) } - EIP712MessageHash := crypto.Keccak256(EIP712MessageDataOptions[0]) - PersonalSignMessageHash := accounts.TextHash([]byte(viewingkey.GeneratePersonalSignMessage(userID, chainID, viewingkey.PersonalSignMessageSupportedVersions[0]))) messages := map[string]MessageWithSignatureType{ "EIP712MessageHash": { @@ -86,12 +96,24 @@ func TestVerifyViewingKey(t *testing.T) { userPrivKey, vkPrivKey, userID, userAddress := generateRandomUserKeys() // Generate all message types and create map with the corresponding signature type // Test EIP712 message format - EIP712MessageDataOptions, err := viewingkey.GenerateAuthenticationEIP712RawDataOptions(userID, chainID) + + EIP712Message, err := viewingkey.GenerateMessage(userID, chainID, viewingkey.PersonalSignVersion, viewingkey.EIP712Signature) + if err != nil { + t.Fatalf(err.Error()) + } + EIP712MessageHash, err := viewingkey.GetMessageHash(EIP712Message, viewingkey.EIP712Signature) + if err != nil { + t.Fatalf(err.Error()) + } + + PersonalSignMessage, err := viewingkey.GenerateMessage(userID, chainID, viewingkey.PersonalSignVersion, viewingkey.PersonalSign) + if err != nil { + t.Fatalf(err.Error()) + } + PersonalSignMessageHash, err := viewingkey.GetMessageHash(PersonalSignMessage, viewingkey.PersonalSign) if err != nil { t.Fatalf(err.Error()) } - EIP712MessageHash := crypto.Keccak256(EIP712MessageDataOptions[0]) - PersonalSignMessageHash := accounts.TextHash([]byte(viewingkey.GeneratePersonalSignMessage(userID, chainID, viewingkey.PersonalSignMessageSupportedVersions[0]))) messages := map[string]MessageWithSignatureType{ "EIP712MessageHash": { diff --git a/tools/walletextension/api/routes.go b/tools/walletextension/api/routes.go index 829dc69a38..6d38b70feb 100644 --- a/tools/walletextension/api/routes.go +++ b/tools/walletextension/api/routes.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" + "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ten-protocol/go-ten/lib/gethfork/node" "github.com/ten-protocol/go-ten/go/common/log" @@ -42,6 +43,10 @@ func NewHTTPRoutes(walletExt *walletextension.WalletExtension) []node.Route { Name: common.APIVersion1 + common.PathJoin, Func: httpHandler(walletExt, joinRequestHandler), }, + { + Name: common.APIVersion1 + common.PathGetMessage, + Func: httpHandler(walletExt, getMessageRequestHandler), + }, { Name: common.APIVersion1 + common.PathAuthenticate, Func: httpHandler(walletExt, authenticateRequestHandler), @@ -295,7 +300,7 @@ func authenticateRequestHandler(walletExt *walletextension.WalletExtension, conn var reqJSONMap map[string]string err = json.Unmarshal(body, &reqJSONMap) if err != nil { - handleError(conn, walletExt.Logger(), fmt.Errorf("could not unmarshal address request - %w", err)) + handleError(conn, walletExt.Logger(), fmt.Errorf("could not unmarshal request body - %w", err)) return } @@ -313,14 +318,14 @@ func authenticateRequestHandler(walletExt *walletextension.WalletExtension, conn return } - // get optional type of the message that was signed + // get an optional type of the message that was signed messageTypeValue := common.DefaultGatewayAuthMessageType if typeFromRequest, ok := reqJSONMap[common.JSONKeyType]; ok && typeFromRequest != "" { messageTypeValue = typeFromRequest } - // check if message type is valid - messageType, ok := common.SignatureTypeMap[messageTypeValue] + // check if a message type is valid + messageType, ok := viewingkey.SignatureTypeMap[messageTypeValue] if !ok { handleError(conn, walletExt.Logger(), fmt.Errorf("invalid message type: %s", messageTypeValue)) } @@ -485,3 +490,80 @@ func versionRequestHandler(walletExt *walletextension.WalletExtension, userConn walletExt.Logger().Error("error writing success response", log.ErrKey, err) } } + +// getMessageRequestHandler handles request to /get-message endpoint. +func getMessageRequestHandler(walletExt *walletextension.WalletExtension, conn userconn.UserConn) { + // read the request + body, err := conn.ReadRequest() + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("error reading request: %w", err)) + return + } + + // read body of the request + var reqJSONMap map[string]interface{} + err = json.Unmarshal(body, &reqJSONMap) + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("could not unmarshal address request - %w", err)) + return + } + + // get address from the request + encryptionToken, ok := reqJSONMap[common.JSONKeyEncryptionToken] + if !ok || len(encryptionToken.(string)) != common.MessageUserIDLen { + handleError(conn, walletExt.Logger(), fmt.Errorf("unable to read encryptionToken field from the request or it is not of correct length")) + return + } + + // get formats from the request, if present + var formatsSlice []string + if formatsInterface, ok := reqJSONMap[common.JSONKeyFormats]; ok { + formats, ok := formatsInterface.([]interface{}) + if !ok { + handleError(conn, walletExt.Logger(), fmt.Errorf("formats field is not an array")) + return + } + + for _, f := range formats { + formatStr, ok := f.(string) + if !ok { + handleError(conn, walletExt.Logger(), fmt.Errorf("format value is not a string")) + return + } + formatsSlice = append(formatsSlice, formatStr) + } + } + + message, err := walletExt.GenerateUserMessageToSign(encryptionToken.(string), formatsSlice) + if err != nil { + handleError(conn, walletExt.Logger(), fmt.Errorf("internal error")) + walletExt.Logger().Error("error getting message", log.ErrKey, err) + return + } + + // create the response structure + type JSONResponse struct { + Message string `json:"message"` + Type string `json:"type"` + } + + // get string representation of the message format + messageFormat := viewingkey.GetBestFormat(formatsSlice) + messageFormatString := viewingkey.GetSignatureTypeString(messageFormat) + + response := JSONResponse{ + Message: message, + Type: messageFormatString, + } + + responseBytes, err := json.Marshal(response) + if err != nil { + walletExt.Logger().Error("error marshaling JSON response", log.ErrKey, err) + return + } + + err = conn.WriteResponse(responseBytes) + if err != nil { + walletExt.Logger().Error("error writing success response", log.ErrKey, err) + } +} diff --git a/tools/walletextension/common/constants.go b/tools/walletextension/common/constants.go index 8e66ea55dd..6adbde8737 100644 --- a/tools/walletextension/common/constants.go +++ b/tools/walletextension/common/constants.go @@ -2,28 +2,28 @@ package common import ( "time" - - "github.com/ten-protocol/go-ten/go/common/viewingkey" ) const ( Localhost = "127.0.0.1" - JSONKeyAddress = "address" - JSONKeyData = "data" - JSONKeyErr = "error" - JSONKeyFrom = "from" - JSONKeyID = "id" - JSONKeyMethod = "method" - JSONKeyParams = "params" - JSONKeyResult = "result" - JSONKeyRoot = "root" - JSONKeyRPCVersion = "jsonrpc" - JSONKeySignature = "signature" - JSONKeySubscription = "subscription" - JSONKeyCode = "code" - JSONKeyMessage = "message" - JSONKeyType = "type" + JSONKeyAddress = "address" + JSONKeyData = "data" + JSONKeyErr = "error" + JSONKeyFrom = "from" + JSONKeyID = "id" + JSONKeyMethod = "method" + JSONKeyParams = "params" + JSONKeyResult = "result" + JSONKeyRoot = "root" + JSONKeyRPCVersion = "jsonrpc" + JSONKeySignature = "signature" + JSONKeySubscription = "subscription" + JSONKeyCode = "code" + JSONKeyMessage = "message" + JSONKeyType = "type" + JSONKeyEncryptionToken = "encryptionToken" + JSONKeyFormats = "formats" ) const ( @@ -33,6 +33,7 @@ const ( PathGenerateViewingKey = "/generateviewingkey/" PathSubmitViewingKey = "/submitviewingkey/" PathJoin = "/join/" + PathGetMessage = "/getmessage/" PathAuthenticate = "/authenticate/" PathQuery = "/query/" PathRevoke = "/revoke/" @@ -57,8 +58,3 @@ const ( ) var ReaderHeadTimeout = 10 * time.Second - -var SignatureTypeMap = map[string]viewingkey.SignatureType{ - "EIP712": viewingkey.EIP712Signature, - "Personal": viewingkey.PersonalSign, -} diff --git a/tools/walletextension/lib/client_lib.go b/tools/walletextension/lib/client_lib.go index 04accae2e4..42be5c4eb0 100644 --- a/tools/walletextension/lib/client_lib.go +++ b/tools/walletextension/lib/client_lib.go @@ -9,8 +9,6 @@ import ( "net/http" "strings" - "github.com/ethereum/go-ethereum/accounts" - "github.com/ten-protocol/go-ten/integration" gethcommon "github.com/ethereum/go-ethereum/common" @@ -48,15 +46,16 @@ func (o *TGLib) Join() error { func (o *TGLib) RegisterAccount(pk *ecdsa.PrivateKey, addr gethcommon.Address) error { // create the registration message - rawMessageOptions, err := viewingkey.GenerateAuthenticationEIP712RawDataOptions(string(o.userID), integration.TenChainID) + message, err := viewingkey.GenerateMessage(string(o.userID), integration.TenChainID, 1, viewingkey.EIP712Signature) if err != nil { return err } - if len(rawMessageOptions) == 0 { - return fmt.Errorf("GenerateAuthenticationEIP712RawDataOptions returned 0 options") + + messageHash, err := viewingkey.GetMessageHash(message, viewingkey.EIP712Signature) + if err != nil { + return fmt.Errorf("failed to get message hash: %w", err) } - messageHash := crypto.Keccak256(rawMessageOptions[0]) sig, err := crypto.Sign(messageHash, pk) if err != nil { return fmt.Errorf("failed to sign message: %w", err) @@ -97,8 +96,15 @@ func (o *TGLib) RegisterAccount(pk *ecdsa.PrivateKey, addr gethcommon.Address) e func (o *TGLib) RegisterAccountPersonalSign(pk *ecdsa.PrivateKey, addr gethcommon.Address) error { // create the registration message - personalSignMessage := viewingkey.GeneratePersonalSignMessage(string(o.userID), integration.TenChainID, 1) - messageHash := accounts.TextHash([]byte(personalSignMessage)) + message, err := viewingkey.GenerateMessage(string(o.userID), integration.TenChainID, viewingkey.PersonalSignVersion, viewingkey.PersonalSign) + if err != nil { + return err + } + + messageHash, err := viewingkey.GetMessageHash(message, viewingkey.PersonalSign) + if err != nil { + return fmt.Errorf("failed to get message hash: %w", err) + } sig, err := crypto.Sign(messageHash, pk) if err != nil { diff --git a/tools/walletextension/wallet_extension.go b/tools/walletextension/wallet_extension.go index cf0ee4da57..ac01209f96 100644 --- a/tools/walletextension/wallet_extension.go +++ b/tools/walletextension/wallet_extension.go @@ -473,3 +473,19 @@ func (w *WalletExtension) Version() string { func (w *WalletExtension) GetTenNodeHealthStatus() (bool, error) { return w.tenClient.Health() } + +func (w *WalletExtension) GenerateUserMessageToSign(encryptionToken string, formatsSlice []string) (string, error) { + // Check if the formats are valid + for _, format := range formatsSlice { + if _, exists := viewingkey.SignatureTypeMap[format]; !exists { + return "", fmt.Errorf("invalid format: %s", format) + } + } + + messageFormat := viewingkey.GetBestFormat(formatsSlice) + message, err := viewingkey.GenerateMessage(encryptionToken, int64(w.config.TenChainID), viewingkey.PersonalSignVersion, messageFormat) + if err != nil { + return "", fmt.Errorf("error generating message: %w", err) + } + return string(message), nil +}