-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
499 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package aws | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/config" | ||
"github.com/aws/aws-sdk-go-v2/credentials" | ||
) | ||
|
||
func GetAWSConfig(accessKey, secretAccessKey, region, endpointURL string) (*aws.Config, error) { | ||
createClient := func(service, region string, options ...interface{}) (aws.Endpoint, error) { | ||
if endpointURL != "" { | ||
return aws.Endpoint{ | ||
PartitionID: "aws", | ||
URL: endpointURL, | ||
SigningRegion: region, | ||
}, nil | ||
} | ||
|
||
// returning EndpointNotFoundError will allow the service to fallback to its default resolution | ||
return aws.Endpoint{}, &aws.EndpointNotFoundError{} | ||
} | ||
customResolver := aws.EndpointResolverWithOptionsFunc(createClient) | ||
|
||
cfg, errCfg := config.LoadDefaultConfig(context.Background(), | ||
config.WithRegion(region), | ||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretAccessKey, "")), | ||
config.WithEndpointResolverWithOptions(customResolver), | ||
config.WithRetryMode(aws.RetryModeStandard), | ||
) | ||
if errCfg != nil { | ||
return nil, errCfg | ||
} | ||
return &cfg, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package kms | ||
|
||
import ( | ||
"context" | ||
"crypto/ecdsa" | ||
"encoding/asn1" | ||
"fmt" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/kms" | ||
"github.com/ethereum/go-ethereum/crypto" | ||
) | ||
|
||
type asn1EcPublicKey struct { | ||
EcPublicKeyInfo asn1EcPublicKeyInfo | ||
PublicKey asn1.BitString | ||
} | ||
|
||
type asn1EcPublicKeyInfo struct { | ||
Algorithm asn1.ObjectIdentifier | ||
Parameters asn1.ObjectIdentifier | ||
} | ||
|
||
// GetPublicKey retrieves the ECDSA public key for a KMS key | ||
// It assumes the key is set up with `ECC_SECG_P256K1` key spec and `SIGN_VERIFY` key usage | ||
func GetPublicKey(ctx context.Context, svc *kms.Client, keyId string) (*ecdsa.PublicKey, error) { | ||
getPubKeyOutput, err := svc.GetPublicKey(ctx, &kms.GetPublicKeyInput{ | ||
KeyId: aws.String(keyId), | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to get public key for KeyId=%s: %w", keyId, err) | ||
} | ||
|
||
var asn1pubk asn1EcPublicKey | ||
_, err = asn1.Unmarshal(getPubKeyOutput.PublicKey, &asn1pubk) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to unmarshal public key for KeyId=%s: %w", keyId, err) | ||
} | ||
|
||
pubkey, err := crypto.UnmarshalPubkey(asn1pubk.PublicKey.Bytes) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to unmarshal public key for KeyId=%s: %w", keyId, err) | ||
} | ||
|
||
return pubkey, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package kms_test | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"os" | ||
"testing" | ||
|
||
eigenkms "github.com/Layr-Labs/eigensdk-go/aws/kms" | ||
"github.com/Layr-Labs/eigensdk-go/testutils" | ||
"github.com/aws/aws-sdk-go-v2/service/kms/types" | ||
"github.com/ethereum/go-ethereum/common" | ||
"github.com/ethereum/go-ethereum/crypto" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/testcontainers/testcontainers-go" | ||
) | ||
|
||
var ( | ||
mappedLocalstackPort string | ||
keyMetadata *types.KeyMetadata | ||
container testcontainers.Container | ||
) | ||
|
||
func TestMain(m *testing.M) { | ||
err := setup() | ||
if err != nil { | ||
fmt.Println("Error setting up test environment:", err) | ||
teardown() | ||
os.Exit(1) | ||
} | ||
exitCode := m.Run() | ||
teardown() | ||
os.Exit(exitCode) | ||
} | ||
|
||
func setup() error { | ||
var err error | ||
container, err = testutils.StartLocalstackContainer() | ||
if err != nil { | ||
return err | ||
} | ||
mappedPort, err := container.MappedPort(context.Background(), testutils.LocalStackPort) | ||
if err != nil { | ||
return err | ||
} | ||
mappedLocalstackPort = string(mappedPort) | ||
keyMetadata, err = testutils.CreateKMSKey(mappedLocalstackPort) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
func teardown() { | ||
_ = container.Terminate(context.Background()) | ||
} | ||
|
||
func TestGetPublicKey(t *testing.T) { | ||
c, err := testutils.NewKMSClient(mappedLocalstackPort) | ||
assert.Nil(t, err) | ||
assert.NotNil(t, keyMetadata.KeyId) | ||
pk, err := eigenkms.GetPublicKey(context.Background(), c, *keyMetadata.KeyId) | ||
assert.Nil(t, err) | ||
assert.NotNil(t, pk) | ||
keyAddr := crypto.PubkeyToAddress(*pk) | ||
t.Logf("Public key address: %s", keyAddr.String()) | ||
assert.NotEqual(t, keyAddr, common.Address{0}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package kms | ||
|
||
import ( | ||
"context" | ||
"encoding/asn1" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/kms" | ||
"github.com/aws/aws-sdk-go-v2/service/kms/types" | ||
) | ||
|
||
type asn1EcSig struct { | ||
R asn1.RawValue | ||
S asn1.RawValue | ||
} | ||
|
||
// GetSignature retrieves the ECDSA signature for a message using a KMS key | ||
func GetSignature( | ||
ctx context.Context, svc *kms.Client, keyId string, msg []byte, | ||
) ([]byte, []byte, error) { | ||
signInput := &kms.SignInput{ | ||
KeyId: aws.String(keyId), | ||
SigningAlgorithm: types.SigningAlgorithmSpecEcdsaSha256, | ||
MessageType: types.MessageTypeDigest, | ||
Message: msg, | ||
} | ||
|
||
signOutput, err := svc.Sign(ctx, signInput) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
var sigAsn1 asn1EcSig | ||
_, err = asn1.Unmarshal(signOutput.Signature, &sigAsn1) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
return sigAsn1.R.Bytes, sigAsn1.S.Bytes, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.