diff --git a/cmd/main.go b/cmd/main.go index 61e4e17..4770f4d 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,6 +4,8 @@ import ( "log/slog" "os" + "go.openfort.xyz/shield/pkg/logger" + "go.openfort.xyz/shield/cmd/cli" ) @@ -11,7 +13,7 @@ func main() { slog.Info("Starting OpenFort Shield") rootCmd := cli.NewCmdRoot() if err := rootCmd.Execute(); err != nil { - slog.Info("Error executing command", slog.String("error", err.Error())) + slog.Info("Error executing command", logger.Error(err)) os.Exit(1) } } diff --git a/di/wire.go b/di/wire.go index a451744..2245bab 100644 --- a/di/wire.go +++ b/di/wire.go @@ -6,7 +6,7 @@ package di import ( "github.com/google/wire" "go.openfort.xyz/shield/internal/applications/projectapp" - "go.openfort.xyz/shield/internal/applications/userapp" + "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/internal/core/services/projectsvc" @@ -114,14 +114,12 @@ func ProvideProviderManager() (pm *providersmgr.Manager, err error) { return } -func ProvideUserApplication() (a *userapp.UserApplication, err error) { +func ProvideShareApplication() (a *shareapp.ShareApplication, err error) { wire.Build( - userapp.New, - ProvideUserService, + shareapp.New, ProvideShareService, - ProvideProjectService, - ProvideProviderService, - ProvideProviderManager, + ProvideSQLShareRepository, + ProvideSQLProjectRepository, ) return @@ -131,7 +129,10 @@ func ProvideProjectApplication() (a *projectapp.ProjectApplication, err error) { wire.Build( projectapp.New, ProvideProjectService, + ProvideSQLProjectRepository, ProvideProviderService, + ProvideSQLProviderRepository, + ProvideSQLShareRepository, ) return @@ -152,7 +153,7 @@ func ProvideRESTServer() (s *rest.Server, err error) { wire.Build( rest.New, rest.GetConfigFromEnv, - ProvideUserApplication, + ProvideShareApplication, ProvideProjectApplication, ProvideAuthenticationManager, ) diff --git a/di/wire_gen.go b/di/wire_gen.go index ae93910..1da86c9 100644 --- a/di/wire_gen.go +++ b/di/wire_gen.go @@ -8,7 +8,7 @@ package di import ( "go.openfort.xyz/shield/internal/applications/projectapp" - "go.openfort.xyz/shield/internal/applications/userapp" + "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/internal/core/services/projectsvc" @@ -124,41 +124,45 @@ func ProvideProviderManager() (*providersmgr.Manager, error) { return manager, nil } -func ProvideUserApplication() (*userapp.UserApplication, error) { - userService, err := ProvideUserService() +func ProvideShareApplication() (*shareapp.ShareApplication, error) { + shareService, err := ProvideShareService() if err != nil { return nil, err } - shareService, err := ProvideShareService() + shareRepository, err := ProvideSQLShareRepository() if err != nil { return nil, err } + projectRepository, err := ProvideSQLProjectRepository() + if err != nil { + return nil, err + } + shareApplication := shareapp.New(shareService, shareRepository, projectRepository) + return shareApplication, nil +} + +func ProvideProjectApplication() (*projectapp.ProjectApplication, error) { projectService, err := ProvideProjectService() if err != nil { return nil, err } - providerService, err := ProvideProviderService() + projectRepository, err := ProvideSQLProjectRepository() if err != nil { return nil, err } - manager, err := ProvideProviderManager() + providerService, err := ProvideProviderService() if err != nil { return nil, err } - userApplication := userapp.New(userService, shareService, projectService, providerService, manager) - return userApplication, nil -} - -func ProvideProjectApplication() (*projectapp.ProjectApplication, error) { - projectService, err := ProvideProjectService() + providerRepository, err := ProvideSQLProviderRepository() if err != nil { return nil, err } - providerService, err := ProvideProviderService() + shareRepository, err := ProvideSQLShareRepository() if err != nil { return nil, err } - projectApplication := projectapp.New(projectService, providerService) + projectApplication := projectapp.New(projectService, projectRepository, providerService, providerRepository, shareRepository) return projectApplication, nil } @@ -188,7 +192,7 @@ func ProvideRESTServer() (*rest.Server, error) { if err != nil { return nil, err } - userApplication, err := ProvideUserApplication() + shareApplication, err := ProvideShareApplication() if err != nil { return nil, err } @@ -196,6 +200,6 @@ func ProvideRESTServer() (*rest.Server, error) { if err != nil { return nil, err } - server := rest.New(config, projectApplication, userApplication, manager) + server := rest.New(config, projectApplication, shareApplication, manager) return server, nil } diff --git a/go.mod b/go.mod index 9abd0c4..c22a939 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/MicahParks/jwkset v0.5.15 // indirect + github.com/benbjohnson/clock v1.3.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-sql-driver/mysql v1.8.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -38,6 +39,7 @@ require ( github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.0 // indirect + go.uber.org/ratelimit v0.3.1 // indirect golang.org/x/sync v0.6.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/go.sum b/go.sum index 531ab1a..7fb9e85 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/MicahParks/jwkset v0.5.15 h1:ACJY045Zuvo2TVWikeFLnKTIsEDQQHUHrNYiMW+g github.com/MicahParks/jwkset v0.5.15/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY= github.com/MicahParks/keyfunc/v3 v3.2.9 h1:juKYzZvb5q4mWnox3439WNq6cusvSdt2fJ5nj+osgCk= github.com/MicahParks/keyfunc/v3 v3.2.9/go.mod h1:Yx3jN/pn7ZMCxwFsyIrsmSqRfp0HGHAcyezBlhYi1Ew= +github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= +github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA= github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18= github.com/codahale/sss v0.0.0-20160501174526-0cb9f6d3f7f1 h1:PJJtqFbZH8ZW9PtsfB+ALZKVPRiRwNbPrNe+gliLpGo= @@ -70,6 +72,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= +go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= diff --git a/internal/applications/projectapp/app.go b/internal/applications/projectapp/app.go index 33e96fc..f1170d5 100644 --- a/internal/applications/projectapp/app.go +++ b/internal/applications/projectapp/app.go @@ -2,50 +2,74 @@ package projectapp import ( "context" + "errors" "log/slog" - "os" + "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/domain/provider" + "go.openfort.xyz/shield/internal/core/domain/share" + "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/pkg/ofcontext" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/contexter" + "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/pkg/logger" ) type ProjectApplication struct { - projectSvc services.ProjectService - providerSvc services.ProviderService - logger *slog.Logger + projectSvc services.ProjectService + projectRepo repositories.ProjectRepository + providerSvc services.ProviderService + providerRepo repositories.ProviderRepository + sharesRepo repositories.ShareRepository + logger *slog.Logger } -func New(projectSvc services.ProjectService, providerSvc services.ProviderService) *ProjectApplication { +func New(projectSvc services.ProjectService, projectRepo repositories.ProjectRepository, providerSvc services.ProviderService, providerRepo repositories.ProviderRepository, sharesRepo repositories.ShareRepository) *ProjectApplication { return &ProjectApplication{ - projectSvc: projectSvc, - providerSvc: providerSvc, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("project_application"), + projectSvc: projectSvc, + projectRepo: projectRepo, + providerSvc: providerSvc, + providerRepo: providerRepo, + sharesRepo: sharesRepo, + logger: logger.New("project_application"), } } -func (a *ProjectApplication) CreateProject(ctx context.Context, name string) (*project.Project, error) { +func (a *ProjectApplication) CreateProject(ctx context.Context, name string, opts ...ProjectOption) (*project.Project, error) { a.logger.InfoContext(ctx, "creating project") proj, err := a.projectSvc.Create(ctx, name) if err != nil { - a.logger.ErrorContext(ctx, "failed to create project", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to create project", logger.Error(err)) return nil, fromDomainError(err) } + var o projectOptions + for _, opt := range opts { + opt(&o) + } + + if o.generateEncryptionKey { + part, err := a.registerEncryptionKey(ctx, proj.ID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to register encryption key", logger.Error(err)) + return nil, fromDomainError(err) + } + + proj.EncryptionPart = part + } + return proj, nil } func (a *ProjectApplication) GetProject(ctx context.Context) (*project.Project, error) { a.logger.InfoContext(ctx, "getting project") + projectID := contexter.GetProjectID(ctx) - projectID := ofcontext.GetProjectID(ctx) - - proj, err := a.projectSvc.Get(ctx, projectID) + proj, err := a.projectRepo.Get(ctx, projectID) if err != nil { - a.logger.ErrorContext(ctx, "failed to get project", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to get project", logger.Error(err)) return nil, fromDomainError(err) } @@ -54,8 +78,7 @@ func (a *ProjectApplication) GetProject(ctx context.Context) (*project.Project, func (a *ProjectApplication) AddProviders(ctx context.Context, opts ...ProviderOption) ([]*provider.Provider, error) { a.logger.InfoContext(ctx, "adding providers") - - projectID := ofcontext.GetProjectID(ctx) + projectID := contexter.GetProjectID(ctx) cfg := &providerConfig{} for _, opt := range opts { @@ -63,43 +86,52 @@ func (a *ProjectApplication) AddProviders(ctx context.Context, opts ...ProviderO } var providers []*provider.Provider - if cfg.jwkURL != nil { - a.logger.InfoContext(ctx, "configuring custom provider") - prov, err := a.providerSvc.Configure(ctx, projectID, &services.CustomProviderConfig{JWKUrl: *cfg.jwkURL}) - if err != nil { - a.logger.ErrorContext(ctx, "failed to configure custom provider", slog.String("error", err.Error())) + if cfg.openfortPublishableKey != nil { + prov, err := a.providerRepo.GetByProjectAndType(ctx, projectID, provider.TypeOpenfort) + if err != nil && !errors.Is(err, domain.ErrProviderNotFound) { + a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, fromDomainError(err) } - - providers = append(providers, prov) + if err == nil && prov != nil { + return nil, ErrProviderAlreadyExists + } + providers = append(providers, &provider.Provider{ProjectID: projectID, Type: provider.TypeOpenfort, Config: provider.OpenfortConfig{PublishableKey: *cfg.openfortPublishableKey}}) } - if cfg.openfortPublishableKey != nil { - a.logger.InfoContext(ctx, "configuring openfort provider") - prov, err := a.providerSvc.Configure(ctx, projectID, &services.OpenfortProviderConfig{OpenfortProject: *cfg.openfortPublishableKey}) - if err != nil { - a.logger.ErrorContext(ctx, "failed to configure openfort provider", slog.String("error", err.Error())) + if cfg.jwkURL != nil { + prov, err := a.providerRepo.GetByProjectAndType(ctx, projectID, provider.TypeCustom) + if err != nil && !errors.Is(err, domain.ErrProviderNotFound) { + a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, fromDomainError(err) } - - providers = append(providers, prov) + if err == nil && prov != nil { + return nil, ErrProviderAlreadyExists + } + providers = append(providers, &provider.Provider{ProjectID: projectID, Type: provider.TypeCustom, Config: provider.CustomConfig{JWK: *cfg.jwkURL}}) } if len(providers) == 0 { return nil, ErrNoProviderSpecified } + for _, prov := range providers { + err := a.providerSvc.Configure(ctx, prov) + if err != nil { + a.logger.ErrorContext(ctx, "failed to create provider", logger.Error(err)) + return nil, fromDomainError(err) + } + } + return providers, nil } func (a *ProjectApplication) GetProviders(ctx context.Context) ([]*provider.Provider, error) { a.logger.InfoContext(ctx, "listing providers") + projectID := contexter.GetProjectID(ctx) - projectID := ofcontext.GetProjectID(ctx) - - providers, err := a.providerSvc.List(ctx, projectID) + providers, err := a.providerRepo.List(ctx, projectID) if err != nil { - a.logger.ErrorContext(ctx, "failed to list providers", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to list providers", logger.Error(err)) return nil, fromDomainError(err) } @@ -108,12 +140,11 @@ func (a *ProjectApplication) GetProviders(ctx context.Context) ([]*provider.Prov func (a *ProjectApplication) GetProviderDetail(ctx context.Context, providerID string) (*provider.Provider, error) { a.logger.InfoContext(ctx, "getting provider detail") + projectID := contexter.GetProjectID(ctx) - projectID := ofcontext.GetProjectID(ctx) - - prov, err := a.providerSvc.Get(ctx, providerID) + prov, err := a.providerRepo.Get(ctx, providerID) if err != nil { - a.logger.ErrorContext(ctx, "failed to get provider", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, fromDomainError(err) } @@ -127,12 +158,11 @@ func (a *ProjectApplication) GetProviderDetail(ctx context.Context, providerID s func (a *ProjectApplication) UpdateProvider(ctx context.Context, providerID string, opts ...ProviderOption) error { a.logger.InfoContext(ctx, "updating provider") + projectID := contexter.GetProjectID(ctx) - projectID := ofcontext.GetProjectID(ctx) - - prov, err := a.providerSvc.Get(ctx, providerID) + prov, err := a.providerRepo.Get(ctx, providerID) if err != nil { - a.logger.ErrorContext(ctx, "failed to get provider", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return fromDomainError(err) } @@ -151,9 +181,9 @@ func (a *ProjectApplication) UpdateProvider(ctx context.Context, providerID stri return ErrProviderMismatch } - err := a.providerSvc.UpdateConfig(ctx, &provider.CustomConfig{ProviderID: providerID, JWK: *cfg.jwkURL}) + err = a.providerRepo.UpdateCustom(ctx, &provider.CustomConfig{ProviderID: providerID, JWK: *cfg.jwkURL}) if err != nil { - a.logger.ErrorContext(ctx, "failed to update custom provider", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to update custom provider", logger.Error(err)) return fromDomainError(err) } } @@ -163,9 +193,9 @@ func (a *ProjectApplication) UpdateProvider(ctx context.Context, providerID stri return ErrProviderMismatch } - err = a.providerSvc.UpdateConfig(ctx, &provider.OpenfortConfig{ProviderID: providerID, PublishableKey: *cfg.openfortPublishableKey}) + err = a.providerRepo.UpdateOpenfort(ctx, &provider.OpenfortConfig{ProviderID: providerID, PublishableKey: *cfg.openfortPublishableKey}) if err != nil { - a.logger.ErrorContext(ctx, "failed to update openfort provider", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to update openfort provider", logger.Error(err)) return fromDomainError(err) } } @@ -174,12 +204,22 @@ func (a *ProjectApplication) UpdateProvider(ctx context.Context, providerID stri func (a *ProjectApplication) RemoveProvider(ctx context.Context, providerID string) error { a.logger.InfoContext(ctx, "removing provider") + projectID := contexter.GetProjectID(ctx) - projectID := ofcontext.GetProjectID(ctx) + prov, err := a.providerRepo.Get(ctx, providerID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) + return fromDomainError(err) + } - err := a.providerSvc.Remove(ctx, projectID, providerID) + if prov.ProjectID != projectID { + a.logger.ErrorContext(ctx, "unauthorized access, trying to remove provider from different project", slog.String("project_id", projectID), slog.String("provider_project_id", prov.ProjectID)) + return ErrProviderNotFound + } + + err = a.providerRepo.Delete(ctx, providerID) if err != nil { - a.logger.ErrorContext(ctx, "failed to remove provider", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to remove provider", logger.Error(err)) return fromDomainError(err) } @@ -188,12 +228,11 @@ func (a *ProjectApplication) RemoveProvider(ctx context.Context, providerID stri func (a *ProjectApplication) AddAllowedOrigin(ctx context.Context, origin string) error { a.logger.InfoContext(ctx, "adding allowed origin") - - projectID := ofcontext.GetProjectID(ctx) + projectID := contexter.GetProjectID(ctx) err := a.projectSvc.AddAllowedOrigin(ctx, projectID, origin) if err != nil { - a.logger.ErrorContext(ctx, "failed to add allowed origin", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to add allowed origin", logger.Error(err)) return fromDomainError(err) } @@ -203,11 +242,11 @@ func (a *ProjectApplication) AddAllowedOrigin(ctx context.Context, origin string func (a *ProjectApplication) RemoveAllowedOrigin(ctx context.Context, origin string) error { a.logger.InfoContext(ctx, "removing allowed origin") - projectID := ofcontext.GetProjectID(ctx) + projectID := contexter.GetProjectID(ctx) err := a.projectSvc.RemoveAllowedOrigin(ctx, projectID, origin) if err != nil { - a.logger.ErrorContext(ctx, "failed to remove allowed origin", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to remove allowed origin", logger.Error(err)) return fromDomainError(err) } @@ -217,13 +256,87 @@ func (a *ProjectApplication) RemoveAllowedOrigin(ctx context.Context, origin str func (a *ProjectApplication) GetAllowedOrigins(ctx context.Context) ([]string, error) { a.logger.InfoContext(ctx, "getting allowed origins") - projectID := ofcontext.GetProjectID(ctx) + projectID := contexter.GetProjectID(ctx) origins, err := a.projectSvc.GetAllowedOrigins(ctx, projectID) if err != nil { - a.logger.ErrorContext(ctx, "failed to get allowed origins", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to get allowed origins", logger.Error(err)) return nil, fromDomainError(err) } return origins, nil } + +func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, projectID, externalPart string) error { + a.logger.InfoContext(ctx, "encrypting project shares") + + storedPart, err := a.projectRepo.GetEncryptionPart(ctx, projectID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) + return fromDomainError(err) + } + + encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, externalPart) + if err != nil { + a.logger.ErrorContext(ctx, "failed to reconstruct encryption key", logger.Error(err)) + return ErrInvalidEncryptionPart + } + + shares, err := a.sharesRepo.ListDecryptedByProjectID(ctx, projectID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to list shares", logger.Error(err)) + return fromDomainError(err) + } + + var encryptedShares []*share.Share + for _, shr := range shares { + if shr.EncryptionParameters != nil && shr.EncryptionParameters.Entropy != share.EntropyNone { + continue + } + + shr.Secret, err = cypher.Encrypt(shr.Secret, encryptionKey) + if err != nil { + a.logger.ErrorContext(ctx, "failed to encrypt share", logger.Error(err)) + return fromDomainError(err) + } + + shr.EncryptionParameters = &share.EncryptionParameters{ + Entropy: share.EntropyProject, + } + } + + for _, encryptedShare := range encryptedShares { + err = a.sharesRepo.Update(ctx, encryptedShare) + if err != nil { + a.logger.ErrorContext(ctx, "failed to update share", logger.Error(err)) + return fromDomainError(err) + } + } + + return nil +} + +func (a *ProjectApplication) registerEncryptionKey(ctx context.Context, projectID string) (externalPart string, err error) { + defer func() { + if err != nil { + a.logger.Info("deleting project") + errD := a.projectRepo.Delete(ctx, projectID) + if errD != nil { + a.logger.Error("failed to delete project", logger.Error(errD)) + err = errors.Join(err, errD) + } + } + }() + var shieldPart string + shieldPart, externalPart, err = cypher.GenerateEncryptionKey() + if err != nil { + return "", err + } + + err = a.projectSvc.SetEncryptionPart(ctx, projectID, shieldPart) + if err != nil { + return "", err + } + + return externalPart, nil +} diff --git a/internal/applications/projectapp/errors.go b/internal/applications/projectapp/errors.go index 97a6eea..50f7d5c 100644 --- a/internal/applications/projectapp/errors.go +++ b/internal/applications/projectapp/errors.go @@ -14,6 +14,7 @@ var ( ErrUnknownProviderType = errors.New("unknown provider type") ErrProviderAlreadyExists = errors.New("custom authentication already registered for this project") ErrProviderNotFound = errors.New("custom authentication not found") + ErrInvalidEncryptionPart = errors.New("invalid encryption part") ErrInternal = errors.New("internal error") ) diff --git a/internal/applications/projectapp/options.go b/internal/applications/projectapp/options.go index b632342..7617481 100644 --- a/internal/applications/projectapp/options.go +++ b/internal/applications/projectapp/options.go @@ -18,3 +18,15 @@ type providerConfig struct { jwkURL *string openfortPublishableKey *string } + +type ProjectOption func(options *projectOptions) + +type projectOptions struct { + generateEncryptionKey bool +} + +func WithEncryptionKey() ProjectOption { + return func(o *projectOptions) { + o.generateEncryptionKey = true + } +} diff --git a/internal/applications/shareapp/app.go b/internal/applications/shareapp/app.go new file mode 100644 index 0000000..0e7e215 --- /dev/null +++ b/internal/applications/shareapp/app.go @@ -0,0 +1,114 @@ +package shareapp + +import ( + "context" + "log/slog" + + "go.openfort.xyz/shield/internal/core/domain/share" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/internal/core/ports/services" + "go.openfort.xyz/shield/pkg/contexter" + "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/pkg/logger" +) + +type ShareApplication struct { + shareSvc services.ShareService + shareRepo repositories.ShareRepository + projectRepo repositories.ProjectRepository + logger *slog.Logger +} + +func New(shareSvc services.ShareService, shareRepo repositories.ShareRepository, projectRepo repositories.ProjectRepository) *ShareApplication { + return &ShareApplication{ + shareSvc: shareSvc, + shareRepo: shareRepo, + projectRepo: projectRepo, + logger: logger.New("share_application"), + } +} + +func (a *ShareApplication) RegisterShare(ctx context.Context, shr *share.Share, opts ...Option) error { + a.logger.InfoContext(ctx, "registering share") + usrID := contexter.GetUserID(ctx) + projID := contexter.GetProjectID(ctx) + shr.UserID = usrID + + var opt options + for _, o := range opts { + o(&opt) + } + + var shrOpts []services.ShareOption + if shr.RequiresEncryption() { + if opt.encryptionPart == nil { + return ErrEncryptionPartRequired + } + + encryptionKey, err := a.reconstructEncryptionKey(ctx, projID, opt) + if err != nil { + return err + } + + shrOpts = append(shrOpts, services.WithEncryptionKey(encryptionKey)) + } + + err := a.shareSvc.Create(ctx, shr, shrOpts...) + if err != nil { + a.logger.ErrorContext(ctx, "failed to create share", logger.Error(err)) + return fromDomainError(err) + } + + return nil +} + +func (a *ShareApplication) GetShare(ctx context.Context, opts ...Option) (*share.Share, error) { + a.logger.InfoContext(ctx, "getting share") + usrID := contexter.GetUserID(ctx) + projID := contexter.GetProjectID(ctx) + + shr, err := a.shareRepo.GetByUserID(ctx, usrID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get share by user ID", logger.Error(err)) + return nil, fromDomainError(err) + } + + var opt options + for _, o := range opts { + o(&opt) + } + + if shr.RequiresEncryption() { + encryptionKey, err := a.reconstructEncryptionKey(ctx, projID, opt) + if err != nil { + return nil, err + } + + shr.Secret, err = cypher.Decrypt(shr.Secret, encryptionKey) + if err != nil { + a.logger.ErrorContext(ctx, "failed to decrypt secret", logger.Error(err)) + return nil, ErrInternal + } + } + + return shr, nil +} + +func (a *ShareApplication) reconstructEncryptionKey(ctx context.Context, projID string, opt options) (string, error) { + if opt.encryptionPart == nil || *opt.encryptionPart == "" { + return "", ErrEncryptionPartRequired + } + + storedPart, err := a.projectRepo.GetEncryptionPart(ctx, projID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) + return "", fromDomainError(err) + } + + encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, *opt.encryptionPart) + if err != nil { + a.logger.ErrorContext(ctx, "failed to reconstruct encryption key", logger.Error(err)) + return "", ErrInvalidEncryptionPart + } + return encryptionKey, nil +} diff --git a/internal/applications/userapp/errors.go b/internal/applications/shareapp/errors.go similarity index 70% rename from internal/applications/userapp/errors.go rename to internal/applications/shareapp/errors.go index cd73502..f5afe6b 100644 --- a/internal/applications/userapp/errors.go +++ b/internal/applications/shareapp/errors.go @@ -1,4 +1,4 @@ -package userapp +package shareapp import ( "errors" @@ -12,6 +12,9 @@ var ( ErrUserNotFound = errors.New("user not found") ErrExternalUserNotFound = errors.New("external user not found") ErrExternalUserAlreadyExists = errors.New("external user already exists") + ErrEncryptionPartRequired = errors.New("encryption part is required") + ErrEncryptionNotConfigured = errors.New("encryption not configured") + ErrInvalidEncryptionPart = errors.New("invalid encryption part") ErrInternal = errors.New("internal error") ) @@ -39,5 +42,13 @@ func fromDomainError(err error) error { return ErrExternalUserAlreadyExists } + if errors.Is(err, domain.ErrEncryptionPartRequired) { + return ErrEncryptionPartRequired + } + + if errors.Is(err, domain.ErrEncryptionPartNotFound) { + return ErrEncryptionNotConfigured + } + return ErrInternal } diff --git a/internal/applications/shareapp/options.go b/internal/applications/shareapp/options.go new file mode 100644 index 0000000..69a06c8 --- /dev/null +++ b/internal/applications/shareapp/options.go @@ -0,0 +1,13 @@ +package shareapp + +type options struct { + encryptionPart *string +} + +type Option func(*options) + +func WithEncryptionPart(encryptionPart string) Option { + return func(o *options) { + o.encryptionPart = &encryptionPart + } +} diff --git a/internal/applications/userapp/types.go b/internal/applications/shareapp/types.go similarity index 86% rename from internal/applications/userapp/types.go rename to internal/applications/shareapp/types.go index 84c6077..5338eae 100644 --- a/internal/applications/userapp/types.go +++ b/internal/applications/shareapp/types.go @@ -1,4 +1,4 @@ -package userapp +package shareapp type EncryptionParameters struct { Salt string diff --git a/internal/applications/userapp/app.go b/internal/applications/userapp/app.go deleted file mode 100644 index 23830f4..0000000 --- a/internal/applications/userapp/app.go +++ /dev/null @@ -1,70 +0,0 @@ -package userapp - -import ( - "context" - "log/slog" - "os" - - "go.openfort.xyz/shield/internal/core/domain/share" - "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/internal/infrastructure/providersmgr" - "go.openfort.xyz/shield/pkg/ofcontext" - "go.openfort.xyz/shield/pkg/oflog" -) - -type UserApplication struct { - userSvc services.UserService - shareSvc services.ShareService - projectSvc services.ProjectService - providerSvc services.ProviderService - providerManager *providersmgr.Manager - logger *slog.Logger -} - -func New(userSvc services.UserService, shareSvc services.ShareService, projectSvc services.ProjectService, providerSvc services.ProviderService, providerManager *providersmgr.Manager) *UserApplication { - return &UserApplication{ - userSvc: userSvc, - shareSvc: shareSvc, - projectSvc: projectSvc, - providerSvc: providerSvc, - providerManager: providerManager, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("user_application"), - } -} - -func (a *UserApplication) RegisterShare(ctx context.Context, secret string, userEntropy bool, parameters *EncryptionParameters) error { - a.logger.InfoContext(ctx, "registering share") - usrID := ofcontext.GetUserID(ctx) - - shre := &share.Share{ - Data: secret, - UserID: usrID, - UserEntropy: userEntropy, - } - if parameters != nil { - shre.Salt = parameters.Salt - shre.Iterations = parameters.Iterations - shre.Length = parameters.Length - shre.Digest = parameters.Digest - } - err := a.shareSvc.Create(ctx, shre) - if err != nil { - a.logger.ErrorContext(ctx, "failed to create share", slog.String("error", err.Error())) - return fromDomainError(err) - } - - return nil -} - -func (a *UserApplication) GetShare(ctx context.Context) (*share.Share, error) { - a.logger.InfoContext(ctx, "getting share") - - usrID := ofcontext.GetUserID(ctx) - shr, err := a.shareSvc.GetByUserID(ctx, usrID) - if err != nil { - a.logger.ErrorContext(ctx, "failed to get share by user ID", slog.String("error", err.Error())) - return nil, fromDomainError(err) - } - - return shr, nil -} diff --git a/internal/core/domain/errors.go b/internal/core/domain/errors.go index 6e57877..4f95cfa 100644 --- a/internal/core/domain/errors.go +++ b/internal/core/domain/errors.go @@ -4,7 +4,10 @@ import "errors" var ( // Project errors - ErrProjectNotFound = errors.New("project not found") + ErrProjectNotFound = errors.New("project not found") + ErrEncryptionPartNotFound = errors.New("encryption part not found") + ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists") + ErrEncryptionPartRequired = errors.New("encryption part is required") // Provider errors ErrNoProviderConfig = errors.New("no provider config found") diff --git a/internal/core/domain/project/project.go b/internal/core/domain/project/project.go index d4150e7..996b930 100644 --- a/internal/core/domain/project/project.go +++ b/internal/core/domain/project/project.go @@ -1,8 +1,9 @@ package project type Project struct { - ID string - Name string - APIKey string - APISecret string + ID string + Name string + APIKey string + APISecret string + EncryptionPart string } diff --git a/internal/core/domain/share/encryption_parameters.go b/internal/core/domain/share/encryption_parameters.go new file mode 100644 index 0000000..944cda6 --- /dev/null +++ b/internal/core/domain/share/encryption_parameters.go @@ -0,0 +1,9 @@ +package share + +type EncryptionParameters struct { + Entropy Entropy + Salt string + Iterations int + Length int + Digest string +} diff --git a/internal/core/domain/share/entropy.go b/internal/core/domain/share/entropy.go new file mode 100644 index 0000000..3f5ab50 --- /dev/null +++ b/internal/core/domain/share/entropy.go @@ -0,0 +1,9 @@ +package share + +type Entropy int8 + +const ( + EntropyNone Entropy = iota + EntropyUser + EntropyProject +) diff --git a/internal/core/domain/share/share.go b/internal/core/domain/share/share.go index c706bcb..df33e51 100644 --- a/internal/core/domain/share/share.go +++ b/internal/core/domain/share/share.go @@ -1,12 +1,12 @@ package share type Share struct { - ID string - Data string - UserID string - UserEntropy bool - Salt string - Iterations int - Length int - Digest string + ID string + Secret string + UserID string + EncryptionParameters *EncryptionParameters +} + +func (s *Share) RequiresEncryption() bool { + return s.EncryptionParameters != nil && s.EncryptionParameters.Entropy == EntropyProject } diff --git a/internal/core/ports/repositories/project.go b/internal/core/ports/repositories/project.go index 464b839..027b3d4 100644 --- a/internal/core/ports/repositories/project.go +++ b/internal/core/ports/repositories/project.go @@ -10,8 +10,13 @@ type ProjectRepository interface { Create(ctx context.Context, project *project.Project) error Get(ctx context.Context, projectID string) (*project.Project, error) GetByAPIKey(ctx context.Context, apiKey string) (*project.Project, error) + Delete(ctx context.Context, projectID string) error + AddAllowedOrigin(ctx context.Context, projectID, origin string) error RemoveAllowedOrigin(ctx context.Context, projectID, origin string) error GetAllowedOrigins(ctx context.Context, projectID string) ([]string, error) GetAllowedOriginsByAPIKey(ctx context.Context, apiKey string) ([]string, error) + + GetEncryptionPart(ctx context.Context, projectID string) (string, error) + SetEncryptionPart(ctx context.Context, projectID, part string) error } diff --git a/internal/core/ports/repositories/shares.go b/internal/core/ports/repositories/shares.go index c62c33b..37bd7fa 100644 --- a/internal/core/ports/repositories/shares.go +++ b/internal/core/ports/repositories/shares.go @@ -9,4 +9,6 @@ import ( type ShareRepository interface { Create(ctx context.Context, shr *share.Share) error GetByUserID(ctx context.Context, userID string) (*share.Share, error) + ListDecryptedByProjectID(ctx context.Context, projectID string) ([]*share.Share, error) + Update(ctx context.Context, shr *share.Share) error } diff --git a/internal/core/ports/services/project.go b/internal/core/ports/services/project.go index e9c6d60..c01eb29 100644 --- a/internal/core/ports/services/project.go +++ b/internal/core/ports/services/project.go @@ -8,9 +8,10 @@ import ( type ProjectService interface { Create(ctx context.Context, name string) (*project.Project, error) - Get(ctx context.Context, projectID string) (*project.Project, error) GetByAPIKey(ctx context.Context, apiKey string) (*project.Project, error) AddAllowedOrigin(ctx context.Context, projectID, origin string) error RemoveAllowedOrigin(ctx context.Context, projectID, origin string) error GetAllowedOrigins(ctx context.Context, projectID string) ([]string, error) + GetEncryptionPart(ctx context.Context, projectID string) (string, error) + SetEncryptionPart(ctx context.Context, projectID, part string) error } diff --git a/internal/core/ports/services/provider.go b/internal/core/ports/services/provider.go index 092dc77..fcd608b 100644 --- a/internal/core/ports/services/provider.go +++ b/internal/core/ports/services/provider.go @@ -7,11 +7,7 @@ import ( ) type ProviderService interface { - Configure(ctx context.Context, projectID string, config ProviderConfig) (*provider.Provider, error) - Get(ctx context.Context, providerID string) (*provider.Provider, error) - List(ctx context.Context, projectID string) ([]*provider.Provider, error) - UpdateConfig(ctx context.Context, config interface{}) error - Remove(ctx context.Context, projectID string, providerID string) error + Configure(ctx context.Context, prov *provider.Provider) error } type ProviderConfig interface { diff --git a/internal/core/ports/services/share.go b/internal/core/ports/services/share.go index 8883058..32904b2 100644 --- a/internal/core/ports/services/share.go +++ b/internal/core/ports/services/share.go @@ -7,6 +7,17 @@ import ( ) type ShareService interface { - Create(ctx context.Context, share *share.Share) error - GetByUserID(ctx context.Context, userID string) (*share.Share, error) + Create(ctx context.Context, share *share.Share, opts ...ShareOption) error +} + +type ShareOption func(*ShareOptions) + +type ShareOptions struct { + EncryptionKey *string +} + +func WithEncryptionKey(key string) ShareOption { + return func(o *ShareOptions) { + o.EncryptionKey = &key + } } diff --git a/internal/core/services/projectsvc/svc.go b/internal/core/services/projectsvc/svc.go index b023858..6c58fa7 100644 --- a/internal/core/services/projectsvc/svc.go +++ b/internal/core/services/projectsvc/svc.go @@ -2,14 +2,15 @@ package projectsvc import ( "context" + "errors" "log/slog" - "os" "github.com/google/uuid" + "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" "golang.org/x/crypto/bcrypt" ) @@ -24,7 +25,7 @@ var _ services.ProjectService = (*service)(nil) func New(repo repositories.ProjectRepository) services.ProjectService { return &service{ repo: repo, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("project_service"), + logger: logger.New("project_service"), cost: bcrypt.DefaultCost, } } @@ -34,7 +35,7 @@ func (s *service) Create(ctx context.Context, name string) (*project.Project, er apiSecret := uuid.NewString() encryptedSecret, err := bcrypt.GenerateFromPassword([]byte(apiSecret), s.cost) if err != nil { - s.logger.ErrorContext(ctx, "failed to encrypt secret", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to encrypt secret", logger.Error(err)) return nil, err } @@ -46,7 +47,7 @@ func (s *service) Create(ctx context.Context, name string) (*project.Project, er err = s.repo.Create(ctx, proj) if err != nil { - s.logger.ErrorContext(ctx, "failed to create project", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to create project", logger.Error(err)) return nil, err } @@ -54,22 +55,11 @@ func (s *service) Create(ctx context.Context, name string) (*project.Project, er return proj, nil } -func (s *service) Get(ctx context.Context, projectID string) (*project.Project, error) { - s.logger.InfoContext(ctx, "getting project", slog.String("project_id", projectID)) - proj, err := s.repo.Get(ctx, projectID) - if err != nil { - s.logger.ErrorContext(ctx, "failed to get project", slog.String("error", err.Error())) - return nil, err - } - - return proj, nil -} - func (s *service) GetByAPIKey(ctx context.Context, apiKey string) (*project.Project, error) { s.logger.InfoContext(ctx, "getting project by API key", slog.String("api_key", apiKey)) proj, err := s.repo.GetByAPIKey(ctx, apiKey) if err != nil { - s.logger.ErrorContext(ctx, "failed to get project by API key", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get project by API key", logger.Error(err)) return nil, err } @@ -80,7 +70,7 @@ func (s *service) AddAllowedOrigin(ctx context.Context, projectID, origin string s.logger.InfoContext(ctx, "adding allowed origin", slog.String("project_id", projectID), slog.String("origin", origin)) err := s.repo.AddAllowedOrigin(ctx, projectID, origin) if err != nil { - s.logger.ErrorContext(ctx, "failed to add allowed origin", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to add allowed origin", logger.Error(err)) return err } @@ -91,7 +81,7 @@ func (s *service) RemoveAllowedOrigin(ctx context.Context, projectID, origin str s.logger.InfoContext(ctx, "removing allowed origin", slog.String("project_id", projectID), slog.String("origin", origin)) err := s.repo.RemoveAllowedOrigin(ctx, projectID, origin) if err != nil { - s.logger.ErrorContext(ctx, "failed to remove allowed origin", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to remove allowed origin", logger.Error(err)) return err } @@ -102,9 +92,42 @@ func (s *service) GetAllowedOrigins(ctx context.Context, projectID string) ([]st s.logger.InfoContext(ctx, "getting allowed origins", slog.String("project_id", projectID)) origins, err := s.repo.GetAllowedOrigins(ctx, projectID) if err != nil { - s.logger.ErrorContext(ctx, "failed to get allowed origins", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get allowed origins", logger.Error(err)) return nil, err } return origins, nil } + +func (s *service) GetEncryptionPart(ctx context.Context, projectID string) (string, error) { + s.logger.InfoContext(ctx, "getting encryption part", slog.String("project_id", projectID)) + part, err := s.repo.GetEncryptionPart(ctx, projectID) + if err != nil { + s.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) + return "", err + } + + return part, nil +} + +func (s *service) SetEncryptionPart(ctx context.Context, projectID, part string) error { + s.logger.InfoContext(ctx, "setting encryption part", slog.String("project_id", projectID)) + ep, err := s.repo.GetEncryptionPart(ctx, projectID) + if err != nil && !errors.Is(err, domain.ErrEncryptionPartNotFound) { + s.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) + return err + } + + if ep != "" { + s.logger.Warn("encryption part already exists", slog.String("project_id", projectID)) + return domain.ErrEncryptionPartAlreadyExists + } + + err = s.repo.SetEncryptionPart(ctx, projectID, part) + if err != nil { + s.logger.ErrorContext(ctx, "failed to set encryption part", logger.Error(err)) + return err + } + + return nil +} diff --git a/internal/core/services/projectsvc/svc_test.go b/internal/core/services/projectsvc/svc_test.go index 1668d64..0b067d4 100644 --- a/internal/core/services/projectsvc/svc_test.go +++ b/internal/core/services/projectsvc/svc_test.go @@ -61,54 +61,6 @@ func TestCreateProject(t *testing.T) { } } -func TestGetProject(t *testing.T) { - mockRepo := new(projectmockrepo.MockProjectRepository) - svc := New(mockRepo) - ctx := context.Background() - testProjectID := "get-test-project-id" - - tc := []struct { - name string - wantErr bool - mock func() - }{ - { - name: "success", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Get", mock.Anything, testProjectID).Return(&project.Project{}, nil) - }, - }, - { - name: "project not found", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Get", mock.Anything, testProjectID).Return(nil, errors.New("project not found")) - }, - }, - { - name: "repository error", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Get", mock.Anything, testProjectID).Return(nil, errors.New("repository error")) - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - _, err := svc.Get(ctx, testProjectID) - if (err != nil) != tt.wantErr { - t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - func TestGetProjectByAPIKey(t *testing.T) { mockRepo := new(projectmockrepo.MockProjectRepository) svc := New(mockRepo) diff --git a/internal/core/services/providersvc/svc.go b/internal/core/services/providersvc/svc.go index 174485e..7e9ab35 100644 --- a/internal/core/services/providersvc/svc.go +++ b/internal/core/services/providersvc/svc.go @@ -4,13 +4,12 @@ import ( "context" "errors" "log/slog" - "os" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type service struct { @@ -23,181 +22,75 @@ var _ services.ProviderService = (*service)(nil) func New(repo repositories.ProviderRepository) services.ProviderService { return &service{ repo: repo, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("provider_service"), + logger: logger.New("provider_service"), } } -func (s *service) Configure(ctx context.Context, projectID string, config services.ProviderConfig) (*provider.Provider, error) { - if config == nil { - return nil, domain.ErrNoProviderConfig - } - - switch config.GetType() { +func (s *service) Configure(ctx context.Context, prov *provider.Provider) error { + switch prov.Type { case provider.TypeCustom: - customConfig, ok := config.GetConfig().(*services.CustomProviderConfig) - if !ok { - return nil, domain.ErrInvalidProviderConfig - } - - return s.configureCustomProvider(ctx, projectID, customConfig.JWKUrl) + return s.configureCustomProvider(ctx, prov) case provider.TypeOpenfort: - openfortConfig, ok := config.GetConfig().(*services.OpenfortProviderConfig) - if !ok { - return nil, domain.ErrInvalidProviderConfig - } - - return s.configureOpenfortProvider(ctx, projectID, openfortConfig.OpenfortProject) + return s.configureOpenfortProvider(ctx, prov) default: - return nil, domain.ErrUnknownProviderType + return domain.ErrUnknownProviderType } } -func (s *service) Get(ctx context.Context, providerID string) (*provider.Provider, error) { - s.logger.InfoContext(ctx, "getting provider", slog.String("provider_id", providerID)) +func (s *service) configureCustomProvider(ctx context.Context, prov *provider.Provider) error { + s.logger.InfoContext(ctx, "configuring custom provider", slog.String("project_id", prov.ProjectID)) - prov, err := s.repo.Get(ctx, providerID) + err := s.repo.Create(ctx, prov) if err != nil { - s.logger.ErrorContext(ctx, "failed to get provider", slog.String("error", err.Error())) - return nil, err - } - - return prov, nil -} - -func (s *service) List(ctx context.Context, projectID string) ([]*provider.Provider, error) { - s.logger.InfoContext(ctx, "listing providers", slog.String("project_id", projectID)) - - provs, err := s.repo.List(ctx, projectID) - if err != nil { - s.logger.ErrorContext(ctx, "failed to list providers", slog.String("error", err.Error())) - return nil, err - } - - return provs, nil -} - -func (s *service) UpdateConfig(ctx context.Context, config interface{}) error { - s.logger.InfoContext(ctx, "updating provider config") - - if cfg, ok := config.(*provider.CustomConfig); ok { - s.logger.InfoContext(ctx, "updating custom provider config", slog.String("provider_id", cfg.ProviderID)) - err := s.repo.UpdateCustom(ctx, cfg) - if err != nil { - s.logger.ErrorContext(ctx, "failed to update provider config", slog.String("error", err.Error())) - return err - } - - return nil - } - - if cfg, ok := config.(*provider.OpenfortConfig); ok { - s.logger.InfoContext(ctx, "updating openfort provider config", slog.String("provider_id", cfg.ProviderID)) - err := s.repo.UpdateOpenfort(ctx, cfg) - if err != nil { - s.logger.ErrorContext(ctx, "failed to update provider config", slog.String("error", err.Error())) - return err - } - - return nil - } - - return domain.ErrInvalidProviderConfig -} - -func (s *service) Remove(ctx context.Context, projectID string, providerID string) error { - s.logger.InfoContext(ctx, "removing provider", slog.String("project_id", projectID), slog.String("provider_id", providerID)) - - err := s.repo.Delete(ctx, providerID) - if err != nil { - s.logger.ErrorContext(ctx, "failed to delete provider", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to create provider", logger.Error(err)) return err } - return nil -} - -func (s *service) configureCustomProvider(ctx context.Context, projectID, jwkURL string) (*provider.Provider, error) { - s.logger.InfoContext(ctx, "configuring custom provider", slog.String("project_id", projectID)) - - prov, err := s.repo.GetByProjectAndType(ctx, projectID, provider.TypeCustom) - if err != nil && !errors.Is(err, domain.ErrProviderNotFound) { - s.logger.ErrorContext(ctx, "failed to get provider", slog.String("error", err.Error())) - return nil, err - } - - if prov != nil { - s.logger.ErrorContext(ctx, "provider already exists") - return nil, domain.ErrProviderAlreadyExists - } - - prov = &provider.Provider{ - ProjectID: projectID, - Type: provider.TypeCustom, - } - err = s.repo.Create(ctx, prov) - if err != nil { - s.logger.ErrorContext(ctx, "failed to create provider", slog.String("error", err.Error())) - return nil, err + customAuth, ok := prov.Config.(*provider.CustomConfig) + if !ok { + s.logger.ErrorContext(ctx, "invalid custom provider config") + return domain.ErrInvalidProviderConfig } - customAuth := &provider.CustomConfig{ - ProviderID: prov.ID, - JWK: jwkURL, - } err = s.repo.CreateCustom(ctx, customAuth) if err != nil { - s.logger.ErrorContext(ctx, "failed to create custom provider", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to create custom provider", logger.Error(err)) errD := s.repo.Delete(ctx, prov.ID) if errD != nil { - s.logger.ErrorContext(ctx, "failed to delete provider", slog.String("provider", prov.ID), slog.String("error", errD.Error())) + s.logger.ErrorContext(ctx, "failed to delete provider", slog.String("provider", prov.ID), logger.Error(errD)) err = errors.Join(err, errD) } - return nil, err + return err } - prov.Config = customAuth - return prov, nil + return nil } -func (s *service) configureOpenfortProvider(ctx context.Context, projectID, openfortProject string) (*provider.Provider, error) { - s.logger.InfoContext(ctx, "configuring openfort provider", slog.String("project_id", projectID)) - - prov, err := s.repo.GetByProjectAndType(ctx, projectID, provider.TypeOpenfort) - if err != nil && !errors.Is(err, domain.ErrProviderNotFound) { - s.logger.ErrorContext(ctx, "failed to get provider", slog.String("error", err.Error())) - return nil, err - } - - if prov != nil { - s.logger.ErrorContext(ctx, "provider already exists") - return nil, domain.ErrProviderAlreadyExists - } +func (s *service) configureOpenfortProvider(ctx context.Context, prov *provider.Provider) error { + s.logger.InfoContext(ctx, "configuring openfort provider", slog.String("project_id", prov.ProjectID)) - prov = &provider.Provider{ - ProjectID: projectID, - Type: provider.TypeOpenfort, - } - err = s.repo.Create(ctx, prov) + err := s.repo.Create(ctx, prov) if err != nil { - s.logger.ErrorContext(ctx, "failed to create provider", slog.String("error", err.Error())) - return nil, err + s.logger.ErrorContext(ctx, "failed to create provider", logger.Error(err)) + return err } - openfortAuth := &provider.OpenfortConfig{ - ProviderID: prov.ID, - PublishableKey: openfortProject, + openfortAuth, ok := prov.Config.(*provider.OpenfortConfig) + if !ok { + s.logger.ErrorContext(ctx, "invalid openfort provider config") + return domain.ErrInvalidProviderConfig } + err = s.repo.CreateOpenfort(ctx, openfortAuth) if err != nil { - s.logger.ErrorContext(ctx, "failed to create openfort provider", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to create openfort provider", logger.Error(err)) errD := s.repo.Delete(ctx, prov.ID) if errD != nil { - s.logger.ErrorContext(ctx, "failed to delete provider", slog.String("provider", prov.ID), slog.String("error", errD.Error())) + s.logger.ErrorContext(ctx, "failed to delete provider", slog.String("provider", prov.ID), logger.Error(errD)) err = errors.Join(err, errD) } - return nil, err + return err } - prov.Config = openfortAuth - return prov, nil + return nil } diff --git a/internal/core/services/providersvc/svc_test.go b/internal/core/services/providersvc/svc_test.go index 61f8348..79ee1d7 100644 --- a/internal/core/services/providersvc/svc_test.go +++ b/internal/core/services/providersvc/svc_test.go @@ -194,7 +194,7 @@ func TestConfigureProvider(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { tt.mock() - _, err := svc.Configure(ctx, projectID, tt.config) + err := svc.Configure(ctx, nil) if (err != nil) != tt.wantErr { t.Errorf("Configure() error = %v, wantErr %v", err, tt.wantErr) } @@ -205,237 +205,6 @@ func TestConfigureProvider(t *testing.T) { } } -func TestGetProvider(t *testing.T) { - mockRepo := new(providermockrepo.MockProviderRepository) - svc := New(mockRepo) - ctx := context.Background() - testProviderID := "test-provider-id" - - tc := []struct { - name string - wantErr bool - err error - mock func() - }{ - { - name: "success", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Get", mock.Anything, mock.AnythingOfType("string")).Return(&provider.Provider{}, nil) - }, - }, - { - name: "provider not found", - wantErr: true, - err: domain.ErrProviderNotFound, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Get", mock.Anything, mock.AnythingOfType("string")).Return(nil, domain.ErrProviderNotFound) - }, - }, - { - name: "repository error", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Get", mock.Anything, mock.AnythingOfType("string")).Return(nil, errors.New("repository error")) - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - _, err := svc.Get(ctx, testProviderID) - if (err != nil) != tt.wantErr { - t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("Get() error = %v, wantErr %v", err, tt.err) - } - }) - } -} - -func TestListProviders(t *testing.T) { - mockRepo := new(providermockrepo.MockProviderRepository) - svc := New(mockRepo) - ctx := context.Background() - testProjectID := "test-project-id" - - tc := []struct { - name string - wantErr bool - mock func() - }{ - { - name: "success with providers", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("List", mock.Anything, mock.AnythingOfType("string")).Return([]*provider.Provider{{}}, nil) - }, - }, - { - name: "success with no providers", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("List", mock.Anything, mock.AnythingOfType("string")).Return([]*provider.Provider{}, nil) - }, - }, - { - name: "repository error", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("List", mock.Anything, mock.AnythingOfType("string")).Return(nil, errors.New("repository error")) - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - result, err := svc.List(ctx, testProjectID) - if (err != nil) != tt.wantErr { - t.Errorf("List() error = %v, wantErr %v", err, tt.wantErr) - } - if !tt.wantErr && result == nil { - t.Errorf("List() expected a result but got nil") - } - }) - } -} - -func TestUpdateConfig(t *testing.T) { - mockRepo := new(providermockrepo.MockProviderRepository) - svc := New(mockRepo) - ctx := context.Background() - customConfig := &provider.CustomConfig{ProviderID: "custom-id", JWK: "http://jwk.url"} - openfortConfig := &provider.OpenfortConfig{ProviderID: "openfort-id", PublishableKey: "openfort-project"} - - tc := []struct { - name string - config interface{} - wantErr bool - err error - mock func() - }{ - { - name: "update custom config success", - config: customConfig, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("UpdateCustom", mock.Anything, customConfig).Return(nil) - }, - }, - { - name: "update openfort config success", - config: openfortConfig, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("UpdateOpenfort", mock.Anything, openfortConfig).Return(nil) - }, - }, - { - name: "custom config repository error", - config: customConfig, - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("UpdateCustom", mock.Anything, customConfig).Return(errors.New("repository error")) - }, - }, - { - name: "openfort config repository error", - config: openfortConfig, - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("UpdateOpenfort", mock.Anything, openfortConfig).Return(errors.New("repository error")) - }, - }, - { - name: "invalid config type", - config: unknownProviderConfig{}, - wantErr: true, - err: domain.ErrInvalidProviderConfig, - mock: func() { - mockRepo.ExpectedCalls = nil - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - err := svc.UpdateConfig(ctx, tt.config) - if (err != nil) != tt.wantErr { - t.Errorf("UpdateConfig() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("UpdateConfig() error = %v, wantErr %v", err, tt.err) - } - }) - } -} - -func TestRemoveProvider(t *testing.T) { - mockRepo := new(providermockrepo.MockProviderRepository) - svc := New(mockRepo) - ctx := context.Background() - testProviderID := "test-provider-id" - testProjectID := "test-project-id" - - tc := []struct { - name string - wantErr bool - err error - mock func() - }{ - { - name: "successful removal", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Delete", mock.Anything, testProviderID).Return(nil) - }, - }, - { - name: "provider not found", - wantErr: true, - err: domain.ErrProviderNotFound, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Delete", mock.Anything, testProviderID).Return(domain.ErrProviderNotFound) - }, - }, - { - name: "repository error", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("Delete", mock.Anything, testProviderID).Return(errors.New("repository error")) - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - err := svc.Remove(ctx, testProjectID, testProviderID) - if (err != nil) != tt.wantErr { - t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("Remove() error = %v, wantErr %v", err, tt.err) - } - }) - } -} - type unknownProviderConfig struct{} func (f *unknownProviderConfig) GetConfig() interface{} { return nil } diff --git a/internal/core/services/sharesvc/svc.go b/internal/core/services/sharesvc/svc.go index f9bb16c..07686ef 100644 --- a/internal/core/services/sharesvc/svc.go +++ b/internal/core/services/sharesvc/svc.go @@ -4,13 +4,13 @@ import ( "context" "errors" "log/slog" - "os" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/pkg/logger" ) type service struct { @@ -23,16 +23,16 @@ var _ services.ShareService = (*service)(nil) func New(repo repositories.ShareRepository) services.ShareService { return &service{ repo: repo, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("share_service"), + logger: logger.New("share_service"), } } -func (s *service) Create(ctx context.Context, shr *share.Share) error { +func (s *service) Create(ctx context.Context, shr *share.Share, opts ...services.ShareOption) error { s.logger.InfoContext(ctx, "creating share", slog.String("user_id", shr.UserID)) shrRepo, err := s.repo.GetByUserID(ctx, shr.UserID) if err != nil && !errors.Is(err, domain.ErrShareNotFound) { - s.logger.ErrorContext(ctx, "failed to get share", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get share", logger.Error(err)) return err } @@ -41,22 +41,28 @@ func (s *service) Create(ctx context.Context, shr *share.Share) error { return domain.ErrShareAlreadyExists } - err = s.repo.Create(ctx, shr) - if err != nil { - s.logger.ErrorContext(ctx, "failed to create share", slog.String("error", err.Error())) - return err + var o services.ShareOptions + for _, opt := range opts { + opt(&o) } - return nil -} + if shr.RequiresEncryption() { + if o.EncryptionKey == nil { + return domain.ErrEncryptionPartRequired + } -func (s *service) GetByUserID(ctx context.Context, userID string) (*share.Share, error) { - s.logger.InfoContext(ctx, "getting share by user", slog.String("user_id", userID)) - shr, err := s.repo.GetByUserID(ctx, userID) + shr.Secret, err = cypher.Encrypt(shr.Secret, *o.EncryptionKey) + if err != nil { + s.logger.ErrorContext(ctx, "failed to encrypt secret", logger.Error(err)) + return err + } + } + + err = s.repo.Create(ctx, shr) if err != nil { - s.logger.ErrorContext(ctx, "failed to get share", slog.String("error", err.Error())) - return nil, err + s.logger.ErrorContext(ctx, "failed to create share", logger.Error(err)) + return err } - return shr, nil + return nil } diff --git a/internal/core/services/sharesvc/svc_test.go b/internal/core/services/sharesvc/svc_test.go index abe825f..06a5943 100644 --- a/internal/core/services/sharesvc/svc_test.go +++ b/internal/core/services/sharesvc/svc_test.go @@ -19,7 +19,7 @@ func TestCreateShare(t *testing.T) { testData := "test-data" testShare := &share.Share{ UserID: testUserID, - Data: testData, + Secret: testData, } tc := []struct { @@ -78,59 +78,3 @@ func TestCreateShare(t *testing.T) { }) } } - -func TestGetShareByUserID(t *testing.T) { - mockRepo := new(sharemockrepo.MockShareRepository) - svc := New(mockRepo) - ctx := context.Background() - testUserID := "test-user" - - tc := []struct { - name string - wantErr bool - err error - mock func() - }{ - { - name: "success", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(&share.Share{}, nil) - }, - }, - { - name: "share not found", - wantErr: true, - err: domain.ErrShareNotFound, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domain.ErrShareNotFound) - }, - }, - { - name: "repository error", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, errors.New("repository error")) - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - result, err := svc.GetByUserID(ctx, testUserID) - if (err != nil) != tt.wantErr { - t.Errorf("GetByUserID() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("GetByUserID() error = %v, expected error %v", err, tt.err) - } - if !tt.wantErr && result == nil { - t.Errorf("GetByUserID() expected a result but got nil") - } - }) - } -} diff --git a/internal/core/services/usersvc/svc.go b/internal/core/services/usersvc/svc.go index c53784c..7fb8125 100644 --- a/internal/core/services/usersvc/svc.go +++ b/internal/core/services/usersvc/svc.go @@ -4,13 +4,12 @@ import ( "context" "errors" "log/slog" - "os" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/user" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type service struct { @@ -23,7 +22,7 @@ var _ services.UserService = (*service)(nil) func New(repo repositories.UserRepository) services.UserService { return &service{ repo: repo, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("user_service"), + logger: logger.New("user_service"), } } @@ -35,7 +34,7 @@ func (s *service) Create(ctx context.Context, projectID string) (*user.User, err err := s.repo.Create(ctx, usr) if err != nil { - s.logger.ErrorContext(ctx, "failed to create user", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) return nil, err } @@ -46,7 +45,7 @@ func (s *service) Get(ctx context.Context, userID string) (*user.User, error) { s.logger.InfoContext(ctx, "getting user", slog.String("user_id", userID)) usr, err := s.repo.Get(ctx, userID) if err != nil { - s.logger.ErrorContext(ctx, "failed to get user", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get user", logger.Error(err)) return nil, err } @@ -58,7 +57,7 @@ func (s *service) GetByExternal(ctx context.Context, externalUserID, providerID extUsrs, err := s.repo.FindExternalBy(ctx, s.repo.WithExternalUserID(externalUserID), s.repo.WithProviderID(providerID)) if err != nil { - s.logger.ErrorContext(ctx, "failed to get external user", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get external user", logger.Error(err)) return nil, err } @@ -70,7 +69,7 @@ func (s *service) GetByExternal(ctx context.Context, externalUserID, providerID extUsr := extUsrs[0] usr, err := s.repo.Get(ctx, extUsr.UserID) if err != nil { - s.logger.ErrorContext(ctx, "failed to get user", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get user", logger.Error(err)) return nil, err } @@ -82,7 +81,7 @@ func (s *service) CreateExternal(ctx context.Context, projectID, userID, externa usr, err := s.repo.Get(ctx, userID) if err != nil { - s.logger.ErrorContext(ctx, "failed to get user", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get user", logger.Error(err)) return nil, err } @@ -98,7 +97,7 @@ func (s *service) CreateExternal(ctx context.Context, projectID, userID, externa extUsrs, err := s.repo.FindExternalBy(ctx, s.repo.WithUserID(userID), s.repo.WithProviderID(providerID)) if err != nil && !errors.Is(err, domain.ErrExternalUserNotFound) { - s.logger.ErrorContext(ctx, "failed to get external user", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to get external user", logger.Error(err)) return nil, err } @@ -115,7 +114,7 @@ func (s *service) CreateExternal(ctx context.Context, projectID, userID, externa err = s.repo.CreateExternal(ctx, extUsr) if err != nil { - s.logger.ErrorContext(ctx, "failed to create external user", slog.String("error", err.Error())) + s.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) return nil, err } diff --git a/internal/infrastructure/authenticationmgr/apisecret.go b/internal/infrastructure/authenticationmgr/apisecret.go index 4a30b7a..85b84e2 100644 --- a/internal/infrastructure/authenticationmgr/apisecret.go +++ b/internal/infrastructure/authenticationmgr/apisecret.go @@ -3,11 +3,10 @@ package authenticationmgr import ( "context" "log/slog" - "os" "go.openfort.xyz/shield/internal/core/ports/authentication" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" "golang.org/x/crypto/bcrypt" ) @@ -21,7 +20,7 @@ var _ authentication.APISecretAuthenticator = (*apiSecret)(nil) func newAPISecretAuthenticator(repository repositories.ProjectRepository) authentication.APISecretAuthenticator { return &apiSecret{ projectRepo: repository, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("api_key_authenticator"), + logger: logger.New("api_key_authenticator"), } } @@ -30,13 +29,13 @@ func (a *apiSecret) Authenticate(ctx context.Context, apiKey, apiSecret string) proj, err := a.projectRepo.GetByAPIKey(ctx, apiKey) if err != nil { - a.logger.ErrorContext(ctx, "failed to authenticate api key", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to authenticate api key", logger.Error(err)) return "", err } err = bcrypt.CompareHashAndPassword([]byte(proj.APISecret), []byte(apiSecret)) if err != nil { - a.logger.ErrorContext(ctx, "failed to authenticate api secret", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to authenticate api secret", logger.Error(err)) return "", err } diff --git a/internal/infrastructure/authenticationmgr/manager.go b/internal/infrastructure/authenticationmgr/manager.go index 55b6fd0..e96a40f 100644 --- a/internal/infrastructure/authenticationmgr/manager.go +++ b/internal/infrastructure/authenticationmgr/manager.go @@ -58,7 +58,6 @@ func (m *Manager) IsAllowedOrigin(ctx context.Context, apiKey string, origin str dbOrigins, err := m.repo.GetAllowedOriginsByAPIKey(ctx, apiKey) if err != nil { - return false, err } m.mapOrigins[apiKey] = dbOrigins diff --git a/internal/infrastructure/authenticationmgr/user.go b/internal/infrastructure/authenticationmgr/user.go index 22b355b..d5b30d3 100644 --- a/internal/infrastructure/authenticationmgr/user.go +++ b/internal/infrastructure/authenticationmgr/user.go @@ -4,7 +4,6 @@ import ( "context" "errors" "log/slog" - "os" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/provider" @@ -13,7 +12,7 @@ import ( "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/internal/infrastructure/providersmgr" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type user struct { @@ -30,7 +29,7 @@ func newUserAuthenticator(repository repositories.ProjectRepository, providerMan projectRepo: repository, providerManager: providerManager, userService: userService, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("api_key_authenticator"), + logger: logger.New("api_key_authenticator"), } } @@ -39,13 +38,13 @@ func (a *user) Authenticate(ctx context.Context, apiKey, token string, providerT proj, err := a.projectRepo.GetByAPIKey(ctx, apiKey) if err != nil { - a.logger.ErrorContext(ctx, "failed to authenticate api key", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to authenticate api key", logger.Error(err)) return "", err } prov, err := a.providerManager.GetProvider(ctx, proj.ID, providerType) if err != nil { - a.logger.ErrorContext(ctx, "failed to get provider", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return "", err } @@ -64,25 +63,25 @@ func (a *user) Authenticate(ctx context.Context, apiKey, token string, providerT externalUserID, err := prov.Identify(ctx, token, providerCustomOptions...) if err != nil { - a.logger.ErrorContext(ctx, "failed to identify user", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to identify user", logger.Error(err)) return "", err } usr, err := a.userService.GetByExternal(ctx, externalUserID, prov.GetProviderID()) if err != nil { if !errors.Is(err, domain.ErrUserNotFound) && !errors.Is(err, domain.ErrExternalUserNotFound) { - a.logger.ErrorContext(ctx, "failed to get user by external", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err)) return "", err } usr, err = a.userService.Create(ctx, proj.ID) if err != nil { - a.logger.ErrorContext(ctx, "failed to create user", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) return "", err } _, err = a.userService.CreateExternal(ctx, proj.ID, usr.ID, externalUserID, prov.GetProviderID()) if err != nil { - a.logger.ErrorContext(ctx, "failed to create external user", slog.String("error", err.Error())) + a.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) return "", err } } diff --git a/internal/infrastructure/handlers/rest/api/errors.go b/internal/infrastructure/handlers/rest/api/errors.go index 92842b1..c5f4ce1 100644 --- a/internal/infrastructure/handlers/rest/api/errors.go +++ b/internal/infrastructure/handlers/rest/api/errors.go @@ -35,6 +35,9 @@ var ( ErrUserNotFound = &Error{"User not found", http.StatusNotFound} ErrExternalUserNotFound = &Error{"External user not found", http.StatusNotFound} ErrExternalUserAlreadyExists = &Error{"External user already exists", http.StatusConflict} + ErrEncryptionPartRequired = &Error{"The requested share have project entropy and encryption part is required", http.StatusConflict} + ErrEncryptionNotConfigured = &Error{"Encryption not configured", http.StatusConflict} + ErrInvalidEncryptionPart = &Error{"Invalid encryption part", http.StatusBadRequest} ErrMissingAPIKey = &Error{"Missing API key", http.StatusUnauthorized} ErrMissingAPISecret = &Error{"Missing API secret", http.StatusUnauthorized} diff --git a/internal/infrastructure/handlers/rest/authmdw/middleware.go b/internal/infrastructure/handlers/rest/authmdw/middleware.go index 990a2ac..584a170 100644 --- a/internal/infrastructure/handlers/rest/authmdw/middleware.go +++ b/internal/infrastructure/handlers/rest/authmdw/middleware.go @@ -7,15 +7,16 @@ import ( authenticate "go.openfort.xyz/shield/internal/core/ports/authentication" "go.openfort.xyz/shield/internal/infrastructure/authenticationmgr" "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" - "go.openfort.xyz/shield/pkg/ofcontext" + "go.openfort.xyz/shield/pkg/contexter" ) -const TokenHeader = "Authorization" //nolint:gosec -const AuthProviderHeader = "X-Auth-Provider" //nolint:gosec -const APIKeyHeader = "X-API-Key" //nolint:gosec -const APISecretHeader = "X-API-Secret" //nolint:gosec -const OpenfortProviderHeader = "X-Openfort-Provider" //nolint:gosec -const OpenfortTokenTypeHeader = "X-Openfort-Token-Type" //nolint:gosec +const TokenHeader = "Authorization" //nolint:gosec +const AuthProviderHeader = "X-Auth-Provider" //nolint:gosec +const APIKeyHeader = "X-API-Key" //nolint:gosec +const APISecretHeader = "X-API-Secret" //nolint:gosec +const OpenfortProviderHeader = "X-Openfort-Provider" //nolint:gosec +const OpenfortTokenTypeHeader = "X-Openfort-Token-Type" //nolint:gosec +const AccessControlAllowOriginHeader = "Access-Control-Allow-Origin" //nolint:gosec type Middleware struct { manager *authenticationmgr.Manager @@ -47,7 +48,7 @@ func (m *Middleware) AuthenticateAPISecret(next http.Handler) http.Handler { return } - ctx := ofcontext.WithProjectID(r.Context(), projectID) + ctx := contexter.WithProjectID(r.Context(), projectID) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -101,7 +102,7 @@ func (m *Middleware) AuthenticateUser(next http.Handler) http.Handler { return } - ctx := ofcontext.WithUserID(r.Context(), userID) + ctx := contexter.WithUserID(r.Context(), userID) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -121,9 +122,5 @@ func (m *Middleware) AllowedOrigin(r *http.Request, origin string) bool { } allowed, err := m.manager.IsAllowedOrigin(r.Context(), apiKey, origin) - if err != nil { - return false - } - - return allowed + return err == nil && allowed } diff --git a/internal/infrastructure/handlers/rest/config.go b/internal/infrastructure/handlers/rest/config.go index 2300cc6..db52ff2 100644 --- a/internal/infrastructure/handlers/rest/config.go +++ b/internal/infrastructure/handlers/rest/config.go @@ -1,11 +1,32 @@ package rest -import "github.com/caarlos0/env/v10" +import ( + "time" + "github.com/caarlos0/env/v10" +) + +// Config holds the configuration for the REST server. +// The default values are used if the environment variables are not set. +// The environment variables are: +// - PORT: the port the server listens on +// - REQUESTS_PER_SECOND: the number of requests per second the server can handle (if 0, the rate limiter is disabled) +// - READ_TIMEOUT: the read timeout for the server (if 0, no timeout is set) +// - WRITE_TIMEOUT: the write timeout for the server (if 0, no timeout is set) +// - IDLE_TIMEOUT: the idle timeout for the server (if 0, no timeout is set) +// - CORS_MAX_AGE: the max age for the CORS header +// - CORS_EXTRA_ALLOWED_HEADERS: the extra allowed headers for the CORS header (comma separated) type Config struct { - Port int `env:"PORT" envDefault:"8080"` + Port int `env:"PORT" envDefault:"8080"` + RPS int `env:"REQUESTS_PER_SECOND" envDefault:"100"` + ReadTimeout time.Duration `env:"READ_TIMEOUT" envDefault:"5s"` + WriteTimeout time.Duration `env:"WRITE_TIMEOUT" envDefault:"10s"` + IdleTimeout time.Duration `env:"IDLE_TIMEOUT" envDefault:"15s"` + CORSMaxAge int `env:"CORS_MAX_AGE" envDefault:"86400"` + CORSExtraAllowedHeaders string `env:"CORS_EXTRA_ALLOWED_HEADERS" envDefault:""` } +// GetConfigFromEnv gets the configuration from the environment variables. func GetConfigFromEnv() (*Config, error) { config := &Config{} err := env.Parse(config) diff --git a/internal/infrastructure/handlers/rest/projecthdl/errors.go b/internal/infrastructure/handlers/rest/projecthdl/errors.go index 194fdf6..c130a84 100644 --- a/internal/infrastructure/handlers/rest/projecthdl/errors.go +++ b/internal/infrastructure/handlers/rest/projecthdl/errors.go @@ -26,6 +26,8 @@ func fromApplicationError(err error) *api.Error { return api.ErrProviderAlreadyExists case errors.Is(err, projectapp.ErrProviderNotFound): return api.ErrProviderNotFound + case errors.Is(err, projectapp.ErrInvalidEncryptionPart): + return api.ErrInvalidEncryptionPart default: return api.ErrInternal } diff --git a/internal/infrastructure/handlers/rest/projecthdl/handler.go b/internal/infrastructure/handlers/rest/projecthdl/handler.go index f150f0e..b72323a 100644 --- a/internal/infrastructure/handlers/rest/projecthdl/handler.go +++ b/internal/infrastructure/handlers/rest/projecthdl/handler.go @@ -5,12 +5,11 @@ import ( "io" "log/slog" "net/http" - "os" "github.com/gorilla/mux" "go.openfort.xyz/shield/internal/applications/projectapp" "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type Handler struct { @@ -22,7 +21,7 @@ type Handler struct { func New(app *projectapp.ProjectApplication) *Handler { return &Handler{ app: app, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("project_handler"), + logger: logger.New("project_handler"), parser: newParser(), } } diff --git a/internal/infrastructure/handlers/rest/ratelimitermdw/middleware.go b/internal/infrastructure/handlers/rest/ratelimitermdw/middleware.go new file mode 100644 index 0000000..cf2404b --- /dev/null +++ b/internal/infrastructure/handlers/rest/ratelimitermdw/middleware.go @@ -0,0 +1,24 @@ +package ratelimitermdw + +import ( + "net/http" + + "go.uber.org/ratelimit" +) + +type Middleware struct { + limiter ratelimit.Limiter +} + +func New(limit int) *Middleware { + return &Middleware{ + limiter: ratelimit.New(limit), + } +} + +func (m *Middleware) RateLimitMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.limiter.Take() + next.ServeHTTP(w, r) + }) +} diff --git a/internal/infrastructure/handlers/rest/requestmdw/middleware.go b/internal/infrastructure/handlers/rest/requestmdw/middleware.go index e508aaf..7e4085f 100644 --- a/internal/infrastructure/handlers/rest/requestmdw/middleware.go +++ b/internal/infrastructure/handlers/rest/requestmdw/middleware.go @@ -4,12 +4,21 @@ import ( "net/http" "github.com/google/uuid" - "go.openfort.xyz/shield/pkg/ofcontext" + "go.openfort.xyz/shield/pkg/contexter" ) +const RequestIDHeader = "X-Request-ID" + func RequestIDMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := ofcontext.WithRequestID(r.Context(), uuid.NewString()) + ctx := r.Context() + if r.Header.Get(RequestIDHeader) != "" { + ctx = contexter.WithRequestID(ctx, r.Header.Get(RequestIDHeader)) + } else { + ctx = contexter.WithRequestID(ctx, uuid.NewString()) + } + + w.Header().Set(RequestIDHeader, contexter.GetRequestID(ctx)) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/infrastructure/handlers/rest/server.go b/internal/infrastructure/handlers/rest/server.go index ed7201c..e0ce61d 100644 --- a/internal/infrastructure/handlers/rest/server.go +++ b/internal/infrastructure/handlers/rest/server.go @@ -5,55 +5,53 @@ import ( "fmt" "log/slog" "net/http" - "os" + "strings" "github.com/gorilla/mux" "github.com/rs/cors" "go.openfort.xyz/shield/internal/applications/projectapp" - "go.openfort.xyz/shield/internal/applications/userapp" + "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/infrastructure/authenticationmgr" "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/authmdw" "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/projecthdl" + "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/ratelimitermdw" "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/requestmdw" "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/responsemdw" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/userhdl" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/sharehdl" + "go.openfort.xyz/shield/pkg/logger" ) +// Server is the REST server for the shield API type Server struct { projectApp *projectapp.ProjectApplication - userApp *userapp.UserApplication + shareApp *shareapp.ShareApplication authManager *authenticationmgr.Manager server *http.Server logger *slog.Logger config *Config } -func New(cfg *Config, projectApp *projectapp.ProjectApplication, userApp *userapp.UserApplication, authManager *authenticationmgr.Manager) *Server { +// New creates a new REST server +func New(cfg *Config, projectApp *projectapp.ProjectApplication, shareApp *shareapp.ShareApplication, authManager *authenticationmgr.Manager) *Server { return &Server{ projectApp: projectApp, - userApp: userApp, + shareApp: shareApp, authManager: authManager, server: new(http.Server), - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("rest_server"), + logger: logger.New("rest_server"), config: cfg, } } -type CORSLogger struct { - logger *slog.Logger -} - -func (l *CORSLogger) Printf(s string, i ...interface{}) { - l.logger.Info(fmt.Sprintf(s, i...)) -} - +// Start starts the REST server func (s *Server) Start(ctx context.Context) error { projectHdl := projecthdl.New(s.projectApp) - userHdl := userhdl.New(s.userApp) + shareHdl := sharehdl.New(s.shareApp) authMdw := authmdw.New(s.authManager) + rateLimiterMdw := ratelimitermdw.New(s.config.RPS) r := mux.NewRouter() + r.Use(rateLimiterMdw.RateLimitMiddleware) r.Use(requestmdw.RequestIDMiddleware) r.Use(responsemdw.ResponseMiddleware) r.HandleFunc("/register", projectHdl.CreateProject).Methods(http.MethodPost) @@ -71,25 +69,37 @@ func (s *Server) Start(ctx context.Context) error { u := r.PathPrefix("/shares").Subrouter() u.Use(authMdw.AuthenticateUser) - u.HandleFunc("", userHdl.GetShare).Methods(http.MethodGet) - u.HandleFunc("", userHdl.RegisterShare).Methods(http.MethodPost) + u.HandleFunc("", shareHdl.GetShare).Methods(http.MethodGet) + u.HandleFunc("", shareHdl.RegisterShare).Methods(http.MethodPost) + extraHeaders := strings.Split(s.config.CORSExtraAllowedHeaders, ",") c := cors.New(cors.Options{ AllowOriginRequestFunc: authMdw.AllowedOrigin, - AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowedHeaders: []string{"Access-Control-Allow-Origin", authmdw.TokenHeader, responsemdw.ContentTypeHeader, authmdw.APIKeyHeader, authmdw.APISecretHeader, authmdw.AuthProviderHeader, authmdw.OpenfortProviderHeader, authmdw.OpenfortTokenTypeHeader}, - MaxAge: 86400, - AllowCredentials: false, - Logger: &CORSLogger{s.logger}, - Debug: true, + AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodOptions}, + AllowedHeaders: append([]string{ + authmdw.AccessControlAllowOriginHeader, + authmdw.TokenHeader, + responsemdw.ContentTypeHeader, + authmdw.APIKeyHeader, + authmdw.APISecretHeader, + authmdw.AuthProviderHeader, + authmdw.OpenfortProviderHeader, + authmdw.OpenfortTokenTypeHeader, + }, extraHeaders...), + MaxAge: s.config.CORSMaxAge, }).Handler(r) + s.server.Addr = fmt.Sprintf(":%d", s.config.Port) s.server.Handler = c + s.server.ReadTimeout = s.config.ReadTimeout + s.server.WriteTimeout = s.config.WriteTimeout + s.server.IdleTimeout = s.config.IdleTimeout s.logger.InfoContext(ctx, "starting server", slog.String("address", s.server.Addr)) return s.server.ListenAndServe() } +// Stop stops the REST server gracefully func (s *Server) Stop(ctx context.Context) error { return s.server.Shutdown(ctx) } diff --git a/internal/infrastructure/handlers/rest/sharehdl/errors.go b/internal/infrastructure/handlers/rest/sharehdl/errors.go new file mode 100644 index 0000000..613798e --- /dev/null +++ b/internal/infrastructure/handlers/rest/sharehdl/errors.go @@ -0,0 +1,34 @@ +package sharehdl + +import ( + "errors" + + "go.openfort.xyz/shield/internal/applications/shareapp" + "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" +) + +func fromApplicationError(err error) *api.Error { + if err == nil { + return nil + } + switch { + case errors.Is(err, shareapp.ErrShareNotFound): + return api.ErrShareNotFound + case errors.Is(err, shareapp.ErrShareAlreadyExists): + return api.ErrShareAlreadyExists + case errors.Is(err, shareapp.ErrUserNotFound): + return api.ErrUserNotFound + case errors.Is(err, shareapp.ErrExternalUserNotFound): + return api.ErrExternalUserNotFound + case errors.Is(err, shareapp.ErrExternalUserAlreadyExists): + return api.ErrExternalUserAlreadyExists + case errors.Is(err, shareapp.ErrEncryptionPartRequired): + return api.ErrEncryptionPartRequired + case errors.Is(err, shareapp.ErrEncryptionNotConfigured): + return api.ErrEncryptionNotConfigured + case errors.Is(err, shareapp.ErrInvalidEncryptionPart): + return api.ErrInvalidEncryptionPart + default: + return api.ErrInternal + } +} diff --git a/internal/infrastructure/handlers/rest/userhdl/handler.go b/internal/infrastructure/handlers/rest/sharehdl/handler.go similarity index 70% rename from internal/infrastructure/handlers/rest/userhdl/handler.go rename to internal/infrastructure/handlers/rest/sharehdl/handler.go index 6d844b9..df0688c 100644 --- a/internal/infrastructure/handlers/rest/userhdl/handler.go +++ b/internal/infrastructure/handlers/rest/sharehdl/handler.go @@ -1,26 +1,29 @@ -package userhdl +package sharehdl import ( "encoding/json" "io" "log/slog" "net/http" - "os" - "go.openfort.xyz/shield/internal/applications/userapp" + "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type Handler struct { - app *userapp.UserApplication - logger *slog.Logger + app *shareapp.ShareApplication + logger *slog.Logger + parser *parser + validator *validator } -func New(app *userapp.UserApplication) *Handler { +func New(app *shareapp.ShareApplication) *Handler { return &Handler{ - app: app, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("user_handler"), + app: app, + logger: logger.New("share_handler"), + parser: newParser(), + validator: newValidator(), } } @@ -58,22 +61,17 @@ func (h *Handler) RegisterShare(w http.ResponseWriter, r *http.Request) { return } - if req.Secret == "" { - api.RespondWithError(w, api.ErrBadRequestWithMessage("secret is required")) + if errV := h.validator.validateShare((*Share)(&req)); errV != nil { + api.RespondWithError(w, errV) return } - var parameters *userapp.EncryptionParameters - if req.Salt != "" || req.Iterations != 0 || req.Length != 0 || req.Digest != "" { - parameters = &userapp.EncryptionParameters{ - Salt: req.Salt, - Iterations: req.Iterations, - Length: req.Length, - Digest: req.Digest, - } + share := h.parser.toDomain((*Share)(&req)) + var opts []shareapp.Option + if req.EncryptionPart != "" { + opts = append(opts, shareapp.WithEncryptionPart(req.EncryptionPart)) } - - err = h.app.RegisterShare(ctx, req.Secret, req.UserEntropy, parameters) + err = h.app.RegisterShare(ctx, share, opts...) if err != nil { api.RespondWithError(w, fromApplicationError(err)) return @@ -93,6 +91,7 @@ func (h *Handler) RegisterShare(w http.ResponseWriter, r *http.Request) { // @Param X-Auth-Provider header string true "Auth Provider" // @Param X-Openfort-Provider header string false "Openfort Provider" // @Param X-Openfort-Token-Type header string false "Openfort Token Type" +// @Param X-Encryption-Part header string false "Encryption Part" // @Success 200 {object} GetShareResponse "Successful response" // @Failure 404 "Description: Not Found" // @Failure 500 "Description: Internal Server Error" @@ -101,20 +100,19 @@ func (h *Handler) GetShare(w http.ResponseWriter, r *http.Request) { ctx := r.Context() h.logger.InfoContext(ctx, "getting share") - shr, err := h.app.GetShare(ctx) + var opts []shareapp.Option + encryptionPart := r.Header.Get(EncryptionPartHeader) + if encryptionPart != "" { + opts = append(opts, shareapp.WithEncryptionPart(encryptionPart)) + } + + shr, err := h.app.GetShare(ctx, opts...) if err != nil { api.RespondWithError(w, fromApplicationError(err)) return } - resp, err := json.Marshal(GetShareResponse{ - Secret: shr.Data, - UserEntropy: shr.UserEntropy, - Salt: shr.Salt, - Iterations: shr.Iterations, - Length: shr.Length, - Digest: shr.Digest, - }) + resp, err := json.Marshal(GetShareResponse(*h.parser.fromDomain(shr))) if err != nil { api.RespondWithError(w, api.ErrInternal) return diff --git a/internal/infrastructure/handlers/rest/sharehdl/parser.go b/internal/infrastructure/handlers/rest/sharehdl/parser.go new file mode 100644 index 0000000..fc1d43f --- /dev/null +++ b/internal/infrastructure/handlers/rest/sharehdl/parser.go @@ -0,0 +1,69 @@ +package sharehdl + +import "go.openfort.xyz/shield/internal/core/domain/share" + +type parser struct { + mapEntropyDomain map[Entropy]share.Entropy + mapDomainEntropy map[share.Entropy]Entropy +} + +func newParser() *parser { + return &parser{ + mapEntropyDomain: map[Entropy]share.Entropy{ + EntropyNone: share.EntropyNone, + EntropyUser: share.EntropyUser, + EntropyProject: share.EntropyProject, + }, + mapDomainEntropy: map[share.Entropy]Entropy{ + share.EntropyNone: EntropyNone, + share.EntropyUser: EntropyUser, + share.EntropyProject: EntropyProject, + }, + } +} + +func (p *parser) toDomain(s *Share) *share.Share { + shr := &share.Share{ + Secret: s.Secret, + EncryptionParameters: &share.EncryptionParameters{ + Entropy: p.mapEntropyDomain[s.Entropy], + }, + } + + if s.Salt != "" { + shr.EncryptionParameters.Salt = s.Salt + } + if s.Iterations != 0 { + shr.EncryptionParameters.Iterations = s.Iterations + } + if s.Length != 0 { + shr.EncryptionParameters.Length = s.Length + } + if s.Digest != "" { + shr.EncryptionParameters.Digest = s.Digest + } + + return shr +} + +func (p *parser) fromDomain(s *share.Share) *Share { + shr := &Share{ + Secret: s.Secret, + Entropy: p.mapDomainEntropy[s.EncryptionParameters.Entropy], + } + + if s.EncryptionParameters.Salt != "" { + shr.Salt = s.EncryptionParameters.Salt + } + if s.EncryptionParameters.Iterations != 0 { + shr.Iterations = s.EncryptionParameters.Iterations + } + if s.EncryptionParameters.Length != 0 { + shr.Length = s.EncryptionParameters.Length + } + if s.EncryptionParameters.Digest != "" { + shr.Digest = s.EncryptionParameters.Digest + } + + return shr +} diff --git a/internal/infrastructure/handlers/rest/sharehdl/types.go b/internal/infrastructure/handlers/rest/sharehdl/types.go new file mode 100644 index 0000000..9fafbc5 --- /dev/null +++ b/internal/infrastructure/handlers/rest/sharehdl/types.go @@ -0,0 +1,24 @@ +package sharehdl + +const EncryptionPartHeader = "X-Encryption-Part" + +type Share struct { + Secret string `json:"secret"` + Entropy Entropy `json:"entropy"` + Salt string `json:"salt,omitempty"` + Iterations int `json:"iterations,omitempty"` + Length int `json:"length,omitempty"` + Digest string `json:"digest,omitempty"` + EncryptionPart string `json:"encryption_part,omitempty"` +} + +type RegisterShareRequest Share +type GetShareResponse Share + +type Entropy string + +const ( + EntropyNone Entropy = "none" + EntropyUser Entropy = "user" + EntropyProject Entropy = "project" +) diff --git a/internal/infrastructure/handlers/rest/sharehdl/validator.go b/internal/infrastructure/handlers/rest/sharehdl/validator.go new file mode 100644 index 0000000..89ec6fb --- /dev/null +++ b/internal/infrastructure/handlers/rest/sharehdl/validator.go @@ -0,0 +1,43 @@ +package sharehdl + +import "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" + +type validator struct { +} + +func newValidator() *validator { + return &validator{} +} + +func (v *validator) validateShare(share *Share) *api.Error { + if share.Secret == "" { + return api.ErrBadRequestWithMessage("secret is required") + } + + switch share.Entropy { + case EntropyNone: + case "": + share.Entropy = EntropyNone + case EntropyUser: + if share.Salt == "" { + return api.ErrBadRequestWithMessage("salt is required when entropy is user") + } + if share.Iterations == 0 { + return api.ErrBadRequestWithMessage("iterations is required when entropy is user") + } + if share.Length == 0 { + return api.ErrBadRequestWithMessage("length is required when entropy is user") + } + if share.Digest == "" { + return api.ErrBadRequestWithMessage("digest is required when entropy is user") + } + case EntropyProject: + if share.EncryptionPart == "" { + return api.ErrBadRequestWithMessage("encryption_part is required when entropy is project") + } + default: + return api.ErrBadRequestWithMessage("invalid entropy") + } + + return nil +} diff --git a/internal/infrastructure/handlers/rest/userhdl/errors.go b/internal/infrastructure/handlers/rest/userhdl/errors.go deleted file mode 100644 index 4ea98f8..0000000 --- a/internal/infrastructure/handlers/rest/userhdl/errors.go +++ /dev/null @@ -1,28 +0,0 @@ -package userhdl - -import ( - "errors" - - "go.openfort.xyz/shield/internal/applications/userapp" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" -) - -func fromApplicationError(err error) *api.Error { - if err == nil { - return nil - } - switch { - case errors.Is(err, userapp.ErrShareNotFound): - return api.ErrShareNotFound - case errors.Is(err, userapp.ErrShareAlreadyExists): - return api.ErrShareAlreadyExists - case errors.Is(err, userapp.ErrUserNotFound): - return api.ErrUserNotFound - case errors.Is(err, userapp.ErrExternalUserNotFound): - return api.ErrExternalUserNotFound - case errors.Is(err, userapp.ErrExternalUserAlreadyExists): - return api.ErrExternalUserAlreadyExists - default: - return api.ErrInternal - } -} diff --git a/internal/infrastructure/handlers/rest/userhdl/types.go b/internal/infrastructure/handlers/rest/userhdl/types.go deleted file mode 100644 index 0a4a05c..0000000 --- a/internal/infrastructure/handlers/rest/userhdl/types.go +++ /dev/null @@ -1,19 +0,0 @@ -package userhdl - -type RegisterShareRequest struct { - Secret string `json:"secret"` - UserEntropy bool `json:"user_entropy"` - Salt string `json:"salt,omitempty"` - Iterations int `json:"iterations,omitempty"` - Length int `json:"length,omitempty"` - Digest string `json:"digest,omitempty"` -} - -type GetShareResponse struct { - Secret string `json:"secret"` - UserEntropy bool `json:"user_entropy"` - Salt string `json:"salt,omitempty"` - Iterations int `json:"iterations,omitempty"` - Length int `json:"length,omitempty"` - Digest string `json:"digest,omitempty"` -} diff --git a/internal/infrastructure/providersmgr/custom.go b/internal/infrastructure/providersmgr/custom.go index a2d1966..fbdd2e3 100644 --- a/internal/infrastructure/providersmgr/custom.go +++ b/internal/infrastructure/providersmgr/custom.go @@ -3,11 +3,10 @@ package providersmgr import ( "context" "log/slog" - "os" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/providers" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type custom struct { @@ -22,7 +21,7 @@ func newCustomProvider(providerConfig *provider.CustomConfig) providers.Identity return &custom{ jwkURL: providerConfig.JWK, providerID: providerConfig.ProviderID, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("custom_provider"), + logger: logger.New("custom_provider"), } } @@ -35,7 +34,7 @@ func (c *custom) Identify(ctx context.Context, token string, _ ...providers.Cust externalUserID, err := validateJWKs(token, c.jwkURL) if err != nil { - c.logger.ErrorContext(ctx, "failed to validate jwks", slog.String("error", err.Error())) + c.logger.ErrorContext(ctx, "failed to validate jwks", logger.Error(err)) return "", err } diff --git a/internal/infrastructure/providersmgr/manager.go b/internal/infrastructure/providersmgr/manager.go index 62900b2..a8d9458 100644 --- a/internal/infrastructure/providersmgr/manager.go +++ b/internal/infrastructure/providersmgr/manager.go @@ -4,13 +4,12 @@ import ( "context" "errors" "log/slog" - "os" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/providers" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type Manager struct { @@ -23,7 +22,7 @@ func NewManager(cfg *Config, repo repositories.ProviderRepository) *Manager { return &Manager{ config: cfg, repo: repo, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("provider_manager"), + logger: logger.New("provider_manager"), } } @@ -35,7 +34,7 @@ func (p *Manager) GetProvider(ctx context.Context, projectID string, providerTyp if errors.Is(err, domain.ErrProjectNotFound) { return nil, ErrProviderNotConfigured } - p.logger.ErrorContext(ctx, "failed to get provider", slog.String("error", err.Error())) + p.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, err } diff --git a/internal/infrastructure/providersmgr/openfort.go b/internal/infrastructure/providersmgr/openfort.go index 98412bc..db6ffc4 100644 --- a/internal/infrastructure/providersmgr/openfort.go +++ b/internal/infrastructure/providersmgr/openfort.go @@ -9,12 +9,11 @@ import ( "io" "log/slog" "net/http" - "os" "time" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/providers" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" ) type openfort struct { @@ -31,7 +30,7 @@ func newOpenfortProvider(config *Config, providerConfig *provider.OpenfortConfig publishableKey: providerConfig.PublishableKey, providerID: providerConfig.ProviderID, baseURL: config.OpenfortBaseURL, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("openfort_provider"), + logger: logger.New("openfort_provider"), } } @@ -45,7 +44,7 @@ func (o *openfort) Identify(ctx context.Context, token string, opts ...providers userID, err := validateJWKs(token, fmt.Sprintf("%s/iam/v1/%s/jwks.json", o.baseURL, o.publishableKey)) if err != nil { if !errors.Is(err, ErrInvalidToken) { - o.logger.ErrorContext(ctx, "failed to validate jwks", slog.String("error", err.Error())) + o.logger.ErrorContext(ctx, "failed to validate jwks", logger.Error(err)) return "", err } diff --git a/internal/infrastructure/repositories/mocks/projectmockrepo/repo.go b/internal/infrastructure/repositories/mocks/projectmockrepo/repo.go index b4ced4a..40ed2c4 100644 --- a/internal/infrastructure/repositories/mocks/projectmockrepo/repo.go +++ b/internal/infrastructure/repositories/mocks/projectmockrepo/repo.go @@ -35,6 +35,11 @@ func (m *MockProjectRepository) GetByAPIKey(ctx context.Context, apiKey string) return args.Get(0).(*project.Project), args.Error(1) } +func (m *MockProjectRepository) Delete(ctx context.Context, projectID string) error { + args := m.Mock.Called(ctx, projectID) + return args.Error(0) +} + func (m *MockProjectRepository) AddAllowedOrigin(ctx context.Context, projectID, origin string) error { args := m.Mock.Called(ctx, projectID, origin) return args.Error(0) @@ -60,3 +65,16 @@ func (m *MockProjectRepository) GetAllowedOriginsByAPIKey(ctx context.Context, a } return args.Get(0).([]string), args.Error(1) } + +func (m *MockProjectRepository) GetEncryptionPart(ctx context.Context, projectID string) (string, error) { + args := m.Mock.Called(ctx, projectID) + if args.Get(0) == nil { + return "", args.Error(1) + } + return args.Get(0).(string), args.Error(1) +} + +func (m *MockProjectRepository) SetEncryptionPart(ctx context.Context, projectID, part string) error { + args := m.Mock.Called(ctx, projectID, part) + return args.Error(0) +} diff --git a/internal/infrastructure/repositories/mocks/sharemockrepo/repo.go b/internal/infrastructure/repositories/mocks/sharemockrepo/repo.go index b0c6db8..e5a6d94 100644 --- a/internal/infrastructure/repositories/mocks/sharemockrepo/repo.go +++ b/internal/infrastructure/repositories/mocks/sharemockrepo/repo.go @@ -26,3 +26,16 @@ func (m *MockShareRepository) GetByUserID(ctx context.Context, userID string) (* } return args.Get(0).(*share.Share), args.Error(1) } + +func (m *MockShareRepository) ListDecryptedByProjectID(ctx context.Context, projectID string) ([]*share.Share, error) { + args := m.Mock.Called(ctx, projectID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]*share.Share), args.Error(1) +} + +func (m *MockShareRepository) Update(ctx context.Context, shr *share.Share) error { + args := m.Mock.Called(ctx, shr) + return args.Error(0) +} diff --git a/internal/infrastructure/repositories/sql/client.go b/internal/infrastructure/repositories/sql/client.go index 253de05..4f345de 100644 --- a/internal/infrastructure/repositories/sql/client.go +++ b/internal/infrastructure/repositories/sql/client.go @@ -24,27 +24,23 @@ func New(cfg *Config) (*Client, error) { return nil, ErrMissingDriver } + var dialector gorm.Dialector + var err error switch cfg.Driver { case DriverMySQL: - return newMySQL(cfg) + dialector, err = newMySQL(cfg) case DriverCloudSQL: - return newCloudSQL(cfg) + dialector, err = newCloudSQL(cfg) case DriverPostgres: - return newPostgres(cfg) + dialector = newPostgres(cfg) default: return nil, ErrDriverNotSupported } -} - -func newMySQL(cfg *Config) (*Client, error) { - sqlDB, err := sql.Open("mysql", cfg.MySQLDSN()) if err != nil { return nil, err } - db, err := gorm.Open(mysql.New(mysql.Config{ - Conn: sqlDB, - }), &gorm.Config{}) + db, err := gorm.Open(dialector, &gorm.Config{}) if err != nil { return nil, err } @@ -52,31 +48,32 @@ func newMySQL(cfg *Config) (*Client, error) { return &Client{db}, nil } -func newCloudSQL(cfg *Config) (*Client, error) { - dsn := cfg.CloudSQLDSN() - fmt.Println("DSN: " + dsn) - sqlDB, err := sql.Open("mysql", dsn) +func newMySQL(cfg *Config) (gorm.Dialector, error) { + sqlDB, err := sql.Open("mysql", cfg.MySQLDSN()) if err != nil { return nil, err } - db, err := gorm.Open(mysql.New(mysql.Config{ + return mysql.New(mysql.Config{ Conn: sqlDB, - }), &gorm.Config{}) - if err != nil { - return nil, err - } - - return &Client{db}, nil + }), nil } -func newPostgres(cfg *Config) (*Client, error) { - db, err := gorm.Open(postgres.Open(cfg.PostgresDSN()), &gorm.Config{}) +func newCloudSQL(cfg *Config) (gorm.Dialector, error) { + dsn := cfg.CloudSQLDSN() + fmt.Println("DSN: " + dsn) + sqlDB, err := sql.Open("mysql", dsn) if err != nil { return nil, err } - return &Client{db}, nil + return mysql.New(mysql.Config{ + Conn: sqlDB, + }), nil +} + +func newPostgres(cfg *Config) gorm.Dialector { + return postgres.Open(cfg.PostgresDSN()) } func (c *Client) Migrate() error { diff --git a/internal/infrastructure/repositories/sql/migrations/20240404172915_entropy.sql b/internal/infrastructure/repositories/sql/migrations/20240404172915_entropy.sql new file mode 100644 index 0000000..bcb6cf0 --- /dev/null +++ b/internal/infrastructure/repositories/sql/migrations/20240404172915_entropy.sql @@ -0,0 +1,13 @@ +-- +goose Up +ALTER TABLE shld_shares ADD COLUMN entropy VARCHAR(255) DEFAULT 'none'; +UPDATE shld_shares SET entropy = CASE WHEN user_entropy = TRUE THEN 'user' ELSE 'none' END; +ALTER TABLE shld_shares DROP COLUMN user_entropy; +-- +goose StatementBegin +-- +goose StatementEnd + +-- +goose Down +ALTER TABLE shld_shares ADD COLUMN user_entropy BOOLEAN DEFAULT FALSE; +UPDATE shld_shares SET user_entropy = CASE WHEN entropy = 'user' THEN TRUE ELSE FALSE END; +ALTER TABLE shld_shares DROP COLUMN entropy; +-- +goose StatementBegin +-- +goose StatementEnd diff --git a/internal/infrastructure/repositories/sql/projectrepo/repo.go b/internal/infrastructure/repositories/sql/projectrepo/repo.go index 9f51995..bb4e4f0 100644 --- a/internal/infrastructure/repositories/sql/projectrepo/repo.go +++ b/internal/infrastructure/repositories/sql/projectrepo/repo.go @@ -5,14 +5,13 @@ import ( "errors" "log/slog" - "os" "github.com/google/uuid" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -27,7 +26,7 @@ var _ repositories.ProjectRepository = &repository{} func New(db *sql.Client) repositories.ProjectRepository { return &repository{ db: db, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("project_repository"), + logger: logger.New("project_repository"), parser: newParser(), } } @@ -41,7 +40,7 @@ func (r *repository) Create(ctx context.Context, proj *project.Project) error { dbProj := r.parser.toDatabase(proj) err := r.db.Create(dbProj).Error if err != nil { - r.logger.ErrorContext(ctx, "error creating project", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error creating project", logger.Error(err)) return err } @@ -57,7 +56,7 @@ func (r *repository) Get(ctx context.Context, projectID string) (*project.Projec if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrProjectNotFound } - r.logger.ErrorContext(ctx, "error getting project", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting project", logger.Error(err)) return nil, err } @@ -73,13 +72,25 @@ func (r *repository) GetByAPIKey(ctx context.Context, apiKey string) (*project.P if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrProjectNotFound } - r.logger.ErrorContext(ctx, "error getting project", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting project", logger.Error(err)) return nil, err } return r.parser.toDomain(dbProj), nil } +func (r *repository) Delete(ctx context.Context, projectID string) error { + r.logger.InfoContext(ctx, "deleting project") + + err := r.db.Delete(&Project{}, "id = ?", projectID).Error + if err != nil { + r.logger.ErrorContext(ctx, "error deleting project", logger.Error(err)) + return err + } + + return nil +} + func (r *repository) AddAllowedOrigin(ctx context.Context, projectID, origin string) error { r.logger.InfoContext(ctx, "adding allowed origin") @@ -91,7 +102,7 @@ func (r *repository) AddAllowedOrigin(ctx context.Context, projectID, origin str err := r.db.Create(allowedOrigin).Error if err != nil { - r.logger.ErrorContext(ctx, "error adding allowed origin", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error adding allowed origin", logger.Error(err)) return err } @@ -103,7 +114,7 @@ func (r *repository) RemoveAllowedOrigin(ctx context.Context, projectID, origin err := r.db.Delete(&AllowedOrigin{}, "project_id = ? AND origin = ?", projectID, origin).Error if err != nil { - r.logger.ErrorContext(ctx, "error removing allowed origin", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error removing allowed origin", logger.Error(err)) return err } @@ -116,7 +127,7 @@ func (r *repository) GetAllowedOrigins(ctx context.Context, projectID string) ([ var origins []AllowedOrigin err := r.db.Model(&AllowedOrigin{}).Where("project_id = ?", projectID).Find(&origins).Error if err != nil { - r.logger.ErrorContext(ctx, "error getting allowed origins", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting allowed origins", logger.Error(err)) return nil, err } @@ -129,9 +140,43 @@ func (r *repository) GetAllowedOriginsByAPIKey(ctx context.Context, apiKey strin var origins []AllowedOrigin err := r.db.Model(&AllowedOrigin{}).Joins("JOIN shld_projects ON shld_projects.id = shld_allowed_origins.project_id").Where("shld_projects.api_key = ?", apiKey).Find(&origins).Error if err != nil { - r.logger.ErrorContext(ctx, "error getting allowed origins", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting allowed origins", logger.Error(err)) return nil, err } return r.parser.toDomainAllowedOrigins(origins), nil } + +func (r *repository) GetEncryptionPart(ctx context.Context, projectID string) (string, error) { + r.logger.InfoContext(ctx, "getting encryption part") + + encryptionPart := &EncryptionPart{} + err := r.db.Where("project_id = ?", projectID).First(encryptionPart).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return "", domain.ErrEncryptionPartNotFound + } + r.logger.ErrorContext(ctx, "error getting encryption part", logger.Error(err)) + return "", err + } + + return encryptionPart.Part, nil +} + +func (r *repository) SetEncryptionPart(ctx context.Context, projectID, part string) error { + r.logger.InfoContext(ctx, "setting encryption part") + + encryptionPart := &EncryptionPart{ + ID: uuid.NewString(), + ProjectID: projectID, + Part: part, + } + + err := r.db.Create(encryptionPart).Error + if err != nil { + r.logger.ErrorContext(ctx, "error setting encryption part", logger.Error(err)) + return err + } + + return nil +} diff --git a/internal/infrastructure/repositories/sql/projectrepo/types.go b/internal/infrastructure/repositories/sql/projectrepo/types.go index 10da550..8514e4b 100644 --- a/internal/infrastructure/repositories/sql/projectrepo/types.go +++ b/internal/infrastructure/repositories/sql/projectrepo/types.go @@ -32,3 +32,13 @@ type AllowedOrigin struct { func (AllowedOrigin) TableName() string { return "shld_allowed_origins" } + +type EncryptionPart struct { + ID string `gorm:"column:id;primaryKey"` + ProjectID string `gorm:"column:project_id"` + Part string `gorm:"column:part"` +} + +func (EncryptionPart) TableName() string { + return "shld_encryption_parts" +} diff --git a/internal/infrastructure/repositories/sql/providerrepo/repo.go b/internal/infrastructure/repositories/sql/providerrepo/repo.go index 2d3dcce..7c20769 100644 --- a/internal/infrastructure/repositories/sql/providerrepo/repo.go +++ b/internal/infrastructure/repositories/sql/providerrepo/repo.go @@ -4,14 +4,13 @@ import ( "context" "errors" "log/slog" - "os" "github.com/google/uuid" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -26,7 +25,7 @@ var _ repositories.ProviderRepository = (*repository)(nil) func New(db *sql.Client) repositories.ProviderRepository { return &repository{ db: db, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("provider_repository"), + logger: logger.New("provider_repository"), parser: newParser(), } } @@ -41,7 +40,7 @@ func (r *repository) Create(ctx context.Context, prov *provider.Provider) error dbProv := r.parser.toDatabaseProvider(prov) err := r.db.Create(dbProv).Error if err != nil { - r.logger.ErrorContext(ctx, "error creating provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error creating provider", logger.Error(err)) return err } @@ -57,7 +56,7 @@ func (r *repository) GetByProjectAndType(ctx context.Context, projectID string, if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrProviderNotFound } - r.logger.ErrorContext(ctx, "error getting provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting provider", logger.Error(err)) return nil, err } @@ -73,7 +72,7 @@ func (r *repository) Get(ctx context.Context, id string) (*provider.Provider, er if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrProviderNotFound } - r.logger.ErrorContext(ctx, "error getting provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting provider", logger.Error(err)) return nil, err } @@ -86,7 +85,7 @@ func (r *repository) List(ctx context.Context, projectID string) ([]*provider.Pr var dbProvs []Provider err := r.db.Where("project_id = ?", projectID).Find(&dbProvs).Error if err != nil { - r.logger.ErrorContext(ctx, "error listing providers", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error listing providers", logger.Error(err)) return nil, err } @@ -103,7 +102,7 @@ func (r *repository) Delete(ctx context.Context, providerID string) error { cmd := r.db.Delete(&Provider{ID: providerID}) if cmd.Error != nil { - r.logger.ErrorContext(ctx, "error deleting provider", slog.String("error", cmd.Error.Error())) + r.logger.ErrorContext(ctx, "error deleting provider", logger.Error(cmd.Error)) return cmd.Error } @@ -120,7 +119,7 @@ func (r *repository) CreateCustom(ctx context.Context, prov *provider.CustomConf dbProv := r.parser.toDatabaseCustomProvider(prov) err := r.db.Create(dbProv).Error if err != nil { - r.logger.ErrorContext(ctx, "error creating custom provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error creating custom provider", logger.Error(err)) return err } @@ -136,7 +135,7 @@ func (r *repository) GetCustom(ctx context.Context, providerID string) (*provide if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrProviderNotFound } - r.logger.ErrorContext(ctx, "error getting custom provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting custom provider", logger.Error(err)) return nil, err } @@ -149,7 +148,7 @@ func (r *repository) UpdateCustom(ctx context.Context, prov *provider.CustomConf dbProv := r.parser.toDatabaseCustomProvider(prov) err := r.db.Save(dbProv).Error if err != nil { - r.logger.ErrorContext(ctx, "error updating custom provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error updating custom provider", logger.Error(err)) return err } @@ -162,7 +161,7 @@ func (r *repository) CreateOpenfort(ctx context.Context, prov *provider.Openfort dbProv := r.parser.toDatabaseOpenfortProvider(prov) err := r.db.Create(dbProv).Error if err != nil { - r.logger.ErrorContext(ctx, "error creating openfort provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error creating openfort provider", logger.Error(err)) return err } @@ -178,7 +177,7 @@ func (r *repository) GetOpenfort(ctx context.Context, providerID string) (*provi if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrProviderNotFound } - r.logger.ErrorContext(ctx, "error getting openfort provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting openfort provider", logger.Error(err)) return nil, err } @@ -191,7 +190,7 @@ func (r *repository) UpdateOpenfort(ctx context.Context, prov *provider.Openfort dbProv := r.parser.toDatabaseOpenfortProvider(prov) err := r.db.Save(dbProv).Error if err != nil { - r.logger.ErrorContext(ctx, "error updating openfort provider", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error updating openfort provider", logger.Error(err)) return err } diff --git a/internal/infrastructure/repositories/sql/sharerepo/parser.go b/internal/infrastructure/repositories/sql/sharerepo/parser.go index 67634a8..1bafde8 100644 --- a/internal/infrastructure/repositories/sql/sharerepo/parser.go +++ b/internal/infrastructure/repositories/sql/sharerepo/parser.go @@ -5,34 +5,75 @@ import ( ) type parser struct { + mapEntropyDomain map[Entropy]share.Entropy + mapDomainEntropy map[share.Entropy]Entropy } func newParser() *parser { - return &parser{} + return &parser{ + mapEntropyDomain: map[Entropy]share.Entropy{ + EntropyNone: share.EntropyNone, + EntropyUser: share.EntropyUser, + EntropyProject: share.EntropyProject, + }, + mapDomainEntropy: map[share.Entropy]Entropy{ + share.EntropyNone: EntropyNone, + share.EntropyUser: EntropyUser, + share.EntropyProject: EntropyProject, + }, + } } func (p *parser) toDomain(s *Share) *share.Share { + encryptionParameters := &share.EncryptionParameters{ + Entropy: p.mapEntropyDomain[s.Entropy], + } + + if s.Salt != "" { + encryptionParameters.Salt = s.Salt + } + if s.Iterations != 0 { + encryptionParameters.Iterations = s.Iterations + } + if s.Length != 0 { + encryptionParameters.Length = s.Length + } + if s.Digest != "" { + encryptionParameters.Digest = s.Digest + } + return &share.Share{ - ID: s.ID, - Data: s.Data, - UserID: s.UserID, - UserEntropy: s.UserEntropy, - Salt: s.Salt, - Iterations: s.Iterations, - Length: s.Length, - Digest: s.Digest, + ID: s.ID, + Secret: s.Data, + UserID: s.UserID, + EncryptionParameters: encryptionParameters, } } func (p *parser) toDatabase(s *share.Share) *Share { - return &Share{ - ID: s.ID, - Data: s.Data, - UserID: s.UserID, - UserEntropy: s.UserEntropy, - Salt: s.Salt, - Iterations: s.Iterations, - Length: s.Length, - Digest: s.Digest, + shr := &Share{ + ID: s.ID, + Data: s.Secret, + UserID: s.UserID, } + + if s.EncryptionParameters != nil { + shr.Entropy = p.mapDomainEntropy[s.EncryptionParameters.Entropy] + if s.EncryptionParameters.Salt != "" { + shr.Salt = s.EncryptionParameters.Salt + } + if s.EncryptionParameters.Iterations != 0 { + shr.Iterations = s.EncryptionParameters.Iterations + } + if s.EncryptionParameters.Length != 0 { + shr.Length = s.EncryptionParameters.Length + } + if s.EncryptionParameters.Digest != "" { + shr.Digest = s.EncryptionParameters.Digest + } + } else { + shr.Entropy = EntropyNone + } + + return shr } diff --git a/internal/infrastructure/repositories/sql/sharerepo/repo.go b/internal/infrastructure/repositories/sql/sharerepo/repo.go index ce6d978..b3efb35 100644 --- a/internal/infrastructure/repositories/sql/sharerepo/repo.go +++ b/internal/infrastructure/repositories/sql/sharerepo/repo.go @@ -4,14 +4,13 @@ import ( "context" "errors" "log/slog" - "os" "github.com/google/uuid" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -26,7 +25,7 @@ var _ repositories.ShareRepository = (*repository)(nil) func New(db *sql.Client) repositories.ShareRepository { return &repository{ db: db, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("share_repository"), + logger: logger.New("share_repository"), parser: newParser(), } } @@ -41,7 +40,7 @@ func (r *repository) Create(ctx context.Context, shr *share.Share) error { dbShr := r.parser.toDatabase(shr) err := r.db.Create(dbShr).Error if err != nil { - r.logger.ErrorContext(ctx, "error creating share", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error creating share", logger.Error(err)) return err } @@ -57,9 +56,44 @@ func (r *repository) GetByUserID(ctx context.Context, userID string) (*share.Sha if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrShareNotFound } - r.logger.ErrorContext(ctx, "error getting share", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting share", logger.Error(err)) return nil, err } return r.parser.toDomain(dbShr), nil } + +func (r *repository) ListDecryptedByProjectID(ctx context.Context, projectID string) ([]*share.Share, error) { + r.logger.InfoContext(ctx, "listing shares", slog.String("project_id", projectID)) + + var dbShares []*Share + err := r.db.Joins("JOIN shld_users ON shld_shares.user_id = shld_users.id"). + Joins("JOIN shld_projects ON shld_users.project_id = shld_projects.id"). + Where("shld_projects.id = ?", projectID). + Where("shld_shares.entropy = ?", EntropyNone). + Find(&dbShares).Error + if err != nil { + r.logger.ErrorContext(ctx, "error listing shares", logger.Error(err)) + return nil, err + } + + var shares []*share.Share + for _, dbShr := range dbShares { + shares = append(shares, r.parser.toDomain(dbShr)) + } + + return shares, nil +} + +func (r *repository) Update(ctx context.Context, shr *share.Share) error { + r.logger.InfoContext(ctx, "updating share", slog.String("id", shr.ID)) + + dbShr := r.parser.toDatabase(shr) + err := r.db.Save(dbShr).Error + if err != nil { + r.logger.ErrorContext(ctx, "error updating share", logger.Error(err)) + return err + } + + return nil +} diff --git a/internal/infrastructure/repositories/sql/sharerepo/types.go b/internal/infrastructure/repositories/sql/sharerepo/types.go index 8f59565..5a2f2b9 100644 --- a/internal/infrastructure/repositories/sql/sharerepo/types.go +++ b/internal/infrastructure/repositories/sql/sharerepo/types.go @@ -7,19 +7,27 @@ import ( ) type Share struct { - ID string `gorm:"column:id;primary_key"` - Data string `gorm:"column:data; not null"` - UserID string `gorm:"column:user_id;not null"` - UserEntropy bool `gorm:"column:user_entropy;default:false"` - Salt string `gorm:"column:salt;default:null"` - Iterations int `gorm:"column:iterations;default:null"` - Length int `gorm:"column:length;default:null"` - Digest string `gorm:"column:digest;default:null"` - CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` - UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"` - DeletedAt gorm.DeletedAt `gorm:"column:deleted_at"` + ID string `gorm:"column:id;primary_key"` + Data string `gorm:"column:data; not null"` + UserID string `gorm:"column:user_id;not null"` + Entropy Entropy `gorm:"column:entropy;default:none"` + Salt string `gorm:"column:salt;default:null"` + Iterations int `gorm:"column:iterations;default:null"` + Length int `gorm:"column:length;default:null"` + Digest string `gorm:"column:digest;default:null"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"` + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at"` } func (Share) TableName() string { return "shld_shares" } + +type Entropy string + +const ( + EntropyNone Entropy = "none" + EntropyUser Entropy = "user" + EntropyProject Entropy = "project" +) diff --git a/internal/infrastructure/repositories/sql/userrepo/repo.go b/internal/infrastructure/repositories/sql/userrepo/repo.go index 4e36c42..28b9bd5 100644 --- a/internal/infrastructure/repositories/sql/userrepo/repo.go +++ b/internal/infrastructure/repositories/sql/userrepo/repo.go @@ -4,14 +4,13 @@ import ( "context" "errors" "log/slog" - "os" "github.com/google/uuid" "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/user" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" - "go.openfort.xyz/shield/pkg/oflog" + "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -26,7 +25,7 @@ var _ repositories.UserRepository = (*repository)(nil) func New(db *sql.Client) repositories.UserRepository { return &repository{ db: db, - logger: slog.New(oflog.NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup("user_repository"), + logger: logger.New("user_repository"), parser: newParser(), } } @@ -41,7 +40,7 @@ func (r *repository) Create(ctx context.Context, usr *user.User) error { dbUsr := r.parser.toDatabase(usr) err := r.db.Create(dbUsr).Error if err != nil { - r.logger.ErrorContext(ctx, "error creating user", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error creating user", logger.Error(err)) return err } @@ -57,7 +56,7 @@ func (r *repository) Get(ctx context.Context, userID string) (*user.User, error) if errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrUserNotFound } - r.logger.ErrorContext(ctx, "error getting user", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error getting user", logger.Error(err)) return nil, err } @@ -74,7 +73,7 @@ func (r *repository) CreateExternal(ctx context.Context, extUsr *user.ExternalUs dbExtUsr := r.parser.toDatabaseExternalUser(extUsr) err := r.db.Create(dbExtUsr).Error if err != nil { - r.logger.ErrorContext(ctx, "error creating external user", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error creating external user", logger.Error(err)) return err } @@ -94,7 +93,7 @@ func (r *repository) FindExternalBy(ctx context.Context, opts ...repositories.Op var dbExtUsrs []ExternalUser err := r.db.Where(options.query).Find(&dbExtUsrs).Error if err != nil { - r.logger.ErrorContext(ctx, "error finding external user", slog.String("error", err.Error())) + r.logger.ErrorContext(ctx, "error finding external user", logger.Error(err)) return nil, err } diff --git a/pkg/ofcontext/context.go b/pkg/contexter/context.go similarity index 98% rename from pkg/ofcontext/context.go rename to pkg/contexter/context.go index fc4e776..2b99b5f 100644 --- a/pkg/ofcontext/context.go +++ b/pkg/contexter/context.go @@ -1,4 +1,4 @@ -package ofcontext +package contexter import "context" diff --git a/pkg/ofcontext/keys.go b/pkg/contexter/keys.go similarity index 93% rename from pkg/ofcontext/keys.go rename to pkg/contexter/keys.go index d16fa13..53c5df0 100644 --- a/pkg/ofcontext/keys.go +++ b/pkg/contexter/keys.go @@ -1,4 +1,4 @@ -package ofcontext +package contexter type ContextKey string diff --git a/pkg/cypher/cypher.go b/pkg/cypher/cypher.go index ba64a3e..261a006 100644 --- a/pkg/cypher/cypher.go +++ b/pkg/cypher/cypher.go @@ -6,8 +6,9 @@ import ( "crypto/rand" "encoding/base64" "errors" - "github.com/codahale/sss" "io" + + "github.com/codahale/sss" ) func generateRandomBytes(n int) ([]byte, error) { @@ -27,12 +28,7 @@ func generateRandomString(n int) (string, error) { return base64.StdEncoding.EncodeToString(b), nil } -func Encrypt(plaintext, share1, share2 string) (string, error) { - key, err := combineShares(share1, share2) - if err != nil { - return "", err - } - +func Encrypt(plaintext, key string) (string, error) { keyBytes, err := base64.StdEncoding.DecodeString(key) if err != nil { return "", err @@ -57,12 +53,7 @@ func Encrypt(plaintext, share1, share2 string) (string, error) { return base64.StdEncoding.EncodeToString(ciphertext), nil } -func Decrypt(encrypted, share1, share2 string) (string, error) { - key, err := combineShares(share1, share2) - if err != nil { - return "", err - } - +func Decrypt(encrypted, key string) (string, error) { encryptedBytes, err := base64.StdEncoding.DecodeString(encrypted) if err != nil { return "", err @@ -128,19 +119,19 @@ func splitKey(key string) (string, string, error) { return base64.StdEncoding.EncodeToString(subset[0]), base64.StdEncoding.EncodeToString(subset[1]), nil } -func combineShares(share1, share2 string) (string, error) { - rawShare1, err := base64.StdEncoding.DecodeString(share1) +func ReconstructEncryptionKey(part1, part2 string) (string, error) { + rawPart1, err := base64.StdEncoding.DecodeString(part1) if err != nil { return "", err } - rawShare2, err := base64.StdEncoding.DecodeString(share2) + rawPart2, err := base64.StdEncoding.DecodeString(part2) if err != nil { return "", err } subset := make(map[byte][]byte, 2) - subset[0] = rawShare1 - subset[1] = rawShare2 + subset[0] = rawPart1 + subset[1] = rawPart2 key := sss.Combine(subset) diff --git a/pkg/oflog/handler.go b/pkg/logger/handler.go similarity index 68% rename from pkg/oflog/handler.go rename to pkg/logger/handler.go index 4ed0cda..336c5e5 100644 --- a/pkg/oflog/handler.go +++ b/pkg/logger/handler.go @@ -1,12 +1,21 @@ -package oflog +package logger import ( "context" "log/slog" + "os" - "go.openfort.xyz/shield/pkg/ofcontext" + "go.openfort.xyz/shield/pkg/contexter" ) +func New(name string) *slog.Logger { + return slog.New(NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup(name) +} + +func Error(err error) slog.Attr { + return slog.String("error", err.Error()) +} + type ContextHandler struct { baseHandler slog.Handler } @@ -24,11 +33,11 @@ func (c *ContextHandler) Enabled(ctx context.Context, level slog.Level) bool { } func (c *ContextHandler) Handle(ctx context.Context, record slog.Record) error { - if projID := ofcontext.GetProjectID(ctx); projID != "" { + if projID := contexter.GetProjectID(ctx); projID != "" { record.Add(slog.String(ProjectID, projID)) } - if reqID := ofcontext.GetRequestID(ctx); reqID != "" { + if reqID := contexter.GetRequestID(ctx); reqID != "" { record.Add(slog.String(RequestID, reqID)) } diff --git a/pkg/oflog/keys.go b/pkg/logger/keys.go similarity index 80% rename from pkg/oflog/keys.go rename to pkg/logger/keys.go index 25b9831..5e86a28 100644 --- a/pkg/oflog/keys.go +++ b/pkg/logger/keys.go @@ -1,4 +1,4 @@ -package oflog +package logger const ( ProjectID = "project_id"