From 49b623667095efa9891932d027ff08d621c12d84 Mon Sep 17 00:00:00 2001 From: "saica.go" Date: Tue, 15 Aug 2023 03:13:34 +0800 Subject: [PATCH] feat: adaptation for gRPC --- pkg/adapters/grpc/traffic.go | 82 ++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 pkg/adapters/grpc/traffic.go diff --git a/pkg/adapters/grpc/traffic.go b/pkg/adapters/grpc/traffic.go new file mode 100644 index 00000000..a6d21ba4 --- /dev/null +++ b/pkg/adapters/grpc/traffic.go @@ -0,0 +1,82 @@ +package grpc + +import ( + "context" + "errors" + "fmt" + "github.com/alibaba/sentinel-golang/core/route" + "github.com/alibaba/sentinel-golang/core/route/base" + "github.com/google/uuid" + "google.golang.org/grpc" + "net" + "strings" +) + +var ( + connToBaggage map[string]map[string]string = make(map[string]map[string]string) + cm *route.ClusterManager = nil +) + +func NewDialer(id string) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + parts := strings.Split(addr, "/") + if len(parts) != 2 { + return nil, errors.New("invalid address format") + } + tc := &base.TrafficContext{ + ServiceName: parts[0], + MethodName: parts[1], + Headers: make(map[string]string), + } + + instance, err := cm.GetOne(tc) + + if err != nil { + return nil, err + } + if instance == nil { + return nil, errors.New("no matched provider") + } + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%v", instance.Host, instance.Port)) + if err != nil { + return nil, err + } + connToBaggage[id] = tc.Baggage + + return conn, nil + } +} + +func NewTrafficUnaryIntercepter(connId string) grpc.DialOption { + return grpc.WithUnaryInterceptor( + func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + newCtx := ctx + if baggage, ok := connToBaggage[connId]; ok { + // TODO modify the request + _ = baggage + _ = newCtx + } + return invoker(newCtx, method, req, reply, cc, opts...) + }) +} + +func NewTrafficStreamIntercepter(connId string) grpc.DialOption { + return grpc.WithStreamInterceptor( + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + newCtx := ctx + if baggage, ok := connToBaggage[connId]; ok { + // TODO modify the request + _ = baggage + _ = newCtx + } + return streamer(newCtx, desc, cc, method, opts...) + }) +} + +func Dial(addr string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + id := uuid.New().String() + opts = append(opts, grpc.WithContextDialer(NewDialer(id))) + opts = append(opts, NewTrafficUnaryIntercepter(id)) + opts = append(opts, NewTrafficStreamIntercepter(id)) + return grpc.Dial(addr, opts...) +}