diff --git a/file.go b/file.go index 3fe688f..a2845c9 100644 --- a/file.go +++ b/file.go @@ -81,14 +81,42 @@ func extractFiles(zipFilePath, filesToExtractFromZipPath, localPath string) erro // Add the path from which we will extract files to the path prefix so we can exclude the appropriate files pathPrefix = filepath.Join(pathPrefix, filesToExtractFromZipPath) - pathPrefix = pathPrefix + "/" // Iterate through the files in the archive, // printing some of their contents. for _, f := range r.File { - // If the given file is in the filesToExtractFromZipPath, proceed - if strings.Index(f.Name, pathPrefix) == 0 { + // + // Skip the current archive item being processed based on + // rules described in comments below. + // + if f.FileInfo().IsDir() { + // The current archive item is a directory. + // Archive item's f.Name will always be appended with a "/", so we use + // that fact to ensure we are working with a full directory name. Skip + // to next item if (pathPrefix + "/") is not a prefix in f.Name + if strings.Index(f.Name, pathPrefix + "/") != 0 { + continue + } + } else { + // The current archive item is a file. + // There are three things possible here: + // 1 User specified a filename using --source-path option and if we hit + // that file in archive, we need to process this file, so do not skip. + // 2 We do additional checks if we did not hit the exact file specified by user: + // 2a User specified a directory (or wants to download full repo); + // the (pathPrefix + "/") is a prefix in f.Name, we want to process this + // file, so not skipping. + // 2b User specified either file or directory. + // (pathPrefix + "/") is not a prefix in the current item's f.Name, we + // have hit a file that is not within the folder that the user specified, + // so we skip it. + if f.Name != pathPrefix { + if strings.Index(f.Name, pathPrefix + "/") != 0 { + continue + } + } + } if f.FileInfo().IsDir() { // Create a directory @@ -115,7 +143,6 @@ func extractFiles(zipFilePath, filesToExtractFromZipPath, localPath string) erro return fmt.Errorf("Failed to write file: %s", err) } } - } } return nil diff --git a/file_test.go b/file_test.go index 775e356..2a87af1 100644 --- a/file_test.go +++ b/file_test.go @@ -340,6 +340,35 @@ func TestExtractFiles(t *testing.T) { } } +func TestExtractFilesExtractFile(t *testing.T) { + // Create a temp directory + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + defer os.RemoveAll(tempDir) + + zipFilePath := "test-fixtures/fetch-test-public-0.0.4.zip" + filePathToExtract := "zzz.txt" + localFileName := "/localzzz.txt" + localPathName := filepath.Join(tempDir, localFileName) + err = extractFiles(zipFilePath, filePathToExtract, localPathName) + if err != nil { + t.Fatalf("Failed to extract files: %s", err) + } + + filepath.Walk(tempDir, func(path string, info os.FileInfo, err error) error { + relativeFilename := strings.TrimPrefix(path, tempDir) + + if ! info.IsDir() { + if relativeFilename != localFileName { + t.Fatalf("Expected local file %s to be created, but not found.\n", localFileName) + } + } + return nil + }) +} + // Return ture if the given slice contains the given string func stringInSlice(s string, slice []string) bool { for _, val := range slice {