diff --git a/aws/config.go b/aws/config.go new file mode 100644 index 00000000..6eb38a8e --- /dev/null +++ b/aws/config.go @@ -0,0 +1,36 @@ +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" +) + +func GetAWSConfig(accessKey, secretAccessKey, region, endpointURL string) (*aws.Config, error) { + createClient := func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if endpointURL != "" { + return aws.Endpoint{ + PartitionID: "aws", + URL: endpointURL, + SigningRegion: region, + }, nil + } + + // returning EndpointNotFoundError will allow the service to fallback to its default resolution + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + } + customResolver := aws.EndpointResolverWithOptionsFunc(createClient) + + cfg, errCfg := config.LoadDefaultConfig(context.Background(), + config.WithRegion(region), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretAccessKey, "")), + config.WithEndpointResolverWithOptions(customResolver), + config.WithRetryMode(aws.RetryModeStandard), + ) + if errCfg != nil { + return nil, errCfg + } + return &cfg, nil +} diff --git a/aws/kms/client.go b/aws/kms/client.go new file mode 100644 index 00000000..559e910a --- /dev/null +++ b/aws/kms/client.go @@ -0,0 +1,19 @@ +package kms + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kms" +) + +func NewKMSClient(ctx context.Context, region string) (*kms.Client, error) { + config, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + c := kms.NewFromConfig(config) + return c, nil +} diff --git a/aws/kms/get_public_key.go b/aws/kms/get_public_key.go new file mode 100644 index 00000000..665a5b42 --- /dev/null +++ b/aws/kms/get_public_key.go @@ -0,0 +1,46 @@ +package kms + +import ( + "context" + "crypto/ecdsa" + "encoding/asn1" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/ethereum/go-ethereum/crypto" +) + +type asn1EcPublicKey struct { + EcPublicKeyInfo asn1EcPublicKeyInfo + PublicKey asn1.BitString +} + +type asn1EcPublicKeyInfo struct { + Algorithm asn1.ObjectIdentifier + Parameters asn1.ObjectIdentifier +} + +// GetECDSAPublicKey retrieves the ECDSA public key for a KMS key +// It assumes the key is set up with `ECC_SECG_P256K1` key spec and `SIGN_VERIFY` key usage +func GetECDSAPublicKey(ctx context.Context, svc *kms.Client, keyId string) (*ecdsa.PublicKey, error) { + getPubKeyOutput, err := svc.GetPublicKey(ctx, &kms.GetPublicKeyInput{ + KeyId: aws.String(keyId), + }) + if err != nil { + return nil, fmt.Errorf("failed to get public key for KeyId=%s: %w", keyId, err) + } + + var asn1pubk asn1EcPublicKey + _, err = asn1.Unmarshal(getPubKeyOutput.PublicKey, &asn1pubk) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal public key for KeyId=%s: %w", keyId, err) + } + + pubkey, err := crypto.UnmarshalPubkey(asn1pubk.PublicKey.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal public key for KeyId=%s: %w", keyId, err) + } + + return pubkey, nil +} diff --git a/aws/kms/get_public_key_test.go b/aws/kms/get_public_key_test.go new file mode 100644 index 00000000..7c9b5eba --- /dev/null +++ b/aws/kms/get_public_key_test.go @@ -0,0 +1,68 @@ +package kms_test + +import ( + "context" + "fmt" + "os" + "testing" + + eigenkms "github.com/Layr-Labs/eigensdk-go/aws/kms" + "github.com/Layr-Labs/eigensdk-go/testutils" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/assert" + "github.com/testcontainers/testcontainers-go" +) + +var ( + mappedLocalstackPort string + keyMetadata *types.KeyMetadata + container testcontainers.Container +) + +func TestMain(m *testing.M) { + err := setup() + if err != nil { + fmt.Println("Error setting up test environment:", err) + teardown() + os.Exit(1) + } + exitCode := m.Run() + teardown() + os.Exit(exitCode) +} + +func setup() error { + var err error + container, err = testutils.StartLocalstackContainer("get_public_key_test") + if err != nil { + return err + } + mappedPort, err := container.MappedPort(context.Background(), testutils.LocalStackPort) + if err != nil { + return err + } + mappedLocalstackPort = string(mappedPort) + keyMetadata, err = testutils.CreateKMSKey(mappedLocalstackPort) + if err != nil { + return err + } + return nil +} + +func teardown() { + _ = container.Terminate(context.Background()) +} + +func TestGetPublicKey(t *testing.T) { + c, err := testutils.NewKMSClient(mappedLocalstackPort) + assert.Nil(t, err) + assert.NotNil(t, keyMetadata.KeyId) + pk, err := eigenkms.GetECDSAPublicKey(context.Background(), c, *keyMetadata.KeyId) + assert.Nil(t, err) + assert.NotNil(t, pk) + keyAddr := crypto.PubkeyToAddress(*pk) + t.Logf("Public key address: %s", keyAddr.String()) + assert.NotEqual(t, keyAddr, common.Address{0}) +} diff --git a/aws/kms/get_signature.go b/aws/kms/get_signature.go new file mode 100644 index 00000000..03e6b4df --- /dev/null +++ b/aws/kms/get_signature.go @@ -0,0 +1,40 @@ +package kms + +import ( + "context" + "encoding/asn1" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" +) + +type asn1EcSig struct { + R asn1.RawValue + S asn1.RawValue +} + +// GetECDSASignature retrieves the ECDSA signature for a message using a KMS key +func GetECDSASignature( + ctx context.Context, svc *kms.Client, keyId string, msg []byte, +) (r []byte, s []byte, err error) { + signInput := &kms.SignInput{ + KeyId: aws.String(keyId), + SigningAlgorithm: types.SigningAlgorithmSpecEcdsaSha256, + MessageType: types.MessageTypeDigest, + Message: msg, + } + + signOutput, err := svc.Sign(ctx, signInput) + if err != nil { + return nil, nil, err + } + + var sigAsn1 asn1EcSig + _, err = asn1.Unmarshal(signOutput.Signature, &sigAsn1) + if err != nil { + return nil, nil, err + } + + return sigAsn1.R.Bytes, sigAsn1.S.Bytes, nil +} diff --git a/cmd/egnaddrs/main_test.go b/cmd/egnaddrs/main_test.go index c4b2e923..0c3c81ea 100644 --- a/cmd/egnaddrs/main_test.go +++ b/cmd/egnaddrs/main_test.go @@ -2,12 +2,9 @@ package main import ( "context" - "os" - "path/filepath" "testing" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/wait" + "github.com/Layr-Labs/eigensdk-go/testutils" ) const ( @@ -20,7 +17,10 @@ const ( func TestEgnAddrsWithServiceManagerFlag(t *testing.T) { - anvilC := startAnvilTestContainer() + anvilC, err := testutils.StartAnvilContainer(anvilStateFileName) + if err != nil { + t.Fatal(err) + } anvilEndpoint, err := anvilC.Endpoint(context.Background(), "") if err != nil { t.Error(err) @@ -35,7 +35,10 @@ func TestEgnAddrsWithServiceManagerFlag(t *testing.T) { func TestEgnAddrsWithRegistryCoordinatorFlag(t *testing.T) { - anvilC := startAnvilTestContainer() + anvilC, err := testutils.StartAnvilContainer(anvilStateFileName) + if err != nil { + t.Fatal(err) + } anvilEndpoint, err := anvilC.Endpoint(context.Background(), "") if err != nil { t.Error(err) @@ -47,34 +50,3 @@ func TestEgnAddrsWithRegistryCoordinatorFlag(t *testing.T) { // we just make sure it doesn't crash run(args) } - -func startAnvilTestContainer() testcontainers.Container { - integrationDir, err := os.Getwd() - if err != nil { - panic(err) - } - - ctx := context.Background() - req := testcontainers.ContainerRequest{ - Image: "ghcr.io/foundry-rs/foundry:latest", - Files: []testcontainers.ContainerFile{ - { - HostFilePath: filepath.Join(integrationDir, "test_data", anvilStateFileName), - ContainerFilePath: "/root/.anvil/state.json", - FileMode: 0644, // Adjust the FileMode according to your requirements - }, - }, - Entrypoint: []string{"anvil"}, - Cmd: []string{"--host", "0.0.0.0", "--load-state", "/root/.anvil/state.json"}, - ExposedPorts: []string{"8545/tcp"}, - WaitingFor: wait.ForLog("Listening on"), - } - anvilC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, - }) - if err != nil { - panic(err) - } - return anvilC -} diff --git a/go.mod b/go.mod index 4e1c7520..ae2a84e1 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/Layr-Labs/eigensdk-go go 1.21 require ( + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/consensys/gnark-crypto v0.12.1 github.com/ethereum/go-ethereum v1.14.0 github.com/google/uuid v1.6.0 @@ -17,7 +18,6 @@ require ( ) require ( - github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect @@ -47,6 +47,7 @@ require ( github.com/StackExchange/wmi v1.2.1 // indirect github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/config v1.27.11 + github.com/aws/aws-sdk-go-v2/service/kms v1.31.0 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.6 github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.10.0 // indirect diff --git a/go.sum b/go.sum index a52fa98c..fbe2c4c7 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1x github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/kms v1.31.0 h1:yl7wcqbisxPzknJVfWTLnK83McUvXba+pz2+tPbIUmQ= +github.com/aws/aws-sdk-go-v2/service/kms v1.31.0/go.mod h1:2snWQJQUKsbN66vAawJuOGX7dr37pfOq9hb0tZDGIqQ= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.6 h1:TIOEjw0i2yyhmhRry3Oeu9YtiiHWISZ6j/irS1W3gX4= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.6/go.mod h1:3Ba++UwWd154xtP4FRX5pUK3Gt4up5sDHCve6kVfE+g= github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= diff --git a/signerv2/kms_signer.go b/signerv2/kms_signer.go new file mode 100644 index 00000000..72dcb9f3 --- /dev/null +++ b/signerv2/kms_signer.go @@ -0,0 +1,107 @@ +package signerv2 + +import ( + "bytes" + "context" + "crypto/ecdsa" + "encoding/hex" + "errors" + "math/big" + + eigenkms "github.com/Layr-Labs/eigensdk-go/aws/kms" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/secp256k1" +) + +var secp256k1N = crypto.S256().Params().N +var secp256k1HalfN = new(big.Int).Div(secp256k1N, big.NewInt(2)) + +func NewKMSSigner(ctx context.Context, svc *kms.Client, pk *ecdsa.PublicKey, keyId string, chainID *big.Int) SignerFn { + return func(ctx context.Context, address common.Address) (bind.SignerFn, error) { + return KMSSignerFn(ctx, svc, pk, keyId, chainID) + } +} + +// KMSSignerFn returns a SignerFn that uses a KMS key to sign transactions +// Heavily taken from https://github.com/welthee/go-ethereum-aws-kms-tx-signer +// It constructs R and S values from KMS, and constructs the recovery id (V) by trying to recover with both 0 and 1 values: +// ref: https://github.com/aws-samples/aws-kms-ethereum-accounts?tab=readme-ov-file#the-recovery-identifier-v +// +// Its V value is 0/1 instead of 27/28 because `types.LatestSignerForChainID` expects 0/1 which turns it into 27/28 +func KMSSignerFn(ctx context.Context, svc *kms.Client, pk *ecdsa.PublicKey, keyId string, chainID *big.Int) (bind.SignerFn, error) { + if chainID == nil { + return nil, errors.New("chainID is required") + } + if svc == nil { + return nil, errors.New("kms client is required") + } + if pk == nil { + return nil, errors.New("public key is required") + } + + pubKeyBytes := secp256k1.S256().Marshal(pk.X, pk.Y) + keyAddr := crypto.PubkeyToAddress(*pk) + signer := types.LatestSignerForChainID(chainID) + return func(address common.Address, tx *types.Transaction) (*types.Transaction, error) { + if address != keyAddr { + return nil, bind.ErrNotAuthorized + } + + txHashBytes := signer.Hash(tx).Bytes() + + rBytes, sBytes, err := eigenkms.GetECDSASignature(ctx, svc, keyId, txHashBytes) + if err != nil { + return nil, err + } + + // Adjust S value from signature according to Ethereum standard + sBigInt := new(big.Int).SetBytes(sBytes) + if sBigInt.Cmp(secp256k1HalfN) > 0 { + sBytes = new(big.Int).Sub(secp256k1N, sBigInt).Bytes() + } + + signature, err := getEthereumSignature(pubKeyBytes, txHashBytes, rBytes, sBytes) + if err != nil { + return nil, err + } + + return tx.WithSignature(signer, signature) + }, nil +} + +func getEthereumSignature(expectedPublicKeyBytes []byte, txHash []byte, r []byte, s []byte) ([]byte, error) { + rsSignature := append(adjustSignatureLength(r), adjustSignatureLength(s)...) + signature := append(rsSignature, []byte{0}...) + + recoveredPublicKeyBytes, err := crypto.Ecrecover(txHash, signature) + if err != nil { + return nil, err + } + + if hex.EncodeToString(recoveredPublicKeyBytes) != hex.EncodeToString(expectedPublicKeyBytes) { + signature = append(rsSignature, []byte{1}...) + recoveredPublicKeyBytes, err = crypto.Ecrecover(txHash, signature) + if err != nil { + return nil, err + } + + if hex.EncodeToString(recoveredPublicKeyBytes) != hex.EncodeToString(expectedPublicKeyBytes) { + return nil, errors.New("can not reconstruct public key from sig") + } + } + + return signature, nil +} + +func adjustSignatureLength(buffer []byte) []byte { + buffer = bytes.TrimLeft(buffer, "\x00") + for len(buffer) < 32 { + zeroBuf := []byte{0} + buffer = append(zeroBuf, buffer...) + } + return buffer +} diff --git a/signerv2/kms_signer_test.go b/signerv2/kms_signer_test.go new file mode 100644 index 00000000..c82390ac --- /dev/null +++ b/signerv2/kms_signer_test.go @@ -0,0 +1,122 @@ +package signerv2_test + +import ( + "context" + "fmt" + "math/big" + "os" + "testing" + + eigenkms "github.com/Layr-Labs/eigensdk-go/aws/kms" + "github.com/Layr-Labs/eigensdk-go/chainio/clients/eth" + "github.com/Layr-Labs/eigensdk-go/chainio/clients/wallet" + "github.com/Layr-Labs/eigensdk-go/chainio/txmgr" + "github.com/Layr-Labs/eigensdk-go/logging" + "github.com/Layr-Labs/eigensdk-go/signerv2" + "github.com/Layr-Labs/eigensdk-go/testutils" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/ethereum/go-ethereum/common" + gtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rpc" + "github.com/stretchr/testify/assert" + "github.com/testcontainers/testcontainers-go" +) + +var ( + mappedLocalstackPort string + keyMetadata *types.KeyMetadata + anvilEndpoint string + localstack testcontainers.Container + anvil testcontainers.Container + rpcClient *rpc.Client +) + +func TestMain(m *testing.M) { + err := setup() + if err != nil { + fmt.Println("Error setting up test environment:", err) + teardown() + os.Exit(1) + } + exitCode := m.Run() + teardown() + os.Exit(exitCode) +} + +func setup() error { + var err error + localstack, err = testutils.StartLocalstackContainer("kms_signer_test") + if err != nil { + return fmt.Errorf("failed to start Localstack container: %w", err) + } + mappedPort, err := localstack.MappedPort(context.Background(), testutils.LocalStackPort) + if err != nil { + return fmt.Errorf("failed to get mapped port: %w", err) + } + mappedLocalstackPort = string(mappedPort) + + anvil, err = testutils.StartAnvilContainer("") + if err != nil { + return fmt.Errorf("failed to start Anvil container: %w", err) + } + endpoint, err := anvil.Endpoint(context.Background(), "") + if err != nil { + return fmt.Errorf("failed to get Anvil endpoint: %w", err) + } + anvilEndpoint = fmt.Sprintf("http://%s", endpoint) + rpcClient, err = rpc.Dial(anvilEndpoint) + if err != nil { + return fmt.Errorf("failed to dial Anvil RPC: %w", err) + } + keyMetadata, err = testutils.CreateKMSKey(mappedLocalstackPort) + if err != nil { + return fmt.Errorf("failed to create KMS key: %w", err) + } + return nil +} + +func teardown() { + _ = localstack.Terminate(context.Background()) + _ = anvil.Terminate(context.Background()) +} + +func TestSendTransaction(t *testing.T) { + c, err := testutils.NewKMSClient(mappedLocalstackPort) + assert.Nil(t, err) + assert.NotNil(t, keyMetadata.KeyId) + pk, err := eigenkms.GetECDSAPublicKey(context.Background(), c, *keyMetadata.KeyId) + assert.Nil(t, err) + assert.NotNil(t, pk) + keyAddr := crypto.PubkeyToAddress(*pk) + t.Logf("Public key address: %s", keyAddr.String()) + assert.NotEqual(t, keyAddr, common.Address{0}) + err = rpcClient.CallContext(context.Background(), nil, "anvil_setBalance", keyAddr, "2_000_000_000_000_000_000") + assert.Nil(t, err) + + logger := &logging.NoopLogger{} + ethClient, err := eth.NewClient(anvilEndpoint) + assert.Nil(t, err) + chainID, err := ethClient.ChainID(context.Background()) + assert.Nil(t, err) + signer := signerv2.NewKMSSigner(context.Background(), c, pk, *keyMetadata.KeyId, chainID) + assert.Nil(t, err) + assert.NotNil(t, signer) + pkWallet, err := wallet.NewPrivateKeyWallet(ethClient, signer, keyAddr, logger) + assert.Nil(t, err) + assert.NotNil(t, pkWallet) + txMgr := txmgr.NewSimpleTxManager(pkWallet, ethClient, logger, keyAddr) + assert.NotNil(t, txMgr) + zeroAddr := common.HexToAddress("0x0") + receipt, err := txMgr.Send(context.Background(), gtypes.NewTx(>ypes.DynamicFeeTx{ + ChainID: chainID, + Nonce: 0, + To: &zeroAddr, + Value: big.NewInt(1_000_000_000_000_000_000), + })) + assert.Nil(t, err) + assert.NotNil(t, receipt) + balance, err := ethClient.BalanceAt(context.Background(), keyAddr, nil) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(999979000000000000), balance) +} diff --git a/testutils/anvil.go b/testutils/anvil.go new file mode 100644 index 00000000..0c84b1b7 --- /dev/null +++ b/testutils/anvil.go @@ -0,0 +1,42 @@ +package testutils + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +func StartAnvilContainer(anvilStateFileName string) (testcontainers.Container, error) { + integrationDir, err := os.Getwd() + if err != nil { + panic(err) + } + + ctx := context.Background() + req := testcontainers.ContainerRequest{ + Image: "ghcr.io/foundry-rs/foundry:latest", + Entrypoint: []string{"anvil"}, + Cmd: []string{"--host", "0.0.0.0", "--base-fee", "0", "--gas-price", "0"}, + ExposedPorts: []string{"8545/tcp"}, + WaitingFor: wait.ForLog("Listening on"), + } + if anvilStateFileName != "" { + fmt.Println("Starting Anvil container with state file: ", anvilStateFileName) + req.Cmd = append(req.Cmd, "--load-state", "/root/.anvil/state.json") + req.Files = []testcontainers.ContainerFile{ + { + HostFilePath: filepath.Join(integrationDir, "test_data", anvilStateFileName), + ContainerFilePath: "/root/.anvil/state.json", + FileMode: 0644, // Adjust the FileMode according to your requirements + }, + } + } + return testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) +} diff --git a/testutils/localstack.go b/testutils/localstack.go new file mode 100644 index 00000000..41e31049 --- /dev/null +++ b/testutils/localstack.go @@ -0,0 +1,55 @@ +package testutils + +import ( + "context" + "fmt" + + "github.com/Layr-Labs/eigensdk-go/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const LocalStackPort = "4566" + +func StartLocalstackContainer(name string) (testcontainers.Container, error) { + fmt.Println("Starting Localstack container") + req := testcontainers.ContainerRequest{ + Image: "localstack/localstack:latest", + Name: fmt.Sprintf("localstack-test-%s", name), + Env: map[string]string{"LOCALSTACK_HOST": fmt.Sprintf("localhost.localstack.cloud:%s", LocalStackPort)}, + ExposedPorts: []string{LocalStackPort}, + WaitingFor: wait.ForLog("Ready."), + AutoRemove: true, + } + return testcontainers.GenericContainer(context.Background(), testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) +} + +func NewKMSClient(localStackPort string) (*kms.Client, error) { + cfg, err := aws.GetAWSConfig("localstack", "localstack", "us-east-1", fmt.Sprintf("http://127.0.0.1:%s", localStackPort)) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + c := kms.NewFromConfig(*cfg) + return c, nil +} + +func CreateKMSKey(localStackPort string) (*types.KeyMetadata, error) { + c, err := NewKMSClient(localStackPort) + if err != nil { + return nil, fmt.Errorf("failed to create KMS client: %w", err) + } + res, err := c.CreateKey(context.Background(), &kms.CreateKeyInput{ + KeySpec: types.KeySpecEccSecgP256k1, + KeyUsage: types.KeyUsageTypeSignVerify, + }) + if err != nil { + return nil, fmt.Errorf("failed to create key: %w", err) + } + return res.KeyMetadata, nil +}