diff --git a/bazel/oci/containers.bzl b/bazel/oci/containers.bzl index 9d6a026547..8f9e47db1b 100644 --- a/bazel/oci/containers.bzl +++ b/bazel/oci/containers.bzl @@ -55,6 +55,14 @@ def containers(): "repotag_file": "//bazel/release:libvirt_tag.txt", "used_by": ["config"], }, + { + "identifier": "s3proxy", + "image_name": "s3proxy", + "name": "s3proxy", + "oci": "//s3proxy/cmd:s3proxy", + "repotag_file": "//bazel/release:s3proxy_tag.txt", + "used_by": ["config"], + }, ] def helm_containers(): diff --git a/bazel/toolchains/go_module_deps.bzl b/bazel/toolchains/go_module_deps.bzl index f7ef9f3545..4bb3bc8bba 100644 --- a/bazel/toolchains/go_module_deps.bzl +++ b/bazel/toolchains/go_module_deps.bzl @@ -3050,6 +3050,7 @@ def go_dependencies(): sum = "h1:YjkZLJ7K3inKgMZ0wzCU9OHqc+UqMQyXsPXnf3Cl2as=", version = "v1.9.2", ) + go_repository( name = "com_github_hexops_gotextdiff", build_file_generation = "on", @@ -5033,6 +5034,15 @@ def go_dependencies(): sum = "h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=", version = "v1.2.0", ) + go_repository( + name = "com_github_tink_crypto_tink_go_v2", + build_file_generation = "on", + build_file_proto_mode = "disable_global", + importpath = "github.com/tink-crypto/tink-go/v2", + replace = "github.com/derpsteb/tink-go/v2", + sum = "h1:FVii9oXvddz9sFir5TRYjQKrzJLbVD/hibT+SnRSDzg=", + version = "v2.0.0-20231002051717-a808e454eed6", + ) go_repository( name = "com_github_titanous_rocacheck", diff --git a/go.mod b/go.mod index a4c06a41f5..295ee217bb 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ replace ( github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/api => ./operators/constellation-node-operator/api github.com/google/go-tpm => github.com/thomasten/go-tpm v0.0.0-20230629092004-f43f8e2a59eb github.com/martinjungblut/go-cryptsetup => github.com/daniel-weisse/go-cryptsetup v0.0.0-20230705150314-d8c07bd1723c + github.com/tink-crypto/tink-go/v2 v2.0.0 => github.com/derpsteb/tink-go/v2 v2.0.0-20231002051717-a808e454eed6 ) require ( @@ -108,6 +109,7 @@ require ( github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 github.com/theupdateframework/go-tuf v0.5.2 + github.com/tink-crypto/tink-go/v2 v2.0.0 go.uber.org/goleak v1.2.1 go.uber.org/zap v1.26.0 golang.org/x/crypto v0.13.0 diff --git a/go.sum b/go.sum index da0467ccde..632e19e438 100644 --- a/go.sum +++ b/go.sum @@ -298,6 +298,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.9.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/derpsteb/tink-go/v2 v2.0.0-20231002051717-a808e454eed6 h1:FVii9oXvddz9sFir5TRYjQKrzJLbVD/hibT+SnRSDzg= +github.com/derpsteb/tink-go/v2 v2.0.0-20231002051717-a808e454eed6/go.mod h1:QAbyq9LZncomYnScxlfaHImbV4ieNIe6bnu/Xcqqox4= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dimchansky/utfbom v1.1.1 h1:vV6w1AhK4VMnhBno/TPVCoK9U/LP0PkLCS9tbxHdi/U= diff --git a/s3proxy/cmd/BUILD.bazel b/s3proxy/cmd/BUILD.bazel new file mode 100644 index 0000000000..a601597b6d --- /dev/null +++ b/s3proxy/cmd/BUILD.bazel @@ -0,0 +1,47 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_cross_binary", "go_library") +load("@rules_oci//oci:defs.bzl", "oci_image") +load("@rules_pkg//:pkg.bzl", "pkg_tar") + +go_library( + name = "cmd_lib", + srcs = ["main.go"], + importpath = "github.com/edgelesssys/constellation/v2/s3proxy/cmd", + visibility = ["//visibility:private"], + deps = [ + "//internal/logger", + "//s3proxy/internal/router", + "@org_uber_go_zap//:zap", + ], +) + +go_binary( + name = "cmd", + embed = [":cmd_lib"], + visibility = ["//visibility:public"], +) + +go_cross_binary( + name = "s3proxy_linux_amd64", + platform = "@io_bazel_rules_go//go/toolchain:linux_amd64", + target = ":cmd", + visibility = ["//visibility:public"], +) + +pkg_tar( + name = "layer", + srcs = [ + ":s3proxy_linux_amd64", + ], + mode = "0755", + remap_paths = {"/s3proxy_linux_amd64": "/s3proxy"}, +) + +oci_image( + name = "s3proxy", + base = "@distroless_static_linux_amd64", + entrypoint = ["/s3proxy"], + tars = [ + ":layer", + ], + visibility = ["//visibility:public"], +) diff --git a/s3proxy/cmd/main.go b/s3proxy/cmd/main.go new file mode 100644 index 0000000000..5b471d8afa --- /dev/null +++ b/s3proxy/cmd/main.go @@ -0,0 +1,133 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +/* +Package main parses command line flags and starts the s3proxy server. +*/ +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "net" + "net/http" + + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/s3proxy/internal/router" + "go.uber.org/zap" +) + +const ( + // defaultPort is the default port to listen on. + defaultPort = 4433 + // defaultIP is the default IP to listen on. + defaultIP = "0.0.0.0" + // defaultRegion is the default AWS region to use. + defaultRegion = "eu-west-1" + // defaultCertLocation is the default location of the TLS certificate. + defaultCertLocation = "/etc/s3proxy/certs" + // defaultLogLevel is the default log level. + defaultLogLevel = 0 +) + +func main() { + flags, err := parseFlags() + if err != nil { + panic(err) + } + + // logLevel can be made a public variable so logging level can be changed dynamically. + // TODO (derpsteb): enable once we are on go 1.21. + // logLevel := new(slog.LevelVar) + // handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: logLevel}) + // logger := slog.New(handler) + // logLevel.Set(flags.logLevel) + + logger := logger.New(logger.JSONLog, logger.VerbosityFromInt(flags.logLevel)) + + if err := runServer(flags, logger); err != nil { + panic(err) + } +} + +func runServer(flags cmdFlags, log *logger.Logger) error { + log.With(zap.String("ip", flags.ip), zap.Int("port", defaultPort), zap.String("region", flags.region)).Infof("listening") + + router, err := router.New(flags.region, flags.kmsEndpoint, log) + if err != nil { + return fmt.Errorf("creating router: %w", err) + } + + server := http.Server{ + Addr: fmt.Sprintf("%s:%d", flags.ip, defaultPort), + Handler: http.HandlerFunc(router.Serve), + // Disable HTTP/2. Serving HTTP/2 will cause some clients to use HTTP/2. + // It seems like AWS S3 does not support HTTP/2. + // Having HTTP/2 enabled will at least cause the aws-sdk-go V1 copy-object operation to fail. + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, + } + + // i.e. if TLS is enabled. + if !flags.noTLS { + cert, err := tls.LoadX509KeyPair(flags.certLocation+"/s3proxy.crt", flags.certLocation+"/s3proxy.key") + if err != nil { + return fmt.Errorf("loading TLS certificate: %w", err) + } + + server.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + // TLSConfig is populated, so we can safely pass empty strings to ListenAndServeTLS. + return server.ListenAndServeTLS("", "") + } + + log.Warnf("TLS is disabled") + return server.ListenAndServe() +} + +func parseFlags() (cmdFlags, error) { + noTLS := flag.Bool("no-tls", false, "disable TLS and listen on port 80, otherwise listen on 443") + ip := flag.String("ip", defaultIP, "ip to listen on") + region := flag.String("region", defaultRegion, "AWS region in which target bucket is located") + certLocation := flag.String("cert", defaultCertLocation, "location of TLS certificate") + kmsEndpoint := flag.String("kms", "key-service.kube-system:9000", "endpoint of the KMS service to get key encryption keys from") + level := flag.Int("level", defaultLogLevel, "log level") + + flag.Parse() + + netIP := net.ParseIP(*ip) + if netIP == nil { + return cmdFlags{}, fmt.Errorf("not a valid IPv4 address: %s", *ip) + } + + // TODO(derpsteb): enable once we are on go 1.21. + // logLevel := new(slog.Level) + // if err := logLevel.UnmarshalText([]byte(*level)); err != nil { + // return cmdFlags{}, fmt.Errorf("parsing log level: %w", err) + // } + + return cmdFlags{ + noTLS: *noTLS, + ip: netIP.String(), + region: *region, + certLocation: *certLocation, + kmsEndpoint: *kmsEndpoint, + logLevel: *level, + }, nil +} + +type cmdFlags struct { + noTLS bool + ip string + region string + certLocation string + kmsEndpoint string + // TODO(derpsteb): enable once we are on go 1.21. + // logLevel slog.Level + logLevel int +} diff --git a/s3proxy/deploy/README.md b/s3proxy/deploy/README.md new file mode 100644 index 0000000000..e95d802061 --- /dev/null +++ b/s3proxy/deploy/README.md @@ -0,0 +1,63 @@ +# Deploying s3proxy + +**Caution:** Using s3proxy outside Constellation is insecure as the connection between the key management service (KMS) and s3proxy is protected by Constellation's WireGuard VPN. +The VPN is a feature of Constellation and will not be present by default in other environments. + +Disclaimer: the following steps will be automated next. +- Within `constellation/build`: `bazel run //:devbuild` +- Copy the container name displayed for the s3proxy image. Look for the line starting with `[@//bazel/release:s3proxy_push]`. +- Replace the image key in `deployment-s3proxy.yaml` with the image value you just copied. Use the sha256 hash instead of the tag to make sure you use the latest image. +- Replace the `replaceme` values with valid AWS credentials. The s3proxy uses those credentials to access S3. +- Run `kubectl apply -f deployment-s3proxy.yaml` + +# Deploying Filestash + +Filestash is a demo application that can be used to see s3proxy in action. +To deploy Filestash, first deploy s3proxy as described above. +Then run the below commands: + +```sh +$ cat << EOF > "deployment-filestash.yaml" +apiVersion: apps/v1 +kind: Deployment +metadata: + name: filestash +spec: + replicas: 1 + selector: + matchLabels: + app: filestash + template: + metadata: + labels: + app: filestash + spec: + imagePullSecrets: + - name: regcred + hostAliases: + - ip: $(kubectl get svc s3proxy-service -o=jsonpath='{.spec.clusterIP}') + hostnames: + - "s3.eu-west-1.amazonaws.com" + containers: + - name: filestash + image: machines/filestash:latest + ports: + - containerPort: 8334 + volumeMounts: + - name: ca-cert + mountPath: /etc/ssl/certs/kube-ca.crt + subPath: kube-ca.crt + volumes: + - name: ca-cert + secret: + secretName: s3proxy-tls + items: + - key: ca.crt + path: kube-ca.crt +EOF + +$ kubectl apply -f deployment-filestash.yaml +``` + +Afterwards you can use a port forward to access the Filestash pod: +- `kubectl port-forward pod/$(kubectl get pod --selector='app=filestash' -o=jsonpath='{.items[*].metadata.name}') 8443:8443` diff --git a/s3proxy/deploy/deployment-s3proxy.yaml b/s3proxy/deploy/deployment-s3proxy.yaml new file mode 100644 index 0000000000..441770effb --- /dev/null +++ b/s3proxy/deploy/deployment-s3proxy.yaml @@ -0,0 +1,94 @@ +apiVersion: cert-manager.io/v1 +kind: Issuer +metadata: + name: selfsigned-issuer + labels: + app: s3proxy +spec: + selfSigned: {} +--- +apiVersion: cert-manager.io/v1 +kind: Certificate +metadata: + name: selfsigned-ca + labels: + app: s3proxy +spec: + isCA: true + commonName: s3proxy-selfsigned-ca + secretName: s3proxy-tls + privateKey: + algorithm: ECDSA + size: 256 + dnsNames: + - "s3.eu-west-1.amazonaws.com" + issuerRef: + name: selfsigned-issuer + kind: ClusterIssuer + group: cert-manager.io +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: s3proxy + labels: + app: s3proxy +spec: + replicas: 1 + selector: + matchLabels: + app: s3proxy + template: + metadata: + labels: + app: s3proxy + spec: + imagePullSecrets: + - name: regcred + containers: + - name: s3proxy + image: ghcr.io/edgelesssys/constellation/s3proxy@sha256:2394a804e8b5ff487a55199dd83138885322a4de8e71ac7ce67b79d4ffc842b2 + args: + - "--level=-1" + ports: + - containerPort: 4433 + name: s3proxy-port + volumeMounts: + - name: tls-cert-data + mountPath: /etc/s3proxy/certs/s3proxy.crt + subPath: tls.crt + - name: tls-cert-data + mountPath: /etc/s3proxy/certs/s3proxy.key + subPath: tls.key + envFrom: + - secretRef: + name: s3-creds + volumes: + - name: tls-cert-data + secret: + secretName: s3proxy-tls + - name: s3-creds + secret: + secretName: s3-creds +--- +apiVersion: v1 +kind: Service +metadata: + name: s3proxy-service +spec: + selector: + app: s3proxy + ports: + - name: https + port: 443 + targetPort: s3proxy-port + type: ClusterIP +--- +apiVersion: v1 +kind: Secret +metadata: + name: s3-creds +type: Opaque +stringData: + AWS_ACCESS_KEY_ID: "replaceme" + AWS_SECRET_ACCESS_KEY: "replaceme" diff --git a/s3proxy/internal/crypto/BUILD.bazel b/s3proxy/internal/crypto/BUILD.bazel new file mode 100644 index 0000000000..cf29bfe783 --- /dev/null +++ b/s3proxy/internal/crypto/BUILD.bazel @@ -0,0 +1,24 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//bazel/go:go_test.bzl", "go_test") + +go_library( + name = "crypto", + srcs = ["crypto.go"], + importpath = "github.com/edgelesssys/constellation/v2/s3proxy/internal/crypto", + visibility = ["//s3proxy:__subpackages__"], + deps = [ + "@com_github_tink_crypto_tink_go_v2//aead/subtle", + "@com_github_tink_crypto_tink_go_v2//kwp/subtle", + "@com_github_tink_crypto_tink_go_v2//subtle/random", + ], +) + +go_test( + name = "crypto_test", + srcs = ["crypto_test.go"], + embed = [":crypto"], + deps = [ + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/s3proxy/internal/crypto/crypto.go b/s3proxy/internal/crypto/crypto.go new file mode 100644 index 0000000000..bdc117a7bc --- /dev/null +++ b/s3proxy/internal/crypto/crypto.go @@ -0,0 +1,73 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +/* +Package crypto provides encryption and decryption functions for the s3proxy. +It uses AES-256-GCM to encrypt and decrypt data. +*/ +package crypto + +import ( + "fmt" + + aeadsubtle "github.com/tink-crypto/tink-go/v2/aead/subtle" + kwpsubtle "github.com/tink-crypto/tink-go/v2/kwp/subtle" + "github.com/tink-crypto/tink-go/v2/subtle/random" +) + +// Encrypt generates a random key to encrypt a plaintext using AES-256-GCM. +// The generated key is encrypted using the supplied key encryption key (KEK). +// The ciphertext and encrypted data encryption key (DEK) are returned. +func Encrypt(plaintext []byte, kek [32]byte) (ciphertext []byte, encryptedDEK []byte, err error) { + dek := random.GetRandomBytes(32) + aesgcm, err := aeadsubtle.NewAESGCMSIV(dek) + if err != nil { + return nil, nil, fmt.Errorf("getting aesgcm: %w", err) + } + + ciphertext, err = aesgcm.Encrypt(plaintext, []byte("")) + if err != nil { + return nil, nil, fmt.Errorf("encrypting plaintext: %w", err) + } + + keywrapper, err := kwpsubtle.NewKWP(kek[:]) + if err != nil { + return nil, nil, fmt.Errorf("getting kwp: %w", err) + } + + encryptedDEK, err = keywrapper.Wrap(dek) + if err != nil { + return nil, nil, fmt.Errorf("wrapping dek: %w", err) + } + + return ciphertext, encryptedDEK, nil +} + +// Decrypt decrypts a ciphertext using AES-256-GCM. +// The encrypted DEK is decrypted using the supplied KEK. +func Decrypt(ciphertext, encryptedDEK []byte, kek [32]byte) ([]byte, error) { + keywrapper, err := kwpsubtle.NewKWP(kek[:]) + if err != nil { + return nil, fmt.Errorf("getting kwp: %w", err) + } + + dek, err := keywrapper.Unwrap(encryptedDEK) + if err != nil { + return nil, fmt.Errorf("unwrapping dek: %w", err) + } + + aesgcm, err := aeadsubtle.NewAESGCMSIV(dek) + if err != nil { + return nil, fmt.Errorf("getting aesgcm: %w", err) + } + + plaintext, err := aesgcm.Decrypt(ciphertext, []byte("")) + if err != nil { + return nil, fmt.Errorf("decrypting ciphertext: %w", err) + } + + return plaintext, nil +} diff --git a/s3proxy/internal/crypto/crypto_test.go b/s3proxy/internal/crypto/crypto_test.go new file mode 100644 index 0000000000..4fb17e87db --- /dev/null +++ b/s3proxy/internal/crypto/crypto_test.go @@ -0,0 +1,48 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ +package crypto + +import ( + "crypto/rand" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncryptDecrypt(t *testing.T) { + tests := map[string]struct { + plaintext []byte + }{ + "simple": { + plaintext: []byte("hello, world"), + }, + "long": { + plaintext: []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed non risus. Suspendisse lectus tortor, dignissim sit amet, adipiscing nec, ultricies sed, dolor."), + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + kek := [32]byte{} + _, err := rand.Read(kek[:]) + require.NoError(t, err) + + ciphertext, encryptedDEK, err := Encrypt(tt.plaintext, kek) + require.NoError(t, err) + + assert.NotContains(t, ciphertext, tt.plaintext) + + // Decrypt the ciphertext using the KEK and encrypted DEK + decrypted, err := Decrypt(ciphertext, encryptedDEK, kek) + require.NoError(t, err) + + // Verify that the decrypted plaintext matches the original plaintext + assert.Equal(t, tt.plaintext, decrypted, fmt.Sprintf("expected plaintext %s, got %s", tt.plaintext, decrypted)) + }) + } +} diff --git a/s3proxy/internal/kms/BUILD.bazel b/s3proxy/internal/kms/BUILD.bazel new file mode 100644 index 0000000000..e4d4d25b82 --- /dev/null +++ b/s3proxy/internal/kms/BUILD.bazel @@ -0,0 +1,29 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//bazel/go:go_test.bzl", "go_test") + +go_library( + name = "kms", + srcs = ["kms.go"], + importpath = "github.com/edgelesssys/constellation/v2/s3proxy/internal/kms", + visibility = ["//s3proxy:__subpackages__"], + deps = [ + "//internal/logger", + "//keyservice/keyserviceproto", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//credentials/insecure", + ], +) + +go_test( + name = "kms_test", + srcs = ["kms_test.go"], + embed = [":kms"], + deps = [ + "//internal/logger", + "//keyservice/keyserviceproto", + "@com_github_stretchr_testify//assert", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//test/bufconn", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/s3proxy/internal/kms/kms.go b/s3proxy/internal/kms/kms.go new file mode 100644 index 0000000000..24e53ed5c8 --- /dev/null +++ b/s3proxy/internal/kms/kms.go @@ -0,0 +1,76 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +/* +Package kms is used to interact with the Constellation keyservice. +So far it is a copy of the joinservice's kms package. +*/ +package kms + +import ( + "context" + "fmt" + + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/keyservice/keyserviceproto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// Client interacts with Constellation's keyservice. +type Client struct { + log *logger.Logger + endpoint string + grpc grpcClient +} + +// New creates a new KMS. +func New(log *logger.Logger, endpoint string) Client { + return Client{ + log: log, + endpoint: endpoint, + grpc: client{}, + } +} + +// GetDataKey returns a data encryption key for the given UUID. +func (c Client) GetDataKey(ctx context.Context, keyID string, length int) ([]byte, error) { + log := c.log.With("keyID", keyID, "endpoint", c.endpoint) + // the KMS does not use aTLS since traffic is only routed through the Constellation cluster + // cluster internal connections are considered trustworthy + log.Infof("Connecting to KMS") + conn, err := grpc.DialContext(ctx, c.endpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + + log.Infof("Requesting data key") + res, err := c.grpc.GetDataKey( + ctx, + &keyserviceproto.GetDataKeyRequest{ + DataKeyId: keyID, + Length: uint32(length), + }, + conn, + ) + if err != nil { + return nil, fmt.Errorf("fetching data encryption key from Constellation KMS: %w", err) + } + + log.Infof("Data key request successful") + return res.DataKey, nil +} + +type grpcClient interface { + GetDataKey(context.Context, *keyserviceproto.GetDataKeyRequest, *grpc.ClientConn) (*keyserviceproto.GetDataKeyResponse, error) +} + +type client struct{} + +func (c client) GetDataKey(ctx context.Context, req *keyserviceproto.GetDataKeyRequest, conn *grpc.ClientConn) (*keyserviceproto.GetDataKeyResponse, error) { + return keyserviceproto.NewAPIClient(conn).GetDataKey(ctx, req) +} diff --git a/s3proxy/internal/kms/kms_test.go b/s3proxy/internal/kms/kms_test.go new file mode 100644 index 0000000000..e91fb34d6a --- /dev/null +++ b/s3proxy/internal/kms/kms_test.go @@ -0,0 +1,72 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package kms + +import ( + "context" + "errors" + "testing" + + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/keyservice/keyserviceproto" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" +) + +type stubClient struct { + getDataKeyErr error + dataKey []byte +} + +func (c *stubClient) GetDataKey(context.Context, *keyserviceproto.GetDataKeyRequest, *grpc.ClientConn) (*keyserviceproto.GetDataKeyResponse, error) { + return &keyserviceproto.GetDataKeyResponse{DataKey: c.dataKey}, c.getDataKeyErr +} + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestGetDataKey(t *testing.T) { + testCases := map[string]struct { + client *stubClient + wantErr bool + }{ + "GetDataKey success": { + client: &stubClient{dataKey: []byte{0x1, 0x2, 0x3}}, + }, + "GetDataKey error": { + client: &stubClient{getDataKeyErr: errors.New("error")}, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + listener := bufconn.Listen(1) + defer listener.Close() + + client := New( + logger.NewTest(t), + listener.Addr().String(), + ) + + client.grpc = tc.client + + res, err := client.GetDataKey(context.Background(), "disk-uuid", 32) + if tc.wantErr { + assert.Error(err) + } else { + assert.NoError(err) + assert.Equal(tc.client.dataKey, res) + } + }) + } +} diff --git a/s3proxy/internal/router/BUILD.bazel b/s3proxy/internal/router/BUILD.bazel new file mode 100644 index 0000000000..c60568bcea --- /dev/null +++ b/s3proxy/internal/router/BUILD.bazel @@ -0,0 +1,27 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//bazel/go:go_test.bzl", "go_test") + +go_library( + name = "router", + srcs = [ + "object.go", + "router.go", + ], + importpath = "github.com/edgelesssys/constellation/v2/s3proxy/internal/router", + visibility = ["//s3proxy:__subpackages__"], + deps = [ + "//internal/logger", + "//s3proxy/internal/crypto", + "//s3proxy/internal/kms", + "//s3proxy/internal/s3", + "@com_github_aws_aws_sdk_go_v2_service_s3//:s3", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "router_test", + srcs = ["router_test.go"], + embed = [":router"], + deps = ["@com_github_stretchr_testify//assert"], +) diff --git a/s3proxy/internal/router/object.go b/s3proxy/internal/router/object.go new file mode 100644 index 0000000000..0f58a2900d --- /dev/null +++ b/s3proxy/internal/router/object.go @@ -0,0 +1,220 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package router + +import ( + "context" + "encoding/hex" + "io" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/s3proxy/internal/crypto" + "go.uber.org/zap" +) + +const ( + // dekTag is the name of the header that holds the encrypted data encryption key for the attached object. Presence of the key implies the object needs to be decrypted. + // Use lowercase only, as AWS automatically lowercases all metadata keys. + dekTag = "constellation-dek" +) + +// object bundles data to implement http.Handler methods that use data from incoming requests. +type object struct { + kek [32]byte + client s3Client + key string + bucket string + data []byte + query url.Values + tags string + contentType string + metadata map[string]string + objectLockLegalHoldStatus string + objectLockMode string + objectLockRetainUntilDate time.Time + sseCustomerAlgorithm string + sseCustomerKey string + sseCustomerKeyMD5 string + log *logger.Logger +} + +// get is a http.HandlerFunc that implements the GET method for objects. +func (o object) get(w http.ResponseWriter, r *http.Request) { + o.log.With(zap.String("key", o.key), zap.String("host", o.bucket)).Debugf("getObject") + + versionID, ok := o.query["versionId"] + if !ok { + versionID = []string{""} + } + + output, err := o.client.GetObject(r.Context(), o.bucket, o.key, versionID[0], o.sseCustomerAlgorithm, o.sseCustomerKey, o.sseCustomerKeyMD5) + if err != nil { + // log with Info as it might be expected behavior (e.g. object not found). + o.log.With(zap.Error(err)).Errorf("GetObject sending request to S3") + + // We want to forward error codes from the s3 API to clients as much as possible. + code := parseErrorCode(err) + if code != 0 { + http.Error(w, err.Error(), code) + return + } + + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if output.ETag != nil { + w.Header().Set("ETag", strings.Trim(*output.ETag, "\"")) + } + if output.Expiration != nil { + w.Header().Set("x-amz-expiration", *output.Expiration) + } + if output.ChecksumCRC32 != nil { + w.Header().Set("x-amz-checksum-crc32", *output.ChecksumCRC32) + } + if output.ChecksumCRC32C != nil { + w.Header().Set("x-amz-checksum-crc32c", *output.ChecksumCRC32C) + } + if output.ChecksumSHA1 != nil { + w.Header().Set("x-amz-checksum-sha1", *output.ChecksumSHA1) + } + if output.ChecksumSHA256 != nil { + w.Header().Set("x-amz-checksum-sha256", *output.ChecksumSHA256) + } + if output.SSECustomerAlgorithm != nil { + w.Header().Set("x-amz-server-side-encryption-customer-algorithm", *output.SSECustomerAlgorithm) + } + if output.SSECustomerKeyMD5 != nil { + w.Header().Set("x-amz-server-side-encryption-customer-key-MD5", *output.SSECustomerKeyMD5) + } + if output.SSEKMSKeyId != nil { + w.Header().Set("x-amz-server-side-encryption-aws-kms-key-id", *output.SSEKMSKeyId) + } + if output.ServerSideEncryption != "" { + w.Header().Set("x-amz-server-side-encryption-context", string(output.ServerSideEncryption)) + } + + body, err := io.ReadAll(output.Body) + if err != nil { + o.log.With(zap.Error(err)).Errorf("GetObject reading S3 response") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + plaintext := body + rawEncryptedDEK, ok := output.Metadata[dekTag] + if ok { + encryptedDEK, err := hex.DecodeString(rawEncryptedDEK) + if err != nil { + o.log.Errorf("GetObject decoding DEK", "error", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + plaintext, err = crypto.Decrypt(body, encryptedDEK, o.kek) + if err != nil { + o.log.With(zap.Error(err)).Errorf("GetObject decrypting response") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + w.WriteHeader(http.StatusOK) + if _, err := w.Write(plaintext); err != nil { + o.log.With(zap.Error(err)).Errorf("GetObject sending response") + } +} + +// put is a http.HandlerFunc that implements the PUT method for objects. +func (o object) put(w http.ResponseWriter, r *http.Request) { + ciphertext, encryptedDEK, err := crypto.Encrypt(o.data, o.kek) + if err != nil { + o.log.With(zap.Error(err)).Errorf("PutObject") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + o.metadata[dekTag] = hex.EncodeToString(encryptedDEK) + + output, err := o.client.PutObject(r.Context(), o.bucket, o.key, o.tags, o.contentType, o.objectLockLegalHoldStatus, o.objectLockMode, o.sseCustomerAlgorithm, o.sseCustomerKey, o.sseCustomerKeyMD5, o.objectLockRetainUntilDate, o.metadata, ciphertext) + if err != nil { + o.log.With(zap.Error(err)).Errorf("PutObject sending request to S3") + + // We want to forward error codes from the s3 API to clients whenever possible. + code := parseErrorCode(err) + if code != 0 { + http.Error(w, err.Error(), code) + return + } + + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("x-amz-server-side-encryption", string(output.ServerSideEncryption)) + + if output.VersionId != nil { + w.Header().Set("x-amz-version-id", *output.VersionId) + } + if output.ETag != nil { + w.Header().Set("ETag", strings.Trim(*output.ETag, "\"")) + } + if output.Expiration != nil { + w.Header().Set("x-amz-expiration", *output.Expiration) + } + if output.ChecksumCRC32 != nil { + w.Header().Set("x-amz-checksum-crc32", *output.ChecksumCRC32) + } + if output.ChecksumCRC32C != nil { + w.Header().Set("x-amz-checksum-crc32c", *output.ChecksumCRC32C) + } + if output.ChecksumSHA1 != nil { + w.Header().Set("x-amz-checksum-sha1", *output.ChecksumSHA1) + } + if output.ChecksumSHA256 != nil { + w.Header().Set("x-amz-checksum-sha256", *output.ChecksumSHA256) + } + if output.SSECustomerAlgorithm != nil { + w.Header().Set("x-amz-server-side-encryption-customer-algorithm", *output.SSECustomerAlgorithm) + } + if output.SSECustomerKeyMD5 != nil { + w.Header().Set("x-amz-server-side-encryption-customer-key-MD5", *output.SSECustomerKeyMD5) + } + if output.SSEKMSKeyId != nil { + w.Header().Set("x-amz-server-side-encryption-aws-kms-key-id", *output.SSEKMSKeyId) + } + if output.SSEKMSEncryptionContext != nil { + w.Header().Set("x-amz-server-side-encryption-context", *output.SSEKMSEncryptionContext) + } + + w.WriteHeader(http.StatusOK) + if _, err := w.Write(nil); err != nil { + o.log.With(zap.Error(err)).Errorf("PutObject sending response") + } +} + +func parseErrorCode(err error) int { + regex := regexp.MustCompile(`https response error StatusCode: (\d+)`) + matches := regex.FindStringSubmatch(err.Error()) + if len(matches) > 1 { + code, _ := strconv.Atoi(matches[1]) + return code + } + + return 0 +} + +type s3Client interface { + GetObject(ctx context.Context, bucket, key, versionID, sseCustomerAlgorithm, sseCustomerKey, sseCustomerKeyMD5 string) (*s3.GetObjectOutput, error) + PutObject(ctx context.Context, bucket, key, tags, contentType, objectLockLegalHoldStatus, objectLockMode, sseCustomerAlgorithm, sseCustomerKey, sseCustomerKeyMD5 string, objectLockRetainUntilDate time.Time, metadata map[string]string, body []byte) (*s3.PutObjectOutput, error) +} diff --git a/s3proxy/internal/router/router.go b/s3proxy/internal/router/router.go new file mode 100644 index 0000000000..bd9b844276 --- /dev/null +++ b/s3proxy/internal/router/router.go @@ -0,0 +1,432 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +/* +Package router implements the main interception logic of s3proxy. +It decides which packages to forward and which to intercept. + +The routing logic in this file is taken from this blog post: https://benhoyt.com/writings/go-routing/#regex-switch. +We should be able to replace this once this is part of the stdlib: https://github.com/golang/go/issues/61410. + +If the router intercepts a PutObject request it will encrypt the body before forwarding it to the S3 API. +The stored object will have a tag that holds an encrypted data encryption key (DEK). +That DEK is used to encrypt the object's body. +The DEK is generated randomly for each PutObject request. +The DEK is encrypted with a key encryption key (KEK) fetched from Constellation's keyservice. +*/ +package router + +import ( + "bytes" + "context" + "crypto/md5" + "crypto/sha256" + "encoding/base64" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/s3proxy/internal/kms" + "github.com/edgelesssys/constellation/v2/s3proxy/internal/s3" + "go.uber.org/zap" +) + +const ( + // Use a 32*8 = 256 bit key for AES-256. + kekSizeBytes = 32 + kekID = "s3proxy-kek" +) + +var ( + keyPattern = regexp.MustCompile("/(.+)") + bucketAndKeyPattern = regexp.MustCompile("/([^/?]+)/(.+)") +) + +// Router implements the interception logic for the s3proxy. +type Router struct { + region string + kek [32]byte + log *logger.Logger +} + +// New creates a new Router. +func New(region, endpoint string, log *logger.Logger) (Router, error) { + kms := kms.New(log, endpoint) + + // Get the key encryption key that encrypts all DEKs. + kek, err := kms.GetDataKey(context.Background(), kekID, kekSizeBytes) + if err != nil { + return Router{}, fmt.Errorf("getting KEK: %w", err) + } + + kekArray, err := byteSliceToByteArray(kek) + if err != nil { + return Router{}, fmt.Errorf("converting KEK to byte array: %w", err) + } + + return Router{region: region, kek: kekArray, log: log}, nil +} + +// Serve implements the routing logic for the s3 proxy. +// It intercepts GetObject and PutObject requests, encrypting/decrypting their bodies if necessary. +// All other requests are forwarded to the S3 API. +// Ideally we could separate routing logic, request handling and s3 interactions. +// Currently routing logic and request handling are integrated. +func (r Router) Serve(w http.ResponseWriter, req *http.Request) { + client, err := s3.NewClient(r.region) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var key string + var bucket string + var matchingPath bool + if containsBucket(req.Host) { + // BUCKET.s3.REGION.amazonaws.com + parts := strings.Split(req.Host, ".") + bucket = parts[0] + + matchingPath = match(req.URL.Path, keyPattern, &key) + + } else { + matchingPath = match(req.URL.Path, bucketAndKeyPattern, &bucket, &key) + } + + var h http.Handler + switch { + // intercept GetObject. + case matchingPath && req.Method == "GET" && !isUnwantedGetEndpoint(req.URL.Query()): + h = handleGetObject(client, key, bucket, r.log) + // intercept PutObject. + case matchingPath && req.Method == "PUT" && !isUnwantedPutEndpoint(req.Header, req.URL.Query()): + h = handlePutObject(client, key, bucket, r.log) + // Forward all other requests. + default: + h = handleForwards(r.log) + } + + h.ServeHTTP(w, req) +} + +// ContentSHA256MismatchError is a helper struct to create an XML formatted error message. +// s3 clients might try to parse error messages, so we need to serve correctly formatted messages. +type ContentSHA256MismatchError struct { + XMLName xml.Name `xml:"Error"` + Code string `xml:"Code"` + Message string `xml:"Message"` + ClientComputedContentSHA256 string `xml:"ClientComputedContentSHA256"` + S3ComputedContentSHA256 string `xml:"S3ComputedContentSHA256"` +} + +// NewContentSHA256MismatchError creates a new ContentSHA256MismatchError. +func NewContentSHA256MismatchError(clientComputedContentSHA256, s3ComputedContentSHA256 string) ContentSHA256MismatchError { + return ContentSHA256MismatchError{ + Code: "XAmzContentSHA256Mismatch", + Message: "The provided 'x-amz-content-sha256' header does not match what was computed.", + ClientComputedContentSHA256: clientComputedContentSHA256, + S3ComputedContentSHA256: s3ComputedContentSHA256, + } +} + +func handleGetObject(client *s3.Client, key string, bucket string, log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting") + if req.Header.Get("Range") != "" { + log.Errorf("GetObject Range header unsupported") + http.Error(w, "s3proxy currently does not support Range headers", http.StatusNotImplemented) + return + } + + obj := object{ + client: client, + key: key, + bucket: bucket, + query: req.URL.Query(), + sseCustomerAlgorithm: req.Header.Get("x-amz-server-side-encryption-customer-algorithm"), + sseCustomerKey: req.Header.Get("x-amz-server-side-encryption-customer-key"), + sseCustomerKeyMD5: req.Header.Get("x-amz-server-side-encryption-customer-key-MD5"), + log: log, + } + get(obj.get)(w, req) + } +} + +func handlePutObject(client *s3.Client, key string, bucket string, log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting") + body, err := io.ReadAll(req.Body) + if err != nil { + log.With(zap.Error(err)).Errorf("PutObject") + http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError) + return + } + + clientDigest := req.Header.Get("x-amz-content-sha256") + serverDigest := sha256sum(body) + + // There may be a client that wants to test that incorrect content digests result in API errors. + // For encrypting the body we have to recalculate the content digest. + // If the client intentionally sends a mismatching content digest, we would take the client request, rewrap it, + // calculate the correct digest for the new body and NOT get an error. + // Thus we have to check incoming requets for matching content digests. + // UNSIGNED-PAYLOAD can be used to disabled payload signing. In that case we don't check the content digest. + if clientDigest != "" && clientDigest != "UNSIGNED-PAYLOAD" && clientDigest != serverDigest { + log.Debugf("PutObject", "error", "x-amz-content-sha256 mismatch") + // The S3 API responds with an XML formatted error message. + mismatchErr := NewContentSHA256MismatchError(clientDigest, serverDigest) + marshalled, err := xml.Marshal(mismatchErr) + if err != nil { + log.With(zap.Error(err)).Errorf("PutObject") + http.Error(w, fmt.Sprintf("marshalling error: %s", err.Error()), http.StatusInternalServerError) + return + } + + http.Error(w, string(marshalled), http.StatusBadRequest) + return + } + + metadata := getMetadataHeaders(req.Header) + + raw := req.Header.Get("x-amz-object-lock-retain-until-date") + retentionTime, err := parseRetentionTime(raw) + if err != nil { + log.With(zap.String("data", raw), zap.Error(err)).Errorf("parsing lock retention time") + http.Error(w, fmt.Sprintf("parsing x-amz-object-lock-retain-until-date: %s", err.Error()), http.StatusInternalServerError) + return + } + + err = validateContentMD5(req.Header.Get("content-md5"), body) + if err != nil { + log.With(zap.Error(err)).Errorf("validating content md5") + http.Error(w, fmt.Sprintf("validating content md5: %s", err.Error()), http.StatusBadRequest) + return + } + + obj := object{ + client: client, + key: key, + bucket: bucket, + data: body, + query: req.URL.Query(), + tags: req.Header.Get("x-amz-tagging"), + contentType: req.Header.Get("Content-Type"), + metadata: metadata, + objectLockLegalHoldStatus: req.Header.Get("x-amz-object-lock-legal-hold"), + objectLockMode: req.Header.Get("x-amz-object-lock-mode"), + objectLockRetainUntilDate: retentionTime, + sseCustomerAlgorithm: req.Header.Get("x-amz-server-side-encryption-customer-algorithm"), + sseCustomerKey: req.Header.Get("x-amz-server-side-encryption-customer-key"), + sseCustomerKeyMD5: req.Header.Get("x-amz-server-side-encryption-customer-key-MD5"), + log: log, + } + + put(obj.put)(w, req) + } +} + +func handleForwards(log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("forwarding") + + newReq := repackage(req) + + httpClient := http.DefaultClient + resp, err := httpClient.Do(&newReq) + if err != nil { + log.With(zap.Error(err)).Errorf("do request") + http.Error(w, fmt.Sprintf("do request: %s", err.Error()), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + for key := range resp.Header { + w.Header().Set(key, resp.Header.Get(key)) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + log.With(zap.Error(err)).Errorf("ReadAll") + http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError) + return + } + w.WriteHeader(resp.StatusCode) + if body == nil { + return + } + + if _, err := w.Write(body); err != nil { + log.With(zap.Error(err)).Errorf("Write") + http.Error(w, fmt.Sprintf("writing body: %s", err.Error()), http.StatusInternalServerError) + return + } + } +} + +// byteSliceToByteArray casts a byte slice to a byte array of length 32. +// It does a length check to prevent the cast from panic'ing. +func byteSliceToByteArray(input []byte) ([32]byte, error) { + if len(input) != 32 { + return [32]byte{}, fmt.Errorf("input length mismatch, got: %d", len(input)) + } + + return ([32]byte)(input), nil +} + +// containsBucket is a helper to recognizes cases where the bucket name is sent as part of the host. +// In other cases the bucket name is sent as part of the path. +func containsBucket(host string) bool { + parts := strings.Split(host, ".") + return len(parts) > 4 +} + +// isUnwantedGetEndpoint returns true if the request is any of these requests: GetObjectAcl, GetObjectAttributes, GetObjectLegalHold, GetObjectRetention, GetObjectTagging, GetObjectTorrent, ListParts. +// These requests are all structured similarly: they all have a query param that is not present in GetObject. +// Otherwise those endpoints are similar to GetObject. +func isUnwantedGetEndpoint(query url.Values) bool { + _, acl := query["acl"] + _, attributes := query["attributes"] + _, legalHold := query["legal-hold"] + _, retention := query["retention"] + _, tagging := query["tagging"] + _, torrent := query["torrent"] + _, uploadID := query["uploadId"] + + return acl || attributes || legalHold || retention || tagging || torrent || uploadID +} + +// isUnwantedPutEndpoint returns true if the request is any of these requests: UploadPart, PutObjectTagging. +// These requests are all structured similarly: they all have a query param that is not present in PutObject. +// Otherwise those endpoints are similar to PutObject. +func isUnwantedPutEndpoint(header http.Header, query url.Values) bool { + if header.Get("x-amz-copy-source") != "" { + return true + } + + _, partNumber := query["partNumber"] + _, uploadID := query["uploadId"] + _, tagging := query["tagging"] + _, legalHold := query["legal-hold"] + _, objectLock := query["object-lock"] + _, retention := query["retention"] + _, publicAccessBlock := query["publicAccessBlock"] + _, acl := query["acl"] + + return partNumber || uploadID || tagging || legalHold || objectLock || retention || publicAccessBlock || acl +} + +func sha256sum(data []byte) string { + digest := sha256.Sum256(data) + return fmt.Sprintf("%x", digest) +} + +// getMetadataHeaders parses user-defined metadata headers from a +// http.Header object. Users can define custom headers by taking +// HEADERNAME and prefixing it with "x-amz-meta-". +func getMetadataHeaders(header http.Header) map[string]string { + result := map[string]string{} + + for key := range header { + key = strings.ToLower(key) + + if strings.HasPrefix(key, "x-amz-meta-") { + name := strings.TrimPrefix(key, "x-amz-meta-") + result[name] = strings.Join(header.Values(key), ",") + } + } + + return result +} + +func parseRetentionTime(raw string) (time.Time, error) { + if raw == "" { + return time.Time{}, nil + } + return time.Parse(time.RFC3339, raw) +} + +// repackage implements all modifications we need to do to an incoming request that we want to forward to the s3 API. +func repackage(r *http.Request) http.Request { + req := r.Clone(r.Context()) + + // HTTP clients are not supposed to set this field, however when we receive a request it is set. + // So, we unset it. + req.RequestURI = "" + + req.URL.Host = r.Host + // We always want to use HTTPS when talking to S3. + req.URL.Scheme = "https" + + return *req +} + +// validateContentMD5 checks if the content-md5 header matches the body. +func validateContentMD5(contentMD5 string, body []byte) error { + if contentMD5 == "" { + return nil + } + + expected, err := base64.StdEncoding.DecodeString(contentMD5) + if err != nil { + return fmt.Errorf("decoding base64: %w", err) + } + + if len(expected) != 16 { + return fmt.Errorf("content-md5 must be 16 bytes long, got %d bytes", len(expected)) + } + + actual := md5.Sum(body) + + if !bytes.Equal(actual[:], expected) { + return fmt.Errorf("content-md5 mismatch, header is %x, body is %x", expected, actual) + } + + return nil +} + +// match reports whether path matches pattern, and if it matches, +// assigns any capture groups to the *string or *int vars. +func match(path string, pattern *regexp.Regexp, vars ...*string) bool { + matches := pattern.FindStringSubmatch(path) + if len(matches) <= 0 { + return false + } + + for i, match := range matches[1:] { + // assign the value of 'match' to the i-th argument. + *vars[i] = match + } + return true +} + +// allowMethod takes a HandlerFunc and wraps it in a handler that only +// responds if the request method is the given method, otherwise it +// responds with HTTP 405 Method Not Allowed. +func allowMethod(h http.HandlerFunc, method string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if method != r.Method { + w.Header().Set("Allow", method) + http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed) + return + } + h(w, r) + } +} + +// get takes a HandlerFunc and wraps it to only allow the GET method. +func get(h http.HandlerFunc) http.HandlerFunc { + return allowMethod(h, "GET") +} + +// put takes a HandlerFunc and wraps it to only allow the POST method. +func put(h http.HandlerFunc) http.HandlerFunc { + return allowMethod(h, "PUT") +} diff --git a/s3proxy/internal/router/router_test.go b/s3proxy/internal/router/router_test.go new file mode 100644 index 0000000000..a690ce669a --- /dev/null +++ b/s3proxy/internal/router/router_test.go @@ -0,0 +1,87 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ +package router + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateContentMD5(t *testing.T) { + tests := map[string]struct { + contentMD5 string + body []byte + expectedErrMsg string + }{ + "empty content-md5": { + contentMD5: "", + body: []byte("hello, world"), + }, + // https://datatracker.ietf.org/doc/html/rfc1864#section-2 + "valid content-md5": { + contentMD5: "Q2hlY2sgSW50ZWdyaXR5IQ==", + body: []byte("Check Integrity!"), + }, + "invalid content-md5": { + contentMD5: "invalid base64", + body: []byte("hello, world"), + expectedErrMsg: "decoding base64", + }, + } + + // Iterate over the test cases + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + // Call the validateContentMD5 function + err := validateContentMD5(tc.contentMD5, tc.body) + + // Check the result against the expected value + if tc.expectedErrMsg != "" { + assert.ErrorContains(t, err, tc.expectedErrMsg) + } + }) + } +} + +func TestByteSliceToByteArray(t *testing.T) { + tests := map[string]struct { + input []byte + output [32]byte + wantErr bool + }{ + "empty input": { + input: []byte{}, + output: [32]byte{}, + }, + "successful input": { + input: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + output: [32]byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41}, + }, + "input too short": { + input: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + output: [32]byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41}, + wantErr: true, + }, + "input too long": { + input: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + output: [32]byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41}, + wantErr: true, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + result, err := byteSliceToByteArray(tc.input) + if tc.wantErr { + assert.Error(t, err) + return + } + + assert.Equal(t, tc.output, result) + }) + } +} diff --git a/s3proxy/internal/s3/BUILD.bazel b/s3proxy/internal/s3/BUILD.bazel new file mode 100644 index 0000000000..e095631974 --- /dev/null +++ b/s3proxy/internal/s3/BUILD.bazel @@ -0,0 +1,13 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "s3", + srcs = ["s3.go"], + importpath = "github.com/edgelesssys/constellation/v2/s3proxy/internal/s3", + visibility = ["//s3proxy:__subpackages__"], + deps = [ + "@com_github_aws_aws_sdk_go_v2_config//:config", + "@com_github_aws_aws_sdk_go_v2_service_s3//:s3", + "@com_github_aws_aws_sdk_go_v2_service_s3//types", + ], +) diff --git a/s3proxy/internal/s3/s3.go b/s3proxy/internal/s3/s3.go new file mode 100644 index 0000000000..462530be73 --- /dev/null +++ b/s3proxy/internal/s3/s3.go @@ -0,0 +1,116 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +/* +Package s3 implements a very thin wrapper around the AWS S3 client. +It only exists to enable stubbing of the AWS S3 client in tests. +*/ +package s3 + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/base64" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +// Client is a wrapper around the AWS S3 client. +type Client struct { + s3client *s3.Client +} + +// NewClient creates a new AWS S3 client. +func NewClient(region string) (*Client, error) { + // Use context.Background here because this context will not influence the later operations of the client. + // The context given here is used for http requests that are made during client construction. + // Client construction happens once during proxy setup. + clientCfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion(region), + ) + if err != nil { + return nil, fmt.Errorf("loading AWS S3 client config: %w", err) + } + + client := s3.NewFromConfig(clientCfg) + + return &Client{client}, nil +} + +// GetObject returns the object with the given key from the given bucket. +// If a versionID is given, the specific version of the object is returned. +func (c Client) GetObject(ctx context.Context, bucket, key, versionID, sseCustomerAlgorithm, sseCustomerKey, sseCustomerKeyMD5 string) (*s3.GetObjectOutput, error) { + getObjectInput := &s3.GetObjectInput{ + Bucket: &bucket, + Key: &key, + } + if versionID != "" { + getObjectInput.VersionId = &versionID + } + if sseCustomerAlgorithm != "" { + getObjectInput.SSECustomerAlgorithm = &sseCustomerAlgorithm + } + if sseCustomerKey != "" { + getObjectInput.SSECustomerKey = &sseCustomerKey + } + if sseCustomerKeyMD5 != "" { + getObjectInput.SSECustomerKeyMD5 = &sseCustomerKeyMD5 + } + + return c.s3client.GetObject(ctx, getObjectInput) +} + +// PutObject creates a new object in the given bucket with the given key and body. +// Various optional parameters can be set. +func (c Client) PutObject(ctx context.Context, bucket, key, tags, contentType, objectLockLegalHoldStatus, objectLockMode, sseCustomerAlgorithm, sseCustomerKey, sseCustomerKeyMD5 string, objectLockRetainUntilDate time.Time, metadata map[string]string, body []byte) (*s3.PutObjectOutput, error) { + // The AWS Go SDK has two versions. V1 does not set the Content-Type header. + // V2 always sets the Content-Type header. We use V2. + // The s3 API sets an object's content-type to binary/octet-stream if + // it receives a request without a Content-Type header set. + // Since a client using V1 may depend on the Content-Type binary/octet-stream + // we have to explicitly emulate the S3 API behavior, if we receive a request + // without a Content-Type. + if contentType == "" { + contentType = "binary/octet-stream" + } + + contentMD5 := md5.Sum(body) + encodedContentMD5 := base64.StdEncoding.EncodeToString(contentMD5[:]) + + putObjectInput := &s3.PutObjectInput{ + Bucket: &bucket, + Key: &key, + Body: bytes.NewReader(body), + Tagging: &tags, + Metadata: metadata, + ContentMD5: &encodedContentMD5, + ContentType: &contentType, + ObjectLockLegalHoldStatus: types.ObjectLockLegalHoldStatus(objectLockLegalHoldStatus), + } + if sseCustomerAlgorithm != "" { + putObjectInput.SSECustomerAlgorithm = &sseCustomerAlgorithm + } + if sseCustomerKey != "" { + putObjectInput.SSECustomerKey = &sseCustomerKey + } + if sseCustomerKeyMD5 != "" { + putObjectInput.SSECustomerKeyMD5 = &sseCustomerKeyMD5 + } + + // It is not allowed to only set one of these two properties. + if objectLockMode != "" && !objectLockRetainUntilDate.IsZero() { + putObjectInput.ObjectLockMode = types.ObjectLockMode(objectLockMode) + putObjectInput.ObjectLockRetainUntilDate = &objectLockRetainUntilDate + } + + return c.s3client.PutObject(ctx, putObjectInput) +}