From a31feae6d4aabb8755d0416291ab3f6d0a501c98 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:18:23 +0200 Subject: [PATCH 01/19] api/render: initial implementation of the package --- api/render/render.go | 58 +++++++++++++++++++++++++++++++++++++++ api/render/render_test.go | 53 +++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 api/render/render.go create mode 100644 api/render/render_test.go diff --git a/api/render/render.go b/api/render/render.go new file mode 100644 index 000000000..e4fbaffa9 --- /dev/null +++ b/api/render/render.go @@ -0,0 +1,58 @@ +// Package render implements functionality related to response rendering. +package render + +import ( + "encoding/json" + "net/http" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/api/log" +) + +// JSON writes the passed value into the http.ResponseWriter. +func JSON(w http.ResponseWriter, v interface{}) { + JSONStatus(w, v, http.StatusOK) +} + +// JSONStatus writes the given value into the http.ResponseWriter and the +// given status is written as the status code of the response. +func JSONStatus(w http.ResponseWriter, v interface{}, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(v); err != nil { + log.Error(w, err) + + return + } + + log.EnabledResponse(w, v) +} + +// ProtoJSON writes the passed value into the http.ResponseWriter. +func ProtoJSON(w http.ResponseWriter, m proto.Message) { + ProtoJSONStatus(w, m, http.StatusOK) +} + +// ProtoJSONStatus writes the given value into the http.ResponseWriter and the +// given status is written as the status code of the response. +func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + b, err := protojson.Marshal(m) + if err != nil { + log.Error(w, err) + + return + } + + if _, err := w.Write(b); err != nil { + log.Error(w, err) + + return + } + + // log.EnabledResponse(w, v) +} diff --git a/api/render/render_test.go b/api/render/render_test.go new file mode 100644 index 000000000..de5919725 --- /dev/null +++ b/api/render/render_test.go @@ -0,0 +1,53 @@ +package render + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/smallstep/certificates/logging" +) + +func TestJSON(t *testing.T) { + type args struct { + rw http.ResponseWriter + v interface{} + } + tests := []struct { + name string + args args + ok bool + }{ + {"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true}, + {"fail", args{httptest.NewRecorder(), make(chan int)}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rw := logging.NewResponseLogger(tt.args.rw) + JSON(rw, tt.args.v) + + rr, ok := tt.args.rw.(*httptest.ResponseRecorder) + if !ok { + t.Error("ResponseWriter does not implement *httptest.ResponseRecorder") + return + } + + fields := rw.Fields() + if tt.ok { + if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" { + t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body) + } + if len(fields) != 0 { + t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields) + } + } else { + if body := rr.Body.String(); body != "" { + t.Errorf("Unexpected body = %s, want empty string", body) + } + if len(fields) != 1 { + t.Errorf("ResponseLogger fields = %v, wants 1 element", fields) + } + } + }) + } +} From eae0211a3e7de57f829c56e4bfb2e869aacdf366 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:18:41 +0200 Subject: [PATCH 02/19] acme/api: refactored to support api/render --- acme/api/account.go | 8 +++++--- acme/api/handler.go | 8 +++++--- acme/api/order.go | 11 +++++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index 0dc8ab402..88522ea68 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -5,8 +5,10 @@ import ( "net/http" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/logging" ) @@ -149,7 +151,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) - api.JSONStatus(w, acc, httpStatus) + render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. @@ -196,7 +198,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) - api.JSON(w, acc) + render.JSON(w, acc) } func logOrdersByAccount(w http.ResponseWriter, oids []string) { @@ -229,6 +231,6 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { h.linker.LinkOrdersByAccountID(ctx, orders) - api.JSON(w, orders) + render.JSON(w, orders) logOrdersByAccount(w, orders) } diff --git a/acme/api/handler.go b/acme/api/handler.go index c3a481f9e..7c4be8ae4 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -12,8 +12,10 @@ import ( "time" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" ) @@ -185,7 +187,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { return } - api.JSON(w, &Directory{ + render.JSON(w, &Directory{ NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), @@ -229,7 +231,7 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { h.linker.LinkAuthorization(ctx, az) w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) - api.JSON(w, az) + render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. @@ -280,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) - api.JSON(w, ch) + render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. diff --git a/acme/api/order.go b/acme/api/order.go index 9cf2c1eb4..125a21f93 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -11,9 +11,12 @@ import ( "time" "github.com/go-chi/chi" + + "go.step.sm/crypto/randutil" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" - "go.step.sm/crypto/randutil" + "github.com/smallstep/certificates/api/render" ) // NewOrderRequest represents the body for a NewOrder request. @@ -142,7 +145,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSONStatus(w, o, http.StatusCreated) + render.JSONStatus(w, o, http.StatusCreated) } func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { @@ -217,7 +220,7 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. @@ -272,7 +275,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // challengeTypes determines the types of challenges that should be used From b79af0456caeacb935961f3132977590998b0fa5 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:18:52 +0200 Subject: [PATCH 03/19] authority/admin: refactored to support api/render --- authority/admin/api/admin.go | 11 ++++++----- authority/admin/api/provisioner.go | 12 +++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 43607c523..b01c82e8b 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -10,6 +10,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" ) @@ -89,7 +90,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { "admin %s not found", id)) return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } // GetAdmins returns a segment of admins associated with the authority. @@ -106,7 +107,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return } - api.JSON(w, &GetAdminsResponse{ + render.JSON(w, &GetAdminsResponse{ Admins: admins, NextCursor: nextCursor, }) @@ -141,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { return } - api.ProtoJSONStatus(w, adm, http.StatusCreated) + render.ProtoJSONStatus(w, adm, http.StatusCreated) } // DeleteAdmin deletes admin. @@ -153,7 +154,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateAdmin updates an existing admin. @@ -177,5 +178,5 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 2106733df..872435c2b 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -4,10 +4,12 @@ import ( "net/http" "github.com/go-chi/chi" + "go.step.sm/linkedca" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" @@ -48,7 +50,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } - api.ProtoJSON(w, prov) + render.ProtoJSON(w, prov) } // GetProvisioners returns the given segment of provisioners associated with the authority. @@ -65,7 +67,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { api.WriteError(w, errs.InternalServerErr(err)) return } - api.JSON(w, &GetProvisionersResponse{ + render.JSON(w, &GetProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -89,7 +91,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } - api.ProtoJSONStatus(w, prov, http.StatusCreated) + render.ProtoJSONStatus(w, prov, http.StatusCreated) } // DeleteProvisioner deletes a provisioner. @@ -118,7 +120,7 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateProvisioner updates an existing prov. @@ -173,5 +175,5 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } - api.ProtoJSON(w, nu) + render.ProtoJSON(w, nu) } From 833ea1e6950925f6fe04b05ea2e68477219d178c Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:19:02 +0200 Subject: [PATCH 04/19] ca: refactored to support api/render --- ca/acmeClient_test.go | 51 ++++++++++++++++++++++--------------------- ca/bootstrap_test.go | 9 +++++--- ca/client_test.go | 29 ++++++++++++------------ 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index ad5f21164..f17a2f7a3 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -12,12 +12,13 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/pemutil" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" - "github.com/smallstep/certificates/api" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/pemutil" + "github.com/smallstep/certificates/api/render" ) func TestNewACMEClient(t *testing.T) { @@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header switch { case i == 0: - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ case i == 1: w.Header().Set("Replay-Nonce", "abc123") - api.JSONStatus(w, []byte{}, 200) + render.JSONStatus(w, []byte{}, 200) i++ default: w.Header().Set("Location", accLocation) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) @@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) }) if nonce, err := ac.GetNonce(); err != nil { @@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) { assert.Equals(t, hdr.KeyID, ac.kid) } - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { @@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, norb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.NewOrder(norb); err != nil { @@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetOrder(url); err != nil { @@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetAuthz(url); err != nil { @@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetChallenge(url); err != nil { @@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { assert.Equals(t, payload, []byte("{}")) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.ValidateChallenge(url); err != nil { @@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, frb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.FinalizeOrder(url, csr); err != nil { @@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := tc.client.GetAccountOrders(); err != nil { @@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { if tc.certBytes != nil { w.Write(tc.certBytes) } else { - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 0e16bd7dc..eb524e2e1 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -14,11 +14,14 @@ import ( "time" "github.com/pkg/errors" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/randutil" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/randutil" ) func newLocalListener() net.Listener { @@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/version" { - api.JSON(w, api.VersionResponse{ + render.JSON(w, api.VersionResponse{ Version: "test", RequireClientAuthentication: true, }) diff --git a/ca/client_test.go b/ca/client_test.go index 4628d19b3..d3a2914e8 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -24,6 +24,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -182,7 +183,7 @@ func TestClient_Version(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Version() @@ -232,7 +233,7 @@ func TestClient_Health(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() @@ -290,7 +291,7 @@ func TestClient_Root(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Root(tt.shasum) @@ -371,7 +372,7 @@ func TestClient_Sign(t *testing.T) { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Sign(tt.request) @@ -443,7 +444,7 @@ func TestClient_Revoke(t *testing.T) { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Revoke(tt.request, nil) @@ -503,7 +504,7 @@ func TestClient_Renew(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) @@ -640,7 +641,7 @@ func TestClient_Rekey(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) @@ -705,7 +706,7 @@ func TestClient_Provisioners(t *testing.T) { if req.RequestURI != tt.expectedURI { t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Provisioners(tt.args...) @@ -762,7 +763,7 @@ func TestClient_ProvisionerKey(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.ProvisionerKey(tt.kid) @@ -821,7 +822,7 @@ func TestClient_Roots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() @@ -879,7 +880,7 @@ func TestClient_Federation(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() @@ -941,7 +942,7 @@ func TestClient_SSHRoots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHRoots() @@ -1041,7 +1042,7 @@ func TestClient_RootFingerprint(t *testing.T) { } tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() @@ -1102,7 +1103,7 @@ func TestClient_SSHBastion(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) From 9aa480a09abc8f7bb8f86699c7e247608a19a78a Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:19:12 +0200 Subject: [PATCH 05/19] api: refactored to support api/render --- api/api.go | 15 +++++++------ api/rekey.go | 3 ++- api/renew.go | 3 ++- api/revoke.go | 3 ++- api/sign.go | 3 ++- api/ssh.go | 15 +++++++------ api/sshRekey.go | 3 ++- api/sshRenew.go | 3 ++- api/sshRevoke.go | 3 ++- api/utils.go | 57 ----------------------------------------------- api/utils_test.go | 53 ------------------------------------------- 11 files changed, 30 insertions(+), 131 deletions(-) delete mode 100644 api/utils.go delete mode 100644 api/utils_test.go diff --git a/api/api.go b/api/api.go index 47d2fd27b..4a1c8afeb 100644 --- a/api/api.go +++ b/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" @@ -282,7 +283,7 @@ func (h *caHandler) Route(r Router) { // Version is an HTTP handler that returns the version of the server. func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { v := h.Authority.Version() - JSON(w, VersionResponse{ + render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, }) @@ -290,7 +291,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { // Health is an HTTP handler that returns the status of the server. func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { - JSON(w, HealthResponse{Status: "ok"}) + render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root @@ -305,7 +306,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { return } - JSON(w, &RootResponse{RootPEM: Certificate{cert}}) + render.JSON(w, &RootResponse{RootPEM: Certificate{cert}}) } func certChainToPEM(certChain []*x509.Certificate) []Certificate { @@ -329,7 +330,7 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { WriteError(w, errs.InternalServerErr(err)) return } - JSON(w, &ProvisionersResponse{ + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -343,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { WriteError(w, errs.NotFoundErr(err)) return } - JSON(w, &ProvisionerKeyResponse{key}) + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. @@ -359,7 +360,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{roots[i]} } - JSONStatus(w, &RootsResponse{ + render.JSONStatus(w, &RootsResponse{ Certificates: certs, }, http.StatusCreated) } @@ -377,7 +378,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{federated[i]} } - JSONStatus(w, &FederationResponse{ + render.JSONStatus(w, &FederationResponse{ Certificates: certs, }, http.StatusCreated) } diff --git a/api/rekey.go b/api/rekey.go index 269086bb7..571594d95 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -55,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/renew.go b/api/renew.go index 408d91a30..8af88d8f4 100644 --- a/api/renew.go +++ b/api/renew.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -34,7 +35,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/revoke.go b/api/revoke.go index 49822e6da..205b8f384 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -103,7 +104,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } logRevoke(w, opts) - JSON(w, &RevokeResponse{Status: "ok"}) + render.JSON(w, &RevokeResponse{Status: "ok"}) } func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/sign.go b/api/sign.go index b2eef45d0..eadb525e0 100644 --- a/api/sign.go +++ b/api/sign.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -84,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { caPEM = certChainPEM[1] } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/ssh.go b/api/ssh.go index fc185d07d..b0e9d77ae 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -12,6 +12,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" @@ -334,7 +335,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { identityCertificate = certChainToPEM(certChain) } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{cert}, AddUserCertificate: addUserCertificate, IdentityCertificate: identityCertificate, @@ -363,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHFederation is an HTTP handler that returns the federated SSH public keys @@ -388,7 +389,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHConfig is an HTTP handler that returns rendered templates for ssh clients @@ -421,7 +422,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } - JSON(w, cfg) + render.JSON(w, cfg) } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. @@ -441,7 +442,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { WriteError(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHCheckPrincipalResponse{ + render.JSON(w, &SSHCheckPrincipalResponse{ Exists: exists, }) } @@ -458,7 +459,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { WriteError(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHGetHostsResponse{ + render.JSON(w, &SSHGetHostsResponse{ Hosts: hosts, }) } @@ -481,7 +482,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } - JSON(w, &SSHBastionResponse{ + render.JSON(w, &SSHBastionResponse{ Hostname: body.Hostname, Bastion: bastion, }) diff --git a/api/sshRekey.go b/api/sshRekey.go index b75817494..ba7f47a85 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -84,7 +85,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - JSONStatus(w, &SSHRekeyResponse{ + render.JSONStatus(w, &SSHRekeyResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRenew.go b/api/sshRenew.go index b98466bfc..33a30d250 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -76,7 +77,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 2d2da1f72..cc01f441c 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -6,6 +6,7 @@ import ( "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -82,7 +83,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { } logSSHRevoke(w, opts) - JSON(w, &SSHRevokeResponse{Status: "ok"}) + render.JSON(w, &SSHRevokeResponse{Status: "ok"}) } func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/utils.go b/api/utils.go deleted file mode 100644 index e3fcc9c40..000000000 --- a/api/utils.go +++ /dev/null @@ -1,57 +0,0 @@ -package api - -import ( - "encoding/json" - "net/http" - - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - - "github.com/smallstep/certificates/api/log" -) - -// JSON writes the passed value into the http.ResponseWriter. -func JSON(w http.ResponseWriter, v interface{}) { - JSONStatus(w, v, http.StatusOK) -} - -// JSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func JSONStatus(w http.ResponseWriter, v interface{}, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if err := json.NewEncoder(w).Encode(v); err != nil { - log.Error(w, err) - - return - } - - log.EnabledResponse(w, v) -} - -// ProtoJSON writes the passed value into the http.ResponseWriter. -func ProtoJSON(w http.ResponseWriter, m proto.Message) { - ProtoJSONStatus(w, m, http.StatusOK) -} - -// ProtoJSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - - b, err := protojson.Marshal(m) - if err != nil { - log.Error(w, err) - - return - } - - if _, err := w.Write(b); err != nil { - log.Error(w, err) - - return - } - - // log.EnabledResponse(w, v) -} diff --git a/api/utils_test.go b/api/utils_test.go deleted file mode 100644 index f5e1e1cb2..000000000 --- a/api/utils_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package api - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/smallstep/certificates/logging" -) - -func TestJSON(t *testing.T) { - type args struct { - rw http.ResponseWriter - v interface{} - } - tests := []struct { - name string - args args - ok bool - }{ - {"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true}, - {"fail", args{httptest.NewRecorder(), make(chan int)}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rw := logging.NewResponseLogger(tt.args.rw) - JSON(rw, tt.args.v) - - rr, ok := tt.args.rw.(*httptest.ResponseRecorder) - if !ok { - t.Error("ResponseWriter does not implement *httptest.ResponseRecorder") - return - } - - fields := rw.Fields() - if tt.ok { - if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" { - t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body) - } - if len(fields) != 0 { - t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields) - } - } else { - if body := rr.Body.String(); body != "" { - t.Errorf("Unexpected body = %s, want empty string", body) - } - if len(fields) != 1 { - t.Errorf("ResponseLogger fields = %v, wants 1 element", fields) - } - } - }) - } -} From 3389e57c48ef58543b30d8756237c29da3ba00d1 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:34:10 +0200 Subject: [PATCH 06/19] api/render: implemented Error --- api/errors.go | 66 -------------------------------------------- api/render/render.go | 57 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 66 deletions(-) delete mode 100644 api/errors.go diff --git a/api/errors.go b/api/errors.go deleted file mode 100644 index 49efd486d..000000000 --- a/api/errors.go +++ /dev/null @@ -1,66 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "net/http" - "os" - - "github.com/pkg/errors" - - "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api/log" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" - "github.com/smallstep/certificates/scep" -) - -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err error) { - switch k := err.(type) { - case *acme.Error: - acme.WriteError(w, k) - return - case *admin.Error: - admin.WriteError(w, k) - return - case *scep.Error: - w.Header().Set("Content-Type", "text/plain") - default: - w.Header().Set("Content-Type", "application/json") - } - - cause := errors.Cause(err) - if sc, ok := err.(errs.StatusCoder); ok { - w.WriteHeader(sc.StatusCode()) - } else { - if sc, ok := cause.(errs.StatusCoder); ok { - w.WriteHeader(sc.StatusCode()) - } else { - w.WriteHeader(http.StatusInternalServerError) - } - } - - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } else if e, ok := cause.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Error(w, err) - } -} diff --git a/api/render/render.go b/api/render/render.go index e4fbaffa9..56a68f303 100644 --- a/api/render/render.go +++ b/api/render/render.go @@ -3,12 +3,20 @@ package render import ( "encoding/json" + "fmt" "net/http" + "os" + "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/scep" ) // JSON writes the passed value into the http.ResponseWriter. @@ -56,3 +64,52 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { // log.EnabledResponse(w, v) } + +// Error encodes the JSON representation of err to w. +func Error(w http.ResponseWriter, err error) { + switch k := err.(type) { + case *acme.Error: + acme.WriteError(w, k) + return + case *admin.Error: + admin.WriteError(w, k) + return + case *scep.Error: + w.Header().Set("Content-Type", "text/plain") + default: + w.Header().Set("Content-Type", "application/json") + } + + cause := errors.Cause(err) + if sc, ok := err.(errs.StatusCoder); ok { + w.WriteHeader(sc.StatusCode()) + } else { + if sc, ok := cause.(errs.StatusCoder); ok { + w.WriteHeader(sc.StatusCode()) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + } + + // Write errors in the response writer + if rl, ok := w.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "error": err, + }) + if os.Getenv("STEPDEBUG") == "1" { + if e, ok := err.(errs.StackTracer); ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e), + }) + } else if e, ok := cause.(errs.StackTracer); ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e), + }) + } + } + } + + if err := json.NewEncoder(w).Encode(err); err != nil { + log.Error(w, err) + } +} From 9b6c1f608ee36c83484eaecea5ad8eb29a317607 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:35:12 +0200 Subject: [PATCH 07/19] api: refactored to support api/render.Error --- api/api.go | 12 ++++++------ api/rekey.go | 8 ++++---- api/renew.go | 4 ++-- api/revoke.go | 12 ++++++------ api/sign.go | 8 ++++---- api/ssh.go | 48 ++++++++++++++++++++++++------------------------ api/sshRekey.go | 15 ++++++++------- api/sshRenew.go | 13 +++++++------ api/sshRevoke.go | 8 ++++---- 9 files changed, 65 insertions(+), 63 deletions(-) diff --git a/api/api.go b/api/api.go index 4a1c8afeb..191d242d8 100644 --- a/api/api.go +++ b/api/api.go @@ -302,7 +302,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { // Load root certificate with the cert, err := h.Authority.Root(sum) if err != nil { - WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) + render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return } @@ -321,13 +321,13 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - WriteError(w, err) + render.Error(w, err) return } p, next, err := h.Authority.GetProvisioners(cursor, limit) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } render.JSON(w, &ProvisionersResponse{ @@ -341,7 +341,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := h.Authority.GetEncryptedKey(kid) if err != nil { - WriteError(w, errs.NotFoundErr(err)) + render.Error(w, errs.NotFoundErr(err)) return } render.JSON(w, &ProvisionerKeyResponse{key}) @@ -351,7 +351,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return } @@ -369,7 +369,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { federated, err := h.Authority.GetFederation() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting federated roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return } diff --git a/api/rekey.go b/api/rekey.go index 571594d95..3116cf74b 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -29,24 +29,24 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing client certificate")) + render.Error(w, errs.BadRequest("missing client certificate")) return } var body RekeyRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return } certChainPEM := certChainToPEM(certChain) diff --git a/api/renew.go b/api/renew.go index 8af88d8f4..9c4bff32a 100644 --- a/api/renew.go +++ b/api/renew.go @@ -19,13 +19,13 @@ const ( func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { cert, err := h.getPeerCertificate(r) if err != nil { - WriteError(w, err) + render.Error(w, err) return } certChain, err := h.Authority.Renew(cert) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return } certChainPEM := certChainToPEM(certChain) diff --git a/api/revoke.go b/api/revoke.go index 205b8f384..c9da2c188 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -52,12 +52,12 @@ func (r *RevokeRequest) Validate() (err error) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -74,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { if len(body.OTT) > 0 { logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT @@ -83,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing ott or client certificate")) + render.Error(w, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest("serial number in client certificate different than body")) + render.Error(w, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. @@ -99,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } diff --git a/api/sign.go b/api/sign.go index eadb525e0..b6bfcc8b3 100644 --- a/api/sign.go +++ b/api/sign.go @@ -52,13 +52,13 @@ type SignResponse struct { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -70,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { signOpts, err := h.Authority.AuthorizeSign(body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return } certChainPEM := certChainToPEM(certChain) diff --git a/api/ssh.go b/api/ssh.go index b0e9d77ae..3b0de7c12 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -253,19 +253,19 @@ type SSHBastionResponse struct { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -273,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } @@ -290,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } @@ -304,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -317,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } @@ -329,7 +329,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return } identityCertificate = certChainToPEM(certChain) @@ -347,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHRoots(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -372,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHFederation(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -397,17 +397,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } @@ -418,7 +418,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { case provisioner.SSHHostCert: cfg.HostTemplates = ts default: - WriteError(w, errs.InternalServer("it should hot get here")) + render.Error(w, errs.InternalServer("it should hot get here")) return } @@ -429,17 +429,17 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } render.JSON(w, &SSHCheckPrincipalResponse{ @@ -456,7 +456,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } render.JSON(w, &SSHGetHostsResponse{ @@ -468,17 +468,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } diff --git a/api/sshRekey.go b/api/sshRekey.go index ba7f47a85..922789506 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -42,36 +42,37 @@ type SSHRekeyResponse struct { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return } @@ -81,7 +82,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index 33a30d250..78d16fa65 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -40,30 +40,31 @@ type SSHRenewResponse struct { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RenewSSH(ctx, oldCert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return } @@ -73,7 +74,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index cc01f441c..a33082cdc 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -51,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -72,13 +72,13 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } From 6636e87fc7610ebb1509c796bcef29554929299c Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:35:23 +0200 Subject: [PATCH 08/19] acme/api: refactored to support api/render.Error --- acme/api/account.go | 39 +++++++++---------- acme/api/handler.go | 30 +++++++-------- acme/api/middleware.go | 87 +++++++++++++++++++++--------------------- acme/api/order.go | 45 +++++++++++----------- acme/api/revoke.go | 39 ++++++++++--------- 5 files changed, 120 insertions(+), 120 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index 88522ea68..ade51aef9 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -7,7 +7,6 @@ import ( "github.com/go-chi/chi" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/logging" ) @@ -72,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -98,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { // Something went wrong ... - api.WriteError(w, err) + render.Error(w, err) return } // Account does not exist // if nar.OnlyReturnExisting { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist")) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } eak, err := h.validateExternalAccountBinding(ctx, &nar) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -127,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { Status: acme.StatusValid, } if err := h.db.CreateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) + render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response err := eak.BindTo(acc) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) + render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } acc.ExternalAccountBinding = nar.ExternalAccountBinding @@ -159,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -173,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(uar.Status) > 0 || len(uar.Contact) > 0 { @@ -189,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { } if err := h.db.UpdateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) + render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } @@ -215,17 +214,17 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } accID := chi.URLParam(r, "accID") if acc.ID != accID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } diff --git a/acme/api/handler.go b/acme/api/handler.go index 7c4be8ae4..10eb22cb2 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -183,7 +183,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -202,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. @@ -210,21 +210,21 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return } if acc.ID != az.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } if err = az.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } @@ -239,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Just verify that the payload was set, since we're not strictly adhering // to ACME V2 spec for reasons specified below. _, err = payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -259,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { azID := chi.URLParam(r, "authzID") ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return } ch.AuthorizationID = azID if acc.ID != ch.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) + render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } @@ -290,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } certID := chi.URLParam(r, "certID") cert, err := h.db.GetCertificate(ctx, certID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return } if cert.AccountID != acc.ID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own certificate '%s'", acc.ID, certID)) return } diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 0cdeaabb1..10f7841fc 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -10,13 +10,14 @@ import ( "strings" "github.com/go-chi/chi" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/nosql" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { nonce, err := h.db.CreateNonce(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } w.Header().Set("Replay-Nonce", string(nonce)) @@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { var expected []string p, err := provisionerFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return } } - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) + render.Error(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } ctx := context.WithValue(r.Context(), jwsContextKey, jws) @@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(jws.Signatures) == 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected @@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least %d bits (%d bytes) in size", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match")) return } @@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) + render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && hdr.KeyID == "" { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } // Overwrite KeyID with the JWK thumbprint. jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) + render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } @@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { // For NewAccount and Revoke requests ... break case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -290,17 +291,17 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } p, err := h.ca.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) @@ -315,11 +316,11 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { ctx := r.Context() ok, err := h.prerequisitesChecker(ctx) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) return } if !ok { - api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) return } next(w, r.WithContext(ctx)) @@ -334,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got %s", kidPrefix, kid)) return @@ -351,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -376,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -412,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ @@ -443,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !payload.isPostAsGet { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) diff --git a/acme/api/order.go b/acme/api/order.go index 125a21f93..99eb0e95d 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -15,7 +15,6 @@ import ( "go.step.sm/crypto/randutil" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" ) @@ -73,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -119,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { Status: acme.StatusPending, } if err := h.newAuthorization(ctx, az); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o.AuthorizationIDs[i] = az.ID @@ -138,7 +137,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } if err := h.db.CreateOrder(ctx, o); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) + render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } @@ -189,31 +188,31 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) + render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } @@ -228,47 +227,47 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) + render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } diff --git a/acme/api/revoke.go b/acme/api/revoke.go index d01e401c9..4b71bc22e 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -10,13 +10,14 @@ import ( "net/http" "strings" + "go.step.sm/crypto/jose" + "golang.org/x/crypto/ocsp" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" - "go.step.sm/crypto/jose" - "golang.org/x/crypto/ocsp" ) type revokePayload struct { @@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var p revokePayload err = json.Unmarshal(payload.value, &p) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) + render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload")) return } certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) if err != nil { // in this case the most likely cause is a client that didn't properly encode the certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) return } certToBeRevoked, err := x509.ParseCertificate(certBytes) if err != nil { // in this case a client may have encoded something different than a certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) return } serial := certToBeRevoked.SerialNumber.String() dbCert, err := h.db.GetCertificateBySerial(ctx, serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return } if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { // this should never happen - api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) + render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal")) return } if shouldCheckAccountFrom(jws) { account, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } } else { @@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { _, err := jws.Verify(certToBeRevoked.PublicKey) if err != nil { // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? - api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) + render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) return } } hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return } if hasBeenRevokedBefore { - api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) + render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) return } reasonCode := p.ReasonCode acmeErr := validateReasonCode(reasonCode) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } @@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) err = prov.AuthorizeRevoke(ctx, "") if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) + render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) return } options := revokeOptions(serial, certToBeRevoked, reasonCode) err = h.ca.Revoke(ctx, options) if err != nil { - api.WriteError(w, wrapRevokeErr(err)) + render.Error(w, wrapRevokeErr(err)) return } From 098c2e1134900ef01b41cb51a83cd1ba3e9e6f82 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:35:33 +0200 Subject: [PATCH 09/19] authority/admin: refactored to support api/render.Error --- authority/admin/api/acme.go | 16 +++++++----- authority/admin/api/admin.go | 22 ++++++++-------- authority/admin/api/middleware.go | 8 +++--- authority/admin/api/provisioner.go | 42 +++++++++++++++--------------- 4 files changed, 45 insertions(+), 43 deletions(-) diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 27c3ba6f3..21a7229d4 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -6,10 +6,12 @@ import ( "net/http" "github.com/go-chi/chi" - "github.com/smallstep/certificates/api" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) const ( @@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { provName := chi.URLParam(r, "provisionerName") eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !eabEnabled { - api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) + render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) return } ctx = context.WithValue(ctx, provisionerContextKey, prov) @@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder { // GetExternalAccountKeys writes the response for the EAB keys GET endpoint func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index b01c82e8b..5e4b9c301 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -86,7 +86,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { adm, ok := h.auth.LoadAdminByID(id) if !ok { - api.WriteError(w, admin.NewError(admin.ErrorNotFoundType, + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) return } @@ -97,14 +97,14 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) + render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return } render.JSON(w, &GetAdminsResponse{ @@ -117,18 +117,18 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } p, err := h.auth.LoadProvisionerByName(body.Provisioner) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return } adm := &linkedca.Admin{ @@ -138,7 +138,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // Store to authority collection. if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing admin")) + render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } @@ -150,7 +150,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } @@ -161,12 +161,12 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -174,7 +174,7 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 19025a9d7..b57dd6eb5 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" ) @@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request) func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { if !h.auth.IsAdminAPIEnabled() { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } @@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { - api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, + render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, "missing authorization header token")) return } adm, err := h.auth.AuthorizeAdminToken(r, tok) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 872435c2b..1cad62dde 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -35,19 +35,19 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } render.ProtoJSON(w, prov) @@ -57,14 +57,14 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } p, next, err := h.auth.GetProvisioners(cursor, limit) if err != nil { - api.WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } render.JSON(w, &GetProvisionersResponse{ @@ -77,18 +77,18 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // TODO: Validate inputs if err := authority.ValidateClaims(prov.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) + render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } render.ProtoJSONStatus(w, prov, http.StatusCreated) @@ -105,18 +105,18 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) + render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } @@ -127,52 +127,52 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } name := chi.URLParam(r, "name") _old, err := h.auth.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) return } if nu.Id != old.Id { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner ID")) return } if nu.Type != old.Type { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner type")) + render.Error(w, admin.NewErrorISE("cannot change provisioner type")) return } if nu.AuthorityId != old.AuthorityId { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID")) return } if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt")) return } if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt")) return } // TODO: Validate inputs if err := authority.ValidateClaims(nu.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } render.ProtoJSON(w, nu) From d49c00b0d7b2de8ef00afdf6224660d917842d44 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:35:42 +0200 Subject: [PATCH 10/19] ca: refactored to support api/render.Error --- ca/bootstrap_test.go | 2 +- ca/client_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index eb524e2e1..2332b4d43 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -96,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han } isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 if !isMTLS { - api.WriteError(w, errs.Unauthorized("missing peer certificate")) + render.Error(w, errs.Unauthorized("missing peer certificate")) } else { next.ServeHTTP(w, r) } diff --git a/ca/client_test.go b/ca/client_test.go index d3a2914e8..b83f689e7 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -361,7 +361,7 @@ func TestClient_Sign(t *testing.T) { if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -433,7 +433,7 @@ func TestClient_Revoke(t *testing.T) { if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { From 23c81db95a34e73ef918731699af3c30a9b5a039 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Fri, 18 Mar 2022 17:35:55 +0200 Subject: [PATCH 11/19] scep: refactored to support api/render.Error --- scep/api/api.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scep/api/api.go b/scep/api/api.go index a326ea922..39fc56b0f 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -16,6 +16,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/scep" ) @@ -189,19 +190,19 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { name := chi.URLParam(r, "provisionerName") provisionerName, err := url.PathUnescape(name) if err != nil { - api.WriteError(w, errors.Errorf("error url unescaping provisioner name '%s'", name)) + render.Error(w, errors.Errorf("error url unescaping provisioner name '%s'", name)) return } p, err := h.Auth.LoadProvisionerByName(provisionerName) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, ok := p.(*provisioner.SCEP) if !ok { - api.WriteError(w, errors.New("provisioner must be of type SCEP")) + render.Error(w, errors.New("provisioner must be of type SCEP")) return } @@ -356,7 +357,7 @@ func writeError(w http.ResponseWriter, err error) { Message: err.Error(), Status: http.StatusInternalServerError, // TODO: make this a param? } - api.WriteError(w, scepError) + render.Error(w, scepError) } func (h *Handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (SCEPResponse, error) { From 4cdb38b2e8baa49c100d8c67d87916eb69c52164 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Mon, 21 Mar 2022 18:50:12 +0200 Subject: [PATCH 12/19] ca: fixed broken tests --- ca/client_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ca/client_test.go b/ca/client_test.go index b83f689e7..adef10310 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -569,9 +569,9 @@ func TestClient_RenewWithToken(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Header.Get("Authorization") != "Bearer token" { - api.JSONStatus(w, errs.InternalServer("force"), 500) + render.JSONStatus(w, errs.InternalServer("force"), 500) } else { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) } }) From e82b21c1cbbe24df1f0dac44cf300c68dea8d805 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 15:36:05 +0200 Subject: [PATCH 13/19] api/render: implemented BadRequest --- api/render/render.go | 27 +++++++++++++++++++++++++++ api/render/render_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/api/render/render.go b/api/render/render.go index 56a68f303..12ad70aae 100644 --- a/api/render/render.go +++ b/api/render/render.go @@ -113,3 +113,30 @@ func Error(w http.ResponseWriter, err error) { log.Error(w, err) } } + +// BadRequest renders the JSON representation of err into w and sets its +// status code to http.StatusBadRequest. +func BadRequest(w http.ResponseWriter, err error) { + codedError(w, http.StatusBadRequest, err) +} + +func codedError(w http.ResponseWriter, code int, err error) { + var wrapper = struct { + Status int `json:"status"` + Message string `json:"message"` + }{ + Status: code, + Message: err.Error(), + } + + data, err := json.Marshal(wrapper) + if err != nil { + log.Error(w, err) + + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + w.Write(data) +} diff --git a/api/render/render_test.go b/api/render/render_test.go index de5919725..f1d7c145a 100644 --- a/api/render/render_test.go +++ b/api/render/render_test.go @@ -3,8 +3,11 @@ package render import ( "net/http" "net/http/httptest" + "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/smallstep/certificates/logging" ) @@ -51,3 +54,35 @@ func TestJSON(t *testing.T) { }) } } + +func TestErrors(t *testing.T) { + cases := []struct { + fn func(http.ResponseWriter, error) // helper + err error // error being render + code int // expected status code + body string // expected body + }{ + // --- BadRequest + 0: { + fn: BadRequest, + err: assert.AnError, + code: http.StatusBadRequest, + body: `{"status":400,"message":"assert.AnError general error for testing"}`, + }, + } + + for caseIndex := range cases { + kase := cases[caseIndex] + + t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { + rec := httptest.NewRecorder() + kase.fn(rec, kase.err) + + ret := rec.Result() + + assert.Equal(t, "application/json", ret.Header.Get("Content-Type")) + assert.Equal(t, kase.code, ret.StatusCode) + assert.Equal(t, kase.body, rec.Body.String()) + }) + } +} From 2fd84227f083ff16231d57dbfde6aaeca6a5f473 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 18:36:57 +0200 Subject: [PATCH 14/19] internal/buffer: initial implementation of the package --- internal/buffer/buffer.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 internal/buffer/buffer.go diff --git a/internal/buffer/buffer.go b/internal/buffer/buffer.go new file mode 100644 index 000000000..85d96a3f8 --- /dev/null +++ b/internal/buffer/buffer.go @@ -0,0 +1,23 @@ +// Package buffer implements a reusable buffer pool. +package buffer + +import ( + "bytes" + "sync" +) + +func Get() *bytes.Buffer { + return pool.Get().(*bytes.Buffer) +} + +func Put(b *bytes.Buffer) { + b.Reset() + + pool.Put(b) +} + +var pool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} From a715e57d04597211985f7f7e49e7432cc6779115 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 18:38:03 +0200 Subject: [PATCH 15/19] api/read: reworked JSON & ProtoJSON to use buffers & added AdminJSON --- api/read/read.go | 87 ++++++++++++++++++++++++++++++++++++------- api/read/read_test.go | 71 ++++++++++++++++++++--------------- 2 files changed, 115 insertions(+), 43 deletions(-) diff --git a/api/read/read.go b/api/read/read.go index de92c5d77..ee72cdb7f 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -2,30 +2,91 @@ package read import ( + "bytes" "encoding/json" - "io" + "fmt" + "net/http" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/admin" + + "github.com/smallstep/certificates/internal/buffer" ) -// JSON reads JSON from the request body and stores it in the value -// pointed by v. -func JSON(r io.Reader, v interface{}) error { - if err := json.NewDecoder(r).Decode(v); err != nil { - return errs.BadRequestErr(err, "error decoding json") +// JSON unmarshals from the given request's JSON body into v. In case of an +// error a HTTP Bad Request error will be written to w. +func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool { + b := read(w, r) + if b == nil { + return false + } + defer buffer.Put(b) + + if err := json.NewDecoder(b).Decode(v); err != nil { + err = fmt.Errorf("error decoding json: %w", err) + + render.BadRequest(w, err) + + return false + } + + return true +} + +// AdminJSON is obsolete; it's here for backwards compatibility. +// +// Please don't use. +func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool { + b := read(w, r) + if b == nil { + return false + } + defer buffer.Put(b) + + if err := json.NewDecoder(b).Decode(v); err != nil { + e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body") + admin.WriteError(w, e) + + return false } - return nil + + return true } // ProtoJSON reads JSON from the request body and stores it in the value // pointed by v. -func ProtoJSON(r io.Reader, m proto.Message) error { - data, err := io.ReadAll(r) - if err != nil { - return errs.BadRequestErr(err, "error reading request body") +func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool { + b := read(w, r) + if b == nil { + return false } - return protojson.Unmarshal(data, m) + defer buffer.Put(b) + + if err := protojson.Unmarshal(b.Bytes(), m); err != nil { + err = fmt.Errorf("error decoding proto json: %w", err) + + render.BadRequest(w, err) + + return false + } + + return true +} + +func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer { + b := buffer.Get() + if _, err := b.ReadFrom(r.Body); err != nil { + buffer.Put(b) + + err = fmt.Errorf("error reading request body: %w", err) + + render.BadRequest(w, err) + + return nil + } + + return b } diff --git a/api/read/read_test.go b/api/read/read_test.go index f2eff1bc8..3fde9d608 100644 --- a/api/read/read_test.go +++ b/api/read/read_test.go @@ -2,45 +2,56 @@ package read import ( "io" - "reflect" + "net/http" + "net/http/httptest" + "strconv" "strings" "testing" + "testing/iotest" - "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" ) func TestJSON(t *testing.T) { - type args struct { - r io.Reader - v interface{} - } - tests := []struct { - name string - args args - wantErr bool + cases := []struct { + src io.Reader + exp interface{} + ok bool + code int }{ - {"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, - {"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, + 0: { + src: strings.NewReader(`{"foo":"bar"}`), + exp: map[string]interface{}{"foo": "bar"}, + ok: true, + code: http.StatusOK, + }, + 1: { + src: strings.NewReader(`{"foo"}`), + code: http.StatusBadRequest, + }, + 2: { + src: io.MultiReader( + strings.NewReader(`{`), + iotest.ErrReader(assert.AnError), + strings.NewReader(`"foo":"bar"}`), + ), + code: http.StatusBadRequest, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := JSON(tt.args.r, &tt.args.v) - if (err != nil) != tt.wantErr { - t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - e, ok := err.(*errs.Error) - if ok { - if code := e.StatusCode(); code != 400 { - t.Errorf("error.StatusCode() = %v, wants 400", code) - } - } else { - t.Errorf("error type = %T, wants *Error", err) - } - } else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) { - t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) - } + for caseIndex := range cases { + kase := cases[caseIndex] + + t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", kase.src) + rec := httptest.NewRecorder() + + var body interface{} + got := JSON(rec, req, &body) + + assert.Equal(t, kase.ok, got) + assert.Equal(t, kase.code, rec.Result().StatusCode) + assert.Equal(t, kase.exp, body) }) } } From a5c171e750b14b222e567586b414e69a45ea2a67 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 18:38:46 +0200 Subject: [PATCH 16/19] api: refactored to support the new signatures for read.JSON, read.AdminJSON & read.ProtoJSON --- api/rekey.go | 3 +-- api/revoke.go | 3 +-- api/sign.go | 3 +-- api/ssh.go | 12 ++++-------- api/sshRekey.go | 3 +-- api/sshRenew.go | 3 +-- api/sshRevoke.go | 3 +-- 7 files changed, 10 insertions(+), 20 deletions(-) diff --git a/api/rekey.go b/api/rekey.go index 3116cf74b..b1a6c2fcd 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -34,8 +34,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { } var body RekeyRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } diff --git a/api/revoke.go b/api/revoke.go index c9da2c188..001b36245 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -51,8 +51,7 @@ func (r *RevokeRequest) Validate() (err error) { // TODO: Add CRL and OCSP support. func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } diff --git a/api/sign.go b/api/sign.go index b6bfcc8b3..6d07029de 100644 --- a/api/sign.go +++ b/api/sign.go @@ -51,8 +51,7 @@ type SignResponse struct { // information in the certificate request. func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } diff --git a/api/ssh.go b/api/ssh.go index 3b0de7c12..21d30ebc5 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -252,8 +252,7 @@ type SSHBastionResponse struct { // the request. func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } @@ -396,8 +395,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { // and servers. func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } if err := body.Validate(); err != nil { @@ -428,8 +426,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } if err := body.Validate(); err != nil { @@ -467,8 +464,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { // SSHBastion provides returns the bastion configured if any. func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } if err := body.Validate(); err != nil { diff --git a/api/sshRekey.go b/api/sshRekey.go index 922789506..acb17cca2 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -41,8 +41,7 @@ type SSHRekeyResponse struct { // the request. func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } diff --git a/api/sshRenew.go b/api/sshRenew.go index 78d16fa65..f85819b6b 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -39,8 +39,7 @@ type SSHRenewResponse struct { // the request. func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if read.JSON(w, r, &body) { return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index a33082cdc..ba1a1de9b 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -50,8 +50,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { // NOTE: currently only Passive revocation is supported. func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + if !read.JSON(w, r, &body) { return } From 2e729ebb262e81e870c785943611ff673b164e86 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 18:39:13 +0200 Subject: [PATCH 17/19] authority/admin/api: refactored to support read.AdminJSON --- authority/admin/api/admin.go | 6 ++---- authority/admin/api/provisioner.go | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 5e4b9c301..6f6000e03 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -116,8 +116,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { // CreateAdmin creates a new admin. func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + if !read.AdminJSON(w, r, &body) { return } @@ -160,8 +159,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { // UpdateAdmin updates an existing admin. func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest - if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + if !read.AdminJSON(w, r, &body) { return } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 1cad62dde..540188f17 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -76,8 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { // CreateProvisioner creates a new prov. func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) - if err := read.ProtoJSON(r.Body, prov); err != nil { - render.Error(w, err) + if !read.ProtoJSON(w, r, prov) { return } @@ -126,8 +125,7 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { // UpdateProvisioner updates an existing prov. func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) - if err := read.ProtoJSON(r.Body, nu); err != nil { - render.Error(w, err) + if !read.ProtoJSON(w, r, nu) { return } From c3cc60e21187253ba5adacdaee9f14032a6291a4 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 18:39:42 +0200 Subject: [PATCH 18/19] go.mod: added dependency to stretchr/testify --- go.mod | 1 + go.sum | 3 +++ 2 files changed, 4 insertions(+) diff --git a/go.mod b/go.mod index 6033d05e9..519fbf1be 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/slackhq/nebula v1.5.2 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/nosql v0.4.0 + github.com/stretchr/testify v1.7.0 github.com/urfave/cli v1.22.4 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 diff --git a/go.sum b/go.sum index c7a18aad1..ae6f5ce91 100644 --- a/go.sum +++ b/go.sum @@ -445,9 +445,11 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -1120,6 +1122,7 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= From 845e41967d42ec023ffd650d57ecbe53b6c3e0b9 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 22 Mar 2022 19:15:23 +0200 Subject: [PATCH 19/19] api/render: use default error messages if none specified --- api/render/render.go | 6 ++++++ api/render/render_test.go | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/api/render/render.go b/api/render/render.go index 12ad70aae..5a37ec84c 100644 --- a/api/render/render.go +++ b/api/render/render.go @@ -116,11 +116,17 @@ func Error(w http.ResponseWriter, err error) { // BadRequest renders the JSON representation of err into w and sets its // status code to http.StatusBadRequest. +// +// In case err is nil, a default error message will be used in its place. func BadRequest(w http.ResponseWriter, err error) { codedError(w, http.StatusBadRequest, err) } func codedError(w http.ResponseWriter, code int, err error) { + if err == nil { + err = errors.New(http.StatusText(code)) + } + var wrapper = struct { Status int `json:"status"` Message string `json:"message"` diff --git a/api/render/render_test.go b/api/render/render_test.go index f1d7c145a..8c7c99a4d 100644 --- a/api/render/render_test.go +++ b/api/render/render_test.go @@ -69,6 +69,11 @@ func TestErrors(t *testing.T) { code: http.StatusBadRequest, body: `{"status":400,"message":"assert.AnError general error for testing"}`, }, + 1: { + fn: BadRequest, + code: http.StatusBadRequest, + body: `{"status":400,"message":"Bad Request"}`, + }, } for caseIndex := range cases {