diff --git a/go.mod b/go.mod index 4f630dfa..bf89512c 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/bavix/features v1.0.0 github.com/bavix/gripmock-sdk-go v1.0.4 github.com/bavix/gripmock-ui v1.0.0-alpha5 - github.com/bavix/gripmock/protogen v0.0.0-20240706174427-ef324cdfb46b + github.com/bavix/gripmock/protogen v0.0.0-20240706201937-fc1e72a8ad5f github.com/cristalhq/base64 v0.1.2 github.com/goccy/go-yaml v1.11.3 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 64ce8311..b6871aa2 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/bavix/gripmock-sdk-go v1.0.4 h1:FDBlusqVFoy5Yo49khztYqfVt9+NSEf1mIl1n github.com/bavix/gripmock-sdk-go v1.0.4/go.mod h1:/1cmn8VuN6Pc7ttMejqXLYpvf1CJF08ezoEA9lJIZiU= github.com/bavix/gripmock-ui v1.0.0-alpha5 h1:+2vWLZPeGGrpBSENWXIfyf6bwN+Pou+XucX0XlccLQo= github.com/bavix/gripmock-ui v1.0.0-alpha5/go.mod h1:XEH4YYEKL+wEDtONntoWm6JxjbVWzl7XtDYztUTBfeA= -github.com/bavix/gripmock/protogen v0.0.0-20240706174427-ef324cdfb46b h1:YCLXlvREBDiqSZX8D4e0DtPQB0jm55cl4YsllmB0B4k= -github.com/bavix/gripmock/protogen v0.0.0-20240706174427-ef324cdfb46b/go.mod h1:ARIfXpB9cyL9jIr7C1yrhwnb3wCSrewPNLdyG4URmJk= +github.com/bavix/gripmock/protogen v0.0.0-20240706201937-fc1e72a8ad5f h1:B/nZWWeQRXb4SQFHQWw45cMAnZ6eu/T6TAkJFpTqLHw= +github.com/bavix/gripmock/protogen v0.0.0-20240706201937-fc1e72a8ad5f/go.mod h1:ARIfXpB9cyL9jIr7C1yrhwnb3wCSrewPNLdyG4URmJk= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/bufbuild/protocompile v0.14.0 h1:z3DW4IvXE5G/uTOnSQn+qwQQxvhckkTWLS/0No/o7KU= github.com/bufbuild/protocompile v0.14.0/go.mod h1:N6J1NYzkspJo3ZwyL4Xjvli86XOj1xq4qAasUFxGups= diff --git a/go.work b/go.work index 86a0c478..b16d4853 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.22.2 +go 1.22.5 use ( . diff --git a/internal/app/rest_server.go b/internal/app/rest_server.go index 51add7e4..55f59981 100644 --- a/internal/app/rest_server.go +++ b/internal/app/rest_server.go @@ -8,6 +8,7 @@ import ( "log" "net/http" "os" + "path" "strings" "sync/atomic" "time" @@ -30,10 +31,10 @@ var ( ) type RestServer struct { + ok atomic.Bool stuber *stuber.Budgerigar convertor *yaml2json.Convertor caser cases.Caser - ok atomic.Bool reflector *grpcreflector.GReflector } @@ -49,15 +50,12 @@ func NewRestServer(path string, reflector *grpcreflector.GReflector) (*RestServe if path != "" { server.readStubs(path) // TODO: someday you will need to rewrite this code + server.ok.Store(true) } return server, nil } -func (h *RestServer) ServiceReady() { - h.ok.Store(true) -} - func (h *RestServer) ServicesList(w http.ResponseWriter, r *http.Request) { services, err := h.reflector.Services(r.Context()) if err != nil { @@ -288,49 +286,66 @@ func (h *RestServer) writeResponseError(err error, w http.ResponseWriter) { } } -func (h *RestServer) readStubs(path string) { - files, err := os.ReadDir(path) +// readStubs reads all the stubs from the given directory and its subdirectories, +// and adds them to the server's stub store. +// The stub files can be in yaml or json format. +// If a file is in yaml format, it will be converted to json format. +func (h *RestServer) readStubs(pathDir string) { + files, err := os.ReadDir(pathDir) if err != nil { - log.Printf("Can't read stub from %s. %v\n", path, err) + log.Printf("can't read stubs from %s: %v", pathDir, err) return } for _, file := range files { + // If the file is a directory, recursively read its stubs. if file.IsDir() { - h.readStubs(path + "/" + file.Name()) + h.readStubs(path.Join(pathDir, file.Name())) continue } - byt, err := os.ReadFile(path + "/" + file.Name()) + // Read the stub file and add it to the server's stub store. + stubs, err := h.readStub(path.Join(pathDir, file.Name())) if err != nil { - log.Printf("Error when reading file %s. %v. skipping...", file.Name(), err) + log.Printf("cant read stubs from %s: %v", file.Name(), err) continue } - if strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") { - byt, err = h.convertor.Execute(file.Name(), byt) - if err != nil { - log.Printf("Error when unmarshalling file %s. %v. skipping...", file.Name(), err) - - continue - } - } - - var storageStubs []*stuber.Stub + h.stuber.PutMany(stubs...) + } +} - if err = jsondecoder.UnmarshalSlice(byt, &storageStubs); err != nil { - log.Printf("Error when unmarshalling file %s. %v %v. skipping...", file.Name(), string(byt), err) +// readStub reads a stub file and returns a slice of stubs. +// The stub file can be in yaml or json format. +// If the file is in yaml format, it will be converted to json format. +func (h *RestServer) readStub(path string) ([]*stuber.Stub, error) { + // Read the file + byt, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("error when reading file %s: %w", path, err) + } - continue + // If the file is in yaml format, convert it to json format + if strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml") { + byt, err = h.convertor.Execute(path, byt) + if err != nil { + return nil, fmt.Errorf("error when unmarshalling file %s: %w", path, err) } + } - h.stuber.PutMany(storageStubs...) + // Unmarshal the json into a slice of stubs + var stubs []*stuber.Stub + if err := jsondecoder.UnmarshalSlice(byt, &stubs); err != nil { + return nil, fmt.Errorf("error when unmarshalling file %s: %v %w", path, string(byt), err) } + + return stubs, nil } +// validateStub validates if the stub is valid or not. func validateStub(stub *stuber.Stub) error { if stub.Service == "" { return ErrServiceIsMissing @@ -353,12 +368,9 @@ func validateStub(stub *stuber.Stub) error { return fmt.Errorf("input cannot be empty") } - // TODO: validate all input case - if stub.Output.Error == "" && stub.Output.Data == nil && stub.Output.Code == nil { - // fixme //nolint:goerr113,perfsprint - return fmt.Errorf("output can't be empty") + return fmt.Errorf("output cannot be empty") } return nil diff --git a/internal/pkg/muxmiddleware/logger.go b/internal/pkg/muxmiddleware/logger.go index 7e63d71a..c1ef12af 100644 --- a/internal/pkg/muxmiddleware/logger.go +++ b/internal/pkg/muxmiddleware/logger.go @@ -11,33 +11,34 @@ import ( "github.com/bavix/gripmock/pkg/jsondecoder" ) +// RequestLogger logs the request and response. func RequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger := zerolog.Ctx(r.Context()) ww := &responseWriter{w: w, status: http.StatusOK} ip, err := getIP(r) - now := time.Now() + start := time.Now() bodyBytes, _ := io.ReadAll(r.Body) - r.Body.Close() // must close + r.Body.Close() r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) next.ServeHTTP(ww, r) - event := logger.Info().Err(err). + event := logger.Info(). + Err(err). IPAddr("ip", ip). Str("method", r.Method). - Str("url", r.URL.RequestURI()) + Str("url", r.URL.RequestURI()). + Dur("elapsed", time.Since(start)). + Str("ua", r.UserAgent()). + Int("bytes", ww.bytesWritten). + Int("code", ww.status) if err := jsondecoder.UnmarshalSlice(bodyBytes, nil); err == nil { event.RawJSON("input", bodyBytes) } - event. - Dur("elapsed", time.Since(now)). - Str("ua", r.UserAgent()). - Int("bytes", ww.bytes). - Int("code", ww.status). - Send() + event.Send() }) } diff --git a/internal/pkg/muxmiddleware/resp_writer.go b/internal/pkg/muxmiddleware/resp_writer.go index 811b6680..e3d79bda 100644 --- a/internal/pkg/muxmiddleware/resp_writer.go +++ b/internal/pkg/muxmiddleware/resp_writer.go @@ -1,28 +1,28 @@ package muxmiddleware -import "net/http" +import ( + "net/http" +) type responseWriter struct { w http.ResponseWriter - status int - bytes int + status int + bytesWritten int } -func (r *responseWriter) Header() http.Header { - return r.w.Header() +func (rw *responseWriter) Header() http.Header { + return rw.w.Header() } -func (r *responseWriter) Write(bytes []byte) (int, error) { - n, err := r.w.Write(bytes) - - r.bytes += n +func (rw *responseWriter) Write(bytes []byte) (int, error) { + n, err := rw.w.Write(bytes) + rw.bytesWritten += n return n, err } -func (r *responseWriter) WriteHeader(statusCode int) { - r.w.WriteHeader(statusCode) - - r.status = statusCode +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.status = statusCode + rw.w.WriteHeader(statusCode) } diff --git a/internal/pkg/muxmiddleware/utils.go b/internal/pkg/muxmiddleware/utils.go index b52c37f2..0f0691f9 100644 --- a/internal/pkg/muxmiddleware/utils.go +++ b/internal/pkg/muxmiddleware/utils.go @@ -6,22 +6,24 @@ import ( "strings" ) +// getIP returns the IP address from the request headers. +// It returns the IP address from the X-Forwarded-For header if it exists, +// otherwise it returns the IP address from the RemoteAddr field. func getIP(r *http.Request) (net.IP, error) { - ips := r.Header.Get("X-Forwarded-For") - splitIps := strings.Split(ips, ",") - - if len(splitIps) > 0 { - netIP := net.ParseIP(splitIps[len(splitIps)-1]) - - if netIP != nil { - return netIP, nil + if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { + ips := strings.Split(forwardedFor, ",") + if len(ips) > 0 { + ip := strings.TrimSpace(ips[len(ips)-1]) + if parsedIP := net.ParseIP(ip); parsedIP != nil { + return parsedIP, nil + } } } - ip, _, err := net.SplitHostPort(r.RemoteAddr) + host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return nil, err } - return net.ParseIP(ip), nil + return net.ParseIP(host), nil } diff --git a/main.go b/main.go index c1f444a7..6ebb679d 100644 --- a/main.go +++ b/main.go @@ -2,10 +2,8 @@ package main import ( "context" - "errors" "flag" "io" - "io/fs" "log" "os" "os/exec" @@ -14,15 +12,12 @@ import ( "path/filepath" "strings" "syscall" - "time" _ "github.com/gripmock/grpc-interceptors" "github.com/rs/zerolog" _ "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" _ "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - "google.golang.org/grpc" _ "google.golang.org/grpc/health" - healthv1 "google.golang.org/grpc/health/grpc_health_v1" _ "github.com/bavix/gripmock-sdk-go" "github.com/bavix/gripmock/internal/pkg/patcher" @@ -57,104 +52,129 @@ func main() { // parse proto files protoPaths := flag.Args() + // Ensure at least one proto file is provided if len(protoPaths) == 0 { logger.Fatal().Msg("at least one proto file is required") } + // Start GripMock server //nolint:godox - // fixme: move validation of required arguments to a separate service + // TODO: move validation of required arguments to a separate service logger.Info().Str("release", version).Msg("Starting GripMock") + // Check if $GOPATH is set if os.Getenv("GOPATH") == "" { logger.Fatal().Msg("$GOPATH is empty") } + // Set output directory output := *outputPointer if output == "" { + // Default to $GOPATH/src/grpc if output is not provided output = os.Getenv("GOPATH") + "/src/grpc" } - // for safety + // For safety output += "/" - if _, err := os.Stat(output); errors.Is(err, fs.ErrNotExist) { + + // Check if output folder exists, if not create it + // nosemgrep:semgrep-go.os-error-is-not-exist + if _, err := os.Stat(output); os.IsNotExist(err) { + // Create output folder if err := os.Mkdir(output, os.ModePerm); err != nil { - logger.Fatal().Err(err).Msg("can't create output folder") + logger.Fatal().Err(err).Msg("unable to create output folder") } } - chReady := make(chan struct{}) - defer close(chReady) - - // run admin stub server - stub.RunRestServer(ctx, chReady, *stubPath, builder.Config(), builder.Reflector()) + // Run the admin stub server in a separate goroutine. + // + // This goroutine runs the REST server that serves the stub files. + // It waits for the ready signal from the gRPC server goroutine. + // Once the gRPC server is ready, it starts the admin stub server. + go func() { + stub.RunRestServer(ctx, *stubPath, builder.Config(), builder.Reflector()) + }() importDirs := strings.Split(*imports, ",") - // generate pb.go and grpc server based on proto + // Generate protoc-generated code and run the gRPC server. + // + // This section generates the protoc-generated code (pb.go) and runs the gRPC server. + // It creates the output directory if it does not exist. + // It then generates the protoc-generated code using the protocParam struct. + // Finally, it runs the gRPC server using the runGrpcServer function. generateProtoc(ctx, protocParam{ protoPath: protoPaths, output: output, imports: importDirs, }) - // and run + // And run run, chErr := runGrpcServer(ctx, output) - // This is a kind of crutch, but now there is no other solution. - // I have an idea to combine gripmock and grpcmock services into one, then this check will be easier to do. - // Checking the grpc port of the service. If the port appears, the service has started successfully. - go func() { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - waiter := healthv1.NewHealthClient(builder.GRPCClient()) - - check, err := waiter.Check(ctx, &healthv1.HealthCheckRequest{Service: ""}, grpc.WaitForReady(true)) - if err != nil { - return - } - - if check.GetStatus() == healthv1.HealthCheckResponse_SERVING { - chReady <- struct{}{} - } - }() - + // Wait for the gRPC server to exit or the context to be done. select { case err := <-chErr: - log.Fatal(err) + // If the gRPC server exits with an error, log the error. + logger.Fatal().Err(err).Msg("gRPC server exited with an error") case <-ctx.Done(): + // If the context is done, check if there was an error. if err := ctx.Err(); err != nil { logger.Err(err).Msg("an error has occurred") } + // Log that the gRPC server is stopping. logger.Info().Msg("Stopping gRPC Server") + // Kill the gRPC server process. if err := run.Process.Kill(); err != nil { - logger.Fatal().Err(err).Msg("process killed") + logger.Fatal().Err(err).Msg("failed to kill process") } } } +// protocParam represents the parameters for the protoc command. type protocParam struct { + // protoPath is a list of paths to the proto files. protoPath []string - output string - imports []string + + // output is the output directory for the generated files. + output string + + // imports is a list of import paths. + imports []string } +// getProtodirs returns a list of proto directories based on the given protoPath +// and imports. +// +// It takes a context.Context and a protoPath string as well as a slice of strings +// representing the imports. The protoPath string is used to deduce the proto +// directory, and the imports are used to search for a proto directory prefix. +// +// The function returns a slice of strings representing the proto directories. func getProtodirs(_ context.Context, protoPath string, imports []string) []string { - // deduced proto dir from proto path + // Deduce the proto directory from the proto path. splitPath := strings.Split(protoPath, "/") protoDir := "" + // If there are any elements in splitPath, join them up to the second-to-last + // element with path.Join to get the proto directory. if len(splitPath) > 0 { protoDir = path.Join(splitPath[:len(splitPath)-1]...) } - // search protoDir prefix + // Search for the proto directory prefix in the imports. protoDirIdx := -1 for i := range imports { + // Join the "protogen" directory with the import directory to get the full + // directory path. dir := path.Join("protogen", imports[i]) + + // If the proto directory starts with the full directory path, set the proto + // directory to the full directory path and set the index of the proto directory + // in the imports slice. if strings.HasPrefix(protoDir, dir) { protoDir = dir protoDirIdx = i @@ -163,9 +183,14 @@ func getProtodirs(_ context.Context, protoPath string, imports []string) []strin } } + // Create a slice to hold the proto directories. protoDirs := make([]string, 0, len(imports)+1) + + // Append the proto directory to the slice. protoDirs = append(protoDirs, protoDir) - // include all dir in imports, skip if it has been added before + + // Loop through the imports and append each directory to the slice, skipping + // any directories that have already been added. for i, dir := range imports { if i == protoDirIdx { continue @@ -174,37 +199,72 @@ func getProtodirs(_ context.Context, protoPath string, imports []string) []strin protoDirs = append(protoDirs, dir) } + // Return the slice of proto directories. return protoDirs } +// generateProtoc is a function that runs the protoc command with the given +// parameters. +// +// It takes a context.Context and a protocParam struct as parameters. The +// protocParam struct contains the protoPath, output, and imports fields that +// are used to configure the protoc command. +// +// It generates the protoc command arguments and runs the protoc command with +// the given parameters. If there is an error running the protoc command, it +// logs a fatal message. func generateProtoc(ctx context.Context, param protocParam) { + // Fix the go_package option for each proto file in the protoPath. param.protoPath = fixGoPackage(ctx, param.protoPath) + + // Get the proto directories based on the protoPath and imports. protodirs := getProtodirs(ctx, param.protoPath[0], param.imports) - // estimate args length to prevent expand + // Estimate the length of the args slice to prevent expanding it. args := make([]string, 0, len(protodirs)+len(param.protoPath)+2) //nolint:mnd + + // Append the -I option for each proto directory to the args slice. for _, dir := range protodirs { args = append(args, "-I", dir) } - // the latest go-grpc plugin will generate subfolders under $GOPATH/src based on go_package option + // Set the output directory for generated files to $GOPATH/src. pbOutput := os.Getenv("GOPATH") + "/src" + // Append the protoPath, --go_out, --go-grpc_out, and --gripmock_out options + // to the args slice. args = append(args, param.protoPath...) args = append(args, "--go_out="+pbOutput) args = append(args, "--go-grpc_out="+pbOutput) args = append(args, "--gripmock_out="+param.output) + + // Create a new exec.Cmd command with the protoc command and the args. protoc := exec.Command("protoc", args...) + + // Set the environment variables for the command. protoc.Env = os.Environ() + + // Set the stdout and stderr for the command. protoc.Stdout = os.Stdout protoc.Stderr = os.Stderr + // Run the protoc command and log a fatal message if there is an error. if err := protoc.Run(); err != nil { zerolog.Ctx(ctx).Fatal().Err(err).Msg("fail on protoc") } } -// append gopackage in proto files if doesn't have any. +// fixGoPackage is a function that appends the go_package option to each +// proto file in the given protoPaths if the proto file doesn't already have +// one. +// +// It reads each proto file, creates a temporary file with the go_package option, +// and copies the contents of the original file to the temporary file. The +// temporary file is then returned as part of the results. +// +// ctx is the context.Context to use for the function. +// protoPaths is a slice of string paths to the proto files. +// fixGoPackage returns a slice of string paths to the temporary files. func fixGoPackage(ctx context.Context, protoPaths []string) []string { results := make([]string, 0, len(protoPaths)) @@ -247,21 +307,32 @@ func fixGoPackage(ctx context.Context, protoPaths []string) []string { return results } +// runGrpcServer runs the gRPC server in a separate process. +// +// ctx is the context.Context to use for the command. +// output is the output directory where the server.go file is located. +// It returns the exec.Cmd object representing the running process, and a channel +// that receives an error when the process exits. func runGrpcServer(ctx context.Context, output string) (*exec.Cmd, <-chan error) { + // Construct the command to run the gRPC server. run := exec.CommandContext(ctx, "go", "run", output+"server.go") //nolint:gosec run.Env = os.Environ() run.Stdout = os.Stdout run.Stderr = os.Stderr - err := run.Start() - if err != nil { - zerolog.Ctx(ctx).Fatal().Err(err).Msg("unable to start grpc service") + // Start the command. + if err := run.Start(); err != nil { + zerolog.Ctx(ctx).Fatal().Err(err).Msg("unable to start gRPC service") } + // Log the process ID. zerolog.Ctx(ctx).Info().Int("pid", run.Process.Pid).Msg("gRPC-service started") + // Create a channel to receive the process exit error. runErr := make(chan error) + // Start a goroutine to wait for the process to exit and send the error + // to the channel. go func() { runErr <- run.Wait() }() diff --git a/pkg/grpccontext/interceptor.go b/pkg/grpccontext/interceptor.go index 4bbc97da..b52e1841 100644 --- a/pkg/grpccontext/interceptor.go +++ b/pkg/grpccontext/interceptor.go @@ -8,13 +8,23 @@ import ( "google.golang.org/grpc/metadata" ) +// UnaryInterceptor is a gRPC interceptor that adds a logger to the context. +// The logger can be used to log messages related to the gRPC request. +// +// It takes a logger as a parameter and returns a grpc.UnaryServerInterceptor. +// The returned interceptor is used to intercept the gRPC unary requests. func UnaryInterceptor(logger *zerolog.Logger) grpc.UnaryServerInterceptor { + // The interceptor function is called for each gRPC unary request. + // It takes the inner context, the request, the server info, and the handler. + // It returns the response and an error. return func( - innerCtx context.Context, - req interface{}, - _ *grpc.UnaryServerInfo, - handler grpc.UnaryHandler, + innerCtx context.Context, // The context of the gRPC request. + req interface{}, // The request object. + _ *grpc.UnaryServerInfo, // The server info. + handler grpc.UnaryHandler, // The handler function for the request. ) (interface{}, error) { + // Add the logger to the context and call the handler. + // The logger can be accessed using grpc.GetLogger(ctx). return handler(logger.WithContext(innerCtx), req) } } @@ -32,12 +42,23 @@ func (w serverStreamWrapper) SetHeader(md metadata.MD) error { return w.ss.SetH func (w serverStreamWrapper) SetTrailer(md metadata.MD) { w.ss.SetTrailer(md) } func StreamInterceptor(logger *zerolog.Logger) grpc.StreamServerInterceptor { + // StreamInterceptor is a gRPC interceptor that adds a logger to the context. + // The logger can be used to log messages related to the gRPC stream. + // + // It takes a logger as a parameter and returns a grpc.StreamServerInterceptor. + // The returned interceptor is used to intercept the gRPC stream requests. + // + // The interceptor function is called for each gRPC stream request. + // It takes the server, the stream, the server info, and the handler. + // It returns an error. return func( - srv interface{}, - ss grpc.ServerStream, - _ *grpc.StreamServerInfo, - handler grpc.StreamHandler, + srv interface{}, // The server object. + ss grpc.ServerStream, // The stream object. + _ *grpc.StreamServerInfo, // The server info. + handler grpc.StreamHandler, // The handler function for the stream. ) error { + // Create a serverStreamWrapper object with the stream and context. + // The context is created with the logger. return handler(srv, serverStreamWrapper{ ss: ss, ctx: logger.WithContext(ss.Context()), diff --git a/pkg/grpcreflector/reflect.go b/pkg/grpcreflector/reflect.go index 51bd2d0a..22a3a063 100644 --- a/pkg/grpcreflector/reflect.go +++ b/pkg/grpcreflector/reflect.go @@ -8,31 +8,40 @@ import ( "google.golang.org/grpc" ) -const prefix = "grpc.reflection.v1" +// GReflector is a client for the gRPC reflection API. +// It provides methods to list services and methods available on a gRPC server. +type GReflector struct { + conn *grpc.ClientConn // grpc connection to the server +} +// Service represents a gRPC service. type Service struct { - ID string - Package string - Name string + ID string // service ID + Package string // service package + Name string // service name } +// Method represents a gRPC method. type Method struct { - ID string - Name string + ID string // method ID + Name string // method name } -type GReflector struct { - conn *grpc.ClientConn -} +const prefix = "grpc.reflection.v1" +// New creates a new GReflector with the given grpc connection. func New(conn *grpc.ClientConn) *GReflector { return &GReflector{conn: conn} } +// client returns a new gRPC reflection client. +// It uses the given context and the grpc connection of the GReflector. func (g *GReflector) client(ctx context.Context) *grpcreflect.Client { return grpcreflect.NewClientAuto(ctx, g.conn) } +// makeService creates a Service struct from a service ID. +// The service ID is split into its package and name parts. func (g *GReflector) makeService(serviceID string) Service { const sep = "." @@ -45,6 +54,8 @@ func (g *GReflector) makeService(serviceID string) Service { } } +// makeMethod creates a Method struct from a service ID and method name. +// The method ID is created by concatenating the service ID and method name with a slash. func (g *GReflector) makeMethod(serviceID, method string) Method { return Method{ ID: serviceID + "/" + method, @@ -52,6 +63,8 @@ func (g *GReflector) makeMethod(serviceID, method string) Method { } } +// Services lists all services available on the gRPC server. +// It uses the gRPC reflection client to get the list of services and filters out the reflection service. func (g *GReflector) Services(ctx context.Context) ([]Service, error) { services, err := g.client(ctx).ListServices() if err != nil { @@ -69,6 +82,8 @@ func (g *GReflector) Services(ctx context.Context) ([]Service, error) { return results, nil } +// Methods lists all methods available on a service. +// It uses the gRPC reflection client to resolve the service and filter out the reflection methods. func (g *GReflector) Methods(ctx context.Context, serviceID string) ([]Method, error) { dest, err := g.client(ctx).ResolveService(serviceID) if err != nil { diff --git a/pkg/jsondecoder/decode.go b/pkg/jsondecoder/decode.go index bdef2e05..cdc7e87d 100644 --- a/pkg/jsondecoder/decode.go +++ b/pkg/jsondecoder/decode.go @@ -5,15 +5,43 @@ import ( "encoding/json" ) -//nolint:mnd +const minJSONLength = 2 + +// UnmarshalSlice is a function that parses JSON data into a slice of the provided interface. +// It handles the case where the input data is not a JSON array by wrapping it in an array. +// +// Examples: +// +// data := []byte(`{"name": "Bob"}`) +// var result []map[string]interface{} +// err := UnmarshalSlice(data, &result) +// // result is now [{"name": "Bob"}] +// +// data := []byte(`{"name": "Bob"}`) +// var result []map[string]string +// err := UnmarshalSlice(data, &result) +// // result is now [{"name": "Bob"}] +// +// data := []byte(`{"name": "Bob"}`) +// var result []interface{} +// err := UnmarshalSlice(data, &result) +// // result is now [{"name": "Bob"}] +// +// data := []byte(`{"name": "Bob"}`) +// var result []map[string]string +// err := UnmarshalSlice(data, &result) +// // result is now [{"name": "Bob"}] +// // NOTE: if the input data is not a JSON array, it is wrapped in an array before decoding func UnmarshalSlice(data []byte, v interface{}) error { input := bytes.TrimSpace(data) - // input[0] == "{" AND input[len(input)-1] == "}" - if bytes.HasPrefix(input, []byte{123}) && - bytes.HasSuffix(input, []byte{125}) { - // "[${input}]" - input = append(append([]byte{91}, input...), 93) + if len(input) < minJSONLength { + return &json.SyntaxError{} + } + + // If the input is not a JSON array, wrap it in an array + if len(input) > 0 && input[0] == '{' && input[len(input)-1] == '}' { + input = append(append([]byte{'['}, input...), ']') } decoder := json.NewDecoder(bytes.NewReader(input)) diff --git a/pkg/yaml2json/convertor.go b/pkg/yaml2json/convertor.go index aa66fb28..c7328414 100644 --- a/pkg/yaml2json/convertor.go +++ b/pkg/yaml2json/convertor.go @@ -12,11 +12,20 @@ func New() *Convertor { return &Convertor{engine: &engine{}} } +// Execute executes the given YAML data and returns the JSON representation. +// +// It takes a name and data as input parameters. +// The name parameter is used as a reference for the execution. +// The data parameter is the YAML data to be executed. +// +// It returns a byte slice and an error. +// The byte slice contains the JSON representation of the executed YAML data. +// The error is non-nil if there was an error during the execution. func (t *Convertor) Execute(name string, data []byte) ([]byte, error) { - bytes, err := t.engine.Execute(name, data) + jsonData, err := t.engine.Execute(name, data) if err != nil { return nil, err } - return yaml.YAMLToJSON(bytes) + return yaml.YAMLToJSON(jsonData) } diff --git a/pkg/yaml2json/engine.go b/pkg/yaml2json/engine.go index 2504ed32..7070b887 100644 --- a/pkg/yaml2json/engine.go +++ b/pkg/yaml2json/engine.go @@ -11,21 +11,28 @@ import ( type engine struct{} +// Execute executes a Go template with the given name and data. +// It returns the generated bytes and an error if any. func (e *engine) Execute(name string, data []byte) ([]byte, error) { - var buffer bytes.Buffer + // Execute a Go template with the given name and data. + // It returns the generated bytes and an error if any. + t := template.New(name).Funcs(e.funcMap()) - parse, err := template.New(name).Funcs(e.funcMap()).Parse(string(data)) + t, err := t.Parse(string(data)) if err != nil { return nil, err } - if err := parse.Execute(&buffer, nil); err != nil { + var buffer bytes.Buffer + if err := t.Execute(&buffer, nil); err != nil { return nil, err } return buffer.Bytes(), nil } +// uuid2int64 converts a UUID string to a map of two int64 values. +// It returns an empty string if there is an error. func (e *engine) uuid2int64(str string) string { v := e.uuid2bytes(str) _ = v[15] @@ -38,22 +45,25 @@ func (e *engine) uuid2int64(str string) string { low := int64(v[8]) | int64(v[9])<<8 | int64(v[10])<<16 | int64(v[11])<<24 | int64(v[12])<<32 | int64(v[13])<<40 | int64(v[14])<<48 | int64(v[15])<<56 - var buffer bytes.Buffer - - if err := json.NewEncoder(&buffer).Encode(map[string]int64{ + int64Data := map[string]int64{ "high": high, "low": low, - }); err != nil { + } + + int64JSON, err := json.Marshal(int64Data) + if err != nil { return "" } - return buffer.String() + return string(int64JSON) } +// uuid2base64 converts a UUID string to a base64 string. func (e *engine) uuid2base64(input string) string { return e.bytes2base64(e.uuid2bytes(input)) } +// uuid2bytes converts a UUID string to a byte slice. func (e *engine) uuid2bytes(input string) []byte { v := uuid.MustParse(input) diff --git a/protoc-gen-gripmock/generator.go b/protoc-gen-gripmock/generator.go index 5e5bc624..465b82e1 100644 --- a/protoc-gen-gripmock/generator.go +++ b/protoc-gen-gripmock/generator.go @@ -1,4 +1,5 @@ -package main +// Package main contains the main function for the protoc-gen-gripmock generator. +package main // import "github.com/bavix/gripmock/protoc-gen-gripmock" import ( "bytes" @@ -19,10 +20,16 @@ import ( "google.golang.org/protobuf/types/pluginpb" ) -func main() { - // Tip of the hat to Tim Coulson - // https://medium.com/@tim.r.coulson/writing-a-protoc-plugin-with-google-golang-org-protobuf-cd5aa75f5777 +// This package contains the implementation of protoc-gen-gripmock. +// It uses the protoc tool to generate a gRPC mock server from a .proto file. +// +// This package is generated by the go generate tag and should not be edited +// by hand. +// The main function is the entry point for the protoc-gen-gripmock generator. +// It reads input from stdin, unmarshals the request, and creates a new +// CodeGenerator object. It then generates the gRPC mock server in server.go. +func main() { // Protoc passes pluginpb.CodeGeneratorRequest in via stdin // marshalled with Protobuf input, _ := io.ReadAll(os.Stdin) @@ -40,20 +47,25 @@ func main() { plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) + // Create a slice of FileDescriptorProto objects for each input file. protos := make([]*descriptorpb.FileDescriptorProto, len(plugin.Files)) for index, file := range plugin.Files { protos[index] = file.Proto } + // Generate the gRPC mock server using the input files. buf := new(bytes.Buffer) err = generateServer(protos, &Options{ - writer: buf, + Writer: buf, }) if err != nil { log.Fatalf("Failed to generate server %v", err) } - file := plugin.NewGeneratedFile("server.go", ".") + // Create a new GeneratedFile with the name "server.go" and ".go" extension. + file := plugin.NewGeneratedFile("server.go", ".go") + + // Write the generated gRPC mock server code to the GeneratedFile. file.Write(buf.Bytes()) // Generate a response from our plugin and marshall as protobuf @@ -66,44 +78,73 @@ func main() { os.Stdout.Write(out) } +// generatorParam contains the parameters used to generate the gRPC mock server. type generatorParam struct { - Services []Service - Dependencies map[string]string + // Services is a slice of Service objects representing the services and methods in the input files. + Services []Service `json:"services"` + // Dependencies is a map of package names to their respective import paths. + // It is used to generate the import statements at the top of the generated server file. + Dependencies map[string]string `json:"dependencies"` } +// Service represents a gRPC service. type Service struct { - Name string - Package string - Methods []methodTemplate + // Name is the name of the service. + Name string `json:"name"` + // Package is the package name of the service. + Package string `json:"package"` + // Methods is a slice of methodTemplate representing the methods in the service. + Methods []methodTemplate `json:"methods"` } +// methodTemplate represents a method in a gRPC service. type methodTemplate struct { - SvcPackage string - Name string - ServiceName string - MethodType string - Input string - Output string + // SvcPackage is the package name of the service. + SvcPackage string `json:"svc_package"` + // Name is the name of the method. + Name string `json:"name"` + // ServiceName is the name of the service. + ServiceName string `json:"service_name"` + // MethodType is the type of the method, which can be "standard", "server-stream", "client-stream", or "bidirectional". + MethodType string `json:"method_type"` + // Input is the name of the input message for the method. + Input string `json:"input"` + // Output is the name of the output message for the method. + Output string `json:"output"` } const ( methodTypeStandard = "standard" // server to client stream + // methodTypeServerStream represents a server-stream method. methodTypeServerStream = "server-stream" - // client to server stream - methodTypeClientStream = "client-stream" + // methodTypeClientStream represents a client-stream method. + methodTypeClientStream = "client-stream" + // methodTypeBidirectional represents a bidirectional method. methodTypeBidirectional = "bidirectional" ) +// Options holds the configuration options for the code generator. type Options struct { - writer io.Writer + // Writer is the io.Writer used to write the generated server code. + // If not provided, the generated code is written to stdout. + Writer io.Writer `json:"writer"` } +// ServerTemplate is the template used to generate the gRPC server code. +// It is populated during the init function. var ServerTemplate string +// serverTmpl is the embed.FS used to read the server template file. +// //go:embed server.tmpl var serverTmpl embed.FS +// Init initializes the ServerTemplate with the contents of the server.tmpl file. +// +// It reads the server.tmpl file from the serverTmpl embed.FS and assigns its contents +// to the ServerTemplate variable. If there is an error reading the file, it logs +// the error and stops the program. func init() { data, err := serverTmpl.ReadFile("server.tmpl") if err != nil { @@ -113,55 +154,89 @@ func init() { ServerTemplate = string(data) } +// generateServer generates the gRPC server code based on the given protobuf +// descriptors and writes it to the provided io.Writer. +// +// It extracts the services from the given protobuf descriptors, resolves their +// dependencies, and generates the server code using the provided options. +// If no io.Writer is provided in the options, the generated code is written to +// os.Stdout. +// +// It returns an error if there is any issue in generating or writing the code. func generateServer(protos []*descriptorpb.FileDescriptorProto, opt *Options) error { + // Extract the services from the given protobuf descriptors services := extractServices(protos) + + // Resolve the dependencies of the services deps := resolveDependencies(protos) + // Prepare the parameters for generating the server code param := generatorParam{ Services: services, Dependencies: deps, } - if opt.writer == nil { - opt.writer = os.Stdout + // If no io.Writer is provided in the options, use os.Stdout + if opt.Writer == nil { + opt.Writer = os.Stdout } + // Create a new template and parse the server template tmpl := template.New("server.tmpl") - tmpl, err := tmpl.Parse(ServerTemplate) + _, err := tmpl.Parse(ServerTemplate) if err != nil { return fmt.Errorf("template parse %v", err) } + // Execute the template with the parameters and write the generated code to a buffer buf := new(bytes.Buffer) err = tmpl.Execute(buf, param) if err != nil { return fmt.Errorf("template execute %v", err) } + // Format the generated code using gofmt byt := buf.Bytes() bytProcessed, err := imports.Process("", byt, nil) if err != nil { return fmt.Errorf("formatting: %v \n%s", err, string(byt)) } - _, err = opt.writer.Write(bytProcessed) + // Write the formatted code to the io.Writer + _, err = opt.Writer.Write(bytProcessed) return err } +// resolveDependencies takes a list of protobuf file descriptors and returns a +// map of go package names to their respective alias. It resolves the +// dependencies by checking the go_package option of each protobuf file. If a +// go_package option is not present, it logs a fatal error. If a go package +// already exists in the map, it is skipped. +// +// Parameters: +// - protos: a list of protobuf file descriptors +// +// Returns: +// - a map of go package names to their respective alias func resolveDependencies(protos []*descriptorpb.FileDescriptorProto) map[string]string { deps := map[string]string{} + + // Iterate over each protobuf file descriptor for _, proto := range protos { + // Get the go package alias and name from the protobuf file descriptor alias, pkg := getGoPackage(proto) - // fatal if go_package is not present + // Log a fatal error if the go_package option is not present if pkg == "" { log.Fatalf("option go_package is required. but %s doesn't have any", proto.GetName()) } + // Skip the go package if it already exists in the map if _, ok := deps[pkg]; ok { continue } + // Add the go package to the map with its alias deps[pkg] = alias } @@ -169,42 +244,61 @@ func resolveDependencies(protos []*descriptorpb.FileDescriptorProto) map[string] } var ( - aliases = map[string]bool{} + // aliases is a map that keeps track of package aliases. The key is the alias and the value + // is a boolean indicating whether the alias is used or not. + aliases = map[string]bool{} + + // aliasNum is an integer that keeps track of the number of used aliases. It is used to + // generate new unique aliases. aliasNum = 1 + + // packages is a map that stores the package names as keys and their corresponding aliases + // as values. The package names are the full go package names and the aliases are the + // generated or specified aliases for the packages. packages = map[string]string{} ) +// getGoPackage returns the go package alias and the go package name +// extracted from the protobuf file's go_package option. +// +// If the go_package option is not present, it returns an empty string for goPackage. +// If the go_package option is present but has no alias, it returns an empty string for alias. +// If the go_package option has an alias, it returns the alias and the go package name. +// The alias is derived from the last folder in the go package name. +// If the last folder contains a dash, it is replaced with an underscore. +// If the alias is a keyword, it appends a random number to the alias. +// If the alias already exists, it appends a number to the alias. +// +// The go_package option format is: package_name;alias. func getGoPackage(proto *descriptorpb.FileDescriptorProto) (alias string, goPackage string) { goPackage = proto.GetOptions().GetGoPackage() if goPackage == "" { return } - // support go_package alias declaration - // https://github.com/golang/protobuf/issues/139 + // If the go_package option has an alias, it is separated by semicolon. if splits := strings.Split(goPackage, ";"); len(splits) > 1 { goPackage = splits[0] alias = splits[1] } else { - // get the alias based on the latest folder + // Get the alias based on the last folder in the go package name. splitSlash := strings.Split(goPackage, "/") - // replace - with _ + // Replace dash with underscore. alias = strings.ReplaceAll(splitSlash[len(splitSlash)-1], "-", "_") } - // if package already discovered just return + // If the package has already been discovered, return the alias. if al, ok := packages[goPackage]; ok { alias = al return } - // Aliases can't be keywords + // If the alias is a keyword, append a random number to it. if isKeyword(alias) { - alias = fmt.Sprintf("%s_pb", alias) + alias = fmt.Sprintf("%s_pb%d", alias, aliasNum) } - // in case of found same alias - // add numbers on it + // If the alias already exists, append a number to it. if ok := aliases[alias]; ok { alias = fmt.Sprintf("%s%d", alias, aliasNum) aliasNum++ @@ -216,20 +310,42 @@ func getGoPackage(proto *descriptorpb.FileDescriptorProto) (alias string, goPack return } -// change the structure also translate method type +// extractServices extracts services from a list of file descriptors. It returns +// a slice of Service structs, each representing a gRPC service. +// +// The function iterates over each file descriptor and extracts the services +// defined in each file. It then populates the Services struct with relevant +// information like the service name, package name, and methods. The methods +// include information such as the method name, input and output types, and the +// type of method (standard, server-stream, client-stream, or bidirectional). +// +// Parameters: +// - protos: A slice of FileDescriptorProto structs representing the file +// descriptors. +// +// Returns: +// - svcTmp: A slice of Service structs representing the extracted services. func extractServices(protos []*descriptorpb.FileDescriptorProto) []Service { var svcTmp []Service title := cases.Title(language.English, cases.NoLower) + + // Iterate over each file descriptor for _, proto := range protos { + // Iterate over each service in the file for _, svc := range proto.GetService() { var s Service s.Name = svc.GetName() + + // Get the package alias if available alias, _ := getGoPackage(proto) if alias != "" { s.Package = alias + "." } + + // Populate the methods for the service methods := make([]methodTemplate, len(svc.Method)) for j, method := range svc.Method { + // Determine the type of method tipe := methodTypeStandard if method.GetServerStreaming() && !method.GetClientStreaming() { tipe = methodTypeServerStream @@ -239,6 +355,7 @@ func extractServices(protos []*descriptorpb.FileDescriptorProto) []Service { tipe = methodTypeBidirectional } + // Populate the methodTemplate struct methods[j] = methodTemplate{ Name: title.String(*method.Name), SvcPackage: s.Package, @@ -248,69 +365,102 @@ func extractServices(protos []*descriptorpb.FileDescriptorProto) []Service { MethodType: tipe, } } + s.Methods = methods svcTmp = append(svcTmp, s) } } + return svcTmp } func getMessageType(protos []*descriptorpb.FileDescriptorProto, tipe string) string { + // Split the message type into package and type parts split := strings.Split(tipe, ".")[1:] targetPackage := strings.Join(split[:len(split)-1], ".") targetType := split[len(split)-1] + + // Iterate over the protos to find the target message for _, proto := range protos { + // Check if the proto package matches the target package if proto.GetPackage() != targetPackage { continue } + // Iterate over the messages in the proto for _, msg := range proto.GetMessageType() { + // Check if the message name matches the target type if msg.GetName() == targetType { + // Get the package alias if available alias, _ := getGoPackage(proto) if alias != "" { alias += "." } + + // Return the fully qualified message type return fmt.Sprintf("%s%s", alias, msg.GetName()) } } } + + // Return the target type if no match was found return targetType } -func isKeyword(word string) bool { - keywords := [...]string{ - "break", - "case", - "chan", - "const", - "continue", - "default", - "defer", - "else", - "fallthrough", - "for", - "func", - "go", - "goto", - "if", - "import", - "interface", - "map", - "package", - "range", - "return", - "select", - "struct", - "switch", - "type", - "var", - } +// keywords is a map that contains all the reserved keywords in Go. +// It helps to determine if a given word is a keyword or not. +var keywords = map[string]bool{ + "break": true, + "case": true, + "chan": true, + "const": true, + "continue": true, + "default": true, + "defer": true, + "else": true, + "fallthrough": true, + "for": true, + "func": true, + "go": true, + "goto": true, + "if": true, + "import": true, + "interface": true, + "map": true, + "package": true, + "range": true, + "return": true, + "select": true, + "struct": true, + "switch": true, + "type": true, + "var": true, + "bool": true, + "byte": true, + "complex128": true, + "complex64": true, + "error": true, + "float32": true, + "float64": true, + "int": true, + "int16": true, + "int32": true, + "int64": true, + "int8": true, + "rune": true, + "string": true, + "uint": true, + "uint16": true, + "uint32": true, + "uint64": true, + "uint8": true, + "uintptr": true, +} - for _, keyword := range keywords { - if strings.ToLower(word) == keyword { - return true - } - } +// isKeyword checks if a word is a keyword or not. +// It does a case insensitive comparison. +func isKeyword(word string) bool { + _, ok := keywords[strings.ToLower(word)] - return false + return ok } diff --git a/protoc-gen-gripmock/server.tmpl b/protoc-gen-gripmock/server.tmpl index 04de4660..d0efc55f 100644 --- a/protoc-gen-gripmock/server.tmpl +++ b/protoc-gen-gripmock/server.tmpl @@ -1,12 +1,18 @@ // Code generated by GripMock. DO NOT EDIT. +// +// This file is generated by GripMock, a tool for generating gRPC mock servers. +// GripMock is a mock server for gRPC services. It's using a .proto file to generate implementation of gRPC service for you. +// You can use GripMock for setting up end-to-end testing or as a dummy server in a software development phase. +// The server implementation is in GoLang but the client can be any programming language that support gRPC. +// +// See https://github.com/bavix/gripmock for more information. package main import ( "context" - "errors" + "time" "slices" "fmt" - "io" "log" "net" "net/http" @@ -55,27 +61,71 @@ func main() { } s := grpc.NewServer( - grpc.StatsHandler(otelgrpc.NewServerHandler()), - grpc.ChainUnaryInterceptor([]grpc.UnaryServerInterceptor{ - grpccontext.UnaryInterceptor(builder.Logger()), - }...), - grpc.ChainStreamInterceptor([]grpc.StreamServerInterceptor{ - grpccontext.StreamInterceptor(builder.Logger()), - }...), - ) + grpc.StatsHandler(otelgrpc.NewServerHandler()), + grpc.UnaryInterceptor(grpccontext.UnaryInterceptor(builder.Logger())), + grpc.StreamInterceptor(grpccontext.StreamInterceptor(builder.Logger())), + ) + + healthcheck := health.NewServer() + healthcheck.SetServingStatus("", healthgrpc.HealthCheckResponse_NOT_SERVING) + {{ range .Services }} {{ template "register_services" . }} {{ end }} - - healthgrpc.RegisterHealthServer(s, health.NewServer()) + healthgrpc.RegisterHealthServer(s, healthcheck) reflection.Register(s) builder.Logger().Info(). - Str("addr", fmt.Sprintf("%s://%s", builder.Config().GRPCNetwork, builder.Config().GRPCAddr)). + Str("addr", builder.Config().GRPCAddr). + Str("network", builder.Config().GRPCNetwork). Msg("Serving gRPC") + // Health check goroutine to wait for the HTTP server to become ready. + // Once the HTTP server is ready, it sets the gRPC server to SERVING state. + go func() { + // Create a new client to interact with the HTTP server API. + api, err := sdk.NewClientWithResponses( + fmt.Sprintf("http://%s/api", builder.Config().HTTPAddr), + sdk.WithHTTPClient(http.DefaultClient), + ) + if err != nil { + return + } + + // Create a context with a timeout of 120 seconds. + ctx, cancel := context.WithTimeout(ctx, 120*time.Second) + defer cancel() + + // Create a ticker to periodically check the readiness of the HTTP server. + tick := time.NewTicker(250 * time.Millisecond) + defer tick.Stop() + + for { + select { + // Check if the context is done. + case <-ctx.Done(): + return + + // Check if the ticker has fired. + case <-tick.C: + // Call the Readiness API on the HTTP server. + resp, err := api.ReadinessWithResponse(ctx) + + // If the API call is successful and the response is not nil, + // set the gRPC server to SERVING state and log a message. + if err == nil && resp.JSON200 != nil { + healthcheck.SetServingStatus("", healthgrpc.HealthCheckResponse_SERVING) + + builder.Logger().Info().Msg("gRPC server is ready to accept requests") + + return + } + } + } + }() + if err := s.Serve(lis); err != nil { - builder.Logger().Fatal().Err(err).Msg("server ended") + builder.Logger().Fatal().Err(err).Msg("failed to serve") } } @@ -107,8 +157,17 @@ type {{.Name}} struct{ {{ define "standard_method" }} func (s *{{.ServiceName}}) {{.Name}}(ctx context.Context, in *{{.Input}}) (*{{.Output}},error){ out := &{{.Output}}{} + // Retrieve metadata from the incoming context. + // The metadata is used to find the stub for the method being called. md, _ := metadata.FromIncomingContext(ctx) + + // Find the stub for the given service name, method name, and metadata. + // The stub defines the input and output messages for the method. + // If the stub is found, its output message is returned. + // If the stub is not found, an error is returned. err := findStub(ctx, s.__builder__.Config(), "{{.ServiceName}}", "{{.Name}}", md, in, out) + + // Return the output message and any error encountered while finding the stub. return out, err } {{ end }} @@ -116,13 +175,24 @@ func (s *{{.ServiceName}}) {{.Name}}(ctx context.Context, in *{{.Input}}) (*{{.O {{ define "server_stream_method" }} func (s *{{.ServiceName}}) {{.Name}}(in *{{.Input}},srv {{.SvcPackage}}{{.ServiceName}}_{{.Name}}Server) error { out := &{{.Output}}{} + // Retrieve metadata from the incoming context. + // The metadata is used to find the stub for the method being called. ctx := srv.Context() md, _ := metadata.FromIncomingContext(ctx) + + // Find the stub for the given service name, method name, and metadata. + // The stub defines the input and output messages for the method. + // If the stub is found, its output message is returned. + // If the stub is not found, an error is returned. err := findStub(ctx, s.__builder__.Config(), "{{.ServiceName}}", "{{.Name}}", md, in, out) if err != nil { + // Return the error encountered while finding the stub. return err } + // Send the output message back to the client. + // This will continue the server-streaming RPC. + // If there is an error sending the message, it will be returned. return srv.Send(out) } {{ end }} @@ -130,15 +200,29 @@ func (s *{{.ServiceName}}) {{.Name}}(in *{{.Input}},srv {{.SvcPackage}}{{.Servic {{ define "client_stream_method"}} func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name}}Server) error { out := &{{.Output}}{} + // Handle the client-streaming RPC. + // This loop will continue until the client closes the RPC. + // For each input message received from the client, it will find the stub + // and generate the output message. + // The output message will be sent back to the client when the RPC is closed. ctx := srv.Context() md, _ := metadata.FromIncomingContext(ctx) for { - input,err := srv.Recv() + // Receive the next input message from the client. + // If the client closes the RPC, io.EOF is returned. + input, err := srv.Recv() if errors.Is(err, io.EOF) { + // If the client closes the RPC, send the output message and close the RPC. return srv.SendAndClose(out) } - err = findStub(ctx, s.__builder__.Config(), "{{.ServiceName}}","{{.Name}}",md,input,out) + + // Find the stub for the given service name, method name, and metadata. + // The stub defines the input and output messages for the method. + // If the stub is found, its output message is returned. + // If the stub is not found, an error is returned. + err = findStub(ctx, s.__builder__.Config(), "{{.ServiceName}}","{{.Name}}", md, input, out) if err != nil { + // If there is an error finding the stub, return the error. return err } } @@ -147,23 +231,40 @@ func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name {{ define "bidirectional_method"}} func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name}}Server) error { + // Handle the bidirectional RPC. + // This loop will continue until the client closes the RPC. + // For each input message received from the client, it will find the stub + // and generate the output message. + // The output message will be sent back to the client when the RPC is closed. ctx := srv.Context() md, _ := metadata.FromIncomingContext(ctx) for { - in, err := srv.Recv() + // Receive the next input message from the client. + // If the client closes the RPC, io.EOF is returned. + input, err := srv.Recv() if errors.Is(err, io.EOF) { + // If the client closes the RPC, send the output message and close the RPC. return nil } + if err != nil { return err } + // Create a new output message. out := &{{.Output}}{} - err = findStub(ctx, s.__builder__.Config(), "{{.ServiceName}}","{{.Name}}", md, in, out) + + // Find the stub for the given service name, method name, and metadata. + // The stub defines the input and output messages for the method. + // If the stub is found, its output message is returned. + // If the stub is not found, an error is returned. + err = findStub(ctx, s.__builder__.Config(), "{{.ServiceName}}","{{.Name}}", md, input, out) if err != nil { return err } + // Send the output message back to the client. + // If there is an error sending the message, it will be returned. if err := srv.Send(out); err != nil{ return err } @@ -178,6 +279,8 @@ func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name {{ define "find_stub" }} func findStub(ctx context.Context, conf environment.Config, service, method string, md metadata.MD, in, out protoreflect.ProtoMessage) error { + // Create a new client with the configured HTTP address. + // Add the default HTTP client as the transport. api, err := sdk.NewClientWithResponses(fmt.Sprintf("http://%s/api", conf.HTTPAddr), sdk.WithHTTPClient(http.DefaultClient), ) @@ -185,57 +288,70 @@ func findStub(ctx context.Context, conf environment.Config, service, method stri return err } - excludes := []string{":authority", "content-type", "grpc-accept-encoding", "user-agent"} - headers := make(map[string]string, len(md)) - for h, v := range md { - if slices.Contains(excludes, h) { - continue - } + // Exclude headers that are not relevant for matching stubs. + excludes := []string{":authority", "content-type", "grpc-accept-encoding", "user-agent"} - headers[h] = strings.Join(v, ";") - } + // Create a map of headers to match with the input metadata. + headers := make(map[string]string, len(md)) + for h, v := range md { + // Exclude headers that are not relevant for matching stubs. + if slices.Contains(excludes, h) { + continue + } - searchStub, err := api.SearchStubsWithResponse(ctx, sdk.SearchStubsJSONRequestBody{ - Service: service, - Method: method, - Headers: headers, - Data: in, + // Join the values of the header with a semicolon. + headers[h] = strings.Join(v, ";") + } + + // Search for a stub that matches the given service, method, and headers. + searchStub, err := api.SearchStubsWithResponse(ctx, sdk.SearchStubsJSONRequestBody{ + Service: service, // The name of the service. + Method: method, // The name of the method. + Headers: headers, // The headers to match. + Data: in, // The input message. }) if err != nil { return err } + // If the search was unsuccessful, return an error with the response body. if searchStub.JSON200 == nil { return fmt.Errorf(string(searchStub.Body)) } + // If the search returned an error, return an error with the error code and message. if searchStub.JSON200.Error != "" || searchStub.JSON200.Code != nil { - if searchStub.JSON200.Code == nil { - return status.Error(codes.Aborted, searchStub.JSON200.Error) - } + if searchStub.JSON200.Code == nil { + return status.Error(codes.Aborted, searchStub.JSON200.Error) + } - if *searchStub.JSON200.Code != codes.OK { - return status.Error(*searchStub.JSON200.Code, searchStub.JSON200.Error) - } - } + if *searchStub.JSON200.Code != codes.OK { + return status.Error(*searchStub.JSON200.Code, searchStub.JSON200.Error) + } + } - data, err := json.Marshal(searchStub.JSON200.Data) - if err != nil { - return err - } + // Convert the search result to JSON. + data, err := json.Marshal(searchStub.JSON200.Data) + if err != nil { + return err + } - mdResp := make(metadata.MD, len(searchStub.JSON200.Headers)) - for k, v := range searchStub.JSON200.Headers { - splits := strings.Split(v, ";") - for i, s := range splits { - splits[i] = strings.TrimSpace(s) - } + // Create a map of headers to set in the context. + mdResp := make(metadata.MD, len(searchStub.JSON200.Headers)) + for k, v := range searchStub.JSON200.Headers { + // Split the values of the header by semicolon and trim each value. + splits := strings.Split(v, ";") + for i, s := range splits { + splits[i] = strings.TrimSpace(s) + } - mdResp[k] = splits - } + mdResp[k] = splits + } - grpc.SetHeader(ctx, mdResp) + // Set the headers in the context. + grpc.SetHeader(ctx, mdResp) - return jsonpb.Unmarshal(data, out) + // Unmarshal the search result into the output message. + return jsonpb.Unmarshal(data, out) } {{ end }} diff --git a/stub/stub.go b/stub/stub.go index 1818bda3..4f8b751e 100644 --- a/stub/stub.go +++ b/stub/stub.go @@ -22,13 +22,10 @@ import ( func RunRestServer( ctx context.Context, - ch chan struct{}, stubPath string, config environment.Config, reflector *grpcreflector.GReflector, ) { - const timeout = time.Millisecond * 25 - apiServer, _ := app.NewRestServer(stubPath, reflector) ui, _ := gripmockui.Assets() @@ -42,6 +39,8 @@ func RunRestServer( }) router.PathPrefix("/").Handler(http.FileServerFS(ui)).Methods(http.MethodGet) + const timeout = time.Millisecond * 25 + srv := &http.Server{ Addr: config.HTTPAddr, ReadHeaderTimeout: timeout, @@ -63,21 +62,8 @@ func RunRestServer( Str("addr", config.HTTPAddr). Msg("stub-manager started") - go func() { - // nosemgrep:go.lang.security.audit.net.use-tls.use-tls - if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - zerolog.Ctx(ctx).Fatal().Err(err).Msg("stub manager completed") - } - }() - - go func() { - select { - case <-ctx.Done(): - return - case <-ch: - apiServer.ServiceReady() - } - - zerolog.Ctx(ctx).Info().Msg("gRPC-service is ready to accept requests") - }() + // nosemgrep:go.lang.security.audit.net.use-tls.use-tls + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + zerolog.Ctx(ctx).Fatal().Err(err).Msg("stub manager completed") + } }