Skip to content

Commit

Permalink
Switch to new Azure SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanBaulch committed Nov 5, 2024
1 parent ee1b14f commit 5f54ef9
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 389 deletions.
167 changes: 78 additions & 89 deletions backend/azure/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package azure

import (
"context"
"fmt"
"errors"
"io"
"net/url"
"strings"
"time"

"github.com/Azure/azure-pipeline-go/pipeline"
"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"

"github.com/c2fo/vfs/v6"
"github.com/c2fo/vfs/v6/utils"
Expand All @@ -24,7 +26,7 @@ type Client interface {

// SetMetadata should add the metadata specified by the parameter metadata for the blob specified by the parameter
// file.
SetMetadata(file vfs.File, metadata map[string]string) error
SetMetadata(file vfs.File, metadata map[string]*string) error

// Upload should create or update the blob specified by the file parameter with the contents of the content
// parameter
Expand All @@ -48,7 +50,7 @@ type Client interface {

// DefaultClient is the main implementation that actually makes the calls to Azure Blob Storage
type DefaultClient struct {
pipeline pipeline.Pipeline
credential any
}

// NewClient initializes a new DefaultClient
Expand All @@ -58,43 +60,39 @@ func NewClient(options *Options) (*DefaultClient, error) {
return nil, err
}

// This configures the client to use the default retry policy. The default policy uses exponential backoff with
// maxRetries = 4. If this behavior needs to be changed, add the Retry member to azblob.PipelineOptions. For
// more information on azure retry policies see https://pkg.go.dev/github.com/Azure/azure-storage-blob-go/azblob#RetryOptions
//
// Example (this is not the default):
// RetryOptions{
// Policy: RetryPolicyExponential, // Use exponential backoff as opposed to linear
// MaxTries: 3, // Try at most 3 times to perform the operation (set to 1 to disable retries)
// TryTimeout: time.Second * 3, // Maximum time allowed for any single try
// RetryDelay: time.Second * 1, // Backoff amount for each retry (exponential or linear)
// MaxRetryDelay: time.Second * 3, // Max delay between retries
// }
pl := azblob.NewPipeline(credential, azblob.PipelineOptions{})

return &DefaultClient{pl}, nil
return &DefaultClient{credential}, nil
}

func (a *DefaultClient) newContainerClient(containerURL string) (*container.Client, error) {
switch cred := a.credential.(type) {
case azcore.TokenCredential:
return container.NewClient(containerURL, cred, nil)
case *container.SharedKeyCredential:
return container.NewClientWithSharedKeyCredential(containerURL, cred, nil)
default:
return container.NewClientWithNoCredential(containerURL, nil)
}
}

// Properties fetches the properties for the blob specified by the parameters containerURI and filePath
func (a *DefaultClient) Properties(containerURI, filePath string) (*BlobProperties, error) {
URL, err := url.Parse(containerURI)
cli, err := a.newContainerClient(containerURI)
if err != nil {
return nil, err
}
containerURL := azblob.NewContainerURL(*URL, a.pipeline)

if filePath == "" {
// this is only used to check for the existence of a container so we don't care about anything but the
// error
_, err := containerURL.GetProperties(context.Background(), azblob.LeaseAccessConditions{})
_, err := cli.GetProperties(context.Background(), nil)
if err != nil {
return nil, err
}
return nil, nil
}

blobURL := containerURL.NewBlockBlobURL(utils.RemoveLeadingSlash(filePath))
resp, err := blobURL.GetProperties(context.Background(), azblob.BlobAccessConditions{}, azblob.ClientProvidedKeyOptions{})
blobURL := cli.NewBlockBlobClient(utils.RemoveLeadingSlash(filePath))
resp, err := blobURL.GetProperties(context.Background(), nil)
if err != nil {
return nil, err
}
Expand All @@ -103,119 +101,110 @@ func (a *DefaultClient) Properties(containerURI, filePath string) (*BlobProperti

// Upload uploads a new file to Azure Blob Storage
func (a *DefaultClient) Upload(file vfs.File, content io.ReadSeeker) error {
URL, err := url.Parse(file.Location().(*Location).ContainerURL())
cli, err := a.newContainerClient(file.Location().(*Location).ContainerURL())
if err != nil {
return err
}

containerURL := azblob.NewContainerURL(*URL, a.pipeline)
blobURL := containerURL.NewBlockBlobURL(utils.RemoveLeadingSlash(file.Path()))
_, err = blobURL.Upload(context.Background(), content, azblob.BlobHTTPHeaders{}, azblob.Metadata{},
azblob.BlobAccessConditions{}, azblob.DefaultAccessTier, nil, azblob.ClientProvidedKeyOptions{}, azblob.ImmutabilityPolicyOptions{})
blobURL := cli.NewBlockBlobClient(utils.RemoveLeadingSlash(file.Path()))
body, ok := content.(io.ReadSeekCloser)
if !ok {
body = streaming.NopCloser(content)
}
_, err = blobURL.Upload(context.Background(), body, nil)
return err
}

// SetMetadata sets the given metadata for the blob
func (a *DefaultClient) SetMetadata(file vfs.File, metadata map[string]string) error {
URL, err := url.Parse(file.Location().(*Location).ContainerURL())
func (a *DefaultClient) SetMetadata(file vfs.File, metadata map[string]*string) error {
cli, err := a.newContainerClient(file.Location().(*Location).ContainerURL())
if err != nil {
return err
}

containerURL := azblob.NewContainerURL(*URL, a.pipeline)
blobURL := containerURL.NewBlockBlobURL(utils.RemoveLeadingSlash(file.Path()))
_, err = blobURL.SetMetadata(context.Background(), metadata, azblob.BlobAccessConditions{}, azblob.ClientProvidedKeyOptions{})
blobURL := cli.NewBlockBlobClient(utils.RemoveLeadingSlash(file.Path()))
_, err = blobURL.SetMetadata(context.Background(), metadata, nil)
return err
}

// Download returns an io.ReadCloser for the given vfs.File
func (a *DefaultClient) Download(file vfs.File) (io.ReadCloser, error) {
URL, err := url.Parse(file.Location().(*Location).ContainerURL())
cli, err := a.newContainerClient(file.Location().(*Location).ContainerURL())
if err != nil {
return nil, err
}

containerURL := azblob.NewContainerURL(*URL, a.pipeline)
blobURL := containerURL.NewBlockBlobURL(utils.RemoveLeadingSlash(file.Path()))
get, err := blobURL.Download(context.Background(), 0, 0, azblob.BlobAccessConditions{}, false, azblob.ClientProvidedKeyOptions{})
blobURL := cli.NewBlockBlobClient(utils.RemoveLeadingSlash(file.Path()))
get, err := blobURL.DownloadStream(context.Background(), nil)
if err != nil {
return nil, err
}
return get.Body(azblob.RetryReaderOptions{}), nil
return get.Body, nil
}

// Copy copies srcFile to the destination tgtFile within Azure Blob Storage. Note that in the case where we get
// encoded spaces in the file name (i.e. %20) the '%' must be encoded or the copy command will return a not found
// error.
func (a *DefaultClient) Copy(srcFile, tgtFile vfs.File) error {
// Can't use url.PathEscape here since that will escape everything (even the directory separators)
srcURL, err := url.Parse(strings.Replace(srcFile.URI(), "%", "%25", -1))
if err != nil {
return err
}
srcURL := strings.Replace(srcFile.URI(), "%", "%25", -1)

tgtURL, err := url.Parse(tgtFile.Location().(*Location).ContainerURL())
tgtURL := tgtFile.Location().(*Location).ContainerURL()

cli, err := a.newContainerClient(tgtURL)
if err != nil {
return err
}

containerURL := azblob.NewContainerURL(*tgtURL, a.pipeline)
blobURL := containerURL.NewBlockBlobURL(utils.RemoveLeadingSlash(tgtFile.Path()))
blobURL := cli.NewBlockBlobClient(utils.RemoveLeadingSlash(tgtFile.Path()))
ctx := context.Background()
resp, err := blobURL.StartCopyFromURL(ctx, *srcURL, azblob.Metadata{}, azblob.ModifiedAccessConditions{},
azblob.BlobAccessConditions{}, azblob.DefaultAccessTier, nil)
resp, err := blobURL.StartCopyFromURL(ctx, srcURL, nil)
if err != nil {
return err
}

for resp.CopyStatus() == azblob.CopyStatusPending {
for *resp.CopyStatus == blob.CopyStatusTypePending {
time.Sleep(2 * time.Second)
}

if resp.CopyStatus() == azblob.CopyStatusSuccess {
if *resp.CopyStatus == blob.CopyStatusTypeSuccess {
return nil
}

return fmt.Errorf("copy failed ERROR[%s]", resp.ErrorCode())
return errors.New("copy failed")
}

// List will return a listing of the contents of the given location. Each item in the list will contain the full key
// as specified by the azure blob (including the virtual 'path').
func (a *DefaultClient) List(l vfs.Location) ([]string, error) {
URL, err := url.Parse(l.(*Location).ContainerURL())
cli, err := a.newContainerClient(l.(*Location).ContainerURL())
if err != nil {
return []string{}, err
}

containerURL := azblob.NewContainerURL(*URL, a.pipeline)
pager := cli.NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{
Prefix: to.Ptr(utils.RemoveLeadingSlash(l.Path())),
Include: container.ListBlobsInclude{Metadata: true, Tags: true},
})
ctx := context.Background()
var list []string
for marker := (azblob.Marker{}); marker.NotDone(); {
listBlob, err := containerURL.ListBlobsHierarchySegment(ctx, marker, "/",
azblob.ListBlobsSegmentOptions{Prefix: utils.RemoveLeadingSlash(l.Path())})
for pager.More() {
listBlob, err := pager.NextPage(ctx)
if err != nil {
return []string{}, err
}

marker = listBlob.NextMarker

for i := range listBlob.Segment.BlobItems {
list = append(list, listBlob.Segment.BlobItems[i].Name)
for i := range listBlob.ListBlobsHierarchySegmentResponse.Segment.BlobItems {
list = append(list, *listBlob.ListBlobsHierarchySegmentResponse.Segment.BlobItems[i].Name)
}
}
return list, nil
}

// Delete deletes the given file from Azure Blob Storage.
func (a *DefaultClient) Delete(file vfs.File) error {
URL, err := url.Parse(file.Location().(*Location).ContainerURL())
cli, err := a.newContainerClient(file.Location().(*Location).ContainerURL())
if err != nil {
return err
}

containerURL := azblob.NewContainerURL(*URL, a.pipeline)
blobURL := containerURL.NewBlockBlobURL(utils.RemoveLeadingSlash(file.Path()))
_, err = blobURL.Delete(context.Background(), azblob.DeleteSnapshotsOptionNone, azblob.BlobAccessConditions{})
blobURL := cli.NewBlockBlobClient(utils.RemoveLeadingSlash(file.Path()))
_, err = blobURL.Delete(context.Background(), nil)
return err
}

Expand All @@ -224,22 +213,24 @@ func (a *DefaultClient) Delete(file vfs.File) error {
// If soft deletion is enabled for blobs in the storage account, each version will be marked for deletion and will be
// permanently deleted by Azure as per the soft deletion policy.
func (a *DefaultClient) DeleteAllVersions(file vfs.File) error {
URL, err := url.Parse(file.Location().(*Location).ContainerURL())
cli, err := a.newContainerClient(file.Location().(*Location).ContainerURL())
if err != nil {
return err
}
blobURL := cli.NewBlockBlobClient(utils.RemoveLeadingSlash(file.Path()))

containerURL := azblob.NewContainerURL(*URL, a.pipeline)
blobURL := containerURL.NewBlockBlobURL(utils.RemoveLeadingSlash(file.Path()))

versions, err := a.getBlobVersions(containerURL, utils.RemoveLeadingSlash(file.Path()))
versions, err := a.getBlobVersions(cli, utils.RemoveLeadingSlash(file.Path()))
if err != nil {
return err
}

for _, version := range versions {
// Delete a specific version
_, err = blobURL.WithVersionID(*version).Delete(context.Background(), azblob.DeleteSnapshotsOptionNone, azblob.BlobAccessConditions{})
cli, err := blobURL.WithVersionID(*version)
if err != nil {
return err
}
_, err = cli.Delete(context.Background(), nil)
if err != nil {
return err
}
Expand All @@ -248,23 +239,21 @@ func (a *DefaultClient) DeleteAllVersions(file vfs.File) error {
return err
}

func (a *DefaultClient) getBlobVersions(containerURL azblob.ContainerURL, blobName string) ([]*string, error) {
func (a *DefaultClient) getBlobVersions(cli *container.Client, blobName string) ([]*string, error) {
ctx := context.Background()
pager := cli.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
Prefix: &blobName,
Include: container.ListBlobsInclude{Versions: true},
})
var versions []*string
for marker := (azblob.Marker{}); marker.NotDone(); {
listBlob, err := containerURL.ListBlobsFlatSegment(ctx, marker,
azblob.ListBlobsSegmentOptions{Prefix: blobName, Details: azblob.BlobListingDetails{Versions: true}})
for pager.More() {
listBlob, err := pager.NextPage(ctx)
if err != nil {
return nil, err
return []*string{}, err
}

marker = listBlob.NextMarker

for i := range listBlob.Segment.BlobItems {
blobItem := listBlob.Segment.BlobItems[i]
if blobItem.VersionID != nil {
versions = append(versions, blobItem.VersionID)
}
for i := range listBlob.ListBlobsFlatSegmentResponse.Segment.BlobItems {
versions = append(versions, listBlob.ListBlobsFlatSegmentResponse.Segment.BlobItems[i].VersionID)
}
}
return versions, nil
Expand Down
28 changes: 15 additions & 13 deletions backend/azure/client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@ package azure
import (
"context"
"fmt"
"net/url"
"io"
"os"
"strings"
"testing"
"time"

"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
"github.com/stretchr/testify/suite"
)

type ClientIntegrationTestSuite struct {
suite.Suite
testContainerURL azblob.ContainerURL
testContainerURL *container.Client
accountName string
accountKey string
}
Expand All @@ -31,28 +34,27 @@ func (s *ClientIntegrationTestSuite) SetupSuite() {
panic(err)
}

p := azblob.NewPipeline(credential, azblob.PipelineOptions{})
baseURL, err := url.Parse(fmt.Sprintf("https://%s.blob.core.windows.net", s.accountName))
cli, err := container.NewClientWithSharedKeyCredential(fmt.Sprintf("https://%s.blob.core.windows.net", s.accountName), credential, nil)
s.NoError(err)
serviceURL := azblob.NewServiceURL(*baseURL, p)
s.testContainerURL = serviceURL.NewContainerURL("test-container")
_, err = s.testContainerURL.Create(context.Background(), azblob.Metadata{}, azblob.PublicAccessNone)
s.testContainerURL = cli

_, err = s.testContainerURL.Create(context.Background(), nil)
s.NoError(err)

// The create function claims to be synchronous but for some reason it does not exist for a little bit so
// we need to wait for it to be there.
_, err = s.testContainerURL.GetProperties(context.Background(), azblob.LeaseAccessConditions{})
_, err = s.testContainerURL.GetProperties(context.Background(), nil)
for {
time.Sleep(2 * time.Second)
if err == nil || err.(azblob.StorageError).ServiceCode() != "BlobNotFound" {
if err == nil || !bloberror.HasCode(err, bloberror.BlobNotFound) {
break
}
_, err = s.testContainerURL.GetProperties(context.Background(), azblob.LeaseAccessConditions{})
_, err = s.testContainerURL.GetProperties(context.Background(), nil)
}
}

func (s *ClientIntegrationTestSuite) TearDownSuite() {
_, err := s.testContainerURL.Delete(context.Background(), azblob.ContainerAccessConditions{})
_, err := s.testContainerURL.Delete(context.Background(), nil)
s.NoError(err)
}

Expand Down Expand Up @@ -205,7 +207,7 @@ func (s *ClientIntegrationTestSuite) TestProperties_NonExistentFile() {

_, err = client.Properties(f.Location().URI(), f.Path())
s.Error(err, "The file does not exist so we expect an error")
s.Equal(404, err.(azblob.ResponseError).Response().StatusCode)
s.Equal(404, err.(*azcore.ResponseError).StatusCode)
}

func (s *ClientIntegrationTestSuite) TestDelete_NonExistentFile() {
Expand Down
Loading

0 comments on commit 5f54ef9

Please sign in to comment.