Skip to content

Commit

Permalink
Use S3 to generate checksum (#4873)
Browse files Browse the repository at this point in the history
* Use MD5 hash generated by S3 in S3 storage

* Calculate content size from the stream without re-processing the entire stream

* Tidy up

* Use S3 sigests

* Ensure S3 storage can't be configured with an invalid digest algorithm
  • Loading branch information
shinyhappydan authored Apr 19, 2024
1 parent 0e20680 commit 82238c6
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import ch.epfl.bluebrain.nexus.delta.sdk.model.BaseUri
import ch.epfl.bluebrain.nexus.delta.sdk.model.search.PaginationConfig
import ch.epfl.bluebrain.nexus.delta.sdk.permissions.model.Permission
import ch.epfl.bluebrain.nexus.delta.sourcing.config.EventLogConfig
import pureconfig.ConfigReader.Result
import pureconfig.ConvertHelpers.{catchReadError, optF}
import pureconfig.error.{CannotConvert, ConfigReaderFailures, ConvertFailure, FailureReason}
import pureconfig.generic.auto._
import pureconfig.{ConfigConvert, ConfigReader}
import pureconfig.{ConfigConvert, ConfigObjectCursor, ConfigReader}

import scala.annotation.nowarn
import scala.concurrent.duration.FiniteDuration
Expand Down Expand Up @@ -80,6 +81,29 @@ object StoragesConfig {
val description: String = s"'allowed-volumes' must contain at least '$defaultVolume' (default-volume)"
}

final case class DigestNotSupportedOnS3(digestAlgorithm: DigestAlgorithm) extends FailureReason {
val description: String = s"Digest algorithm '${digestAlgorithm.value}' is not supported on S3"
}

private def assertValidS3Algorithm(
digestAlgorithm: DigestAlgorithm,
amazonCursor: ConfigObjectCursor
): Result[Unit] = {
digestAlgorithm.value match {
case "SHA-256" | "SHA-1" | "MD5" => Right(())
case _ =>
Left(
ConfigReaderFailures(
ConvertFailure(
DigestNotSupportedOnS3(digestAlgorithm),
None,
amazonCursor.atKeyOrUndefined("digest-algorithm").path
)
)
)
}
}

implicit val storageTypeConfigReader: ConfigReader[StorageTypeConfig] = ConfigReader.fromCursor { cursor =>
for {
obj <- cursor.asObjectCursor
Expand All @@ -96,6 +120,7 @@ object StoragesConfig {
amazonEnabledCursor <- amazonCursor.atKey("enabled")
amazonEnabled <- amazonEnabledCursor.asBoolean
amazon <- ConfigReader[S3StorageConfig].from(amazonCursor)
_ <- assertValidS3Algorithm(amazon.digestAlgorithm, amazonCursor)
remoteCursor <- obj.atKeyOrUndefined("remote-disk").asObjectCursor
remoteEnabledCursor <- remoteCursor.atKey("enabled")
remoteEnabled <- remoteEnabledCursor.asBoolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ object DigestAlgorithm {
final val default: DigestAlgorithm =
new DigestAlgorithm("SHA-256")

final val MD5: DigestAlgorithm =
new DigestAlgorithm("MD5")

/**
* Safely construct an [[DigestAlgorithm]]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ object StorageValue {
) extends StorageValue {

override val tpe: StorageType = StorageType.S3Storage

}

object S3StorageValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import akka.actor.ActorSystem
import akka.http.scaladsl.model.{BodyPartEntity, Uri}
import akka.stream.scaladsl.Source
import akka.util.ByteString
import cats.effect.IO
import cats.effect.{IO, Ref}
import cats.implicits._
import ch.epfl.bluebrain.nexus.delta.kernel.Logger
import ch.epfl.bluebrain.nexus.delta.kernel.utils.UUIDF
Expand All @@ -15,94 +15,78 @@ import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.model.DigestAlgori
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.model.Storage.S3Storage
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.FileOperations.intermediateFolders
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.StorageFileRejection.SaveFileRejection._
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.s3.S3StorageSaveFile.PutObjectRequestOps
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.s3.client.S3StorageClient
import ch.epfl.bluebrain.nexus.delta.rdf.syntax.uriSyntax
import ch.epfl.bluebrain.nexus.delta.sdk.stream.StreamConverter
import eu.timepit.refined.refineMV
import eu.timepit.refined.types.string.NonEmptyString
import fs2.Stream
import fs2.aws.s3.S3.MultipartETagValidation
import fs2.aws.s3.models.Models.{BucketName, ETag, FileKey, PartSizeMB}
import fs2.{Chunk, Pipe, Stream}
import org.apache.commons.codec.binary.Hex
import software.amazon.awssdk.core.async.AsyncRequestBody
import software.amazon.awssdk.services.s3.model._

import java.util.UUID
import java.util.{Base64, UUID}

final class S3StorageSaveFile(s3StorageClient: S3StorageClient)(implicit
as: ActorSystem,
uuidf: UUIDF
) {

private val s3 = s3StorageClient.underlyingClient
private val multipartETagValidation = MultipartETagValidation.create[IO]
private val logger = Logger[S3StorageSaveFile]
private val partSizeMB: PartSizeMB = refineMV(5)
private val s3 = s3StorageClient.underlyingClient
private val logger = Logger[S3StorageSaveFile]

def apply(
storage: S3Storage,
filename: String,
entity: BodyPartEntity
): IO[FileStorageMetadata] = {

val bucket = BucketName(NonEmptyString.unsafeFrom(storage.value.bucket))

def storeFile(key: String, uuid: UUID, entity: BodyPartEntity): IO[FileStorageMetadata] = {
val fileData: Stream[IO, Byte] = convertStream(entity.dataBytes)

(for {
_ <- log(key, s"Checking for object existence")
_ <- validateObjectDoesNotExist(bucket.value.value, key)
_ <- log(key, s"Beginning multipart upload")
maybeEtags <- uploadFileMultipart(fileData, bucket, key)
_ <- log(key, s"Finished multipart upload. Etag by part: $maybeEtags")
attr <- collectFileMetadata(fileData, key, uuid, maybeEtags)
} yield attr)
.onError(e => logger.error(e)("Unexpected error when storing file"))
.adaptError { err => UnexpectedSaveError(key, err.getMessage) }
}

def collectFileMetadata(
bytes: Stream[IO, Byte],
key: String,
uuid: UUID,
maybeEtags: List[Option[ETag]]
): IO[FileStorageMetadata] =
maybeEtags.sequence match {
case Some(onlyPartETag :: Nil) =>
// TODO our tests expect specific values for digests and the only algorithm currently used is SHA-256.
// If we want to continue to check this, but allow for different algorithms, this needs to be abstracted
// in the tests and verified for specific file contents.
// The result will als depend on whether we use a multipart upload or a standard put object.
for {
_ <- log(key, s"Received ETag for single part upload: $onlyPartETag")
fileSize <- computeSize(bytes)
digest <- computeDigest(bytes, storage.storageValue.algorithm)
metadata <- fileMetadata(key, uuid, fileSize, digest)
} yield metadata
case Some(other) => raiseUnexpectedErr(key, s"S3 multipart upload returned multiple etags unexpectedly: $other")
case None => raiseUnexpectedErr(key, "S3 multipart upload was aborted because no data was received")
}

def fileMetadata(key: String, uuid: UUID, fileSize: Long, digest: String) =
s3StorageClient.baseEndpoint.map { base =>
FileStorageMetadata(
uuid = uuid,
bytes = fileSize,
digest = Digest.ComputedDigest(storage.value.algorithm, digest),
origin = Client,
location = base / bucket.value.value / Uri.Path(key),
path = Uri.Path(key)
)
}

def log(key: String, msg: String) = logger.info(s"Bucket: ${bucket.value}. Key: $key. $msg")

for {
uuid <- uuidf()
path = Uri.Path(intermediateFolders(storage.project, uuid, filename))
result <- storeFile(path.toString(), uuid, entity)
result <- storeFile(storage.value.bucket, path.toString(), uuid, entity, storage.value.algorithm)
} yield result
}

private def storeFile(
bucket: String,
key: String,
uuid: UUID,
entity: BodyPartEntity,
algorithm: DigestAlgorithm
): IO[FileStorageMetadata] = {
val fileData: Stream[IO, Byte] = convertStream(entity.dataBytes)

(for {
_ <- log(bucket, key, s"Checking for object existence")
_ <- validateObjectDoesNotExist(bucket, key)
_ <- log(bucket, key, s"Beginning upload")
(digest, fileSize) <- uploadFile(fileData, bucket, key, algorithm)
_ <- log(bucket, key, s"Finished upload. Digest: $digest")
attr <- fileMetadata(bucket, key, uuid, fileSize, algorithm, digest)
} yield attr)
.onError(e => logger.error(e)("Unexpected error when storing file"))
.adaptError { err => UnexpectedSaveError(key, err.getMessage) }
}

private def fileMetadata(
bucket: String,
key: String,
uuid: UUID,
fileSize: Long,
algorithm: DigestAlgorithm,
digest: String
): IO[FileStorageMetadata] =
s3StorageClient.baseEndpoint.map { base =>
FileStorageMetadata(
uuid = uuid,
bytes = fileSize,
digest = Digest.ComputedDigest(algorithm, digest),
origin = Client,
location = base / bucket / Uri.Path(key),
path = Uri.Path(key)
)
}

private def validateObjectDoesNotExist(bucket: String, key: String) =
getFileAttributes(bucket, key).redeemWith(
{
Expand All @@ -119,42 +103,70 @@ final class S3StorageSaveFile(s3StorageClient: S3StorageClient)(implicit
.mapMaterializedValue(_ => NotUsed)
)

private def uploadFileMultipart(fileData: Stream[IO, Byte], bucket: BucketName, key: String): IO[List[Option[ETag]]] =
fileData
.through(
s3.uploadFileMultipart(
bucket,
FileKey(NonEmptyString.unsafeFrom(key)),
partSizeMB,
uploadEmptyFiles = true,
multipartETagValidation = multipartETagValidation.some
)
)
.compile
.to(List)
private def uploadFile(
fileData: Stream[IO, Byte],
bucket: String,
key: String,
algorithm: DigestAlgorithm
): IO[(String, Long)] = {
for {
fileSizeAcc <- Ref.of[IO, Long](0L)
digest <- fileData
.evalTap(_ => fileSizeAcc.update(_ + 1))
.through(
uploadFilePipe(bucket, key, algorithm)
)
.compile
.onlyOrError
fileSize <- fileSizeAcc.get
} yield (digest, fileSize)
}

private def uploadFilePipe(bucket: String, key: String, algorithm: DigestAlgorithm): Pipe[IO, Byte, String] = { in =>
fs2.Stream.eval {
in.compile.to(Chunk).flatMap { chunks =>
val bs = chunks.toByteBuffer
for {
response <- s3.putObject(
PutObjectRequest
.builder()
.bucket(bucket)
.deltaDigest(algorithm)
.key(key)
.build(),
AsyncRequestBody.fromByteBuffer(bs)
)
} yield {
checksumFromResponse(response, algorithm)
}
}
}
}

private def checksumFromResponse(response: PutObjectResponse, algorithm: DigestAlgorithm): String = {
algorithm.value match {
case "MD5" => response.eTag().stripPrefix("\"").stripSuffix("\"")
case "SHA-256" => Hex.encodeHexString(Base64.getDecoder.decode(response.checksumSHA256()))
case "SHA-1" => Hex.encodeHexString(Base64.getDecoder.decode(response.checksumSHA1()))
case _ => throw new IllegalArgumentException(s"Unsupported algorithm for S3: ${algorithm.value}")
}
}

private def getFileAttributes(bucket: String, key: String): IO[GetObjectAttributesResponse] =
s3StorageClient.getFileAttributes(bucket, key)

// TODO issue fetching attributes when tested against localstack, only after the object is saved
// Verify if it's the same for real S3. Error msg: 'Could not parse XML response.'
// For now we just compute it manually.
// private def getFileSize(key: String) =
// getFileAttributes(key).flatMap { attr =>
// log(key, s"File attributes from S3: $attr").as(attr.objectSize())
// }
private def computeSize(bytes: Stream[IO, Byte]): IO[Long] = bytes.fold(0L)((acc, _) => acc + 1).compile.lastOrError

private def computeDigest(bytes: Stream[IO, Byte], algorithm: DigestAlgorithm): IO[String] = {
val digest = algorithm.digest
bytes.chunks
.evalMap(chunk => IO(digest.update(chunk.toArray)))
.compile
.last
.map { _ =>
digest.digest().map("%02x".format(_)).mkString
private def log(bucket: String, key: String, msg: String): IO[Unit] =
logger.info(s"Bucket: ${bucket}. Key: $key. $msg")
}

object S3StorageSaveFile {
implicit class PutObjectRequestOps(request: PutObjectRequest.Builder) {
def deltaDigest(algorithm: DigestAlgorithm): PutObjectRequest.Builder =
algorithm.value match {
case "MD5" => request
case "SHA-256" => request.checksumAlgorithm(ChecksumAlgorithm.SHA256)
case "SHA-1" => request.checksumAlgorithm(ChecksumAlgorithm.SHA1)
case _ => throw new IllegalArgumentException(s"Unsupported algorithm for S3: ${algorithm.value}")
}
}

private def raiseUnexpectedErr[A](key: String, msg: String): IO[A] = IO.raiseError(UnexpectedSaveError(key, msg))
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ trait S3StorageClient {

def getFileAttributes(bucket: String, key: String): IO[GetObjectAttributesResponse]

def underlyingClient: S3[IO]
def underlyingClient: S3AsyncClientOp[IO]

def baseEndpoint: IO[Uri]
}
Expand Down Expand Up @@ -80,7 +80,7 @@ object S3StorageClient {
.build()
)

override def underlyingClient: S3[IO] = s3
override def underlyingClient: S3AsyncClientOp[IO] = client

override def baseEndpoint: IO[Uri] = IO.pure(baseEndpoint)
}
Expand All @@ -97,7 +97,7 @@ object S3StorageClient {

override def getFileAttributes(bucket: String, key: String): IO[GetObjectAttributesResponse] = raiseDisabledErr

override def underlyingClient: S3[IO] = throw disabledErr
override def underlyingClient: S3AsyncClientOp[IO] = throw disabledErr

override def baseEndpoint: IO[Uri] = raiseDisabledErr
}
Expand Down
Loading

0 comments on commit 82238c6

Please sign in to comment.