Skip to content

Commit

Permalink
Merge pull request #4 from openfort-xyz/feat/add-pem-to-custom-provider
Browse files Browse the repository at this point in the history
feat: add pem to validate custom provider
  • Loading branch information
gllm-dev authored Apr 11, 2024
2 parents 4be3af2 + 9c9a7e9 commit 37a3a1c
Show file tree
Hide file tree
Showing 18 changed files with 356 additions and 34 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ curl --location --request DELETE 'https://shield.openfort.xyz/project/providers/
}
}
```
- Custom provider also can work with a PEM and Key type ("rsa" or "ecdsa" or "ed25519")
```json
{
"providers": {
"custom": {
"pem": "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEEVs/o5+uQbTjL3chynL4wXgUg2R9\nq9UU8I5mEovUf86QZ7kOBIjJwqnzD1omageEHWwHdBO6B+dFabmdT9POxg==\n-----END PUBLIC KEY-----",
"key_type": "ecdsa"
}
}
}
```

#### 8. Get Allowed Origins
- **GET**: `https://shield.openfort.xyz/project/allowed-origins`
Expand Down
4 changes: 3 additions & 1 deletion cmd/cli/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cli

import (
"errors"
"net/http"
"os"
"os/signal"
"sync"
Expand Down Expand Up @@ -43,7 +45,7 @@ func NewCmdServer() *cobra.Command {
wg.Done()
}()

if err = server.Start(cmd.Context()); err != nil {
if err = server.Start(cmd.Context()); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}

Expand Down
32 changes: 32 additions & 0 deletions internal/applications/projectapp/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ func (a *ProjectApplication) AddProviders(ctx context.Context, opts ...ProviderO
providers = append(providers, &provider.Provider{ProjectID: projectID, Type: provider.TypeOpenfort, Config: &provider.OpenfortConfig{PublishableKey: *cfg.openfortPublishableKey}})
}

if cfg.jwkURL != nil && cfg.pem != nil {
return nil, ErrJWKPemConflict
}

if cfg.jwkURL != nil {
prov, err := a.providerRepo.GetByProjectAndType(ctx, projectID, provider.TypeCustom)
if err != nil && !errors.Is(err, domain.ErrProviderNotFound) {
Expand All @@ -115,6 +119,18 @@ func (a *ProjectApplication) AddProviders(ctx context.Context, opts ...ProviderO
providers = append(providers, &provider.Provider{ProjectID: projectID, Type: provider.TypeCustom, Config: &provider.CustomConfig{JWK: *cfg.jwkURL}})
}

if cfg.pem != nil {
prov, err := a.providerRepo.GetByProjectAndType(ctx, projectID, provider.TypeCustom)
if err != nil && !errors.Is(err, domain.ErrProviderNotFound) {
a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err))
return nil, fromDomainError(err)
}
if err == nil && prov != nil {
return nil, ErrProviderAlreadyExists
}
providers = append(providers, &provider.Provider{ProjectID: projectID, Type: provider.TypeCustom, Config: &provider.CustomConfig{PEM: *cfg.pem, KeyType: cfg.keyType}})
}

if len(providers) == 0 {
return nil, ErrNoProviderSpecified
}
Expand Down Expand Up @@ -204,6 +220,22 @@ func (a *ProjectApplication) UpdateProvider(ctx context.Context, providerID stri
return fromDomainError(err)
}
}

if cfg.pem != nil {
if prov.Type != provider.TypeCustom {
return ErrProviderMismatch
}

if prov.Config.(*provider.CustomConfig).KeyType == provider.KeyTypeUnknown && cfg.keyType == provider.KeyTypeUnknown {
return ErrKeyTypeNotSpecified
}

err = a.providerRepo.UpdateCustom(ctx, &provider.CustomConfig{ProviderID: providerID, PEM: *cfg.pem, KeyType: cfg.keyType})
if err != nil {
a.logger.ErrorContext(ctx, "failed to update custom provider", logger.Error(err))
return fromDomainError(err)
}
}
return nil
}

Expand Down
128 changes: 121 additions & 7 deletions internal/applications/projectapp/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func TestProjectApplication_AddProviders(t *testing.T) {
name: "success",
options: []ProviderOption{
WithOpenfort("publishableKey"),
WithCustom("ur"),
WithCustomJWK("ur"),
},
wantErr: nil,
wantProviders: 2,
Expand All @@ -220,6 +220,21 @@ func TestProjectApplication_AddProviders(t *testing.T) {
providerRepo.On("CreateCustom", mock.Anything, mock.AnythingOfType("*provider.CustomConfig")).Return(nil)
},
},
{
name: "success with pem",
options: []ProviderOption{
WithCustomPEM("pem", provider.KeyTypeECDSA),
},
wantErr: nil,
wantProviders: 1,
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound)
providerRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil)
providerRepo.On("CreateCustom", mock.Anything, mock.AnythingOfType("*provider.CustomConfig")).Return(nil)
},
},
{
name: "no providers",
wantErr: ErrNoProviderSpecified,
Expand All @@ -243,7 +258,7 @@ func TestProjectApplication_AddProviders(t *testing.T) {
{
name: "custom provider already exists",
options: []ProviderOption{
WithCustom("ur"),
WithCustomJWK("ur"),
},
wantErr: ErrProviderAlreadyExists,
mock: func() {
Expand All @@ -252,6 +267,30 @@ func TestProjectApplication_AddProviders(t *testing.T) {
providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(&provider.Provider{}, nil)
},
},
{
name: "custom provider already exists",
options: []ProviderOption{
WithCustomPEM("pem", provider.KeyTypeECDSA),
},
wantErr: ErrProviderAlreadyExists,
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(&provider.Provider{}, nil)
},
},
{
name: "custom provider conflict config",
options: []ProviderOption{
WithCustomJWK("ur"),
WithCustomPEM("pem", provider.KeyTypeECDSA),
},
wantErr: ErrJWKPemConflict,
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
},
},
{
name: "error getting openfort provider",
options: []ProviderOption{
Expand All @@ -267,7 +306,19 @@ func TestProjectApplication_AddProviders(t *testing.T) {
{
name: "error getting custom provider",
options: []ProviderOption{
WithCustom("ur"),
WithCustomJWK("ur"),
},
wantErr: ErrInternal,
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, errors.New("repository error"))
},
},
{
name: "error getting custom provider",
options: []ProviderOption{
WithCustomPEM("pem", provider.KeyTypeECDSA),
},
wantErr: ErrInternal,
mock: func() {
Expand All @@ -280,7 +331,7 @@ func TestProjectApplication_AddProviders(t *testing.T) {
name: "error configuring provider",
options: []ProviderOption{
WithOpenfort("publishableKey"),
WithCustom("ur"),
WithCustomJWK("ur"),
},
wantErr: ErrInternal,
mock: func() {
Expand Down Expand Up @@ -496,7 +547,7 @@ func TestProjectApplication_UpdateProvider(t *testing.T) {
mock func()
}{
{
name: "success",
name: "success openfort",
providerID: "provider-id",
wantErr: nil,
mock: func() {
Expand All @@ -510,6 +561,29 @@ func TestProjectApplication_UpdateProvider(t *testing.T) {
WithOpenfort("publishable-key"),
},
},
{
name: "success custom jwk",
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("Get", mock.Anything, mock.Anything).Return(customProvider, nil)
providerRepo.On("Update", mock.Anything, mock.Anything).Return(nil)
providerRepo.On("UpdateCustom", mock.Anything, mock.Anything).Return(nil)
},
options: []ProviderOption{
WithCustomJWK("url"),
},
},
{
name: "success custom pem",
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("Get", mock.Anything, mock.Anything).Return(customProvider, nil)
providerRepo.On("Update", mock.Anything, mock.Anything).Return(nil)
providerRepo.On("UpdateCustom", mock.Anything, mock.Anything).Return(nil)
},
},
{
name: "provider not found",
providerID: "provider-id",
Expand Down Expand Up @@ -562,7 +636,32 @@ func TestProjectApplication_UpdateProvider(t *testing.T) {
providerRepo.On("Get", mock.Anything, mock.Anything).Return(&provider.Provider{ProjectID: "project_id", Type: provider.TypeOpenfort}, nil)
},
options: []ProviderOption{
WithCustom("ur"),
WithCustomJWK("ur"),
},
},
{
name: "error provider mismatch",
wantErr: ErrProviderMismatch,
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("Get", mock.Anything, mock.Anything).Return(&provider.Provider{ProjectID: "project_id", Type: provider.TypeOpenfort}, nil)
},
options: []ProviderOption{
WithCustomPEM("pem", provider.KeyTypeECDSA),
},
},
{
name: "error key not specified",
providerID: "provider-id",
wantErr: ErrKeyTypeNotSpecified,
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("Get", mock.Anything, mock.Anything).Return(&provider.Provider{ProjectID: "project_id", Type: provider.TypeCustom, Config: &provider.CustomConfig{}}, nil)
},
options: []ProviderOption{
WithCustomPEM("pem", provider.KeyTypeUnknown),
},
},
{
Expand Down Expand Up @@ -592,7 +691,22 @@ func TestProjectApplication_UpdateProvider(t *testing.T) {
providerRepo.On("UpdateCustom", mock.Anything, mock.Anything).Return(errors.New("repository error"))
},
options: []ProviderOption{
WithCustom("ur"),
WithCustomJWK("ur"),
},
},
{
name: "error updating custom provider",
providerID: "provider-id",
wantErr: ErrInternal,
mock: func() {
projectRepo.ExpectedCalls = nil
providerRepo.ExpectedCalls = nil
providerRepo.On("Get", mock.Anything, mock.Anything).Return(customProvider, nil)
providerRepo.On("Update", mock.Anything, mock.Anything).Return(nil)
providerRepo.On("UpdateCustom", mock.Anything, mock.Anything).Return(errors.New("repository error"))
},
options: []ProviderOption{
WithCustomPEM("pem", provider.KeyTypeECDSA),
},
},
}
Expand Down
2 changes: 2 additions & 0 deletions internal/applications/projectapp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var (
ErrProjectNotFound = errors.New("project not found")
ErrNoProviderSpecified = errors.New("no provider specified")
ErrProviderMismatch = errors.New("provider mismatch")
ErrKeyTypeNotSpecified = errors.New("key type not specified")
ErrInvalidProviderConfig = errors.New("invalid provider config")
ErrUnknownProviderType = errors.New("unknown provider type")
ErrProviderAlreadyExists = errors.New("custom authentication already registered for this project")
Expand All @@ -18,6 +19,7 @@ var (
ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists")
ErrAllowedOriginNotFound = errors.New("allowed origin not found")
ErrEncryptionNotConfigured = errors.New("encryption not configured")
ErrJWKPemConflict = errors.New("jwk and pem cannot be set at the same time")
ErrInternal = errors.New("internal error")
)

Expand Down
13 changes: 12 additions & 1 deletion internal/applications/projectapp/options.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package projectapp

import "go.openfort.xyz/shield/internal/core/domain/provider"

type ProviderOption func(*providerConfig)

func WithCustom(url string) ProviderOption {
func WithCustomJWK(url string) ProviderOption {
return func(c *providerConfig) {
c.jwkURL = &url
}
}

func WithCustomPEM(pem string, keyType provider.KeyType) ProviderOption {
return func(c *providerConfig) {
c.pem = &pem
c.keyType = keyType
}
}

func WithOpenfort(openfortProjectID string) ProviderOption {
return func(c *providerConfig) {
c.openfortPublishableKey = &openfortProjectID
Expand All @@ -16,6 +25,8 @@ func WithOpenfort(openfortProjectID string) ProviderOption {

type providerConfig struct {
jwkURL *string
pem *string
keyType provider.KeyType
openfortPublishableKey *string
}

Expand Down
11 changes: 11 additions & 0 deletions internal/core/domain/provider/customcfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,15 @@ package provider
type CustomConfig struct {
ProviderID string
JWK string
PEM string
KeyType KeyType
}

type KeyType int8

const (
KeyTypeUnknown KeyType = iota
KeyTypeRSA
KeyTypeECDSA
KeyTypeEd25519
)
2 changes: 2 additions & 0 deletions internal/infrastructure/handlers/rest/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
ErrMissingProvider = &Error{"Missing provider", http.StatusBadRequest}
ErrProviderNotFound = &Error{"Provider not found", http.StatusNotFound}
ErrInvalidProviderConfig = &Error{"Invalid provider config", http.StatusBadRequest}
ErrMissingKeyType = &Error{"Missing key type", http.StatusBadRequest}
ErrProviderAlreadyExists = &Error{"Custom authentication already registered for this project", http.StatusConflict}

ErrShareNotFound = &Error{"Share not found", http.StatusNotFound}
Expand All @@ -37,6 +38,7 @@ var (
ErrExternalUserAlreadyExists = &Error{"External user already exists", http.StatusConflict}
ErrEncryptionPartRequired = &Error{"The requested share have project entropy and encryption part is required", http.StatusConflict}
ErrEncryptionNotConfigured = &Error{"Encryption not configured", http.StatusConflict}
ErrJWKPemConflict = &Error{"JWK and PEM cannot be set at the same time", http.StatusConflict}
ErrInvalidEncryptionPart = &Error{"Invalid encryption part", http.StatusBadRequest}
ErrEncryptionPartAlreadyExists = &Error{"Encryption part already exists", http.StatusConflict}
ErrAllowedOriginNotFound = &Error{"Allowed origin not found", http.StatusNotFound}
Expand Down
4 changes: 4 additions & 0 deletions internal/infrastructure/handlers/rest/projecthdl/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ func fromApplicationError(err error) *api.Error {
return api.ErrMissingProvider
case errors.Is(err, projectapp.ErrProviderMismatch):
return api.ErrInvalidProviderConfig
case errors.Is(err, projectapp.ErrKeyTypeNotSpecified):
return api.ErrMissingKeyType
case errors.Is(err, projectapp.ErrInvalidProviderConfig):
return api.ErrInvalidProviderConfig
case errors.Is(err, projectapp.ErrUnknownProviderType):
Expand All @@ -34,6 +36,8 @@ func fromApplicationError(err error) *api.Error {
return api.ErrAllowedOriginNotFound
case errors.Is(err, projectapp.ErrEncryptionNotConfigured):
return api.ErrEncryptionNotConfigured
case errors.Is(err, projectapp.ErrJWKPemConflict):
return api.ErrJWKPemConflict
default:
return api.ErrInternal
}
Expand Down
Loading

0 comments on commit 37a3a1c

Please sign in to comment.