Skip to content

Commit

Permalink
Fix routine leak on s3 write failure (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored May 26, 2022
1 parent 63ba605 commit 57d7882
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 31 deletions.
70 changes: 44 additions & 26 deletions s3sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package s3sync

import (
"context"
"errors"
"net/url"
"os"
Expand Down Expand Up @@ -89,6 +90,9 @@ func (m *Manager) Sync(source, dest string) error {
return err
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

chJob := make(chan func())
var wg sync.WaitGroup
for i := 0; i < m.nJobs; i++ {
Expand All @@ -115,17 +119,17 @@ func (m *Manager) Sync(source, dest string) error {
if err != nil {
return err
}
return m.syncS3ToS3(chJob, sourceS3Path, destS3Path)
return m.syncS3ToS3(ctx, chJob, sourceS3Path, destS3Path)
}
return m.syncS3ToLocal(chJob, sourceS3Path, dest)
return m.syncS3ToLocal(ctx, chJob, sourceS3Path, dest)
}

if isS3URL(destURL) {
destS3Path, err := urlToS3Path(destURL)
if err != nil {
return err
}
return m.syncLocalToS3(chJob, source, destS3Path)
return m.syncLocalToS3(ctx, chJob, source, destS3Path)
}

return errors.New("local to local sync is not supported")
Expand All @@ -135,15 +139,15 @@ func isS3URL(url *url.URL) bool {
return url.Scheme == "s3"
}

func (m *Manager) syncS3ToS3(chJob chan func(), sourcePath, destPath *s3Path) error {
func (m *Manager) syncS3ToS3(ctx context.Context, chJob chan func(), sourcePath, destPath *s3Path) error {
return errors.New("S3 to S3 sync feature is not implemented")
}

func (m *Manager) syncLocalToS3(chJob chan func(), sourcePath string, destPath *s3Path) error {
func (m *Manager) syncLocalToS3(ctx context.Context, chJob chan func(), sourcePath string, destPath *s3Path) error {
wg := &sync.WaitGroup{}
errs := &multiErr{}
for source := range filterFilesForSync(
listLocalFiles(sourcePath), m.listS3Files(destPath), m.del,
listLocalFiles(ctx, sourcePath), m.listS3Files(ctx, destPath), m.del,
) {
wg.Add(1)
source := source
Expand Down Expand Up @@ -171,11 +175,11 @@ func (m *Manager) syncLocalToS3(chJob chan func(), sourcePath string, destPath *
}

// syncS3ToLocal syncs the given s3 path to the given local path.
func (m *Manager) syncS3ToLocal(chJob chan func(), sourcePath *s3Path, destPath string) error {
func (m *Manager) syncS3ToLocal(ctx context.Context, chJob chan func(), sourcePath *s3Path, destPath string) error {
wg := &sync.WaitGroup{}
errs := &multiErr{}
for source := range filterFilesForSync(
m.listS3Files(sourcePath), listLocalFiles(destPath), m.del,
m.listS3Files(ctx, sourcePath), listLocalFiles(ctx, destPath), m.del,
) {
wg.Add(1)
source := source
Expand Down Expand Up @@ -351,14 +355,14 @@ func (m *Manager) deleteRemote(file *fileInfo, destPath *s3Path) error {
}

// listS3Files return a channel which receives the file infos under the given s3Path.
func (m *Manager) listS3Files(path *s3Path) chan *fileInfo {
func (m *Manager) listS3Files(ctx context.Context, path *s3Path) chan *fileInfo {
c := make(chan *fileInfo, 50000) // TODO: revisit this buffer size later

go func() {
defer close(c)
var token *string
for {
if token = m.listS3FileWithToken(c, path, token); token == nil {
if token = m.listS3FileWithToken(ctx, c, path, token); token == nil {
break
}
}
Expand All @@ -368,14 +372,14 @@ func (m *Manager) listS3Files(path *s3Path) chan *fileInfo {
}

// listS3FileWithToken lists (send to the result channel) the s3 files from the given continuation token.
func (m *Manager) listS3FileWithToken(c chan *fileInfo, path *s3Path, token *string) *string {
func (m *Manager) listS3FileWithToken(ctx context.Context, c chan *fileInfo, path *s3Path, token *string) *string {
list, err := m.s3.ListObjectsV2(&s3.ListObjectsV2Input{
Bucket: &path.bucket,
Prefix: &path.bucketPrefix,
ContinuationToken: token,
})
if err != nil {
sendErrorInfoToChannel(c, err)
sendErrorInfoToChannel(ctx, c, err)
return nil
}

Expand All @@ -386,34 +390,40 @@ func (m *Manager) listS3FileWithToken(c chan *fileInfo, path *s3Path, token *str
}
name, err := filepath.Rel(path.bucketPrefix, *object.Key)
if err != nil {
sendErrorInfoToChannel(c, err)
sendErrorInfoToChannel(ctx, c, err)
continue
}
var fi *fileInfo
if name == "." {
// Single file was specified
c <- &fileInfo{
fi = &fileInfo{
name: filepath.Base(*object.Key),
path: filepath.Dir(*object.Key),
size: *object.Size,
lastModified: *object.LastModified,
singleFile: true,
}
} else {
c <- &fileInfo{
fi = &fileInfo{
name: name,
path: *object.Key,
size: *object.Size,
lastModified: *object.LastModified,
}
}
select {
case c <- fi:
case <-ctx.Done():
return nil
}
}

return list.NextContinuationToken
}

// listLocalFiles returns a channel which receives the infos of the files under the given basePath.
// basePath have to be absolute path.
func listLocalFiles(basePath string) chan *fileInfo {
func listLocalFiles(ctx context.Context, basePath string) chan *fileInfo {
c := make(chan *fileInfo)

basePath = filepath.ToSlash(basePath)
Expand All @@ -427,51 +437,59 @@ func listLocalFiles(basePath string) chan *fileInfo {
// Returns and closes the channel without sending any.
return
} else if err != nil {
sendErrorInfoToChannel(c, err)
sendErrorInfoToChannel(ctx, c, err)
return
}

if !stat.IsDir() {
sendFileInfoToChannel(c, filepath.Dir(basePath), basePath, stat, true)
sendFileInfoToChannel(ctx, c, filepath.Dir(basePath), basePath, stat, true)
return
}

sendFileInfoToChannel(c, basePath, basePath, stat, false)
sendFileInfoToChannel(ctx, c, basePath, basePath, stat, false)

err = filepath.Walk(basePath, func(path string, stat os.FileInfo, err error) error {
if err != nil {
return err
}
sendFileInfoToChannel(c, basePath, path, stat, false)
return nil
sendFileInfoToChannel(ctx, c, basePath, path, stat, false)
return ctx.Err()
})

if err != nil {
sendErrorInfoToChannel(c, err)
sendErrorInfoToChannel(ctx, c, err)
}

}()
return c
}

func sendFileInfoToChannel(c chan *fileInfo, basePath, path string, stat os.FileInfo, singleFile bool) {
func sendFileInfoToChannel(ctx context.Context, c chan *fileInfo, basePath, path string, stat os.FileInfo, singleFile bool) {
if stat == nil || stat.IsDir() {
return
}
relPath, _ := filepath.Rel(basePath, path)
c <- &fileInfo{
fi := &fileInfo{
name: relPath,
path: path,
size: stat.Size(),
lastModified: stat.ModTime(),
singleFile: singleFile,
}
select {
case c <- fi:
case <-ctx.Done():
}
}

func sendErrorInfoToChannel(c chan *fileInfo, err error) {
c <- &fileInfo{
func sendErrorInfoToChannel(ctx context.Context, c chan *fileInfo, err error) {
fi := &fileInfo{
err: err,
}
select {
case c <- fi:
case <-ctx.Done():
}
}

// filterFilesForSync filters the source files from the given destination files, and returns
Expand Down
11 changes: 6 additions & 5 deletions s3sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package s3sync
import (
"bytes"
"compress/gzip"
"context"
"io/ioutil"
"os"
"path/filepath"
Expand Down Expand Up @@ -578,7 +579,7 @@ func TestListLocalFiles(t *testing.T) {
}

t.Run("Root", func(t *testing.T) {
paths := collectFilePaths(listLocalFiles(temp))
paths := collectFilePaths(listLocalFiles(context.Background(), temp))
expected := []string{
filepath.Join(temp, "bar", "baz", "test3"),
filepath.Join(temp, "foo", "test2"),
Expand All @@ -590,15 +591,15 @@ func TestListLocalFiles(t *testing.T) {
})

t.Run("EmptyDir", func(t *testing.T) {
paths := collectFilePaths(listLocalFiles(filepath.Join(temp, "empty")))
paths := collectFilePaths(listLocalFiles(context.Background(), filepath.Join(temp, "empty")))
expected := []string{}
if !reflect.DeepEqual(expected, paths) {
t.Errorf("Local file list is expected to be %v, got %v", expected, paths)
}
})

t.Run("File", func(t *testing.T) {
paths := collectFilePaths(listLocalFiles(filepath.Join(temp, "test1")))
paths := collectFilePaths(listLocalFiles(context.Background(), filepath.Join(temp, "test1")))
expected := []string{
filepath.Join(temp, "test1"),
}
Expand All @@ -608,7 +609,7 @@ func TestListLocalFiles(t *testing.T) {
})

t.Run("Dir", func(t *testing.T) {
paths := collectFilePaths(listLocalFiles(filepath.Join(temp, "foo")))
paths := collectFilePaths(listLocalFiles(context.Background(), filepath.Join(temp, "foo")))
expected := []string{
filepath.Join(temp, "foo", "test2"),
}
Expand All @@ -618,7 +619,7 @@ func TestListLocalFiles(t *testing.T) {
})

t.Run("Dir2", func(t *testing.T) {
paths := collectFilePaths(listLocalFiles(filepath.Join(temp, "bar")))
paths := collectFilePaths(listLocalFiles(context.Background(), filepath.Join(temp, "bar")))
expected := []string{
filepath.Join(temp, "bar", "baz", "test3"),
}
Expand Down

0 comments on commit 57d7882

Please sign in to comment.