diff --git a/database/mock/store.go b/database/mock/store.go index df3814b8a2..5f428fa277 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -192,6 +192,20 @@ func (mr *MockStoreMockRecorder) CountUsers(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountUsers", reflect.TypeOf((*MockStore)(nil).CountUsers), ctx) } +// CreateEntitlements mocks base method. +func (m *MockStore) CreateEntitlements(ctx context.Context, arg db.CreateEntitlementsParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateEntitlements", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateEntitlements indicates an expected call of CreateEntitlements. +func (mr *MockStoreMockRecorder) CreateEntitlements(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEntitlements", reflect.TypeOf((*MockStore)(nil).CreateEntitlements), ctx, arg) +} + // CreateEntity mocks base method. func (m *MockStore) CreateEntity(ctx context.Context, arg db.CreateEntityParams) (db.EntityInstance, error) { m.ctrl.T.Helper() diff --git a/database/query/entitlements.sql b/database/query/entitlements.sql index dc14569541..6e1104b04c 100644 --- a/database/query/entitlements.sql +++ b/database/query/entitlements.sql @@ -9,4 +9,9 @@ WHERE e.project_id = sqlc.arg(project_id)::UUID AND e.feature = sqlc.arg(feature -- name: GetEntitlementFeaturesByProjectID :many SELECT feature FROM entitlements -WHERE project_id = sqlc.arg(project_id)::UUID; \ No newline at end of file +WHERE project_id = sqlc.arg(project_id)::UUID; + +-- name: CreateEntitlements :exec +INSERT INTO entitlements (feature, project_id) +SELECT unnest(sqlc.arg(features)::text[]), sqlc.arg(project_id)::UUID +ON CONFLICT DO NOTHING; diff --git a/internal/controlplane/handlers_projects.go b/internal/controlplane/handlers_projects.go index c9c24728bf..dbc1fd0c05 100644 --- a/internal/controlplane/handlers_projects.go +++ b/internal/controlplane/handlers_projects.go @@ -201,6 +201,15 @@ func (s *Server) CreateProject( return nil, status.Errorf(codes.Internal, "error creating subproject: %v", err) } + // Retrieve the membership-to-feature mapping from the configuration + projectFeatures := s.cfg.Features.GetFeaturesForMemberships(ctx) + if err := qtx.CreateEntitlements(ctx, db.CreateEntitlementsParams{ + Features: projectFeatures, + ProjectID: subProject.ID, + }); err != nil { + return nil, status.Errorf(codes.Internal, "error creating entitlements: %v", err) + } + if err := s.authzClient.Adopt(ctx, parent.ID, subProject.ID); err != nil { return nil, status.Errorf(codes.Internal, "error creating subproject: %v", err) } diff --git a/internal/controlplane/handlers_user_test.go b/internal/controlplane/handlers_user_test.go index 75d07224e0..c71028e4ce 100644 --- a/internal/controlplane/handlers_user_test.go +++ b/internal/controlplane/handlers_user_test.go @@ -83,6 +83,8 @@ func TestCreateUser_gRPC(t *testing.T) { store.EXPECT(). CreateUser(gomock.Any(), gomock.Any()). Return(returnedUser, nil) + store.EXPECT().CreateEntitlements(gomock.Any(), gomock.Any()). + Return(nil) store.EXPECT().Commit(gomock.Any()) store.EXPECT().Rollback(gomock.Any()) tokenResult, _ := openid.NewBuilder().GivenName("Foo").FamilyName("Bar").Email("test@stacklok.com").Subject("subject1").Build() @@ -262,6 +264,7 @@ func TestCreateUser_gRPC(t *testing.T) { authz, marketplaces.NewNoopMarketplace(), &serverconfig.DefaultProfilesConfig{}, + &serverconfig.FeaturesConfig{}, ), } diff --git a/internal/db/entitlements.sql.go b/internal/db/entitlements.sql.go index 42e48f9960..ff0320782d 100644 --- a/internal/db/entitlements.sql.go +++ b/internal/db/entitlements.sql.go @@ -10,8 +10,25 @@ import ( "encoding/json" "github.com/google/uuid" + "github.com/lib/pq" ) +const createEntitlements = `-- name: CreateEntitlements :exec +INSERT INTO entitlements (feature, project_id) +SELECT unnest($1::text[]), $2::UUID +ON CONFLICT DO NOTHING +` + +type CreateEntitlementsParams struct { + Features []string `json:"features"` + ProjectID uuid.UUID `json:"project_id"` +} + +func (q *Queries) CreateEntitlements(ctx context.Context, arg CreateEntitlementsParams) error { + _, err := q.db.ExecContext(ctx, createEntitlements, pq.Array(arg.Features), arg.ProjectID) + return err +} + const getEntitlementFeaturesByProjectID = `-- name: GetEntitlementFeaturesByProjectID :many SELECT feature FROM entitlements diff --git a/internal/db/querier.go b/internal/db/querier.go index 25c071eeaf..dc9d33477b 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -20,6 +20,7 @@ type Querier interface { CountRepositories(ctx context.Context) (int64, error) CountRepositoriesByProjectID(ctx context.Context, projectID uuid.UUID) (int64, error) CountUsers(ctx context.Context) (int64, error) + CreateEntitlements(ctx context.Context, arg CreateEntitlementsParams) error // CreateEntity adds an entry to the entity_instances table so it can be tracked by Minder. CreateEntity(ctx context.Context, arg CreateEntityParams) (EntityInstance, error) // CreateEntityWithID adds an entry to the entities table with a specific ID so it can be tracked by Minder. diff --git a/internal/projects/creator.go b/internal/projects/creator.go index 6d1797f6bf..13b4f422a8 100644 --- a/internal/projects/creator.go +++ b/internal/projects/creator.go @@ -39,17 +39,20 @@ type projectCreator struct { authzClient authz.Client marketplace marketplaces.Marketplace profilesCfg *server.DefaultProfilesConfig + featuresCfg *server.FeaturesConfig } // NewProjectCreator creates a new instance of the project creator func NewProjectCreator(authzClient authz.Client, marketplace marketplaces.Marketplace, profilesCfg *server.DefaultProfilesConfig, + featuresCfg *server.FeaturesConfig, ) ProjectCreator { return &projectCreator{ authzClient: authzClient, marketplace: marketplace, profilesCfg: profilesCfg, + featuresCfg: featuresCfg, } } @@ -105,6 +108,15 @@ func (p *projectCreator) ProvisionSelfEnrolledProject( return nil, fmt.Errorf("failed to create default project: %v", err) } + // Retrieve the membership-to-feature mapping from the configuration + projectFeatures := p.featuresCfg.GetFeaturesForMemberships(ctx) + if err := qtx.CreateEntitlements(ctx, db.CreateEntitlementsParams{ + Features: projectFeatures, + ProjectID: project.ID, + }); err != nil { + return nil, fmt.Errorf("error creating entitlements: %w", err) + } + // Enable any default profiles and rule types in the project. // For now, we subscribe to a single bundle and a single profile. // Both are specified in the service config. diff --git a/internal/projects/creator_test.go b/internal/projects/creator_test.go index 8c890038bb..edae7a922f 100644 --- a/internal/projects/creator_test.go +++ b/internal/projects/creator_test.go @@ -6,13 +6,17 @@ package projects_test import ( "context" "fmt" + "reflect" "testing" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwt/openid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" mockdb "github.com/mindersec/minder/database/mock" + "github.com/mindersec/minder/internal/auth/jwt" "github.com/mindersec/minder/internal/authz/mock" "github.com/mindersec/minder/internal/db" "github.com/mindersec/minder/internal/marketplaces" @@ -33,10 +37,28 @@ func TestProvisionSelfEnrolledProject(t *testing.T) { Return(db.Project{ ID: uuid.New(), }, nil) + mockStore.EXPECT().CreateEntitlements(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, params db.CreateEntitlementsParams) error { + expectedFeatures := []string{"featureA", "featureB"} + if !reflect.DeepEqual(params.Features, expectedFeatures) { + t.Errorf("expected features %v, got %v", expectedFeatures, params.Features) + } + return nil + }) + + ctx := prepareTestToken(context.Background(), t, []any{ + "teamA", + "teamB", + "teamC", + }) + + creator := projects.NewProjectCreator(authzClient, marketplaces.NewNoopMarketplace(), &server.DefaultProfilesConfig{}, &server.FeaturesConfig{ + MembershipFeatureMapping: map[string]string{ + "teamA": "featureA", + "teamB": "featureB", + }, + }) - ctx := context.Background() - - creator := projects.NewProjectCreator(authzClient, marketplaces.NewNoopMarketplace(), &server.DefaultProfilesConfig{}) _, err := creator.ProvisionSelfEnrolledProject( ctx, mockStore, @@ -62,8 +84,7 @@ func TestProvisionSelfEnrolledProjectFailsWritingProjectToDB(t *testing.T) { Return(db.Project{}, fmt.Errorf("failed to create project")) ctx := context.Background() - - creator := projects.NewProjectCreator(authzClient, marketplaces.NewNoopMarketplace(), &server.DefaultProfilesConfig{}) + creator := projects.NewProjectCreator(authzClient, marketplaces.NewNoopMarketplace(), &server.DefaultProfilesConfig{}, &server.FeaturesConfig{}) _, err := creator.ProvisionSelfEnrolledProject( ctx, mockStore, @@ -94,7 +115,7 @@ func TestProvisionSelfEnrolledProjectInvalidName(t *testing.T) { mockStore := mockdb.NewMockStore(ctrl) ctx := context.Background() - creator := projects.NewProjectCreator(authzClient, marketplaces.NewNoopMarketplace(), &server.DefaultProfilesConfig{}) + creator := projects.NewProjectCreator(authzClient, marketplaces.NewNoopMarketplace(), &server.DefaultProfilesConfig{}, &server.FeaturesConfig{}) for _, tc := range testCases { _, err := creator.ProvisionSelfEnrolledProject( @@ -107,3 +128,15 @@ func TestProvisionSelfEnrolledProjectInvalidName(t *testing.T) { } } + +// prepareTestToken creates a JWT token with the specified roles and returns the context with the token. +func prepareTestToken(ctx context.Context, t *testing.T, roles []any) context.Context { + t.Helper() + + token := openid.New() + require.NoError(t, token.Set("realm_access", map[string]any{ + "roles": roles, + })) + + return jwt.WithAuthTokenContext(ctx, token) +} diff --git a/internal/service/service.go b/internal/service/service.go index 72071aa910..66d8deeb7d 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -102,7 +102,7 @@ func AllInOneServerService( fallbackTokenClient := ghprov.NewFallbackTokenClient(cfg.Provider) ghClientFactory := clients.NewGitHubClientFactory(providerMetrics) providerStore := providers.NewProviderStore(store) - projectCreator := projects.NewProjectCreator(authzClient, marketplace, &cfg.DefaultProfiles) + projectCreator := projects.NewProjectCreator(authzClient, marketplace, &cfg.DefaultProfiles, &cfg.Features) propSvc := propService.NewPropertiesService(store) // TODO: isolate GitHub-specific wiring. We'll need to isolate GitHub diff --git a/pkg/config/server/config.go b/pkg/config/server/config.go index 31e1127ef4..235d1e662b 100644 --- a/pkg/config/server/config.go +++ b/pkg/config/server/config.go @@ -30,6 +30,7 @@ type Config struct { Auth AuthConfig `mapstructure:"auth"` WebhookConfig WebhookConfig `mapstructure:"webhook-config"` Events EventConfig `mapstructure:"events"` + Features FeaturesConfig `mapstructure:"features"` Authz AuthzConfig `mapstructure:"authz"` Provider ProviderConfig `mapstructure:"provider"` Marketplace MarketplaceConfig `mapstructure:"marketplace"` diff --git a/pkg/config/server/features.go b/pkg/config/server/features.go new file mode 100644 index 0000000000..a07c4c7237 --- /dev/null +++ b/pkg/config/server/features.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: Copyright 2024 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "context" + + "github.com/mindersec/minder/internal/auth/jwt" +) + +// FeaturesConfig is the configuration for the features +type FeaturesConfig struct { + // MembershipFeatureMapping maps a membership to a feature + MembershipFeatureMapping map[string]string `mapstructure:"membership_feature_mapping"` +} + +// GetFeaturesForMemberships returns the features associated with the memberships in the context +func (fc *FeaturesConfig) GetFeaturesForMemberships(ctx context.Context) []string { + memberships := extractMembershipsFromContext(ctx) + + features := make([]string, 0, len(memberships)) + for _, m := range memberships { + if feature := fc.MembershipFeatureMapping[m]; feature != "" { + features = append(features, feature) + } + } + + return features +} + +// extractMembershipsFromContext extracts memberships from the JWT in the context. +// Returns empty slice if no memberships are found. +func extractMembershipsFromContext(ctx context.Context) []string { + realmAccess, ok := jwt.GetUserClaimFromContext[map[string]any](ctx, "realm_access") + if !ok { + return nil + } + + rawMemberships, ok := realmAccess["roles"].([]any) + if !ok { + return nil + } + + memberships := make([]string, 0, len(rawMemberships)) + for _, membership := range rawMemberships { + if membershipStr, ok := membership.(string); ok { + memberships = append(memberships, membershipStr) + } + } + + return memberships +}