diff --git a/cmd/cli/handler.go b/cmd/cli/handler.go index ce650479292..da2d3746452 100644 --- a/cmd/cli/handler.go +++ b/cmd/cli/handler.go @@ -16,7 +16,7 @@ type Handler struct { func NewHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *Handler { return &Handler{ - Migration: newMigrateHandler(), + Migration: newMigrateHandler(slOpts, dOpts, cOpts), Janitor: NewJanitorHandler(slOpts, dOpts, cOpts), } } diff --git a/cmd/cli/handler_migrate.go b/cmd/cli/handler_migrate.go index 0172584028f..a4c2dd4885d 100644 --- a/cmd/cli/handler_migrate.go +++ b/cmd/cli/handler_migrate.go @@ -34,10 +34,18 @@ import ( "github.com/ory/x/flagx" ) -type MigrateHandler struct{} +type MigrateHandler struct { + slOpts []servicelocatorx.Option + dOpts []driver.OptionsModifier + cOpts []configx.OptionModifier +} -func newMigrateHandler() *MigrateHandler { - return &MigrateHandler{} +func newMigrateHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *MigrateHandler { + return &MigrateHandler{ + slOpts: slOpts, + dOpts: dOpts, + cOpts: cOpts, + } } const ( @@ -262,21 +270,21 @@ func (h *MigrateHandler) MigrateGen(cmd *cobra.Command, args []string) { os.Exit(0) } -func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) { +func (h *MigrateHandler) makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) { var d driver.Registry if flagx.MustGetBool(cmd, "read-from-env") { d, err = driver.New( cmd.Context(), servicelocatorx.NewOptions(), - []driver.OptionsModifier{ + append([]driver.OptionsModifier{ driver.WithOptions( configx.SkipValidation(), configx.WithFlags(cmd.Flags())), driver.DisableValidation(), driver.DisablePreloading(), driver.SkipNetworkInit(), - }) + }, h.dOpts...)) if err != nil { return nil, err } @@ -292,7 +300,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, d, err = driver.New( cmd.Context(), servicelocatorx.NewOptions(), - []driver.OptionsModifier{ + append([]driver.OptionsModifier{ driver.WithOptions( configx.WithFlags(cmd.Flags()), configx.SkipValidation(), @@ -301,7 +309,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, driver.DisableValidation(), driver.DisablePreloading(), driver.SkipNetworkInit(), - }) + }, h.dOpts...)) if err != nil { return nil, err } @@ -310,7 +318,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, } func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) (err error) { - p, err := makePersister(cmd, args) + p, err := h.makePersister(cmd, args) if err != nil { return err } @@ -360,7 +368,7 @@ func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) (err erro } func (h *MigrateHandler) MigrateStatus(cmd *cobra.Command, args []string) error { - p, err := makePersister(cmd, args) + p, err := h.makePersister(cmd, args) if err != nil { return err } diff --git a/cmd/migrate_status.go b/cmd/migrate_status.go index 397f5e86e48..940728b79b6 100644 --- a/cmd/migrate_status.go +++ b/cmd/migrate_status.go @@ -4,6 +4,7 @@ package cmd import ( + "github.com/ory/x/cmdx" "github.com/ory/x/configx" "github.com/ory/x/servicelocatorx" @@ -20,6 +21,7 @@ func NewMigrateStatusCmd(slOpts []servicelocatorx.Option, dOpts []driver.Options RunE: cli.NewHandler(slOpts, dOpts, cOpts).Migration.MigrateStatus, } + cmdx.RegisterFormatFlags(cmd.PersistentFlags()) cmd.Flags().BoolP("read-from-env", "e", false, "If set, reads the database connection string from the environment variable DSN or config file key dsn.") cmd.Flags().Bool("block", false, "Block until all migrations have been applied") diff --git a/driver/factory.go b/driver/factory.go index 2e5fe949c29..4b206ac71c7 100644 --- a/driver/factory.go +++ b/driver/factory.go @@ -5,6 +5,7 @@ package driver import ( "context" + "io/fs" "github.com/ory/hydra/v2/driver/config" "github.com/ory/x/configx" @@ -22,7 +23,10 @@ type ( // The first default refers to determining the NID at startup; the second default referes to the fact that the Contextualizer may dynamically change the NID. skipNetworkInit bool tracerWrapper TracerWrapper + extraMigrations []fs.FS } + OptionsModifier func(*options) + TracerWrapper func(*otelx.Tracer) *otelx.Tracer ) @@ -34,14 +38,12 @@ func newOptions() *options { } } -func WithConfig(config *config.DefaultProvider) func(o *options) { +func WithConfig(config *config.DefaultProvider) OptionsModifier { return func(o *options) { o.config = config } } -type OptionsModifier func(*options) - func WithOptions(opts ...configx.OptionModifier) OptionsModifier { return func(o *options) { o.opts = append(o.opts, opts...) @@ -77,6 +79,13 @@ func WithTracerWrapper(wrapper TracerWrapper) OptionsModifier { } } +// WithExtraMigrations specifies additional database migration. +func WithExtraMigrations(m ...fs.FS) OptionsModifier { + return func(o *options) { + o.extraMigrations = append(o.extraMigrations, m...) + } +} + func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifier) (Registry, error) { o := newOptions() for _, f := range opts { @@ -115,7 +124,7 @@ func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifie r.WithTracerWrapper(o.tracerWrapper) } - if err = r.Init(ctx, o.skipNetworkInit, false, ctxter); err != nil { + if err = r.Init(ctx, o.skipNetworkInit, false, ctxter, o.extraMigrations); err != nil { l.WithError(err).Error("Unable to initialize service registry.") return nil, err } diff --git a/driver/registry.go b/driver/registry.go index c75213e52c7..4c956c4cd48 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -5,6 +5,7 @@ package driver import ( "context" + "io/fs" "net/http" "go.opentelemetry.io/otel/trace" @@ -44,7 +45,7 @@ import ( type Registry interface { dbal.Driver - Init(ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer) error + Init(ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer, extraMigrations []fs.FS) error WithBuildInfo(v, h, d string) Registry WithConfig(c *config.DefaultProvider) Registry @@ -89,7 +90,7 @@ func NewRegistryFromDSN(ctx context.Context, c *config.DefaultProvider, l *logru if err != nil { return nil, err } - if err := registry.Init(ctx, skipNetworkInit, migrate, ctxer); err != nil { + if err := registry.Init(ctx, skipNetworkInit, migrate, ctxer, nil); err != nil { return nil, err } return registry, nil diff --git a/driver/registry_base_test.go b/driver/registry_base_test.go index 4e0f80ef859..4dedab5dead 100644 --- a/driver/registry_base_test.go +++ b/driver/registry_base_test.go @@ -67,7 +67,7 @@ func TestRegistryBase_newKeyStrategy_handlesNetworkError(t *testing.T) { r := registry.(*RegistrySQL) r.initialPing = failedPing(errors.New("snizzles")) - _ = r.Init(context.Background(), true, false, &contextx.TestContextualizer{}) + _ = r.Init(context.Background(), true, false, &contextx.TestContextualizer{}, nil) registryBase := RegistryBase{r: r, l: l} registryBase.WithConfig(c) diff --git a/driver/registry_sql.go b/driver/registry_sql.go index 7660a884c90..361d1aa154d 100644 --- a/driver/registry_sql.go +++ b/driver/registry_sql.go @@ -5,6 +5,7 @@ package driver import ( "context" + "io/fs" "strings" "time" @@ -64,7 +65,11 @@ func NewRegistrySQL() *RegistrySQL { } func (m *RegistrySQL) Init( - ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer, + ctx context.Context, + skipNetworkInit bool, + migrate bool, + ctxer contextx.Contextualizer, + extraMigrations []fs.FS, ) error { if m.persister == nil { m.WithContextualizer(ctxer) @@ -100,7 +105,7 @@ func (m *RegistrySQL) Init( return errorsx.WithStack(err) } - p, err := sql.NewPersister(ctx, c, m, m.Config(), m.l) + p, err := sql.NewPersister(ctx, c, m, m.Config(), extraMigrations) if err != nil { return err } diff --git a/driver/registry_sql_test.go b/driver/registry_sql_test.go index ffbc071f1b5..cd126a3f711 100644 --- a/driver/registry_sql_test.go +++ b/driver/registry_sql_test.go @@ -31,7 +31,7 @@ func TestDefaultKeyManager_HsmDisabled(t *testing.T) { reg, err := NewRegistryWithoutInit(c, l) r := reg.(*RegistrySQL) r.initialPing = sussessfulPing() - if err := r.Init(context.Background(), true, false, &contextx.Default{}); err != nil { + if err := r.Init(context.Background(), true, false, &contextx.Default{}, nil); err != nil { t.Fatalf("unable to init registry: %s", err) } assert.NoError(t, err) diff --git a/go.mod b/go.mod index baa5d57cd2b..5e6d3fb1f98 100644 --- a/go.mod +++ b/go.mod @@ -40,10 +40,10 @@ require ( github.com/ory/fosite v0.44.1-0.20230704083823-8098e48b2e09 github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe github.com/ory/graceful v0.1.3 - github.com/ory/herodot v0.10.2 + github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88 github.com/ory/hydra-client-go/v2 v2.1.1 github.com/ory/jsonschema/v3 v3.0.8 - github.com/ory/x v0.0.567 + github.com/ory/x v0.0.574 github.com/pborman/uuid v1.2.1 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 diff --git a/go.sum b/go.sum index 8dc45a304d3..e8f6c2baca8 100644 --- a/go.sum +++ b/go.sum @@ -549,6 +549,7 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/laher/mergefs v0.1.1 h1:nV2bTS57vrmbMxeR6uvJpI8LyGl3QHj4bLBZO3aUV58= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -653,12 +654,12 @@ github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTs github.com/ory/go-convenience v0.1.0/go.mod h1:uEY/a60PL5c12nYz4V5cHY03IBmwIAEm8TWB0yn9KNs= github.com/ory/graceful v0.1.3 h1:FaeXcHZh168WzS+bqruqWEw/HgXWLdNv2nJ+fbhxbhc= github.com/ory/graceful v0.1.3/go.mod h1:4zFz687IAF7oNHHiB586U4iL+/4aV09o/PYLE34t2bA= -github.com/ory/herodot v0.10.2 h1:gGvNMHgAwWzdP/eo+roSiT5CGssygHSjDU7MSQNlJ4E= -github.com/ory/herodot v0.10.2/go.mod h1:MMNmY6MG1uB6fnXYFaHoqdV23DTWctlPsmRCeq/2+wc= +github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88 h1:J0CIFKdpUeqKbVMw7pQ1qLtUnflRM1JWAcOEq7Hp4yg= +github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88/go.mod h1:MMNmY6MG1uB6fnXYFaHoqdV23DTWctlPsmRCeq/2+wc= github.com/ory/jsonschema/v3 v3.0.8 h1:Ssdb3eJ4lDZ/+XnGkvQS/te0p+EkolqwTsDOCxr/FmU= github.com/ory/jsonschema/v3 v3.0.8/go.mod h1:ZPzqjDkwd3QTnb2Z6PAS+OTvBE2x5i6m25wCGx54W/0= -github.com/ory/x v0.0.567 h1:oUj75hIqBv3ESsmIwc/4u8jaD2zSx/HTGzRnfMKUykg= -github.com/ory/x v0.0.567/go.mod h1:g0QdN0Z47vdCYtfrTQkgWJdIOPuez9VGiuQivLxa4lo= +github.com/ory/x v0.0.574 h1:JjdOP6iIh4ngoR1zDxaZL9bsBzIAyvw0aZdqSfJOEVI= +github.com/ory/x v0.0.574/go.mod h1:aeJFTlvDLGYSABzPS3z5SeLcYC52Ek7uGZiuYGcTMSU= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= diff --git a/hsm/manager_hsm_test.go b/hsm/manager_hsm_test.go index cf629bd62c9..e7bc145180e 100644 --- a/hsm/manager_hsm_test.go +++ b/hsm/manager_hsm_test.go @@ -52,7 +52,7 @@ func TestDefaultKeyManager_HSMEnabled(t *testing.T) { reg.WithLogger(l) reg.WithConfig(c) reg.WithHsmContext(mockHsmContext) - err := reg.Init(context.Background(), false, true, &contextx.TestContextualizer{}) + err := reg.Init(context.Background(), false, true, &contextx.TestContextualizer{}, nil) assert.NoError(t, err) assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager()) assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager()) diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index e454f527fac..a9000dd314c 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -6,11 +6,11 @@ package sql import ( "context" "database/sql" + "io/fs" "reflect" "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - "github.com/pkg/errors" "github.com/ory/fosite" @@ -21,6 +21,7 @@ import ( "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" "github.com/ory/x/errorsx" + "github.com/ory/x/fsx" "github.com/ory/x/logrusx" "github.com/ory/x/networkx" "github.com/ory/x/otelx" @@ -104,8 +105,8 @@ func (p *Persister) Rollback(ctx context.Context) (err error) { return errorsx.WithStack(tx.TX.Rollback()) } -func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config *config.DefaultProvider, l *logrusx.Logger) (*Persister, error) { - mb, err := popx.NewMigrationBox(migrations, popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0)) +func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config *config.DefaultProvider, extraMigrations []fs.FS) (*Persister, error) { + mb, err := popx.NewMigrationBox(fsx.Merge(append([]fs.FS{migrations}, extraMigrations...)...), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0)) if err != nil { return nil, errorsx.WithStack(err) } @@ -115,7 +116,7 @@ func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config mb: mb, r: r, config: config, - l: l, + l: r.Logger(), p: networkx.NewManager(c, r.Logger(), r.Tracer(ctx)), }, nil }