From bc64fc17a0a58502dfdf8726cfab13db3f203eb0 Mon Sep 17 00:00:00 2001 From: Pontus Freyhult Date: Wed, 10 Apr 2024 17:53:44 +0200 Subject: [PATCH 1/2] Extend s3 reader to be seekable, provide a seekable multireader --- sda-download/internal/storage/seekable.go | 478 ++++++++++++++++++ .../internal/storage/seekable_test.go | 344 +++++++++++++ sda-download/internal/storage/storage.go | 1 + 3 files changed, 823 insertions(+) create mode 100644 sda-download/internal/storage/seekable.go create mode 100644 sda-download/internal/storage/seekable_test.go diff --git a/sda-download/internal/storage/seekable.go b/sda-download/internal/storage/seekable.go new file mode 100644 index 000000000..6c572f366 --- /dev/null +++ b/sda-download/internal/storage/seekable.go @@ -0,0 +1,478 @@ +// Package storage provides interface for storage areas, e.g. s3 or POSIX file system. +package storage + +import ( + "bytes" + "fmt" + "io" + "strings" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" +) + +// NewFileReadSeeker returns an io.ReadSeeker instance +func (pb *posixBackend) NewFileReadSeeker(filePath string) (io.ReadSeekCloser, error) { + + reader, err := pb.NewFileReader(filePath) + if err != nil { + return nil, err + } + + seeker, ok := reader.(io.ReadSeekCloser) + if !ok { + return nil, fmt.Errorf("Invalid posixBackend") + } + + return seeker, nil +} + +// s3CacheBlock is used to keep track of cached data +type s3CacheBlock struct { + start int64 + length int64 + data []byte +} + +// s3Reader is the vehicle to keep track of needed state for the reader +type s3Reader struct { + s3Backend + currentOffset int64 + local []s3CacheBlock + filePath string + objectSize int64 + lock sync.Mutex + outstandingPrefetches []int64 + seeked bool + objectReader io.Reader +} + +func (sb *s3Backend) NewFileReadSeeker(filePath string) (io.ReadSeekCloser, error) { + objectSize, err := sb.GetFileSize(filePath) + + if err != nil { + return nil, err + } + + reader := &s3Reader{ + *sb, + 0, + make([]s3CacheBlock, 0, 32), + filePath, + objectSize, + sync.Mutex{}, + make([]int64, 0, 32), + false, + nil, + } + + return reader, nil +} + +func (r *s3Reader) Close() (err error) { + return nil +} + +func (r *s3Reader) pruneCache() { + r.lock.Lock() + defer r.lock.Unlock() + + if len(r.local) < 16 { + return + } + + // Prune the cache + keepfrom := len(r.local) - 8 + r.local = r.local[keepfrom:] + +} + +func (r *s3Reader) prefetchSize() int64 { + n := r.Conf.Chunksize + + if n >= 5*1024*1024 { + return int64(n) + } + + return 50 * 1024 * 1024 +} + +func (r *s3Reader) prefetchAt(offset int64) { + r.pruneCache() + + r.lock.Lock() + defer r.lock.Unlock() + + if r.isPrefetching(offset) { + // We're already fetching this + return + } + + // Check if we have the data in cache + for _, p := range r.local { + if offset >= p.start && offset < p.start+p.length { + // At least part of the data is here + return + } + } + + // Not found in cache, we should fetch the data + bucket := aws.String(r.Bucket) + key := aws.String(r.filePath) + prefetchSize := r.prefetchSize() + + r.outstandingPrefetches = append(r.outstandingPrefetches, offset) + + r.lock.Unlock() + + wantedRange := aws.String(fmt.Sprintf("bytes=%d-%d", offset, offset+prefetchSize-1)) + + object, err := r.Client.GetObject(&s3.GetObjectInput{ + Bucket: bucket, + Key: key, + Range: wantedRange, + }) + + r.lock.Lock() + + r.removeFromOutstanding(offset) + + if err != nil { + return + } + + responseRange := fmt.Sprintf("bytes %d-", r.currentOffset) + + if object.ContentRange == nil || !strings.HasPrefix(*object.ContentRange, responseRange) { + // Unexpected content range - ignore + return + } + + if len(r.local) > 16 { + // Don't cache anything more right now + return + } + + // Read into Buffer + b := bytes.Buffer{} + _, err = io.Copy(&b, object.Body) + if err != nil { + return + } + + // Store in cache + cacheBytes := b.Bytes() + r.local = append(r.local, s3CacheBlock{offset, int64(len(cacheBytes)), cacheBytes}) +} + +func (r *s3Reader) Seek(offset int64, whence int) (int64, error) { + r.lock.Lock() + defer r.lock.Unlock() + + // Flag that we've seeked, so we don't use the mode optimised for reading from + // start to end + r.seeked = true + + switch whence { + case io.SeekStart: + if offset < 0 { + return r.currentOffset, fmt.Errorf("Invalid offset %v- can't be negative when seeking from start", offset) + } + if offset > r.objectSize { + return r.currentOffset, fmt.Errorf("Invalid offset %v - beyond end of object (size %v)", offset, r.objectSize) + } + + r.currentOffset = offset + go r.prefetchAt(r.currentOffset) + + return offset, nil + + case io.SeekCurrent: + if r.currentOffset+offset < 0 { + return r.currentOffset, fmt.Errorf("Invalid offset %v from %v would be be before start", offset, r.currentOffset) + } + if offset > r.objectSize { + return r.currentOffset, fmt.Errorf("Invalid offset - %v from %v would end up beyond of object %v", offset, r.currentOffset, r.objectSize) + } + + r.currentOffset += offset + go r.prefetchAt(r.currentOffset) + + return r.currentOffset, nil + + case io.SeekEnd: + if r.objectSize+offset < 0 { + return r.currentOffset, fmt.Errorf("Invalid offset %v from end in %v bytes object, would be before file start", offset, r.objectSize) + } + if r.objectSize+offset > r.objectSize { + return r.currentOffset, fmt.Errorf("Invalid offset %v from end in %v bytes object", offset, r.objectSize) + } + + r.currentOffset = r.objectSize + offset + go r.prefetchAt(r.currentOffset) + + return r.currentOffset, nil + } + + return r.currentOffset, fmt.Errorf("Bad whence") +} + +// removeFromOutstanding removes a prefetch from the list of outstanding prefetches once it's no longer active +func (r *s3Reader) removeFromOutstanding(toRemove int64) { + switch len(r.outstandingPrefetches) { + case 0: + // Nothing to do + case 1: + // Check if it's the one we should remove + if r.outstandingPrefetches[0] == toRemove { + r.outstandingPrefetches = r.outstandingPrefetches[:0] + } + + default: + remove := 0 + found := false + for i, j := range r.outstandingPrefetches { + if j == toRemove { + remove = i + found = true + } + } + if found { + r.outstandingPrefetches[remove] = r.outstandingPrefetches[len(r.outstandingPrefetches)-1] + r.outstandingPrefetches = r.outstandingPrefetches[:len(r.outstandingPrefetches)-1] + } + } +} + +// isPrefetching checks if the data is already being fetched +func (r *s3Reader) isPrefetching(offset int64) bool { + // Walk through the outstanding prefetches + for _, p := range r.outstandingPrefetches { + if offset >= p && offset < p+r.prefetchSize() { + // At least some of this read is already being fetched + + return true + } + } + + return false +} + +// wholeReader is a helper for when we read the whole object +func (r *s3Reader) wholeReader(dst []byte) (int, error) { + if r.objectReader == nil { + // First call, setup a reader for the object + object, err := r.Client.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(r.Bucket), + Key: aws.String(r.filePath), + }) + + if err != nil { + return 0, err + } + + // Store for future use + r.objectReader = object.Body + } + + // Just use the reader, offset is handled in the caller + return r.objectReader.Read(dst) +} + +func (r *s3Reader) Read(dst []byte) (n int, err error) { + r.lock.Lock() + defer r.lock.Unlock() + + if !r.seeked { + // If not seeked, guess that we use a whole object reader for performance + + n, err = r.wholeReader(dst) + // We need to keep track of the position in the stream in case we seek + r.currentOffset += int64(n) + + return n, err + } + + if r.currentOffset >= r.objectSize { + // For reading when there is no more data, just return EOF + return 0, io.EOF + } + + start := r.currentOffset + + // Walk through the cache + for _, p := range r.local { + if start >= p.start && start < p.start+p.length { + // At least part of the data is here + + offsetInBlock := start - p.start + + // Pull out wanted data (as much as we have) + n = copy(dst, p.data[offsetInBlock:]) + r.currentOffset += int64(n) + + // Prefetch the next bit + go r.prefetchAt(r.currentOffset) + + return n, nil + } + } + + // Check if we're already fetching this data + if r.isPrefetching(start) { + // Return 0, nil to have the client retry + + return 0, nil + } + + // Not found in cache, need to fetch data + + bucket := aws.String(r.Bucket) + key := aws.String(r.filePath) + + wantedRange := aws.String(fmt.Sprintf("bytes=%d-%d", r.currentOffset, r.currentOffset+r.prefetchSize()-1)) + + r.outstandingPrefetches = append(r.outstandingPrefetches, start) + + r.lock.Unlock() + + object, err := r.Client.GetObject(&s3.GetObjectInput{ + Bucket: bucket, + Key: key, + Range: wantedRange, + }) + + r.lock.Lock() + + r.removeFromOutstanding(start) + + if err != nil { + return 0, err + } + + responseRange := fmt.Sprintf("bytes %d-", r.currentOffset) + + if object.ContentRange == nil || !strings.HasPrefix(*object.ContentRange, responseRange) { + return 0, fmt.Errorf("Unexpected content range %v - expected prefix %v", object.ContentRange, responseRange) + } + + b := bytes.Buffer{} + _, err = io.Copy(&b, object.Body) + + // Add to cache + cacheBytes := bytes.Clone(b.Bytes()) + r.local = append(r.local, s3CacheBlock{start, int64(len(cacheBytes)), cacheBytes}) + + n, err = b.Read(dst) + + r.currentOffset += int64(n) + go r.prefetchAt(r.currentOffset) + + return n, err +} + +// seekableMultiReader is a helper struct to allow io.MultiReader to be used with a seekable reader +type seekableMultiReader struct { + readers []io.Reader + sizes []int64 + currentOffset int64 + totalSize int64 +} + +// SeekableMultiReader constructs a multireader that supports seeking. Requires +// all passed readers to be seekable +func SeekableMultiReader(readers ...io.Reader) (io.ReadSeeker, error) { + + r := make([]io.Reader, len(readers)) + sizes := make([]int64, len(readers)) + + copy(r, readers) + + var totalSize int64 + for i, reader := range readers { + seeker, ok := reader.(io.ReadSeeker) + if !ok { + return nil, fmt.Errorf("Reader %d to SeekableMultiReader is not seekable", i) + } + + size, err := seeker.Seek(0, io.SeekEnd) + if err != nil { + return nil, fmt.Errorf("Size determination failed for reader %d to SeekableMultiReader: %v", i, err) + } + + sizes[i] = size + totalSize += size + + } + + return &seekableMultiReader{r, sizes, 0, totalSize}, nil +} + +func (r *seekableMultiReader) Seek(offset int64, whence int) (int64, error) { + + switch whence { + case io.SeekStart: + r.currentOffset = offset + case io.SeekCurrent: + r.currentOffset += offset + case io.SeekEnd: + r.currentOffset = r.totalSize + offset + + default: + return 0, fmt.Errorf("Unsupported whence") + + } + + return r.currentOffset, nil +} + +func (r *seekableMultiReader) Read(dst []byte) (int, error) { + + var readerStartAt int64 + + for i, reader := range r.readers { + + if r.currentOffset < readerStartAt { + // We want data from a previous reader (? HELP ?) + readerStartAt += r.sizes[i] + + continue + } + + if readerStartAt+r.sizes[i] < r.currentOffset { + // We want data from a later reader + readerStartAt += r.sizes[i] + + continue + } + + // At least part of the data is in this reader + + seekable, ok := reader.(io.ReadSeeker) + if !ok { + return 0, fmt.Errorf("Expected seekable reader but changed") + } + + _, err := seekable.Seek(r.currentOffset-int64(readerStartAt), 0) + if err != nil { + return 0, fmt.Errorf("Unexpected error while seeking: %v", err) + } + + n, err := seekable.Read(dst) + r.currentOffset += int64(n) + + if n > 0 || err != io.EOF { + if err == io.EOF && r.currentOffset < r.totalSize { + // More data left, hold that EOF + err = nil + } + + return n, err + } + + readerStartAt += r.sizes[i] + } + + return 0, io.EOF +} diff --git a/sda-download/internal/storage/seekable_test.go b/sda-download/internal/storage/seekable_test.go new file mode 100644 index 000000000..adbdf7b33 --- /dev/null +++ b/sda-download/internal/storage/seekable_test.go @@ -0,0 +1,344 @@ +package storage + +import ( + "bytes" + "fmt" + "io" + "os" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + log "github.com/sirupsen/logrus" +) + +func TestSeekableBackend(t *testing.T) { + + for _, backendType := range []string{posixType, s3Type} { + + testConf.Type = backendType + + backend, err := NewBackend(testConf) + assert.Nil(t, err, "Backend failed") + + var buf bytes.Buffer + + path := fmt.Sprintf("%v.%v", s3Creatable, time.Now().UnixNano()) + + if testConf.Type == s3Type { + s3back := backend.(*s3Backend) + assert.IsType(t, s3back, &s3Backend{}, "Wrong type from NewBackend with seekable s3") + } + if testConf.Type == posixType { + path, err = writeName() + posix := backend.(*posixBackend) + assert.Nil(t, err, "File creation for backend failed") + assert.IsType(t, posix, &posixBackend{}, "Wrong type from NewBackend with seekable posix") + } + + writer, err := backend.NewFileWriter(path) + + assert.NotNil(t, writer, "Got a nil reader for writer from s3") + assert.Nil(t, err, "posix NewFileWriter failed when it shouldn't") + + for i := 0; i < 1000; i++ { + written, err := writer.Write(writeData) + assert.Nil(t, err, "Failure when writing to s3 writer") + assert.Equal(t, len(writeData), written, "Did not write all writeData") + } + + writer.Close() + + reader, err := backend.NewFileReadSeeker(path) + assert.Nil(t, err, "s3 NewFileReadSeeker failed when it should work") + assert.NotNil(t, reader, "Got a nil reader for s3") + + size, err := backend.GetFileSize(path) + assert.Nil(t, err, "s3 GetFileSize failed when it should work") + assert.Equal(t, int64(len(writeData))*1000, size, "Got an incorrect file size") + + if reader == nil { + t.Error("reader that should be usable is not, bailing out") + + return + } + + var readBackBuffer [4096]byte + seeker := reader + + _, err = seeker.Read(readBackBuffer[0:4096]) + assert.Equal(t, writeData, readBackBuffer[:14], "did not read back data as expected") + assert.Nil(t, err, "read returned unexpected error") + + if testConf.Type == s3Type { + // POSIX is more allowing + _, err := seeker.Seek(95000, io.SeekStart) + assert.NotNil(t, err, "Seek didn't fail when it should") + + _, err = seeker.Seek(-95000, io.SeekStart) + assert.NotNil(t, err, "Seek didn't fail when it should") + + _, err = seeker.Seek(-95000, io.SeekCurrent) + assert.NotNil(t, err, "Seek didn't fail when it should") + + _, err = seeker.Seek(95000, io.SeekCurrent) + assert.NotNil(t, err, "Seek didn't fail when it should") + + _, err = seeker.Seek(95000, io.SeekEnd) + assert.NotNil(t, err, "Seek didn't fail when it should") + + _, err = seeker.Seek(-95000, io.SeekEnd) + assert.NotNil(t, err, "Seek didn't fail when it should") + + _, err = seeker.Seek(0, 4) + assert.NotNil(t, err, "Seek didn't fail when it should") + + } + + offset, err := seeker.Seek(15, io.SeekStart) + assert.Nil(t, err, "Seek failed when it shouldn't") + assert.Equal(t, int64(15), offset, "Seek did not return expected offset") + + offset, err = seeker.Seek(5, io.SeekCurrent) + assert.Nil(t, err, "Seek failed when it shouldn't") + assert.Equal(t, int64(20), offset, "Seek did not return expected offset") + + offset, err = seeker.Seek(-5, io.SeekEnd) + assert.Nil(t, err, "Seek failed when it shouldn't") + assert.Equal(t, int64(13995), offset, "Seek did not return expected offset") + + n, err := seeker.Read(readBackBuffer[0:4096]) + assert.Equal(t, 5, n, "Unexpected amount of read bytes") + assert.Nil(t, err, "Read failed when it shouldn't") + + n, err = seeker.Read(readBackBuffer[0:4096]) + + assert.Equal(t, io.EOF, err, "Expected EOF") + assert.Equal(t, 0, n, "Unexpected amount of read bytes") + + offset, err = seeker.Seek(0, io.SeekEnd) + assert.Nil(t, err, "Seek failed when it shouldn't") + assert.Equal(t, int64(14000), offset, "Seek did not return expected offset") + + n, err = seeker.Read(readBackBuffer[0:4096]) + assert.Equal(t, 0, n, "Unexpected amount of read bytes") + assert.Equal(t, io.EOF, err, "Read returned unexpected error when EOF") + + offset, err = seeker.Seek(6302, io.SeekStart) + assert.Nil(t, err, "Seek failed") + assert.Equal(t, int64(6302), offset, "Seek did not return expected offset") + + n = 0 + for i := 0; i < 500000 && n == 0 && err == nil; i++ { + // Allow 0 sizes while waiting for prefetch + n, err = seeker.Read(readBackBuffer[0:4096]) + } + + assert.Equal(t, 4096, n, "Read did not return expected amounts of bytes for %v", seeker) + assert.Equal(t, writeData[2:], readBackBuffer[:12], "did not read back data as expected") + assert.Nil(t, err, "unexpected error when reading back data") + + offset, err = seeker.Seek(6302, io.SeekStart) + assert.Nil(t, err, "unexpected error when seeking to read back data") + assert.Equal(t, int64(6302), offset, "returned offset wasn't expected") + + largeBuf := make([]byte, 65536) + readLen, err := seeker.Read(largeBuf) + assert.Equal(t, 7698, readLen, "did not read back expected amount of data") + assert.Nil(t, err, "unexpected error when reading back data") + + buf.Reset() + + log.SetOutput(&buf) + + if !testing.Short() { + _, err = backend.GetFileSize(s3DoesNotExist) + assert.NotNil(t, err, "s3 GetFileSize worked when it should not") + assert.NotZero(t, buf.Len(), "Expected warning missing") + + buf.Reset() + + reader, err = backend.NewFileReadSeeker(s3DoesNotExist) + assert.NotNil(t, err, "s3 NewFileReader worked when it should not") + assert.Nil(t, reader, "Got a non-nil reader for s3") + assert.NotZero(t, buf.Len(), "Expected warning missing") + } + + log.SetOutput(os.Stdout) + } +} + +func TestS3SeekablePrefetchSize(t *testing.T) { + + testConf.Type = s3Type + chunkSize := testConf.S3.Chunksize + testConf.S3.Chunksize = 5 * 1024 * 1024 + backend, err := NewBackend(testConf) + s3back := backend.(*s3Backend) + assert.IsType(t, s3back, &s3Backend{}, "Wrong type from NewBackend with seekable s3") + assert.Nil(t, err, "S3 backend failed") + path := fmt.Sprintf("%v.%v", s3Creatable, time.Now().UnixNano()) + + writer, err := backend.NewFileWriter(path) + + assert.NotNil(t, writer, "Got a nil reader for writer from s3") + assert.Nil(t, err, "posix NewFileWriter failed when it shouldn't") + + writer.Close() + + reader, err := backend.NewFileReadSeeker(path) + assert.Nil(t, err, "s3 NewFileReadSeeker failed when it should work") + assert.NotNil(t, reader, "Got a nil reader for s3") + + s := reader.(*s3Reader) + + assert.Equal(t, int64(5*1024*1024), s.prefetchSize(), "Prefetch size not as expected with chunksize 5MB") + s.Conf.Chunksize = 0 + assert.Equal(t, int64(50*1024*1024), s.prefetchSize(), "Prefetch size not as expected") + + s.Conf.Chunksize = 1024 * 1024 + assert.Equal(t, int64(50*1024*1024), s.prefetchSize(), "Prefetch size not as expected") + + testConf.S3.Chunksize = chunkSize +} + +func TestS3SeekableSpecial(t *testing.T) { + // Some special tests here, messing with internals to expose behaviour + + testConf.Type = s3Type + + backend, err := NewBackend(testConf) + assert.Nil(t, err, "Backend failed") + + path := fmt.Sprintf("%v.%v", s3Creatable, time.Now().UnixNano()) + + s3back := backend.(*s3Backend) + assert.IsType(t, s3back, &s3Backend{}, "Wrong type from NewBackend with seekable s3") + + writer, err := backend.NewFileWriter(path) + + assert.NotNil(t, writer, "Got a nil reader for writer from s3") + assert.Nil(t, err, "posix NewFileWriter failed when it shouldn't") + + for i := 0; i < 1000; i++ { + written, err := writer.Write(writeData) + assert.Nil(t, err, "Failure when writing to s3 writer") + assert.Equal(t, len(writeData), written, "Did not write all writeData") + } + + writer.Close() + + reader, err := backend.NewFileReadSeeker(path) + reader.(*s3Reader).seeked = true + + assert.Nil(t, err, "s3 NewFileReader failed when it should work") + assert.NotNil(t, reader, "Got a nil reader for s3") + size, err := backend.GetFileSize(path) + assert.Nil(t, err, "s3 GetFileSize failed when it should work") + assert.Equal(t, int64(len(writeData))*1000, size, "Got an incorrect file size") + + if reader == nil { + t.Error("reader that should be usable is not, bailing out") + + return + } + + var readBackBuffer [4096]byte + seeker := reader + + _, err = seeker.Read(readBackBuffer[0:4096]) + assert.Equal(t, writeData, readBackBuffer[:14], "did not read back data as expected") + assert.Nil(t, err, "read returned unexpected error") + + err = seeker.Close() + assert.Nil(t, err, "unexpected error when closing") + + reader, err = backend.NewFileReadSeeker(path) + assert.Nil(t, err, "unexpected error when creating reader") + + s := reader.(*s3Reader) + s.seeked = true + s.prefetchAt(0) + assert.Equal(t, 1, len(s.local), "nothing cached after prefetch") + // Clear cache + s.local = s.local[:0] + + s.outstandingPrefetches = []int64{0} + t.Logf("Cache %v, outstanding %v", s.local, s.outstandingPrefetches) + + n, err := s.Read(readBackBuffer[0:4096]) + assert.Nil(t, err, "read returned unexpected error") + assert.Equal(t, 0, n, "got data when we should get 0 because of prefetch") + + for i := 0; i < 30; i++ { + s.local = append(s.local, s3CacheBlock{90000000, int64(0), nil}) + } + s.prefetchAt(0) + assert.Equal(t, 8, len(s.local), "unexpected length of cache after prefetch") + + s.outstandingPrefetches = []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + s.removeFromOutstanding(9) + assert.Equal(t, s.outstandingPrefetches, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8}, "unexpected outstanding prefetches after remove") + s.removeFromOutstanding(19) + assert.Equal(t, s.outstandingPrefetches, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8}, "unexpected outstanding prefetches after remove") + s.removeFromOutstanding(5) + // We don't care about the internal order, sort for simplicity + slices.Sort(s.outstandingPrefetches) + assert.Equal(t, s.outstandingPrefetches, []int64{0, 1, 2, 3, 4, 6, 7, 8}, "unexpected outstanding prefetches after remove") + + s.objectReader = nil + s.Bucket = "" + s.filePath = "" + data := make([]byte, 100) + _, err = s.wholeReader(data) + assert.NotNil(t, err, "wholeReader object instantiation worked when it should have failed") +} + +func TestSeekableMultiReader(t *testing.T) { + + readers := make([]io.Reader, 10) + for i := 0; i < 10; i++ { + readers[i] = bytes.NewReader(writeData) + } + + seeker, err := SeekableMultiReader(readers...) + assert.Nil(t, err, "unexpected error from creating SeekableMultiReader") + + var readBackBuffer [4096]byte + + _, err = seeker.Read(readBackBuffer[0:4096]) + assert.Equal(t, writeData, readBackBuffer[:14], "did not read back data as expected") + assert.Nil(t, err, "unexpected error from read") + + offset, err := seeker.Seek(60, io.SeekStart) + + assert.Nil(t, err, "Seek failed") + assert.Equal(t, int64(60), offset, "Seek did not return expected offset") + + // We don't know how many bytes this should return + _, err = seeker.Read(readBackBuffer[0:4096]) + assert.Equal(t, writeData[4:], readBackBuffer[:10], "did not read back data as expected") + assert.Nil(t, err, "Read failed when it should not") + + offset, err = seeker.Seek(0, io.SeekEnd) + assert.Equal(t, int64(140), offset, "Seek did not return expected offset") + assert.Nil(t, err, "Seek failed when it should not") + + n, err := seeker.Read(readBackBuffer[0:4096]) + + assert.Equal(t, 0, n, "Read did not return expected amounts of bytes") + assert.Equal(t, io.EOF, err, "did not get EOF as expected") + + offset, err = seeker.Seek(56, io.SeekStart) + assert.Equal(t, int64(56), offset, "Seek did not return expected offset") + assert.Nil(t, err, "Seek failed unexpectedly") + + largeBuf := make([]byte, 65536) + readLen, err := seeker.Read(largeBuf) + assert.Nil(t, err, "unexpected error when reading back data") + assert.Equal(t, 14, readLen, "did not read back expect amount of data") + + log.SetOutput(os.Stdout) +} diff --git a/sda-download/internal/storage/storage.go b/sda-download/internal/storage/storage.go index 3cb68a876..22c34ffd2 100644 --- a/sda-download/internal/storage/storage.go +++ b/sda-download/internal/storage/storage.go @@ -27,6 +27,7 @@ import ( type Backend interface { GetFileSize(filePath string) (int64, error) NewFileReader(filePath string) (io.ReadCloser, error) + NewFileReadSeeker(filePath string) (io.ReadSeekCloser, error) NewFileWriter(filePath string) (io.WriteCloser, error) } From dd4ed6faa8ce2daba4f08478fbaf0222e43684f6 Mon Sep 17 00:00:00 2001 From: Pontus Freyhult Date: Mon, 22 Apr 2024 15:34:35 +0200 Subject: [PATCH 2/2] Use seekable file interface if not entire file is requested --- sda-download/api/sda/sda.go | 46 +++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/sda-download/api/sda/sda.go b/sda-download/api/sda/sda.go index ae686b61c..7e5856cfe 100644 --- a/sda-download/api/sda/sda.go +++ b/sda-download/api/sda/sda.go @@ -299,8 +299,20 @@ func Download(c *gin.Context) { } } + wholeFile := true + if start != 0 || end != 0 { + wholeFile = false + } + // Get archive file handle - file, err := Backend.NewFileReader(fileDetails.ArchivePath) + var file io.Reader + + if wholeFile { + file, err = Backend.NewFileReader(fileDetails.ArchivePath) + } else { + file, err = Backend.NewFileReadSeeker(fileDetails.ArchivePath) + } + if err != nil { log.Errorf("could not find archive file %s, %s", fileDetails.ArchivePath, err) c.String(http.StatusInternalServerError, "archive error") @@ -368,9 +380,21 @@ func Download(c *gin.Context) { return } log.Debugf("Reencrypted c4gh file header = %v", newHeader) + newHr := bytes.NewReader(newHeader) - fileStream = io.MultiReader(newHr, file) + if wholeFile { + fileStream = io.MultiReader(newHr, file) + } else { + seeker, _ := file.(io.ReadSeeker) + fileStream, err = storage.SeekableMultiReader(newHr, seeker) + if err != nil { + log.Errorf("Failed to construct SeekableMultiReader, reason: %v", err) + c.String(http.StatusInternalServerError, "file decoding error") + + return + } + } default: // Reencrypt header for use with our temporary key newHeader, err := reencryptHeader(fileDetails.Header, config.Config.App.Crypt4GHPublicKeyB64) @@ -381,8 +405,22 @@ func Download(c *gin.Context) { return } - encryptedFileReader := io.MultiReader(bytes.NewReader(newHeader), file) - c4ghfileStream, err := streaming.NewCrypt4GHReader(encryptedFileReader, config.Config.App.Crypt4GHPrivateKey, nil) + newHr := bytes.NewReader(newHeader) + + if wholeFile { + fileStream = io.MultiReader(newHr, file) + } else { + seeker, _ := file.(io.ReadSeeker) + fileStream, err = storage.SeekableMultiReader(newHr, seeker) + if err != nil { + log.Errorf("Failed to construct SeekableMultiReader, reason: %v", err) + c.String(http.StatusInternalServerError, "file decoding error") + + return + } + } + + c4ghfileStream, err := streaming.NewCrypt4GHReader(fileStream, config.Config.App.Crypt4GHPrivateKey, nil) defer c4ghfileStream.Close() if err != nil { log.Errorf("could not prepare file for streaming, %s", err)