Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add aws iam authentication for mysql #1867

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions internal/datastore/mysql/common/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package common

import (
"context"

"github.com/go-sql-driver/mysql"

log "github.com/authzed/spicedb/internal/logging"
"github.com/authzed/spicedb/pkg/datastore"
)

// MaybeAddCredentialsProviderHook adds a hook that retrieves the configuration from the CredentialsProvider if the given credentialsProvider is not nil
func MaybeAddCredentialsProviderHook(dbConfig *mysql.Config, credentialsProvider datastore.CredentialsProvider) error {
j-white marked this conversation as resolved.
Show resolved Hide resolved
if credentialsProvider == nil {
// a noop for a nil CredentialsProvider
return nil
}

log.Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider")

if credentialsProvider.IsCleartextToken() {
// we must transmit the token over the connection, and not a hash
dbConfig.AllowCleartextPasswords = true

// log a warning if we don't detect TLS to be enabled
if dbConfig.TLSConfig == "false" || dbConfig.TLS == nil {
log.Warn().Msg("Tokens originating from credential provider are sent in cleartext. We recommend enabling TLS for the connection.")
}
}

// add a before connect callback to trigger the token retrieval from the credentials provider
return dbConfig.Apply(mysql.BeforeConnect(func(ctx context.Context, config *mysql.Config) error {
var err error
config.User, config.Passwd, err = credentialsProvider.Get(ctx, config.Addr, config.User)
return err
}))
}
19 changes: 18 additions & 1 deletion internal/datastore/mysql/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"sync/atomic"
"time"

mysqlCommon "github.com/authzed/spicedb/internal/datastore/mysql/common"

sq "github.com/Masterminds/squirrel"
"github.com/dlmiddlecote/sqlstats"
"github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -118,7 +120,22 @@ func newMySQLDatastore(ctx context.Context, uri string, options ...Option) (*Dat
return nil, errors.New("error in NewMySQLDatastore: connection URI for MySQL datastore must include `parseTime=true` as a query parameter; see https://spicedb.dev/d/parse-time-mysql for more details")
}

connector, err := mysql.MySQLDriver{}.OpenConnector(uri)
// Setup the credentials provider
var credentialsProvider datastore.CredentialsProvider
if config.credentialsProviderName != "" {
credentialsProvider, err = datastore.NewCredentialsProvider(ctx, config.credentialsProviderName)
if err != nil {
return nil, err
}
}

err = mysqlCommon.MaybeAddCredentialsProviderHook(parsedURI, credentialsProvider)
if err != nil {
return nil, err
}

// Call NewConnector with the existing parsed configuration to preserve the BeforeConnect added by the CredentialsProvider
connector, err := mysql.NewConnector(parsedURI)
j-white marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, "NewMySQLDatastore: failed to create connector", err, uri)
}
Expand Down
17 changes: 17 additions & 0 deletions internal/datastore/mysql/datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,23 @@ func TestMySQLMigrationsWithPrefix(t *testing.T) {
req.NoError(rows.Err())
}

func TestMySQLWithAWSIAMCredentialsProvider(t *testing.T) {
j-white marked this conversation as resolved.
Show resolved Hide resolved
// set up the environment, so we don't make any external calls to AWS
t.Setenv("AWS_CONFIG_FILE", "file_not_exists")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists")
t.Setenv("AWS_ENDPOINT_URL", "http://169.254.169.254/aws")
t.Setenv("AWS_ACCESS_KEY", "access_key")
t.Setenv("AWS_SECRET_KEY", "secret_key")
t.Setenv("AWS_REGION", "us-east-1")

// initialize the datastore using the AWS IAM credentials provider, and point it to a database that does not exist
_, err := NewMySQLDatastore(context.Background(), "root:password@(localhost:1234)/mysql?parseTime=True&tls=skip-verify", CredentialsProviderName("aws-iam"))

// we expect the connection attempt to fail
// which means that the credentials provider was wired and called successfully before making the connection attempt
require.ErrorContains(t, err, ":1234: connect: connection refused")
}

func datastoreDB(t *testing.T, migrate bool) *sql.DB {
var databaseURI string
testdatastore.RunMySQLForTestingWithOptions(t, testdatastore.MySQLTesterOptions{MigrateForNewDatastore: migrate}, "").NewDatastore(t, func(engine, uri string) datastore.Datastore {
Expand Down
16 changes: 14 additions & 2 deletions internal/datastore/mysql/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import (
"fmt"
"strings"

mysqlCommon "github.com/authzed/spicedb/internal/datastore/mysql/common"

"github.com/authzed/spicedb/pkg/datastore"

"github.com/authzed/spicedb/internal/datastore/common"

sq "github.com/Masterminds/squirrel"
Expand Down Expand Up @@ -35,16 +39,24 @@ type MySQLDriver struct {
//
// URI: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2...
// See https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html
func NewMySQLDriverFromDSN(url string, tablePrefix string) (*MySQLDriver, error) {
func NewMySQLDriverFromDSN(url string, tablePrefix string, credentialsProvider datastore.CredentialsProvider) (*MySQLDriver, error) {
dbConfig, err := sqlDriver.ParseDSN(url)
if err != nil {
return nil, fmt.Errorf(errUnableToInstantiate, err)
}

db, err := sql.Open("mysql", dbConfig.FormatDSN())
err = mysqlCommon.MaybeAddCredentialsProviderHook(dbConfig, credentialsProvider)
if err != nil {
return nil, fmt.Errorf(errUnableToInstantiate, err)
}

// Call NewConnector with the existing parsed configuration to preserve the BeforeConnect added by the CredentialsProvider
connector, err := sqlDriver.NewConnector(dbConfig)
j-white marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, fmt.Errorf(errUnableToInstantiate, err)
}

db := sql.OpenDB(connector)
err = sqlDriver.SetLogger(&log.Logger)
if err != nil {
return nil, fmt.Errorf("unable to set logging to mysql driver: %w", err)
Expand Down
11 changes: 11 additions & 0 deletions internal/datastore/mysql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
defaultEnablePrometheusStats = false
defaultMaxRetries = 8
defaultGCEnabled = true
defaultCredentialsProviderName = ""
)

type mysqlOptions struct {
Expand All @@ -40,6 +41,7 @@ type mysqlOptions struct {
maxRetries uint8
lockWaitTimeoutSeconds *uint8
gcEnabled bool
credentialsProviderName string
}

// Option provides the facility to configure how clients within the
Expand All @@ -61,6 +63,7 @@ func generateConfig(options []Option) (mysqlOptions, error) {
enablePrometheusStats: defaultEnablePrometheusStats,
maxRetries: defaultMaxRetries,
gcEnabled: defaultGCEnabled,
credentialsProviderName: defaultCredentialsProviderName,
}

for _, option := range options {
Expand Down Expand Up @@ -236,3 +239,11 @@ func GCMaxOperationTime(time time.Duration) Option {
mo.gcMaxOperationTime = time
}
}

// CredentialsProviderName is the name of the CredentialsProvider implementation to use
// for dynamically retrieving the datastore credentials at runtime
//
// Empty by default.
func CredentialsProviderName(credentialsProviderName string) Option {
return func(mo *mysqlOptions) { mo.credentialsProviderName = credentialsProviderName }
}
2 changes: 1 addition & 1 deletion internal/datastore/postgres/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewAlembicPostgresDriver(ctx context.Context, url string, credentialsProvid

if credentialsProvider != nil {
log.Ctx(ctx).Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider")
connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, connConfig.Host, connConfig.Port, connConfig.User)
connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", connConfig.Host, connConfig.Port), connConfig.User)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions internal/datastore/postgres/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ func MigrationPhase(phase string) Option {
return func(po *postgresOptions) { po.migrationPhase = phase }
}

// CredentialsProviderName is the name of the CredentialsProvider implementation to use
// for dynamically retrieving the datastore credentials at runtime
//
// Empty by default.
func CredentialsProviderName(credentialsProviderName string) Option {
return func(po *postgresOptions) { po.credentialsProviderName = credentialsProviderName }
}
2 changes: 1 addition & 1 deletion internal/datastore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func newPostgresDatastore(
if credentialsProvider != nil {
// add before connect callbacks to trigger the token
getToken := func(ctx context.Context, config *pgx.ConnConfig) error {
config.User, config.Password, err = credentialsProvider.Get(ctx, config.Host, config.Port, config.User)
config.User, config.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", config.Host, config.Port), config.User)
return err
}
readPoolConfig.BeforeConnect = getToken
Expand Down
2 changes: 1 addition & 1 deletion internal/testserver/datastore/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (mb *mysqlTester) NewDatabase(t testing.TB) string {
}

func (mb *mysqlTester) runMigrate(t testing.TB, dsn string) {
driver, err := migrations.NewMySQLDriverFromDSN(dsn, mb.options.Prefix)
driver, err := migrations.NewMySQLDriverFromDSN(dsn, mb.options.Prefix, datastore.NoCredentialsProvider)
require.NoError(t, err, "failed to create migration driver: %s", err)
err = migrations.Manager.Run(context.Background(), driver, migrate.Head, migrate.LiveRun)
require.NoError(t, err, "failed to run migration: %s", err)
Expand Down
1 change: 1 addition & 0 deletions pkg/cmd/datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ func newMySQLDatastore(ctx context.Context, opts Config) (datastore.Datastore, e
mysql.WithEnablePrometheusStats(opts.EnableDatastoreMetrics),
mysql.MaxRetries(uint8(opts.MaxRetries)),
mysql.OverrideLockWaitTimeout(1),
mysql.CredentialsProviderName(opts.CredentialsProviderName),
}
return mysql.NewMySQLDatastore(ctx, opts.URI, mysqlOpts...)
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/cmd/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,17 @@ func migrateRun(cmd *cobra.Command, args []string) error {
log.Ctx(cmd.Context()).Fatal().Msg(fmt.Sprintf("unable to get table prefix: %s", err))
}

migrationDriver, err := mysqlmigrations.NewMySQLDriverFromDSN(dbURL, tablePrefix)
var credentialsProvider datastore.CredentialsProvider
credentialsProviderName := cobrautil.MustGetString(cmd, "datastore-credentials-provider-name")
if credentialsProviderName != "" {
var err error
credentialsProvider, err = datastore.NewCredentialsProvider(cmd.Context(), credentialsProviderName)
if err != nil {
return err
}
}

migrationDriver, err := mysqlmigrations.NewMySQLDriverFromDSN(dbURL, tablePrefix, credentialsProvider)
if err != nil {
return fmt.Errorf("unable to create migration driver for %s: %w", datastoreEngine, err)
}
Expand Down
15 changes: 12 additions & 3 deletions pkg/datastore/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ import (
type CredentialsProvider interface {
// Name returns the name of the provider
Name() string
// IsCleartextToken returns true if the token returned represents a token (rather than a password) that must be sent in cleartext to the datastore, or false otherwise.
// This may be used to configure the datastore options to avoid sending a hash of the token instead of its value.
// Note that it is always recommended that communication channel be encrypted.
IsCleartextToken() bool
// Get returns the username and password to use when connecting to the underlying datastore
Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error)
Get(ctx context.Context, dbEndpoint string, dbUser string) (string, string, error)
}

var NoCredentialsProvider CredentialsProvider = nil
Expand Down Expand Up @@ -74,8 +78,13 @@ func (d awsIamCredentialsProvider) Name() string {
return AWSIAMCredentialProvider
}

func (d awsIamCredentialsProvider) Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) {
dbEndpoint := fmt.Sprintf("%s:%d", dbHostname, dbPort)
func (d awsIamCredentialsProvider) IsCleartextToken() bool {
j-white marked this conversation as resolved.
Show resolved Hide resolved
// The AWS IAM token can be of an arbitrary length and must not be hashed or truncated by the datastore driver
// See https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
return true
}

func (d awsIamCredentialsProvider) Get(ctx context.Context, dbEndpoint string, dbUser string) (string, string, error) {
authToken, err := rdsauth.BuildAuthToken(ctx, dbEndpoint, d.awsSdkConfig.Region, dbUser, d.awsSdkConfig.Credentials)
if err != nil {
log.Ctx(ctx).Trace().Str("region", d.awsSdkConfig.Region).Str("endpoint", dbEndpoint).Str("user", dbUser).Msg("successfully retrieved IAM auth token for DB")
Expand Down
4 changes: 3 additions & 1 deletion pkg/datastore/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ func TestAWSIAMCredentialsProvider(t *testing.T) {
require.NotNil(t, credentialsProvider)
require.NoError(t, err)

username, password, err := credentialsProvider.Get(context.Background(), "some-hostname", 5432, "some-user")
require.True(t, credentialsProvider.IsCleartextToken(), "AWS IAM tokens should be communicated in cleartext")

username, password, err := credentialsProvider.Get(context.Background(), "some-hostname:5432", "some-user")
require.NoError(t, err)
require.Equal(t, "some-user", username)
require.Containsf(t, password, "X-Amz-Algorithm", "signed token should contain algorithm attribute")
Expand Down
Loading