Skip to content

Commit

Permalink
refactor(proxy): forward auth request rather than proxy
Browse files Browse the repository at this point in the history
Signed-off-by: Rodney Osodo <[email protected]>
  • Loading branch information
rodneyosodo committed Sep 12, 2024
1 parent 0cd8cca commit 7ec21af
Show file tree
Hide file tree
Showing 17 changed files with 505 additions and 255 deletions.
3 changes: 2 additions & 1 deletion docker-compose/.env
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ MG_SPICEDB_DB_PORT=5432
### SpiceDB config
MG_SPICEDB_PRE_SHARED_KEY="yejaTYpqgwqn8ACsnt4qzUso9z5auY"
MG_SPICEDB_SCHEMA_FILE="/schema.zed"
MG_SPICEDB_HOST=magistrala-spicedb
MG_SPICEDB_HOST=spicedb
MG_SPICEDB_PORT=50051
MG_SPICEDB_DATASTORE_ENGINE=postgres

Expand Down Expand Up @@ -150,3 +150,4 @@ UV_VAULT_PROXY_PORT=8900
UV_VAULT_PROXY_SERVER_CERT=
UV_VAULT_PROXY_SERVER_KEY=
UV_VAULT_PROXY_TARGET_URL=http://ollama:11434
UV_VAULT_PROXY_INSTANCE_ID=
5 changes: 4 additions & 1 deletion docker-compose/proxy-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ services:
- auth
environment:
UV_VAULT_PROXY_LOG_LEVEL: ${UV_VAULT_PROXY_LOG_LEVEL}
UV_VAULT_PROXY_TARGET_URL: ${UV_VAULT_PROXY_TARGET_URL}
UV_VAULT_PROXY_HOST: ${UV_VAULT_PROXY_HOST}
UV_VAULT_PROXY_PORT: ${UV_VAULT_PROXY_PORT}
UV_VAULT_PROXY_SERVER_CERT: ${UV_VAULT_PROXY_SERVER_CERT}
Expand All @@ -19,7 +20,9 @@ services:
MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY}
UV_VAULT_PROXY_TARGET_URL: ${UV_VAULT_PROXY_TARGET_URL}
UV_VAULT_PROXY_INSTANCE_ID: ${UV_VAULT_PROXY_INSTANCE_ID}
MG_JAEGER_URL: ${MG_JAEGER_URL}
MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO}
volumes:
# Auth gRPC client certificates
- type: bind
Expand Down
15 changes: 12 additions & 3 deletions docker-compose/traefik/dynamic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ http:
stripPrefix:
prefixes:
- "/ollama"
forward-auth-middleware:
forwardAuth:
address: http://vault-proxy:8900

services:
users:
Expand Down Expand Up @@ -56,10 +59,15 @@ http:
interval: 10s
timeout: 10s

vault-proxy:
ollama:
loadBalancer:
servers:
- url: http://vault-proxy:8900
- url: http://ollama:11434
healthCheck:
scheme: http
path: /
interval: 10s
timeout: 10s

routers:
users-health:
Expand Down Expand Up @@ -99,9 +107,10 @@ http:
rule: "PathPrefix(`/ollama`)"
entryPoints:
- websecure
service: vault-proxy
service: ollama
middlewares:
- strip-ollama-prefix-middleware
- forward-auth-middleware
- retry-middleware
- headers-middleware
priority: 10
Expand Down
2 changes: 1 addition & 1 deletion proxy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ VERSION ?= $(shell git describe --abbrev=0 --tags 2>/dev/null || echo 'v0.0.0')

define compile_service
CGO_ENABLED=$(CGO_ENABLED) GOOS=$(GOOS) GOARCH=$(GOARCH) \
go build -ldflags "-s -w" -o ${BUILD_DIR}/
go build -ldflags "-s -w" -o ${BUILD_DIR}/vault-proxy cmd/main.go
endef

define make_docker
Expand Down
25 changes: 25 additions & 0 deletions proxy/api/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package api

import (
"context"

"github.com/absmach/magistrala/pkg/apiutil"
"github.com/absmach/magistrala/pkg/errors"
"github.com/go-kit/kit/endpoint"
proxy "github.com/ultraviolet/vault-proxy"
)

func identifyEndpoint(svc proxy.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(identifyRequest)
if err := req.Validate(); err != nil {
return identifyResponse{identified: false}, errors.Wrap(apiutil.ErrValidation, err)
}

if err := svc.Identify(ctx, req.Token); err != nil {
return identifyResponse{identified: false}, err
}

return identifyResponse{identified: true}, nil
}
}
15 changes: 15 additions & 0 deletions proxy/api/request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package api

import "github.com/absmach/magistrala/pkg/apiutil"

type identifyRequest struct {
Token string `json:"token"`
}

func (i *identifyRequest) Validate() error {
if i.Token == "" {
return apiutil.ErrBearerToken
}

return nil
}
29 changes: 29 additions & 0 deletions proxy/api/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package api

import (
"net/http"

"github.com/absmach/magistrala"
)

var _ magistrala.Response = (*identifyResponse)(nil)

type identifyResponse struct {
identified bool
}

func (i *identifyResponse) Code() int {
if i.identified {
return http.StatusOK
}

return http.StatusUnauthorized
}

func (i *identifyResponse) Headers() map[string]string {
return map[string]string{}
}

func (i identifyResponse) Empty() bool {
return true
}
95 changes: 95 additions & 0 deletions proxy/api/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package api

import (
"context"
"encoding/json"
"log/slog"
"net/http"

"github.com/absmach/magistrala"
"github.com/absmach/magistrala/pkg/apiutil"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
proxy "github.com/ultraviolet/vault-proxy"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)

const (
ContentType = "application/json"
)

func MakeHandler(svc proxy.Service, logger *slog.Logger, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, encodeError)),
}

mux := chi.NewRouter()

mux.HandleFunc("/", otelhttp.NewHandler(kithttp.NewServer(
identifyEndpoint(svc),
decodeIdentifyReq,
encodeResponse,
opts...,
), "identify").ServeHTTP)

return mux
}

func decodeIdentifyReq(_ context.Context, r *http.Request) (interface{}, error) {
return identifyRequest{
Token: apiutil.ExtractBearerToken(r),
}, nil
}

func encodeError(_ context.Context, err error, w http.ResponseWriter) {
var wrapper error
if errors.Contains(err, apiutil.ErrValidation) {
wrapper, err = errors.Unwrap(err)
}

w.Header().Set("Content-Type", ContentType)
switch {
case errors.Contains(err, apiutil.ErrBearerToken),
errors.Contains(err, svcerr.ErrAuthentication):
err = unwrap(err)
w.WriteHeader(http.StatusUnauthorized)
default:
w.WriteHeader(http.StatusInternalServerError)
}

if wrapper != nil {
err = errors.Wrap(wrapper, err)
}

if errorVal, ok := err.(errors.Error); ok {
if err := json.NewEncoder(w).Encode(errorVal); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}
}

func unwrap(err error) error {
wrapper, err := errors.Unwrap(err)
if wrapper != nil {
return wrapper
}
return err
}

func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
if ar, ok := response.(magistrala.Response); ok {
for k, v := range ar.Headers() {
w.Header().Set(k, v)
}
w.Header().Set("Content-Type", ContentType)
w.WriteHeader(ar.Code())

if ar.Empty() {
return nil
}
}

return json.NewEncoder(w).Encode(response)
}
137 changes: 137 additions & 0 deletions proxy/cmd/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package main

import (
"context"
"fmt"
"log"
"log/slog"
"net/url"
"os"

"github.com/absmach/callhome/pkg/client"
"github.com/absmach/magistrala"
authclient "github.com/absmach/magistrala/auth/api/grpc"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/grpcclient"
"github.com/absmach/magistrala/pkg/jaeger"
"github.com/absmach/magistrala/pkg/prometheus"
"github.com/absmach/magistrala/pkg/server"
"github.com/absmach/magistrala/pkg/server/http"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/caarlos0/env/v11"
proxy "github.com/ultraviolet/vault-proxy"
"github.com/ultraviolet/vault-proxy/api"
"github.com/ultraviolet/vault-proxy/middleware"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
)

const (
svcName = "vault_proxy"
envPrefixHTTP = "UV_VAULT_PROXY_"
envPrefixAuth = "MG_AUTH_GRPC_"
defSvcHTTPPort = "8900"
)

type config struct {
LogLevel string `env:"UV_VAULT_PROXY_LOG_LEVEL" envDefault:"info"`
TargetURL string `env:"UV_VAULT_PROXY_TARGET_URL" envDefault:"http://ollama:11434"`
SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"UV_VAULT_PROXY_INSTANCE_ID" envDefault:""`
JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"`
}

func main() {
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)

cfg := config{}
if err := env.Parse(&cfg); err != nil {
log.Fatalf("failed to load %s configuration : %s", svcName, err)
}

logger, err := mglog.New(os.Stdout, cfg.LogLevel)
if err != nil {
log.Fatalf("failed to init logger: %s", err.Error())
}

var exitCode int
defer mglog.ExitWithError(&exitCode)

if cfg.InstanceID == "" {
if cfg.InstanceID, err = uuid.New().ID(); err != nil {
logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err))
exitCode = 1
return
}
}

authCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&authCfg, env.Options{Prefix: envPrefixAuth}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err))
exitCode = 1
return
}

authClient, authHandler, err := grpcclient.SetupAuthClient(ctx, authCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authHandler.Close()

logger.Info("Auth service gRPC client successfully connected to auth gRPC server " + authHandler.Secure())

tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
if err != nil {
logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err))
exitCode = 1
return
}
defer func() {
if err := tp.Shutdown(ctx); err != nil {
logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err))
}
}()
tracer := tp.Tracer(svcName)

svc := newService(authClient, logger, tracer)

httpServerConfig := server.Config{Port: defSvcHTTPPort}
if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err))
exitCode = 1
return
}

httpSvr := http.NewServer(ctx, cancel, svcName, httpServerConfig, api.MakeHandler(svc, logger, cfg.InstanceID), logger)

if cfg.SendTelemetry {
chc := client.New(svcName, magistrala.Version, logger, cancel)
go chc.CallHome(ctx)
}

g.Go(func() error {
return httpSvr.Start()
})

g.Go(func() error {
return server.StopSignalHandler(ctx, cancel, logger, svcName, httpSvr)
})

if err := g.Wait(); err != nil {
logger.Error(fmt.Sprintf("HTTP adapter service terminated: %s", err))
}
}

func newService(authClient authclient.AuthServiceClient, logger *slog.Logger, tracer trace.Tracer) proxy.Service {
svc := proxy.NewService(authClient)
svc = middleware.NewLoggingMiddleware(logger, svc)
svc = middleware.NewTracingMiddleware(tracer, svc)
counter, latency := prometheus.MakeMetrics(svcName, "api")
svc = middleware.NewMetricsMiddleware(counter, latency, svc)

return svc
}
Loading

0 comments on commit 7ec21af

Please sign in to comment.