From 72380c58f3d2f9cf74904150e03ef404fab4bfe6 Mon Sep 17 00:00:00 2001 From: Jesse White Date: Sat, 13 Apr 2024 15:56:01 -0400 Subject: [PATCH] add aws iam authentication for mysql --- .../datastore/mysql/common/credentials.go | 37 +++++++++++++++++++ internal/datastore/mysql/datastore.go | 19 +++++++++- internal/datastore/mysql/datastore_test.go | 17 +++++++++ internal/datastore/mysql/migrations/driver.go | 16 +++++++- internal/datastore/mysql/options.go | 11 ++++++ .../datastore/postgres/migrations/driver.go | 2 +- internal/datastore/postgres/options.go | 4 ++ internal/datastore/postgres/postgres.go | 2 +- internal/testserver/datastore/mysql.go | 2 +- pkg/cmd/datastore/datastore.go | 1 + pkg/cmd/migrate.go | 12 +++++- pkg/datastore/credentials.go | 15 ++++++-- pkg/datastore/credentials_test.go | 4 +- 13 files changed, 131 insertions(+), 11 deletions(-) create mode 100644 internal/datastore/mysql/common/credentials.go diff --git a/internal/datastore/mysql/common/credentials.go b/internal/datastore/mysql/common/credentials.go new file mode 100644 index 0000000000..a4b19cc1d8 --- /dev/null +++ b/internal/datastore/mysql/common/credentials.go @@ -0,0 +1,37 @@ +package common + +import ( + "context" + + "github.com/go-sql-driver/mysql" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" +) + +// MaybeAddCredentialsProviderHook adds a hook that retrieves the configuration from the CredentialsProvider if the given credentialsProvider is not nil +func MaybeAddCredentialsProviderHook(dbConfig *mysql.Config, credentialsProvider datastore.CredentialsProvider) error { + if credentialsProvider == nil { + // a noop for a nil CredentialsProvider + return nil + } + + log.Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider") + + if credentialsProvider.IsCleartextToken() { + // we must transmit the token over the connection, and not a hash + dbConfig.AllowCleartextPasswords = true + + // log a warning if we don't detect TLS to be enabled + if dbConfig.TLSConfig == "false" || dbConfig.TLS == nil { + log.Warn().Msg("Tokens originating from credential provider are sent in cleartext. We recommend enabling TLS for the connection.") + } + } + + // add a before connect callback to trigger the token retrieval from the credentials provider + return dbConfig.Apply(mysql.BeforeConnect(func(ctx context.Context, config *mysql.Config) error { + var err error + config.User, config.Passwd, err = credentialsProvider.Get(ctx, config.Addr, config.User) + return err + })) +} diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index 0bb9b88b18..7153336215 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -10,6 +10,8 @@ import ( "sync/atomic" "time" + mysqlCommon "github.com/authzed/spicedb/internal/datastore/mysql/common" + sq "github.com/Masterminds/squirrel" "github.com/dlmiddlecote/sqlstats" "github.com/go-sql-driver/mysql" @@ -118,7 +120,22 @@ func newMySQLDatastore(ctx context.Context, uri string, options ...Option) (*Dat return nil, errors.New("error in NewMySQLDatastore: connection URI for MySQL datastore must include `parseTime=true` as a query parameter; see https://spicedb.dev/d/parse-time-mysql for more details") } - connector, err := mysql.MySQLDriver{}.OpenConnector(uri) + // Setup the credentials provider + var credentialsProvider datastore.CredentialsProvider + if config.credentialsProviderName != "" { + credentialsProvider, err = datastore.NewCredentialsProvider(ctx, config.credentialsProviderName) + if err != nil { + return nil, err + } + } + + err = mysqlCommon.MaybeAddCredentialsProviderHook(parsedURI, credentialsProvider) + if err != nil { + return nil, err + } + + // Call NewConnector with the existing parsed configuration to preserve the BeforeConnect added by the CredentialsProvider + connector, err := mysql.NewConnector(parsedURI) if err != nil { return nil, common.RedactAndLogSensitiveConnString(ctx, "NewMySQLDatastore: failed to create connector", err, uri) } diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index 808fd4e671..2d421aeee6 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -711,6 +711,23 @@ func TestMySQLMigrationsWithPrefix(t *testing.T) { req.NoError(rows.Err()) } +func TestMySQLWithAWSIAMCredentialsProvider(t *testing.T) { + // set up the environment, so we don't make any external calls to AWS + t.Setenv("AWS_CONFIG_FILE", "file_not_exists") + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists") + t.Setenv("AWS_ENDPOINT_URL", "http://169.254.169.254/aws") + t.Setenv("AWS_ACCESS_KEY", "access_key") + t.Setenv("AWS_SECRET_KEY", "secret_key") + t.Setenv("AWS_REGION", "us-east-1") + + // initialize the datastore using the AWS IAM credentials provider, and point it to a database that does not exist + _, err := NewMySQLDatastore(context.Background(), "root:password@(localhost:1234)/mysql?parseTime=True&tls=skip-verify", CredentialsProviderName("aws-iam")) + + // we expect the connection attempt to fail + // which means that the credentials provider was wired and called successfully before making the connection attempt + require.ErrorContains(t, err, ":1234: connect: connection refused") +} + func datastoreDB(t *testing.T, migrate bool) *sql.DB { var databaseURI string testdatastore.RunMySQLForTestingWithOptions(t, testdatastore.MySQLTesterOptions{MigrateForNewDatastore: migrate}, "").NewDatastore(t, func(engine, uri string) datastore.Datastore { diff --git a/internal/datastore/mysql/migrations/driver.go b/internal/datastore/mysql/migrations/driver.go index 3c57eb91b0..e09bf828ba 100644 --- a/internal/datastore/mysql/migrations/driver.go +++ b/internal/datastore/mysql/migrations/driver.go @@ -7,6 +7,10 @@ import ( "fmt" "strings" + mysqlCommon "github.com/authzed/spicedb/internal/datastore/mysql/common" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/internal/datastore/common" sq "github.com/Masterminds/squirrel" @@ -35,16 +39,24 @@ type MySQLDriver struct { // // URI: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2... // See https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html -func NewMySQLDriverFromDSN(url string, tablePrefix string) (*MySQLDriver, error) { +func NewMySQLDriverFromDSN(url string, tablePrefix string, credentialsProvider datastore.CredentialsProvider) (*MySQLDriver, error) { dbConfig, err := sqlDriver.ParseDSN(url) if err != nil { return nil, fmt.Errorf(errUnableToInstantiate, err) } - db, err := sql.Open("mysql", dbConfig.FormatDSN()) + err = mysqlCommon.MaybeAddCredentialsProviderHook(dbConfig, credentialsProvider) if err != nil { return nil, fmt.Errorf(errUnableToInstantiate, err) } + + // Call NewConnector with the existing parsed configuration to preserve the BeforeConnect added by the CredentialsProvider + connector, err := sqlDriver.NewConnector(dbConfig) + if err != nil { + return nil, fmt.Errorf(errUnableToInstantiate, err) + } + + db := sql.OpenDB(connector) err = sqlDriver.SetLogger(&log.Logger) if err != nil { return nil, fmt.Errorf("unable to set logging to mysql driver: %w", err) diff --git a/internal/datastore/mysql/options.go b/internal/datastore/mysql/options.go index e75d070b8a..00e7858a58 100644 --- a/internal/datastore/mysql/options.go +++ b/internal/datastore/mysql/options.go @@ -21,6 +21,7 @@ const ( defaultEnablePrometheusStats = false defaultMaxRetries = 8 defaultGCEnabled = true + defaultCredentialsProviderName = "" ) type mysqlOptions struct { @@ -40,6 +41,7 @@ type mysqlOptions struct { maxRetries uint8 lockWaitTimeoutSeconds *uint8 gcEnabled bool + credentialsProviderName string } // Option provides the facility to configure how clients within the @@ -61,6 +63,7 @@ func generateConfig(options []Option) (mysqlOptions, error) { enablePrometheusStats: defaultEnablePrometheusStats, maxRetries: defaultMaxRetries, gcEnabled: defaultGCEnabled, + credentialsProviderName: defaultCredentialsProviderName, } for _, option := range options { @@ -236,3 +239,11 @@ func GCMaxOperationTime(time time.Duration) Option { mo.gcMaxOperationTime = time } } + +// CredentialsProviderName is the name of the CredentialsProvider implementation to use +// for dynamically retrieving the datastore credentials at runtime +// +// Empty by default. +func CredentialsProviderName(credentialsProviderName string) Option { + return func(mo *mysqlOptions) { mo.credentialsProviderName = credentialsProviderName } +} diff --git a/internal/datastore/postgres/migrations/driver.go b/internal/datastore/postgres/migrations/driver.go index 9fe11d2164..e56efa0b00 100644 --- a/internal/datastore/postgres/migrations/driver.go +++ b/internal/datastore/postgres/migrations/driver.go @@ -42,7 +42,7 @@ func NewAlembicPostgresDriver(ctx context.Context, url string, credentialsProvid if credentialsProvider != nil { log.Ctx(ctx).Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider") - connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, connConfig.Host, connConfig.Port, connConfig.User) + connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", connConfig.Host, connConfig.Port), connConfig.User) if err != nil { return nil, err } diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index bc3ea935ec..0eba2ab3e4 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -337,6 +337,10 @@ func MigrationPhase(phase string) Option { return func(po *postgresOptions) { po.migrationPhase = phase } } +// CredentialsProviderName is the name of the CredentialsProvider implementation to use +// for dynamically retrieving the datastore credentials at runtime +// +// Empty by default. func CredentialsProviderName(credentialsProviderName string) Option { return func(po *postgresOptions) { po.credentialsProviderName = credentialsProviderName } } diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index b77b939b37..4258f4dc2e 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -175,7 +175,7 @@ func newPostgresDatastore( if credentialsProvider != nil { // add before connect callbacks to trigger the token getToken := func(ctx context.Context, config *pgx.ConnConfig) error { - config.User, config.Password, err = credentialsProvider.Get(ctx, config.Host, config.Port, config.User) + config.User, config.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", config.Host, config.Port), config.User) return err } readPoolConfig.BeforeConnect = getToken diff --git a/internal/testserver/datastore/mysql.go b/internal/testserver/datastore/mysql.go index f7752061bf..5b1571d5eb 100644 --- a/internal/testserver/datastore/mysql.go +++ b/internal/testserver/datastore/mysql.go @@ -116,7 +116,7 @@ func (mb *mysqlTester) NewDatabase(t testing.TB) string { } func (mb *mysqlTester) runMigrate(t testing.TB, dsn string) { - driver, err := migrations.NewMySQLDriverFromDSN(dsn, mb.options.Prefix) + driver, err := migrations.NewMySQLDriverFromDSN(dsn, mb.options.Prefix, datastore.NoCredentialsProvider) require.NoError(t, err, "failed to create migration driver: %s", err) err = migrations.Manager.Run(context.Background(), driver, migrate.Head, migrate.LiveRun) require.NoError(t, err, "failed to run migration: %s", err) diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index ad637a727b..c35590d135 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -459,6 +459,7 @@ func newMySQLDatastore(ctx context.Context, opts Config) (datastore.Datastore, e mysql.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), mysql.MaxRetries(uint8(opts.MaxRetries)), mysql.OverrideLockWaitTimeout(1), + mysql.CredentialsProviderName(opts.CredentialsProviderName), } return mysql.NewMySQLDatastore(ctx, opts.URI, mysqlOpts...) } diff --git a/pkg/cmd/migrate.go b/pkg/cmd/migrate.go index 16a30592b8..c0e00e68af 100644 --- a/pkg/cmd/migrate.go +++ b/pkg/cmd/migrate.go @@ -98,7 +98,17 @@ func migrateRun(cmd *cobra.Command, args []string) error { log.Ctx(cmd.Context()).Fatal().Msg(fmt.Sprintf("unable to get table prefix: %s", err)) } - migrationDriver, err := mysqlmigrations.NewMySQLDriverFromDSN(dbURL, tablePrefix) + var credentialsProvider datastore.CredentialsProvider + credentialsProviderName := cobrautil.MustGetString(cmd, "datastore-credentials-provider-name") + if credentialsProviderName != "" { + var err error + credentialsProvider, err = datastore.NewCredentialsProvider(cmd.Context(), credentialsProviderName) + if err != nil { + return err + } + } + + migrationDriver, err := mysqlmigrations.NewMySQLDriverFromDSN(dbURL, tablePrefix, credentialsProvider) if err != nil { return fmt.Errorf("unable to create migration driver for %s: %w", datastoreEngine, err) } diff --git a/pkg/datastore/credentials.go b/pkg/datastore/credentials.go index ee71fc309c..1c557ad818 100644 --- a/pkg/datastore/credentials.go +++ b/pkg/datastore/credentials.go @@ -18,8 +18,12 @@ import ( type CredentialsProvider interface { // Name returns the name of the provider Name() string + // IsCleartextToken returns true if the token returned represents a token (rather than a password) that must be sent in cleartext to the datastore, or false otherwise. + // This may be used to configure the datastore options to avoid sending a hash of the token instead of its value. + // Note that it is always recommended that communication channel be encrypted. + IsCleartextToken() bool // Get returns the username and password to use when connecting to the underlying datastore - Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) + Get(ctx context.Context, dbEndpoint string, dbUser string) (string, string, error) } var NoCredentialsProvider CredentialsProvider = nil @@ -74,8 +78,13 @@ func (d awsIamCredentialsProvider) Name() string { return AWSIAMCredentialProvider } -func (d awsIamCredentialsProvider) Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) { - dbEndpoint := fmt.Sprintf("%s:%d", dbHostname, dbPort) +func (d awsIamCredentialsProvider) IsCleartextToken() bool { + // The AWS IAM token can be of an arbitrary length and must not be hashed or truncated by the datastore driver + // See https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html + return true +} + +func (d awsIamCredentialsProvider) Get(ctx context.Context, dbEndpoint string, dbUser string) (string, string, error) { authToken, err := rdsauth.BuildAuthToken(ctx, dbEndpoint, d.awsSdkConfig.Region, dbUser, d.awsSdkConfig.Credentials) if err != nil { log.Ctx(ctx).Trace().Str("region", d.awsSdkConfig.Region).Str("endpoint", dbEndpoint).Str("user", dbUser).Msg("successfully retrieved IAM auth token for DB") diff --git a/pkg/datastore/credentials_test.go b/pkg/datastore/credentials_test.go index 076df2dd0b..770ef50ee4 100644 --- a/pkg/datastore/credentials_test.go +++ b/pkg/datastore/credentials_test.go @@ -31,7 +31,9 @@ func TestAWSIAMCredentialsProvider(t *testing.T) { require.NotNil(t, credentialsProvider) require.NoError(t, err) - username, password, err := credentialsProvider.Get(context.Background(), "some-hostname", 5432, "some-user") + require.True(t, credentialsProvider.IsCleartextToken(), "AWS IAM tokens should be communicated in cleartext") + + username, password, err := credentialsProvider.Get(context.Background(), "some-hostname:5432", "some-user") require.NoError(t, err) require.Equal(t, "some-user", username) require.Containsf(t, password, "X-Amz-Algorithm", "signed token should contain algorithm attribute")