diff --git a/cmd/migrate/config.go b/cmd/migrate/config.go index ff4c3c46d..de812156a 100644 --- a/cmd/migrate/config.go +++ b/cmd/migrate/config.go @@ -3,40 +3,40 @@ package main import "github.com/spf13/pflag" const ( - // configuration defaults support local development (i.e. "go run ...") - defaultDatabaseDSN = "" - defaultDatabaseDriver = "postgres" - defaultDatabaseAddress = "0.0.0.0:5432" - defaultDatabaseName = "" - defaultDatabaseUser = "postgres" - defaultDatabasePassword = "postgres" - defaultDatabaseSSL = "disable" - defaultConfigDirectory = "/cli/config" + // configuration defaults support local development (i.e. "go run ...") + defaultDatabaseDSN = "" + defaultDatabaseDriver = "postgres" + defaultDatabaseAddress = "0.0.0.0:5432" + defaultDatabaseName = "" + defaultDatabaseUser = "postgres" + defaultDatabasePassword = "postgres" + defaultDatabaseSSL = "disable" + defaultConfigDirectory = "/cli/config" ) var ( - // define flag overrides - flagHelp = pflag.Bool("help", false, "Print usage") - flagVersion = pflag.String("version", Version, "Print version") - flagLoggingVerbose = pflag.Bool("verbose", true, "Print verbose logging") - flagPrefetch = pflag.Uint("prefetch", 10, "Number of migrations to load in advance before executing") - flaglockTimeout = pflag.Uint("lock-timeout", 15, "Allow N seconds to acquire database lock") + // define flag overrides + flagHelp = pflag.Bool("help", false, "Print usage") + flagVersion = pflag.String("version", Version, "Print version") + flagLoggingVerbose = pflag.Bool("verbose", true, "Print verbose logging") + flagPrefetch = pflag.Uint("prefetch", 10, "Number of migrations to load in advance before executing") + flaglockTimeout = pflag.Uint("lock-timeout", 15, "Allow N seconds to acquire database lock") - flagDatabaseDSN = pflag.String("database.dsn", defaultDatabaseDSN, "database connection string") - flagDatabaseDriver = pflag.String("database.driver", defaultDatabaseDriver, "database driver") - flagDatabaseAddress = pflag.String("database.address", defaultDatabaseAddress, "address of the database") - flagDatabaseName = pflag.String("database.name", defaultDatabaseName, "name of the database") - flagDatabaseUser = pflag.String("database.user", defaultDatabaseUser, "database username") - flagDatabasePassword = pflag.String("database.password", defaultDatabasePassword, "database password") - flagDatabaseSSL = pflag.String("database.ssl", defaultDatabaseSSL, "database ssl mode") + flagDatabaseDSN = pflag.String("database.dsn", defaultDatabaseDSN, "database connection string") + flagDatabaseDriver = pflag.String("database.driver", defaultDatabaseDriver, "database driver") + flagDatabaseAddress = pflag.String("database.address", defaultDatabaseAddress, "address of the database") + flagDatabaseName = pflag.String("database.name", defaultDatabaseName, "name of the database") + flagDatabaseUser = pflag.String("database.user", defaultDatabaseUser, "database username") + flagDatabasePassword = pflag.String("database.password", defaultDatabasePassword, "database password") + flagDatabaseSSL = pflag.String("database.ssl", defaultDatabaseSSL, "database ssl mode") - flagSource = pflag.String("source", "", "Location of the migrations (driver://url)") - flagPath = pflag.String("path", "", "Shorthand for -source=file://path") + flagSource = pflag.String("source", "", "Location of the migrations (driver://url)") + flagPath = pflag.String("path", "", "Shorthand for -source=file://path") - flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") - flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") + flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") + flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") - // goto command flags - flagDirty = pflag.Bool("dirty", false, "migration is dirty") - flagPVCPath = pflag.String("intermediate-path", "", "path to the mounted volume which is used to copy the migration files") + // goto command flags + flagDirty = pflag.Bool("dirty", false, "migration is dirty") + flagPVCPath = pflag.String("intermediate-path", "", "path to the mounted volume which is used to copy the migration files") ) diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 868938f8e..7adec2f84 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -1,248 +1,248 @@ package cli import ( - "errors" - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/golang-migrate/migrate/v4" - _ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again - _ "github.com/golang-migrate/migrate/v4/source/file" + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again + _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( - errInvalidSequenceWidth = errors.New("Digits must be positive") - errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive") - errInvalidTimeFormat = errors.New("Time format may not be empty") + errInvalidSequenceWidth = errors.New("Digits must be positive") + errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive") + errInvalidTimeFormat = errors.New("Time format may not be empty") ) func nextSeqVersion(matches []string, seqDigits int) (string, error) { - if seqDigits <= 0 { - return "", errInvalidSequenceWidth - } + if seqDigits <= 0 { + return "", errInvalidSequenceWidth + } - nextSeq := uint64(1) + nextSeq := uint64(1) - if len(matches) > 0 { - filename := matches[len(matches)-1] - matchSeqStr := filepath.Base(filename) - idx := strings.Index(matchSeqStr, "_") + if len(matches) > 0 { + filename := matches[len(matches)-1] + matchSeqStr := filepath.Base(filename) + idx := strings.Index(matchSeqStr, "_") - if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit - return "", fmt.Errorf("Malformed migration filename: %s", filename) - } + if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit + return "", fmt.Errorf("Malformed migration filename: %s", filename) + } - var err error - matchSeqStr = matchSeqStr[0:idx] - nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64) + var err error + matchSeqStr = matchSeqStr[0:idx] + nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64) - if err != nil { - return "", err - } + if err != nil { + return "", err + } - nextSeq++ - } + nextSeq++ + } - version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits) + version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits) - if len(version) > seqDigits { - return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits) - } + if len(version) > seqDigits { + return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits) + } - return version, nil + return version, nil } func timeVersion(startTime time.Time, format string) (version string, err error) { - switch format { - case "": - err = errInvalidTimeFormat - case "unix": - version = strconv.FormatInt(startTime.Unix(), 10) - case "unixNano": - version = strconv.FormatInt(startTime.UnixNano(), 10) - default: - version = startTime.Format(format) - } - - return + switch format { + case "": + err = errInvalidTimeFormat + case "unix": + version = strconv.FormatInt(startTime.Unix(), 10) + case "unixNano": + version = strconv.FormatInt(startTime.UnixNano(), 10) + default: + version = startTime.Format(format) + } + + return } // createCmd (meant to be called via a CLI command) creates a new migration func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error { - if seq && format != defaultTimeFormat { - return errIncompatibleSeqAndFormat - } + if seq && format != defaultTimeFormat { + return errIncompatibleSeqAndFormat + } - var version string - var err error + var version string + var err error - dir = filepath.Clean(dir) - ext = "." + strings.TrimPrefix(ext, ".") + dir = filepath.Clean(dir) + ext = "." + strings.TrimPrefix(ext, ".") - if seq { - matches, err := filepath.Glob(filepath.Join(dir, "*"+ext)) + if seq { + matches, err := filepath.Glob(filepath.Join(dir, "*"+ext)) - if err != nil { - return err - } + if err != nil { + return err + } - version, err = nextSeqVersion(matches, seqDigits) + version, err = nextSeqVersion(matches, seqDigits) - if err != nil { - return err - } - } else { - version, err = timeVersion(startTime, format) + if err != nil { + return err + } + } else { + version, err = timeVersion(startTime, format) - if err != nil { - return err - } - } + if err != nil { + return err + } + } - versionGlob := filepath.Join(dir, version+"_*"+ext) - matches, err := filepath.Glob(versionGlob) + versionGlob := filepath.Join(dir, version+"_*"+ext) + matches, err := filepath.Glob(versionGlob) - if err != nil { - return err - } + if err != nil { + return err + } - if len(matches) > 0 { - return fmt.Errorf("duplicate migration version: %s", version) - } + if len(matches) > 0 { + return fmt.Errorf("duplicate migration version: %s", version) + } - if err = os.MkdirAll(dir, os.ModePerm); err != nil { - return err - } + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return err + } - for _, direction := range []string{"up", "down"} { - basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext) - filename := filepath.Join(dir, basename) + for _, direction := range []string{"up", "down"} { + basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext) + filename := filepath.Join(dir, basename) - if err = createFile(filename); err != nil { - return err - } + if err = createFile(filename); err != nil { + return err + } - if print { - absPath, _ := filepath.Abs(filename) - log.Println(absPath) - } - } + if print { + absPath, _ := filepath.Abs(filename) + log.Println(absPath) + } + } - return nil + return nil } func createFile(filename string) error { - // create exclusive (fails if file already exists) - // os.Create() specifies 0666 as the FileMode, so we're doing the same - f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) + // create exclusive (fails if file already exists) + // os.Create() specifies 0666 as the FileMode, so we're doing the same + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) - if err != nil { - return err - } + if err != nil { + return err + } - return f.Close() + return f.Close() } func gotoCmd(m *migrate.Migrate, v uint) error { - if err := m.Migrate(v); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - return nil + if err := m.Migrate(v); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + return nil } func upCmd(m *migrate.Migrate, limit int) error { - if limit >= 0 { - if err := m.Steps(limit); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } else { - if err := m.Up(); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } - return nil + if limit >= 0 { + if err := m.Steps(limit); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } else { + if err := m.Up(); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } + return nil } func downCmd(m *migrate.Migrate, limit int) error { - if limit >= 0 { - if err := m.Steps(-limit); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } else { - if err := m.Down(); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } - return nil + if limit >= 0 { + if err := m.Steps(-limit); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } else { + if err := m.Down(); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } + return nil } func dropCmd(m *migrate.Migrate) error { - if err := m.Drop(); err != nil { - return err - } - return nil + if err := m.Drop(); err != nil { + return err + } + return nil } func forceCmd(m *migrate.Migrate, v int) error { - if err := m.Force(v); err != nil { - return err - } - return nil + if err := m.Force(v); err != nil { + return err + } + return nil } func versionCmd(m *migrate.Migrate) error { - v, dirty, err := m.Version() - if err != nil { - return err - } - if dirty { - log.Printf("%v (dirty)\n", v) - } else { - log.Println(v) - } - return nil + v, dirty, err := m.Version() + if err != nil { + return err + } + if dirty { + log.Printf("%v (dirty)\n", v) + } else { + log.Println(v) + } + return nil } // numDownMigrationsFromArgs returns an int for number of migrations to apply // and a bool indicating if we need a confirm before applying func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) { - if applyAll { - if len(args) > 0 { - return 0, false, errors.New("-all cannot be used with other arguments") - } - return -1, false, nil - } - - switch len(args) { - case 0: - return -1, true, nil - case 1: - downValue := args[0] - n, err := strconv.ParseUint(downValue, 10, 64) - if err != nil { - return 0, false, errors.New("can't read limit argument N") - } - return int(n), false, nil - default: - return 0, false, errors.New("too many arguments") - } + if applyAll { + if len(args) > 0 { + return 0, false, errors.New("-all cannot be used with other arguments") + } + return -1, false, nil + } + + switch len(args) { + case 0: + return -1, true, nil + case 1: + downValue := args[0] + n, err := strconv.ParseUint(downValue, 10, 64) + if err != nil { + return 0, false, errors.New("can't read limit argument N") + } + return int(n), false, nil + default: + return 0, false, errors.New("too many arguments") + } } diff --git a/internal/cli/main.go b/internal/cli/main.go index 74459da29..5158de116 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -1,103 +1,96 @@ package cli import ( - "database/sql" - "fmt" - "net/url" - "os" - "os/signal" - "strconv" - "strings" - "syscall" - "time" - - flag "github.com/spf13/pflag" - "github.com/spf13/viper" - - "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database" - "github.com/golang-migrate/migrate/v4/database/postgres" - "github.com/golang-migrate/migrate/v4/source" + "database/sql" + "fmt" + "net/url" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + flag "github.com/spf13/pflag" + "github.com/spf13/viper" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source" ) const ( - defaultTimeFormat = "20060102150405" - defaultTimezone = "UTC" - createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME + defaultTimeFormat = "20060102150405" + defaultTimezone = "UTC" + createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME Create a set of timestamped up/down migrations titled NAME, in directory D with extension E. Use -seq option to generate sequential up/down migrations with N digits. Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error. Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC). ` - gotoUsage = `goto V [-dirty] Migrate to version V` - upUsage = `up [N] Apply all or N up migrations` - downUsage = `down [N] [-all] Apply all or N down migrations + gotoUsage = `goto V [-dirty] [-intermediate-path] Migrate to version V` + upUsage = `up [N] Apply all or N up migrations` + downUsage = `down [N] [-all] Apply all or N down migrations Use -all to apply all down migrations` - dropUsage = `drop [-f] Drop everything inside database + dropUsage = `drop [-f] Drop everything inside database Use -f to bypass confirmation` - forceUsage = `force V Set version V but don't run migration (ignores dirty state)` + forceUsage = `force V Set version V but don't run migration (ignores dirty state)` ) func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) { - if help { - fmt.Fprintln(os.Stderr, usage) - flagSet.PrintDefaults() - os.Exit(0) - } + if help { + fmt.Fprintln(os.Stderr, usage) + flagSet.PrintDefaults() + os.Exit(0) + } } func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) { - flagSet := flag.NewFlagSet(name, flag.ExitOnError) - helpPtr := flagSet.Bool("help", false, "Print help information") - return flagSet, helpPtr -} - -func newGoToFlagSetWithHelp(name string) (*flag.FlagSet, *bool) { - flagSet := flag.NewFlagSet(name, flag.ExitOnError) - flagSet.Bool("dirty", false, "Migration in dirty state") - helpPtr := flagSet.Bool("help", false, "Print help information") - return flagSet, helpPtr + flagSet := flag.NewFlagSet(name, flag.ExitOnError) + helpPtr := flagSet.Bool("help", false, "Print help information") + return flagSet, helpPtr } // set main log var log = &Log{} func printUsageAndExit() { - flag.Usage() + flag.Usage() - // If a command is not found we exit with a status 2 to match the behavior - // of flag.Parse() with flag.ExitOnError when parsing an invalid flag. - os.Exit(2) + // If a command is not found we exit with a status 2 to match the behavior + // of flag.Parse() with flag.ExitOnError when parsing an invalid flag. + os.Exit(2) } func dbMakeConnectionString(driver, user, password, address, name, ssl string) string { - return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", - driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, - ) + return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", + driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, + ) } // Main function of a cli application. It is public for backwards compatibility with `cli` package func Main(version string) { - help := viper.GetBool("help") - version = viper.GetString("version") - verbose := viper.GetBool("verbose") - prefetch := viper.GetInt("prefetch") - lockTimeout := viper.GetInt("lock-timeout") - path := viper.GetString("path") - sourcePtr := viper.GetString("source") - - databasePtr := viper.GetString("database.dsn") - if databasePtr == "" { - databasePtr = dbMakeConnectionString( - viper.GetString("database.driver"), viper.GetString("database.user"), - viper.GetString("database.password"), viper.GetString("database.address"), - viper.GetString("database.name"), viper.GetString("database.ssl"), - ) - } - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, - `Usage: migrate OPTIONS COMMAND [arg...] + help := viper.GetBool("help") + version = viper.GetString("version") + verbose := viper.GetBool("verbose") + prefetch := viper.GetInt("prefetch") + lockTimeout := viper.GetInt("lock-timeout") + path := viper.GetString("path") + sourcePtr := viper.GetString("source") + + databasePtr := viper.GetString("database.dsn") + if databasePtr == "" { + databasePtr = dbMakeConnectionString( + viper.GetString("database.driver"), viper.GetString("database.user"), + viper.GetString("database.password"), viper.GetString("database.address"), + viper.GetString("database.name"), viper.GetString("database.ssl"), + ) + } + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, + `Usage: migrate OPTIONS COMMAND [arg...] migrate [ -version | -help ] Options: @@ -132,323 +125,322 @@ Commands: Source drivers: `+strings.Join(source.List(), ", ")+` Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage) - } - - // initialize logger - log.verbose = verbose - - // show cli version - if version == "" { - fmt.Fprintln(os.Stderr, version) - os.Exit(0) - } - - // show help - if help { - flag.Usage() - os.Exit(0) - } - - // translate -path into -source if given - if sourcePtr == "" && path != "" { - sourcePtr = fmt.Sprintf("file://%v", path) - } - - // initialize migrate - // don't catch migraterErr here and let each command decide - // how it wants to handle the error - var migrater *migrate.Migrate - var migraterErr error - - if driver := viper.GetString("database.driver"); driver == "hotload" { - db, err := sql.Open(driver, databasePtr) - if err != nil { - log.fatalErr(fmt.Errorf("could not open hotload dsn %s: %s", databasePtr, err)) - } - var dbname, user string - if err := db.QueryRow("SELECT current_database(), user").Scan(&dbname, &user); err != nil { - log.fatalErr(fmt.Errorf("could not get current_database: %s", err.Error())) - } - // dbname is not needed since it gets filled in by the driver but we want to be complete - migrateDriver, err := postgres.WithInstance(db, &postgres.Config{DatabaseName: dbname}) - if err != nil { - log.fatalErr(fmt.Errorf("could not create migrate driver: %s", err)) - } - migrater, migraterErr = migrate.NewWithDatabaseInstance(sourcePtr, dbname, migrateDriver) - } else { - migrater, migraterErr = migrate.New(sourcePtr, databasePtr) - } - defer func() { - if migraterErr == nil { - if _, err := migrater.Close(); err != nil { - log.Println(err) - } - } - }() - if migraterErr == nil { - migrater.Log = log - migrater.PrefetchMigrations = uint(prefetch) - migrater.LockTimeout = time.Duration(int64(lockTimeout)) * time.Second - - // handle Ctrl+c - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT) - go func() { - for range signals { - log.Println("Stopping after this running migration ...") - migrater.GracefulStop <- true - return - } - }() - } - - startTime := time.Now() - - if len(flag.Args()) < 1 { - printUsageAndExit() - } - args := flag.Args()[1:] - - switch flag.Arg(0) { - case "create": - - seq := false - seqDigits := 6 - - createFlagSet, help := newFlagSetWithHelp("create") - extPtr := createFlagSet.String("ext", "", "File extension") - dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)") - formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`) - timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`) - createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)") - createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)") - - if err := createFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*help, createUsage, createFlagSet) - - if createFlagSet.NArg() == 0 { - log.fatal("error: please specify name") - } - name := createFlagSet.Arg(0) - - if *extPtr == "" { - log.fatal("error: -ext flag must be specified") - } - - timezone, err := time.LoadLocation(*timezoneName) - if err != nil { - log.fatal(err) - } - - if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil { - log.fatalErr(err) - } - - case "goto": - - gotoSet, helpPtr := newFlagSetWithHelp("goto") - - if err := gotoSet.Parse(args); err != nil { - log.fatalErr(err) - } - handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if gotoSet.NArg() == 0 { - log.fatal("error: please specify version argument V") - } - - v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read version argument V") - } - handleDirty := viper.GetBool("dirty") - destPath := viper.GetString("intermediate-path") - srcPath := "" - // if sourcePtr is set, use it to get the source path - // otherwise, use the path flag - if path != "" { - srcPath = path - } - if sourcePtr != "" { - // parse the source path from the source argument - parse, err := url.Parse(sourcePtr) - if err != nil { - log.fatal("error: can't parse the source path from the source argument") - } - srcPath = parse.Path - } - - if handleDirty && destPath == "" { - log.fatal("error: intermediate-path must be specified when dirty is set") - } - log.Printf("running goto with handleDirty: %t, destPath: %s, srcPath: %s\n", handleDirty, destPath, srcPath) - migrater.WithDirtyStateHandler(srcPath, destPath, handleDirty) - if err = gotoCmd(migrater, uint(v)); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "up": - upSet, helpPtr := newFlagSetWithHelp("up") - - if err := upSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, upUsage, upSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - limit := -1 - if upSet.NArg() > 0 { - n, err := strconv.ParseUint(upSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read limit argument N") - } - limit = int(n) - } - - if err := upCmd(migrater, limit); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "down": - downFlagSet, helpPtr := newFlagSetWithHelp("down") - applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") - if err := downFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - downArgs := downFlagSet.Args() - - log.Println(*applyAll, downArgs) - - num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) - if err != nil { - log.fatalErr(err) - } - if needsConfirm { - log.Println("Are you sure you want to apply all down migrations? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Applying all down migrations") - } else { - log.fatal("Not applying all down migrations") - } - } - - if err := downCmd(migrater, num); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "drop": - dropFlagSet, help := newFlagSetWithHelp("drop") - forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") - - if err := dropFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*help, dropUsage, dropFlagSet) - - if !*forceDrop { - log.Println("Are you sure you want to drop the entire database schema? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Dropping the entire database schema") - } else { - log.fatal("Aborted dropping the entire database schema") - } - } - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if err := dropCmd(migrater); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "force": - forceSet, helpPtr := newFlagSetWithHelp("force") - - if err := forceSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, forceUsage, forceSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if forceSet.NArg() == 0 { - log.fatal("error: please specify version argument V") - } - - v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read version argument V") - } - - if v < -1 { - log.fatal("error: argument V must be >= -1") - } - - if err := forceCmd(migrater, int(v)); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "version": - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if err := versionCmd(migrater); err != nil { - log.fatalErr(err) - } - - default: - printUsageAndExit() - } + } + + // initialize logger + log.verbose = verbose + + // show cli version + if version == "" { + fmt.Fprintln(os.Stderr, version) + os.Exit(0) + } + + // show help + if help { + flag.Usage() + os.Exit(0) + } + + // translate -path into -source if given + if sourcePtr == "" && path != "" { + sourcePtr = fmt.Sprintf("file://%v", path) + } + + // initialize migrate + // don't catch migraterErr here and let each command decide + // how it wants to handle the error + var migrater *migrate.Migrate + var migraterErr error + + if driver := viper.GetString("database.driver"); driver == "hotload" { + db, err := sql.Open(driver, databasePtr) + if err != nil { + log.fatalErr(fmt.Errorf("could not open hotload dsn %s: %s", databasePtr, err)) + } + var dbname, user string + if err := db.QueryRow("SELECT current_database(), user").Scan(&dbname, &user); err != nil { + log.fatalErr(fmt.Errorf("could not get current_database: %s", err.Error())) + } + // dbname is not needed since it gets filled in by the driver but we want to be complete + migrateDriver, err := postgres.WithInstance(db, &postgres.Config{DatabaseName: dbname}) + if err != nil { + log.fatalErr(fmt.Errorf("could not create migrate driver: %s", err)) + } + migrater, migraterErr = migrate.NewWithDatabaseInstance(sourcePtr, dbname, migrateDriver) + } else { + migrater, migraterErr = migrate.New(sourcePtr, databasePtr) + } + defer func() { + if migraterErr == nil { + if _, err := migrater.Close(); err != nil { + log.Println(err) + } + } + }() + if migraterErr == nil { + migrater.Log = log + migrater.PrefetchMigrations = uint(prefetch) + migrater.LockTimeout = time.Duration(int64(lockTimeout)) * time.Second + + // handle Ctrl+c + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT) + go func() { + for range signals { + log.Println("Stopping after this running migration ...") + migrater.GracefulStop <- true + return + } + }() + } + + startTime := time.Now() + + if len(flag.Args()) < 1 { + printUsageAndExit() + } + args := flag.Args()[1:] + + switch flag.Arg(0) { + case "create": + + seq := false + seqDigits := 6 + + createFlagSet, help := newFlagSetWithHelp("create") + extPtr := createFlagSet.String("ext", "", "File extension") + dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)") + formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`) + timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`) + createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)") + createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)") + + if err := createFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*help, createUsage, createFlagSet) + + if createFlagSet.NArg() == 0 { + log.fatal("error: please specify name") + } + name := createFlagSet.Arg(0) + + if *extPtr == "" { + log.fatal("error: -ext flag must be specified") + } + + timezone, err := time.LoadLocation(*timezoneName) + if err != nil { + log.fatal(err) + } + + if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil { + log.fatalErr(err) + } + + case "goto": + + gotoSet, helpPtr := newFlagSetWithHelp("goto") + + if err := gotoSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if gotoSet.NArg() == 0 { + log.fatal("error: please specify version argument V") + } + + v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read version argument V") + } + handleDirty := viper.GetBool("dirty") + destPath := viper.GetString("intermediate-path") + srcPath := "" + // if sourcePtr is set, use it to get the source path + // otherwise, use the path flag + if path != "" { + srcPath = path + } + if sourcePtr != "" { + // parse the source path from the source argument + parse, err := url.Parse(sourcePtr) + if err != nil { + log.fatal("error: can't parse the source path from the source argument") + } + srcPath = parse.Path + } + + if handleDirty && destPath == "" { + log.fatal("error: intermediate-path must be specified when dirty is set") + } + + migrater.WithDirtyStateHandler(srcPath, destPath, handleDirty) + if err = gotoCmd(migrater, uint(v)); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "up": + upSet, helpPtr := newFlagSetWithHelp("up") + + if err := upSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, upUsage, upSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + limit := -1 + if upSet.NArg() > 0 { + n, err := strconv.ParseUint(upSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read limit argument N") + } + limit = int(n) + } + + if err := upCmd(migrater, limit); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "down": + downFlagSet, helpPtr := newFlagSetWithHelp("down") + applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") + + if err := downFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + downArgs := downFlagSet.Args() + num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) + if err != nil { + log.fatalErr(err) + } + if needsConfirm { + log.Println("Are you sure you want to apply all down migrations? [y/N]") + var response string + _, _ = fmt.Scanln(&response) + response = strings.ToLower(strings.TrimSpace(response)) + + if response == "y" { + log.Println("Applying all down migrations") + } else { + log.fatal("Not applying all down migrations") + } + } + + if err := downCmd(migrater, num); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "drop": + dropFlagSet, help := newFlagSetWithHelp("drop") + forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") + + if err := dropFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*help, dropUsage, dropFlagSet) + + if !*forceDrop { + log.Println("Are you sure you want to drop the entire database schema? [y/N]") + var response string + _, _ = fmt.Scanln(&response) + response = strings.ToLower(strings.TrimSpace(response)) + + if response == "y" { + log.Println("Dropping the entire database schema") + } else { + log.fatal("Aborted dropping the entire database schema") + } + } + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if err := dropCmd(migrater); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "force": + forceSet, helpPtr := newFlagSetWithHelp("force") + + if err := forceSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, forceUsage, forceSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if forceSet.NArg() == 0 { + log.fatal("error: please specify version argument V") + } + + v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read version argument V") + } + + if v < -1 { + log.fatal("error: argument V must be >= -1") + } + + if err := forceCmd(migrater, int(v)); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "version": + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if err := versionCmd(migrater); err != nil { + log.fatalErr(err) + } + + default: + printUsageAndExit() + } } diff --git a/migrate.go b/migrate.go index e7631c843..0793924c1 100644 --- a/migrate.go +++ b/migrate.go @@ -5,17 +5,21 @@ package migrate import ( - "errors" - "fmt" - "os" - "sync" - "time" - - "github.com/hashicorp/go-multierror" - - "github.com/golang-migrate/migrate/v4/database" - iurl "github.com/golang-migrate/migrate/v4/internal/url" - "github.com/golang-migrate/migrate/v4/source" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + + "github.com/golang-migrate/migrate/v4/database" + iurl "github.com/golang-migrate/migrate/v4/internal/url" + "github.com/golang-migrate/migrate/v4/source" ) // DefaultPrefetchMigrations sets the number of migrations to pre-read @@ -29,107 +33,110 @@ var DefaultPrefetchMigrations = uint(10) var DefaultLockTimeout = 15 * time.Second var ( - ErrNoChange = errors.New("no change") - ErrNilVersion = errors.New("no migration") - ErrInvalidVersion = errors.New("version must be >= -1") - ErrLocked = errors.New("database locked") - ErrLockTimeout = errors.New("timeout: can't acquire database lock") + ErrNoChange = errors.New("no change") + ErrNilVersion = errors.New("no migration") + ErrInvalidVersion = errors.New("version must be >= -1") + ErrLocked = errors.New("database locked") + ErrLockTimeout = errors.New("timeout: can't acquire database lock") ) +// Define a constant for the migration file name +const lastSuccessfulMigrationFile = "lastSuccessfulMigration" + // ErrShortLimit is an error returned when not enough migrations // can be returned by a source for a given limit. type ErrShortLimit struct { - Short uint + Short uint } // Error implements the error interface. func (e ErrShortLimit) Error() string { - return fmt.Sprintf("limit %v short", e.Short) + return fmt.Sprintf("limit %v short", e.Short) } type ErrDirty struct { - Version int + Version int } func (e ErrDirty) Error() string { - return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) + return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) } type Migrate struct { - sourceName string - sourceDrv source.Driver - databaseName string - databaseDrv database.Driver + sourceName string + sourceDrv source.Driver + databaseName string + databaseDrv database.Driver - // Log accepts a Logger interface - Log Logger + // Log accepts a Logger interface + Log Logger - // GracefulStop accepts `true` and will stop executing migrations - // as soon as possible at a safe break point, so that the database - // is not corrupted. - GracefulStop chan bool - isLockedMu *sync.Mutex + // GracefulStop accepts `true` and will stop executing migrations + // as soon as possible at a safe break point, so that the database + // is not corrupted. + GracefulStop chan bool + isLockedMu *sync.Mutex - isGracefulStop bool - isLocked bool + isGracefulStop bool + isLocked bool - // PrefetchMigrations defaults to DefaultPrefetchMigrations, - // but can be set per Migrate instance. - PrefetchMigrations uint + // PrefetchMigrations defaults to DefaultPrefetchMigrations, + // but can be set per Migrate instance. + PrefetchMigrations uint - // LockTimeout defaults to DefaultLockTimeout, - // but can be set per Migrate instance. - LockTimeout time.Duration + // LockTimeout defaults to DefaultLockTimeout, + // but can be set per Migrate instance. + LockTimeout time.Duration - // DirtyStateHandler is used to handle dirty state of the database - ds *dirtyStateHandler + // DirtyStateHandler is used to handle dirty state of the database + ds *dirtyStateHandler } type dirtyStateHandler struct { - srcPath string - destPath string - isDirty bool + srcPath string + destPath string + isDirty bool } // New returns a new Migrate instance from a source URL and a database URL. // The URL scheme is defined by each driver. func New(sourceURL, databaseURL string) (*Migrate, error) { - m := newCommon() - - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) - } - m.sourceName = sourceName - - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName - - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv - - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) - } - m.databaseDrv = databaseDrv - - return m, nil + m := newCommon() + + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName + + databaseName, err := iurl.SchemeFromURL(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName + + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + + databaseDrv, err := database.Open(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) + } + m.databaseDrv = databaseDrv + + return m, nil } func (m *Migrate) updateSourceDrv(sourceURL string) error { - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv - return nil + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + return nil } // NewWithDatabaseInstance returns a new Migrate instance from a source URL @@ -137,25 +144,25 @@ func (m *Migrate) updateSourceDrv(sourceURL string) error { // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() + m := newCommon() - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, err - } - m.sourceName = sourceName + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return nil, err + } + m.sourceName = sourceName - m.databaseName = databaseName + m.databaseName = databaseName - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv - m.databaseDrv = databaseInstance + m.databaseDrv = databaseInstance - return m, nil + return m, nil } // NewWithSourceInstance returns a new Migrate instance from an existing source instance @@ -163,25 +170,25 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { - m := newCommon() + m := newCommon() - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName + databaseName, err := iurl.SchemeFromURL(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName - m.sourceName = sourceName + m.sourceName = sourceName - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) - } - m.databaseDrv = databaseDrv + databaseDrv, err := database.Open(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) + } + m.databaseDrv = databaseDrv - m.sourceDrv = sourceInstance + m.sourceDrv = sourceInstance - return m, nil + return m, nil } // NewWithInstance returns a new Migrate instance from an existing source and @@ -189,191 +196,194 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data // as sourceName and databaseName. You are responsible for closing down // the underlying source and database client if necessary. func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() + m := newCommon() - m.sourceName = sourceName - m.databaseName = databaseName + m.sourceName = sourceName + m.databaseName = databaseName - m.sourceDrv = sourceInstance - m.databaseDrv = databaseInstance + m.sourceDrv = sourceInstance + m.databaseDrv = databaseInstance - return m, nil + return m, nil } func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) { - m.ds = &dirtyStateHandler{ - srcPath: srcPath, - destPath: destPath, - isDirty: isDirty, - } + m.ds = &dirtyStateHandler{ + srcPath: srcPath, + destPath: destPath, + isDirty: isDirty, + } } func newCommon() *Migrate { - return &Migrate{ - GracefulStop: make(chan bool, 1), - PrefetchMigrations: DefaultPrefetchMigrations, - LockTimeout: DefaultLockTimeout, - isLockedMu: &sync.Mutex{}, - } + return &Migrate{ + GracefulStop: make(chan bool, 1), + PrefetchMigrations: DefaultPrefetchMigrations, + LockTimeout: DefaultLockTimeout, + isLockedMu: &sync.Mutex{}, + } } // Close closes the source and the database. func (m *Migrate) Close() (source error, database error) { - databaseSrvClose := make(chan error) - sourceSrvClose := make(chan error) + databaseSrvClose := make(chan error) + sourceSrvClose := make(chan error) - m.logVerbosePrintf("Closing source and database\n") + m.logVerbosePrintf("Closing source and database\n") - go func() { - databaseSrvClose <- m.databaseDrv.Close() - }() + go func() { + databaseSrvClose <- m.databaseDrv.Close() + }() - go func() { - sourceSrvClose <- m.sourceDrv.Close() - }() + go func() { + sourceSrvClose <- m.sourceDrv.Close() + }() - return <-sourceSrvClose, <-databaseSrvClose + return <-sourceSrvClose, <-databaseSrvClose } // Migrate looks at the currently active migration version, // then migrates either up or down to the specified version. func (m *Migrate) Migrate(version uint) error { - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - m.Log.Printf("******************Failed to get current version: %v\n", err) - return err - } - - if err = m.CopyFiles(); err != nil { - return err - } - - m.Log.Printf("Current version: %d, dirty: %t\n", curVersion, dirty) - // if the dirty flag is passed to the 'goto' command, handle the dirty state - if dirty { - if m.ds.isDirty { - m.Log.Printf("Version: %d, handle dirty: %t\n", version, m.ds.isDirty) - if err = m.HandleDirtyState(); err != nil { - return err - } - if err = m.updateSourceDrv(fmt.Sprintf("file://%s", m.ds.destPath)); err != nil { - return err - } - - } else { - // default behaviour - m.Log.Printf("Database is set to dirty for version: %v\n", curVersion) - return ErrDirty{curVersion} - } - } - - if err = m.lock(); err != nil { - return err - } - - ret := make(chan interface{}, m.PrefetchMigrations) - go m.read(curVersion, int(version), ret) - - err = m.runMigrations(ret) - if err != nil { - if m.ds.isDirty { - // Handle failure: store last successful migration version and exit - if err = m.HandleMigrationFailure(curVersion, version); err != nil { - return err - } - } - return m.unlockErr(err) - } - // Success: Clean up and confirm - if err = m.CleanupFiles(version); err != nil { - return m.unlockErr(err) - } - return nil + if err := m.lock(); err != nil { + return err + } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + // if the dirty flag is passed to the 'goto' command, handle the dirty state + if dirty { + if m.ds != nil && m.ds.isDirty { + if err = m.unlock(); err != nil { + return m.unlockErr(err) + } + if err = m.HandleDirtyState(); err != nil { + return m.unlockErr(err) + } + if err = m.lock(); err != nil { + return err + } + if err = m.updateSourceDrv("file://" + m.ds.destPath); err != nil { + return m.unlockErr(err) + } + + } else { + // default behaviour + return m.unlockErr(ErrDirty{curVersion}) + } + } + + // Copy migrations to the destination directory, + // if state was dirty when Migrate was called, we should handle the dirty state first before copying the migrations + if err = m.CopyFiles(); err != nil { + return m.unlockErr(err) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + go m.read(curVersion, int(version), ret) + + if err = m.runMigrations(ret); err != nil { + if m.ds != nil && m.ds.isDirty { + // Handle failure: store last successful migration version and exit + if err = m.HandleMigrationFailure(curVersion, version); err != nil { + return m.unlockErr(err) + } + } + return m.unlockErr(err) + } + // Success: Clean up and confirm + if err = m.CleanupFiles(version); err != nil { + return m.unlockErr(err) + } + // unlock the database + return m.unlock() } // Steps looks at the currently active migration version. // It will migrate up if n > 0, and down if n < 0. func (m *Migrate) Steps(n int) error { - if n == 0 { - return ErrNoChange - } + if n == 0 { + return ErrNoChange + } - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } - ret := make(chan interface{}, m.PrefetchMigrations) + ret := make(chan interface{}, m.PrefetchMigrations) - if n > 0 { - go m.readUp(curVersion, n, ret) - } else { - go m.readDown(curVersion, -n, ret) - } + if n > 0 { + go m.readUp(curVersion, n, ret) + } else { + go m.readDown(curVersion, -n, ret) + } - return m.unlockErr(m.runMigrations(ret)) + return m.unlockErr(m.runMigrations(ret)) } // Up looks at the currently active migration version // and will migrate all the way up (applying all up migrations). func (m *Migrate) Up() error { - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } - ret := make(chan interface{}, m.PrefetchMigrations) + ret := make(chan interface{}, m.PrefetchMigrations) - go m.readUp(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + go m.readUp(curVersion, -1, ret) + return m.unlockErr(m.runMigrations(ret)) } // Down looks at the currently active migration version // and will migrate all the way down (applying all down migrations). func (m *Migrate) Down() error { - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - go m.readDown(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + if err := m.lock(); err != nil { + return err + } + + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + go m.readDown(curVersion, -1, ret) + return m.unlockErr(m.runMigrations(ret)) } // Drop deletes everything in the database. func (m *Migrate) Drop() error { - if err := m.lock(); err != nil { - return err - } - if err := m.databaseDrv.Drop(); err != nil { - return m.unlockErr(err) - } - return m.unlock() + if err := m.lock(); err != nil { + return err + } + if err := m.databaseDrv.Drop(); err != nil { + return m.unlockErr(err) + } + return m.unlock() } // Run runs any migration provided by you against the database. @@ -381,78 +391,78 @@ func (m *Migrate) Drop() error { // Usually you don't need this function at all. Use Migrate, // Steps, Up or Down instead. func (m *Migrate) Run(migration ...*Migration) error { - if len(migration) == 0 { - return ErrNoChange - } - - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - - go func() { - defer close(ret) - for _, migr := range migration { - if m.PrefetchMigrations > 0 && migr.Body != nil { - m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) - } else { - m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) - } - - ret <- migr - go func(migr *Migration) { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }(migr) - } - }() - - return m.unlockErr(m.runMigrations(ret)) + if len(migration) == 0 { + return ErrNoChange + } + + if err := m.lock(); err != nil { + return err + } + + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + + go func() { + defer close(ret) + for _, migr := range migration { + if m.PrefetchMigrations > 0 && migr.Body != nil { + m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) + } else { + m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) + } + + ret <- migr + go func(migr *Migration) { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }(migr) + } + }() + + return m.unlockErr(m.runMigrations(ret)) } // Force sets a migration version. // It does not check any currently active version in database. // It resets the dirty state to false. func (m *Migrate) Force(version int) error { - if version < -1 { - return ErrInvalidVersion - } + if version < -1 { + return ErrInvalidVersion + } - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - if err := m.databaseDrv.SetVersion(version, false); err != nil { - return m.unlockErr(err) - } + if err := m.databaseDrv.SetVersion(version, false); err != nil { + return m.unlockErr(err) + } - return m.unlock() + return m.unlock() } // Version returns the currently active migration version. // If no migration has been applied, yet, it will return ErrNilVersion. func (m *Migrate) Version() (version uint, dirty bool, err error) { - v, d, err := m.databaseDrv.Version() - if err != nil { - return 0, false, err - } + v, d, err := m.databaseDrv.Version() + if err != nil { + return 0, false, err + } - if v == database.NilVersion { - return 0, false, ErrNilVersion - } + if v == database.NilVersion { + return 0, false, ErrNilVersion + } - return suint(v), d, nil + return suint(v), d, nil } // read reads either up or down migrations from source `from` to `to`. @@ -460,130 +470,130 @@ func (m *Migrate) Version() (version uint, dirty bool, err error) { // If an error occurs during reading, that error is written to the ret channel, too. // Once read is done reading it will close the ret channel. func (m *Migrate) read(from int, to int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - // check if to version exists - if to >= 0 { - if err := m.versionExists(suint(to)); err != nil { - ret <- err - return - } - } - - // no change? - if from == to { - ret <- ErrNoChange - return - } - - if from < to { - // it's going up - // apply first migration if from is nil version - if from == -1 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, int(firstVersion)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(firstVersion) - } - - // run until we reach target ... - for from < to { - if m.stop() { - return - } - - next, err := m.sourceDrv.Next(suint(from)) - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(next, int(next)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(next) - } - - } else { - // it's going down - // run until we reach target ... - for from > to && from >= 0 { - if m.stop() { - return - } - - prev, err := m.sourceDrv.Prev(suint(from)) - if errors.Is(err, os.ErrNotExist) && to == -1 { - // apply nil migration - migr, err := m.newMigration(suint(from), -1) - if err != nil { - ret <- err - return - } - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - return - - } else if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(suint(from), int(prev)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(prev) - } - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + // check if to version exists + if to >= 0 { + if err := m.versionExists(suint(to)); err != nil { + ret <- err + return + } + } + + // no change? + if from == to { + ret <- ErrNoChange + return + } + + if from < to { + // it's going up + // apply first migration if from is nil version + if from == -1 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, int(firstVersion)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(firstVersion) + } + + // run until we reach target ... + for from < to { + if m.stop() { + return + } + + next, err := m.sourceDrv.Next(suint(from)) + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(next, int(next)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(next) + } + + } else { + // it's going down + // run until we reach target ... + for from > to && from >= 0 { + if m.stop() { + return + } + + prev, err := m.sourceDrv.Prev(suint(from)) + if errors.Is(err, os.ErrNotExist) && to == -1 { + // apply nil migration + migr, err := m.newMigration(suint(from), -1) + if err != nil { + ret <- err + return + } + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + return + + } else if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(suint(from), int(prev)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(prev) + } + } } // readUp reads up migrations from `from` limitted by `limit`. @@ -592,98 +602,98 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { // If an error occurs during reading, that error is written to the ret channel, too. // Once readUp is done reading it will close the ret channel. func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - if limit == 0 { - ret <- ErrNoChange - return - } - - count := 0 - for count < limit || limit == -1 { - if m.stop() { - return - } - - // apply first migration if from is nil version - if from == -1 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, int(firstVersion)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(firstVersion) - count++ - continue - } - - // apply next migration - next, err := m.sourceDrv.Next(suint(from)) - if errors.Is(err, os.ErrNotExist) { - // no limit, but no migrations applied? - if limit == -1 && count == 0 { - ret <- ErrNoChange - return - } - - // no limit, reached end - if limit == -1 { - return - } - - // reached end, and didn't apply any migrations - if limit > 0 && count == 0 { - ret <- os.ErrNotExist - return - } - - // applied less migrations than limit? - if count < limit { - ret <- ErrShortLimit{suint(limit - count)} - return - } - } - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(next, int(next)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(next) - count++ - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + if limit == 0 { + ret <- ErrNoChange + return + } + + count := 0 + for count < limit || limit == -1 { + if m.stop() { + return + } + + // apply first migration if from is nil version + if from == -1 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, int(firstVersion)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(firstVersion) + count++ + continue + } + + // apply next migration + next, err := m.sourceDrv.Next(suint(from)) + if errors.Is(err, os.ErrNotExist) { + // no limit, but no migrations applied? + if limit == -1 && count == 0 { + ret <- ErrNoChange + return + } + + // no limit, reached end + if limit == -1 { + return + } + + // reached end, and didn't apply any migrations + if limit > 0 && count == 0 { + ret <- os.ErrNotExist + return + } + + // applied less migrations than limit? + if count < limit { + ret <- ErrShortLimit{suint(limit - count)} + return + } + } + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(next, int(next)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(next) + count++ + } } // readDown reads down migrations from `from` limitted by `limit`. @@ -692,88 +702,88 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { // If an error occurs during reading, that error is written to the ret channel, too. // Once readDown is done reading it will close the ret channel. func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - if limit == 0 { - ret <- ErrNoChange - return - } - - // no change if already at nil version - if from == -1 && limit == -1 { - ret <- ErrNoChange - return - } - - // can't go over limit if already at nil version - if from == -1 && limit > 0 { - ret <- os.ErrNotExist - return - } - - count := 0 - for count < limit || limit == -1 { - if m.stop() { - return - } - - prev, err := m.sourceDrv.Prev(suint(from)) - if errors.Is(err, os.ErrNotExist) { - // no limit or haven't reached limit, apply "first" migration - if limit == -1 || limit-count > 0 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, -1) - if err != nil { - ret <- err - return - } - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - count++ - } - - if count < limit { - ret <- ErrShortLimit{suint(limit - count)} - } - return - } - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(suint(from), int(prev)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(prev) - count++ - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + if limit == 0 { + ret <- ErrNoChange + return + } + + // no change if already at nil version + if from == -1 && limit == -1 { + ret <- ErrNoChange + return + } + + // can't go over limit if already at nil version + if from == -1 && limit > 0 { + ret <- os.ErrNotExist + return + } + + count := 0 + for count < limit || limit == -1 { + if m.stop() { + return + } + + prev, err := m.sourceDrv.Prev(suint(from)) + if errors.Is(err, os.ErrNotExist) { + // no limit or haven't reached limit, apply "first" migration + if limit == -1 || limit-count > 0 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, -1) + if err != nil { + ret <- err + return + } + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + count++ + } + + if count < limit { + ret <- ErrShortLimit{suint(limit - count)} + } + return + } + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(suint(from), int(prev)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(prev) + count++ + } } // runMigrations reads *Migration and error from a channel. Any other type @@ -783,260 +793,418 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { - m.Log.Printf("Starting %s migrations\n", m.sourceDrv) - for r := range ret { - - if m.stop() { - return nil - } - - switch r := r.(type) { - case error: - return r - - case *Migration: - migr := r - - // set version with dirty state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil { - return err - } - - if migr.Body != nil { - m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) - if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { - return err - } - } - - // set clean state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { - return err - } - - endTime := time.Now() - readTime := migr.FinishedReading.Sub(migr.StartedBuffering) - runTime := endTime.Sub(migr.FinishedReading) - - // log either verbose or normal - if m.Log != nil { - if m.Log.Verbose() { - m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime) - } else { - m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime) - } - } - - default: - return fmt.Errorf("unknown type: %T with value: %+v", r, r) - } - } - return nil + for r := range ret { + + if m.stop() { + return nil + } + + switch r := r.(type) { + case error: + return r + + case *Migration: + migr := r + + // set version with dirty state + if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil { + return err + } + + if migr.Body != nil { + m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) + if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { + return err + } + } + + // set clean state + if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { + return err + } + + endTime := time.Now() + readTime := migr.FinishedReading.Sub(migr.StartedBuffering) + runTime := endTime.Sub(migr.FinishedReading) + + // log either verbose or normal + if m.Log != nil { + if m.Log.Verbose() { + m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime) + } else { + m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime) + } + } + + default: + return fmt.Errorf("unknown type: %T with value: %+v", r, r) + } + } + return nil } // versionExists checks the source if either the up or down migration for // the specified migration version exists. func (m *Migrate) versionExists(version uint) (result error) { - // try up migration first - up, _, err := m.sourceDrv.ReadUp(version) - if err == nil { - defer func() { - if errClose := up.Close(); errClose != nil { - result = multierror.Append(result, errClose) - } - }() - } - if errors.Is(err, os.ErrExist) { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - // then try down migration - down, _, err := m.sourceDrv.ReadDown(version) - if err == nil { - defer func() { - if errClose := down.Close(); errClose != nil { - result = multierror.Append(result, errClose) - } - }() - } - if errors.Is(err, os.ErrExist) { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - err = fmt.Errorf("no migration found for version %d: %w", version, err) - m.logErr(err) - return err + // try up migration first + up, _, err := m.sourceDrv.ReadUp(version) + if err == nil { + defer func() { + if errClose := up.Close(); errClose != nil { + result = multierror.Append(result, errClose) + } + }() + } + if errors.Is(err, os.ErrExist) { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + // then try down migration + down, _, err := m.sourceDrv.ReadDown(version) + if err == nil { + defer func() { + if errClose := down.Close(); errClose != nil { + result = multierror.Append(result, errClose) + } + }() + } + if errors.Is(err, os.ErrExist) { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + err = fmt.Errorf("no migration found for version %d: %w", version, err) + m.logErr(err) + return err } // stop returns true if no more migrations should be run against the database // because a stop signal was received on the GracefulStop channel. // Calls are cheap and this function is not blocking. func (m *Migrate) stop() bool { - if m.isGracefulStop { - return true - } - - select { - case <-m.GracefulStop: - m.isGracefulStop = true - return true - - default: - return false - } + if m.isGracefulStop { + return true + } + + select { + case <-m.GracefulStop: + m.isGracefulStop = true + return true + + default: + return false + } } // newMigration is a helper func that returns a *Migration for the // specified version and targetVersion. func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, error) { - var migr *Migration - - if targetVersion >= int(version) { - r, identifier, err := m.sourceDrv.ReadUp(version) - if errors.Is(err, os.ErrNotExist) { - // create "empty" migration - migr, err = NewMigration(nil, "", version, targetVersion) - if err != nil { - return nil, err - } - - } else if err != nil { - return nil, err - - } else { - // create migration from up source - migr, err = NewMigration(r, identifier, version, targetVersion) - if err != nil { - return nil, err - } - } - - } else { - r, identifier, err := m.sourceDrv.ReadDown(version) - if errors.Is(err, os.ErrNotExist) { - // create "empty" migration - migr, err = NewMigration(nil, "", version, targetVersion) - if err != nil { - return nil, err - } - - } else if err != nil { - return nil, err - - } else { - // create migration from down source - migr, err = NewMigration(r, identifier, version, targetVersion) - if err != nil { - return nil, err - } - } - } - - if m.PrefetchMigrations > 0 && migr.Body != nil { - m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) - } else { - m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) - } - - return migr, nil + var migr *Migration + + if targetVersion >= int(version) { + r, identifier, err := m.sourceDrv.ReadUp(version) + if errors.Is(err, os.ErrNotExist) { + // create "empty" migration + migr, err = NewMigration(nil, "", version, targetVersion) + if err != nil { + return nil, err + } + + } else if err != nil { + return nil, err + + } else { + // create migration from up source + migr, err = NewMigration(r, identifier, version, targetVersion) + if err != nil { + return nil, err + } + } + + } else { + r, identifier, err := m.sourceDrv.ReadDown(version) + if errors.Is(err, os.ErrNotExist) { + // create "empty" migration + migr, err = NewMigration(nil, "", version, targetVersion) + if err != nil { + return nil, err + } + + } else if err != nil { + return nil, err + + } else { + // create migration from down source + migr, err = NewMigration(r, identifier, version, targetVersion) + if err != nil { + return nil, err + } + } + } + + if m.PrefetchMigrations > 0 && migr.Body != nil { + m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) + } else { + m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) + } + + return migr, nil } // lock is a thread safe helper function to lock the database. // It should be called as late as possible when running migrations. func (m *Migrate) lock() error { - m.isLockedMu.Lock() - defer m.isLockedMu.Unlock() - - if m.isLocked { - return ErrLocked - } - - // create done channel, used in the timeout goroutine - done := make(chan bool, 1) - defer func() { - done <- true - }() - - // use errchan to signal error back to this context - errchan := make(chan error, 2) - - // start timeout goroutine - timeout := time.After(m.LockTimeout) - go func() { - for { - select { - case <-done: - return - case <-timeout: - errchan <- ErrLockTimeout - return - } - } - }() - - // now try to acquire the lock - go func() { - if err := m.databaseDrv.Lock(); err != nil { - errchan <- err - } else { - errchan <- nil - } - }() - - // wait until we either receive ErrLockTimeout or error from Lock operation - err := <-errchan - if err == nil { - m.isLocked = true - } - return err + m.isLockedMu.Lock() + defer m.isLockedMu.Unlock() + + if m.isLocked { + return ErrLocked + } + + // create done channel, used in the timeout goroutine + done := make(chan bool, 1) + defer func() { + done <- true + }() + + // use errchan to signal error back to this context + errchan := make(chan error, 2) + + // start timeout goroutine + timeout := time.After(m.LockTimeout) + go func() { + for { + select { + case <-done: + return + case <-timeout: + errchan <- ErrLockTimeout + return + } + } + }() + + // now try to acquire the lock + go func() { + if err := m.databaseDrv.Lock(); err != nil { + errchan <- err + } else { + errchan <- nil + } + }() + + // wait until we either receive ErrLockTimeout or error from Lock operation + err := <-errchan + if err == nil { + m.isLocked = true + } + return err } // unlock is a thread safe helper function to unlock the database. // It should be called as early as possible when no more migrations are // expected to be executed. func (m *Migrate) unlock() error { - m.isLockedMu.Lock() - defer m.isLockedMu.Unlock() + m.isLockedMu.Lock() + defer m.isLockedMu.Unlock() - if err := m.databaseDrv.Unlock(); err != nil { - // BUG: Can potentially create a deadlock. Add a timeout. - return err - } + if err := m.databaseDrv.Unlock(); err != nil { + // BUG: Can potentially create a deadlock. Add a timeout. + return err + } - m.isLocked = false - return nil + m.isLocked = false + return nil } // unlockErr calls unlock and returns a combined error // if a prevErr is not nil. func (m *Migrate) unlockErr(prevErr error) error { - if err := m.unlock(); err != nil { - return multierror.Append(prevErr, err) - } - return prevErr + if err := m.unlock(); err != nil { + return multierror.Append(prevErr, err) + } + return prevErr } // logPrintf writes to m.Log if not nil func (m *Migrate) logPrintf(format string, v ...interface{}) { - if m.Log != nil { - m.Log.Printf(format, v...) - } + if m.Log != nil { + m.Log.Printf(format, v...) + } } // logVerbosePrintf writes to m.Log if not nil. Use for verbose logging output. func (m *Migrate) logVerbosePrintf(format string, v ...interface{}) { - if m.Log != nil && m.Log.Verbose() { - m.Log.Printf(format, v...) - } + if m.Log != nil && m.Log.Verbose() { + m.Log.Printf(format, v...) + } } // logErr writes error to m.Log if not nil func (m *Migrate) logErr(err error) { - if m.Log != nil { - m.Log.Printf("error: %v", err) - } + if m.Log != nil { + m.Log.Printf("error: %v", err) + } +} + +func (m *Migrate) HandleDirtyState() error { + // Perform actions when the database state is dirty + lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) + lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + return err + } + lastVersionStr := strings.TrimSpace(string(lastVersionBytes)) + lastVersion, err := strconv.ParseInt(lastVersionStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse last successful migration version: %w", err) + } + + if err = m.Force(int(lastVersion)); err != nil { + return fmt.Errorf("failed to apply last successful migration: %w", err) + } + + m.logPrintf("Successfully applied migration: %s", lastVersionStr) + + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + return err + } + + m.logPrintf("Successfully deleted file: %s", lastSuccessfulMigrationPath) + return nil +} + +func (m *Migrate) HandleMigrationFailure(curVersion int, v uint) error { + failedVersion, _, err := m.databaseDrv.Version() + if err != nil { + return err + } + // Determine the last successful migration + lastSuccessfulMigration := strconv.Itoa(curVersion) + ret := make(chan interface{}, m.PrefetchMigrations) + go m.read(curVersion, int(v), ret) + + for r := range ret { + mig, ok := r.(*Migration) + if ok { + if mig.Version == uint(failedVersion) { + break + } + lastSuccessfulMigration = strconv.Itoa(int(mig.Version)) + } + } + + lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) + return os.WriteFile(lastSuccessfulMigrationPath, []byte(lastSuccessfulMigration), 0644) +} + +func (m *Migrate) CleanupFiles(v uint) error { + if m.ds == nil || m.ds.destPath == "" { + return nil + } + files, err := os.ReadDir(m.ds.destPath) + if err != nil { + return err + } + + targetVersion := uint64(v) + + for _, file := range files { + fileName := file.Name() + + // Check if file is a migration file we want to process + if !strings.HasSuffix(fileName, "down.sql") && !strings.HasSuffix(fileName, "up.sql") { + continue + } + + // Extract version and compare + versionEnd := strings.Index(fileName, "_") + if versionEnd == -1 { + // Skip files that don't match the expected naming pattern + continue + } + + fileVersion, err := strconv.ParseUint(fileName[:versionEnd], 10, 64) + if err != nil { + m.logErr(fmt.Errorf("skipping file %s due to version parse error: %v", fileName, err)) + continue + } + + // Delete file if version is greater than targetVersion + if fileVersion > targetVersion { + if err = os.Remove(filepath.Join(m.ds.destPath, fileName)); err != nil { + m.logErr(fmt.Errorf("failed to delete file %s: %v", fileName, err)) + continue + } + m.logPrintf("Deleted file: %s", fileName) + } + } + + return nil +} + +// CopyFiles copies all files from srcDir to destDir. +func (m *Migrate) CopyFiles() error { + if m.ds == nil || m.ds.destPath == "" { + return nil + } + _, err := os.ReadDir(m.ds.destPath) + if err != nil { + // If the directory does not exist + return err + } + + m.logPrintf("Copying files from %s to %s", m.ds.srcPath, m.ds.destPath) + + return filepath.Walk(m.ds.srcPath, func(src string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // ignore sub-directories in the migration directory + if info.IsDir() { + // Skip the tests directory and its files + if info.Name() == "tests" { + return filepath.SkipDir + } + return nil + } + // Ignore the current.sql file + if info.Name() == "current.sql" { + return nil + } + + var ( + srcFile *os.File + destFile *os.File + ) + dest := filepath.Join(m.ds.destPath, info.Name()) + if srcFile, err = os.Open(src); err != nil { + return err + } + defer func(srcFile *os.File) { + if err = srcFile.Close(); err != nil { + m.logErr(fmt.Errorf("failed to close file %s: %v", destFile.Name(), err)) + } + }(srcFile) + + // Create the destination file + if destFile, err = os.Create(dest); err != nil { + return err + } + defer func(destFile *os.File) { + if err = destFile.Close(); err != nil { + m.logErr(fmt.Errorf("failed to close file %s: %v", destFile.Name(), err)) + } + }(destFile) + + // Copy the file + if _, err = io.Copy(destFile, srcFile); err != nil { + return err + } + return os.Chmod(dest, info.Mode()) + }) } diff --git a/migrate_goto_temp.go b/migrate_goto_temp.go deleted file mode 100644 index 517d63adb..000000000 --- a/migrate_goto_temp.go +++ /dev/null @@ -1,176 +0,0 @@ -package migrate - -import ( - "io" - "os" - "path/filepath" - "strconv" - "strings" - - "github.com/pkg/errors" -) - -// Define a constant for the migration file name -const lastSuccessfulMigrationFile = "lastSuccessfulMigration" - -func (m *Migrate) HandleDirtyState() error { - // Perform actions when the database state is dirty - lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) - lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) - if err != nil { - return err - } - lastVersionStr := strings.TrimSpace(string(lastVersionBytes)) - lastVersion, err := strconv.ParseUint(lastVersionStr, 10, 64) - if err != nil { - return errors.Wrap(err, "failed to parse last successful migration version") - } - - if err = m.Force(int(lastVersion)); err != nil { - return errors.Wrap(err, "failed to apply last successful migration") - } - - m.Log.Printf("Successfully applied migration: %s", lastVersionStr) - - if err = os.Remove(lastSuccessfulMigrationPath); err != nil { - return err - } - - m.Log.Printf("Successfully deleted file: %s", lastSuccessfulMigrationPath) - return nil -} - -func (m *Migrate) HandleMigrationFailure(curVersion int, v uint) error { - failedVersion, _, err := m.databaseDrv.Version() - if err != nil { - return err - } - - // Determine the last successful migration - lastSuccessfulMigration := strconv.Itoa(curVersion) - ret := make(chan interface{}, m.PrefetchMigrations) - go m.read(curVersion, int(v), ret) - - for r := range ret { - mig, ok := r.(*Migration) - if ok { - if mig.Version == uint(failedVersion) { - break - } - lastSuccessfulMigration = strconv.Itoa(int(mig.Version)) - } - } - - m.Log.Printf("migration failed, last successful migration version: %s", lastSuccessfulMigration) - lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) - if err = os.WriteFile(lastSuccessfulMigrationPath, []byte(lastSuccessfulMigration), 0644); err != nil { - return err - } - - return nil -} - -func (m *Migrate) CleanupFiles(v uint) error { - if m.ds.destPath == "" { - return nil - } - files, err := os.ReadDir(m.ds.destPath) - if err != nil { - return err - } - - targetVersion := uint64(v) - - for _, file := range files { - fileName := file.Name() - - // Check if file is a migration file we want to process - if !strings.HasSuffix(fileName, "down.sql") && !strings.HasSuffix(fileName, "up.sql") { - continue - } - - // Extract version and compare - versionEnd := strings.Index(fileName, "_") - if versionEnd == -1 { - // Skip files that don't match the expected naming pattern - continue - } - - fileVersion, err := strconv.ParseUint(fileName[:versionEnd], 10, 64) - if err != nil { - m.Log.Printf("Skipping file %s due to version parse error: %v", fileName, err) - continue - } - - // Delete file if version is greater than targetVersion - if fileVersion > targetVersion { - if err = os.Remove(filepath.Join(m.ds.destPath, fileName)); err != nil { - m.Log.Printf("Failed to delete file %s: %v", fileName, err) - continue - } - m.Log.Printf("Deleted file: %s", fileName) - } - } - - return nil -} - -// CopyFiles copies all files from srcDir to destDir. -func (m *Migrate) CopyFiles() error { - if m.ds.destPath == "" { - return nil - } - _, err := os.ReadDir(m.ds.destPath) - if err != nil { - // If the directory does not exist - return err - } - - m.Log.Printf("Copying files from %s to %s", m.ds.srcPath, m.ds.destPath) - - return filepath.Walk(m.ds.srcPath, func(src string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - // ignore sub-directories in the migration directory - if info.IsDir() { - // Skip the tests directory and its files - if info.Name() == "tests" { - m.Log.Printf("Ignoring directory %s", info.Name()) - return filepath.SkipDir - } - return nil - } - // Ignore the current.sql file - if info.Name() == "current.sql" { - m.Log.Printf("Ignoring file %s", info.Name()) - return nil - } - - var ( - srcFile *os.File - destFile *os.File - ) - dest := filepath.Join(m.ds.destPath, info.Name()) - if srcFile, err = os.Open(src); err != nil { - return err - } - defer func(srcFile *os.File) { - if err = srcFile.Close(); err != nil { - m.Log.Printf("failed to close file %s: %s", srcFile.Name, err) - } - }(srcFile) - - // Create the destination file - if destFile, err = os.Create(dest); err != nil { - return err - } - - // Copy the file - if _, err = io.Copy(destFile, srcFile); err == nil { - return err - } - return os.Chmod(dest, info.Mode()) - }) -} diff --git a/migrate_test.go b/migrate_test.go index f2728179e..61c6b7d73 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -4,9 +4,12 @@ import ( "bytes" "database/sql" "errors" + "fmt" "io" "log" "os" + "path/filepath" + "strconv" "strings" "testing" @@ -1414,3 +1417,321 @@ func equalDbSeq(t *testing.T, i int, expected migrationSequence, got *dStub.Stub t.Fatalf("\nexpected sequence %v,\ngot %v, in %v", bs, got.MigrationSequence, i) } } + +// Setting up temp directory to be used as a PVC mount +func setupTempDir(t *testing.T) (string, func()) { + tempDir, err := os.MkdirTemp("", "migrate_test") + if err != nil { + t.Fatal(err) + } + return tempDir, func() { + if err = os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + } +} + +func setupMigrateInstance(tempDir string) (*Migrate, *dStub.Stub) { + m, _ := New("stub://", "stub://") + m.ds = &dirtyStateHandler{ + destPath: tempDir, + isDirty: true, + } + return m, m.databaseDrv.(*dStub.Stub) +} + +func setupSourceStubMigrations() *source.Migrations { + migrations := source.NewMigrations() + migrations.Append(&source.Migration{Version: 1, Direction: source.Up, Identifier: "CREATE 1"}) + migrations.Append(&source.Migration{Version: 1, Direction: source.Down, Identifier: "DROP 1"}) + migrations.Append(&source.Migration{Version: 2, Direction: source.Up, Identifier: "CREATE 2"}) + migrations.Append(&source.Migration{Version: 2, Direction: source.Down, Identifier: "DROP 2"}) + migrations.Append(&source.Migration{Version: 3, Direction: source.Up, Identifier: "CREATE 3"}) + migrations.Append(&source.Migration{Version: 3, Direction: source.Down, Identifier: "DROP 3"}) + migrations.Append(&source.Migration{Version: 4, Direction: source.Up, Identifier: "CREATE 4"}) + migrations.Append(&source.Migration{Version: 4, Direction: source.Down, Identifier: "DROP 4"}) + migrations.Append(&source.Migration{Version: 5, Direction: source.Up, Identifier: "CREATE 5"}) + migrations.Append(&source.Migration{Version: 5, Direction: source.Down, Identifier: "DROP 5"}) + migrations.Append(&source.Migration{Version: 6, Direction: source.Up, Identifier: "CREATE 6"}) + migrations.Append(&source.Migration{Version: 6, Direction: source.Down, Identifier: "DROP 6"}) + migrations.Append(&source.Migration{Version: 7, Direction: source.Up, Identifier: "CREATE 7"}) + migrations.Append(&source.Migration{Version: 7, Direction: source.Down, Identifier: "DROP 7"}) + return migrations +} + +func TestHandleDirtyState(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, dbDrv := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + + tests := []struct { + lastSuccessful int + currentVersion int + err error + setupFailure bool + }{ + {lastSuccessful: 1, currentVersion: 2, err: nil, setupFailure: false}, + {lastSuccessful: 4, currentVersion: 5, err: nil, setupFailure: false}, + {lastSuccessful: 3, currentVersion: 4, err: nil, setupFailure: false}, + {lastSuccessful: -3, currentVersion: 4, err: ErrInvalidVersion, setupFailure: false}, + {lastSuccessful: 4, currentVersion: 3, err: fmt.Errorf("open %s: no such file or directory", filepath.Join(tempDir, lastSuccessfulMigrationFile)), setupFailure: true}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + var lastSuccessfulMigrationPath string + // setupFailure tests scenario where the 'lastSuccessfulMigrationFile' doesn't exist + if !test.setupFailure { + lastSuccessfulMigrationPath = filepath.Join(tempDir, lastSuccessfulMigrationFile) + if err := os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(test.lastSuccessful)), 0644); err != nil { + t.Fatal(err) + } + } + // Setting the DB version as dirty + if err := dbDrv.SetVersion(test.currentVersion, true); err != nil { + t.Fatal(err) + } + + // Quick check to see if set correctly + version, b, err := dbDrv.Version() + if err != nil { + t.Fatal(err) + } + if version != test.currentVersion { + t.Fatalf("expected version %d, got %d", test.currentVersion, version) + } + + if !b { + t.Fatalf("expected false, got true") + } + + // Handle dirty state + if err = m.HandleDirtyState(); err != nil { + if strings.Contains(err.Error(), test.err.Error()) { + t.Logf("expected error %v, got %v", test.err, err) + if !test.setupFailure { + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + t.Fatal(err) + } + } + return + } else { + t.Fatal(err) + } + } + // Check 1: DB should no longer be dirty + if dbDrv.IsDirty { + t.Fatalf("expected dirty to be false, got true") + } + // Check 2: Current version should be the last successful version + if dbDrv.CurrentVersion != test.lastSuccessful { + t.Fatalf("expected version %d, got %d", test.lastSuccessful, dbDrv.CurrentVersion) + } + // Check 3: The lastSuccessfulMigration file shouldn't exists + if _, err = os.Stat(lastSuccessfulMigrationPath); !os.IsNotExist(err) { + t.Fatalf("expected file to be deleted, but it still exists") + } + }) + } +} + +func TestHandleMigrationFailure(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, dbDrv := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + + tests := []struct { + curVersion int + targetVersion uint + dirtyVersion int + }{ + {curVersion: 1, targetVersion: 7, dirtyVersion: 4}, + {curVersion: 4, targetVersion: 6, dirtyVersion: 5}, + {curVersion: 3, targetVersion: 7, dirtyVersion: 6}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + t.Cleanup(func() { + m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + dbDrv = m.databaseDrv.(*dStub.Stub) + }) + + // Setup: Simulate a migration failure by setting the dirty version in the DB + if err := dbDrv.SetVersion(test.dirtyVersion, true); err != nil { + t.Fatal(err) + } + + // Test + if err := m.HandleMigrationFailure(test.curVersion, test.targetVersion); err != nil { + t.Fatal(err) + } + + // Check 1: Should no longer be dirty + if !dbDrv.IsDirty { + t.Fatalf("expected dirty to be true, got false") + } + + // Check 2: last successful Migration version should be stored in a file + lastSuccessfulMigrationPath := filepath.Join(tempDir, lastSuccessfulMigrationFile) + if _, err := os.Stat(lastSuccessfulMigrationPath); os.IsNotExist(err) { + t.Fatalf("expected file to be created, but it does not exist") + } + + // Check 3: Check if the content of last successful migration has the correct version + content, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + t.Fatal(err) + } + + if string(content) != strconv.Itoa(test.dirtyVersion-1) { + t.Fatalf("expected %d, got %s", test.dirtyVersion-1, string(content)) + } + }) + } +} + +func TestCleanupFiles(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, _ := setupMigrateInstance(tempDir) + + tests := []struct { + migrationFiles []string + targetVersion uint + remainingFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + targetVersion: 2, + remainingFiles: []string{"1_up.sql", "2_up.sql"}, + }, + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql", "5_up.sql"}, + targetVersion: 3, + remainingFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + }, + { + migrationFiles: []string{}, + targetVersion: 1, + remainingFiles: []string{}, + emptyDestPath: true, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(tempDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + + if test.emptyDestPath { + m.ds.destPath = "" + } + + if err := m.CleanupFiles(test.targetVersion); err != nil { + t.Fatal(err) + } + + // check 1: only files upto the target version should exist + for _, file := range test.remainingFiles { + if _, err := os.Stat(filepath.Join(tempDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to exist, but it does not", file) + } + } + + // check 2: the files removed are as expected + deletedFiles := diff(test.migrationFiles, test.remainingFiles) + for _, deletedFile := range deletedFiles { + if _, err := os.Stat(filepath.Join(tempDir, deletedFile)); !os.IsNotExist(err) { + t.Fatalf("expected file %s to be deleted, but it still exists", deletedFile) + } + } + }) + } +} + +func TestCopyFiles(t *testing.T) { + srcDir, cleanupSrc := setupTempDir(t) + defer cleanupSrc() + + destDir, cleanupDest := setupTempDir(t) + defer cleanupDest() + + m, _ := New("stub://", "stub://") + m.ds = &dirtyStateHandler{ + srcPath: srcDir, + destPath: destDir, + } + + tests := []struct { + migrationFiles []string + copiedFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + copiedFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + }, + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql", "current.sql"}, + copiedFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql"}, + }, + { + emptyDestPath: true, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(srcDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + if test.emptyDestPath { + m.ds.destPath = "" + } + + if err := m.CopyFiles(); err != nil { + t.Fatal(err) + } + + for _, file := range test.copiedFiles { + if _, err := os.Stat(filepath.Join(destDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to be copied, but it does not exist", file) + } + } + }) + } +} + +/* + diff returns an array containing the elements in Array A and not in B +*/ + +func diff(a, b []string) []string { + temp := map[string]int{} + for _, s := range a { + temp[s]++ + } + for _, s := range b { + temp[s]-- + } + + var result []string + for s, v := range temp { + if v != 0 { + result = append(result, s) + } + } + return result +} diff --git a/test/main.go b/test/main.go deleted file mode 100644 index 56e540407..000000000 --- a/test/main.go +++ /dev/null @@ -1 +0,0 @@ -package test