Skip to content

Commit

Permalink
Implement DynamoDBLock (#4880)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Oct 27, 2023
1 parent e3cce56 commit b3e5f8e
Show file tree
Hide file tree
Showing 6 changed files with 507 additions and 89 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/object_store.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: test
AWS_ENDPOINT: http://localhost:4566
AWS_ALLOW_HTTP: true
AWS_COPY_IF_NOT_EXISTS: dynamo:test-table
HTTP_URL: "http://localhost:8080"
GOOGLE_BUCKET: test-bucket
GOOGLE_SERVICE_ACCOUNT: "/tmp/gcs.json"
Expand All @@ -136,6 +137,7 @@ jobs:
docker run -d -p 4566:4566 localstack/localstack:2.0
docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2
aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket
aws --endpoint-url=http://localhost:4566 dynamodb create-table --table-name test-table --key-schema AttributeName=key,KeyType=HASH --attribute-definitions AttributeName=key,AttributeType=S --provisioned-throughput ReadCapacityUnits=5,WriteCapacityUnits=5
- name: Configure Azurite (Azure emulation)
# the magical connection string is from
Expand Down
24 changes: 10 additions & 14 deletions object_store/src/aws/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,27 +823,23 @@ impl AmazonS3Builder {
)) as _
};

let endpoint: String;
let bucket_endpoint: String;

// If `endpoint` is provided then its assumed to be consistent with
// `virtual_hosted_style_request`. i.e. if `virtual_hosted_style_request` is true then
// `endpoint` should have bucket name included.
if self.virtual_hosted_style_request.get()? {
endpoint = self
.endpoint
.unwrap_or_else(|| format!("https://{bucket}.s3.{region}.amazonaws.com"));
bucket_endpoint = endpoint.clone();
let bucket_endpoint = if self.virtual_hosted_style_request.get()? {
self.endpoint
.clone()
.unwrap_or_else(|| format!("https://{bucket}.s3.{region}.amazonaws.com"))
} else {
endpoint = self
.endpoint
.unwrap_or_else(|| format!("https://s3.{region}.amazonaws.com"));
bucket_endpoint = format!("{endpoint}/{bucket}");
}
match &self.endpoint {
None => format!("https://s3.{region}.amazonaws.com/{bucket}"),
Some(endpoint) => format!("{endpoint}/{bucket}"),
}
};

let config = S3Config {
region,
endpoint,
endpoint: self.endpoint,
bucket,
bucket_endpoint,
credentials,
Expand Down
99 changes: 36 additions & 63 deletions object_store/src/aws/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::aws::{
AwsCredentialProvider, S3ConditionalPut, S3CopyIfNotExists, STORE, STRICT_PATH_ENCODE_SET,
};
use crate::client::get::GetClient;
use crate::client::header::HeaderConfig;
use crate::client::header::{get_etag, HeaderConfig};
use crate::client::header::{get_put_result, get_version};
use crate::client::list::ListClient;
use crate::client::retry::RetryExt;
Expand All @@ -39,13 +39,14 @@ use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::{Buf, Bytes};
use hyper::http;
use hyper::http::HeaderName;
use itertools::Itertools;
use percent_encoding::{utf8_percent_encode, PercentEncode};
use quick_xml::events::{self as xml_events};
use reqwest::{
header::{CONTENT_LENGTH, CONTENT_TYPE},
Client as ReqwestClient, Method, RequestBuilder, Response, StatusCode,
Client as ReqwestClient, Method, RequestBuilder, Response,
};
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
Expand Down Expand Up @@ -196,7 +197,7 @@ impl From<DeleteError> for Error {
#[derive(Debug)]
pub struct S3Config {
pub region: String,
pub endpoint: String,
pub endpoint: Option<String>,
pub bucket: String,
pub bucket_endpoint: String,
pub credentials: AwsCredentialProvider,
Expand All @@ -214,34 +215,38 @@ impl S3Config {
format!("{}/{}", self.bucket_endpoint, encode_path(path))
}

async fn get_credential(&self) -> Result<Option<Arc<AwsCredential>>> {
pub(crate) async fn get_credential(&self) -> Result<Option<Arc<AwsCredential>>> {
Ok(match self.skip_signature {
false => Some(self.credentials.get_credential().await?),
true => None,
})
}
}

/// A builder for a put request allowing customisation of the headers and query string
pub(crate) struct PutRequest<'a> {
/// A builder for a request allowing customisation of the headers and query string
pub(crate) struct Request<'a> {
path: &'a Path,
config: &'a S3Config,
builder: RequestBuilder,
payload_sha256: Option<Vec<u8>>,
}

impl<'a> PutRequest<'a> {
impl<'a> Request<'a> {
pub fn query<T: Serialize + ?Sized + Sync>(self, query: &T) -> Self {
let builder = self.builder.query(query);
Self { builder, ..self }
}

pub fn header(self, k: &HeaderName, v: &str) -> Self {
pub fn header<K>(self, k: K, v: &str) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
{
let builder = self.builder.header(k, v);
Self { builder, ..self }
}

pub async fn send(self) -> Result<PutResult> {
pub async fn send(self) -> Result<Response> {
let credential = self.config.get_credential().await?;

let response = self
Expand All @@ -259,14 +264,19 @@ impl<'a> PutRequest<'a> {
path: self.path.as_ref(),
})?;

Ok(response)
}

pub async fn do_put(self) -> Result<PutResult> {
let response = self.send().await?;
Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?)
}
}

#[derive(Debug)]
pub(crate) struct S3Client {
config: S3Config,
client: ReqwestClient,
pub config: S3Config,
pub client: ReqwestClient,
}

impl S3Client {
Expand All @@ -275,20 +285,15 @@ impl S3Client {
Ok(Self { config, client })
}

/// Returns the config
pub fn config(&self) -> &S3Config {
&self.config
}

/// Make an S3 PUT request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html>
///
/// Returns the ETag
pub fn put_request<'a>(&'a self, path: &'a Path, bytes: Bytes) -> PutRequest<'a> {
pub fn put_request<'a>(&'a self, path: &'a Path, bytes: Bytes) -> Request<'a> {
let url = self.config.path_url(path);
let mut builder = self.client.request(Method::PUT, url);
let mut payload_sha256 = None;

if let Some(checksum) = self.config().checksum {
if let Some(checksum) = self.config.checksum {
let digest = checksum.digest(&bytes);
builder = builder.header(checksum.header_name(), BASE64_STANDARD.encode(&digest));
if checksum == Checksum::SHA256 {
Expand All @@ -301,11 +306,11 @@ impl S3Client {
false => builder.body(bytes),
};

if let Some(value) = self.config().client_options.get_content_type(path) {
if let Some(value) = self.config.client_options.get_content_type(path) {
builder = builder.header(CONTENT_TYPE, value);
}

PutRequest {
Request {
path,
builder,
payload_sha256,
Expand Down Expand Up @@ -399,7 +404,7 @@ impl S3Client {

// Compute checksum - S3 *requires* this for DeleteObjects requests, so we default to
// their algorithm if the user hasn't specified one.
let checksum = self.config().checksum.unwrap_or(Checksum::SHA256);
let checksum = self.config.checksum.unwrap_or(Checksum::SHA256);
let digest = checksum.digest(&body);
builder = builder.header(checksum.header_name(), BASE64_STANDARD.encode(&digest));
let payload_sha256 = if checksum == Checksum::SHA256 {
Expand Down Expand Up @@ -450,52 +455,21 @@ impl S3Client {
}

/// Make an S3 Copy request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html>
pub async fn copy_request(&self, from: &Path, to: &Path, overwrite: bool) -> Result<()> {
let credential = self.config.get_credential().await?;
pub fn copy_request<'a>(&'a self, from: &Path, to: &'a Path) -> Request<'a> {
let url = self.config.path_url(to);
let source = format!("{}/{}", self.config.bucket, encode_path(from));

let mut builder = self
let builder = self
.client
.request(Method::PUT, url)
.header("x-amz-copy-source", source);

if !overwrite {
match &self.config.copy_if_not_exists {
Some(S3CopyIfNotExists::Header(k, v)) => {
builder = builder.header(k, v);
}
None => {
return Err(crate::Error::NotSupported {
source: "S3 does not support copy-if-not-exists".to_string().into(),
})
}
}
Request {
builder,
path: to,
config: &self.config,
payload_sha256: None,
}

builder
.with_aws_sigv4(
credential.as_deref(),
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
.map_err(|source| match source.status() {
Some(StatusCode::PRECONDITION_FAILED) => crate::Error::AlreadyExists {
source: Box::new(source),
path: to.to_string(),
},
_ => Error::CopyRequest {
source,
path: from.to_string(),
}
.into(),
})?;

Ok(())
}

pub async fn create_multipart(&self, location: &Path) -> Result<MultipartId> {
Expand Down Expand Up @@ -534,15 +508,14 @@ impl S3Client {
) -> Result<PartId> {
let part = (part_idx + 1).to_string();

let result = self
let response = self
.put_request(path, data)
.query(&[("partNumber", &part), ("uploadId", upload_id)])
.send()
.await?;

Ok(PartId {
content_id: result.e_tag.unwrap(),
})
let content_id = get_etag(response.headers()).context(MetadataSnafu)?;
Ok(PartId { content_id })
}

pub async fn complete_multipart(
Expand Down
Loading

0 comments on commit b3e5f8e

Please sign in to comment.