Skip to content

Commit

Permalink
add aws iam authentication for mysql
Browse files Browse the repository at this point in the history
  • Loading branch information
j-white committed Apr 16, 2024
1 parent 3d7871a commit 77d32da
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 11 deletions.
22 changes: 21 additions & 1 deletion internal/datastore/mysql/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,27 @@ func newMySQLDatastore(ctx context.Context, uri string, options ...Option) (*Dat
return nil, errors.New("error in NewMySQLDatastore: connection URI for MySQL datastore must include `parseTime=true` as a query parameter; see https://spicedb.dev/d/parse-time-mysql for more details")
}

connector, err := mysql.MySQLDriver{}.OpenConnector(uri)
// Setup the credentials provider
var credentialsProvider datastore.CredentialsProvider
if config.credentialsProviderName != "" {
credentialsProvider, err = datastore.NewCredentialsProvider(ctx, config.credentialsProviderName)
if err != nil {
return nil, err
}
}

if credentialsProvider != nil {
// add a before connect callback to trigger the token retrieval from the credentials provider
err := parsedURI.Apply(mysql.BeforeConnect(func(ctx context.Context, config *mysql.Config) error {
config.User, config.Passwd, err = credentialsProvider.Get(ctx, config.Addr, config.User)
return err
}))
if err != nil {
return nil, err
}
}

connector, err := mysql.NewConnector(parsedURI)
if err != nil {
return nil, common.RedactAndLogSensitiveConnString(ctx, "NewMySQLDatastore: failed to create connector", err, uri)
}
Expand Down
19 changes: 17 additions & 2 deletions internal/datastore/mysql/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"strings"

"github.com/authzed/spicedb/pkg/datastore"

"github.com/authzed/spicedb/internal/datastore/common"

sq "github.com/Masterminds/squirrel"
Expand Down Expand Up @@ -35,16 +37,29 @@ type MySQLDriver struct {
//
// URI: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2...
// See https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html
func NewMySQLDriverFromDSN(url string, tablePrefix string) (*MySQLDriver, error) {
func NewMySQLDriverFromDSN(url string, tablePrefix string, credentialsProvider datastore.CredentialsProvider) (*MySQLDriver, error) {
dbConfig, err := sqlDriver.ParseDSN(url)
if err != nil {
return nil, fmt.Errorf(errUnableToInstantiate, err)
}

db, err := sql.Open("mysql", dbConfig.FormatDSN())
if credentialsProvider != nil {
// add a before connect callback to trigger the token retrieval from the credentials provider
err := dbConfig.Apply(sqlDriver.BeforeConnect(func(ctx context.Context, config *sqlDriver.Config) error {
config.User, config.Passwd, err = credentialsProvider.Get(ctx, config.Addr, config.User)
return err
}))
if err != nil {
return nil, fmt.Errorf(errUnableToInstantiate, err)
}
}

connector, err := sqlDriver.NewConnector(dbConfig)
if err != nil {
return nil, fmt.Errorf(errUnableToInstantiate, err)
}

db := sql.OpenDB(connector)
err = sqlDriver.SetLogger(&log.Logger)
if err != nil {
return nil, fmt.Errorf("unable to set logging to mysql driver: %w", err)
Expand Down
11 changes: 11 additions & 0 deletions internal/datastore/mysql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
defaultEnablePrometheusStats = false
defaultMaxRetries = 8
defaultGCEnabled = true
defaultCredentialsProviderName = ""
)

type mysqlOptions struct {
Expand All @@ -40,6 +41,7 @@ type mysqlOptions struct {
maxRetries uint8
lockWaitTimeoutSeconds *uint8
gcEnabled bool
credentialsProviderName string
}

// Option provides the facility to configure how clients within the
Expand All @@ -61,6 +63,7 @@ func generateConfig(options []Option) (mysqlOptions, error) {
enablePrometheusStats: defaultEnablePrometheusStats,
maxRetries: defaultMaxRetries,
gcEnabled: defaultGCEnabled,
credentialsProviderName: defaultCredentialsProviderName,
}

for _, option := range options {
Expand Down Expand Up @@ -236,3 +239,11 @@ func GCMaxOperationTime(time time.Duration) Option {
mo.gcMaxOperationTime = time
}
}

// CredentialsProviderName the name of the CredentialsProvider implementation to use
// for dynamically retrieving the datastore credentials at runtime
//
// Disabled by default.
func CredentialsProviderName(credentialsProviderName string) Option {
return func(mo *mysqlOptions) { mo.credentialsProviderName = credentialsProviderName }
}
2 changes: 1 addition & 1 deletion internal/datastore/postgres/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewAlembicPostgresDriver(ctx context.Context, url string, credentialsProvid

if credentialsProvider != nil {
log.Ctx(ctx).Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider")
connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, connConfig.Host, connConfig.Port, connConfig.User)
connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", connConfig.Host, connConfig.Port), connConfig.User)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions internal/datastore/postgres/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ func MigrationPhase(phase string) Option {
return func(po *postgresOptions) { po.migrationPhase = phase }
}

// CredentialsProviderName the name of the CredentialsProvider implementation to use
// for dynamically retrieving the datastore credentials at runtime
//
// Disabled by default.
func CredentialsProviderName(credentialsProviderName string) Option {
return func(po *postgresOptions) { po.credentialsProviderName = credentialsProviderName }
}
2 changes: 1 addition & 1 deletion internal/datastore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func newPostgresDatastore(
if credentialsProvider != nil {
// add before connect callbacks to trigger the token
getToken := func(ctx context.Context, config *pgx.ConnConfig) error {
config.User, config.Password, err = credentialsProvider.Get(ctx, config.Host, config.Port, config.User)
config.User, config.Password, err = credentialsProvider.Get(ctx, fmt.Sprintf("%s:%d", config.Host, config.Port), config.User)
return err
}
readPoolConfig.BeforeConnect = getToken
Expand Down
2 changes: 1 addition & 1 deletion internal/testserver/datastore/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (mb *mysqlTester) NewDatabase(t testing.TB) string {
}

func (mb *mysqlTester) runMigrate(t testing.TB, dsn string) {
driver, err := migrations.NewMySQLDriverFromDSN(dsn, mb.options.Prefix)
driver, err := migrations.NewMySQLDriverFromDSN(dsn, mb.options.Prefix, datastore.NoCredentialsProvider)
require.NoError(t, err, "failed to create migration driver: %s", err)
err = migrations.Manager.Run(context.Background(), driver, migrate.Head, migrate.LiveRun)
require.NoError(t, err, "failed to run migration: %s", err)
Expand Down
1 change: 1 addition & 0 deletions pkg/cmd/datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ func newMySQLDatastore(ctx context.Context, opts Config) (datastore.Datastore, e
mysql.WithEnablePrometheusStats(opts.EnableDatastoreMetrics),
mysql.MaxRetries(uint8(opts.MaxRetries)),
mysql.OverrideLockWaitTimeout(1),
mysql.CredentialsProviderName(opts.CredentialsProviderName),
}
return mysql.NewMySQLDatastore(ctx, opts.URI, mysqlOpts...)
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/cmd/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,17 @@ func migrateRun(cmd *cobra.Command, args []string) error {
log.Ctx(cmd.Context()).Fatal().Msg(fmt.Sprintf("unable to get table prefix: %s", err))
}

migrationDriver, err := mysqlmigrations.NewMySQLDriverFromDSN(dbURL, tablePrefix)
var credentialsProvider datastore.CredentialsProvider
credentialsProviderName := cobrautil.MustGetString(cmd, "datastore-credentials-provider-name")
if credentialsProviderName != "" {
var err error
credentialsProvider, err = datastore.NewCredentialsProvider(cmd.Context(), credentialsProviderName)
if err != nil {
return err
}
}

migrationDriver, err := mysqlmigrations.NewMySQLDriverFromDSN(dbURL, tablePrefix, credentialsProvider)
if err != nil {
return fmt.Errorf("unable to create migration driver for %s: %w", datastoreEngine, err)
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/datastore/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type CredentialsProvider interface {
// Name returns the name of the provider
Name() string
// Get returns the username and password to use when connecting to the underlying datastore
Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error)
Get(ctx context.Context, dbEndpoint string, dbUser string) (string, string, error)
}

var NoCredentialsProvider CredentialsProvider = nil
Expand Down Expand Up @@ -74,8 +74,7 @@ func (d awsIamCredentialsProvider) Name() string {
return AWSIAMCredentialProvider
}

func (d awsIamCredentialsProvider) Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) {
dbEndpoint := fmt.Sprintf("%s:%d", dbHostname, dbPort)
func (d awsIamCredentialsProvider) Get(ctx context.Context, dbEndpoint string, dbUser string) (string, string, error) {
authToken, err := rdsauth.BuildAuthToken(ctx, dbEndpoint, d.awsSdkConfig.Region, dbUser, d.awsSdkConfig.Credentials)
if err != nil {
log.Ctx(ctx).Trace().Str("region", d.awsSdkConfig.Region).Str("endpoint", dbEndpoint).Str("user", dbUser).Msg("successfully retrieved IAM auth token for DB")
Expand Down
2 changes: 1 addition & 1 deletion pkg/datastore/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestAWSIAMCredentialsProvider(t *testing.T) {
require.NotNil(t, credentialsProvider)
require.NoError(t, err)

username, password, err := credentialsProvider.Get(context.Background(), "some-hostname", 5432, "some-user")
username, password, err := credentialsProvider.Get(context.Background(), "some-hostname:5432", "some-user")
require.NoError(t, err)
require.Equal(t, "some-user", username)
require.Containsf(t, password, "X-Amz-Algorithm", "signed token should contain algorithm attribute")
Expand Down

0 comments on commit 77d32da

Please sign in to comment.