diff --git a/internal/datasources/service/service.go b/internal/datasources/service/service.go index 3219eac96f..6bf70ca171 100644 --- a/internal/datasources/service/service.go +++ b/internal/datasources/service/service.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/encoding/protojson" "github.com/mindersec/minder/internal/datasources" "github.com/mindersec/minder/internal/db" @@ -172,11 +173,86 @@ func (d *dataSourceService) List( return outDS, nil } -// nolint:revive // there is a TODO func (d *dataSourceService) Create( ctx context.Context, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) { - //TODO implement me - panic("implement me") + if ds == nil { + return nil, errors.New("data source is nil") + } + + stx, err := d.txBuilder(d, opts) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + + defer func(stx serviceTX) { + err := stx.Rollback() + if err != nil { + fmt.Printf("failed to rollback transaction: %v", err) + } + }(stx) + + tx := stx.Q() + + projectID, err := uuid.Parse(ds.GetContext().GetProjectId()) + if err != nil { + return nil, fmt.Errorf("invalid project ID: %w", err) + } + + // Check if such data source already exists in project hierarchy + projs, err := listRelevantProjects(ctx, tx, projectID, true) + if err != nil { + return nil, fmt.Errorf("failed to list relevant projects: %w", err) + } + existing, err := tx.GetDataSourceByName(ctx, db.GetDataSourceByNameParams{ + Name: ds.GetName(), + Projects: projs, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("failed to check for existing data source: %w", err) + } + if existing.ID != uuid.Nil { + return nil, util.UserVisibleError(codes.AlreadyExists, + "data source with name %s already exists", ds.GetName()) + } + + // Create data source record + dsRecord, err := tx.CreateDataSource(ctx, db.CreateDataSourceParams{ + ProjectID: projectID, + Name: ds.GetName(), + DisplayName: ds.GetName(), + }) + if err != nil { + return nil, fmt.Errorf("failed to create data source: %w", err) + } + + // Create function records based on driver type + switch drv := ds.GetDriver().(type) { + case *minderv1.DataSource_Rest: + for name, def := range drv.Rest.GetDef() { + defBytes, err := protojson.Marshal(def) + if err != nil { + return nil, fmt.Errorf("failed to marshal REST definition: %w", err) + } + + if _, err := tx.AddDataSourceFunction(ctx, db.AddDataSourceFunctionParams{ + DataSourceID: dsRecord.ID, + ProjectID: projectID, + Name: name, + Type: v1datasources.DataSourceDriverRest, + Definition: defBytes, + }); err != nil { + return nil, fmt.Errorf("failed to create data source function: %w", err) + } + } + default: + return nil, fmt.Errorf("unsupported data source driver type: %T", drv) + } + + if err := stx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit transaction: %w", err) + } + + return ds, nil } // nolint:revive // there is a TODO diff --git a/internal/datasources/service/service_test.go b/internal/datasources/service/service_test.go index 6d1e0f20a5..d0f9ed4f8d 100644 --- a/internal/datasources/service/service_test.go +++ b/internal/datasources/service/service_test.go @@ -21,7 +21,7 @@ import ( "github.com/mindersec/minder/internal/db" "github.com/mindersec/minder/internal/util/ptr" minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1" - "github.com/mindersec/minder/pkg/datasources/v1" + v1 "github.com/mindersec/minder/pkg/datasources/v1" ) func TestGetByName(t *testing.T) { @@ -388,6 +388,167 @@ func TestList(t *testing.T) { } } +func TestCreate(t *testing.T) { + t.Parallel() + + type args struct { + ds *minderv1.DataSource + opts *Options + } + tests := []struct { + name string + args args + setup func(mockDB *mockdb.MockStore) + want *minderv1.DataSource + wantErr bool + }{ + { + name: "Successfully create REST data source", + args: args{ + ds: &minderv1.DataSource{ + Name: "test_ds", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: &minderv1.DataSource_Rest{ + Rest: &minderv1.RestDataSource{ + Def: map[string]*minderv1.RestDataSource_Def{ + "test_function": { + Endpoint: "http://example.com", + InputSchema: func() *structpb.Struct { + s, _ := structpb.NewStruct(map[string]any{ + "type": "object", + "properties": map[string]any{ + "test": "string", + }, + }) + return s + }(), + }, + }, + }, + }, + }, + opts: &Options{}, + }, + setup: func(mockDB *mockdb.MockStore) { + mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()). + Return([]uuid.UUID{uuid.New()}, nil) + + mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()). + Return(db.DataSource{}, sql.ErrNoRows) + + mockDB.EXPECT().CreateDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: uuid.New(), + Name: "test_ds", + }, nil) + + mockDB.EXPECT().AddDataSourceFunction(gomock.Any(), gomock.Any()). + Return(db.DataSourcesFunction{}, nil) + }, + want: &minderv1.DataSource{ + Name: "test_ds", + }, + wantErr: false, + }, + { + name: "Nil data source", + args: args{ + ds: nil, + opts: &Options{}, + }, + setup: func(_ *mockdb.MockStore) {}, + wantErr: true, + }, + { + name: "Invalid project ID", + args: args{ + ds: &minderv1.DataSource{ + Context: &minderv1.ContextV2{ + ProjectId: "invalid-uuid", + }, + }, + opts: &Options{}, + }, + setup: func(_ *mockdb.MockStore) {}, + wantErr: true, + }, + { + name: "Data source already exists", + args: args{ + ds: &minderv1.DataSource{ + Name: "existing_ds", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + }, + opts: &Options{}, + }, + setup: func(mockDB *mockdb.MockStore) { + mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()). + Return([]uuid.UUID{uuid.New()}, nil) + mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()). + Return(db.DataSource{ID: uuid.New()}, nil) + }, + wantErr: true, + }, + { + name: "Unsupported driver type", + args: args{ + ds: &minderv1.DataSource{ + Name: "test_ds", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: nil, + }, + opts: &Options{}, + }, + setup: func(mockDB *mockdb.MockStore) { + mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()). + Return([]uuid.UUID{uuid.New()}, nil) + mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()). + Return(db.DataSource{}, sql.ErrNoRows) + mockDB.EXPECT().CreateDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: uuid.New(), + Name: "test_ds", + }, nil) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := mockdb.NewMockStore(ctrl) + + svc := NewDataSourceService(mockStore) + svc.txBuilder = func(_ *dataSourceService, _ txGetter) (serviceTX, error) { + return &fakeTxBuilder{ + store: mockStore, + }, nil + } + tt.setup(mockStore) + + got, err := svc.Create(context.Background(), tt.args.ds, tt.args.opts) + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want.Name, got.Name) + }) + } +} + func TestBuildDataSourceRegistry(t *testing.T) { t.Parallel()