Skip to content

Commit

Permalink
test(gateway): slight refactor of some presentation and add some endp…
Browse files Browse the repository at this point in the history
…oint registration tests (#2939)

goals: readability, test coverage

Appreciate that the gateway is v unimportant but is exists. I was
starting at the top just looking for places to give a clean up and add
some unit test coverage and found this.

Mostly trying to make it cleaner to read, then added some tests, also
wanted to add a `/status/health` endpoint just for ease, i would want
this as a node operator if i was running with `--gateway` for whatever
reason

Also took liberty to take some error handling code out of `writeError`
that would never execute. In trying to write a test to trigger the
json.Marshal error case, i realized that as err.Error() always returns a
string it could never trigger json.Unmarshal to fail without panic
before even getting there, so removed that code path.

i might revisit this and give it another refactor at some point soon,
unless we're planning to remove this completely?

on to next area
  • Loading branch information
ramin authored Jan 23, 2024
1 parent 8f768a7 commit c173a1e
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 43 deletions.
73 changes: 73 additions & 0 deletions api/gateway/bindings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package gateway

import (
"fmt"
"net/http"
)

func (h *Handler) RegisterEndpoints(rpc *Server) {
// state endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", balanceEndpoint, addrKey),
h.handleBalanceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
submitTxEndpoint,
h.handleSubmitTx,
http.MethodPost,
)

rpc.RegisterHandlerFunc(
healthEndpoint,
h.handleHealthRequest,
http.MethodGet,
)

// share endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf(
"%s/{%s}/height/{%s}",
namespacedSharesEndpoint,
namespaceKey,
heightKey,
),
h.handleSharesByNamespaceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", namespacedSharesEndpoint, namespaceKey),
h.handleSharesByNamespaceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}/height/{%s}", namespacedDataEndpoint, namespaceKey, heightKey),
h.handleDataByNamespaceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", namespacedDataEndpoint, namespaceKey),
h.handleDataByNamespaceRequest,
http.MethodGet,
)

// DAS endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", heightAvailabilityEndpoint, heightKey),
h.handleHeightAvailabilityRequest,
http.MethodGet,
)

// header endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", headerByHeightEndpoint, heightKey),
h.handleHeaderRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(headEndpoint, h.handleHeadRequest, http.MethodGet)
}
119 changes: 119 additions & 0 deletions api/gateway/bindings_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package gateway

import (
"fmt"
"net/http"
"testing"

"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
)

func TestRegisterEndpoints(t *testing.T) {
handler := &Handler{}
rpc := NewServer("localhost", "6969")

handler.RegisterEndpoints(rpc)

testCases := []struct {
name string
path string
method string
expected bool
}{
{
name: "Get balance endpoint",
path: fmt.Sprintf("%s/{%s}", balanceEndpoint, addrKey),
method: http.MethodGet,
expected: true,
},
{
name: "Submit transaction endpoint",
path: submitTxEndpoint,
method: http.MethodPost,
expected: true,
},
{
name: "Get namespaced shares by height endpoint",
path: fmt.Sprintf("%s/{%s}/height/{%s}", namespacedSharesEndpoint, namespaceKey, heightKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get namespaced shares endpoint",
path: fmt.Sprintf("%s/{%s}", namespacedSharesEndpoint, namespaceKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get namespaced data by height endpoint",
path: fmt.Sprintf("%s/{%s}/height/{%s}", namespacedDataEndpoint, namespaceKey, heightKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get namespaced data endpoint",
path: fmt.Sprintf("%s/{%s}", namespacedDataEndpoint, namespaceKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get health endpoint",
path: "/status/health",
method: http.MethodGet,
expected: true,
},

// Going forward, we can add previously deprecated and since
// removed endpoints here to ensure we don't accidentally re-enable
// them in the future and accidentally expand surface area
{
name: "example totally bogus endpoint",
path: fmt.Sprintf("/wutang/{%s}/%s", "chambers", "36"),
method: http.MethodGet,
expected: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(
t,
tc.expected,
hasEndpointRegistered(rpc.Router(), tc.path, tc.method),
"Endpoint registration mismatch for: %s %s %s", tc.name, tc.method, tc.path)
})
}
}

func hasEndpointRegistered(router *mux.Router, path string, method string) bool {
var registered bool
err := router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
template, err := route.GetPathTemplate()
if err != nil {
return err
}

if template == path {
methods, err := route.GetMethods()
if err != nil {
return err
}

for _, m := range methods {
if m == method {
registered = true
return nil
}
}
}
return nil
})

if err != nil {
fmt.Println("Error walking through routes:", err)
return false
}

return registered
}
32 changes: 0 additions & 32 deletions api/gateway/endpoints.go

This file was deleted.

16 changes: 16 additions & 0 deletions api/gateway/health.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package gateway

import "net/http"

const (
healthEndpoint = "/status/health"
)

func (h *Handler) handleHealthRequest(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("ok"))
if err != nil {
log.Errorw("serving request", "endpoint", healthEndpoint, "err", err)
writeError(w, http.StatusBadGateway, healthEndpoint, err)
return
}
}
4 changes: 4 additions & 0 deletions api/gateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func NewServer(address, port string) *Server {
return server
}

func (s *Server) Router() *mux.Router {
return s.srvMux
}

// Start starts the gateway Server, listening on the given address.
func (s *Server) Start(context.Context) error {
couldStart := s.started.CompareAndSwap(false, true)
Expand Down
16 changes: 7 additions & 9 deletions api/gateway/util.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
package gateway

import (
"encoding/json"
"net/http"
)

func writeError(w http.ResponseWriter, statusCode int, endpoint string, err error) {
log.Debugw("serving request", "endpoint", endpoint, "err", err)

w.WriteHeader(statusCode)
errBody, jerr := json.Marshal(err.Error())
if jerr != nil {
log.Errorw("serializing error", "endpoint", endpoint, "err", jerr)
return
}
_, werr := w.Write(errBody)
if werr != nil {
log.Errorw("writing error response", "endpoint", endpoint, "err", werr)

errorMessage := err.Error() // Get the error message as a string
errorBytes := []byte(errorMessage)

_, err = w.Write(errorBytes)
if err != nil {
log.Errorw("writing error response", "endpoint", endpoint, "err", err)
}
}
24 changes: 24 additions & 0 deletions api/gateway/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package gateway

import (
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestWriteError(t *testing.T) {
t.Run("writeError", func(t *testing.T) {
// Create a mock HTTP response writer
w := httptest.NewRecorder()

testErr := errors.New("test error")

writeError(w, http.StatusInternalServerError, "/api/endpoint", testErr)
assert.Equal(t, http.StatusInternalServerError, w.Code)
responseBody := w.Body.Bytes()
assert.Equal(t, testErr.Error(), string(responseBody))
})
}
4 changes: 2 additions & 2 deletions cmd/celestia/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestCompletionHelpString(t *testing.T) {
}
methods := reflect.VisibleFields(reflect.TypeOf(TestFields{}))
for i, method := range methods {
require.Equal(t, testOutputs[i], parseSignatureForHelpstring(method))
require.Equal(t, testOutputs[i], parseSignatureForHelpString(method))
}
}

Expand Down Expand Up @@ -129,7 +129,7 @@ func TestBridge(t *testing.T) {
*/
}

func parseSignatureForHelpstring(methodSig reflect.StructField) string {
func parseSignatureForHelpString(methodSig reflect.StructField) string {
simplifiedSignature := "("
in, out := methodSig.Type.NumIn(), methodSig.Type.NumOut()
for i := 1; i < in; i++ {
Expand Down

0 comments on commit c173a1e

Please sign in to comment.