diff --git a/pkg/mount/mount.go b/pkg/mount/mount.go index 7afa8f077..3acdec9f7 100644 --- a/pkg/mount/mount.go +++ b/pkg/mount/mount.go @@ -75,6 +75,9 @@ type Manager interface { RemoveMountPath(path string, opts map[string]string) error // EmptyTrashDir removes all directories from the mounter trash directory EmptyTrashDir() error + // SafeEmptyTrashDir removes all the directories from the mounter trash directory + // only if the targets have the provided targetPrefix + SafeEmptyTrashDir(targetPrefix string, trashLocation string) error } // MountImpl backend implementation for Mount/Unmount calls @@ -809,20 +812,27 @@ func (m *Mounter) RemoveMountPath(mountPath string, opts map[string]string) erro return nil } +func (m *Mounter) SafeEmptyTrashDir(targetPrefix string, trashLocation string) error { + return m.emptyTrashDir(targetPrefix, trashLocation) +} + func (m *Mounter) EmptyTrashDir() error { - files, err := ioutil.ReadDir(m.trashLocation) + return m.emptyTrashDir("", m.trashLocation) +} + +func (m *Mounter) emptyTrashDir(safeRemovalPrefix, trashLocation string) error { + files, err := ioutil.ReadDir(trashLocation) if err != nil { - logrus.Errorf("failed to read trash dir: %s. Err: %v", m.trashLocation, err) + logrus.Errorf("failed to read trash dir: %s. Err: %v", trashLocation, err) return err } if _, err := sched.Instance().Schedule( func(sched.Interval) { for _, file := range files { - logrus.Infof("[EmptyTrashDir] Scheduled removing file %v in trash location %v", file.Name(), m.trashLocation) - e := m.removeSoftlinkAndTarget(path.Join(m.trashLocation, file.Name())) + e := m.removeSoftlinkAndTarget(safeRemovalPrefix, path.Join(trashLocation, file.Name())) if e != nil { - logrus.Errorf("failed to remove link: %s. Err: %v", path.Join(m.trashLocation, file.Name()), e) + logrus.Errorf("failed to remove link: %s. Err: %v", path.Join(trashLocation, file.Name()), e) } } }, @@ -836,16 +846,29 @@ func (m *Mounter) EmptyTrashDir() error { return nil } -func (m *Mounter) removeSoftlinkAndTarget(link string) error { +func (m *Mounter) removeSoftlinkAndTarget(safeRemovalPrefix, link string) error { if _, err := os.Stat(link); err == nil { target, err := os.Readlink(link) if err != nil { + if len(safeRemovalPrefix) > 0 { + // In case of safe removals if we are not able to validate the target path + // and its prefix we dont want the caller to think we hit an error with this file. + // This is primarily done to not log the error. + return nil + } return err } + if len(safeRemovalPrefix) > 0 && !strings.HasPrefix(target, safeRemovalPrefix) { + return fmt.Errorf("target %s does not have prefix %s, skipping removal", target, safeRemovalPrefix) + } + + logrus.Infof("[EmptyTrashDir] Scheduled removing file %v", target) if err = m.removeMountPath(target); err != nil { return err } + } else { + return fmt.Errorf("failed to stat link: %s. Err: %w", link, err) } if err := os.Remove(link); err != nil { diff --git a/pkg/mount/mount_test.go b/pkg/mount/mount_test.go index 55cf1f065..7c1e89fd4 100644 --- a/pkg/mount/mount_test.go +++ b/pkg/mount/mount_test.go @@ -7,8 +7,10 @@ import ( "sync" "syscall" "testing" + "time" "github.com/libopenstorage/openstorage/pkg/options" + "github.com/libopenstorage/openstorage/pkg/sched" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" ) @@ -23,17 +25,27 @@ const ( var m Manager +func setLogger(fn string, t *testing.T) { + // The mount tests log a lot of messages, so we route the logs + // to a tmp location to avoid Travis CI log limits. + logFile, err := os.Create("/tmp/" + fn + ".log") + require.NoError(t, err, "unable to create log file") + logrus.SetOutput(logFile) +} func TestNFSMounter(t *testing.T) { + setLogger("TestNFSMounter", t) setupNFS(t) allTests(t, source, dest) } func TestBindMounter(t *testing.T) { + setLogger("TestBindMounter", t) setupBindMounter(t) allTests(t, source, dest) } func TestRawMounter(t *testing.T) { + setLogger("TestRawMounter", t) setupRawMounter(t) allTests(t, rawSource, rawDest) } @@ -124,7 +136,6 @@ func enoentUnmountTestWithoutOptions(t *testing.T, source, dest string) { syscall.Unmount(dest, 0) } - // mountTestParallel runs mount and unmount in parallel with serveral dirs // in addition, we trigger failed unmount to test race condition in the case // source directory is not found in the cache @@ -274,3 +285,50 @@ func makeFile(pathname string) error { return nil } + +func TestSafeEmptyTrashDir(t *testing.T) { + sched.Init(time.Second) + m, err := New(NFSMount, nil, []*regexp.Regexp{regexp.MustCompile("")}, nil, []string{}, "") + require.NoError(t, err, "Failed to setup test %v", err) + + err = os.MkdirAll("/tmp/safe-empty-trash-dir-tests", 0755) + require.NoError(t, err) + + defer func() { + err = os.RemoveAll("/tmp/safe-empty-trash-dir-tests") + require.NoError(t, err, "Failed to cleanup after test") + }() + + // Create files that should not be removed + file, err := os.Create("/tmp/safe-empty-trash-dir-tests/should-not-remove.txt") + require.NoError(t, err, "Failed to create file: %v", err) + file.Close() + + // Create a symbolic link that should not be removed + err = os.Symlink("/tmp/safe-empty-trash-dir-tests/should-not-remove.txt", "/tmp/safe-empty-trash-dir-tests/should-not-remove-symlink.txt") + require.NoError(t, err, "Failed to create symlink: %v", err) + + // Create a file that should be removed + file, err = os.Create("/tmp/safe-empty-trash-dir-tests/should-remove-file.txt") + require.NoError(t, err, "Failed to create file: %v", err) + + file.Close() + + // Create a symbolic link + err = os.Symlink("/tmp/safe-empty-trash-dir-tests/should-remove-file.txt", "/tmp/safe-empty-trash-dir-tests/should-remove-symlink.txt") + require.NoError(t, err, "Failed to create symlink: %v", err) + + err = m.SafeEmptyTrashDir("/tmp/safe-empty-trash-dir-tests/should-remove", "/tmp/safe-empty-trash-dir-tests") + require.NoError(t, err, "Failed to empty trash dir %v", err) + + time.Sleep(mountPathRemoveDelay + 5*time.Second) + + _, err = os.Stat("/tmp/safe-empty-trash-dir-tests/should-remove-file.txt") + require.True(t, os.IsNotExist(err), "File should be removed") + _, err = os.Stat("/tmp/safe-empty-trash-dir-tests/should-remove-symlink.txt") + require.True(t, os.IsNotExist(err), "File should be removed") + _, err = os.Stat("/tmp/safe-empty-trash-dir-tests/should-not-remove.txt") + require.NoError(t, err, "File should not be removed") + _, err = os.Stat("/tmp/safe-empty-trash-dir-tests/should-not-remove-symlink.txt") + require.NoError(t, err, "File should not be removed") +}