From c18bd7de7214091cffc3cd6e842db6577c6d5755 Mon Sep 17 00:00:00 2001 From: Javad Date: Sun, 9 Jun 2024 09:13:25 +0330 Subject: [PATCH] feat: support client headers request currently support client custom headers and convert to metadata.MD for grpc. --- _example/proto/echo_jgw.pb.go | 17 ++++++++-- jrpc/server.go | 44 ++++++++++++++++++++++++- jrpc/server_test.go | 16 +++++++-- protoc-gen-jrpc-gateway/internal/jgw.go | 18 ++++++++-- 4 files changed, 87 insertions(+), 8 deletions(-) diff --git a/_example/proto/echo_jgw.pb.go b/_example/proto/echo_jgw.pb.go index cb0dcdc..d66d928 100644 --- a/_example/proto/echo_jgw.pb.go +++ b/_example/proto/echo_jgw.pb.go @@ -13,6 +13,7 @@ import ( "encoding/json" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/encoding/protojson" ) @@ -20,6 +21,11 @@ type EchoServiceJsonRpcService struct { client EchoServiceClient } +type paramsAndHeaders struct { + Headers metadata.MD `json:"headers,omitempty"` + Params json.RawMessage `json:"params"` +} + // RegisterEchoServiceJsonRpcService register the grpc client EchoService for json-rpc. // The handlers forward requests to the grpc endpoint over "conn". func RegisterEchoServiceJsonRpcService(conn *grpc.ClientConn) *EchoServiceJsonRpcService { @@ -33,11 +39,18 @@ func (s *EchoServiceJsonRpcService) Methods() map[string]func(ctx context.Contex "proto.echo_service.echo": func(ctx context.Context, data json.RawMessage) (any, error) { req := new(EchoRequest) - err := protojson.Unmarshal(data, req) + + var jrpcData paramsAndHeaders + + if err := json.Unmarshal(data, &jrpcData); err != nil { + return nil, err + } + + err := protojson.Unmarshal(jrpcData.Params, req) if err != nil { return nil, err } - return s.client.Echo(ctx, req) + return s.client.Echo(metadata.NewOutgoingContext(ctx, jrpcData.Headers), req) }, } } diff --git a/jrpc/server.go b/jrpc/server.go index 3615eee..705465e 100644 --- a/jrpc/server.go +++ b/jrpc/server.go @@ -3,6 +3,9 @@ package jrpc import ( "context" "encoding/json" + "github.com/creachadair/jrpc2" + "google.golang.org/grpc/metadata" + "io" "net" "net/http" "time" @@ -22,6 +25,11 @@ type Server struct { handler http.Handler } +type paramsAndHeaders struct { + Headers metadata.MD `json:"headers,omitempty"` + Params json.RawMessage `json:"params"` +} + // NewServer create json rpc server func NewServer() *Server { sv := new(Server) @@ -58,9 +66,43 @@ func (s *Server) RegisterServices(svs ...Service) { hd[m] = handler.New(h) } } - s.handler = jhttp.NewBridge(hd, nil) + s.handler = jhttp.NewBridge(hd, &jhttp.BridgeOptions{ + ParseRequest: func(req *http.Request) ([]*jrpc2.ParsedRequest, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + prs, err := jrpc2.ParseRequests(body) + if err != nil { + return nil, err + } + + // Decorate the incoming request parameters with the headers. + for _, pr := range prs { + w, err := json.Marshal(paramsAndHeaders{ + Headers: headersToMetadata(req), + Params: pr.Params, + }) + if err != nil { + return nil, err + } + pr.Params = w + } + return prs, nil + }, + }) } func (s *Server) httpHandler(w http.ResponseWriter, r *http.Request) { s.handler.ServeHTTP(w, r) } + +func headersToMetadata(r *http.Request) metadata.MD { + headersMap := make(map[string]string) + for key, values := range r.Header { + if len(values) > 0 { + headersMap[key] = values[0] + } + } + return metadata.New(headersMap) +} diff --git a/jrpc/server_test.go b/jrpc/server_test.go index dbdcaa4..7b4f7f6 100644 --- a/jrpc/server_test.go +++ b/jrpc/server_test.go @@ -21,10 +21,16 @@ func (ms *MockService) Methods() map[string]method { } func (ms *MockService) testMethod(ctx context.Context, message json.RawMessage) (any, error) { + var ph paramsAndHeaders + if err := json.Unmarshal(message, &ph); err != nil { + return nil, err + } + var params map[string]string - if err := json.Unmarshal(message, ¶ms); err != nil { + if err := json.Unmarshal(ph.Params, ¶ms); err != nil { return nil, err } + return map[string]string{"response": "Hello " + params["name"]}, nil } @@ -34,15 +40,19 @@ func TestServer(t *testing.T) { mockService := &MockService{} server.RegisterServices(mockService) - // Create a listener for the server listener, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err, "Failed to create listener") + serverStarted := make(chan struct{}) + go func() { - err = server.Serve(listener) + close(serverStarted) + err := server.Serve(listener) assert.Error(t, err, "Server failed to serve") }() + <-serverStarted + requestBody, err := json.Marshal(map[string]any{ "jsonrpc": "2.0", "id": "1", diff --git a/protoc-gen-jrpc-gateway/internal/jgw.go b/protoc-gen-jrpc-gateway/internal/jgw.go index 7a5a57b..7eebca9 100644 --- a/protoc-gen-jrpc-gateway/internal/jgw.go +++ b/protoc-gen-jrpc-gateway/internal/jgw.go @@ -32,6 +32,7 @@ import ( "encoding/json" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/encoding/protojson" ) @@ -45,6 +46,11 @@ type {{$serviceName}} struct { client {{$clientName}} } +type paramsAndHeaders struct { + Headers metadata.MD ` + "`json:\"headers,omitempty\"`" + ` + Params json.RawMessage ` + "`json:\"params\"`" + ` +} + // {{$serviceName | printf "Register%s"}} register the grpc client {{$service.GetName}} for json-rpc. // The handlers forward requests to the grpc endpoint over "conn". func {{$serviceName | printf "Register%s"}} (conn *grpc.ClientConn) *{{$serviceName}} { @@ -58,11 +64,19 @@ func (s *{{$serviceName}}) Methods() map[string]func(ctx context.Context, messag {{range $method := $service.GetMethod}} "{{rpcMethod $package $service.GetName $method.GetName}}": func(ctx context.Context, data json.RawMessage) (any, error) { req := new({{methodInput $method.GetInputType}}) - err := protojson.Unmarshal(data, req) + + var jrpcData paramsAndHeaders + + if err := json.Unmarshal(data, &jrpcData); err != nil { + return nil, err + } + + err := protojson.Unmarshal(jrpcData.Params, req) if err != nil { return nil, err } - return s.client.{{$method.GetName}}(ctx, req) + + return s.client.{{$method.GetName}}(metadata.NewOutgoingContext(ctx, jrpcData.Headers), req) }, {{end}} }