diff --git a/cmd/cmd.go b/cmd/cmd.go index a40609f..fb1eb75 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -2,12 +2,12 @@ package cmd import ( stdcontext "context" - "errors" "fmt" "io" "log" "os" "path" + "strings" "github.com/spf13/cobra" @@ -281,26 +281,48 @@ func getProjectContext( // getRepositoryContext makes sure that we're in a repository context, this is useful to add extra commands, which are only useful when in a repository with a shuttle file func getRepositoryContext(projectPath string) bool { - - var fullProjectPath string - if path.IsAbs(projectPath) { - fullProjectPath = projectPath + if projectPath != "" && projectPath != "." { + return shuttleFileExists(projectPath, fileExists) } else { dir, err := os.Getwd() if err != nil { log.Fatal(err) } - fullProjectPath = path.Join(dir, projectPath) + fullProjectPath := path.Join(dir, projectPath) + exists := shuttleFileExistsRecursive(fullProjectPath, fileExists) + return exists } +} - shuttleFile := path.Join(fullProjectPath, "shuttle.yaml") +type fileExistsFunc func(filePath string) bool - if _, err := os.Stat(shuttleFile); err != nil { - if errors.Is(err, os.ErrNotExist) { - return false +// shuttleFileExistsRecursive tries to find a shuttle file by going towards the root the path, it will check each folder towards the root. +func shuttleFileExistsRecursive(projectPath string, existsFunc fileExistsFunc) bool { + if strings.Contains(projectPath, "/") { + exists := shuttleFileExists(projectPath, existsFunc) + if exists { + return true } + + return shuttleFileExistsRecursive(path.Dir(projectPath), existsFunc) + } + return shuttleFileExists(projectPath, existsFunc) + +} + +// shuttleFileExists will check the given directory and return if a shuttle.yaml file is found +func shuttleFileExists(projectPath string, existsFunc fileExistsFunc) bool { + shuttleFile := path.Join(projectPath, "shuttle.yaml") + return existsFunc(shuttleFile) +} + +func fileExists(filePath string) bool { + _, err := os.Stat(filePath) + if err != nil { + return false + } return true } diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go new file mode 100644 index 0000000..f2ce2b4 --- /dev/null +++ b/cmd/cmd_test.go @@ -0,0 +1,155 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShuttleFileExists(t *testing.T) { + t.Parallel() + + t.Run("full path, with file", func(t *testing.T) { + actual := shuttleFileExists("/some/long/path", func(filePath string) bool { + switch filePath { + case "/some/long/path/shuttle.yaml": + return true + default: + pathNotExpected(t, filePath) + return false + } + }) + + assert.True(t, actual) + }) + + t.Run("full path, no file", func(t *testing.T) { + actual := shuttleFileExists("/some/long/path", func(filePath string) bool { + switch filePath { + case "/some/long/path/shuttle.yaml": + return false + default: + pathNotExpected(t, filePath) + return true + } + }) + + assert.False(t, actual) + }) + + t.Run("current path, with file", func(t *testing.T) { + actual := shuttleFileExists(".", func(filePath string) bool { + switch filePath { + case "shuttle.yaml": + return true + default: + pathNotExpected(t, filePath) + return false + } + }) + + assert.True(t, actual) + }) + + t.Run("current path, no file", func(t *testing.T) { + actual := shuttleFileExists(".", func(filePath string) bool { + switch filePath { + case "shuttle.yaml": + return false + default: + pathNotExpected(t, filePath) + return true + } + }) + + assert.False(t, actual) + }) +} + +func TestShuttleFileExistsRecursive(t *testing.T) { + t.Parallel() + + t.Run("full path, file in given path", func(t *testing.T) { + actual := shuttleFileExistsRecursive("/some/long/path", func(filePath string) bool { + switch filePath { + case "/some/long/path/shuttle.yaml": + return true + default: + pathNotExpected(t, filePath) + return false + } + }) + + assert.True(t, actual) + }) + + t.Run("full path, file in sub directory", func(t *testing.T) { + actual := shuttleFileExistsRecursive("/some/long/path", func(filePath string) bool { + switch filePath { + case "/some/long/path/shuttle.yaml": + return false + case "/some/long/shuttle.yaml": + return true + default: + pathNotExpected(t, filePath) + return false + } + }) + + assert.True(t, actual) + }) + + t.Run("full path, file in root", func(t *testing.T) { + actual := shuttleFileExistsRecursive("/some/long/path", func(filePath string) bool { + switch filePath { + case "/some/long/path/shuttle.yaml": + return false + case "/some/long/shuttle.yaml": + return false + case "/some/shuttle.yaml": + return false + case "/shuttle.yaml": + return true + default: + pathNotExpected(t, filePath) + return false + } + }) + + assert.True(t, actual) + }) + + t.Run("empty path, file false", func(t *testing.T) { + actual := shuttleFileExistsRecursive("", func(filePath string) bool { + switch filePath { + case "shuttle.yaml": + return false + default: + pathNotExpected(t, filePath) + return false + } + }) + + assert.False(t, actual) + }) + + t.Run("current dir, file found", func(t *testing.T) { + actual := shuttleFileExistsRecursive(".", func(filePath string) bool { + switch filePath { + case "shuttle.yaml": + return true + default: + pathNotExpected(t, filePath) + return false + } + }) + + assert.True(t, actual) + }) +} + +func pathNotExpected(t *testing.T, filePath string) { + t.Helper() + + assert.Fail(t, "path was not expected", "the path %s was not expected in matcher", filePath) +}