diff --git a/internal/services/api_router.go b/internal/services/api_router.go index e2d6b1f8..1a027625 100644 --- a/internal/services/api_router.go +++ b/internal/services/api_router.go @@ -12,9 +12,13 @@ import ( "github.com/freifunkMUC/wg-access-server/proto/proto" "github.com/freifunkMUC/wg-embed/pkg/wgembed" - grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" + grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpcLogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" + grpcRecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" "github.com/improbable-eng/grpc-web/go/grpcweb" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type ApiServices struct { @@ -27,9 +31,18 @@ func ApiRouter(deps *ApiServices) http.Handler { // Native GRPC server server := grpc.NewServer([]grpc.ServerOption{ grpc.MaxRecvMsgSize(int(1 * math.Pow(2, 20))), // 1MB - grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - return grpc_logrus.UnaryServerInterceptor(traces.Logger(ctx))(ctx, req, info, handler) - }), + grpc.UnaryInterceptor(grpcMiddleware.ChainUnaryServer( + func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + // wrapped in anonymous func to get ctx + return grpcLogrus.UnaryServerInterceptor(traces.Logger(ctx))(ctx, req, info, handler) + }, + grpcRecovery.UnaryServerInterceptor( + grpcRecovery.WithRecoveryHandlerContext(func(ctx context.Context, p interface{}) (err error) { + // add trace id to error message so it's visible for the client + return status.Errorf(codes.Internal, "%v; trace = %s", p, traces.TraceID(ctx)) + }), + ), + )), }...) // Register GRPC services diff --git a/internal/services/health.go b/internal/services/health.go index 756f4fc1..06e09fc1 100644 --- a/internal/services/health.go +++ b/internal/services/health.go @@ -6,8 +6,8 @@ import ( ) func HealthEndpoint() http.Handler { - return http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) fmt.Fprintf(w, "ok") - })) + }) } diff --git a/internal/services/middleware.go b/internal/services/middleware.go index 6d5cfc3b..57eb82d5 100644 --- a/internal/services/middleware.go +++ b/internal/services/middleware.go @@ -5,7 +5,6 @@ import ( "net/http" "runtime/debug" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/freifunkMUC/wg-access-server/internal/traces" ) @@ -19,7 +18,7 @@ func RecoveryMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { - ctxlogrus.Extract(r.Context()). + traces.Logger(r.Context()). WithField("stack", string(debug.Stack())). Error(err) w.WriteHeader(500) diff --git a/pkg/authnz/authconfig/oidc.go b/pkg/authnz/authconfig/oidc.go index eb839b69..54adbb08 100644 --- a/pkg/authnz/authconfig/oidc.go +++ b/pkg/authnz/authconfig/oidc.go @@ -8,12 +8,13 @@ import ( "strings" "time" - "github.com/coreos/go-oidc" - "github.com/gorilla/mux" - "github.com/pkg/errors" "github.com/freifunkMUC/wg-access-server/pkg/authnz/authruntime" "github.com/freifunkMUC/wg-access-server/pkg/authnz/authsession" "github.com/freifunkMUC/wg-access-server/pkg/authnz/authutil" + + "github.com/coreos/go-oidc" + "github.com/gorilla/mux" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/oauth2" "gopkg.in/Knetic/govaluate.v2" @@ -138,14 +139,18 @@ func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauth } } + identity := &authsession.Identity{ + Provider: c.Name, + Subject: info.Subject, + Email: info.Email, + Claims: *claims, + } + if name, ok := oidcProfileData["name"].(string); ok { + identity.Name = name + } + err = runtime.SetSession(w, r, &authsession.AuthSession{ - Identity: &authsession.Identity{ - Provider: c.Name, - Subject: info.Subject, - Email: info.Email, - Name: oidcProfileData["name"].(string), - Claims: *claims, - }, + Identity: identity, }) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized)