Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datasource Service Create #5056

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions internal/datasources/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/google/uuid"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/encoding/protojson"

"github.com/mindersec/minder/internal/datasources"
"github.com/mindersec/minder/internal/db"
Expand Down Expand Up @@ -172,11 +173,86 @@ func (d *dataSourceService) List(
return outDS, nil
}

// nolint:revive // there is a TODO
func (d *dataSourceService) Create(
ctx context.Context, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) {
//TODO implement me
panic("implement me")
if ds == nil {
return nil, errors.New("data source is nil")
}

stx, err := d.txBuilder(d, opts)
if err != nil {
return nil, fmt.Errorf("failed to start transaction: %w", err)
}

defer func(stx serviceTX) {
err := stx.Rollback()
if err != nil {
fmt.Printf("failed to rollback transaction: %v", err)
}
}(stx)

tx := stx.Q()

projectID, err := uuid.Parse(ds.GetContext().GetProjectId())
if err != nil {
return nil, fmt.Errorf("invalid project ID: %w", err)
}

// Check if such data source already exists in project hierarchy
projs, err := listRelevantProjects(ctx, tx, projectID, true)
if err != nil {
return nil, fmt.Errorf("failed to list relevant projects: %w", err)
}
existing, err := tx.GetDataSourceByName(ctx, db.GetDataSourceByNameParams{
Name: ds.GetName(),
Projects: projs,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("failed to check for existing data source: %w", err)
}
if existing.ID != uuid.Nil {
return nil, util.UserVisibleError(codes.AlreadyExists,
"data source with name %s already exists", ds.GetName())
}

// Create data source record
dsRecord, err := tx.CreateDataSource(ctx, db.CreateDataSourceParams{
ProjectID: projectID,
Name: ds.GetName(),
DisplayName: ds.GetName(),
})
if err != nil {
return nil, fmt.Errorf("failed to create data source: %w", err)
}

// Create function records based on driver type
switch drv := ds.GetDriver().(type) {
case *minderv1.DataSource_Rest:
for name, def := range drv.Rest.GetDef() {
defBytes, err := protojson.Marshal(def)
if err != nil {
return nil, fmt.Errorf("failed to marshal REST definition: %w", err)
}

if _, err := tx.AddDataSourceFunction(ctx, db.AddDataSourceFunctionParams{
DataSourceID: dsRecord.ID,
ProjectID: projectID,
Name: name,
Type: v1datasources.DataSourceDriverRest,
Definition: defBytes,
}); err != nil {
return nil, fmt.Errorf("failed to create data source function: %w", err)
}
}
default:
return nil, fmt.Errorf("unsupported data source driver type: %T", drv)
}

if err := stx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}

return ds, nil
}

// nolint:revive // there is a TODO
Expand Down
163 changes: 162 additions & 1 deletion internal/datasources/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/mindersec/minder/internal/db"
"github.com/mindersec/minder/internal/util/ptr"
minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
"github.com/mindersec/minder/pkg/datasources/v1"
v1 "github.com/mindersec/minder/pkg/datasources/v1"
)

func TestGetByName(t *testing.T) {
Expand Down Expand Up @@ -388,6 +388,167 @@ func TestList(t *testing.T) {
}
}

func TestCreate(t *testing.T) {
t.Parallel()

type args struct {
ds *minderv1.DataSource
opts *Options
}
tests := []struct {
name string
args args
setup func(mockDB *mockdb.MockStore)
want *minderv1.DataSource
wantErr bool
}{
{
name: "Successfully create REST data source",
args: args{
ds: &minderv1.DataSource{
Name: "test_ds",
Context: &minderv1.ContextV2{
ProjectId: uuid.New().String(),
},
Driver: &minderv1.DataSource_Rest{
Rest: &minderv1.RestDataSource{
Def: map[string]*minderv1.RestDataSource_Def{
"test_function": {
Endpoint: "http://example.com",
InputSchema: func() *structpb.Struct {
s, _ := structpb.NewStruct(map[string]any{
"type": "object",
"properties": map[string]any{
"test": "string",
},
})
return s
}(),
},
},
},
},
},
opts: &Options{},
},
setup: func(mockDB *mockdb.MockStore) {
mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()).
Return([]uuid.UUID{uuid.New()}, nil)

mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()).
Return(db.DataSource{}, sql.ErrNoRows)

mockDB.EXPECT().CreateDataSource(gomock.Any(), gomock.Any()).
Return(db.DataSource{
ID: uuid.New(),
Name: "test_ds",
}, nil)

mockDB.EXPECT().AddDataSourceFunction(gomock.Any(), gomock.Any()).
Return(db.DataSourcesFunction{}, nil)
},
want: &minderv1.DataSource{
Name: "test_ds",
},
wantErr: false,
},
{
name: "Nil data source",
args: args{
ds: nil,
opts: &Options{},
},
setup: func(_ *mockdb.MockStore) {},
wantErr: true,
},
{
name: "Invalid project ID",
args: args{
ds: &minderv1.DataSource{
Context: &minderv1.ContextV2{
ProjectId: "invalid-uuid",
},
},
opts: &Options{},
},
setup: func(_ *mockdb.MockStore) {},
wantErr: true,
},
{
name: "Data source already exists",
args: args{
ds: &minderv1.DataSource{
Name: "existing_ds",
Context: &minderv1.ContextV2{
ProjectId: uuid.New().String(),
},
},
opts: &Options{},
},
setup: func(mockDB *mockdb.MockStore) {
mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()).
Return([]uuid.UUID{uuid.New()}, nil)
mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()).
Return(db.DataSource{ID: uuid.New()}, nil)
},
wantErr: true,
},
{
name: "Unsupported driver type",
args: args{
ds: &minderv1.DataSource{
Name: "test_ds",
Context: &minderv1.ContextV2{
ProjectId: uuid.New().String(),
},
Driver: nil,
},
opts: &Options{},
},
setup: func(mockDB *mockdb.MockStore) {
mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()).
Return([]uuid.UUID{uuid.New()}, nil)
mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()).
Return(db.DataSource{}, sql.ErrNoRows)
mockDB.EXPECT().CreateDataSource(gomock.Any(), gomock.Any()).
Return(db.DataSource{
ID: uuid.New(),
Name: "test_ds",
}, nil)
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockStore := mockdb.NewMockStore(ctrl)

svc := NewDataSourceService(mockStore)
svc.txBuilder = func(_ *dataSourceService, _ txGetter) (serviceTX, error) {
return &fakeTxBuilder{
store: mockStore,
}, nil
}
tt.setup(mockStore)

got, err := svc.Create(context.Background(), tt.args.ds, tt.args.opts)
if tt.wantErr {
assert.Error(t, err)
return
}

require.NoError(t, err)
assert.Equal(t, tt.want.Name, got.Name)
})
}
}

func TestBuildDataSourceRegistry(t *testing.T) {
t.Parallel()

Expand Down