Skip to content

Commit

Permalink
s3proxy: use log/slog for logging
Browse files Browse the repository at this point in the history
  • Loading branch information
derpsteb committed Sep 29, 2023
1 parent 6ef937b commit 8bbc9b6
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 52 deletions.
38 changes: 17 additions & 21 deletions s3proxy/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
"crypto/tls"
"flag"
"fmt"
"log/slog"
"net"
"net/http"
"os"

"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/s3proxy/internal/router"
)

Expand All @@ -30,7 +31,7 @@ const (
// defaultCertLocation is the default location of the TLS certificate.
defaultCertLocation = "/etc/s3proxy/certs"
// defaultLogLevel is the default log level.
defaultLogLevel = 0
defaultLogLevel = "info"
)

func main() {
Expand All @@ -40,21 +41,18 @@ func main() {
}

// logLevel can be made a public variable so logging level can be changed dynamically.
// TODO (derpsteb): enable once we are on go 1.21.
// logLevel := new(slog.LevelVar)
// handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: logLevel})
// logger := slog.New(handler)
// logLevel.Set(flags.logLevel)

logger := logger.New(logger.JSONLog, logger.VerbosityFromInt(flags.logLevel))
logLevel := new(slog.LevelVar)
handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: logLevel})
logger := slog.New(handler)
logLevel.Set(flags.logLevel)

if err := runServer(flags, logger); err != nil {
panic(err)
}
}

func runServer(flags cmdFlags, log *logger.Logger) error {
log.Infof("listening", "ip", flags.ip, "port", flags.port, "region", flags.region)
func runServer(flags cmdFlags, log *slog.Logger) error {
log.Info("listening", "ip", flags.ip, "port", flags.port, "region", flags.region)

router := router.New(flags.region, log)

Expand All @@ -81,7 +79,7 @@ func runServer(flags cmdFlags, log *logger.Logger) error {
return server.ListenAndServeTLS("", "")
}

log.Warnf("TLS is disabled")
log.Warn("TLS is disabled")
return server.ListenAndServe()
}

Expand All @@ -90,7 +88,7 @@ func parseFlags() (cmdFlags, error) {
ip := flag.String("ip", defaultIP, "ip to listen on")
region := flag.String("region", defaultRegion, "AWS region in which target bucket is located")
certLocation := flag.String("cert", defaultCertLocation, "location of TLS certificate")
level := flag.Int("level", defaultLogLevel, "log level")
level := flag.String("level", defaultLogLevel, "log level")

flag.Parse()

Expand All @@ -100,20 +98,18 @@ func parseFlags() (cmdFlags, error) {
}

// TODO(derpsteb): enable once we are on go 1.21.
// logLevel := new(slog.Level)
// if err := logLevel.UnmarshalText([]byte(*level)); err != nil {
// return cmdFlags{}, fmt.Errorf("parsing log level: %w", err)
// }
logLevel := new(slog.Level)
if err := logLevel.UnmarshalText([]byte(*level)); err != nil {
return cmdFlags{}, fmt.Errorf("parsing log level: %w", err)
}

return cmdFlags{port: *port, ip: netIP.String(), region: *region, certLocation: *certLocation, logLevel: *level}, nil
return cmdFlags{port: *port, ip: netIP.String(), region: *region, certLocation: *certLocation, logLevel: *logLevel}, nil
}

type cmdFlags struct {
port int
ip string
region string
certLocation string
// TODO(derpsteb): enable once we are on go 1.21.
// logLevel slog.Level
logLevel int
logLevel slog.Level
}
20 changes: 10 additions & 10 deletions s3proxy/internal/router/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package router
import (
"context"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
Expand All @@ -17,7 +18,6 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/s3proxy/internal/crypto"
)

Expand All @@ -42,12 +42,12 @@ type object struct {
objectLockLegalHoldStatus string
objectLockMode string
objectLockRetainUntilDate time.Time
log *logger.Logger
log *slog.Logger
}

// TODO(derpsteb): serve all headers present in s3.GetObjectOutput in s3 proxy response. currently we only serve those required to make minio/mint pass.
func (o object) get(w http.ResponseWriter, r *http.Request) {
o.log.Debugf("getObject", "key", o.key, "host", o.bucket)
o.log.Debug("getObject", "key", o.key, "host", o.bucket)

versionID, ok := o.query["versionId"]
if !ok {
Expand All @@ -57,7 +57,7 @@ func (o object) get(w http.ResponseWriter, r *http.Request) {
data, err := o.client.GetObject(r.Context(), o.bucket, o.key, versionID[0])
if err != nil {
// log with Info as it might be expected behavior (e.g. object not found).
o.log.Errorf("GetObject sending request to S3", "error", err)
o.log.Error("GetObject sending request to S3", "error", err)

// We want to forward error codes from the s3 API to clients as much as possible.
code := parseErrorCode(err)
Expand All @@ -76,7 +76,7 @@ func (o object) get(w http.ResponseWriter, r *http.Request) {

body, err := io.ReadAll(data.Body)
if err != nil {
o.log.Errorf("GetObject reading S3 response", "error", err)
o.log.Error("GetObject reading S3 response", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -87,22 +87,22 @@ func (o object) get(w http.ResponseWriter, r *http.Request) {
if ok && decrypt == "true" {
plaintext, err = crypto.Decrypt(body, []byte(testingKey))
if err != nil {
o.log.Errorf("GetObject decrypting response", "error", err)
o.log.Error("GetObject decrypting response", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

w.WriteHeader(http.StatusOK)
if _, err := w.Write(plaintext); err != nil {
o.log.Errorf("GetObject sending response", "error", err)
o.log.Error("GetObject sending response", "error", err)
}
}

func (o object) put(w http.ResponseWriter, r *http.Request) {
ciphertext, err := crypto.Encrypt(o.data, []byte(testingKey))
if err != nil {
o.log.Errorf("PutObject", "error", err)
o.log.Error("PutObject", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -113,7 +113,7 @@ func (o object) put(w http.ResponseWriter, r *http.Request) {

output, err := o.client.PutObject(r.Context(), o.bucket, o.key, o.tags, o.contentType, o.objectLockLegalHoldStatus, o.objectLockMode, o.objectLockRetainUntilDate, o.metadata, ciphertext)
if err != nil {
o.log.Errorf("PutObject sending request to S3", "error", err)
o.log.Error("PutObject sending request to S3", "error", err)

// We want to forward error codes from the s3 API to clients whenever possible.
code := parseErrorCode(err)
Expand Down Expand Up @@ -164,7 +164,7 @@ func (o object) put(w http.ResponseWriter, r *http.Request) {

w.WriteHeader(http.StatusOK)
if _, err := w.Write(nil); err != nil {
o.log.Errorf("PutObject sending response", "error", err)
o.log.Error("PutObject sending response", "error", err)
}
}

Expand Down
38 changes: 19 additions & 19 deletions s3proxy/internal/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/xml"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
Expand All @@ -29,7 +30,6 @@ import (
"sync"
"time"

"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/s3proxy/internal/s3"
)

Expand All @@ -41,11 +41,11 @@ var (
// Router implements the interception logic for the s3proxy.
type Router struct {
region string
log *logger.Logger
log *slog.Logger
}

// New creates a new Router.
func New(region string, log *logger.Logger) Router {
func New(region string, log *slog.Logger) Router {
return Router{region: region, log: log}
}

Expand Down Expand Up @@ -96,10 +96,10 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
parts := strings.Split(req.Host, ".")
bucket := parts[0]

r.log.Debugf("intercepting", "path", req.URL.Path, "method", req.Method, "host", req.Host)
r.log.Debug("intercepting", "path", req.URL.Path, "method", req.Method, "host", req.Host)
body, err := io.ReadAll(req.Body)
if err != nil {
r.log.Errorf("PutObject", "error", err)
r.log.Error("PutObject", "error", err)
http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError)
return
}
Expand All @@ -113,12 +113,12 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
// calculate the correct digest for the new body and NOT get an error.
// Thus we have to check incoming requets for matching content digests.
if clientDigest != "" && clientDigest != serverDigest {
r.log.Debugf("PutObject", "error", "x-amz-content-sha256 mismatch")
r.log.Debug("PutObject", "error", "x-amz-content-sha256 mismatch")
// The S3 API responds with an XML formatted error message.
mismatchErr := NewContentSHA256MismatchError(clientDigest, serverDigest)
marshalled, err := xml.Marshal(mismatchErr)
if err != nil {
r.log.Errorf("PutObject", "error", err)
r.log.Error("PutObject", "error", err)
http.Error(w, fmt.Sprintf("marshalling error: %s", err.Error()), http.StatusInternalServerError)
return
}
Expand All @@ -132,14 +132,14 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
raw := req.Header.Get("x-amz-object-lock-retain-until-date")
retentionTime, err := parseRetentionTime(raw)
if err != nil {
r.log.Errorf("parsing lock retention time", "data", raw, "error", err.Error())
r.log.Error("parsing lock retention time", "data", raw, "error", err.Error())
http.Error(w, fmt.Sprintf("parsing x-amz-object-lock-retain-until-date: %s", err.Error()), http.StatusInternalServerError)
return
}

err = validateContentMD5(req.Header.Get("content-md5"), body)
if err != nil {
r.log.Errorf("validating content md5", "error", err.Error())
r.log.Error("validating content md5", "error", err.Error())
http.Error(w, fmt.Sprintf("validating content md5: %s", err.Error()), http.StatusBadRequest)
return
}
Expand All @@ -162,10 +162,10 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
h = put(obj.put)

case !containsBucket(req.Host) && match(path, "/([^/?]+)/(.+)", &bucket, &key) && req.Method == "PUT" && !isUnwantedPutEndpoint(req.Header, req.URL.Query()):
r.log.Debugf("intercepting", "path", req.URL.Path, "method", req.Method, "host", req.Host)
r.log.Debug("intercepting", "path", req.URL.Path, "method", req.Method, "host", req.Host)
body, err := io.ReadAll(req.Body)
if err != nil {
r.log.Errorf("PutObject", "error", err)
r.log.Error("PutObject", "error", err)
http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError)
return
}
Expand All @@ -179,12 +179,12 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
// calculate the correct digest for the new body and NOT get an error.
// Thus we have to check incoming requets for matching content digests.
if clientDigest != "" && clientDigest != serverDigest {
r.log.Debugf("PutObject", "error", "x-amz-content-sha256 mismatch")
r.log.Debug("PutObject", "error", "x-amz-content-sha256 mismatch")
// The S3 API responds with an XML formatted error message.
mismatchErr := NewContentSHA256MismatchError(clientDigest, serverDigest)
marshalled, err := xml.Marshal(mismatchErr)
if err != nil {
r.log.Errorf("PutObject", "error", err)
r.log.Error("PutObject", "error", err)
http.Error(w, fmt.Sprintf("marshalling error: %s", err.Error()), http.StatusInternalServerError)
return
}
Expand All @@ -198,14 +198,14 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
raw := req.Header.Get("x-amz-object-lock-retain-until-date")
retentionTime, err := parseRetentionTime(raw)
if err != nil {
r.log.Errorf("parsing lock retention time", "data", raw, "error", err.Error())
r.log.Error("parsing lock retention time", "data", raw, "error", err.Error())
http.Error(w, fmt.Sprintf("parsing x-amz-object-lock-retain-until-date: %s", err.Error()), http.StatusInternalServerError)
return
}

err = validateContentMD5(req.Header.Get("content-md5"), body)
if err != nil {
r.log.Errorf("validating content md5", "error", err.Error())
r.log.Error("validating content md5", "error", err.Error())
http.Error(w, fmt.Sprintf("validating content md5: %s", err.Error()), http.StatusBadRequest)
return
}
Expand All @@ -229,14 +229,14 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {

// Forward all other requests.
default:
r.log.Debugf("forwarding", "path", req.URL.Path, "method", req.Method, "host", req.Host, "headers", req.Header)
r.log.Debug("forwarding", "path", req.URL.Path, "method", req.Method, "host", req.Host, "headers", req.Header)

newReq := repackage(req)

httpClient := http.DefaultClient
resp, err := httpClient.Do(&newReq)
if err != nil {
r.log.Errorf("do request", "error", err)
r.log.Error("do request", "error", err)
http.Error(w, fmt.Sprintf("do request: %s", err.Error()), http.StatusInternalServerError)
return
}
Expand All @@ -247,7 +247,7 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
}
body, err := io.ReadAll(resp.Body)
if err != nil {
r.log.Errorf("ReadAll", "error", err)
r.log.Error("ReadAll", "error", err)
http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError)
return
}
Expand All @@ -257,7 +257,7 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) {
}

if _, err := w.Write(body); err != nil {
r.log.Errorf("Write", "error", err)
r.log.Error("Write", "error", err)
http.Error(w, fmt.Sprintf("writing body: %s", err.Error()), http.StatusInternalServerError)
return
}
Expand Down
22 changes: 20 additions & 2 deletions s3proxy/internal/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,30 @@ func NewClient(region string) (*Client, error) {

// GetObject returns the object with the given key from the given bucket.
// If a versionID is given, the specific version of the object is returned.
func (c Client) GetObject(ctx context.Context, bucket, key, versionID string) (*s3.GetObjectOutput, error) {
func (c Client) GetObject(ctx context.Context, bucket, key, versionID, sseCustomerAlgorithm, sseCustomerKey, sseCustomerKeyMD5 string) (*s3.GetObjectOutput, error) {
getObjectInput := &s3.GetObjectInput{
Bucket: &bucket,
Key: &key,
}
if versionID != "" {
getObjectInput.VersionId = &versionID
}
if sseCustomerAlgorithm != "" {
getObjectInput.SSECustomerAlgorithm = &sseCustomerAlgorithm
}
if sseCustomerKey != "" {
getObjectInput.SSECustomerKey = &sseCustomerKey
}
if sseCustomerKeyMD5 != "" {
getObjectInput.SSECustomerKeyMD5 = &sseCustomerKeyMD5
}

return c.s3client.GetObject(ctx, getObjectInput)
}

// PutObject creates a new object in the given bucket with the given key and body.
// Various optional parameters can be set.
func (c Client) PutObject(ctx context.Context, bucket, key, tags, contentType, objectLockLegalHoldStatus, objectLockMode string, objectLockRetainUntilDate time.Time, metadata map[string]string, body []byte) (*s3.PutObjectOutput, error) {
func (c Client) PutObject(ctx context.Context, bucket, key, tags, contentType, objectLockLegalHoldStatus, objectLockMode, sseCustomerAlgorithm, sseCustomerKey, sseCustomerKeyMD5 string, objectLockRetainUntilDate time.Time, metadata map[string]string, body []byte) (*s3.PutObjectOutput, error) {
// The AWS Go SDK has two versions. V1 does not set the Content-Type header.
// V2 always sets the Content-Type header. We use V2.
// The s3 API sets an object's content-type to binary/octet-stream if
Expand All @@ -87,6 +96,15 @@ func (c Client) PutObject(ctx context.Context, bucket, key, tags, contentType, o
ContentType: &contentType,
ObjectLockLegalHoldStatus: types.ObjectLockLegalHoldStatus(objectLockLegalHoldStatus),
}
if sseCustomerAlgorithm != "" {
putObjectInput.SSECustomerAlgorithm = &sseCustomerAlgorithm
}
if sseCustomerKey != "" {
putObjectInput.SSECustomerKey = &sseCustomerKey
}
if sseCustomerKeyMD5 != "" {
putObjectInput.SSECustomerKeyMD5 = &sseCustomerKeyMD5
}

// It is not allowed to only set one of these two properties.
if objectLockMode != "" && !objectLockRetainUntilDate.IsZero() {
Expand Down

0 comments on commit 8bbc9b6

Please sign in to comment.