diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 00000000..9cd8b36a --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,41 @@ +name: Go + +on: + push: + branches: ['*', '*/*'] + tags: ['v*'] + pull_request: + branches: ['*'] + +permissions: + contents: read + +jobs: + lint: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/setup-go@v5 + with: + go-version: '1.21' + - uses: actions/checkout@v4 + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + args: --timeout=5m + test: + name: test + strategy: + matrix: + go: ["1.21.x"] + runs-on: ubuntu-latest + steps: + - name: Setup Go + with: + go-version: ${{ matrix.go }} + uses: actions/setup-go@v2 + + - uses: actions/checkout@v2 + + - name: Test + run: go test ./... diff --git a/accounting/component/metering.go b/accounting/component/metering.go index 4e04a086..73d245ad 100644 --- a/accounting/component/metering.go +++ b/accounting/component/metering.go @@ -8,18 +8,23 @@ import ( "opencsg.com/csghub-server/common/types" ) -type MeteringComponent struct { - ams *database.AccountMeteringStore +type meteringComponentImpl struct { + ams database.AccountMeteringStore } -func NewMeteringComponent() *MeteringComponent { - ams := &MeteringComponent{ +type MeteringComponent interface { + SaveMeteringEventRecord(ctx context.Context, req *types.METERING_EVENT) error + ListMeteringByUserIDAndDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]database.AccountMetering, int, error) +} + +func NewMeteringComponent() MeteringComponent { + ams := &meteringComponentImpl{ ams: database.NewAccountMeteringStore(), } return ams } -func (mc *MeteringComponent) SaveMeteringEventRecord(ctx context.Context, req *types.METERING_EVENT) error { +func (mc *meteringComponentImpl) SaveMeteringEventRecord(ctx context.Context, req *types.METERING_EVENT) error { am := database.AccountMetering{ EventUUID: req.Uuid, UserUUID: req.UserUUID, @@ -41,7 +46,7 @@ func (mc *MeteringComponent) SaveMeteringEventRecord(ctx context.Context, req *t return nil } -func (mc *MeteringComponent) ListMeteringByUserIDAndDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]database.AccountMetering, int, error) { +func (mc *meteringComponentImpl) ListMeteringByUserIDAndDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]database.AccountMetering, int, error) { meters, total, err := mc.ams.ListByUserIDAndTime(ctx, req) if err != nil { return nil, 0, fmt.Errorf("failed to list metering by UserIDAndDate, error: %w", err) diff --git a/accounting/consumer/metering.go b/accounting/consumer/metering.go index bcf7a315..10fd77af 100644 --- a/accounting/consumer/metering.go +++ b/accounting/consumer/metering.go @@ -17,7 +17,7 @@ import ( type Metering struct { sysMQ *mq.NatsHandler - meterComp *component.MeteringComponent + meterComp component.MeteringComponent } func NewMetering(natHandler *mq.NatsHandler, config *config.Config) *Metering { diff --git a/accounting/handler/metering.go b/accounting/handler/metering.go index 97bb79f0..a3e86bd9 100644 --- a/accounting/handler/metering.go +++ b/accounting/handler/metering.go @@ -18,7 +18,7 @@ func NewMeteringHandler() (*MeteringHandler, error) { } type MeteringHandler struct { - amc *component.MeteringComponent + amc component.MeteringComponent } func (mh *MeteringHandler) QueryMeteringStatementByUserID(ctx *gin.Context) { diff --git a/api/handler/accounting.go b/api/handler/accounting.go index 47cbb389..71ec3346 100644 --- a/api/handler/accounting.go +++ b/api/handler/accounting.go @@ -16,7 +16,7 @@ import ( ) type AccountingHandler struct { - ac *component.AccountingComponent + ac component.AccountingComponent apiToken string } diff --git a/api/handler/cluster.go b/api/handler/cluster.go index a1cdd430..4211872b 100644 --- a/api/handler/cluster.go +++ b/api/handler/cluster.go @@ -21,7 +21,7 @@ func NewClusterHandler(config *config.Config) (*ClusterHandler, error) { } type ClusterHandler struct { - c *component.ClusterComponent + c component.ClusterComponent } // Getclusters godoc diff --git a/api/handler/code.go b/api/handler/code.go index 586a755c..5792c0c0 100644 --- a/api/handler/code.go +++ b/api/handler/code.go @@ -31,8 +31,8 @@ func NewCodeHandler(config *config.Config) (*CodeHandler, error) { } type CodeHandler struct { - c *component.CodeComponent - sc *component.SensitiveComponent + c component.CodeComponent + sc component.SensitiveComponent } // CreateCode godoc diff --git a/api/handler/collection.go b/api/handler/collection.go index 9f99de99..37efd370 100644 --- a/api/handler/collection.go +++ b/api/handler/collection.go @@ -32,8 +32,8 @@ func NewCollectionHandler(cfg *config.Config) (*CollectionHandler, error) { } type CollectionHandler struct { - cc *component.CollectionComponent - sc *component.SensitiveComponent + cc component.CollectionComponent + sc component.SensitiveComponent } // GetCollections godoc diff --git a/api/handler/dataset.go b/api/handler/dataset.go index 0463dca9..ca9bc94d 100644 --- a/api/handler/dataset.go +++ b/api/handler/dataset.go @@ -27,15 +27,22 @@ func NewDatasetHandler(config *config.Config) (*DatasetHandler, error) { if err != nil { return nil, fmt.Errorf("error creating sensitive component:%w", err) } + repo, err := component.NewRepoComponent(config) + if err != nil { + return nil, fmt.Errorf("error creating repo component:%w", err) + } + return &DatasetHandler{ - c: tc, - sc: sc, + c: tc, + sc: sc, + repo: repo, }, nil } type DatasetHandler struct { - c *component.DatasetComponent - sc *component.SensitiveComponent + c component.DatasetComponent + sc component.SensitiveComponent + repo component.RepoComponent } // CreateDataset godoc @@ -344,7 +351,7 @@ func (h *DatasetHandler) AllFiles(ctx *gin.Context) { req.RepoType = types.DatasetRepo req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Ref = "" - detail, err := h.c.AllFiles(ctx, req) + detail, err := h.repo.AllFiles(ctx, req) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) diff --git a/api/handler/dataset_viewer.go b/api/handler/dataset_viewer.go index a7f23eb6..2f679cd4 100644 --- a/api/handler/dataset_viewer.go +++ b/api/handler/dataset_viewer.go @@ -12,7 +12,7 @@ import ( ) type DatasetViewerHandler struct { - c *component.DatasetViewerComponent + c component.DatasetViewerComponent } func NewDatasetViewerHandler(cfg *config.Config) (*DatasetViewerHandler, error) { diff --git a/api/handler/discussion.go b/api/handler/discussion.go index 404a28af..109c5553 100644 --- a/api/handler/discussion.go +++ b/api/handler/discussion.go @@ -15,8 +15,8 @@ import ( ) type DiscussionHandler struct { - c *component.DiscussionComponent - sc *component.SensitiveComponent + c component.DiscussionComponent + sc component.SensitiveComponent } func NewDiscussionHandler(cfg *config.Config) (*DiscussionHandler, error) { diff --git a/api/handler/event.go b/api/handler/event.go index 3dd6060a..c49fae9d 100644 --- a/api/handler/event.go +++ b/api/handler/event.go @@ -11,7 +11,7 @@ import ( ) type EventHandler struct { - ec *component.EventComponent + ec component.EventComponent } func NewEventHandler() (*EventHandler, error) { diff --git a/api/handler/git_http.go b/api/handler/git_http.go index 04144893..5047cea1 100644 --- a/api/handler/git_http.go +++ b/api/handler/git_http.go @@ -30,7 +30,7 @@ func NewGitHTTPHandler(config *config.Config) (*GitHTTPHandler, error) { } type GitHTTPHandler struct { - c *component.GitHTTPComponent + c component.GitHTTPComponent } func (h *GitHTTPHandler) InfoRefs(ctx *gin.Context) { diff --git a/api/handler/hf_dataset.go b/api/handler/hf_dataset.go index d930b806..7c111bd3 100644 --- a/api/handler/hf_dataset.go +++ b/api/handler/hf_dataset.go @@ -24,7 +24,7 @@ func NewHFDatasetHandler(config *config.Config) (*HFDatasetHandler, error) { } type HFDatasetHandler struct { - dc *component.HFDatasetComponent + dc component.HFDatasetComponent } func (h *HFDatasetHandler) DatasetPathsInfo(ctx *gin.Context) { diff --git a/api/handler/internal.go b/api/handler/internal.go index 70acb549..020206f7 100644 --- a/api/handler/internal.go +++ b/api/handler/internal.go @@ -26,7 +26,7 @@ func NewInternalHandler(config *config.Config) (*InternalHandler, error) { } type InternalHandler struct { - c *component.InternalComponent + c component.InternalComponent config *config.Config } diff --git a/api/handler/list.go b/api/handler/list.go index f129524a..8004f73e 100644 --- a/api/handler/list.go +++ b/api/handler/list.go @@ -27,8 +27,8 @@ func NewListHandler(config *config.Config) (*ListHandler, error) { } type ListHandler struct { - c *component.ListComponent - sc *component.SpaceComponent + c component.ListComponent + sc component.SpaceComponent } // ListTrendingModels godoc diff --git a/api/handler/mirror.go b/api/handler/mirror.go index 25465396..36d35d46 100644 --- a/api/handler/mirror.go +++ b/api/handler/mirror.go @@ -23,7 +23,7 @@ func NewMirrorHandler(config *config.Config) (*MirrorHandler, error) { } type MirrorHandler struct { - mc *component.MirrorComponent + mc component.MirrorComponent } // CreateMirrorRepo godoc diff --git a/api/handler/mirror_source.go b/api/handler/mirror_source.go index 17bd8585..b55406b8 100644 --- a/api/handler/mirror_source.go +++ b/api/handler/mirror_source.go @@ -23,7 +23,7 @@ func NewMirrorSourceHandler(config *config.Config) (*MirrorSourceHandler, error) } type MirrorSourceHandler struct { - c *component.MirrorSourceComponent + c component.MirrorSourceComponent } // CreateMirrorSource godoc diff --git a/api/handler/model.go b/api/handler/model.go index 2e60c8a5..9d489f7a 100644 --- a/api/handler/model.go +++ b/api/handler/model.go @@ -26,15 +26,22 @@ func NewModelHandler(config *config.Config) (*ModelHandler, error) { if err != nil { return nil, fmt.Errorf("error creating sensitive component:%w", err) } + repo, err := component.NewRepoComponent(config) + if err != nil { + return nil, fmt.Errorf("error creating repo component:%w", err) + } + return &ModelHandler{ - c: uc, - sc: sc, + c: uc, + sc: sc, + repo: repo, }, nil } type ModelHandler struct { - c *component.ModelComponent - sc *component.SensitiveComponent + c component.ModelComponent + sc component.SensitiveComponent + repo component.RepoComponent } // GetVisiableModels godoc @@ -604,7 +611,7 @@ func (h *ModelHandler) DeployDedicated(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - allow, err := h.c.AllowReadAccess(ctx, types.ModelRepo, namespace, name, currentUser) + allow, err := h.repo.AllowReadAccess(ctx, types.ModelRepo, namespace, name, currentUser) if err != nil { slog.Error("failed to check user permission", "error", err) httpbase.ServerError(ctx, errors.New("failed to check user permission")) @@ -691,7 +698,7 @@ func (h *ModelHandler) FinetuneCreate(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - allow, err := h.c.AllowAdminAccess(ctx, types.ModelRepo, namespace, name, currentUser) + allow, err := h.repo.AllowAdminAccess(ctx, types.ModelRepo, namespace, name, currentUser) if err != nil { slog.Error("failed to check user permission", "error", err) httpbase.ServerError(ctx, errors.New("failed to check user permission")) @@ -791,7 +798,7 @@ func (h *ModelHandler) DeployDelete(ctx *gin.Context) { DeployID: id, DeployType: types.InferenceType, } - err = h.c.DeleteDeploy(ctx, delReq) + err = h.repo.DeleteDeploy(ctx, delReq) if err != nil { slog.Error("Failed to delete deploy", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -846,7 +853,7 @@ func (h *ModelHandler) FinetuneDelete(ctx *gin.Context) { DeployID: id, DeployType: types.FinetuneType, } - err = h.c.DeleteDeploy(ctx, delReq) + err = h.repo.DeleteDeploy(ctx, delReq) if err != nil { slog.Error("Failed to delete deploy", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -900,7 +907,7 @@ func (h *ModelHandler) DeployStop(ctx *gin.Context) { DeployID: id, DeployType: types.InferenceType, } - err = h.c.DeployStop(ctx, stopReq) + err = h.repo.DeployStop(ctx, stopReq) if err != nil { slog.Error("Failed to stop deploy", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -956,7 +963,7 @@ func (h *ModelHandler) DeployStart(ctx *gin.Context) { DeployType: types.InferenceType, } - err = h.c.DeployStart(ctx, startReq) + err = h.repo.DeployStart(ctx, startReq) if err != nil { slog.Error("Failed to start deploy", slog.Any("error", err), slog.Any("repoType", types.ModelRepo), slog.String("namespace", namespace), slog.String("name", name), slog.Any("deployID", id)) httpbase.ServerError(ctx, err) @@ -1069,7 +1076,7 @@ func (h *ModelHandler) FinetuneStop(ctx *gin.Context) { DeployID: id, DeployType: types.FinetuneType, } - err = h.c.DeployStop(ctx, stopReq) + err = h.repo.DeployStop(ctx, stopReq) if err != nil { slog.Error("Failed to stop deploy", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1123,7 +1130,7 @@ func (h *ModelHandler) FinetuneStart(ctx *gin.Context) { DeployID: id, DeployType: types.FinetuneType, } - err = h.c.DeployStart(ctx, startReq) + err = h.repo.DeployStart(ctx, startReq) if err != nil { slog.Error("Failed to start deploy", slog.Any("error", err), slog.Any("repoType", types.ModelRepo), slog.String("namespace", namespace), slog.String("name", name), slog.Any("deployID", id)) httpbase.ServerError(ctx, err) @@ -1351,7 +1358,7 @@ func (h *ModelHandler) AllFiles(ctx *gin.Context) { req.Name = name req.RepoType = types.ModelRepo req.CurrentUser = httpbase.GetCurrentUser(ctx) - detail, err := h.c.AllFiles(ctx, req) + detail, err := h.repo.AllFiles(ctx, req) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -1482,7 +1489,7 @@ func (h *ModelHandler) ServerlessStart(ctx *gin.Context) { DeployType: types.ServerlessType, } - err = h.c.DeployStart(ctx, startReq) + err = h.repo.DeployStart(ctx, startReq) if err != nil { slog.Error("Failed to start deploy", slog.Any("error", err), slog.Any("repoType", types.ModelRepo), slog.String("namespace", namespace), slog.String("name", name), slog.Any("deployID", id)) httpbase.ServerError(ctx, err) @@ -1538,7 +1545,7 @@ func (h *ModelHandler) ServerlessStop(ctx *gin.Context) { DeployType: types.ServerlessType, } - err = h.c.DeployStop(ctx, stopReq) + err = h.repo.DeployStop(ctx, stopReq) if err != nil { slog.Error("Failed to stop deploy", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/multi_sync.go b/api/handler/multi_sync.go index 1eab76e9..fcbb4215 100644 --- a/api/handler/multi_sync.go +++ b/api/handler/multi_sync.go @@ -12,7 +12,7 @@ import ( ) type SyncHandler struct { - c *component.MultiSyncComponent + c component.MultiSyncComponent } func NewSyncHandler(config *config.Config) (*SyncHandler, error) { diff --git a/api/handler/organization.go b/api/handler/organization.go index 29064279..69caf46e 100644 --- a/api/handler/organization.go +++ b/api/handler/organization.go @@ -48,12 +48,12 @@ func NewOrganizationHandler(config *config.Config) (*OrganizationHandler, error) } type OrganizationHandler struct { - sc *component.SpaceComponent - cc *component.CodeComponent - mc *component.ModelComponent - dsc *component.DatasetComponent - colc *component.CollectionComponent - pc *component.PromptComponent + sc component.SpaceComponent + cc component.CodeComponent + mc component.ModelComponent + dsc component.DatasetComponent + colc component.CollectionComponent + pc component.PromptComponent } // GetOrganizationModels godoc diff --git a/api/handler/prompt.go b/api/handler/prompt.go index e0132584..0eecbdee 100644 --- a/api/handler/prompt.go +++ b/api/handler/prompt.go @@ -20,8 +20,9 @@ import ( ) type PromptHandler struct { - pc *component.PromptComponent - sc *component.SensitiveComponent + pc component.PromptComponent + sc component.SensitiveComponent + repo component.RepoComponent } func NewPromptHandler(cfg *config.Config) (*PromptHandler, error) { @@ -33,9 +34,15 @@ func NewPromptHandler(cfg *config.Config) (*PromptHandler, error) { if err != nil { return nil, fmt.Errorf("failed to create SensitiveComponent: %w", err) } + repo, err := component.NewRepoComponent(cfg) + if err != nil { + return nil, fmt.Errorf("error creating repo component:%w", err) + } + return &PromptHandler{ - pc: promptComp, - sc: sc, + pc: promptComp, + sc: sc, + repo: repo, }, nil } @@ -1100,7 +1107,7 @@ func (h *PromptHandler) Branches(ctx *gin.Context) { RepoType: types.PromptRepo, CurrentUser: currentUser, } - branches, err := h.pc.Branches(ctx, req) + branches, err := h.repo.Branches(ctx, req) if err != nil { slog.Error("Failed to get prompt repo branches", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1137,7 +1144,7 @@ func (h *PromptHandler) Tags(ctx *gin.Context) { RepoType: types.PromptRepo, CurrentUser: currentUser, } - tags, err := h.pc.Tags(ctx, req) + tags, err := h.repo.Tags(ctx, req) if err != nil { slog.Error("Failed to get prompt repo tags", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1180,7 +1187,7 @@ func (h *PromptHandler) UpdateTags(ctx *gin.Context) { } category := ctx.Param("category") - err = h.pc.UpdateTags(ctx, namespace, name, types.PromptRepo, category, currentUser, tags) + err = h.repo.UpdateTags(ctx, namespace, name, types.PromptRepo, category, currentUser, tags) if err != nil { slog.Error("Failed to update tags", slog.String("error", err.Error()), slog.String("category", category), slog.String("namespace", namespace), slog.String("name", name)) httpbase.ServerError(ctx, err) @@ -1216,7 +1223,7 @@ func (h *PromptHandler) UpdateDownloads(ctx *gin.Context) { } req.Date = date - err = h.pc.UpdateDownloads(ctx, req) + err = h.repo.UpdateDownloads(ctx, req) if err != nil { slog.Error("Failed to update repo download count", slog.Any("error", err), slog.String("namespace", namespace), slog.String("name", name), slog.Time("date", date), slog.Int64("clone_count", req.CloneCount)) httpbase.ServerError(ctx, err) diff --git a/api/handler/recom.go b/api/handler/recom.go index 1a7c87b1..7d03e43d 100644 --- a/api/handler/recom.go +++ b/api/handler/recom.go @@ -12,7 +12,7 @@ import ( // RecomHandler handles requests for repo recommendation type RecomHandler struct { - c *component.RecomComponent + c component.RecomComponent } func NewRecomHandler(cfg *config.Config) (*RecomHandler, error) { diff --git a/api/handler/repo.go b/api/handler/repo.go index 8b6a435a..8ebc865c 100644 --- a/api/handler/repo.go +++ b/api/handler/repo.go @@ -34,7 +34,7 @@ func NewRepoHandler(config *config.Config) (*RepoHandler, error) { } type RepoHandler struct { - c *component.RepoComponent + c component.RepoComponent } // CreateRepoFile godoc diff --git a/api/handler/rproxy.go b/api/handler/rproxy.go index 38b54428..cfa66551 100644 --- a/api/handler/rproxy.go +++ b/api/handler/rproxy.go @@ -16,8 +16,8 @@ import ( type RProxyHandler struct { SpaceRootDomain string - spaceComp *component.SpaceComponent - repoComp *component.RepoComponent + spaceComp component.SpaceComponent + repoComp component.RepoComponent } func NewRProxyHandler(config *config.Config) (*RProxyHandler, error) { diff --git a/api/handler/runtime_architecture.go b/api/handler/runtime_architecture.go index 88dca304..6c73a798 100644 --- a/api/handler/runtime_architecture.go +++ b/api/handler/runtime_architecture.go @@ -29,8 +29,8 @@ func NewRuntimeArchitectureHandler(config *config.Config) (*RuntimeArchitectureH } type RuntimeArchitectureHandler struct { - rc *component.RepoComponent - rac *component.RuntimeArchitectureComponent + rc component.RepoComponent + rac component.RuntimeArchitectureComponent } // GetArchitectures godoc diff --git a/api/handler/sensitive.go b/api/handler/sensitive.go index 57b08a39..13d1eb17 100644 --- a/api/handler/sensitive.go +++ b/api/handler/sensitive.go @@ -11,7 +11,7 @@ import ( ) type SensitiveHandler struct { - c *component.SensitiveComponent + c component.SensitiveComponent } func NewSensitiveHandler(cfg *config.Config) (*SensitiveHandler, error) { diff --git a/api/handler/space.go b/api/handler/space.go index 3f2120ae..04d7c290 100644 --- a/api/handler/space.go +++ b/api/handler/space.go @@ -25,15 +25,22 @@ func NewSpaceHandler(config *config.Config) (*SpaceHandler, error) { if err != nil { return nil, fmt.Errorf("error creating sensitive component:%w", err) } + repo, err := component.NewRepoComponent(config) + if err != nil { + return nil, fmt.Errorf("error creating repo component:%w", err) + } + return &SpaceHandler{ - c: c, - ssc: ssc, + c: c, + ssc: ssc, + repo: repo, }, nil } type SpaceHandler struct { - c *component.SpaceComponent - ssc *component.SensitiveComponent + c component.SpaceComponent + ssc component.SensitiveComponent + repo component.RepoComponent } // GetAllSpaces godoc @@ -289,7 +296,7 @@ func (h *SpaceHandler) Run(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - allow, err := h.c.AllowAdminAccess(ctx, types.SpaceRepo, namespace, name, currentUser) + allow, err := h.repo.AllowAdminAccess(ctx, types.SpaceRepo, namespace, name, currentUser) if err != nil { slog.Error("failed to check user permission", "error", err) httpbase.ServerError(ctx, errors.New("failed to check user permission")) @@ -367,7 +374,7 @@ func (h *SpaceHandler) Stop(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - allow, err := h.c.AllowAdminAccess(ctx, types.SpaceRepo, namespace, name, currentUser) + allow, err := h.repo.AllowAdminAccess(ctx, types.SpaceRepo, namespace, name, currentUser) if err != nil { slog.Error("failed to check user permission", "error", err) httpbase.ServerError(ctx, errors.New("failed to check user permission")) @@ -419,7 +426,7 @@ func (h *SpaceHandler) Status(ctx *gin.Context) { } currentUser := httpbase.GetCurrentUser(ctx) - allow, err := h.c.AllowReadAccess(ctx, types.SpaceRepo, namespace, name, currentUser) + allow, err := h.repo.AllowReadAccess(ctx, types.SpaceRepo, namespace, name, currentUser) if err != nil { slog.Error("failed to check user permission", "error", err) httpbase.ServerError(ctx, errors.New("failed to check user permission")) @@ -516,7 +523,7 @@ func (h *SpaceHandler) Logs(ctx *gin.Context) { } currentUser := httpbase.GetCurrentUser(ctx) - allow, err := h.c.AllowReadAccess(ctx, types.SpaceRepo, namespace, name, currentUser) + allow, err := h.repo.AllowReadAccess(ctx, types.SpaceRepo, namespace, name, currentUser) if err != nil { slog.Error("failed to check user permission", "error", err) httpbase.ServerError(ctx, errors.New("failed to check user permission")) diff --git a/api/handler/space_resource.go b/api/handler/space_resource.go index 94cf8096..86ad56db 100644 --- a/api/handler/space_resource.go +++ b/api/handler/space_resource.go @@ -22,7 +22,7 @@ func NewSpaceResourceHandler(config *config.Config) (*SpaceResourceHandler, erro } type SpaceResourceHandler struct { - c *component.SpaceResourceComponent + c component.SpaceResourceComponent } // GetSpaceResources godoc diff --git a/api/handler/space_sdk.go b/api/handler/space_sdk.go index d753a783..f3fc43d8 100644 --- a/api/handler/space_sdk.go +++ b/api/handler/space_sdk.go @@ -22,7 +22,7 @@ func NewSpaceSdkHandler(config *config.Config) (*SpaceSdkHandler, error) { } type SpaceSdkHandler struct { - c *component.SpaceSdkComponent + c component.SpaceSdkComponent } // GetSpaceSdks godoc diff --git a/api/handler/sshkey.go b/api/handler/sshkey.go index cead4967..fb808706 100644 --- a/api/handler/sshkey.go +++ b/api/handler/sshkey.go @@ -28,8 +28,8 @@ func NewSSHKeyHandler(config *config.Config) (*SSHKeyHandler, error) { } type SSHKeyHandler struct { - c *component.SSHKeyComponent - sc *component.SensitiveComponent + c component.SSHKeyComponent + sc component.SensitiveComponent } // CreateUserSSHKey godoc diff --git a/api/handler/sync_client_setting.go b/api/handler/sync_client_setting.go index 37e4599d..7ae70ee2 100644 --- a/api/handler/sync_client_setting.go +++ b/api/handler/sync_client_setting.go @@ -11,7 +11,7 @@ import ( ) type SyncClientSettingHandler struct { - c *component.SyncClientSettingComponent + c component.SyncClientSettingComponent } func NewSyncClientSettingHandler(config *config.Config) (*SyncClientSettingHandler, error) { diff --git a/api/handler/tag.go b/api/handler/tag.go index 16bfb382..a707bd06 100644 --- a/api/handler/tag.go +++ b/api/handler/tag.go @@ -21,7 +21,7 @@ func NewTagHandler(config *config.Config) (*TagsHandler, error) { } type TagsHandler struct { - tc *component.TagComponent + tc component.TagComponent } // GetAllTags godoc diff --git a/api/handler/telemetry.go b/api/handler/telemetry.go index 27aa88f5..a384f9aa 100644 --- a/api/handler/telemetry.go +++ b/api/handler/telemetry.go @@ -12,7 +12,7 @@ import ( ) type TelemetryHandler struct { - c *component.TelemetryComponent + c component.TelemetryComponent } func NewTelemetryHandler() (*TelemetryHandler, error) { diff --git a/api/handler/user.go b/api/handler/user.go index fcb8ea82..b0454146 100644 --- a/api/handler/user.go +++ b/api/handler/user.go @@ -25,7 +25,7 @@ func NewUserHandler(config *config.Config) (*UserHandler, error) { } type UserHandler struct { - c *component.UserComponent + c component.UserComponent } // GetUserDatasets godoc diff --git a/builder/deploy/cluster/cluster_manager.go b/builder/deploy/cluster/cluster_manager.go index cf833c90..a5692c5b 100644 --- a/builder/deploy/cluster/cluster_manager.go +++ b/builder/deploy/cluster/cluster_manager.go @@ -33,7 +33,7 @@ type Cluster struct { // ClusterPool is a resource pool of cluster information type ClusterPool struct { Clusters []Cluster - ClusterStore *database.ClusterInfoStore + ClusterStore database.ClusterInfoStore } // NewClusterPool initializes and returns a ClusterPool by reading kubeconfig files from $HOME/.kube directory diff --git a/builder/deploy/deployer.go b/builder/deploy/deployer.go index 985772e9..2365db6a 100644 --- a/builder/deploy/deployer.go +++ b/builder/deploy/deployer.go @@ -48,14 +48,14 @@ type deployer struct { ib imagebuilder.Builder ir imagerunner.Runner - store *database.DeployTaskStore - spaceStore *database.SpaceStore - spaceResourceStore *database.SpaceResourceStore + store database.DeployTaskStore + spaceStore database.SpaceStore + spaceResourceStore database.SpaceResourceStore runnerStatuscache map[string]types.StatusResponse internalRootDomain string sfNode *snowflake.Node eventPub *event.EventPublisher - rtfm *database.RuntimeFrameworksStore + rtfm database.RuntimeFrameworksStore } func newDeployer(s scheduler.Scheduler, ib imagebuilder.Builder, ir imagerunner.Runner) (*deployer, error) { diff --git a/builder/deploy/scheduler/builder_runner.go b/builder/deploy/scheduler/builder_runner.go index 0b6dd0b9..55cad33f 100644 --- a/builder/deploy/scheduler/builder_runner.go +++ b/builder/deploy/scheduler/builder_runner.go @@ -18,8 +18,8 @@ type BuilderRunner struct { repo *RepoInfo task *database.DeployTask ib imagebuilder.Builder - deployStore *database.DeployTaskStore - tokenStore *database.AccessTokenStore + deployStore database.DeployTaskStore + tokenStore database.AccessTokenStore } func NewBuidRunner(b imagebuilder.Builder, r *RepoInfo, t *database.DeployTask) Runner { diff --git a/builder/deploy/scheduler/deploy_runner.go b/builder/deploy/scheduler/deploy_runner.go index 998e43e6..be12529d 100644 --- a/builder/deploy/scheduler/deploy_runner.go +++ b/builder/deploy/scheduler/deploy_runner.go @@ -21,8 +21,8 @@ type DeployRunner struct { repo *RepoInfo task *database.DeployTask ir imagerunner.Runner - store *database.DeployTaskStore - tokenStore *database.AccessTokenStore + store database.DeployTaskStore + tokenStore database.AccessTokenStore deployStartTime time.Time deployCfg common.DeployConfig } diff --git a/builder/deploy/scheduler/scheduler.go b/builder/deploy/scheduler/scheduler.go index bb79955d..bb4a734a 100644 --- a/builder/deploy/scheduler/scheduler.go +++ b/builder/deploy/scheduler/scheduler.go @@ -29,10 +29,10 @@ type FIFOScheduler struct { tasks chan Runner last *database.DeployTask - store *database.DeployTaskStore - spaceStore *database.SpaceStore - modelStore *database.ModelStore - spaceResourcesStore *database.SpaceResourceStore + store database.DeployTaskStore + spaceStore database.SpaceStore + modelStore database.ModelStore + spaceResourcesStore database.SpaceResourceStore ib imagebuilder.Builder ir imagerunner.Runner diff --git a/builder/store/database/access_token.go b/builder/store/database/access_token.go index 640c239b..ae63d318 100644 --- a/builder/store/database/access_token.go +++ b/builder/store/database/access_token.go @@ -11,12 +11,26 @@ import ( "opencsg.com/csghub-server/common/types" ) -type AccessTokenStore struct { +type accessTokenStoreImpl struct { db *DB } -func NewAccessTokenStore() *AccessTokenStore { - return &AccessTokenStore{ +type AccessTokenStore interface { + Create(ctx context.Context, token *AccessToken) (err error) + // Refresh will disable existing access token, and then generate new one + Refresh(ctx context.Context, token *AccessToken, newTokenValue string, newExpiredAt time.Time) (*AccessToken, error) + FindByID(ctx context.Context, id int64) (token *AccessToken, err error) + Delete(ctx context.Context, username, tkName, app string) (err error) + IsExist(ctx context.Context, username, tkName, app string) (exists bool, err error) + FindByUID(ctx context.Context, uid int64) (token *AccessToken, err error) + GetUserGitToken(ctx context.Context, username string) (*AccessToken, error) + FindByToken(ctx context.Context, tokenValue, app string) (*AccessToken, error) + FindByTokenName(ctx context.Context, username, tokenName, app string) (*AccessToken, error) + FindByUser(ctx context.Context, username, app string) ([]AccessToken, error) +} + +func NewAccessTokenStore() AccessTokenStore { + return &accessTokenStoreImpl{ db: defaultDB, } } @@ -36,13 +50,13 @@ type AccessToken struct { times } -func (s *AccessTokenStore) Create(ctx context.Context, token *AccessToken) (err error) { +func (s *accessTokenStoreImpl) Create(ctx context.Context, token *AccessToken) (err error) { err = s.db.Operator.Core.NewInsert().Model(token).Scan(ctx) return } // Refresh will disable existing access token, and then generate new one -func (s *AccessTokenStore) Refresh(ctx context.Context, token *AccessToken, newTokenValue string, newExpiredAt time.Time) (*AccessToken, error) { +func (s *accessTokenStoreImpl) Refresh(ctx context.Context, token *AccessToken, newTokenValue string, newExpiredAt time.Time) (*AccessToken, error) { var newToken *AccessToken err := s.db.Core.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { _, err := tx.NewUpdate().Model(token). @@ -80,7 +94,7 @@ func (s *AccessTokenStore) Refresh(ctx context.Context, token *AccessToken, newT return newToken, err } -func (s *AccessTokenStore) FindByID(ctx context.Context, id int64) (token *AccessToken, err error) { +func (s *accessTokenStoreImpl) FindByID(ctx context.Context, id int64) (token *AccessToken, err error) { var tokens []AccessToken err = s.db.Operator.Core. NewSelect(). @@ -92,7 +106,7 @@ func (s *AccessTokenStore) FindByID(ctx context.Context, id int64) (token *Acces return } -func (s *AccessTokenStore) Delete(ctx context.Context, username, tkName, app string) (err error) { +func (s *accessTokenStoreImpl) Delete(ctx context.Context, username, tkName, app string) (err error) { var token AccessToken _, err = s.db.Operator.Core. NewDelete(). @@ -105,7 +119,7 @@ func (s *AccessTokenStore) Delete(ctx context.Context, username, tkName, app str return } -func (s *AccessTokenStore) IsExist(ctx context.Context, username, tkName, app string) (exists bool, err error) { +func (s *accessTokenStoreImpl) IsExist(ctx context.Context, username, tkName, app string) (exists bool, err error) { var token AccessToken exists, err = s.db.Operator.Core. NewSelect(). @@ -117,7 +131,7 @@ func (s *AccessTokenStore) IsExist(ctx context.Context, username, tkName, app st return } -func (s *AccessTokenStore) FindByUID(ctx context.Context, uid int64) (token *AccessToken, err error) { +func (s *accessTokenStoreImpl) FindByUID(ctx context.Context, uid int64) (token *AccessToken, err error) { var tokens []AccessToken err = s.db.Operator.Core. NewSelect(). @@ -139,7 +153,7 @@ func (s *AccessTokenStore) FindByUID(ctx context.Context, uid int64) (token *Acc return } -func (s *AccessTokenStore) GetUserGitToken(ctx context.Context, username string) (*AccessToken, error) { +func (s *accessTokenStoreImpl) GetUserGitToken(ctx context.Context, username string) (*AccessToken, error) { var token AccessToken err := s.db.Operator.Core. NewSelect(). @@ -157,7 +171,7 @@ func (s *AccessTokenStore) GetUserGitToken(ctx context.Context, username string) return &token, nil } -func (s *AccessTokenStore) FindByToken(ctx context.Context, tokenValue, app string) (*AccessToken, error) { +func (s *accessTokenStoreImpl) FindByToken(ctx context.Context, tokenValue, app string) (*AccessToken, error) { var token AccessToken q := s.db.Operator.Core. NewSelect(). @@ -174,7 +188,7 @@ func (s *AccessTokenStore) FindByToken(ctx context.Context, tokenValue, app stri return &token, nil } -func (s *AccessTokenStore) FindByTokenName(ctx context.Context, username, tokenName, app string) (*AccessToken, error) { +func (s *accessTokenStoreImpl) FindByTokenName(ctx context.Context, username, tokenName, app string) (*AccessToken, error) { var token AccessToken q := s.db.Operator.Core. NewSelect(). @@ -188,7 +202,7 @@ func (s *AccessTokenStore) FindByTokenName(ctx context.Context, username, tokenN return &token, nil } -func (s *AccessTokenStore) FindByUser(ctx context.Context, username, app string) ([]AccessToken, error) { +func (s *accessTokenStoreImpl) FindByUser(ctx context.Context, username, app string) ([]AccessToken, error) { var tokens []AccessToken q := s.db.Operator.Core. NewSelect(). diff --git a/builder/store/database/account_metering.go b/builder/store/database/account_metering.go index 391bf8c3..5185abf6 100644 --- a/builder/store/database/account_metering.go +++ b/builder/store/database/account_metering.go @@ -10,12 +10,18 @@ import ( commonTypes "opencsg.com/csghub-server/common/types" ) -type AccountMeteringStore struct { +type accountMeteringStoreImpl struct { db *DB } -func NewAccountMeteringStore() *AccountMeteringStore { - return &AccountMeteringStore{ +type AccountMeteringStore interface { + Create(ctx context.Context, input AccountMetering) error + ListByUserIDAndTime(ctx context.Context, req commonTypes.ACCT_STATEMENTS_REQ) ([]AccountMetering, int, error) + ListAllByUserUUID(ctx context.Context, userUUID string) ([]AccountMetering, error) +} + +func NewAccountMeteringStore() AccountMeteringStore { + return &accountMeteringStoreImpl{ db: defaultDB, } } @@ -37,7 +43,7 @@ type AccountMetering struct { SkuUnitType string `json:"sku_unit_type"` } -func (am *AccountMeteringStore) Create(ctx context.Context, input AccountMetering) error { +func (am *accountMeteringStoreImpl) Create(ctx context.Context, input AccountMetering) error { res, err := am.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { return fmt.Errorf("failed to save metering event, error: %w", err) @@ -45,7 +51,7 @@ func (am *AccountMeteringStore) Create(ctx context.Context, input AccountMeterin return nil } -func (am *AccountMeteringStore) ListByUserIDAndTime(ctx context.Context, req commonTypes.ACCT_STATEMENTS_REQ) ([]AccountMetering, int, error) { +func (am *accountMeteringStoreImpl) ListByUserIDAndTime(ctx context.Context, req commonTypes.ACCT_STATEMENTS_REQ) ([]AccountMetering, int, error) { var accountMeters []AccountMetering q := am.db.Operator.Core.NewSelect().Model(&accountMeters).Where("user_uuid = ? and scene = ? and customer_id = ? and recorded_at >= ? and recorded_at <= ?", req.UserUUID, req.Scene, req.InstanceName, req.StartTime, req.EndTime) @@ -61,7 +67,7 @@ func (am *AccountMeteringStore) ListByUserIDAndTime(ctx context.Context, req com return accountMeters, count, nil } -func (am *AccountMeteringStore) ListAllByUserUUID(ctx context.Context, userUUID string) ([]AccountMetering, error) { +func (am *accountMeteringStoreImpl) ListAllByUserUUID(ctx context.Context, userUUID string) ([]AccountMetering, error) { var accountMeters []AccountMetering err := am.db.Operator.Core.NewSelect().Model(&accountMeters).Where("user_uuid = ?", userUUID).Scan(ctx, &accountMeters) if err != nil { diff --git a/builder/store/database/cluster.go b/builder/store/database/cluster.go index 3a17d93a..f79a3f53 100644 --- a/builder/store/database/cluster.go +++ b/builder/store/database/cluster.go @@ -9,12 +9,20 @@ import ( "github.com/uptrace/bun" ) -type ClusterInfoStore struct { +type clusterInfoStoreImpl struct { db *DB } -func NewClusterInfoStore() *ClusterInfoStore { - return &ClusterInfoStore{ +type ClusterInfoStore interface { + Add(ctx context.Context, clusterConfig string, region string) error + Update(ctx context.Context, clusterInfo ClusterInfo) error + ByClusterID(ctx context.Context, clusterId string) (clusterInfo ClusterInfo, err error) + ByClusterConfig(ctx context.Context, clusterConfig string) (clusterInfo ClusterInfo, err error) + List(ctx context.Context) ([]ClusterInfo, error) +} + +func NewClusterInfoStore() ClusterInfoStore { + return &clusterInfoStoreImpl{ db: defaultDB, } } @@ -29,7 +37,7 @@ type ClusterInfo struct { Enable bool `bun:",notnull" json:"enable"` } -func (r *ClusterInfoStore) Add(ctx context.Context, clusterConfig string, region string) error { +func (r *clusterInfoStoreImpl) Add(ctx context.Context, clusterConfig string, region string) error { err := r.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { cluster := &ClusterInfo{ ClusterID: uuid.New().String(), @@ -47,7 +55,7 @@ func (r *ClusterInfoStore) Add(ctx context.Context, clusterConfig string, region return err } -func (r *ClusterInfoStore) Update(ctx context.Context, clusterInfo ClusterInfo) error { +func (r *clusterInfoStoreImpl) Update(ctx context.Context, clusterInfo ClusterInfo) error { err := r.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { _, err := r.ByClusterConfig(ctx, clusterInfo.ClusterConfig) if err == nil { @@ -58,19 +66,19 @@ func (r *ClusterInfoStore) Update(ctx context.Context, clusterInfo ClusterInfo) return err } -func (s *ClusterInfoStore) ByClusterID(ctx context.Context, clusterId string) (clusterInfo ClusterInfo, err error) { +func (s *clusterInfoStoreImpl) ByClusterID(ctx context.Context, clusterId string) (clusterInfo ClusterInfo, err error) { clusterInfo.ClusterID = clusterId err = s.db.Operator.Core.NewSelect().Model(&clusterInfo).Where("cluster_id = ?", clusterId).Scan(ctx) return } -func (s *ClusterInfoStore) ByClusterConfig(ctx context.Context, clusterConfig string) (clusterInfo ClusterInfo, err error) { +func (s *clusterInfoStoreImpl) ByClusterConfig(ctx context.Context, clusterConfig string) (clusterInfo ClusterInfo, err error) { clusterInfo.ClusterConfig = clusterConfig err = s.db.Operator.Core.NewSelect().Model(&clusterInfo).Where("cluster_config = ?", clusterConfig).Scan(ctx) return } -func (s *ClusterInfoStore) List(ctx context.Context) ([]ClusterInfo, error) { +func (s *clusterInfoStoreImpl) List(ctx context.Context) ([]ClusterInfo, error) { var result []ClusterInfo _, err := s.db.Operator.Core.NewSelect().Model(&result).Order("region").Exec(ctx, &result) if err != nil { diff --git a/builder/store/database/code.go b/builder/store/database/code.go index aa9c69c3..404c3031 100644 --- a/builder/store/database/code.go +++ b/builder/store/database/code.go @@ -9,12 +9,25 @@ import ( "github.com/uptrace/bun" ) -type CodeStore struct { +type codeStoreImpl struct { db *DB } -func NewCodeStore() *CodeStore { - return &CodeStore{db: defaultDB} +type CodeStore interface { + ByRepoIDs(ctx context.Context, repoIDs []int64) (codes []Code, err error) + ByRepoID(ctx context.Context, repoID int64) (*Code, error) + ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (codes []Code, total int, err error) + UserLikesCodes(ctx context.Context, userID int64, per, page int) (codes []Code, total int, err error) + ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (codes []Code, total int, err error) + Create(ctx context.Context, input Code) (*Code, error) + Update(ctx context.Context, input Code) (err error) + FindByPath(ctx context.Context, namespace string, repoPath string) (code *Code, err error) + Delete(ctx context.Context, input Code) error + ListByPath(ctx context.Context, paths []string) ([]Code, error) +} + +func NewCodeStore() CodeStore { + return &codeStoreImpl{db: defaultDB} } type Code struct { @@ -25,7 +38,7 @@ type Code struct { times } -func (s *CodeStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (codes []Code, err error) { +func (s *codeStoreImpl) ByRepoIDs(ctx context.Context, repoIDs []int64) (codes []Code, err error) { err = s.db.Operator.Core.NewSelect(). Model(&codes). Where("repository_id in (?)", bun.In(repoIDs)). @@ -34,7 +47,7 @@ func (s *CodeStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (codes []Cod return } -func (s *CodeStore) ByRepoID(ctx context.Context, repoID int64) (*Code, error) { +func (s *codeStoreImpl) ByRepoID(ctx context.Context, repoID int64) (*Code, error) { var code Code err := s.db.Operator.Core.NewSelect(). Model(&code). @@ -48,7 +61,7 @@ func (s *CodeStore) ByRepoID(ctx context.Context, repoID int64) (*Code, error) { return &code, nil } -func (s *CodeStore) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (codes []Code, total int, err error) { +func (s *codeStoreImpl) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (codes []Code, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&codes). @@ -74,7 +87,7 @@ func (s *CodeStore) ByUsername(ctx context.Context, username string, per, page i return } -func (s *CodeStore) UserLikesCodes(ctx context.Context, userID int64, per, page int) (codes []Code, total int, err error) { +func (s *codeStoreImpl) UserLikesCodes(ctx context.Context, userID int64, per, page int) (codes []Code, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&codes). @@ -97,7 +110,7 @@ func (s *CodeStore) UserLikesCodes(ctx context.Context, userID int64, per, page return } -func (s *CodeStore) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (codes []Code, total int, err error) { +func (s *codeStoreImpl) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (codes []Code, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&codes). @@ -123,7 +136,7 @@ func (s *CodeStore) ByOrgPath(ctx context.Context, namespace string, per, page i return } -func (s *CodeStore) Create(ctx context.Context, input Code) (*Code, error) { +func (s *codeStoreImpl) Create(ctx context.Context, input Code) (*Code, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { slog.Error("create code in db failed", slog.String("error", err.Error())) @@ -133,12 +146,12 @@ func (s *CodeStore) Create(ctx context.Context, input Code) (*Code, error) { return &input, nil } -func (s *CodeStore) Update(ctx context.Context, input Code) (err error) { +func (s *codeStoreImpl) Update(ctx context.Context, input Code) (err error) { _, err = s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return } -func (s *CodeStore) FindByPath(ctx context.Context, namespace string, repoPath string) (code *Code, err error) { +func (s *codeStoreImpl) FindByPath(ctx context.Context, namespace string, repoPath string) (code *Code, err error) { resCode := new(Code) err = s.db.Operator.Core. NewSelect(). @@ -159,7 +172,7 @@ func (s *CodeStore) FindByPath(ctx context.Context, namespace string, repoPath s return resCode, err } -func (s *CodeStore) Delete(ctx context.Context, input Code) error { +func (s *codeStoreImpl) Delete(ctx context.Context, input Code) error { res, err := s.db.Operator.Core.NewDelete().Model(&input).WherePK().Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { return fmt.Errorf("delete code in tx failed,error:%w", err) @@ -167,7 +180,7 @@ func (s *CodeStore) Delete(ctx context.Context, input Code) error { return nil } -func (s *CodeStore) ListByPath(ctx context.Context, paths []string) ([]Code, error) { +func (s *codeStoreImpl) ListByPath(ctx context.Context, paths []string) ([]Code, error) { var codes []Code err := s.db.Operator.Core. NewSelect(). diff --git a/builder/store/database/collection.go b/builder/store/database/collection.go index 00097b73..d7d48838 100644 --- a/builder/store/database/collection.go +++ b/builder/store/database/collection.go @@ -10,12 +10,31 @@ import ( "opencsg.com/csghub-server/common/types" ) -type CollectionStore struct { +type collectionStoreImpl struct { db *DB } -func NewCollectionStore() *CollectionStore { - return &CollectionStore{ +type CollectionStore interface { + // query collections in the database + GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int, showPrivate bool) (collections []Collection, total int, err error) + // query collections in the database + QueryByTrending(ctx context.Context, filter *types.CollectionFilter, per, page int) (collections []Collection, total int, err error) + CreateCollection(ctx context.Context, collection Collection) (*Collection, error) + DeleteCollection(ctx context.Context, id int64, uid int64) error + UpdateCollection(ctx context.Context, collection Collection) (*Collection, error) + GetCollection(ctx context.Context, id int64) (*Collection, error) + ByUserLikes(ctx context.Context, userID int64, per, page int) (collections []Collection, total int, err error) + ByUserOrgs(ctx context.Context, namespace string, per, page int, onlyPublic bool) (collections []Collection, total int, err error) + // get collections by ids + GetCollectionsByIDs(ctx context.Context, collections []Collection, ids []interface{}, total int, onlyPublic bool) ([]Collection, int, error) + FindById(ctx context.Context, id int64) (collection Collection, err error) + AddCollectionRepos(ctx context.Context, crs []CollectionRepository) error + RemoveCollectionRepos(ctx context.Context, crs []CollectionRepository) error + ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (collections []Collection, total int, err error) +} + +func NewCollectionStore() CollectionStore { + return &collectionStoreImpl{ db: defaultDB, } } @@ -52,7 +71,7 @@ type RankedRepository struct { var Fields = []string{"id", "download_count", "likes", "path", "private", "repository_type", "updated_at", "created_at", "user_id", "name", "nickname", "description"} // query collections in the database -func (cs *CollectionStore) GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int, showPrivate bool) (collections []Collection, total int, err error) { +func (cs *collectionStoreImpl) GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int, showPrivate bool) (collections []Collection, total int, err error) { if filter.Sort == "trending" { return cs.QueryByTrending(ctx, filter, per, page) } @@ -85,7 +104,7 @@ func (cs *CollectionStore) GetCollections(ctx context.Context, filter *types.Col } // query collections in the database -func (cs *CollectionStore) QueryByTrending(ctx context.Context, filter *types.CollectionFilter, per, page int) (collections []Collection, total int, err error) { +func (cs *collectionStoreImpl) QueryByTrending(ctx context.Context, filter *types.CollectionFilter, per, page int) (collections []Collection, total int, err error) { query := cs.db.Operator.Core.NewSelect(). Model(&collections). Column("collection.*"). @@ -122,7 +141,7 @@ func (cs *CollectionStore) QueryByTrending(ctx context.Context, filter *types.Co return cs.GetCollectionsByIDs(ctx, collections, ids, total, true) } -func (cs *CollectionStore) CreateCollection(ctx context.Context, collection Collection) (*Collection, error) { +func (cs *collectionStoreImpl) CreateCollection(ctx context.Context, collection Collection) (*Collection, error) { res, err := cs.db.Core.NewInsert().Model(&collection).Exec(ctx, &collection) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("failed to create collection in db, error:%w", err) @@ -131,7 +150,7 @@ func (cs *CollectionStore) CreateCollection(ctx context.Context, collection Coll return &collection, nil } -func (cs *CollectionStore) DeleteCollection(ctx context.Context, id int64, uid int64) error { +func (cs *collectionStoreImpl) DeleteCollection(ctx context.Context, id int64, uid int64) error { var collection Collection res, err := cs.db.Operator.Core.NewDelete().Model(&collection).Where("id =?", id).Where("user_id =?", uid).Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { @@ -140,13 +159,13 @@ func (cs *CollectionStore) DeleteCollection(ctx context.Context, id int64, uid i return nil } -func (cs *CollectionStore) UpdateCollection(ctx context.Context, collection Collection) (*Collection, error) { +func (cs *collectionStoreImpl) UpdateCollection(ctx context.Context, collection Collection) (*Collection, error) { _, err := cs.db.Core.NewUpdate().Model(&collection).WherePK().Exec(ctx) return &collection, err } -func (cs *CollectionStore) GetCollection(ctx context.Context, id int64) (*Collection, error) { +func (cs *collectionStoreImpl) GetCollection(ctx context.Context, id int64) (*Collection, error) { collection := new(Collection) err := cs.db.Operator.Core. NewSelect(). @@ -166,7 +185,7 @@ func (cs *CollectionStore) GetCollection(ctx context.Context, id int64) (*Collec return collection, err } -func (cs *CollectionStore) ByUserLikes(ctx context.Context, userID int64, per, page int) (collections []Collection, total int, err error) { +func (cs *collectionStoreImpl) ByUserLikes(ctx context.Context, userID int64, per, page int) (collections []Collection, total int, err error) { query := cs.db.Operator.Core. NewSelect(). Model(&collections). @@ -192,7 +211,7 @@ func (cs *CollectionStore) ByUserLikes(ctx context.Context, userID int64, per, p return cs.GetCollectionsByIDs(ctx, collections, ids, total, true) } -func (cs *CollectionStore) ByUserOrgs(ctx context.Context, namespace string, per, page int, onlyPublic bool) (collections []Collection, total int, err error) { +func (cs *collectionStoreImpl) ByUserOrgs(ctx context.Context, namespace string, per, page int, onlyPublic bool) (collections []Collection, total int, err error) { query := cs.db.Operator.Core. NewSelect(). Model(&collections). @@ -223,7 +242,7 @@ func (cs *CollectionStore) ByUserOrgs(ctx context.Context, namespace string, per } // get collections by ids -func (cs *CollectionStore) GetCollectionsByIDs(ctx context.Context, collections []Collection, ids []interface{}, total int, onlyPublic bool) ([]Collection, int, error) { +func (cs *collectionStoreImpl) GetCollectionsByIDs(ctx context.Context, collections []Collection, ids []interface{}, total int, onlyPublic bool) ([]Collection, int, error) { subQuery := cs.db.Operator.Core.NewSelect(). Column("cr.collection_id"). ColumnExpr("repository.id as repository_id"). @@ -285,7 +304,7 @@ func getCollectionMaps(rankedRepos []RankedRepository, repositories []Repository return } -func (cs *CollectionStore) FindById(ctx context.Context, id int64) (collection Collection, err error) { +func (cs *collectionStoreImpl) FindById(ctx context.Context, id int64) (collection Collection, err error) { q := cs.db.Operator.Core. NewSelect() err = q. @@ -295,7 +314,7 @@ func (cs *CollectionStore) FindById(ctx context.Context, id int64) (collection C return } -func (cs *CollectionStore) AddCollectionRepos(ctx context.Context, crs []CollectionRepository) error { +func (cs *collectionStoreImpl) AddCollectionRepos(ctx context.Context, crs []CollectionRepository) error { result, err := cs.db.Core.NewInsert().Model(&crs).Exec(ctx) if err != nil { @@ -305,7 +324,7 @@ func (cs *CollectionStore) AddCollectionRepos(ctx context.Context, crs []Collect return assertAffectedXRows(int64(len(crs)), result, err) } -func (cs *CollectionStore) RemoveCollectionRepos(ctx context.Context, crs []CollectionRepository) error { +func (cs *collectionStoreImpl) RemoveCollectionRepos(ctx context.Context, crs []CollectionRepository) error { for _, cr := range crs { _, err := cs.db.Core.NewDelete(). Model((*CollectionRepository)(nil)). @@ -318,7 +337,7 @@ func (cs *CollectionStore) RemoveCollectionRepos(ctx context.Context, crs []Coll return nil } -func (cs *CollectionStore) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (collections []Collection, total int, err error) { +func (cs *collectionStoreImpl) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (collections []Collection, total int, err error) { query := cs.db.Operator.Core. NewSelect(). Model(&collections). diff --git a/builder/store/database/dataset.go b/builder/store/database/dataset.go index 55040d83..d6c49cf3 100644 --- a/builder/store/database/dataset.go +++ b/builder/store/database/dataset.go @@ -16,12 +16,26 @@ var sortBy = map[string]string{ "most_favorite": "likes DESC NULLS LAST", } -type DatasetStore struct { +type datasetStoreImpl struct { db *DB } -func NewDatasetStore() *DatasetStore { - return &DatasetStore{db: defaultDB} +type DatasetStore interface { + ByRepoIDs(ctx context.Context, repoIDs []int64) (datasets []Dataset, err error) + ByRepoID(ctx context.Context, repoID int64) (*Dataset, error) + ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (datasets []Dataset, total int, err error) + UserLikesDatasets(ctx context.Context, userID int64, per, page int) (datasets []Dataset, total int, err error) + ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (datasets []Dataset, total int, err error) + Create(ctx context.Context, input Dataset) (*Dataset, error) + Update(ctx context.Context, input Dataset) (err error) + FindByPath(ctx context.Context, namespace string, repoPath string) (dataset *Dataset, err error) + Delete(ctx context.Context, input Dataset) error + ListByPath(ctx context.Context, paths []string) ([]Dataset, error) + CreateIfNotExist(ctx context.Context, input Dataset) (*Dataset, error) +} + +func NewDatasetStore() DatasetStore { + return &datasetStoreImpl{db: defaultDB} } type Dataset struct { @@ -32,7 +46,7 @@ type Dataset struct { times } -func (s *DatasetStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (datasets []Dataset, err error) { +func (s *datasetStoreImpl) ByRepoIDs(ctx context.Context, repoIDs []int64) (datasets []Dataset, err error) { q := s.db.Operator.Core.NewSelect(). Model(&datasets). Relation("Repository"). @@ -42,7 +56,7 @@ func (s *DatasetStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (datasets return } -func (s *DatasetStore) ByRepoID(ctx context.Context, repoID int64) (*Dataset, error) { +func (s *datasetStoreImpl) ByRepoID(ctx context.Context, repoID int64) (*Dataset, error) { var dataset Dataset err := s.db.Operator.Core.NewSelect(). Model(&dataset). @@ -55,7 +69,7 @@ func (s *DatasetStore) ByRepoID(ctx context.Context, repoID int64) (*Dataset, er return &dataset, nil } -func (s *DatasetStore) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (datasets []Dataset, total int, err error) { +func (s *datasetStoreImpl) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (datasets []Dataset, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&datasets). @@ -81,7 +95,7 @@ func (s *DatasetStore) ByUsername(ctx context.Context, username string, per, pag return } -func (s *DatasetStore) UserLikesDatasets(ctx context.Context, userID int64, per, page int) (datasets []Dataset, total int, err error) { +func (s *datasetStoreImpl) UserLikesDatasets(ctx context.Context, userID int64, per, page int) (datasets []Dataset, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&datasets). @@ -104,7 +118,7 @@ func (s *DatasetStore) UserLikesDatasets(ctx context.Context, userID int64, per, return } -func (s *DatasetStore) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (datasets []Dataset, total int, err error) { +func (s *datasetStoreImpl) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (datasets []Dataset, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&datasets). @@ -130,7 +144,7 @@ func (s *DatasetStore) ByOrgPath(ctx context.Context, namespace string, per, pag return } -func (s *DatasetStore) Create(ctx context.Context, input Dataset) (*Dataset, error) { +func (s *datasetStoreImpl) Create(ctx context.Context, input Dataset) (*Dataset, error) { input.LastUpdatedAt = time.Now() res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { @@ -141,13 +155,13 @@ func (s *DatasetStore) Create(ctx context.Context, input Dataset) (*Dataset, err return &input, nil } -func (s *DatasetStore) Update(ctx context.Context, input Dataset) (err error) { +func (s *datasetStoreImpl) Update(ctx context.Context, input Dataset) (err error) { input.LastUpdatedAt = time.Now() _, err = s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return } -func (s *DatasetStore) FindByPath(ctx context.Context, namespace string, repoPath string) (dataset *Dataset, err error) { +func (s *datasetStoreImpl) FindByPath(ctx context.Context, namespace string, repoPath string) (dataset *Dataset, err error) { resDataset := new(Dataset) err = s.db.Operator.Core. NewSelect(). @@ -169,7 +183,7 @@ func (s *DatasetStore) FindByPath(ctx context.Context, namespace string, repoPat return resDataset, err } -func (s *DatasetStore) Delete(ctx context.Context, input Dataset) error { +func (s *datasetStoreImpl) Delete(ctx context.Context, input Dataset) error { res, err := s.db.Operator.Core.NewDelete().Model(&input).WherePK().Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { return fmt.Errorf("delete dataset in tx failed,error:%w", err) @@ -177,7 +191,7 @@ func (s *DatasetStore) Delete(ctx context.Context, input Dataset) error { return nil } -func (s *DatasetStore) ListByPath(ctx context.Context, paths []string) ([]Dataset, error) { +func (s *datasetStoreImpl) ListByPath(ctx context.Context, paths []string) ([]Dataset, error) { var datasets []Dataset err := s.db.Operator.Core. NewSelect(). @@ -202,7 +216,7 @@ func (s *DatasetStore) ListByPath(ctx context.Context, paths []string) ([]Datase return sortedDatasets, nil } -func (s *DatasetStore) CreateIfNotExist(ctx context.Context, input Dataset) (*Dataset, error) { +func (s *datasetStoreImpl) CreateIfNotExist(ctx context.Context, input Dataset) (*Dataset, error) { err := s.db.Core.NewSelect(). Model(&input). Where("repository_id = ?", input.RepositoryID). diff --git a/builder/store/database/deploy_task.go b/builder/store/database/deploy_task.go index 05b09ba5..d455a26a 100644 --- a/builder/store/database/deploy_task.go +++ b/builder/store/database/deploy_task.go @@ -61,41 +61,66 @@ type DeployTask struct { times } -type DeployTaskStore struct { +type deployTaskStoreImpl struct { db *DB } -func NewDeployTaskStore() *DeployTaskStore { - return &DeployTaskStore{db: defaultDB} +type DeployTaskStore interface { + CreateDeploy(ctx context.Context, deploy *Deploy) error + UpdateDeploy(ctx context.Context, deploy *Deploy) error + GetLatestDeployBySpaceID(ctx context.Context, spaceID int64) (*Deploy, error) + CreateDeployTask(ctx context.Context, deployTask *DeployTask) error + UpdateDeployTask(ctx context.Context, deployTask *DeployTask) error + GetDeployTask(ctx context.Context, id int64) (*DeployTask, error) + GetDeployTasksOfDeploy(ctx context.Context, deployID int64) ([]*DeployTask, error) + // GetNewTaskAfter return the first task of the next deploy + GetNewTaskAfter(ctx context.Context, currentDeployTaskID int64) (*DeployTask, error) + // GetNewTaskFirst returns the first task which has not end + GetNewTaskFirst(ctx context.Context) (*DeployTask, error) + UpdateInTx(ctx context.Context, deployColumns, deployTaskColumns []string, deploy *Deploy, deployTasks ...*DeployTask) error + ListDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64) ([]Deploy, error) + DeleteDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64, deployID int64) error + ListDeployByUserID(ctx context.Context, userID int64, req *types.DeployReq) ([]Deploy, int, error) + ListInstancesByUserID(ctx context.Context, userID int64, per, page int) ([]Deploy, int, error) + GetDeployByID(ctx context.Context, deployID int64) (*Deploy, error) + GetDeployBySvcName(ctx context.Context, svcName string) (*Deploy, error) + StopDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64, deployID int64) error + GetServerlessDeployByRepID(ctx context.Context, repoID int64) (*Deploy, error) + ListServerless(ctx context.Context, req types.DeployReq) ([]Deploy, int, error) + ListAllDeployments(ctx context.Context, userID int64) ([]Deploy, error) } -func (s *DeployTaskStore) CreateDeploy(ctx context.Context, deploy *Deploy) error { +func NewDeployTaskStore() DeployTaskStore { + return &deployTaskStoreImpl{db: defaultDB} +} + +func (s *deployTaskStoreImpl) CreateDeploy(ctx context.Context, deploy *Deploy) error { _, err := s.db.Core.NewInsert().Model(deploy).Exec(ctx, deploy) return err } -func (s *DeployTaskStore) UpdateDeploy(ctx context.Context, deploy *Deploy) error { +func (s *deployTaskStoreImpl) UpdateDeploy(ctx context.Context, deploy *Deploy) error { _, err := s.db.Core.NewUpdate().Model(deploy).WherePK().Exec(ctx) return err } -func (s *DeployTaskStore) GetLatestDeployBySpaceID(ctx context.Context, spaceID int64) (*Deploy, error) { +func (s *deployTaskStoreImpl) GetLatestDeployBySpaceID(ctx context.Context, spaceID int64) (*Deploy, error) { deploy := &Deploy{} err := s.db.Core.NewSelect().Model(deploy).Where("space_id = ?", spaceID).Order("created_at DESC").Limit(1).Scan(ctx, deploy) return deploy, err } -func (s *DeployTaskStore) CreateDeployTask(ctx context.Context, deployTask *DeployTask) error { +func (s *deployTaskStoreImpl) CreateDeployTask(ctx context.Context, deployTask *DeployTask) error { _, err := s.db.Core.NewInsert().Model(deployTask).Exec(ctx, deployTask) return err } -func (s *DeployTaskStore) UpdateDeployTask(ctx context.Context, deployTask *DeployTask) error { +func (s *deployTaskStoreImpl) UpdateDeployTask(ctx context.Context, deployTask *DeployTask) error { _, err := s.db.Core.NewUpdate().Model(deployTask).WherePK().Exec(ctx) return err } -func (s *DeployTaskStore) GetDeployTask(ctx context.Context, id int64) (*DeployTask, error) { +func (s *deployTaskStoreImpl) GetDeployTask(ctx context.Context, id int64) (*DeployTask, error) { deployTask := &DeployTask{} err := s.db.Core.NewSelect().Model(deployTask).Where("deploy_task.id = ?", id). Relation("Deploy"). @@ -104,14 +129,14 @@ func (s *DeployTaskStore) GetDeployTask(ctx context.Context, id int64) (*DeployT return deployTask, err } -func (s *DeployTaskStore) GetDeployTasksOfDeploy(ctx context.Context, deployID int64) ([]*DeployTask, error) { +func (s *deployTaskStoreImpl) GetDeployTasksOfDeploy(ctx context.Context, deployID int64) ([]*DeployTask, error) { var deployTasks []*DeployTask err := s.db.Core.NewSelect().Model((*DeployTask)(nil)).Where("deploy_id = ?", deployID).Scan(ctx, &deployTasks) return deployTasks, err } // GetNewTaskAfter return the first task of the next deploy -func (s *DeployTaskStore) GetNewTaskAfter(ctx context.Context, currentDeployTaskID int64) (*DeployTask, error) { +func (s *deployTaskStoreImpl) GetNewTaskAfter(ctx context.Context, currentDeployTaskID int64) (*DeployTask, error) { deployTask := &DeployTask{} err := s.db.Core.NewSelect().Model(deployTask).Relation("Deploy"). Where("deploy_task.id > ? ", currentDeployTaskID). @@ -123,7 +148,7 @@ func (s *DeployTaskStore) GetNewTaskAfter(ctx context.Context, currentDeployTask } // GetNewTaskFirst returns the first task which has not end -func (s *DeployTaskStore) GetNewTaskFirst(ctx context.Context) (*DeployTask, error) { +func (s *deployTaskStoreImpl) GetNewTaskFirst(ctx context.Context) (*DeployTask, error) { deployTask := &DeployTask{} err := s.db.Core.NewSelect().Model(deployTask).Relation("Deploy"). Where("(task_type = 0 and deploy_task.status in (0,1)) or (task_type = 1 and deploy_task.status in (0,1,3))"). @@ -133,7 +158,7 @@ func (s *DeployTaskStore) GetNewTaskFirst(ctx context.Context) (*DeployTask, err return deployTask, err } -func (s *DeployTaskStore) UpdateInTx(ctx context.Context, deployColumns, deployTaskColumns []string, deploy *Deploy, deployTasks ...*DeployTask) error { +func (s *deployTaskStoreImpl) UpdateInTx(ctx context.Context, deployColumns, deployTaskColumns []string, deploy *Deploy, deployTasks ...*DeployTask) error { tx, err := s.db.Core.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction,%w", err) @@ -169,7 +194,7 @@ func (s *DeployTaskStore) UpdateInTx(ctx context.Context, deployColumns, deployT return tx.Commit() } -func (s *DeployTaskStore) ListDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64) ([]Deploy, error) { +func (s *deployTaskStoreImpl) ListDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64) ([]Deploy, error) { var result []Deploy query := s.db.Operator.Core.NewSelect().Model(&result).Where("user_id = ? and repo_id = ?", userID, repoID) if repoType == types.ModelRepo { @@ -186,7 +211,7 @@ func (s *DeployTaskStore) ListDeploy(ctx context.Context, repoType types.Reposit return result, nil } -func (s *DeployTaskStore) DeleteDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64, deployID int64) error { +func (s *deployTaskStoreImpl) DeleteDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64, deployID int64) error { // only delete the deploy of specific repo was triggered by current login user res, err := s.db.BunDB.Exec("Update deploys set status = ? where id = ? and repo_id = ? and user_id = ?", common.Deleted, deployID, repoID, userID) if err != nil { @@ -196,7 +221,7 @@ func (s *DeployTaskStore) DeleteDeploy(ctx context.Context, repoType types.Repos return err } -func (s *DeployTaskStore) ListDeployByUserID(ctx context.Context, userID int64, req *types.DeployReq) ([]Deploy, int, error) { +func (s *deployTaskStoreImpl) ListDeployByUserID(ctx context.Context, userID int64, req *types.DeployReq) ([]Deploy, int, error) { var result []Deploy query := s.db.Operator.Core.NewSelect().Model(&result).Where("user_id = ? and type = ?", userID, req.DeployType) if req.RepoType == types.ModelRepo { @@ -219,7 +244,7 @@ func (s *DeployTaskStore) ListDeployByUserID(ctx context.Context, userID int64, return result, total, nil } -func (s *DeployTaskStore) ListInstancesByUserID(ctx context.Context, userID int64, per, page int) ([]Deploy, int, error) { +func (s *deployTaskStoreImpl) ListInstancesByUserID(ctx context.Context, userID int64, per, page int) ([]Deploy, int, error) { var result []Deploy query := s.db.Operator.Core.NewSelect().Model(&result).Where("user_id = ?", userID) query = query.Where("type = ? and status != ?", types.FinetuneType, common.Deleted) @@ -236,19 +261,19 @@ func (s *DeployTaskStore) ListInstancesByUserID(ctx context.Context, userID int6 return result, total, nil } -func (s *DeployTaskStore) GetDeployByID(ctx context.Context, deployID int64) (*Deploy, error) { +func (s *deployTaskStoreImpl) GetDeployByID(ctx context.Context, deployID int64) (*Deploy, error) { deploy := &Deploy{} err := s.db.Operator.Core.NewSelect().Model(deploy).Where("id = ?", deployID).Scan(ctx, deploy) return deploy, err } -func (s *DeployTaskStore) GetDeployBySvcName(ctx context.Context, svcName string) (*Deploy, error) { +func (s *deployTaskStoreImpl) GetDeployBySvcName(ctx context.Context, svcName string) (*Deploy, error) { deploy := &Deploy{} err := s.db.Operator.Core.NewSelect().Model(deploy).Where("svc_name = ?", svcName).Scan(ctx, deploy) return deploy, err } -func (s *DeployTaskStore) StopDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64, deployID int64) error { +func (s *deployTaskStoreImpl) StopDeploy(ctx context.Context, repoType types.RepositoryType, repoID, userID int64, deployID int64) error { // only stop the deploy of specific repo was triggered by current login user res, err := s.db.BunDB.Exec("Update deploys set status=?,updated_at=current_timestamp where id = ? and repo_id = ? and user_id = ?", common.Stopped, deployID, repoID, userID) if err != nil { @@ -258,7 +283,7 @@ func (s *DeployTaskStore) StopDeploy(ctx context.Context, repoType types.Reposit return err } -func (s *DeployTaskStore) GetServerlessDeployByRepID(ctx context.Context, repoID int64) (*Deploy, error) { +func (s *deployTaskStoreImpl) GetServerlessDeployByRepID(ctx context.Context, repoID int64) (*Deploy, error) { deploy := &Deploy{} err := s.db.Operator.Core.NewSelect().Model(deploy).Where("repo_id = ? and type = ?", repoID, types.ServerlessType).Scan(ctx, deploy) if errors.Is(err, sql.ErrNoRows) { @@ -270,7 +295,7 @@ func (s *DeployTaskStore) GetServerlessDeployByRepID(ctx context.Context, repoID return deploy, nil } -func (s *DeployTaskStore) ListServerless(ctx context.Context, req types.DeployReq) ([]Deploy, int, error) { +func (s *deployTaskStoreImpl) ListServerless(ctx context.Context, req types.DeployReq) ([]Deploy, int, error) { var result []Deploy query := s.db.Operator.Core.NewSelect().Model(&result).Where("type = ?", req.DeployType) query = query.Limit(req.PageSize).Offset((req.Page - 1) * req.PageSize) @@ -285,7 +310,7 @@ func (s *DeployTaskStore) ListServerless(ctx context.Context, req types.DeployRe return result, total, nil } -func (s *DeployTaskStore) ListAllDeployments(ctx context.Context, userID int64) ([]Deploy, error) { +func (s *deployTaskStoreImpl) ListAllDeployments(ctx context.Context, userID int64) ([]Deploy, error) { var result []Deploy err := s.db.Operator.Core.NewSelect(). Model(&result). diff --git a/builder/store/database/discussion.go b/builder/store/database/discussion.go index 4b9daee4..fb8c7d92 100644 --- a/builder/store/database/discussion.go +++ b/builder/store/database/discussion.go @@ -33,17 +33,30 @@ const ( DiscussionableTypeCollection = "collection" ) -type DiscussionStore struct { +type discussionStoreImpl struct { db *DB } -func NewDiscussionStore() *DiscussionStore { - return &DiscussionStore{ +type DiscussionStore interface { + Create(ctx context.Context, discussion Discussion) (*Discussion, error) + FindByID(ctx context.Context, id int64) (*Discussion, error) + FindByDiscussionableID(ctx context.Context, discussionableType string, discussionableID int64) ([]Discussion, error) + UpdateByID(ctx context.Context, id int64, title string) error + DeleteByID(ctx context.Context, id int64) error + FindDiscussionComments(ctx context.Context, discussionID int64) ([]Comment, error) + CreateComment(ctx context.Context, comment Comment) (*Comment, error) + UpdateComment(ctx context.Context, id int64, content string) error + FindCommentByID(ctx context.Context, id int64) (*Comment, error) + DeleteComment(ctx context.Context, id int64) error +} + +func NewDiscussionStore() DiscussionStore { + return &discussionStoreImpl{ db: defaultDB, } } -func (s *DiscussionStore) Create(ctx context.Context, discussion Discussion) (*Discussion, error) { +func (s *discussionStoreImpl) Create(ctx context.Context, discussion Discussion) (*Discussion, error) { _, err := s.db.Core.NewInsert().Model(&discussion).Exec(ctx) if err != nil { return nil, err @@ -51,7 +64,7 @@ func (s *DiscussionStore) Create(ctx context.Context, discussion Discussion) (*D return &discussion, nil } -func (s *DiscussionStore) FindByID(ctx context.Context, id int64) (*Discussion, error) { +func (s *discussionStoreImpl) FindByID(ctx context.Context, id int64) (*Discussion, error) { discussion := Discussion{} err := s.db.Core.NewSelect().Model(&discussion). Where("discussion.id = ?", id). @@ -63,7 +76,7 @@ func (s *DiscussionStore) FindByID(ctx context.Context, id int64) (*Discussion, return &discussion, nil } -func (s *DiscussionStore) FindByDiscussionableID(ctx context.Context, discussionableType string, discussionableID int64) ([]Discussion, error) { +func (s *discussionStoreImpl) FindByDiscussionableID(ctx context.Context, discussionableType string, discussionableID int64) ([]Discussion, error) { discussions := []Discussion{} err := s.db.Core.NewSelect().Model(&discussions). Where("discussionable_type = ? AND discussionable_id = ?", discussionableType, discussionableID). @@ -75,7 +88,7 @@ func (s *DiscussionStore) FindByDiscussionableID(ctx context.Context, discussion return discussions, nil } -func (s *DiscussionStore) UpdateByID(ctx context.Context, id int64, title string) error { +func (s *discussionStoreImpl) UpdateByID(ctx context.Context, id int64, title string) error { _, err := s.db.Core.NewUpdate().Model(&Discussion{}).Set("title = ?", title).Where("id = ?", id).Exec(ctx) if err != nil { return err @@ -83,7 +96,7 @@ func (s *DiscussionStore) UpdateByID(ctx context.Context, id int64, title string return nil } -func (s *DiscussionStore) DeleteByID(ctx context.Context, id int64) error { +func (s *discussionStoreImpl) DeleteByID(ctx context.Context, id int64) error { _, err := s.db.Core.NewDelete().Model(&Discussion{}).Where("id = ?", id).Exec(ctx) if err != nil { return err @@ -91,7 +104,7 @@ func (s *DiscussionStore) DeleteByID(ctx context.Context, id int64) error { return nil } -func (s *DiscussionStore) FindDiscussionComments(ctx context.Context, discussionID int64) ([]Comment, error) { +func (s *discussionStoreImpl) FindDiscussionComments(ctx context.Context, discussionID int64) ([]Comment, error) { comments := []Comment{} err := s.db.Core.NewSelect().Model(&comments). Relation("User"). @@ -103,7 +116,7 @@ func (s *DiscussionStore) FindDiscussionComments(ctx context.Context, discussion return comments, nil } -func (s *DiscussionStore) CreateComment(ctx context.Context, comment Comment) (*Comment, error) { +func (s *discussionStoreImpl) CreateComment(ctx context.Context, comment Comment) (*Comment, error) { _, err := s.db.Core.NewInsert().Model(&comment).Exec(ctx) if err != nil { return nil, err @@ -111,7 +124,7 @@ func (s *DiscussionStore) CreateComment(ctx context.Context, comment Comment) (* return &comment, nil } -func (s *DiscussionStore) UpdateComment(ctx context.Context, id int64, content string) error { +func (s *discussionStoreImpl) UpdateComment(ctx context.Context, id int64, content string) error { _, err := s.db.Core.NewUpdate().Model(&Comment{}).Set("content = ?", content).Where("id = ?", id).Exec(ctx) if err != nil { return err @@ -119,7 +132,7 @@ func (s *DiscussionStore) UpdateComment(ctx context.Context, id int64, content s return nil } -func (s *DiscussionStore) FindCommentByID(ctx context.Context, id int64) (*Comment, error) { +func (s *discussionStoreImpl) FindCommentByID(ctx context.Context, id int64) (*Comment, error) { comment := Comment{} err := s.db.Core.NewSelect().Model(&comment). Where("comment.id = ?", id). @@ -131,7 +144,7 @@ func (s *DiscussionStore) FindCommentByID(ctx context.Context, id int64) (*Comme return &comment, nil } -func (s *DiscussionStore) DeleteComment(ctx context.Context, id int64) error { +func (s *discussionStoreImpl) DeleteComment(ctx context.Context, id int64) error { _, err := s.db.Core.NewDelete().Model(&Comment{}).Where("id = ?", id).Exec(ctx) if err != nil { return err diff --git a/builder/store/database/event.go b/builder/store/database/event.go index d7af8793..1add4e6e 100644 --- a/builder/store/database/event.go +++ b/builder/store/database/event.go @@ -2,22 +2,28 @@ package database import "context" -type EventStore struct { +type eventStoreImpl struct { db *DB } -func NewEventStore() *EventStore { - return &EventStore{ +type EventStore interface { + Save(ctx context.Context, event Event) error + // batch insert + BatchSave(ctx context.Context, events []Event) error +} + +func NewEventStore() EventStore { + return &eventStoreImpl{ db: defaultDB, } } -func (s *EventStore) Save(ctx context.Context, event Event) error { +func (s *eventStoreImpl) Save(ctx context.Context, event Event) error { return assertAffectedOneRow(s.db.Core.NewInsert().Model(&event).Exec(ctx)) } // batch insert -func (s *EventStore) BatchSave(ctx context.Context, events []Event) error { +func (s *eventStoreImpl) BatchSave(ctx context.Context, events []Event) error { result, err := s.db.Core.NewInsert().Model(&events).Exec(ctx) return assertAffectedXRows(int64(len(events)), result, err) } diff --git a/builder/store/database/file.go b/builder/store/database/file.go index 5fe7d678..0278a516 100644 --- a/builder/store/database/file.go +++ b/builder/store/database/file.go @@ -4,12 +4,17 @@ import ( "context" ) -type FileStore struct { +type fileStoreImpl struct { db *DB } -func NewFileStore() *FileStore { - return &FileStore{ +type FileStore interface { + FindByParentPath(ctx context.Context, repoID int64, path string) ([]File, error) + BatchCreate(ctx context.Context, files []File) error +} + +func NewFileStore() FileStore { + return &fileStoreImpl{ db: defaultDB, } } @@ -28,7 +33,7 @@ type File struct { times } -func (s *FileStore) FindByParentPath(ctx context.Context, repoID int64, path string) ([]File, error) { +func (s *fileStoreImpl) FindByParentPath(ctx context.Context, repoID int64, path string) ([]File, error) { var files []File err := s.db.Operator.Core.NewSelect(). Model(&files). @@ -40,7 +45,7 @@ func (s *FileStore) FindByParentPath(ctx context.Context, repoID int64, path str return files, nil } -func (s *FileStore) BatchCreate(ctx context.Context, files []File) error { +func (s *fileStoreImpl) BatchCreate(ctx context.Context, files []File) error { result, err := s.db.Operator.Core.NewInsert(). Model(&files). Exec(ctx) diff --git a/builder/store/database/files.csv b/builder/store/database/files.csv new file mode 100644 index 00000000..7edf8317 --- /dev/null +++ b/builder/store/database/files.csv @@ -0,0 +1,68 @@ +File Name +repo_relation.go +account_present.go +account_price.go +lfs_lock.go +account_sync_quota.go +space_resource.go +mirror.go +runtime_architecture.go +recom.go +event.go +user.go +user_test.go +db.go +access_token.go +repository_file_check.go +prompt_test.go +user_like.go +ssh_key.go +multi_sync.go +account_statement.go +account_sync_quota_statement.go +db_query_option.go +lfs_meta_object.go +lfs.go +mirror_source.go +telemetry.go +code.go +member.go +account_event_test.go +qq.py +tag_rule.go +prompt_prefix.go +repository_file.go +organization.go +argo_workflow.go +tables.go +license.go +collection.go +sync_version.go +files.csv +runtime_framework.go +account_order.go +common.go +account_bill.go +cluster.go +repository_runtime_framework.go +file.go +prompt_conversation.go +sync_client_setting.go +model.go +deploy_task.go +dataset.go +account_users.go +tag.go +account_metering.go +user_resources.go +llm_config.go +git_server_access_token.go +repository.go +account_event.go +namespace.go +space.go +space_sdk.go +resources_models.go +prompt.go +repository_download.go +discussion.go diff --git a/builder/store/database/git_server_access_token.go b/builder/store/database/git_server_access_token.go index 731a637b..b2e8fcd9 100644 --- a/builder/store/database/git_server_access_token.go +++ b/builder/store/database/git_server_access_token.go @@ -2,12 +2,18 @@ package database import "context" -type GitServerAccessTokenStore struct { +type gitServerAccessTokenStoreImpl struct { db *DB } -func NewGitServerAccessTokenStore() *GitServerAccessTokenStore { - return &GitServerAccessTokenStore{ +type GitServerAccessTokenStore interface { + Create(ctx context.Context, gToken *GitServerAccessToken) (*GitServerAccessToken, error) + Index(ctx context.Context) ([]GitServerAccessToken, error) + FindByType(ctx context.Context, serverType string) ([]GitServerAccessToken, error) +} + +func NewGitServerAccessTokenStore() GitServerAccessTokenStore { + return &gitServerAccessTokenStoreImpl{ db: defaultDB, } } @@ -26,7 +32,7 @@ type GitServerAccessToken struct { times } -func (s *GitServerAccessTokenStore) Create(ctx context.Context, gToken *GitServerAccessToken) (*GitServerAccessToken, error) { +func (s *gitServerAccessTokenStoreImpl) Create(ctx context.Context, gToken *GitServerAccessToken) (*GitServerAccessToken, error) { err := s.db.Operator.Core.NewInsert(). Model(gToken). Scan(ctx) @@ -36,7 +42,7 @@ func (s *GitServerAccessTokenStore) Create(ctx context.Context, gToken *GitServe return gToken, nil } -func (s *GitServerAccessTokenStore) Index(ctx context.Context) ([]GitServerAccessToken, error) { +func (s *gitServerAccessTokenStoreImpl) Index(ctx context.Context) ([]GitServerAccessToken, error) { var gTokens []GitServerAccessToken err := s.db.Operator.Core.NewSelect(). Model(&gTokens). @@ -47,7 +53,7 @@ func (s *GitServerAccessTokenStore) Index(ctx context.Context) ([]GitServerAcces return gTokens, nil } -func (s *GitServerAccessTokenStore) FindByType(ctx context.Context, serverType string) ([]GitServerAccessToken, error) { +func (s *gitServerAccessTokenStoreImpl) FindByType(ctx context.Context, serverType string) ([]GitServerAccessToken, error) { var gTokens []GitServerAccessToken err := s.db.Operator.Core.NewSelect(). Model(&gTokens). diff --git a/builder/store/database/lfs_lock.go b/builder/store/database/lfs_lock.go index b049d91c..bd762eba 100644 --- a/builder/store/database/lfs_lock.go +++ b/builder/store/database/lfs_lock.go @@ -4,12 +4,20 @@ import ( "context" ) -type LfsLockStore struct { +type lfsLockStoreImpl struct { db *DB } -func NewLfsLockStore() *LfsLockStore { - return &LfsLockStore{ +type LfsLockStore interface { + FindByID(ctx context.Context, ID int64) (*LfsLock, error) + FindByPath(ctx context.Context, RepoId int64, path string) (*LfsLock, error) + FindByRepoID(ctx context.Context, RepoId int64, page, per int) ([]LfsLock, error) + Create(ctx context.Context, lfsLock LfsLock) (*LfsLock, error) + RemoveByID(ctx context.Context, ID int64) error +} + +func NewLfsLockStore() LfsLockStore { + return &lfsLockStoreImpl{ db: defaultDB, } } @@ -24,7 +32,7 @@ type LfsLock struct { times } -func (s *LfsLockStore) FindByID(ctx context.Context, ID int64) (*LfsLock, error) { +func (s *lfsLockStoreImpl) FindByID(ctx context.Context, ID int64) (*LfsLock, error) { var lfsLock LfsLock err := s.db.Operator.Core.NewSelect(). Model(&lfsLock). @@ -37,7 +45,7 @@ func (s *LfsLockStore) FindByID(ctx context.Context, ID int64) (*LfsLock, error) return &lfsLock, nil } -func (s *LfsLockStore) FindByPath(ctx context.Context, RepoId int64, path string) (*LfsLock, error) { +func (s *lfsLockStoreImpl) FindByPath(ctx context.Context, RepoId int64, path string) (*LfsLock, error) { var lfsLock LfsLock err := s.db.Operator.Core.NewSelect(). Model(&lfsLock). @@ -50,7 +58,7 @@ func (s *LfsLockStore) FindByPath(ctx context.Context, RepoId int64, path string return &lfsLock, nil } -func (s *LfsLockStore) FindByRepoID(ctx context.Context, RepoId int64, page, per int) ([]LfsLock, error) { +func (s *lfsLockStoreImpl) FindByRepoID(ctx context.Context, RepoId int64, page, per int) ([]LfsLock, error) { var lfsLocks []LfsLock query := s.db.Operator.Core.NewSelect(). Model(&lfsLocks). @@ -67,7 +75,7 @@ func (s *LfsLockStore) FindByRepoID(ctx context.Context, RepoId int64, page, per return lfsLocks, nil } -func (s *LfsLockStore) Create(ctx context.Context, lfsLock LfsLock) (*LfsLock, error) { +func (s *lfsLockStoreImpl) Create(ctx context.Context, lfsLock LfsLock) (*LfsLock, error) { err := s.db.Operator.Core.NewInsert(). Model(&lfsLock). Scan(ctx) @@ -77,7 +85,7 @@ func (s *LfsLockStore) Create(ctx context.Context, lfsLock LfsLock) (*LfsLock, e return &lfsLock, nil } -func (s *LfsLockStore) RemoveByID(ctx context.Context, ID int64) error { +func (s *lfsLockStoreImpl) RemoveByID(ctx context.Context, ID int64) error { _, err := s.db.Operator.Core.NewDelete(). Model(&LfsLock{}). Where("id = ?", ID). diff --git a/builder/store/database/lfs_meta_object.go b/builder/store/database/lfs_meta_object.go index bc61d3f8..22b3db70 100644 --- a/builder/store/database/lfs_meta_object.go +++ b/builder/store/database/lfs_meta_object.go @@ -6,12 +6,21 @@ import ( "time" ) -type LfsMetaObjectStore struct { +type lfsMetaObjectStoreImpl struct { db *DB } -func NewLfsMetaObjectStore() *LfsMetaObjectStore { - return &LfsMetaObjectStore{ +type LfsMetaObjectStore interface { + FindByOID(ctx context.Context, RepoId int64, Oid string) (*LfsMetaObject, error) + FindByRepoID(ctx context.Context, repoID int64) ([]LfsMetaObject, error) + Create(ctx context.Context, lfsObj LfsMetaObject) (*LfsMetaObject, error) + RemoveByOid(ctx context.Context, oid string, repoID int64) error + UpdateOrCreate(ctx context.Context, input LfsMetaObject) (*LfsMetaObject, error) + BulkUpdateOrCreate(ctx context.Context, input []LfsMetaObject) error +} + +func NewLfsMetaObjectStore() LfsMetaObjectStore { + return &lfsMetaObjectStoreImpl{ db: defaultDB, } } @@ -26,7 +35,7 @@ type LfsMetaObject struct { times } -func (s *LfsMetaObjectStore) FindByOID(ctx context.Context, RepoId int64, Oid string) (*LfsMetaObject, error) { +func (s *lfsMetaObjectStoreImpl) FindByOID(ctx context.Context, RepoId int64, Oid string) (*LfsMetaObject, error) { var lfsMetaObject LfsMetaObject err := s.db.Operator.Core.NewSelect(). Model(&lfsMetaObject). @@ -38,7 +47,7 @@ func (s *LfsMetaObjectStore) FindByOID(ctx context.Context, RepoId int64, Oid st return &lfsMetaObject, nil } -func (s *LfsMetaObjectStore) FindByRepoID(ctx context.Context, repoID int64) ([]LfsMetaObject, error) { +func (s *lfsMetaObjectStoreImpl) FindByRepoID(ctx context.Context, repoID int64) ([]LfsMetaObject, error) { var lfsMetaObjects []LfsMetaObject err := s.db.Operator.Core.NewSelect(). Model(&lfsMetaObjects). @@ -50,7 +59,7 @@ func (s *LfsMetaObjectStore) FindByRepoID(ctx context.Context, repoID int64) ([] return lfsMetaObjects, nil } -func (s *LfsMetaObjectStore) Create(ctx context.Context, lfsObj LfsMetaObject) (*LfsMetaObject, error) { +func (s *lfsMetaObjectStoreImpl) Create(ctx context.Context, lfsObj LfsMetaObject) (*LfsMetaObject, error) { err := s.db.Operator.Core.NewInsert(). Model(&lfsObj). Scan(ctx) @@ -60,7 +69,7 @@ func (s *LfsMetaObjectStore) Create(ctx context.Context, lfsObj LfsMetaObject) ( return &lfsObj, nil } -func (s *LfsMetaObjectStore) RemoveByOid(ctx context.Context, oid string, repoID int64) error { +func (s *lfsMetaObjectStoreImpl) RemoveByOid(ctx context.Context, oid string, repoID int64) error { err := s.db.Operator.Core.NewDelete(). Model(&LfsMetaObject{}). Where("oid = ? and repository_id= ?", oid, repoID). @@ -69,7 +78,7 @@ func (s *LfsMetaObjectStore) RemoveByOid(ctx context.Context, oid string, repoID return err } -func (s *LfsMetaObjectStore) UpdateOrCreate(ctx context.Context, input LfsMetaObject) (*LfsMetaObject, error) { +func (s *lfsMetaObjectStoreImpl) UpdateOrCreate(ctx context.Context, input LfsMetaObject) (*LfsMetaObject, error) { input.UpdatedAt = time.Now() _, err := s.db.Core.NewUpdate(). Model(&input). @@ -88,7 +97,7 @@ func (s *LfsMetaObjectStore) UpdateOrCreate(ctx context.Context, input LfsMetaOb return &input, nil } -func (s *LfsMetaObjectStore) BulkUpdateOrCreate(ctx context.Context, input []LfsMetaObject) error { +func (s *lfsMetaObjectStoreImpl) BulkUpdateOrCreate(ctx context.Context, input []LfsMetaObject) error { if len(input) == 0 { return nil } diff --git a/builder/store/database/llm_config.go b/builder/store/database/llm_config.go index 1219f9dc..6c1758f8 100644 --- a/builder/store/database/llm_config.go +++ b/builder/store/database/llm_config.go @@ -5,7 +5,7 @@ import ( "fmt" ) -type LLMConfigStore struct { +type lLMConfigStoreImpl struct { db *DB } @@ -19,11 +19,15 @@ type LLMConfig struct { times } -func NewLLMConfigStore() *LLMConfigStore { - return &LLMConfigStore{db: defaultDB} +type LLMConfigStore interface { + GetOptimization(ctx context.Context) (*LLMConfig, error) } -func (s *LLMConfigStore) GetOptimization(ctx context.Context) (*LLMConfig, error) { +func NewLLMConfigStore() LLMConfigStore { + return &lLMConfigStoreImpl{db: defaultDB} +} + +func (s *lLMConfigStoreImpl) GetOptimization(ctx context.Context) (*LLMConfig, error) { var config LLMConfig err := s.db.Operator.Core.NewSelect().Model(&config).Where("type = 1 and enabled = true").Limit(1).Scan(ctx) if err != nil { diff --git a/builder/store/database/member.go b/builder/store/database/member.go index 41777132..12f4316f 100644 --- a/builder/store/database/member.go +++ b/builder/store/database/member.go @@ -5,12 +5,20 @@ import ( "fmt" ) -type MemberStore struct { +type memberStoreImpl struct { db *DB } -func NewMemberStore() *MemberStore { - return &MemberStore{ +type MemberStore interface { + Find(ctx context.Context, orgID, userID int64) (*Member, error) + Add(ctx context.Context, orgID, userID int64, role string) error + Delete(ctx context.Context, orgID, userID int64, role string) error + UserMembers(ctx context.Context, userID int64) ([]Member, error) + OrganizationMembers(ctx context.Context, orgID int64, pageSize, page int) ([]Member, int, error) +} + +func NewMemberStore() MemberStore { + return &memberStoreImpl{ db: defaultDB, } } @@ -26,7 +34,7 @@ type Member struct { times } -func (s *MemberStore) Find(ctx context.Context, orgID, userID int64) (*Member, error) { +func (s *memberStoreImpl) Find(ctx context.Context, orgID, userID int64) (*Member, error) { var member Member err := s.db.Core.NewSelect().Model(&member).Where("organization_id = ? AND user_id = ?", orgID, userID).Scan(ctx) if err != nil { @@ -35,7 +43,7 @@ func (s *MemberStore) Find(ctx context.Context, orgID, userID int64) (*Member, e return &member, nil } -func (s *MemberStore) Add(ctx context.Context, orgID, userID int64, role string) error { +func (s *memberStoreImpl) Add(ctx context.Context, orgID, userID int64, role string) error { member := &Member{ OrganizationID: orgID, UserID: userID, @@ -48,19 +56,19 @@ func (s *MemberStore) Add(ctx context.Context, orgID, userID int64, role string) return assertAffectedOneRow(result, err) } -func (s *MemberStore) Delete(ctx context.Context, orgID, userID int64, role string) error { +func (s *memberStoreImpl) Delete(ctx context.Context, orgID, userID int64, role string) error { var member Member _, err := s.db.Core.NewDelete().Model(&member).Where("organization_id=? and user_id=? and role=?", orgID, userID, role).Exec(ctx) return err } -func (s *MemberStore) UserMembers(ctx context.Context, userID int64) ([]Member, error) { +func (s *memberStoreImpl) UserMembers(ctx context.Context, userID int64) ([]Member, error) { var members []Member err := s.db.Core.NewSelect().Model((*Member)(nil)).Where("user_id=?", userID).Scan(ctx, &members) return members, err } -func (s *MemberStore) OrganizationMembers(ctx context.Context, orgID int64, pageSize, page int) ([]Member, int, error) { +func (s *memberStoreImpl) OrganizationMembers(ctx context.Context, orgID int64, pageSize, page int) ([]Member, int, error) { var members []Member var total int q := s.db.Core.NewSelect().Model((*Member)(nil)). diff --git a/builder/store/database/mirror.go b/builder/store/database/mirror.go index a58a88a0..9f733d56 100644 --- a/builder/store/database/mirror.go +++ b/builder/store/database/mirror.go @@ -9,12 +9,34 @@ import ( "opencsg.com/csghub-server/common/types" ) -type MirrorStore struct { +type mirrorStoreImpl struct { db *DB } -func NewMirrorStore() *MirrorStore { - return &MirrorStore{ +type MirrorStore interface { + IsExist(ctx context.Context, repoID int64) (exists bool, err error) + IsRepoExist(ctx context.Context, repoType types.RepositoryType, namespace, name string) (exists bool, err error) + FindByRepoID(ctx context.Context, repoID int64) (*Mirror, error) + FindByID(ctx context.Context, ID int64) (*Mirror, error) + FindByRepoPath(ctx context.Context, repoType types.RepositoryType, namespace, name string) (*Mirror, error) + FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*Mirror, error) + Create(ctx context.Context, mirror *Mirror) (*Mirror, error) + WithPagination(ctx context.Context) ([]Mirror, error) + WithPaginationWithRepository(ctx context.Context) ([]Mirror, error) + NoPushMirror(ctx context.Context) ([]Mirror, error) + PushedMirror(ctx context.Context) ([]Mirror, error) + Update(ctx context.Context, mirror *Mirror) (err error) + Delete(ctx context.Context, mirror *Mirror) (err error) + Unfinished(ctx context.Context) ([]Mirror, error) + Finished(ctx context.Context) ([]Mirror, error) + ToSyncRepo(ctx context.Context) ([]Mirror, error) + ToSyncLfs(ctx context.Context) ([]Mirror, error) + IndexWithPagination(ctx context.Context, per, page int) (mirrors []Mirror, count int, err error) + UpdateMirrorAndRepository(ctx context.Context, mirror *Mirror, repo *Repository) error +} + +func NewMirrorStore() MirrorStore { + return &mirrorStoreImpl{ db: defaultDB, } } @@ -48,7 +70,7 @@ type Mirror struct { times } -func (s *MirrorStore) IsExist(ctx context.Context, repoID int64) (exists bool, err error) { +func (s *mirrorStoreImpl) IsExist(ctx context.Context, repoID int64) (exists bool, err error) { var mirror Mirror exists, err = s.db.Operator.Core. NewSelect(). @@ -57,7 +79,7 @@ func (s *MirrorStore) IsExist(ctx context.Context, repoID int64) (exists bool, e Exists(ctx) return } -func (s *MirrorStore) IsRepoExist(ctx context.Context, repoType types.RepositoryType, namespace, name string) (exists bool, err error) { +func (s *mirrorStoreImpl) IsRepoExist(ctx context.Context, repoType types.RepositoryType, namespace, name string) (exists bool, err error) { var repo Repository exists, err = s.db.Operator.Core. NewSelect(). @@ -67,7 +89,7 @@ func (s *MirrorStore) IsRepoExist(ctx context.Context, repoType types.Repository return } -func (s *MirrorStore) FindByRepoID(ctx context.Context, repoID int64) (*Mirror, error) { +func (s *mirrorStoreImpl) FindByRepoID(ctx context.Context, repoID int64) (*Mirror, error) { var mirror Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirror). @@ -79,7 +101,7 @@ func (s *MirrorStore) FindByRepoID(ctx context.Context, repoID int64) (*Mirror, return &mirror, nil } -func (s *MirrorStore) FindByID(ctx context.Context, ID int64) (*Mirror, error) { +func (s *mirrorStoreImpl) FindByID(ctx context.Context, ID int64) (*Mirror, error) { var mirror Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirror). @@ -92,7 +114,7 @@ func (s *MirrorStore) FindByID(ctx context.Context, ID int64) (*Mirror, error) { return &mirror, nil } -func (s *MirrorStore) FindByRepoPath(ctx context.Context, repoType types.RepositoryType, namespace, name string) (*Mirror, error) { +func (s *mirrorStoreImpl) FindByRepoPath(ctx context.Context, repoType types.RepositoryType, namespace, name string) (*Mirror, error) { var mirror Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirror). @@ -105,7 +127,7 @@ func (s *MirrorStore) FindByRepoPath(ctx context.Context, repoType types.Reposit return &mirror, nil } -func (s *MirrorStore) FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*Mirror, error) { +func (s *mirrorStoreImpl) FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*Mirror, error) { var mirror Mirror var err error if mapping == types.CSGHubMapping { @@ -138,7 +160,7 @@ func (s *MirrorStore) FindWithMapping(ctx context.Context, repoType types.Reposi return &mirror, nil } -func (s *MirrorStore) Create(ctx context.Context, mirror *Mirror) (*Mirror, error) { +func (s *mirrorStoreImpl) Create(ctx context.Context, mirror *Mirror) (*Mirror, error) { err := s.db.Operator.Core.NewInsert(). Model(mirror). Scan(ctx) @@ -148,7 +170,7 @@ func (s *MirrorStore) Create(ctx context.Context, mirror *Mirror) (*Mirror, erro return mirror, nil } -func (s *MirrorStore) WithPagination(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) WithPagination(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -159,7 +181,7 @@ func (s *MirrorStore) WithPagination(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *MirrorStore) WithPaginationWithRepository(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) WithPaginationWithRepository(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -171,7 +193,7 @@ func (s *MirrorStore) WithPaginationWithRepository(ctx context.Context) ([]Mirro return mirrors, nil } -func (s *MirrorStore) NoPushMirror(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) NoPushMirror(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -183,7 +205,7 @@ func (s *MirrorStore) NoPushMirror(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *MirrorStore) PushedMirror(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) PushedMirror(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -196,7 +218,7 @@ func (s *MirrorStore) PushedMirror(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *MirrorStore) Update(ctx context.Context, mirror *Mirror) (err error) { +func (s *mirrorStoreImpl) Update(ctx context.Context, mirror *Mirror) (err error) { err = assertAffectedOneRow(s.db.Operator.Core.NewUpdate(). Model(mirror). WherePK(). @@ -206,7 +228,7 @@ func (s *MirrorStore) Update(ctx context.Context, mirror *Mirror) (err error) { return } -func (s *MirrorStore) Delete(ctx context.Context, mirror *Mirror) (err error) { +func (s *mirrorStoreImpl) Delete(ctx context.Context, mirror *Mirror) (err error) { _, err = s.db.Operator.Core. NewDelete(). Model(mirror). @@ -215,7 +237,7 @@ func (s *MirrorStore) Delete(ctx context.Context, mirror *Mirror) (err error) { return } -func (s *MirrorStore) Unfinished(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) Unfinished(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -228,7 +250,7 @@ func (s *MirrorStore) Unfinished(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *MirrorStore) Finished(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) Finished(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -241,7 +263,7 @@ func (s *MirrorStore) Finished(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *MirrorStore) ToSyncRepo(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) ToSyncRepo(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -253,7 +275,7 @@ func (s *MirrorStore) ToSyncRepo(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *MirrorStore) ToSyncLfs(ctx context.Context) ([]Mirror, error) { +func (s *mirrorStoreImpl) ToSyncLfs(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). @@ -265,7 +287,7 @@ func (s *MirrorStore) ToSyncLfs(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *MirrorStore) IndexWithPagination(ctx context.Context, per, page int) (mirrors []Mirror, count int, err error) { +func (s *mirrorStoreImpl) IndexWithPagination(ctx context.Context, per, page int) (mirrors []Mirror, count int, err error) { q := s.db.Operator.Core.NewSelect(). Model(&mirrors). Relation("Repository"). @@ -285,7 +307,7 @@ func (s *MirrorStore) IndexWithPagination(ctx context.Context, per, page int) (m return } -func (s *MirrorStore) UpdateMirrorAndRepository(ctx context.Context, mirror *Mirror, repo *Repository) error { +func (s *mirrorStoreImpl) UpdateMirrorAndRepository(ctx context.Context, mirror *Mirror, repo *Repository) error { err := s.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { _, err := tx.NewUpdate().Model(mirror).WherePK().Exec(ctx) if err != nil { diff --git a/builder/store/database/mirror_source.go b/builder/store/database/mirror_source.go index 31dbe25b..314f171c 100644 --- a/builder/store/database/mirror_source.go +++ b/builder/store/database/mirror_source.go @@ -6,12 +6,21 @@ import ( "strings" ) -type MirrorSourceStore struct { +type mirrorSourceStoreImpl struct { db *DB } -func NewMirrorSourceStore() *MirrorSourceStore { - return &MirrorSourceStore{ +type MirrorSourceStore interface { + Create(ctx context.Context, mirrorSource *MirrorSource) (*MirrorSource, error) + Index(ctx context.Context) ([]MirrorSource, error) + Get(ctx context.Context, id int64) (*MirrorSource, error) + FindByName(ctx context.Context, name string) (*MirrorSource, error) + Update(ctx context.Context, mirrorSource *MirrorSource) (err error) + Delete(ctx context.Context, mirrorSource *MirrorSource) (err error) +} + +func NewMirrorSourceStore() MirrorSourceStore { + return &mirrorSourceStoreImpl{ db: defaultDB, } } @@ -24,7 +33,7 @@ type MirrorSource struct { times } -func (s *MirrorSourceStore) Create(ctx context.Context, mirrorSource *MirrorSource) (*MirrorSource, error) { +func (s *mirrorSourceStoreImpl) Create(ctx context.Context, mirrorSource *MirrorSource) (*MirrorSource, error) { err := s.db.Operator.Core.NewInsert(). Model(mirrorSource). Scan(ctx) @@ -34,7 +43,7 @@ func (s *MirrorSourceStore) Create(ctx context.Context, mirrorSource *MirrorSour return mirrorSource, nil } -func (s *MirrorSourceStore) Index(ctx context.Context) ([]MirrorSource, error) { +func (s *mirrorSourceStoreImpl) Index(ctx context.Context) ([]MirrorSource, error) { var mirrorSources []MirrorSource err := s.db.Operator.Core.NewSelect(). Model(&mirrorSources). @@ -45,7 +54,7 @@ func (s *MirrorSourceStore) Index(ctx context.Context) ([]MirrorSource, error) { return mirrorSources, nil } -func (s *MirrorSourceStore) Get(ctx context.Context, id int64) (*MirrorSource, error) { +func (s *mirrorSourceStoreImpl) Get(ctx context.Context, id int64) (*MirrorSource, error) { var mirrorSource MirrorSource err := s.db.Operator.Core.NewSelect(). Model(&mirrorSource). @@ -57,7 +66,7 @@ func (s *MirrorSourceStore) Get(ctx context.Context, id int64) (*MirrorSource, e return &mirrorSource, nil } -func (s *MirrorSourceStore) FindByName(ctx context.Context, name string) (*MirrorSource, error) { +func (s *mirrorSourceStoreImpl) FindByName(ctx context.Context, name string) (*MirrorSource, error) { var mirrorSource MirrorSource err := s.db.Operator.Core.NewSelect(). Model(&mirrorSource). @@ -69,7 +78,7 @@ func (s *MirrorSourceStore) FindByName(ctx context.Context, name string) (*Mirro return &mirrorSource, nil } -func (s *MirrorSourceStore) Update(ctx context.Context, mirrorSource *MirrorSource) (err error) { +func (s *mirrorSourceStoreImpl) Update(ctx context.Context, mirrorSource *MirrorSource) (err error) { err = assertAffectedOneRow(s.db.Operator.Core.NewUpdate(). Model(mirrorSource). WherePK(). @@ -79,7 +88,7 @@ func (s *MirrorSourceStore) Update(ctx context.Context, mirrorSource *MirrorSour return } -func (s *MirrorSourceStore) Delete(ctx context.Context, mirrorSource *MirrorSource) (err error) { +func (s *mirrorSourceStoreImpl) Delete(ctx context.Context, mirrorSource *MirrorSource) (err error) { _, err = s.db.Operator.Core. NewDelete(). Model(mirrorSource). diff --git a/builder/store/database/model.go b/builder/store/database/model.go index b115bbba..34452bc7 100644 --- a/builder/store/database/model.go +++ b/builder/store/database/model.go @@ -10,12 +10,29 @@ import ( "opencsg.com/csghub-server/common/types" ) -type ModelStore struct { +type modelStoreImpl struct { db *DB } -func NewModelStore() *ModelStore { - return &ModelStore{ +type ModelStore interface { + ByRepoIDs(ctx context.Context, repoIDs []int64) (models []Model, err error) + ByRepoID(ctx context.Context, repoID int64) (*Model, error) + ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (models []Model, total int, err error) + UserLikesModels(ctx context.Context, userID int64, per, page int) (models []Model, total int, err error) + ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (models []Model, total int, err error) + Count(ctx context.Context) (count int, err error) + PublicCount(ctx context.Context) (count int, err error) + Create(ctx context.Context, input Model) (*Model, error) + Update(ctx context.Context, input Model) (*Model, error) + FindByPath(ctx context.Context, namespace string, name string) (*Model, error) + Delete(ctx context.Context, input Model) error + ListByPath(ctx context.Context, paths []string) ([]Model, error) + ByID(ctx context.Context, id int64) (*Model, error) + CreateIfNotExist(ctx context.Context, input Model) (*Model, error) +} + +func NewModelStore() ModelStore { + return &modelStoreImpl{ db: defaultDB, } } @@ -29,7 +46,7 @@ type Model struct { times } -func (s *ModelStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (models []Model, err error) { +func (s *modelStoreImpl) ByRepoIDs(ctx context.Context, repoIDs []int64) (models []Model, err error) { err = s.db.Operator.Core.NewSelect(). Model(&models). Relation("Repository"). @@ -39,7 +56,7 @@ func (s *ModelStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (models []M return } -func (s *ModelStore) ByRepoID(ctx context.Context, repoID int64) (*Model, error) { +func (s *modelStoreImpl) ByRepoID(ctx context.Context, repoID int64) (*Model, error) { var m Model err := s.db.Core.NewSelect(). Model(&m). @@ -52,7 +69,7 @@ func (s *ModelStore) ByRepoID(ctx context.Context, repoID int64) (*Model, error) return &m, nil } -func (s *ModelStore) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (models []Model, total int, err error) { +func (s *modelStoreImpl) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (models []Model, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&models). @@ -78,7 +95,7 @@ func (s *ModelStore) ByUsername(ctx context.Context, username string, per, page return } -func (s *ModelStore) UserLikesModels(ctx context.Context, userID int64, per, page int) (models []Model, total int, err error) { +func (s *modelStoreImpl) UserLikesModels(ctx context.Context, userID int64, per, page int) (models []Model, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&models). @@ -101,7 +118,7 @@ func (s *ModelStore) UserLikesModels(ctx context.Context, userID int64, per, pag return } -func (s *ModelStore) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (models []Model, total int, err error) { +func (s *modelStoreImpl) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (models []Model, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&models). @@ -127,7 +144,7 @@ func (s *ModelStore) ByOrgPath(ctx context.Context, namespace string, per, page return } -func (s *ModelStore) Count(ctx context.Context) (count int, err error) { +func (s *modelStoreImpl) Count(ctx context.Context) (count int, err error) { count, err = s.db.Operator.Core. NewSelect(). Model(&Repository{}). @@ -139,7 +156,7 @@ func (s *ModelStore) Count(ctx context.Context) (count int, err error) { return } -func (s *ModelStore) PublicCount(ctx context.Context) (count int, err error) { +func (s *modelStoreImpl) PublicCount(ctx context.Context) (count int, err error) { count, err = s.db.Operator.Core. NewSelect(). Model(&Repository{}). @@ -152,7 +169,7 @@ func (s *ModelStore) PublicCount(ctx context.Context) (count int, err error) { return } -func (s *ModelStore) Create(ctx context.Context, input Model) (*Model, error) { +func (s *modelStoreImpl) Create(ctx context.Context, input Model) (*Model, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { slog.Error("create model in db failed", slog.String("error", err.Error())) @@ -162,13 +179,13 @@ func (s *ModelStore) Create(ctx context.Context, input Model) (*Model, error) { return &input, nil } -func (s *ModelStore) Update(ctx context.Context, input Model) (*Model, error) { +func (s *modelStoreImpl) Update(ctx context.Context, input Model) (*Model, error) { _, err := s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return &input, err } -func (s *ModelStore) FindByPath(ctx context.Context, namespace string, name string) (*Model, error) { +func (s *modelStoreImpl) FindByPath(ctx context.Context, namespace string, name string) (*Model, error) { resModel := new(Model) err := s.db.Operator.Core. NewSelect(). @@ -191,7 +208,7 @@ func (s *ModelStore) FindByPath(ctx context.Context, namespace string, name stri return resModel, err } -func (s *ModelStore) Delete(ctx context.Context, input Model) error { +func (s *modelStoreImpl) Delete(ctx context.Context, input Model) error { res, err := s.db.Operator.Core.NewDelete().Model(&input).WherePK().Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { return fmt.Errorf("delete model in tx failed,error:%w", err) @@ -199,7 +216,7 @@ func (s *ModelStore) Delete(ctx context.Context, input Model) error { return nil } -func (s *ModelStore) ListByPath(ctx context.Context, paths []string) ([]Model, error) { +func (s *modelStoreImpl) ListByPath(ctx context.Context, paths []string) ([]Model, error) { var models []Model err := s.db.Operator.Core. NewSelect(). @@ -223,7 +240,7 @@ func (s *ModelStore) ListByPath(ctx context.Context, paths []string) ([]Model, e return sortedModels, nil } -func (s *ModelStore) ByID(ctx context.Context, id int64) (*Model, error) { +func (s *modelStoreImpl) ByID(ctx context.Context, id int64) (*Model, error) { var model Model err := s.db.Core.NewSelect().Model(&model).Relation("Repository").Where("model.id = ?", id).Scan(ctx) if err != nil { @@ -232,7 +249,7 @@ func (s *ModelStore) ByID(ctx context.Context, id int64) (*Model, error) { return &model, err } -func (s *ModelStore) CreateIfNotExist(ctx context.Context, input Model) (*Model, error) { +func (s *modelStoreImpl) CreateIfNotExist(ctx context.Context, input Model) (*Model, error) { err := s.db.Core.NewSelect(). Model(&input). Where("repository_id = ?", input.RepositoryID). diff --git a/builder/store/database/multi_sync.go b/builder/store/database/multi_sync.go index 00b90785..521efba5 100644 --- a/builder/store/database/multi_sync.go +++ b/builder/store/database/multi_sync.go @@ -5,17 +5,26 @@ import ( "fmt" ) -type MultiSyncStore struct { +type multiSyncStoreImpl struct { db *DB } -func NewMultiSyncStore() *MultiSyncStore { - return &MultiSyncStore{ +type MultiSyncStore interface { + Create(ctx context.Context, v SyncVersion) (*SyncVersion, error) + // GetAfter get N records after version in ASC order + GetAfter(ctx context.Context, version, limit int64) ([]SyncVersion, error) + // GetLatest get max sync version + GetLatest(ctx context.Context) (SyncVersion, error) + GetAfterDistinct(ctx context.Context, version int64) ([]SyncVersion, error) +} + +func NewMultiSyncStore() MultiSyncStore { + return &multiSyncStoreImpl{ db: defaultDB, } } -func (s *MultiSyncStore) Create(ctx context.Context, v SyncVersion) (*SyncVersion, error) { +func (s *multiSyncStoreImpl) Create(ctx context.Context, v SyncVersion) (*SyncVersion, error) { res, err := s.db.Core.NewInsert().Model(&v).Exec(ctx, &v) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("create sync version in db failed,error:%w", err) @@ -25,7 +34,7 @@ func (s *MultiSyncStore) Create(ctx context.Context, v SyncVersion) (*SyncVersio } // GetAfter get N records after version in ASC order -func (s *MultiSyncStore) GetAfter(ctx context.Context, version, limit int64) ([]SyncVersion, error) { +func (s *multiSyncStoreImpl) GetAfter(ctx context.Context, version, limit int64) ([]SyncVersion, error) { var vs []SyncVersion err := s.db.Core.NewSelect().Model(&vs).Where("version > ?", version). Order("version asc"). @@ -35,7 +44,7 @@ func (s *MultiSyncStore) GetAfter(ctx context.Context, version, limit int64) ([] } // GetLatest get max sync version -func (s *MultiSyncStore) GetLatest(ctx context.Context) (SyncVersion, error) { +func (s *multiSyncStoreImpl) GetLatest(ctx context.Context) (SyncVersion, error) { var v SyncVersion err := s.db.Core.NewSelect().Model(&v). Order("version desc"). @@ -45,7 +54,7 @@ func (s *MultiSyncStore) GetLatest(ctx context.Context) (SyncVersion, error) { return v, err } -func (s *MultiSyncStore) GetAfterDistinct(ctx context.Context, version int64) ([]SyncVersion, error) { +func (s *multiSyncStoreImpl) GetAfterDistinct(ctx context.Context, version int64) ([]SyncVersion, error) { var vs []SyncVersion err := s.db.Core.NewSelect(). ColumnExpr("DISTINCT ON (source_id, repo_path, repo_type) version, source_id, repo_path, repo_type, last_modified_at, change_log"). diff --git a/builder/store/database/namespace.go b/builder/store/database/namespace.go index 2ed376b1..063fa6f6 100644 --- a/builder/store/database/namespace.go +++ b/builder/store/database/namespace.go @@ -4,12 +4,17 @@ import ( "context" ) -type NamespaceStore struct { +type namespaceStoreImpl struct { db *DB } -func NewNamespaceStore() *NamespaceStore { - return &NamespaceStore{db: defaultDB} +type NamespaceStore interface { + FindByPath(ctx context.Context, path string) (namespace Namespace, err error) + Exists(ctx context.Context, path string) (exists bool, err error) +} + +func NewNamespaceStore() NamespaceStore { + return &namespaceStoreImpl{db: defaultDB} } type NamespaceType string @@ -29,13 +34,13 @@ type Namespace struct { times } -func (s *NamespaceStore) FindByPath(ctx context.Context, path string) (namespace Namespace, err error) { +func (s *namespaceStoreImpl) FindByPath(ctx context.Context, path string) (namespace Namespace, err error) { namespace.Path = path err = s.db.Operator.Core.NewSelect().Model(&namespace).Relation("User").Where("path = ?", path).Scan(ctx) return } -func (s *NamespaceStore) Exists(ctx context.Context, path string) (exists bool, err error) { +func (s *namespaceStoreImpl) Exists(ctx context.Context, path string) (exists bool, err error) { var namespace Namespace return s.db.Operator.Core. NewSelect(). diff --git a/builder/store/database/organization.go b/builder/store/database/organization.go index 430bf1b5..ede0fda2 100644 --- a/builder/store/database/organization.go +++ b/builder/store/database/organization.go @@ -6,12 +6,22 @@ import ( "github.com/uptrace/bun" ) -type OrgStore struct { +type orgStoreImpl struct { db *DB } -func NewOrgStore() *OrgStore { - return &OrgStore{ +type OrgStore interface { + Create(ctx context.Context, org *Organization, namepace *Namespace) (err error) + GetUserOwnOrgs(ctx context.Context, username string) (orgs []Organization, err error) + Update(ctx context.Context, org *Organization) (err error) + Delete(ctx context.Context, path string) (err error) + FindByPath(ctx context.Context, path string) (org Organization, err error) + Exists(ctx context.Context, path string) (exists bool, err error) + GetUserBelongOrgs(ctx context.Context, userID int64) (orgs []Organization, err error) +} + +func NewOrgStore() OrgStore { + return &orgStoreImpl{ db: defaultDB, } } @@ -34,7 +44,7 @@ type Organization struct { times } -func (s *OrgStore) Create(ctx context.Context, org *Organization, namepace *Namespace) (err error) { +func (s *orgStoreImpl) Create(ctx context.Context, org *Organization, namepace *Namespace) (err error) { err = s.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if err = assertAffectedOneRow(tx.NewInsert().Model(org).Exec(ctx)); err != nil { return err @@ -48,7 +58,7 @@ func (s *OrgStore) Create(ctx context.Context, org *Organization, namepace *Name return } -func (s *OrgStore) GetUserOwnOrgs(ctx context.Context, username string) (orgs []Organization, err error) { +func (s *orgStoreImpl) GetUserOwnOrgs(ctx context.Context, username string) (orgs []Organization, err error) { query := s.db.Operator.Core. NewSelect(). Model(&orgs). @@ -63,7 +73,7 @@ func (s *OrgStore) GetUserOwnOrgs(ctx context.Context, username string) (orgs [] return } -func (s *OrgStore) Update(ctx context.Context, org *Organization) (err error) { +func (s *orgStoreImpl) Update(ctx context.Context, org *Organization) (err error) { err = assertAffectedOneRow(s.db.Operator.Core. NewUpdate(). Model(org). @@ -72,7 +82,7 @@ func (s *OrgStore) Update(ctx context.Context, org *Organization) (err error) { return } -func (s *OrgStore) Delete(ctx context.Context, path string) (err error) { +func (s *orgStoreImpl) Delete(ctx context.Context, path string) (err error) { err = s.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if err = assertAffectedOneRow( tx.NewDelete(). @@ -93,7 +103,7 @@ func (s *OrgStore) Delete(ctx context.Context, path string) (err error) { return } -func (s *OrgStore) FindByPath(ctx context.Context, path string) (org Organization, err error) { +func (s *orgStoreImpl) FindByPath(ctx context.Context, path string) (org Organization, err error) { org.Nickname = path err = s.db.Operator.Core. NewSelect(). @@ -103,7 +113,7 @@ func (s *OrgStore) FindByPath(ctx context.Context, path string) (org Organizatio return } -func (s *OrgStore) Exists(ctx context.Context, path string) (exists bool, err error) { +func (s *orgStoreImpl) Exists(ctx context.Context, path string) (exists bool, err error) { var org Organization exists, err = s.db.Operator.Core. NewSelect(). @@ -116,7 +126,7 @@ func (s *OrgStore) Exists(ctx context.Context, path string) (exists bool, err er return } -func (s *OrgStore) GetUserBelongOrgs(ctx context.Context, userID int64) (orgs []Organization, err error) { +func (s *orgStoreImpl) GetUserBelongOrgs(ctx context.Context, userID int64) (orgs []Organization, err error) { err = s.db.Operator.Core. NewSelect(). Model(&orgs). diff --git a/builder/store/database/prompt.go b/builder/store/database/prompt.go index c76dcee0..80891a90 100644 --- a/builder/store/database/prompt.go +++ b/builder/store/database/prompt.go @@ -14,15 +14,26 @@ type Prompt struct { times } -type PromptStore struct { +type promptStoreImpl struct { db *DB } -func NewPromptStore() *PromptStore { - return &PromptStore{db: defaultDB} +type PromptStore interface { + Create(ctx context.Context, input Prompt) (*Prompt, error) + ByRepoIDs(ctx context.Context, repoIDs []int64) (prompts []Prompt, err error) + ByRepoID(ctx context.Context, repoID int64) (*Prompt, error) + Update(ctx context.Context, input Prompt) (err error) + FindByPath(ctx context.Context, namespace string, repoPath string) (*Prompt, error) + Delete(ctx context.Context, input Prompt) error + ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (prompts []Prompt, total int, err error) + ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (prompts []Prompt, total int, err error) } -func (s *PromptStore) Create(ctx context.Context, input Prompt) (*Prompt, error) { +func NewPromptStore() PromptStore { + return &promptStoreImpl{db: defaultDB} +} + +func (s *promptStoreImpl) Create(ctx context.Context, input Prompt) (*Prompt, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("create prompt in db failed,error:%w", err) @@ -31,7 +42,7 @@ func (s *PromptStore) Create(ctx context.Context, input Prompt) (*Prompt, error) return &input, nil } -func (s *PromptStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (prompts []Prompt, err error) { +func (s *promptStoreImpl) ByRepoIDs(ctx context.Context, repoIDs []int64) (prompts []Prompt, err error) { q := s.db.Operator.Core.NewSelect(). Model(&prompts). Relation("Repository"). @@ -41,7 +52,7 @@ func (s *PromptStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (prompts [ return } -func (s *PromptStore) ByRepoID(ctx context.Context, repoID int64) (*Prompt, error) { +func (s *promptStoreImpl) ByRepoID(ctx context.Context, repoID int64) (*Prompt, error) { var prompt Prompt err := s.db.Operator.Core.NewSelect(). Model(&prompt). @@ -54,12 +65,12 @@ func (s *PromptStore) ByRepoID(ctx context.Context, repoID int64) (*Prompt, erro return &prompt, nil } -func (s *PromptStore) Update(ctx context.Context, input Prompt) (err error) { +func (s *promptStoreImpl) Update(ctx context.Context, input Prompt) (err error) { _, err = s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return } -func (s *PromptStore) FindByPath(ctx context.Context, namespace string, repoPath string) (*Prompt, error) { +func (s *promptStoreImpl) FindByPath(ctx context.Context, namespace string, repoPath string) (*Prompt, error) { resPrompt := new(Prompt) err := s.db.Operator.Core. NewSelect(). @@ -80,7 +91,7 @@ func (s *PromptStore) FindByPath(ctx context.Context, namespace string, repoPath return resPrompt, err } -func (s *PromptStore) Delete(ctx context.Context, input Prompt) error { +func (s *promptStoreImpl) Delete(ctx context.Context, input Prompt) error { res, err := s.db.Operator.Core.NewDelete().Model(&input).WherePK().Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { return fmt.Errorf("delete prompt failed,error:%w", err) @@ -88,7 +99,7 @@ func (s *PromptStore) Delete(ctx context.Context, input Prompt) error { return nil } -func (s *PromptStore) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (prompts []Prompt, total int, err error) { +func (s *promptStoreImpl) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (prompts []Prompt, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&prompts). @@ -113,7 +124,7 @@ func (s *PromptStore) ByUsername(ctx context.Context, username string, per, page return } -func (s *PromptStore) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (prompts []Prompt, total int, err error) { +func (s *promptStoreImpl) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (prompts []Prompt, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&prompts). diff --git a/builder/store/database/prompt_conversation.go b/builder/store/database/prompt_conversation.go index 87f32138..f2bfa535 100644 --- a/builder/store/database/prompt_conversation.go +++ b/builder/store/database/prompt_conversation.go @@ -7,7 +7,7 @@ import ( "github.com/uptrace/bun" ) -type PromptConversationStore struct { +type promptConversationStoreImpl struct { db *DB } @@ -30,11 +30,22 @@ type PromptConversationMessage struct { times } -func NewPromptConversationStore() *PromptConversationStore { - return &PromptConversationStore{db: defaultDB} +type PromptConversationStore interface { + CreateConversation(ctx context.Context, conversation PromptConversation) error + SaveConversationMessage(ctx context.Context, message PromptConversationMessage) (*PromptConversationMessage, error) + UpdateConversation(ctx context.Context, conversation PromptConversation) error + FindConversationsByUserID(ctx context.Context, userID int64) ([]PromptConversation, error) + GetConversationByID(ctx context.Context, userID int64, uuid string, hasDetail bool) (*PromptConversation, error) + DeleteConversationsByID(ctx context.Context, userID int64, uuid string) error + LikeMessageByID(ctx context.Context, id int64) error + HateMessageByID(ctx context.Context, id int64) error } -func (p *PromptConversationStore) CreateConversation(ctx context.Context, conversation PromptConversation) error { +func NewPromptConversationStore() PromptConversationStore { + return &promptConversationStoreImpl{db: defaultDB} +} + +func (p *promptConversationStoreImpl) CreateConversation(ctx context.Context, conversation PromptConversation) error { err := p.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if err := assertAffectedOneRow(tx.NewInsert().Model(&conversation).Exec(ctx)); err != nil { return fmt.Errorf("save conversation, %v, error:%w", conversation, err) @@ -44,7 +55,7 @@ func (p *PromptConversationStore) CreateConversation(ctx context.Context, conver return err } -func (p *PromptConversationStore) SaveConversationMessage(ctx context.Context, message PromptConversationMessage) (*PromptConversationMessage, error) { +func (p *promptConversationStoreImpl) SaveConversationMessage(ctx context.Context, message PromptConversationMessage) (*PromptConversationMessage, error) { res, err := p.db.Core.NewInsert().Model(&message).Exec(ctx, &message) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("insert message, %v, error:%w", message, err) @@ -52,7 +63,7 @@ func (p *PromptConversationStore) SaveConversationMessage(ctx context.Context, m return &message, nil } -func (p *PromptConversationStore) UpdateConversation(ctx context.Context, conversation PromptConversation) error { +func (p *promptConversationStoreImpl) UpdateConversation(ctx context.Context, conversation PromptConversation) error { res, err := p.db.Core.NewUpdate().Model(&conversation). Where("user_id = ?", conversation.UserID). Where("conversation_id = ?", conversation.ConversationID). @@ -63,7 +74,7 @@ func (p *PromptConversationStore) UpdateConversation(ctx context.Context, conver return nil } -func (p *PromptConversationStore) FindConversationsByUserID(ctx context.Context, userID int64) ([]PromptConversation, error) { +func (p *promptConversationStoreImpl) FindConversationsByUserID(ctx context.Context, userID int64) ([]PromptConversation, error) { var conversations []PromptConversation err := p.db.Operator.Core.NewSelect().Model(&conversations).Where("user_id = ?", userID).Order("id desc").Scan(ctx) if err != nil { @@ -72,7 +83,7 @@ func (p *PromptConversationStore) FindConversationsByUserID(ctx context.Context, return conversations, nil } -func (p *PromptConversationStore) GetConversationByID(ctx context.Context, userID int64, uuid string, hasDetail bool) (*PromptConversation, error) { +func (p *promptConversationStoreImpl) GetConversationByID(ctx context.Context, userID int64, uuid string, hasDetail bool) (*PromptConversation, error) { var conversation PromptConversation q := p.db.Operator.Core.NewSelect().Model(&conversation) if hasDetail { @@ -85,7 +96,7 @@ func (p *PromptConversationStore) GetConversationByID(ctx context.Context, userI return &conversation, nil } -func (p *PromptConversationStore) DeleteConversationsByID(ctx context.Context, userID int64, uuid string) error { +func (p *promptConversationStoreImpl) DeleteConversationsByID(ctx context.Context, userID int64, uuid string) error { err := p.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { res, err := tx.NewDelete().Model(&PromptConversation{}).Where("user_id = ? and conversation_id = ?", userID, uuid).Exec(ctx) err = assertAffectedOneRow(res, err) @@ -102,7 +113,7 @@ func (p *PromptConversationStore) DeleteConversationsByID(ctx context.Context, u return err } -func (p *PromptConversationStore) LikeMessageByID(ctx context.Context, id int64) error { +func (p *promptConversationStoreImpl) LikeMessageByID(ctx context.Context, id int64) error { res, err := p.db.BunDB.Exec("update prompt_conversation_messages set user_like=NOT user_like where id = ?", id) if err != nil { return err @@ -111,7 +122,7 @@ func (p *PromptConversationStore) LikeMessageByID(ctx context.Context, id int64) return err } -func (p *PromptConversationStore) HateMessageByID(ctx context.Context, id int64) error { +func (p *promptConversationStoreImpl) HateMessageByID(ctx context.Context, id int64) error { res, err := p.db.BunDB.Exec("update prompt_conversation_messages set user_hate=NOT user_hate where id = ?", id) if err != nil { return err diff --git a/builder/store/database/prompt_prefix.go b/builder/store/database/prompt_prefix.go index c7940f9b..03f14e3f 100644 --- a/builder/store/database/prompt_prefix.go +++ b/builder/store/database/prompt_prefix.go @@ -5,7 +5,7 @@ import ( "fmt" ) -type PromptPrefixStore struct { +type promptPrefixStoreImpl struct { db *DB } @@ -15,11 +15,15 @@ type PromptPrefix struct { EN string `bun:",notnull" json:"en"` } -func NewPromptPrefixStore() *PromptPrefixStore { - return &PromptPrefixStore{db: defaultDB} +type PromptPrefixStore interface { + Get(ctx context.Context) (*PromptPrefix, error) } -func (p *PromptPrefixStore) Get(ctx context.Context) (*PromptPrefix, error) { +func NewPromptPrefixStore() PromptPrefixStore { + return &promptPrefixStoreImpl{db: defaultDB} +} + +func (p *promptPrefixStoreImpl) Get(ctx context.Context) (*PromptPrefix, error) { var prefix PromptPrefix err := p.db.Operator.Core.NewSelect().Model(&prefix).Order("id desc").Limit(1).Scan(ctx) if err != nil { diff --git a/builder/store/database/recom.go b/builder/store/database/recom.go index 9b78c848..5535f19e 100644 --- a/builder/store/database/recom.go +++ b/builder/store/database/recom.go @@ -2,18 +2,28 @@ package database import "context" -type RecomStore struct { +type recomStoreImpl struct { db *DB } -func NewRecomStore() *RecomStore { - return &RecomStore{ +type RecomStore interface { + // Index returns repos in descend order of score. + Index(ctx context.Context, page, pageSize int) ([]*RecomRepoScore, error) + // Upsert recom repo score + UpsertScore(ctx context.Context, repoID int64, score float64) error + LoadWeights(ctx context.Context) ([]*RecomWeight, error) + LoadOpWeights(ctx context.Context) ([]*RecomOpWeight, error) + UpsetOpWeights(ctx context.Context, repoID, weight int64) error +} + +func NewRecomStore() RecomStore { + return &recomStoreImpl{ db: defaultDB, } } // Index returns repos in descend order of score. -func (s *RecomStore) Index(ctx context.Context, page, pageSize int) ([]*RecomRepoScore, error) { +func (s *recomStoreImpl) Index(ctx context.Context, page, pageSize int) ([]*RecomRepoScore, error) { items := make([]*RecomRepoScore, 0) err := s.db.Operator.Core.NewSelect().Model(&RecomRepoScore{}). Order("score desc"). @@ -23,7 +33,7 @@ func (s *RecomStore) Index(ctx context.Context, page, pageSize int) ([]*RecomRep } // Upsert recom repo score -func (s *RecomStore) UpsertScore(ctx context.Context, repoID int64, score float64) error { +func (s *recomStoreImpl) UpsertScore(ctx context.Context, repoID int64, score float64) error { _, err := s.db.Operator.Core.NewInsert(). Model(&RecomRepoScore{ RepositoryID: repoID, @@ -34,19 +44,19 @@ func (s *RecomStore) UpsertScore(ctx context.Context, repoID int64, score float6 return err } -func (s *RecomStore) LoadWeights(ctx context.Context) ([]*RecomWeight, error) { +func (s *recomStoreImpl) LoadWeights(ctx context.Context) ([]*RecomWeight, error) { weights := make([]*RecomWeight, 0) err := s.db.Operator.Core.NewSelect().Model(&RecomWeight{}).Scan(ctx, &weights) return weights, err } -func (s *RecomStore) LoadOpWeights(ctx context.Context) ([]*RecomOpWeight, error) { +func (s *recomStoreImpl) LoadOpWeights(ctx context.Context) ([]*RecomOpWeight, error) { weights := make([]*RecomOpWeight, 0) err := s.db.Operator.Core.NewSelect().Model(&RecomOpWeight{}).Scan(ctx, &weights) return weights, err } -func (s *RecomStore) UpsetOpWeights(ctx context.Context, repoID, weight int64) error { +func (s *recomStoreImpl) UpsetOpWeights(ctx context.Context, repoID, weight int64) error { _, err := s.db.Core.NewInsert(). Model(&RecomOpWeight{ RepositoryID: repoID, diff --git a/builder/store/database/repo_relation.go b/builder/store/database/repo_relation.go index fbaffbde..595c5d9d 100644 --- a/builder/store/database/repo_relation.go +++ b/builder/store/database/repo_relation.go @@ -5,12 +5,25 @@ import ( "fmt" ) -type RepoRelationsStore struct { +type repoRelationsStoreImpl struct { db *DB } -func NewRepoRelationsStore() *RepoRelationsStore { - return &RepoRelationsStore{ +type RepoRelationsStore interface { + // From gets the relationships from a repository + From(ctx context.Context, repoID int64) ([]*RepoRelation, error) + // To gets the relationships to a repository + To(ctx context.Context, repoID int64) ([]*RepoRelation, error) + // Override replaces all existing relationships from a repository to others + // + // `to` can be empty, in which case all existing relationships will be deleted + Override(ctx context.Context, from int64, to ...int64) error + // Delete removes a relationship from a repository to another + Delete(ctx context.Context, from, to int64) error +} + +func NewRepoRelationsStore() RepoRelationsStore { + return &repoRelationsStoreImpl{ db: defaultDB, } } @@ -22,14 +35,14 @@ type RepoRelation struct { } // From gets the relationships from a repository -func (r *RepoRelationsStore) From(ctx context.Context, repoID int64) ([]*RepoRelation, error) { +func (r *repoRelationsStoreImpl) From(ctx context.Context, repoID int64) ([]*RepoRelation, error) { var rrs []*RepoRelation err := r.db.Core.NewSelect().Model(&rrs).Where("from_repo_id = ?", repoID).Scan(ctx) return rrs, err } // To gets the relationships to a repository -func (r *RepoRelationsStore) To(ctx context.Context, repoID int64) ([]*RepoRelation, error) { +func (r *repoRelationsStoreImpl) To(ctx context.Context, repoID int64) ([]*RepoRelation, error) { var rrs []*RepoRelation err := r.db.Core.NewSelect().Model(&rrs).Where("to_repo_id = ?", repoID).Scan(ctx) return rrs, err @@ -38,7 +51,7 @@ func (r *RepoRelationsStore) To(ctx context.Context, repoID int64) ([]*RepoRelat // Override replaces all existing relationships from a repository to others // // `to` can be empty, in which case all existing relationships will be deleted -func (r *RepoRelationsStore) Override(ctx context.Context, from int64, to ...int64) error { +func (r *repoRelationsStoreImpl) Override(ctx context.Context, from int64, to ...int64) error { var relations []*RepoRelation for _, toRepoID := range to { relations = append(relations, &RepoRelation{ @@ -70,7 +83,7 @@ func (r *RepoRelationsStore) Override(ctx context.Context, from int64, to ...int } // Delete removes a relationship from a repository to another -func (r *RepoRelationsStore) Delete(ctx context.Context, from, to int64) error { +func (r *repoRelationsStoreImpl) Delete(ctx context.Context, from, to int64) error { result, err := r.db.Core.NewDelete(). Model((*RepoRelation)(nil)). Where("from_repo_id = ? and to_repo_id = ?", from, to). diff --git a/builder/store/database/repository.go b/builder/store/database/repository.go index 81aea84d..1d49165c 100644 --- a/builder/store/database/repository.go +++ b/builder/store/database/repository.go @@ -19,12 +19,52 @@ var RepositorySourceAndPrefixMapping = map[types.RepositorySource]string{ types.LocalSource: "", } -type RepoStore struct { +type repoStoreImpl struct { db *DB } -func NewRepoStore() *RepoStore { - return &RepoStore{ +type RepoStore interface { + CreateRepoTx(ctx context.Context, tx bun.Tx, input Repository) (*Repository, error) + CreateRepo(ctx context.Context, input Repository) (*Repository, error) + UpdateRepo(ctx context.Context, input Repository) (*Repository, error) + DeleteRepo(ctx context.Context, input Repository) error + Find(ctx context.Context, owner, repoType, repoName string) (*Repository, error) + FindById(ctx context.Context, id int64) (*Repository, error) + FindByIds(ctx context.Context, ids []int64, opts ...SelectOption) ([]*Repository, error) + FindByPath(ctx context.Context, repoType types.RepositoryType, namespace, name string) (*Repository, error) + FindByGitPath(ctx context.Context, path string) (*Repository, error) + FindByGitPaths(ctx context.Context, paths []string, opts ...SelectOption) ([]*Repository, error) + Exists(ctx context.Context, repoType types.RepositoryType, namespace string, name string) (bool, error) + All(ctx context.Context) ([]*Repository, error) + UpdateRepoFileDownloads(ctx context.Context, repo *Repository, date time.Time, clickDownloadCount int64) (err error) + UpdateRepoCloneDownloads(ctx context.Context, repo *Repository, date time.Time, cloneCount int64) (err error) + UpdateDownloads(ctx context.Context, repo *Repository) error + Tags(ctx context.Context, repoID int64) (tags []Tag, err error) + TagsWithCategory(ctx context.Context, repoID int64, category string) (tags []Tag, err error) + // TagIDs get tag ids by repo id, if category is not empty, return only tags of the category + TagIDs(ctx context.Context, repoID int64, category string) (tagIDs []int64, err error) + SetUpdateTimeByPath(ctx context.Context, repoType types.RepositoryType, namespace, name string, update time.Time) error + PublicToUser(ctx context.Context, repoType types.RepositoryType, userIDs []int64, filter *types.RepoFilter, per, page int) (repos []*Repository, count int, err error) + IsMirrorRepo(ctx context.Context, repoType types.RepositoryType, namespace, name string) (bool, error) + ListRepoPublicToUserByRepoIDs(ctx context.Context, repoType types.RepositoryType, userID int64, search, sort string, per, page int, repoIDs []int64) (repos []*Repository, count int, err error) + WithMirror(ctx context.Context, per, page int) (repos []Repository, count int, err error) + CleanRelationsByRepoID(ctx context.Context, repoId int64) error + BatchCreateRepoTags(ctx context.Context, repoTags []RepositoryTag) error + DeleteAllFiles(ctx context.Context, repoID int64) error + DeleteAllTags(ctx context.Context, repoID int64) error + UpdateOrCreateRepo(ctx context.Context, input Repository) (*Repository, error) + UpdateLicenseByTag(ctx context.Context, repoID int64) error + CountByRepoType(ctx context.Context, repoType types.RepositoryType) (int, error) + GetRepoWithoutRuntimeByID(ctx context.Context, rfID int64, paths []string) ([]Repository, error) + GetRepoWithRuntimeByID(ctx context.Context, rfID int64, paths []string) ([]Repository, error) + BatchGet(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, batch int) ([]Repository, error) + FindWithBatch(ctx context.Context, batchSize, batch int) ([]Repository, error) + FindByRepoSourceWithBatch(ctx context.Context, repoSource types.RepositorySource, batchSize, batch int) ([]Repository, error) + ByUser(ctx context.Context, userID int64) ([]Repository, error) +} + +func NewRepoStore() RepoStore { + return &repoStoreImpl{ db: defaultDB, } } @@ -86,7 +126,7 @@ func (r Repository) PathWithOutPrefix() string { } -func (s *RepoStore) CreateRepoTx(ctx context.Context, tx bun.Tx, input Repository) (*Repository, error) { +func (s *repoStoreImpl) CreateRepoTx(ctx context.Context, tx bun.Tx, input Repository) (*Repository, error) { res, err := tx.NewInsert().Model(&input).Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("create repository in tx failed,error:%w", err) @@ -95,7 +135,7 @@ func (s *RepoStore) CreateRepoTx(ctx context.Context, tx bun.Tx, input Repositor return &input, nil } -func (s *RepoStore) CreateRepo(ctx context.Context, input Repository) (*Repository, error) { +func (s *repoStoreImpl) CreateRepo(ctx context.Context, input Repository) (*Repository, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("create repository in tx failed,error:%w", err) @@ -104,19 +144,19 @@ func (s *RepoStore) CreateRepo(ctx context.Context, input Repository) (*Reposito return &input, nil } -func (s *RepoStore) UpdateRepo(ctx context.Context, input Repository) (*Repository, error) { +func (s *repoStoreImpl) UpdateRepo(ctx context.Context, input Repository) (*Repository, error) { _, err := s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return &input, err } -func (s *RepoStore) DeleteRepo(ctx context.Context, input Repository) error { +func (s *repoStoreImpl) DeleteRepo(ctx context.Context, input Repository) error { _, err := s.db.Core.NewDelete().Model(&input).WherePK().Exec(ctx) return err } -func (s *RepoStore) Find(ctx context.Context, owner, repoType, repoName string) (*Repository, error) { +func (s *repoStoreImpl) Find(ctx context.Context, owner, repoType, repoName string) (*Repository, error) { var err error repo := &Repository{} err = s.db.Operator.Core. @@ -128,7 +168,7 @@ func (s *RepoStore) Find(ctx context.Context, owner, repoType, repoName string) return repo, err } -func (s *RepoStore) FindById(ctx context.Context, id int64) (*Repository, error) { +func (s *repoStoreImpl) FindById(ctx context.Context, id int64) (*Repository, error) { resRepo := new(Repository) err := s.db.Operator.Core. NewSelect(). @@ -138,7 +178,7 @@ func (s *RepoStore) FindById(ctx context.Context, id int64) (*Repository, error) return resRepo, err } -func (s *RepoStore) FindByIds(ctx context.Context, ids []int64, opts ...SelectOption) ([]*Repository, error) { +func (s *repoStoreImpl) FindByIds(ctx context.Context, ids []int64, opts ...SelectOption) ([]*Repository, error) { repos := make([]*Repository, 0) q := s.db.Operator.Core. NewSelect() @@ -152,7 +192,7 @@ func (s *RepoStore) FindByIds(ctx context.Context, ids []int64, opts ...SelectOp return repos, err } -func (s *RepoStore) FindByPath(ctx context.Context, repoType types.RepositoryType, namespace, name string) (*Repository, error) { +func (s *repoStoreImpl) FindByPath(ctx context.Context, repoType types.RepositoryType, namespace, name string) (*Repository, error) { resRepo := new(Repository) err := s.db.Operator.Core. NewSelect(). @@ -166,7 +206,7 @@ func (s *RepoStore) FindByPath(ctx context.Context, repoType types.RepositoryTyp return resRepo, err } -func (s *RepoStore) FindByGitPath(ctx context.Context, path string) (*Repository, error) { +func (s *repoStoreImpl) FindByGitPath(ctx context.Context, path string) (*Repository, error) { resRepo := new(Repository) err := s.db.Operator.Core. NewSelect(). @@ -176,7 +216,7 @@ func (s *RepoStore) FindByGitPath(ctx context.Context, path string) (*Repository return resRepo, err } -func (s *RepoStore) FindByGitPaths(ctx context.Context, paths []string, opts ...SelectOption) ([]*Repository, error) { +func (s *repoStoreImpl) FindByGitPaths(ctx context.Context, paths []string, opts ...SelectOption) ([]*Repository, error) { for i := range paths { paths[i] = strings.ToLower(paths[i]) } @@ -192,13 +232,13 @@ func (s *RepoStore) FindByGitPaths(ctx context.Context, paths []string, opts ... return repos, err } -func (s *RepoStore) Exists(ctx context.Context, repoType types.RepositoryType, namespace string, name string) (bool, error) { +func (s *repoStoreImpl) Exists(ctx context.Context, repoType types.RepositoryType, namespace string, name string) (bool, error) { return s.db.Operator.Core.NewSelect().Model((*Repository)(nil)). Where("LOWER(git_path) = LOWER(?)", fmt.Sprintf("%ss_%s/%s", repoType, namespace, name)). Exists(ctx) } -func (s *RepoStore) All(ctx context.Context) ([]*Repository, error) { +func (s *repoStoreImpl) All(ctx context.Context) ([]*Repository, error) { repos := make([]*Repository, 0) err := s.db.Operator.Core. NewSelect(). @@ -207,7 +247,7 @@ func (s *RepoStore) All(ctx context.Context) ([]*Repository, error) { return repos, err } -func (s *RepoStore) UpdateRepoFileDownloads(ctx context.Context, repo *Repository, date time.Time, clickDownloadCount int64) (err error) { +func (s *repoStoreImpl) UpdateRepoFileDownloads(ctx context.Context, repo *Repository, date time.Time, clickDownloadCount int64) (err error) { rd := new(RepositoryDownload) err = s.db.Operator.Core.NewSelect(). Model(rd). @@ -247,7 +287,7 @@ func (s *RepoStore) UpdateRepoFileDownloads(ctx context.Context, repo *Repositor return } -func (s *RepoStore) UpdateRepoCloneDownloads(ctx context.Context, repo *Repository, date time.Time, cloneCount int64) (err error) { +func (s *repoStoreImpl) UpdateRepoCloneDownloads(ctx context.Context, repo *Repository, date time.Time, cloneCount int64) (err error) { rd := new(RepositoryDownload) err = s.db.Operator.Core.NewSelect(). Model(rd). @@ -287,7 +327,7 @@ func (s *RepoStore) UpdateRepoCloneDownloads(ctx context.Context, repo *Reposito return } -func (s *RepoStore) UpdateDownloads(ctx context.Context, repo *Repository) error { +func (s *repoStoreImpl) UpdateDownloads(ctx context.Context, repo *Repository) error { var downloadCount int64 err := s.db.Operator.Core.NewSelect(). ColumnExpr("(SUM(clone_count)+SUM(click_download_count)) AS total_count"). @@ -309,7 +349,7 @@ func (s *RepoStore) UpdateDownloads(ctx context.Context, repo *Repository) error return nil } -func (s *RepoStore) Tags(ctx context.Context, repoID int64) (tags []Tag, err error) { +func (s *repoStoreImpl) Tags(ctx context.Context, repoID int64) (tags []Tag, err error) { query := s.db.Operator.Core.NewSelect(). ColumnExpr("tags.*"). Model(&RepositoryTag{}). @@ -320,7 +360,7 @@ func (s *RepoStore) Tags(ctx context.Context, repoID int64) (tags []Tag, err err return } -func (s *RepoStore) TagsWithCategory(ctx context.Context, repoID int64, category string) (tags []Tag, err error) { +func (s *repoStoreImpl) TagsWithCategory(ctx context.Context, repoID int64, category string) (tags []Tag, err error) { query := s.db.Operator.Core.NewSelect(). ColumnExpr("tags.*"). Model(&RepositoryTag{}). @@ -333,7 +373,7 @@ func (s *RepoStore) TagsWithCategory(ctx context.Context, repoID int64, category } // TagIDs get tag ids by repo id, if category is not empty, return only tags of the category -func (s *RepoStore) TagIDs(ctx context.Context, repoID int64, category string) (tagIDs []int64, err error) { +func (s *repoStoreImpl) TagIDs(ctx context.Context, repoID int64, category string) (tagIDs []int64, err error) { query := s.db.Operator.Core.NewSelect(). Model(&RepositoryTag{}). Join("JOIN tags ON repository_tag.tag_id = tags.id"). @@ -346,7 +386,7 @@ func (s *RepoStore) TagIDs(ctx context.Context, repoID int64, category string) ( return tagIDs, err } -func (s *RepoStore) SetUpdateTimeByPath(ctx context.Context, repoType types.RepositoryType, namespace, name string, update time.Time) error { +func (s *repoStoreImpl) SetUpdateTimeByPath(ctx context.Context, repoType types.RepositoryType, namespace, name string, update time.Time) error { repo := new(Repository) repo.UpdatedAt = update _, err := s.db.Operator.Core.NewUpdate().Model(repo). @@ -356,7 +396,7 @@ func (s *RepoStore) SetUpdateTimeByPath(ctx context.Context, repoType types.Repo return err } -func (s *RepoStore) PublicToUser(ctx context.Context, repoType types.RepositoryType, userIDs []int64, filter *types.RepoFilter, per, page int) (repos []*Repository, count int, err error) { +func (s *repoStoreImpl) PublicToUser(ctx context.Context, repoType types.RepositoryType, userIDs []int64, filter *types.RepoFilter, per, page int) (repos []*Repository, count int, err error) { q := s.db.Operator.Core. NewSelect(). Column("repository.*"). @@ -409,7 +449,7 @@ func (s *RepoStore) PublicToUser(ctx context.Context, repoType types.RepositoryT return } -func (s *RepoStore) IsMirrorRepo(ctx context.Context, repoType types.RepositoryType, namespace, name string) (bool, error) { +func (s *repoStoreImpl) IsMirrorRepo(ctx context.Context, repoType types.RepositoryType, namespace, name string) (bool, error) { var result struct { Exists bool `bun:"exists"` } @@ -427,7 +467,7 @@ func (s *RepoStore) IsMirrorRepo(ctx context.Context, repoType types.RepositoryT return result.Exists, nil } -func (s *RepoStore) ListRepoPublicToUserByRepoIDs(ctx context.Context, repoType types.RepositoryType, userID int64, search, sort string, per, page int, repoIDs []int64) (repos []*Repository, count int, err error) { +func (s *repoStoreImpl) ListRepoPublicToUserByRepoIDs(ctx context.Context, repoType types.RepositoryType, userID int64, search, sort string, per, page int, repoIDs []int64) (repos []*Repository, count int, err error) { q := s.db.Operator.Core. NewSelect(). Column("repository.*"). @@ -474,7 +514,7 @@ func (s *RepoStore) ListRepoPublicToUserByRepoIDs(ctx context.Context, repoType return } -func (s *RepoStore) WithMirror(ctx context.Context, per, page int) (repos []Repository, count int, err error) { +func (s *repoStoreImpl) WithMirror(ctx context.Context, per, page int) (repos []Repository, count int, err error) { q := s.db.Operator.Core.NewSelect(). Model(&repos). Relation("Mirror"). @@ -494,7 +534,7 @@ func (s *RepoStore) WithMirror(ctx context.Context, per, page int) (repos []Repo return } -func (s *RepoStore) CleanRelationsByRepoID(ctx context.Context, repoId int64) error { +func (s *repoStoreImpl) CleanRelationsByRepoID(ctx context.Context, repoId int64) error { err := s.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.Exec("delete from repositories_runtime_frameworks where repo_id=?", repoId); err != nil { return err @@ -508,7 +548,7 @@ func (s *RepoStore) CleanRelationsByRepoID(ctx context.Context, repoId int64) er return err } -func (s *RepoStore) BatchCreateRepoTags(ctx context.Context, repoTags []RepositoryTag) error { +func (s *repoStoreImpl) BatchCreateRepoTags(ctx context.Context, repoTags []RepositoryTag) error { result, err := s.db.Operator.Core.NewInsert(). Model(&repoTags). Exec(ctx) @@ -519,7 +559,7 @@ func (s *RepoStore) BatchCreateRepoTags(ctx context.Context, repoTags []Reposito return assertAffectedXRows(int64(len(repoTags)), result, err) } -func (s *RepoStore) DeleteAllFiles(ctx context.Context, repoID int64) error { +func (s *repoStoreImpl) DeleteAllFiles(ctx context.Context, repoID int64) error { err := s.db.Operator.Core.NewDelete(). Model(&File{}). Where("repository_id = ?", repoID). @@ -531,7 +571,7 @@ func (s *RepoStore) DeleteAllFiles(ctx context.Context, repoID int64) error { return nil } -func (s *RepoStore) DeleteAllTags(ctx context.Context, repoID int64) error { +func (s *repoStoreImpl) DeleteAllTags(ctx context.Context, repoID int64) error { err := s.db.Operator.Core.NewDelete(). Model(&RepositoryTag{}). Where("repository_id = ?", repoID). @@ -543,7 +583,7 @@ func (s *RepoStore) DeleteAllTags(ctx context.Context, repoID int64) error { return nil } -func (s *RepoStore) UpdateOrCreateRepo(ctx context.Context, input Repository) (*Repository, error) { +func (s *repoStoreImpl) UpdateOrCreateRepo(ctx context.Context, input Repository) (*Repository, error) { input.UpdatedAt = time.Now() _, err := s.db.Core.NewUpdate(). Model(&input). @@ -562,7 +602,7 @@ func (s *RepoStore) UpdateOrCreateRepo(ctx context.Context, input Repository) (* return &input, nil } -func (s *RepoStore) UpdateLicenseByTag(ctx context.Context, repoID int64) error { +func (s *repoStoreImpl) UpdateLicenseByTag(ctx context.Context, repoID int64) error { var tag Tag err := s.db.Core.NewSelect(). Model(&tag). @@ -587,11 +627,11 @@ func (s *RepoStore) UpdateLicenseByTag(ctx context.Context, repoID int64) error return nil } -func (s *RepoStore) CountByRepoType(ctx context.Context, repoType types.RepositoryType) (int, error) { +func (s *repoStoreImpl) CountByRepoType(ctx context.Context, repoType types.RepositoryType) (int, error) { return s.db.Core.NewSelect().Model(&Repository{}).Where("repository_type = ?", repoType).Count(ctx) } -func (s *RepoStore) GetRepoWithoutRuntimeByID(ctx context.Context, rfID int64, paths []string) ([]Repository, error) { +func (s *repoStoreImpl) GetRepoWithoutRuntimeByID(ctx context.Context, rfID int64, paths []string) ([]Repository, error) { var res []Repository q := s.db.Operator.Core.NewSelect().Model(&res) if len(paths) > 0 { @@ -606,7 +646,7 @@ func (s *RepoStore) GetRepoWithoutRuntimeByID(ctx context.Context, rfID int64, p return res, nil } -func (s *RepoStore) GetRepoWithRuntimeByID(ctx context.Context, rfID int64, paths []string) ([]Repository, error) { +func (s *repoStoreImpl) GetRepoWithRuntimeByID(ctx context.Context, rfID int64, paths []string) ([]Repository, error) { var res []Repository q := s.db.Operator.Core.NewSelect().Model(&res) if len(paths) > 0 { @@ -621,7 +661,7 @@ func (s *RepoStore) GetRepoWithRuntimeByID(ctx context.Context, rfID int64, path return res, nil } -func (s *RepoStore) BatchGet(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, batch int) ([]Repository, error) { +func (s *repoStoreImpl) BatchGet(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, batch int) ([]Repository, error) { var res []Repository q := s.db.Operator.Core.NewSelect().Model(&res) if lastRepoID > 0 { @@ -637,7 +677,7 @@ func (s *RepoStore) BatchGet(ctx context.Context, repoType types.RepositoryType, return res, nil } -func (s *RepoStore) FindWithBatch(ctx context.Context, batchSize, batch int) ([]Repository, error) { +func (s *repoStoreImpl) FindWithBatch(ctx context.Context, batchSize, batch int) ([]Repository, error) { var res []Repository err := s.db.Operator.Core.NewSelect(). Model(&res). @@ -648,7 +688,7 @@ func (s *RepoStore) FindWithBatch(ctx context.Context, batchSize, batch int) ([] return res, err } -func (s *RepoStore) FindByRepoSourceWithBatch(ctx context.Context, repoSource types.RepositorySource, batchSize, batch int) ([]Repository, error) { +func (s *repoStoreImpl) FindByRepoSourceWithBatch(ctx context.Context, repoSource types.RepositorySource, batchSize, batch int) ([]Repository, error) { var res []Repository err := s.db.Operator.Core.NewSelect(). Model(&res). @@ -660,7 +700,7 @@ func (s *RepoStore) FindByRepoSourceWithBatch(ctx context.Context, repoSource ty return res, err } -func (s *RepoStore) ByUser(ctx context.Context, userID int64) ([]Repository, error) { +func (s *repoStoreImpl) ByUser(ctx context.Context, userID int64) ([]Repository, error) { var repos []Repository err := s.db.Operator.Core.NewSelect().Model(&repos).Where("user_id = ?", userID).Scan(ctx) return repos, err diff --git a/builder/store/database/repository_file.go b/builder/store/database/repository_file.go index f1a241fe..4c268310 100644 --- a/builder/store/database/repository_file.go +++ b/builder/store/database/repository_file.go @@ -21,22 +21,30 @@ type RepositoryFile struct { Repository *Repository `bun:"rel:belongs-to,join:repository_id=id"` } -type RepoFileStore struct { +type repoFileStoreImpl struct { db *DB } -func NewRepoFileStore() *RepoFileStore { - return &RepoFileStore{ +type RepoFileStore interface { + Create(ctx context.Context, file *RepositoryFile) error + BatchGet(ctx context.Context, repoID, lastRepoFileID, batch int64) ([]*RepositoryFile, error) + BatchGetUnchcked(ctx context.Context, repoID, lastRepoFileID, batch int64) ([]*RepositoryFile, error) + Exists(ctx context.Context, file RepositoryFile) (bool, error) + ExistsSensitiveCheckRecord(ctx context.Context, repoID int64, branch string, status types.SensitiveCheckStatus) (bool, error) +} + +func NewRepoFileStore() RepoFileStore { + return &repoFileStoreImpl{ db: defaultDB, } } -func (s *RepoFileStore) Create(ctx context.Context, file *RepositoryFile) error { +func (s *repoFileStoreImpl) Create(ctx context.Context, file *RepositoryFile) error { _, err := s.db.Operator.Core.NewInsert().Model(file).Exec(ctx) return err } -func (s *RepoFileStore) BatchGet(ctx context.Context, repoID, lastRepoFileID, batch int64) ([]*RepositoryFile, error) { +func (s *repoFileStoreImpl) BatchGet(ctx context.Context, repoID, lastRepoFileID, batch int64) ([]*RepositoryFile, error) { files := make([]*RepositoryFile, 0, batch) err := s.db.Operator.Core.NewSelect(). Model(&files). @@ -49,7 +57,7 @@ func (s *RepoFileStore) BatchGet(ctx context.Context, repoID, lastRepoFileID, ba return files, err } -func (s *RepoFileStore) BatchGetUnchcked(ctx context.Context, repoID, lastRepoFileID, batch int64) ([]*RepositoryFile, error) { +func (s *repoFileStoreImpl) BatchGetUnchcked(ctx context.Context, repoID, lastRepoFileID, batch int64) ([]*RepositoryFile, error) { files := make([]*RepositoryFile, 0, batch) err := s.db.Operator.Core.NewSelect(). Model(&files). @@ -63,14 +71,14 @@ func (s *RepoFileStore) BatchGetUnchcked(ctx context.Context, repoID, lastRepoFi return files, err } -func (s *RepoFileStore) Exists(ctx context.Context, file RepositoryFile) (bool, error) { +func (s *repoFileStoreImpl) Exists(ctx context.Context, file RepositoryFile) (bool, error) { slog.Debug("file", slog.Any("file", file)) return s.db.Operator.Core.NewSelect().Model(&file). Where("path = ? and repository_id = ? and branch = ? and COALESCE(commit_sha, '') = ?", file.Path, file.RepositoryID, file.Branch, file.CommitSha). Exists(ctx) } -func (s *RepoFileStore) ExistsSensitiveCheckRecord(ctx context.Context, repoID int64, branch string, status types.SensitiveCheckStatus) (bool, error) { +func (s *repoFileStoreImpl) ExistsSensitiveCheckRecord(ctx context.Context, repoID int64, branch string, status types.SensitiveCheckStatus) (bool, error) { return s.db.Operator.Core.NewSelect().Model(&RepositoryFileCheck{}). Join("INNER JOIN repository_files rf ON rf.id = repository_file_check.repo_file_id"). Where("rf.repository_id = ? and rf.branch = ? and repository_file_check.status = ?", repoID, branch, status). diff --git a/builder/store/database/repository_file_check.go b/builder/store/database/repository_file_check.go index e9e9b7ba..09d82b31 100644 --- a/builder/store/database/repository_file_check.go +++ b/builder/store/database/repository_file_check.go @@ -18,22 +18,27 @@ type RepositoryFileCheck struct { TaskID string `bun:",nullzero" json:"task_id"` } -type RepoFileCheckStore struct { +type repoFileCheckStoreImpl struct { db *DB } -func NewRepoFileCheckStore() *RepoFileCheckStore { - return &RepoFileCheckStore{ +type RepoFileCheckStore interface { + Create(ctx context.Context, history RepositoryFileCheck) error + Upsert(ctx context.Context, history RepositoryFileCheck) error +} + +func NewRepoFileCheckStore() RepoFileCheckStore { + return &repoFileCheckStoreImpl{ db: defaultDB, } } -func (s *RepoFileCheckStore) Create(ctx context.Context, history RepositoryFileCheck) error { +func (s *repoFileCheckStoreImpl) Create(ctx context.Context, history RepositoryFileCheck) error { _, err := s.db.Operator.Core.NewInsert().Model(&history).Exec(ctx) return err } -func (s *RepoFileCheckStore) Upsert(ctx context.Context, history RepositoryFileCheck) error { +func (s *repoFileCheckStoreImpl) Upsert(ctx context.Context, history RepositoryFileCheck) error { _, err := s.db.Operator.Core.NewInsert().Model(&history). On("CONFLICT (repo_file_id) DO UPDATE"). Exec(ctx) diff --git a/builder/store/database/resources_models.go b/builder/store/database/resources_models.go index 8fb893c1..02348331 100644 --- a/builder/store/database/resources_models.go +++ b/builder/store/database/resources_models.go @@ -4,12 +4,19 @@ import ( "context" ) -type ResourceModelStore struct { +type resourceModelStoreImpl struct { db *DB } -func NewResourceModelStore() *ResourceModelStore { - return &ResourceModelStore{db: defaultDB} +type ResourceModelStore interface { + // find multi Resource model by model name with fuzzy matching, parameter modelName like model_name in db + FindByModelName(ctx context.Context, modelName string) ([]*ResourceModel, error) + // find model by name which is in resource model table but not in runtime framework repo + CheckModelNameNotInRFRepo(ctx context.Context, modelName string, repoId int64) (*ResourceModel, error) +} + +func NewResourceModelStore() ResourceModelStore { + return &resourceModelStoreImpl{db: defaultDB} } type ResourceModel struct { @@ -22,14 +29,14 @@ type ResourceModel struct { } // find multi Resource model by model name with fuzzy matching, parameter modelName like model_name in db -func (s *ResourceModelStore) FindByModelName(ctx context.Context, modelName string) ([]*ResourceModel, error) { +func (s *resourceModelStoreImpl) FindByModelName(ctx context.Context, modelName string) ([]*ResourceModel, error) { var models []*ResourceModel err := s.db.Core.NewSelect().Model(&models).Where("model_name LIKE ?", "%"+modelName+"%").Scan(ctx) return models, err } // find model by name which is in resource model table but not in runtime framework repo -func (s *ResourceModelStore) CheckModelNameNotInRFRepo(ctx context.Context, modelName string, repoId int64) (*ResourceModel, error) { +func (s *resourceModelStoreImpl) CheckModelNameNotInRFRepo(ctx context.Context, modelName string, repoId int64) (*ResourceModel, error) { var rm ResourceModel _, err := s.db.Core.NewSelect().Model(&rm). Where("LOWER(model_name) LIKE ?", "%"+modelName+"%"). diff --git a/builder/store/database/runtime_architecture.go b/builder/store/database/runtime_architecture.go index 30c685da..98b1d9b1 100644 --- a/builder/store/database/runtime_architecture.go +++ b/builder/store/database/runtime_architecture.go @@ -8,12 +8,22 @@ import ( "strings" ) -type RuntimeArchitecturesStore struct { +type runtimeArchitecturesStoreImpl struct { db *DB } -func NewRuntimeArchitecturesStore() *RuntimeArchitecturesStore { - return &RuntimeArchitecturesStore{ +type RuntimeArchitecturesStore interface { + ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]RuntimeArchitecture, error) + Add(ctx context.Context, arch RuntimeArchitecture) error + DeleteByRuntimeIDAndArchName(ctx context.Context, id int64, archName string) error + FindByRuntimeIDAndArchName(ctx context.Context, id int64, archName string) (*RuntimeArchitecture, error) + ListByRArchName(ctx context.Context, archName string) ([]RuntimeArchitecture, error) + ListByRArchNameAndModel(ctx context.Context, archName, modelName string) ([]RuntimeArchitecture, error) + GetRuntimeByModelName(ctx context.Context, archName, modelName string) ([]RuntimeArchitecture, error) +} + +func NewRuntimeArchitecturesStore() RuntimeArchitecturesStore { + return &runtimeArchitecturesStoreImpl{ db: defaultDB, } } @@ -24,7 +34,7 @@ type RuntimeArchitecture struct { ArchitectureName string `bun:",notnull" json:"architecture_name"` } -func (ra *RuntimeArchitecturesStore) ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]RuntimeArchitecture, error) { +func (ra *runtimeArchitecturesStoreImpl) ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]RuntimeArchitecture, error) { var result []RuntimeArchitecture _, err := ra.db.Operator.Core.NewSelect().Model(&result).Where("runtime_framework_id = ?", id).Exec(ctx, &result) if err != nil { @@ -33,7 +43,7 @@ func (ra *RuntimeArchitecturesStore) ListByRuntimeFrameworkID(ctx context.Contex return result, nil } -func (ra *RuntimeArchitecturesStore) Add(ctx context.Context, arch RuntimeArchitecture) error { +func (ra *runtimeArchitecturesStoreImpl) Add(ctx context.Context, arch RuntimeArchitecture) error { res, err := ra.db.Core.NewInsert().Model(&arch).Exec(ctx, &arch) if err := assertAffectedOneRow(res, err); err != nil { return fmt.Errorf("creating runtime architecture in the db failed,error:%w", err) @@ -41,7 +51,7 @@ func (ra *RuntimeArchitecturesStore) Add(ctx context.Context, arch RuntimeArchit return nil } -func (ra *RuntimeArchitecturesStore) DeleteByRuntimeIDAndArchName(ctx context.Context, id int64, archName string) error { +func (ra *runtimeArchitecturesStoreImpl) DeleteByRuntimeIDAndArchName(ctx context.Context, id int64, archName string) error { var arch RuntimeArchitecture _, err := ra.db.Core.NewDelete().Model(&arch).Where("runtime_framework_id = ? and architecture_name = ?", id, archName).Exec(ctx) if err != nil { @@ -50,7 +60,7 @@ func (ra *RuntimeArchitecturesStore) DeleteByRuntimeIDAndArchName(ctx context.Co return nil } -func (ra *RuntimeArchitecturesStore) FindByRuntimeIDAndArchName(ctx context.Context, id int64, archName string) (*RuntimeArchitecture, error) { +func (ra *runtimeArchitecturesStoreImpl) FindByRuntimeIDAndArchName(ctx context.Context, id int64, archName string) (*RuntimeArchitecture, error) { var arch RuntimeArchitecture _, err := ra.db.Core.NewSelect().Model(&arch).Where("runtime_framework_id = ? and architecture_name = ?", id, archName).Exec(ctx, &arch) if errors.Is(err, sql.ErrNoRows) { @@ -62,7 +72,7 @@ func (ra *RuntimeArchitecturesStore) FindByRuntimeIDAndArchName(ctx context.Cont return &arch, nil } -func (ra *RuntimeArchitecturesStore) ListByRArchName(ctx context.Context, archName string) ([]RuntimeArchitecture, error) { +func (ra *runtimeArchitecturesStoreImpl) ListByRArchName(ctx context.Context, archName string) ([]RuntimeArchitecture, error) { var result []RuntimeArchitecture _, err := ra.db.Operator.Core.NewSelect().Model(&result).Where("architecture_name = ?", archName).Exec(ctx, &result) if err != nil { @@ -71,7 +81,7 @@ func (ra *RuntimeArchitecturesStore) ListByRArchName(ctx context.Context, archNa return result, nil } -func (ra *RuntimeArchitecturesStore) ListByRArchNameAndModel(ctx context.Context, archName, modelName string) ([]RuntimeArchitecture, error) { +func (ra *runtimeArchitecturesStoreImpl) ListByRArchNameAndModel(ctx context.Context, archName, modelName string) ([]RuntimeArchitecture, error) { var result []RuntimeArchitecture _, err := ra.db.Operator.Core.NewSelect().Model(&result).Where("architecture_name = ?", archName).Exec(ctx, &result) if err != nil { @@ -116,7 +126,7 @@ Meta-Llama-3-8B-Instruct --> llama3-8b-instruct Llama-2-13b-chat --> llama-2-13b-chat */ -func (ra *RuntimeArchitecturesStore) GetRuntimeByModelName(ctx context.Context, archName, modelName string) ([]RuntimeArchitecture, error) { +func (ra *runtimeArchitecturesStoreImpl) GetRuntimeByModelName(ctx context.Context, archName, modelName string) ([]RuntimeArchitecture, error) { var result []RuntimeArchitecture var resModel []ResourceModel err := ra.db.Core.NewSelect().Model(&resModel).Where("LOWER(model_name) like ? and engine_name != ?", fmt.Sprintf("%%%s%%", strings.ToLower(modelName)), "nim").Scan(ctx) diff --git a/builder/store/database/runtime_framework.go b/builder/store/database/runtime_framework.go index eab7366a..9562d5a8 100644 --- a/builder/store/database/runtime_framework.go +++ b/builder/store/database/runtime_framework.go @@ -8,12 +8,25 @@ import ( "github.com/uptrace/bun" ) -type RuntimeFrameworksStore struct { +type runtimeFrameworksStoreImpl struct { db *DB } -func NewRuntimeFrameworksStore() *RuntimeFrameworksStore { - return &RuntimeFrameworksStore{ +type RuntimeFrameworksStore interface { + List(ctx context.Context, deployType int) ([]RuntimeFramework, error) + ListByRepoID(ctx context.Context, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) + FindByID(ctx context.Context, id int64) (*RuntimeFramework, error) + Add(ctx context.Context, frame RuntimeFramework) error + Update(ctx context.Context, frame RuntimeFramework) (*RuntimeFramework, error) + Delete(ctx context.Context, frame RuntimeFramework) error + FindEnabledByID(ctx context.Context, id int64) (*RuntimeFramework, error) + FindEnabledByName(ctx context.Context, name string) (*RuntimeFramework, error) + ListAll(ctx context.Context) ([]RuntimeFramework, error) + ListByIDs(ctx context.Context, ids []int64) ([]RuntimeFramework, error) +} + +func NewRuntimeFrameworksStore() RuntimeFrameworksStore { + return &runtimeFrameworksStoreImpl{ db: defaultDB, } } @@ -30,7 +43,7 @@ type RuntimeFramework struct { times } -func (rf *RuntimeFrameworksStore) List(ctx context.Context, deployType int) ([]RuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) List(ctx context.Context, deployType int) ([]RuntimeFramework, error) { var result []RuntimeFramework _, err := rf.db.Operator.Core.NewSelect().Model(&result).Where("type = ?", deployType).Exec(ctx, &result) if err != nil { @@ -39,7 +52,7 @@ func (rf *RuntimeFrameworksStore) List(ctx context.Context, deployType int) ([]R return result, nil } -func (rf *RuntimeFrameworksStore) ListByRepoID(ctx context.Context, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) ListByRepoID(ctx context.Context, repoID int64, deployType int) ([]RepositoriesRuntimeFramework, error) { var result []RepositoriesRuntimeFramework err := rf.db.Operator.Core.NewSelect().Model(&RepositoriesRuntimeFramework{}).Relation("RuntimeFramework").Where("repositories_runtime_framework.type = ? and repositories_runtime_framework.repo_id = ?", deployType, repoID).Scan(ctx, &result) if err != nil { @@ -48,14 +61,14 @@ func (rf *RuntimeFrameworksStore) ListByRepoID(ctx context.Context, repoID int64 return result, err } -func (rf *RuntimeFrameworksStore) FindByID(ctx context.Context, id int64) (*RuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) FindByID(ctx context.Context, id int64) (*RuntimeFramework, error) { var res RuntimeFramework res.ID = id _, err := rf.db.Core.NewSelect().Model(&res).WherePK().Exec(ctx, &res) return &res, err } -func (rf *RuntimeFrameworksStore) Add(ctx context.Context, frame RuntimeFramework) error { +func (rf *runtimeFrameworksStoreImpl) Add(ctx context.Context, frame RuntimeFramework) error { res, err := rf.db.Core.NewInsert().Model(&frame).Exec(ctx, &frame) if err := assertAffectedOneRow(res, err); err != nil { slog.Error("create runtime framework in db failed", slog.String("error", err.Error())) @@ -64,30 +77,30 @@ func (rf *RuntimeFrameworksStore) Add(ctx context.Context, frame RuntimeFramewor return nil } -func (rf *RuntimeFrameworksStore) Update(ctx context.Context, frame RuntimeFramework) (*RuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) Update(ctx context.Context, frame RuntimeFramework) (*RuntimeFramework, error) { _, err := rf.db.Core.NewUpdate().Model(&frame).WherePK().Exec(ctx) return &frame, err } -func (rf *RuntimeFrameworksStore) Delete(ctx context.Context, frame RuntimeFramework) error { +func (rf *runtimeFrameworksStoreImpl) Delete(ctx context.Context, frame RuntimeFramework) error { _, err := rf.db.Core.NewDelete().Model(&frame).WherePK().Exec(ctx) return err } -func (rf *RuntimeFrameworksStore) FindEnabledByID(ctx context.Context, id int64) (*RuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) FindEnabledByID(ctx context.Context, id int64) (*RuntimeFramework, error) { var res RuntimeFramework res.ID = id _, err := rf.db.Core.NewSelect().Model(&res).WherePK().Where("enabled = 1").Exec(ctx, &res) return &res, err } -func (rf *RuntimeFrameworksStore) FindEnabledByName(ctx context.Context, name string) (*RuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) FindEnabledByName(ctx context.Context, name string) (*RuntimeFramework, error) { var res RuntimeFramework _, err := rf.db.Core.NewSelect().Model(&res).Where("frame_name = ?", name).Where("enabled = 1").Exec(ctx, &res) return &res, err } -func (rf *RuntimeFrameworksStore) ListAll(ctx context.Context) ([]RuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) ListAll(ctx context.Context) ([]RuntimeFramework, error) { var result []RuntimeFramework _, err := rf.db.Operator.Core.NewSelect().Model(&result).Exec(ctx, &result) if err != nil { @@ -96,7 +109,7 @@ func (rf *RuntimeFrameworksStore) ListAll(ctx context.Context) ([]RuntimeFramewo return result, nil } -func (rf *RuntimeFrameworksStore) ListByIDs(ctx context.Context, ids []int64) ([]RuntimeFramework, error) { +func (rf *runtimeFrameworksStoreImpl) ListByIDs(ctx context.Context, ids []int64) ([]RuntimeFramework, error) { var result []RuntimeFramework _, err := rf.db.Operator.Core.NewSelect().Model(&result).Where("id in (?)", bun.In(ids)).Exec(ctx, &result) if err != nil { diff --git a/builder/store/database/space.go b/builder/store/database/space.go index d97bc4cf..c9ecb6d4 100644 --- a/builder/store/database/space.go +++ b/builder/store/database/space.go @@ -9,21 +9,38 @@ import ( "opencsg.com/csghub-server/common/types" ) -type SpaceStore struct { +type spaceStoreImpl struct { db *DB } -func NewSpaceStore() *SpaceStore { - return &SpaceStore{ +type SpaceStore interface { + BeginTx(ctx context.Context) (bun.Tx, error) + CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Space, error) + Create(ctx context.Context, input Space) (*Space, error) + Update(ctx context.Context, input Space) (err error) + FindByPath(ctx context.Context, namespace, name string) (*Space, error) + Delete(ctx context.Context, input Space) error + ByID(ctx context.Context, id int64) (*Space, error) + // ByRepoIDs get spaces by repoIDs, only basice info, no related repo + ByRepoIDs(ctx context.Context, repoIDs []int64) (spaces []Space, err error) + ByRepoID(ctx context.Context, repoID int64) (*Space, error) + ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (spaces []Space, total int, err error) + ByUserLikes(ctx context.Context, userID int64, per, page int) (spaces []Space, total int, err error) + ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (spaces []Space, total int, err error) + ListByPath(ctx context.Context, paths []string) ([]Space, error) +} + +func NewSpaceStore() SpaceStore { + return &spaceStoreImpl{ db: defaultDB, } } -func (s *SpaceStore) BeginTx(ctx context.Context) (bun.Tx, error) { +func (s *spaceStoreImpl) BeginTx(ctx context.Context) (bun.Tx, error) { return s.db.Core.BeginTx(ctx, nil) } -func (s *SpaceStore) CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Space, error) { +func (s *spaceStoreImpl) CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Space, error) { res, err := tx.NewInsert().Model(&input).Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { slog.Error("create space in tx failed", slog.String("error", err.Error())) @@ -34,7 +51,7 @@ func (s *SpaceStore) CreateTx(ctx context.Context, tx bun.Tx, input Space) (*Spa return &input, nil } -func (s *SpaceStore) Create(ctx context.Context, input Space) (*Space, error) { +func (s *spaceStoreImpl) Create(ctx context.Context, input Space) (*Space, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { slog.Error("create space in db failed", slog.String("error", err.Error())) @@ -45,12 +62,12 @@ func (s *SpaceStore) Create(ctx context.Context, input Space) (*Space, error) { return &input, nil } -func (s *SpaceStore) Update(ctx context.Context, input Space) (err error) { +func (s *spaceStoreImpl) Update(ctx context.Context, input Space) (err error) { _, err = s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return } -func (s *SpaceStore) FindByPath(ctx context.Context, namespace, name string) (*Space, error) { +func (s *spaceStoreImpl) FindByPath(ctx context.Context, namespace, name string) (*Space, error) { resSpace := new(Space) err := s.db.Operator.Core. NewSelect(). @@ -65,7 +82,7 @@ func (s *SpaceStore) FindByPath(ctx context.Context, namespace, name string) (*S return resSpace, err } -func (s *SpaceStore) Delete(ctx context.Context, input Space) error { +func (s *spaceStoreImpl) Delete(ctx context.Context, input Space) error { res, err := s.db.Operator.Core.NewDelete().Model(&input).WherePK().Exec(ctx) if err := assertAffectedOneRow(res, err); err != nil { return fmt.Errorf("delete space in tx failed,error:%w", err) @@ -73,7 +90,7 @@ func (s *SpaceStore) Delete(ctx context.Context, input Space) error { return nil } -func (s *SpaceStore) ByID(ctx context.Context, id int64) (*Space, error) { +func (s *spaceStoreImpl) ByID(ctx context.Context, id int64) (*Space, error) { var space Space err := s.db.Core.NewSelect().Model(&space).Relation("Repository").Where("space.id = ?", id).Scan(ctx) if err != nil { @@ -83,7 +100,7 @@ func (s *SpaceStore) ByID(ctx context.Context, id int64) (*Space, error) { } // ByRepoIDs get spaces by repoIDs, only basice info, no related repo -func (s *SpaceStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (spaces []Space, err error) { +func (s *spaceStoreImpl) ByRepoIDs(ctx context.Context, repoIDs []int64) (spaces []Space, err error) { err = s.db.Operator.Core.NewSelect(). Model(&spaces). Where("repository_id in (?)", bun.In(repoIDs)). @@ -92,7 +109,7 @@ func (s *SpaceStore) ByRepoIDs(ctx context.Context, repoIDs []int64) (spaces []S return } -func (s *SpaceStore) ByRepoID(ctx context.Context, repoID int64) (*Space, error) { +func (s *spaceStoreImpl) ByRepoID(ctx context.Context, repoID int64) (*Space, error) { var space Space err := s.db.Core.NewSelect().Model(&space).Where("repository_id = ?", repoID).Scan(ctx) if err != nil { @@ -101,7 +118,7 @@ func (s *SpaceStore) ByRepoID(ctx context.Context, repoID int64) (*Space, error) return &space, err } -func (s *SpaceStore) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (spaces []Space, total int, err error) { +func (s *spaceStoreImpl) ByUsername(ctx context.Context, username string, per, page int, onlyPublic bool) (spaces []Space, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&spaces). @@ -126,7 +143,7 @@ func (s *SpaceStore) ByUsername(ctx context.Context, username string, per, page return } -func (s *SpaceStore) ByUserLikes(ctx context.Context, userID int64, per, page int) (spaces []Space, total int, err error) { +func (s *spaceStoreImpl) ByUserLikes(ctx context.Context, userID int64, per, page int) (spaces []Space, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&spaces). @@ -148,7 +165,7 @@ func (s *SpaceStore) ByUserLikes(ctx context.Context, userID int64, per, page in return } -func (s *SpaceStore) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (spaces []Space, total int, err error) { +func (s *spaceStoreImpl) ByOrgPath(ctx context.Context, namespace string, per, page int, onlyPublic bool) (spaces []Space, total int, err error) { query := s.db.Operator.Core. NewSelect(). Model(&spaces). @@ -174,7 +191,7 @@ func (s *SpaceStore) ByOrgPath(ctx context.Context, namespace string, per, page return } -func (s *SpaceStore) ListByPath(ctx context.Context, paths []string) ([]Space, error) { +func (s *spaceStoreImpl) ListByPath(ctx context.Context, paths []string) ([]Space, error) { var spaces []Space err := s.db.Operator.Core. NewSelect(). diff --git a/builder/store/database/space_resource.go b/builder/store/database/space_resource.go index f11d9d07..deaeada8 100644 --- a/builder/store/database/space_resource.go +++ b/builder/store/database/space_resource.go @@ -5,12 +5,22 @@ import ( "fmt" ) -type SpaceResourceStore struct { +type spaceResourceStoreImpl struct { db *DB } -func NewSpaceResourceStore() *SpaceResourceStore { - return &SpaceResourceStore{db: defaultDB} +type SpaceResourceStore interface { + Index(ctx context.Context, clusterId string) ([]SpaceResource, error) + Create(ctx context.Context, input SpaceResource) (*SpaceResource, error) + Update(ctx context.Context, input SpaceResource) (*SpaceResource, error) + Delete(ctx context.Context, input SpaceResource) error + FindByID(ctx context.Context, id int64) (*SpaceResource, error) + FindByName(ctx context.Context, name string) (*SpaceResource, error) + FindAll(ctx context.Context) ([]SpaceResource, error) +} + +func NewSpaceResourceStore() SpaceResourceStore { + return &spaceResourceStoreImpl{db: defaultDB} } type SpaceResource struct { @@ -21,7 +31,7 @@ type SpaceResource struct { times } -func (s *SpaceResourceStore) Index(ctx context.Context, clusterId string) ([]SpaceResource, error) { +func (s *spaceResourceStoreImpl) Index(ctx context.Context, clusterId string) ([]SpaceResource, error) { var result []SpaceResource _, err := s.db.Operator.Core.NewSelect().Model(&result).Where("cluster_id = ?", clusterId).Order("name asc").Exec(ctx, &result) if err != nil { @@ -30,7 +40,7 @@ func (s *SpaceResourceStore) Index(ctx context.Context, clusterId string) ([]Spa return result, nil } -func (s *SpaceResourceStore) Create(ctx context.Context, input SpaceResource) (*SpaceResource, error) { +func (s *spaceResourceStoreImpl) Create(ctx context.Context, input SpaceResource) (*SpaceResource, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("create space resource in tx failed,error:%w", err) @@ -39,19 +49,19 @@ func (s *SpaceResourceStore) Create(ctx context.Context, input SpaceResource) (* return &input, nil } -func (s *SpaceResourceStore) Update(ctx context.Context, input SpaceResource) (*SpaceResource, error) { +func (s *spaceResourceStoreImpl) Update(ctx context.Context, input SpaceResource) (*SpaceResource, error) { _, err := s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return &input, err } -func (s *SpaceResourceStore) Delete(ctx context.Context, input SpaceResource) error { +func (s *spaceResourceStoreImpl) Delete(ctx context.Context, input SpaceResource) error { _, err := s.db.Core.NewDelete().Model(&input).WherePK().Exec(ctx) return err } -func (s *SpaceResourceStore) FindByID(ctx context.Context, id int64) (*SpaceResource, error) { +func (s *spaceResourceStoreImpl) FindByID(ctx context.Context, id int64) (*SpaceResource, error) { var res SpaceResource res.ID = id _, err := s.db.Core.NewSelect().Model(&res).WherePK().Exec(ctx, &res) @@ -59,14 +69,14 @@ func (s *SpaceResourceStore) FindByID(ctx context.Context, id int64) (*SpaceReso return &res, err } -func (s *SpaceResourceStore) FindByName(ctx context.Context, name string) (*SpaceResource, error) { +func (s *spaceResourceStoreImpl) FindByName(ctx context.Context, name string) (*SpaceResource, error) { var res SpaceResource err := s.db.Core.NewSelect().Model(&res).Where("name = ?", name).Scan(ctx) return &res, err } -func (s *SpaceResourceStore) FindAll(ctx context.Context) ([]SpaceResource, error) { +func (s *spaceResourceStoreImpl) FindAll(ctx context.Context) ([]SpaceResource, error) { var result []SpaceResource _, err := s.db.Operator.Core.NewSelect().Model(&result).Exec(ctx, &result) if err != nil { diff --git a/builder/store/database/space_sdk.go b/builder/store/database/space_sdk.go index 48fd7d4a..a5f243aa 100644 --- a/builder/store/database/space_sdk.go +++ b/builder/store/database/space_sdk.go @@ -5,12 +5,20 @@ import ( "fmt" ) -type SpaceSdkStore struct { +type spaceSdkStoreImpl struct { db *DB } -func NewSpaceSdkStore() *SpaceSdkStore { - return &SpaceSdkStore{db: defaultDB} +type SpaceSdkStore interface { + Index(ctx context.Context) ([]SpaceSdk, error) + Create(ctx context.Context, input SpaceSdk) (*SpaceSdk, error) + Update(ctx context.Context, input SpaceSdk) (*SpaceSdk, error) + Delete(ctx context.Context, input SpaceSdk) error + FindByID(ctx context.Context, id int64) (*SpaceSdk, error) +} + +func NewSpaceSdkStore() SpaceSdkStore { + return &spaceSdkStoreImpl{db: defaultDB} } type SpaceSdk struct { @@ -20,7 +28,7 @@ type SpaceSdk struct { times } -func (s *SpaceSdkStore) Index(ctx context.Context) ([]SpaceSdk, error) { +func (s *spaceSdkStoreImpl) Index(ctx context.Context) ([]SpaceSdk, error) { var result []SpaceSdk _, err := s.db.Operator.Core. NewSelect(). @@ -32,7 +40,7 @@ func (s *SpaceSdkStore) Index(ctx context.Context) ([]SpaceSdk, error) { return result, nil } -func (s *SpaceSdkStore) Create(ctx context.Context, input SpaceSdk) (*SpaceSdk, error) { +func (s *spaceSdkStoreImpl) Create(ctx context.Context, input SpaceSdk) (*SpaceSdk, error) { res, err := s.db.Core.NewInsert().Model(&input).Exec(ctx, &input) if err := assertAffectedOneRow(res, err); err != nil { return nil, fmt.Errorf("create space sdk in tx failed,error:%w", err) @@ -41,19 +49,19 @@ func (s *SpaceSdkStore) Create(ctx context.Context, input SpaceSdk) (*SpaceSdk, return &input, nil } -func (s *SpaceSdkStore) Update(ctx context.Context, input SpaceSdk) (*SpaceSdk, error) { +func (s *spaceSdkStoreImpl) Update(ctx context.Context, input SpaceSdk) (*SpaceSdk, error) { _, err := s.db.Core.NewUpdate().Model(&input).WherePK().Exec(ctx) return &input, err } -func (s *SpaceSdkStore) Delete(ctx context.Context, input SpaceSdk) error { +func (s *spaceSdkStoreImpl) Delete(ctx context.Context, input SpaceSdk) error { _, err := s.db.Core.NewDelete().Model(&input).WherePK().Exec(ctx) return err } -func (s *SpaceSdkStore) FindByID(ctx context.Context, id int64) (*SpaceSdk, error) { +func (s *spaceSdkStoreImpl) FindByID(ctx context.Context, id int64) (*SpaceSdk, error) { var res SpaceSdk res.ID = id _, err := s.db.Core.NewSelect().Model(&res).WherePK().Exec(ctx, &res) diff --git a/builder/store/database/ssh_key.go b/builder/store/database/ssh_key.go index 046f7ad2..dd49d114 100644 --- a/builder/store/database/ssh_key.go +++ b/builder/store/database/ssh_key.go @@ -4,12 +4,24 @@ import ( "context" ) -type SSHKeyStore struct { +type sSHKeyStoreImpl struct { db *DB } -func NewSSHKeyStore() *SSHKeyStore { - return &SSHKeyStore{ +type SSHKeyStore interface { + Index(ctx context.Context, username string, per, page int) (sshKeys []SSHKey, err error) + Create(ctx context.Context, sshKey *SSHKey) (*SSHKey, error) + FindByID(ctx context.Context, id int64) (*SSHKey, error) + FindByFingerpringSHA256(ctx context.Context, fingerprint string) (*SSHKey, error) + Delete(ctx context.Context, gid int64) (err error) + IsExist(ctx context.Context, username, keyName string) (exists bool, err error) + FindByUsernameAndName(ctx context.Context, username, keyName string) (sshKey SSHKey, err error) + FindByKeyContent(ctx context.Context, key string) (*SSHKey, error) + FindByNameAndUserID(ctx context.Context, name string, userID int64) (*SSHKey, error) +} + +func NewSSHKeyStore() SSHKeyStore { + return &sSHKeyStoreImpl{ db: defaultDB, } } @@ -25,7 +37,7 @@ type SSHKey struct { times } -func (s *SSHKeyStore) Index(ctx context.Context, username string, per, page int) (sshKeys []SSHKey, err error) { +func (s *sSHKeyStoreImpl) Index(ctx context.Context, username string, per, page int) (sshKeys []SSHKey, err error) { err = s.db.Operator.Core. NewSelect(). Model(&sshKeys). @@ -39,7 +51,7 @@ func (s *SSHKeyStore) Index(ctx context.Context, username string, per, page int) return } -func (s *SSHKeyStore) Create(ctx context.Context, sshKey *SSHKey) (*SSHKey, error) { +func (s *sSHKeyStoreImpl) Create(ctx context.Context, sshKey *SSHKey) (*SSHKey, error) { err := s.db.Operator.Core. NewInsert(). Model(sshKey). @@ -48,7 +60,7 @@ func (s *SSHKeyStore) Create(ctx context.Context, sshKey *SSHKey) (*SSHKey, erro return sshKey, err } -func (s *SSHKeyStore) FindByID(ctx context.Context, id int64) (*SSHKey, error) { +func (s *sSHKeyStoreImpl) FindByID(ctx context.Context, id int64) (*SSHKey, error) { var sshKey SSHKey err := s.db.Operator.Core. NewSelect(). @@ -59,7 +71,7 @@ func (s *SSHKeyStore) FindByID(ctx context.Context, id int64) (*SSHKey, error) { return &sshKey, err } -func (s *SSHKeyStore) FindByFingerpringSHA256(ctx context.Context, fingerprint string) (*SSHKey, error) { +func (s *sSHKeyStoreImpl) FindByFingerpringSHA256(ctx context.Context, fingerprint string) (*SSHKey, error) { var sshKey SSHKey err := s.db.Operator.Core. NewSelect(). @@ -70,7 +82,7 @@ func (s *SSHKeyStore) FindByFingerpringSHA256(ctx context.Context, fingerprint s return &sshKey, err } -func (s *SSHKeyStore) Delete(ctx context.Context, gid int64) (err error) { +func (s *sSHKeyStoreImpl) Delete(ctx context.Context, gid int64) (err error) { var sshKey SSHKey _, err = s.db.Operator.Core. NewDelete(). @@ -80,7 +92,7 @@ func (s *SSHKeyStore) Delete(ctx context.Context, gid int64) (err error) { return } -func (s *SSHKeyStore) IsExist(ctx context.Context, username, keyName string) (exists bool, err error) { +func (s *sSHKeyStoreImpl) IsExist(ctx context.Context, username, keyName string) (exists bool, err error) { var sshKey SSHKey exists, err = s.db.Operator.Core. NewSelect(). @@ -92,7 +104,7 @@ func (s *SSHKeyStore) IsExist(ctx context.Context, username, keyName string) (ex return } -func (s *SSHKeyStore) FindByUsernameAndName(ctx context.Context, username, keyName string) (sshKey SSHKey, err error) { +func (s *sSHKeyStoreImpl) FindByUsernameAndName(ctx context.Context, username, keyName string) (sshKey SSHKey, err error) { sshKey.Name = keyName err = s.db.Operator.Core. NewSelect(). @@ -104,7 +116,7 @@ func (s *SSHKeyStore) FindByUsernameAndName(ctx context.Context, username, keyNa return sshKey, err } -func (s *SSHKeyStore) FindByKeyContent(ctx context.Context, key string) (*SSHKey, error) { +func (s *sSHKeyStoreImpl) FindByKeyContent(ctx context.Context, key string) (*SSHKey, error) { sshKey := new(SSHKey) err := s.db.Operator.Core. NewSelect(). @@ -114,7 +126,7 @@ func (s *SSHKeyStore) FindByKeyContent(ctx context.Context, key string) (*SSHKey return sshKey, err } -func (s *SSHKeyStore) FindByNameAndUserID(ctx context.Context, name string, userID int64) (*SSHKey, error) { +func (s *sSHKeyStoreImpl) FindByNameAndUserID(ctx context.Context, name string, userID int64) (*SSHKey, error) { sshKey := new(SSHKey) err := s.db.Operator.Core. NewSelect(). diff --git a/builder/store/database/sync_client_setting.go b/builder/store/database/sync_client_setting.go index 2c8b4617..f80c7c0d 100644 --- a/builder/store/database/sync_client_setting.go +++ b/builder/store/database/sync_client_setting.go @@ -2,12 +2,19 @@ package database import "context" -type SyncClientSettingStore struct { +type syncClientSettingStoreImpl struct { db *DB } -func NewSyncClientSettingStore() *SyncClientSettingStore { - return &SyncClientSettingStore{ +type SyncClientSettingStore interface { + Create(ctx context.Context, setting *SyncClientSetting) (*SyncClientSetting, error) + SyncClientSettingExists(ctx context.Context) (bool, error) + DeleteAll(ctx context.Context) error + First(ctx context.Context) (*SyncClientSetting, error) +} + +func NewSyncClientSettingStore() SyncClientSettingStore { + return &syncClientSettingStoreImpl{ db: defaultDB, } } @@ -21,7 +28,7 @@ type SyncClientSetting struct { times } -func (s *SyncClientSettingStore) Create(ctx context.Context, setting *SyncClientSetting) (*SyncClientSetting, error) { +func (s *syncClientSettingStoreImpl) Create(ctx context.Context, setting *SyncClientSetting) (*SyncClientSetting, error) { err := s.db.Operator.Core.NewInsert(). Model(setting). Scan(ctx) @@ -31,18 +38,18 @@ func (s *SyncClientSettingStore) Create(ctx context.Context, setting *SyncClient return setting, nil } -func (s *SyncClientSettingStore) SyncClientSettingExists(ctx context.Context) (bool, error) { +func (s *syncClientSettingStoreImpl) SyncClientSettingExists(ctx context.Context) (bool, error) { return s.db.Operator.Core.NewSelect(). Model((*SyncClientSetting)(nil)). Exists(ctx) } -func (s *SyncClientSettingStore) DeleteAll(ctx context.Context) error { +func (s *syncClientSettingStoreImpl) DeleteAll(ctx context.Context) error { _, err := s.db.Operator.Core.NewDelete().Model((*SyncClientSetting)(nil)).Where("1=1").Exec(ctx) return err } -func (s *SyncClientSettingStore) First(ctx context.Context) (*SyncClientSetting, error) { +func (s *syncClientSettingStoreImpl) First(ctx context.Context) (*SyncClientSetting, error) { var mt SyncClientSetting err := s.db.Operator.Core.NewSelect(). Model(&mt). diff --git a/builder/store/database/sync_version.go b/builder/store/database/sync_version.go index d6e2bb64..15762858 100644 --- a/builder/store/database/sync_version.go +++ b/builder/store/database/sync_version.go @@ -6,29 +6,36 @@ import ( "opencsg.com/csghub-server/common/types" ) -type SyncVersionStore struct { +type syncVersionStoreImpl struct { db *DB } type SyncVersionSource int -func NewSyncVersionStore() *SyncVersionStore { - return &SyncVersionStore{ +type SyncVersionStore interface { + Create(ctx context.Context, version *SyncVersion) (err error) + BatchCreate(ctx context.Context, versions []SyncVersion) error + FindByPath(ctx context.Context, path string) (*SyncVersion, error) + FindByRepoTypeAndPath(ctx context.Context, path string, repoType types.RepositoryType) (*SyncVersion, error) +} + +func NewSyncVersionStore() SyncVersionStore { + return &syncVersionStoreImpl{ db: defaultDB, } } -func (s *SyncVersionStore) Create(ctx context.Context, version *SyncVersion) (err error) { +func (s *syncVersionStoreImpl) Create(ctx context.Context, version *SyncVersion) (err error) { _, err = s.db.Operator.Core.NewInsert().Model(version).Exec(ctx) return } -func (s *SyncVersionStore) BatchCreate(ctx context.Context, versions []SyncVersion) error { +func (s *syncVersionStoreImpl) BatchCreate(ctx context.Context, versions []SyncVersion) error { result, err := s.db.Core.NewInsert().Model(&versions).Exec(ctx) return assertAffectedXRows(int64(len(versions)), result, err) } -func (s *SyncVersionStore) FindByPath(ctx context.Context, path string) (*SyncVersion, error) { +func (s *syncVersionStoreImpl) FindByPath(ctx context.Context, path string) (*SyncVersion, error) { var syncVersion SyncVersion err := s.db.Core.NewSelect(). Model(&syncVersion). @@ -41,7 +48,7 @@ func (s *SyncVersionStore) FindByPath(ctx context.Context, path string) (*SyncVe return &syncVersion, nil } -func (s *SyncVersionStore) FindByRepoTypeAndPath(ctx context.Context, path string, repoType types.RepositoryType) (*SyncVersion, error) { +func (s *syncVersionStoreImpl) FindByRepoTypeAndPath(ctx context.Context, path string, repoType types.RepositoryType) (*SyncVersion, error) { var syncVersion SyncVersion err := s.db.Core.NewSelect(). Model(&syncVersion). diff --git a/builder/store/database/tag.go b/builder/store/database/tag.go index 68073224..15886f2f 100644 --- a/builder/store/database/tag.go +++ b/builder/store/database/tag.go @@ -12,12 +12,39 @@ import ( "opencsg.com/csghub-server/common/types" ) -type TagStore struct { +type tagStoreImpl struct { db *DB } -func NewTagStore() *TagStore { - return &TagStore{ +type TagStore interface { + // Alltags returns all tags in the database + AllTags(ctx context.Context) ([]Tag, error) + AllTagsByScope(ctx context.Context, scope TagScope) ([]*Tag, error) + AllTagsByScopeAndCategory(ctx context.Context, scope TagScope, category string) ([]*Tag, error) + GetTagsByScopeAndCategories(ctx context.Context, scope TagScope, categories []string) ([]*Tag, error) + AllModelTags(ctx context.Context) ([]*Tag, error) + AllPromptTags(ctx context.Context) ([]*Tag, error) + AllDatasetTags(ctx context.Context) ([]*Tag, error) + AllCodeTags(ctx context.Context) ([]*Tag, error) + AllSpaceTags(ctx context.Context) ([]*Tag, error) + AllModelCategories(ctx context.Context) ([]TagCategory, error) + AllPromptCategories(ctx context.Context) ([]TagCategory, error) + AllDatasetCategories(ctx context.Context) ([]TagCategory, error) + AllCodeCategories(ctx context.Context) ([]TagCategory, error) + AllSpaceCategories(ctx context.Context) ([]TagCategory, error) + CreateTag(ctx context.Context, category, name, group string, scope TagScope) (Tag, error) + SaveTags(ctx context.Context, tags []*Tag) error + UpsertTags(ctx context.Context, tagScope TagScope, categoryTagMap map[string][]string) ([]Tag, error) + // SetMetaTags will delete existing tags and create new ones + SetMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string, tags []*Tag) (repoTags []*RepositoryTag, err error) + SetLibraryTag(ctx context.Context, repoType types.RepositoryType, namespace, name string, newTag, oldTag *Tag) (err error) + UpsertRepoTags(ctx context.Context, repoID int64, oldTagIDs, newTagIDs []int64) (err error) + RemoveRepoTags(ctx context.Context, repoID int64, tagIDs []int64) (err error) + FindOrCreate(ctx context.Context, tag Tag) (*Tag, error) +} + +func NewTagStore() TagStore { + return &tagStoreImpl{ db: defaultDB, } } @@ -53,7 +80,7 @@ type TagCategory struct { } // Alltags returns all tags in the database -func (ts *TagStore) AllTags(ctx context.Context) ([]Tag, error) { +func (ts *tagStoreImpl) AllTags(ctx context.Context) ([]Tag, error) { var tags []Tag err := ts.db.Operator.Core.NewSelect().Model(&Tag{}).Scan(ctx, &tags) if err != nil { @@ -63,7 +90,7 @@ func (ts *TagStore) AllTags(ctx context.Context) ([]Tag, error) { return tags, nil } -func (ts *TagStore) AllTagsByScope(ctx context.Context, scope TagScope) ([]*Tag, error) { +func (ts *tagStoreImpl) AllTagsByScope(ctx context.Context, scope TagScope) ([]*Tag, error) { var tags []*Tag err := ts.db.Operator.Core.NewSelect().Model(&tags). Where("scope =?", scope). @@ -74,7 +101,7 @@ func (ts *TagStore) AllTagsByScope(ctx context.Context, scope TagScope) ([]*Tag, return tags, nil } -func (ts *TagStore) AllTagsByScopeAndCategory(ctx context.Context, scope TagScope, category string) ([]*Tag, error) { +func (ts *tagStoreImpl) AllTagsByScopeAndCategory(ctx context.Context, scope TagScope, category string) ([]*Tag, error) { var tags []*Tag err := ts.db.Operator.Core.NewSelect().Model(&tags). Where("scope = ? and category = ?", scope, category). @@ -85,7 +112,7 @@ func (ts *TagStore) AllTagsByScopeAndCategory(ctx context.Context, scope TagScop return tags, nil } -func (ts *TagStore) GetTagsByScopeAndCategories(ctx context.Context, scope TagScope, categories []string) ([]*Tag, error) { +func (ts *tagStoreImpl) GetTagsByScopeAndCategories(ctx context.Context, scope TagScope, categories []string) ([]*Tag, error) { var tags []*Tag err := ts.db.Operator.Core.NewSelect().Model(&tags). Where("scope = ? and category in (?)", scope, bun.In(categories)). @@ -96,47 +123,47 @@ func (ts *TagStore) GetTagsByScopeAndCategories(ctx context.Context, scope TagSc return tags, nil } -func (ts *TagStore) AllModelTags(ctx context.Context) ([]*Tag, error) { +func (ts *tagStoreImpl) AllModelTags(ctx context.Context) ([]*Tag, error) { return ts.AllTagsByScope(ctx, ModelTagScope) } -func (ts *TagStore) AllPromptTags(ctx context.Context) ([]*Tag, error) { +func (ts *tagStoreImpl) AllPromptTags(ctx context.Context) ([]*Tag, error) { return ts.AllTagsByScope(ctx, PromptTagScope) } -func (ts *TagStore) AllDatasetTags(ctx context.Context) ([]*Tag, error) { +func (ts *tagStoreImpl) AllDatasetTags(ctx context.Context) ([]*Tag, error) { return ts.AllTagsByScope(ctx, DatasetTagScope) } -func (ts *TagStore) AllCodeTags(ctx context.Context) ([]*Tag, error) { +func (ts *tagStoreImpl) AllCodeTags(ctx context.Context) ([]*Tag, error) { return ts.AllTagsByScope(ctx, CodeTagScope) } -func (ts *TagStore) AllSpaceTags(ctx context.Context) ([]*Tag, error) { +func (ts *tagStoreImpl) AllSpaceTags(ctx context.Context) ([]*Tag, error) { return ts.AllTagsByScope(ctx, SpaceTagScope) } -func (ts *TagStore) AllModelCategories(ctx context.Context) ([]TagCategory, error) { +func (ts *tagStoreImpl) AllModelCategories(ctx context.Context) ([]TagCategory, error) { return ts.allCategories(ctx, ModelTagScope) } -func (ts *TagStore) AllPromptCategories(ctx context.Context) ([]TagCategory, error) { +func (ts *tagStoreImpl) AllPromptCategories(ctx context.Context) ([]TagCategory, error) { return ts.allCategories(ctx, PromptTagScope) } -func (ts *TagStore) AllDatasetCategories(ctx context.Context) ([]TagCategory, error) { +func (ts *tagStoreImpl) AllDatasetCategories(ctx context.Context) ([]TagCategory, error) { return ts.allCategories(ctx, DatasetTagScope) } -func (ts *TagStore) AllCodeCategories(ctx context.Context) ([]TagCategory, error) { +func (ts *tagStoreImpl) AllCodeCategories(ctx context.Context) ([]TagCategory, error) { return ts.allCategories(ctx, CodeTagScope) } -func (ts *TagStore) AllSpaceCategories(ctx context.Context) ([]TagCategory, error) { +func (ts *tagStoreImpl) AllSpaceCategories(ctx context.Context) ([]TagCategory, error) { return ts.allCategories(ctx, SpaceTagScope) } -func (ts *TagStore) allCategories(ctx context.Context, scope TagScope) ([]TagCategory, error) { +func (ts *tagStoreImpl) allCategories(ctx context.Context, scope TagScope) ([]TagCategory, error) { var tags []TagCategory err := ts.db.Operator.Core.NewSelect().Model(&TagCategory{}). Where("scope = ?", scope). @@ -148,7 +175,7 @@ func (ts *TagStore) allCategories(ctx context.Context, scope TagScope) ([]TagCat return tags, nil } -func (ts *TagStore) CreateTag(ctx context.Context, category, name, group string, scope TagScope) (Tag, error) { +func (ts *tagStoreImpl) CreateTag(ctx context.Context, category, name, group string, scope TagScope) (Tag, error) { tag := Tag{ Name: name, Category: category, @@ -159,7 +186,7 @@ func (ts *TagStore) CreateTag(ctx context.Context, category, name, group string, return tag, err } -func (ts *TagStore) SaveTags(ctx context.Context, tags []*Tag) error { +func (ts *tagStoreImpl) SaveTags(ctx context.Context, tags []*Tag) error { if len(tags) == 0 { return nil } @@ -171,7 +198,7 @@ func (ts *TagStore) SaveTags(ctx context.Context, tags []*Tag) error { return nil } -func (ts *TagStore) UpsertTags(ctx context.Context, tagScope TagScope, categoryTagMap map[string][]string) ([]Tag, error) { +func (ts *tagStoreImpl) UpsertTags(ctx context.Context, tagScope TagScope, categoryTagMap map[string][]string) ([]Tag, error) { var tags []Tag for category, tagNames := range categoryTagMap { ctags := make([]Tag, 0) @@ -201,7 +228,7 @@ func (ts *TagStore) UpsertTags(ctx context.Context, tagScope TagScope, categoryT } // SetMetaTags will delete existing tags and create new ones -func (ts *TagStore) SetMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string, tags []*Tag) (repoTags []*RepositoryTag, err error) { +func (ts *tagStoreImpl) SetMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string, tags []*Tag) (repoTags []*RepositoryTag, err error) { repo := new(Repository) err = ts.db.Operator.Core.NewSelect().Model(repo). Column("id"). @@ -247,7 +274,7 @@ func (ts *TagStore) SetMetaTags(ctx context.Context, repoType types.RepositoryTy return repoTags, err } -func (ts *TagStore) SetLibraryTag(ctx context.Context, repoType types.RepositoryType, namespace, name string, newTag, oldTag *Tag) (err error) { +func (ts *tagStoreImpl) SetLibraryTag(ctx context.Context, repoType types.RepositoryType, namespace, name string, newTag, oldTag *Tag) (err error) { slog.Debug("set library tag", slog.Any("newTag", newTag), slog.Any("oldTag", oldTag)) repo := new(Repository) err = ts.db.Operator.Core.NewSelect().Model(repo). @@ -298,7 +325,7 @@ func (ts *TagStore) SetLibraryTag(ctx context.Context, repoType types.Repository return err } -func (ts *TagStore) UpsertRepoTags(ctx context.Context, repoID int64, oldTagIDs, newTagIDs []int64) (err error) { +func (ts *tagStoreImpl) UpsertRepoTags(ctx context.Context, repoID int64, oldTagIDs, newTagIDs []int64) (err error) { err = ts.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { var err error if len(oldTagIDs) > 0 { @@ -336,7 +363,7 @@ func (ts *TagStore) UpsertRepoTags(ctx context.Context, repoID int64, oldTagIDs, return err } -func (ts *TagStore) RemoveRepoTags(ctx context.Context, repoID int64, tagIDs []int64) (err error) { +func (ts *tagStoreImpl) RemoveRepoTags(ctx context.Context, repoID int64, tagIDs []int64) (err error) { if len(tagIDs) == 0 { return nil } @@ -351,7 +378,7 @@ func (ts *TagStore) RemoveRepoTags(ctx context.Context, repoID int64, tagIDs []i return err } -func (ts *TagStore) FindOrCreate(ctx context.Context, tag Tag) (*Tag, error) { +func (ts *tagStoreImpl) FindOrCreate(ctx context.Context, tag Tag) (*Tag, error) { var resTag Tag err := ts.db.Operator.Core.NewSelect(). Model(&resTag). diff --git a/builder/store/database/telemetry.go b/builder/store/database/telemetry.go index 14637cd9..b56b42a7 100644 --- a/builder/store/database/telemetry.go +++ b/builder/store/database/telemetry.go @@ -27,16 +27,20 @@ type Telemetry struct { Counts interface{} `bun:"type:jsonb" json:"counts,omitempty"` } -type TelemetryStore struct { +type telemetryStoreImpl struct { db *DB } -func NewTelemetryStore() *TelemetryStore { - return &TelemetryStore{ +type TelemetryStore interface { + Save(ctx context.Context, telemetry *Telemetry) error +} + +func NewTelemetryStore() TelemetryStore { + return &telemetryStoreImpl{ db: defaultDB, } } -func (s *TelemetryStore) Save(ctx context.Context, telemetry *Telemetry) error { +func (s *telemetryStoreImpl) Save(ctx context.Context, telemetry *Telemetry) error { return assertAffectedOneRow(s.db.Core.NewInsert().Model(telemetry).Exec(ctx)) } diff --git a/builder/store/database/user.go b/builder/store/database/user.go index 03a97d79..cb459a06 100644 --- a/builder/store/database/user.go +++ b/builder/store/database/user.go @@ -8,12 +8,30 @@ import ( "github.com/uptrace/bun" ) -type UserStore struct { +type userStoreImpl struct { db *DB } -func NewUserStore() *UserStore { - return &UserStore{ +type UserStore interface { + Index(ctx context.Context) (users []User, err error) + IndexWithSearch(ctx context.Context, search string, per, page int) (users []User, count int, err error) + FindByUsername(ctx context.Context, username string) (user User, err error) + FindByID(ctx context.Context, id int) (user User, err error) + Update(ctx context.Context, user *User) (err error) + ChangeUserName(ctx context.Context, username string, newUsername string) (err error) + Create(ctx context.Context, user *User, namespace *Namespace) (err error) + IsExist(ctx context.Context, username string) (exists bool, err error) + IsExistByUUID(ctx context.Context, uuid string) (exists bool, err error) + // FindByAccessToken retrieves user information based on the access token. The access token must be active and not expired. + FindByAccessToken(ctx context.Context, token string) (*User, error) + FindByGitAccessToken(ctx context.Context, token string) (*User, error) + FindByUUID(ctx context.Context, uuid string) (*User, error) + GetActiveUserCount(ctx context.Context) (int, error) + DeleteUserAndRelations(ctx context.Context, input User) (err error) +} + +func NewUserStore() UserStore { + return &userStoreImpl{ db: defaultDB, } } @@ -80,7 +98,7 @@ func (u *User) SetRoles(roles []string) { u.RoleMask = strings.Join(roles, ",") } -func (s *UserStore) Index(ctx context.Context) (users []User, err error) { +func (s *userStoreImpl) Index(ctx context.Context) (users []User, err error) { err = s.db.Operator.Core.NewSelect().Model(&users).Scan(ctx, &users) if err != nil { return @@ -88,7 +106,7 @@ func (s *UserStore) Index(ctx context.Context) (users []User, err error) { return } -func (s *UserStore) IndexWithSearch(ctx context.Context, search string, per, page int) (users []User, count int, err error) { +func (s *userStoreImpl) IndexWithSearch(ctx context.Context, search string, per, page int) (users []User, count int, err error) { search = strings.ToLower(search) query := s.db.Operator.Core.NewSelect(). Model(&users) @@ -107,19 +125,19 @@ func (s *UserStore) IndexWithSearch(ctx context.Context, search string, per, pag return } -func (s *UserStore) FindByUsername(ctx context.Context, username string) (user User, err error) { +func (s *userStoreImpl) FindByUsername(ctx context.Context, username string) (user User, err error) { user.Username = username err = s.db.Operator.Core.NewSelect().Model(&user).Where("username = ?", username).Scan(ctx) return } -func (s *UserStore) FindByID(ctx context.Context, id int) (user User, err error) { +func (s *userStoreImpl) FindByID(ctx context.Context, id int) (user User, err error) { user.ID = int64(id) err = s.db.Operator.Core.NewSelect().Model(&user).WherePK().Scan(ctx) return } -func (s *UserStore) Update(ctx context.Context, user *User) (err error) { +func (s *userStoreImpl) Update(ctx context.Context, user *User) (err error) { err = assertAffectedOneRow(s.db.Operator.Core.NewUpdate(). Model(user). WherePK(). @@ -129,7 +147,7 @@ func (s *UserStore) Update(ctx context.Context, user *User) (err error) { return } -func (s *UserStore) ChangeUserName(ctx context.Context, username string, newUsername string) (err error) { +func (s *userStoreImpl) ChangeUserName(ctx context.Context, username string, newUsername string) (err error) { return s.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if err = assertAffectedOneRow(tx.NewUpdate().Model((*Namespace)(nil)). Set("path = ?", newUsername). @@ -153,7 +171,7 @@ func (s *UserStore) ChangeUserName(ctx context.Context, username string, newUser }) } -func (s *UserStore) Create(ctx context.Context, user *User, namespace *Namespace) (err error) { +func (s *userStoreImpl) Create(ctx context.Context, user *User, namespace *Namespace) (err error) { err = s.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if err = assertAffectedOneRow(tx.NewInsert().Model(user).Exec(ctx)); err != nil { return err @@ -168,7 +186,7 @@ func (s *UserStore) Create(ctx context.Context, user *User, namespace *Namespace return } -func (s *UserStore) IsExist(ctx context.Context, username string) (exists bool, err error) { +func (s *userStoreImpl) IsExist(ctx context.Context, username string) (exists bool, err error) { return s.db.Operator.Core. NewSelect(). Model((*User)(nil)). @@ -176,7 +194,7 @@ func (s *UserStore) IsExist(ctx context.Context, username string) (exists bool, Exists(ctx) } -func (s *UserStore) IsExistByUUID(ctx context.Context, uuid string) (exists bool, err error) { +func (s *userStoreImpl) IsExistByUUID(ctx context.Context, uuid string) (exists bool, err error) { return s.db.Operator.Core. NewSelect(). Model((*User)(nil)). @@ -185,7 +203,7 @@ func (s *UserStore) IsExistByUUID(ctx context.Context, uuid string) (exists bool } // FindByAccessToken retrieves user information based on the access token. The access token must be active and not expired. -func (s *UserStore) FindByAccessToken(ctx context.Context, token string) (*User, error) { +func (s *userStoreImpl) FindByAccessToken(ctx context.Context, token string) (*User, error) { var user User _, err := s.db.Operator.Core. NewSelect(). @@ -200,7 +218,7 @@ func (s *UserStore) FindByAccessToken(ctx context.Context, token string) (*User, return &user, nil } -func (s *UserStore) FindByGitAccessToken(ctx context.Context, token string) (*User, error) { +func (s *userStoreImpl) FindByGitAccessToken(ctx context.Context, token string) (*User, error) { var user User _, err := s.db.Operator.Core. NewSelect(). @@ -216,7 +234,7 @@ func (s *UserStore) FindByGitAccessToken(ctx context.Context, token string) (*Us return &user, nil } -func (s *UserStore) FindByUUID(ctx context.Context, uuid string) (*User, error) { +func (s *userStoreImpl) FindByUUID(ctx context.Context, uuid string) (*User, error) { var user User err := s.db.Operator.Core.NewSelect().Model(&user).Where("uuid = ?", uuid).Scan(ctx) if err != nil { @@ -225,14 +243,14 @@ func (s *UserStore) FindByUUID(ctx context.Context, uuid string) (*User, error) return &user, nil } -func (s *UserStore) GetActiveUserCount(ctx context.Context) (int, error) { +func (s *userStoreImpl) GetActiveUserCount(ctx context.Context) (int, error) { return s.db.Operator.Core. NewSelect(). Model(&User{}). Count(ctx) } -func (s *UserStore) DeleteUserAndRelations(ctx context.Context, input User) (err error) { +func (s *userStoreImpl) DeleteUserAndRelations(ctx context.Context, input User) (err error) { exists, err := s.IsExist(ctx, input.Username) if err != nil { return fmt.Errorf("error checking if user exists: %v", err) diff --git a/builder/store/database/user_like.go b/builder/store/database/user_like.go index e7178f81..1293203d 100644 --- a/builder/store/database/user_like.go +++ b/builder/store/database/user_like.go @@ -6,12 +6,21 @@ import ( "github.com/uptrace/bun" ) -type UserLikesStore struct { +type userLikesStoreImpl struct { db *DB } -func NewUserLikesStore() *UserLikesStore { - return &UserLikesStore{ +type UserLikesStore interface { + Add(ctx context.Context, userId, repoId int64) error + LikeCollection(ctx context.Context, userId, collectionId int64) error + UnLikeCollection(ctx context.Context, userId, collectionId int64) error + Delete(ctx context.Context, userId, repoId int64) error + IsExist(ctx context.Context, username string, repoId int64) (exists bool, err error) + IsExistCollection(ctx context.Context, username string, collectionId int64) (exists bool, err error) +} + +func NewUserLikesStore() UserLikesStore { + return &userLikesStoreImpl{ db: defaultDB, } } @@ -23,7 +32,7 @@ type UserLike struct { CollectionID int64 `bun:",notnull" json:"collection_id"` } -func (r *UserLikesStore) Add(ctx context.Context, userId, repoId int64) error { +func (r *userLikesStoreImpl) Add(ctx context.Context, userId, repoId int64) error { err := r.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { userLikes := &UserLike{ UserID: userId, @@ -41,7 +50,7 @@ func (r *UserLikesStore) Add(ctx context.Context, userId, repoId int64) error { return err } -func (r *UserLikesStore) LikeCollection(ctx context.Context, userId, collectionId int64) error { +func (r *userLikesStoreImpl) LikeCollection(ctx context.Context, userId, collectionId int64) error { err := r.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { userLikes := &UserLike{ UserID: userId, @@ -59,7 +68,7 @@ func (r *UserLikesStore) LikeCollection(ctx context.Context, userId, collectionI return err } -func (r *UserLikesStore) UnLikeCollection(ctx context.Context, userId, collectionId int64) error { +func (r *userLikesStoreImpl) UnLikeCollection(ctx context.Context, userId, collectionId int64) error { err := r.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { var userLikes UserLike if err := assertAffectedOneRow(r.db.Core.NewDelete().Model(&userLikes).Where("user_id = ? and collection_id = ?", userId, collectionId).Exec(ctx)); err != nil { @@ -74,7 +83,7 @@ func (r *UserLikesStore) UnLikeCollection(ctx context.Context, userId, collectio return err } -func (r *UserLikesStore) Delete(ctx context.Context, userId, repoId int64) error { +func (r *userLikesStoreImpl) Delete(ctx context.Context, userId, repoId int64) error { err := r.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { var userLikes UserLike if err := assertAffectedOneRow(r.db.Core.NewDelete().Model(&userLikes).Where("user_id = ? and repo_id = ?", userId, repoId).Exec(ctx)); err != nil { @@ -89,7 +98,7 @@ func (r *UserLikesStore) Delete(ctx context.Context, userId, repoId int64) error return err } -func (r *UserLikesStore) IsExist(ctx context.Context, username string, repoId int64) (exists bool, err error) { +func (r *userLikesStoreImpl) IsExist(ctx context.Context, username string, repoId int64) (exists bool, err error) { var userLike UserLike exists, err = r.db.Operator.Core. NewSelect(). @@ -100,7 +109,7 @@ func (r *UserLikesStore) IsExist(ctx context.Context, username string, repoId in return } -func (r *UserLikesStore) IsExistCollection(ctx context.Context, username string, collectionId int64) (exists bool, err error) { +func (r *userLikesStoreImpl) IsExistCollection(ctx context.Context, username string, collectionId int64) (exists bool, err error) { var userLike UserLike exists, err = r.db.Operator.Core. NewSelect(). diff --git a/cmd/csghub-server/cmd/git/generate_lfs_meta_objects.go b/cmd/csghub-server/cmd/git/generate_lfs_meta_objects.go index cc829e55..9553e3df 100644 --- a/cmd/csghub-server/cmd/git/generate_lfs_meta_objects.go +++ b/cmd/csghub-server/cmd/git/generate_lfs_meta_objects.go @@ -94,7 +94,7 @@ var generateLfsMetaObjectsCmd = &cobra.Command{ }, } -func fetchAllPointersForRepo(config *config.Config, gitServer gitserver.GitServer, s3Client *s3.Client, lfsMetaObjectStore *database.LfsMetaObjectStore, repo database.Repository) error { +func fetchAllPointersForRepo(config *config.Config, gitServer gitserver.GitServer, s3Client *s3.Client, lfsMetaObjectStore database.LfsMetaObjectStore, repo database.Repository) error { namespace := strings.Split(repo.Path, "/")[0] name := strings.Split(repo.Path, "/")[1] ref := repo.DefaultBranch @@ -124,7 +124,7 @@ func fetchAllPointersForRepo(config *config.Config, gitServer gitserver.GitServe return nil } -func checkAndUpdateLfsMetaObjects(config *config.Config, s3Client *s3.Client, lfsMetaObjectStore *database.LfsMetaObjectStore, repo database.Repository, pointer *types.Pointer) { +func checkAndUpdateLfsMetaObjects(config *config.Config, s3Client *s3.Client, lfsMetaObjectStore database.LfsMetaObjectStore, repo database.Repository, pointer *types.Pointer) { var exists bool ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() diff --git a/cmd/csghub-server/cmd/logscan/logscan.go b/cmd/csghub-server/cmd/logscan/logscan.go index 0a020d3a..5d5b3927 100644 --- a/cmd/csghub-server/cmd/logscan/logscan.go +++ b/cmd/csghub-server/cmd/logscan/logscan.go @@ -16,7 +16,7 @@ import ( var logPath string var ( - repoStore *database.RepoStore + repoStore database.RepoStore ) func init() { diff --git a/cmd/csghub-server/cmd/trigger/git_callback.go b/cmd/csghub-server/cmd/trigger/git_callback.go index 6c9e3bc8..c939d316 100644 --- a/cmd/csghub-server/cmd/trigger/git_callback.go +++ b/cmd/csghub-server/cmd/trigger/git_callback.go @@ -18,7 +18,7 @@ import ( var ( callbackComponent *callback.GitCallbackComponent - rs *database.RepoStore + rs database.RepoStore gs gitserver.GitServer repoPaths []string diff --git a/component/accesstoken.go b/component/accesstoken.go index d8c12f23..9b1b7071 100644 --- a/component/accesstoken.go +++ b/component/accesstoken.go @@ -15,8 +15,16 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewAccessTokenComponent(config *config.Config) (*AccessTokenComponent, error) { - c := &AccessTokenComponent{} +type AccessTokenComponent interface { + Create(ctx context.Context, req *types.CreateUserTokenRequest) (*database.AccessToken, error) + Delete(ctx context.Context, req *types.DeleteUserTokenRequest) error + Check(ctx context.Context, req *types.CheckAccessTokenReq) (types.CheckAccessTokenResp, error) + GetTokens(ctx context.Context, username, app string) ([]types.CheckAccessTokenResp, error) + RefreshToken(ctx context.Context, userName, tokenName, app string, newExpiredAt time.Time) (types.CheckAccessTokenResp, error) +} + +func NewAccessTokenComponent(config *config.Config) (AccessTokenComponent, error) { + c := &accessTokenComponentImpl{} c.ts = database.NewAccessTokenStore() c.us = database.NewUserStore() var err error @@ -29,13 +37,13 @@ func NewAccessTokenComponent(config *config.Config) (*AccessTokenComponent, erro return c, nil } -type AccessTokenComponent struct { - ts *database.AccessTokenStore - us *database.UserStore +type accessTokenComponentImpl struct { + ts database.AccessTokenStore + us database.UserStore gs gitserver.GitServer } -func (c *AccessTokenComponent) Create(ctx context.Context, req *types.CreateUserTokenRequest) (*database.AccessToken, error) { +func (c *accessTokenComponentImpl) Create(ctx context.Context, req *types.CreateUserTokenRequest) (*database.AccessToken, error) { user, err := c.us.FindByUsername(ctx, req.Username) if err != nil { return nil, fmt.Errorf("fail to find user,error:%w", err) @@ -81,12 +89,12 @@ func (c *AccessTokenComponent) Create(ctx context.Context, req *types.CreateUser return token, nil } -func (c *AccessTokenComponent) genUnique() string { +func (c *accessTokenComponentImpl) genUnique() string { // TODO:change return strings.ReplaceAll(uuid.NewString(), "-", "") } -func (c *AccessTokenComponent) Delete(ctx context.Context, req *types.DeleteUserTokenRequest) error { +func (c *accessTokenComponentImpl) Delete(ctx context.Context, req *types.DeleteUserTokenRequest) error { ue, err := c.us.IsExist(ctx, req.Username) if !ue { return fmt.Errorf("user does not exists,error:%w", err) @@ -110,7 +118,7 @@ func (c *AccessTokenComponent) Delete(ctx context.Context, req *types.DeleteUser return nil } -func (c *AccessTokenComponent) Check(ctx context.Context, req *types.CheckAccessTokenReq) (types.CheckAccessTokenResp, error) { +func (c *accessTokenComponentImpl) Check(ctx context.Context, req *types.CheckAccessTokenReq) (types.CheckAccessTokenResp, error) { var resp types.CheckAccessTokenResp t, err := c.ts.FindByToken(ctx, req.Token, req.Application) if err != nil { @@ -127,7 +135,7 @@ func (c *AccessTokenComponent) Check(ctx context.Context, req *types.CheckAccess return resp, nil } -func (c *AccessTokenComponent) GetTokens(ctx context.Context, username, app string) ([]types.CheckAccessTokenResp, error) { +func (c *accessTokenComponentImpl) GetTokens(ctx context.Context, username, app string) ([]types.CheckAccessTokenResp, error) { var resps []types.CheckAccessTokenResp tokens, err := c.ts.FindByUser(ctx, username, app) if err != nil { @@ -149,7 +157,7 @@ func (c *AccessTokenComponent) GetTokens(ctx context.Context, username, app stri return resps, nil } -func (c *AccessTokenComponent) RefreshToken(ctx context.Context, userName, tokenName, app string, newExpiredAt time.Time) (types.CheckAccessTokenResp, error) { +func (c *accessTokenComponentImpl) RefreshToken(ctx context.Context, userName, tokenName, app string, newExpiredAt time.Time) (types.CheckAccessTokenResp, error) { var resp types.CheckAccessTokenResp t, err := c.ts.FindByTokenName(ctx, userName, tokenName, app) if err != nil { diff --git a/component/accounting.go b/component/accounting.go index ef810ce3..e03d2c46 100644 --- a/component/accounting.go +++ b/component/accounting.go @@ -11,25 +11,29 @@ import ( "opencsg.com/csghub-server/common/types" ) -type AccountingComponent struct { +type accountingComponentImpl struct { acctClient *accounting.AccountingClient - user *database.UserStore - deploy *database.DeployTaskStore + user database.UserStore + deploy database.DeployTaskStore } -func NewAccountingComponent(config *config.Config) (*AccountingComponent, error) { +type AccountingComponent interface { + ListMeteringsByUserIDAndTime(ctx context.Context, req types.ACCT_STATEMENTS_REQ) (interface{}, error) +} + +func NewAccountingComponent(config *config.Config) (AccountingComponent, error) { c, err := accounting.NewAccountingClient(config) if err != nil { return nil, err } - return &AccountingComponent{ + return &accountingComponentImpl{ acctClient: c, user: database.NewUserStore(), deploy: database.NewDeployTaskStore(), }, nil } -func (ac *AccountingComponent) ListMeteringsByUserIDAndTime(ctx context.Context, req types.ACCT_STATEMENTS_REQ) (interface{}, error) { +func (ac *accountingComponentImpl) ListMeteringsByUserIDAndTime(ctx context.Context, req types.ACCT_STATEMENTS_REQ) (interface{}, error) { user, err := ac.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, fmt.Errorf("user does not exist, %w", err) diff --git a/component/callback/git_callback.go b/component/callback/git_callback.go index 8fdcbf00..1568e7c7 100644 --- a/component/callback/git_callback.go +++ b/component/callback/git_callback.go @@ -22,23 +22,23 @@ import ( type GitCallbackComponent struct { config *config.Config gs gitserver.GitServer - tc *component.TagComponent + tc component.TagComponent modSvcClient rpc.ModerationSvcClient - ms *database.ModelStore - ds *database.DatasetStore - sc *component.SpaceComponent - ss *database.SpaceStore - rs *database.RepoStore - rrs *database.RepoRelationsStore - mirrorStore *database.MirrorStore + ms database.ModelStore + ds database.DatasetStore + sc component.SpaceComponent + ss database.SpaceStore + rs database.RepoStore + rrs database.RepoRelationsStore + mirrorStore database.MirrorStore rrf *database.RepositoriesRuntimeFrameworkStore - rac *component.RuntimeArchitectureComponent - ras *database.RuntimeArchitecturesStore - rfs *database.RuntimeFrameworksStore - ts *database.TagStore + rac component.RuntimeArchitectureComponent + ras database.RuntimeArchitecturesStore + rfs database.RuntimeFrameworksStore + ts database.TagStore // set visibility if file content is sensitive setRepoVisibility bool - pp *component.PromptComponent + pp component.PromptComponent maxPromptFS int64 } diff --git a/component/callback/repo_relation_watcher.go b/component/callback/repo_relation_watcher.go index 06d90b51..83afa623 100644 --- a/component/callback/repo_relation_watcher.go +++ b/component/callback/repo_relation_watcher.go @@ -17,15 +17,15 @@ import ( type repoRelationWatcher struct { ops []func() error - rs *database.RepoStore - rrs *database.RepoRelationsStore + rs database.RepoStore + rrs database.RepoRelationsStore gs gitserver.GitServer readmeStatus string } -func WatchRepoRelation(req *types.GiteaCallbackPushReq, ss *database.RepoStore, - rrs *database.RepoRelationsStore, +func WatchRepoRelation(req *types.GiteaCallbackPushReq, ss database.RepoStore, + rrs database.RepoRelationsStore, gs gitserver.GitServer) Watcher { watcher := new(repoRelationWatcher) watcher.rs = ss diff --git a/component/callback/space_deploy_watcher.go b/component/callback/space_deploy_watcher.go index 374b7970..f0a07ccf 100644 --- a/component/callback/space_deploy_watcher.go +++ b/component/callback/space_deploy_watcher.go @@ -14,11 +14,11 @@ import ( type spaceDeployWatcher struct { ops []func() error - ss *database.SpaceStore - sc *component.SpaceComponent + ss database.SpaceStore + sc component.SpaceComponent } -func WatchSpaceChange(req *types.GiteaCallbackPushReq, ss *database.SpaceStore, sc *component.SpaceComponent) Watcher { +func WatchSpaceChange(req *types.GiteaCallbackPushReq, ss database.SpaceStore, sc component.SpaceComponent) Watcher { watcher := new(spaceDeployWatcher) watcher.ss = ss watcher.sc = sc diff --git a/component/cluster.go b/component/cluster.go index 9119d609..8439b156 100644 --- a/component/cluster.go +++ b/component/cluster.go @@ -8,25 +8,31 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewClusterComponent(config *config.Config) (*ClusterComponent, error) { - c := &ClusterComponent{} +type ClusterComponent interface { + Index(ctx context.Context) ([]types.ClusterRes, error) + GetClusterById(ctx context.Context, clusterId string) (*types.ClusterRes, error) + Update(ctx context.Context, data types.ClusterRequest) (*types.UpdateClusterResponse, error) +} + +func NewClusterComponent(config *config.Config) (ClusterComponent, error) { + c := &clusterComponentImpl{} c.deployer = deploy.NewDeployer() return c, nil } -type ClusterComponent struct { +type clusterComponentImpl struct { deployer deploy.Deployer } -func (c *ClusterComponent) Index(ctx context.Context) ([]types.ClusterRes, error) { +func (c *clusterComponentImpl) Index(ctx context.Context) ([]types.ClusterRes, error) { return c.deployer.ListCluster(ctx) } -func (c *ClusterComponent) GetClusterById(ctx context.Context, clusterId string) (*types.ClusterRes, error) { +func (c *clusterComponentImpl) GetClusterById(ctx context.Context, clusterId string) (*types.ClusterRes, error) { return c.deployer.GetClusterById(ctx, clusterId) } -func (c *ClusterComponent) Update(ctx context.Context, data types.ClusterRequest) (*types.UpdateClusterResponse, error) { +func (c *clusterComponentImpl) Update(ctx context.Context, data types.ClusterRequest) (*types.UpdateClusterResponse, error) { return c.deployer.UpdateCluster(ctx, data) } diff --git a/component/code.go b/component/code.go index e317719b..57415f69 100644 --- a/component/code.go +++ b/component/code.go @@ -14,10 +14,20 @@ import ( const codeGitattributesContent = modelGitattributesContent -func NewCodeComponent(config *config.Config) (*CodeComponent, error) { - c := &CodeComponent{} +type CodeComponent interface { + Create(ctx context.Context, req *types.CreateCodeReq) (*types.Code, error) + Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Code, int, error) + Update(ctx context.Context, req *types.UpdateCodeReq) (*types.Code, error) + Delete(ctx context.Context, namespace, name, currentUser string) error + Show(ctx context.Context, namespace, name, currentUser string) (*types.Code, error) + Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) + OrgCodes(ctx context.Context, req *types.OrgCodesReq) ([]types.Code, int, error) +} + +func NewCodeComponent(config *config.Config) (CodeComponent, error) { + c := &codeComponentImpl{} var err error - c.RepoComponent, err = NewRepoComponent(config) + c.repoComponentImpl, err = NewRepoComponentImpl(config) if err != nil { return nil, err } @@ -26,13 +36,13 @@ func NewCodeComponent(config *config.Config) (*CodeComponent, error) { return c, nil } -type CodeComponent struct { - *RepoComponent - cs *database.CodeStore - rs *database.RepoStore +type codeComponentImpl struct { + *repoComponentImpl + cs database.CodeStore + rs database.RepoStore } -func (c *CodeComponent) Create(ctx context.Context, req *types.CreateCodeReq) (*types.Code, error) { +func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq) (*types.Code, error) { var ( nickname string tags []types.RepoTag @@ -134,7 +144,7 @@ func (c *CodeComponent) Create(ctx context.Context, req *types.CreateCodeReq) (* return resCode, nil } -func (c *CodeComponent) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Code, int, error) { +func (c *codeComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Code, int, error) { var ( err error resCodes []types.Code @@ -198,7 +208,7 @@ func (c *CodeComponent) Index(ctx context.Context, filter *types.RepoFilter, per return resCodes, total, nil } -func (c *CodeComponent) Update(ctx context.Context, req *types.UpdateCodeReq) (*types.Code, error) { +func (c *codeComponentImpl) Update(ctx context.Context, req *types.UpdateCodeReq) (*types.Code, error) { req.RepoType = types.CodeRepo dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { @@ -233,7 +243,7 @@ func (c *CodeComponent) Update(ctx context.Context, req *types.UpdateCodeReq) (* return resCode, nil } -func (c *CodeComponent) Delete(ctx context.Context, namespace, name, currentUser string) error { +func (c *codeComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { code, err := c.cs.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find code, error: %w", err) @@ -257,7 +267,7 @@ func (c *CodeComponent) Delete(ctx context.Context, namespace, name, currentUser return nil } -func (c *CodeComponent) Show(ctx context.Context, namespace, name, currentUser string) (*types.Code, error) { +func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Code, error) { var tags []types.RepoTag code, err := c.cs.FindByPath(ctx, namespace, name) if err != nil { @@ -327,7 +337,7 @@ func (c *CodeComponent) Show(ctx context.Context, namespace, name, currentUser s return resCode, nil } -func (c *CodeComponent) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { +func (c *codeComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { code, err := c.cs.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find code repo, error: %w", err) @@ -341,7 +351,7 @@ func (c *CodeComponent) Relations(ctx context.Context, namespace, name, currentU return c.getRelations(ctx, code.RepositoryID, currentUser) } -func (c *CodeComponent) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { +func (c *codeComponentImpl) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { res, err := c.relatedRepos(ctx, repoID, currentUser) if err != nil { return nil, err @@ -363,7 +373,7 @@ func (c *CodeComponent) getRelations(ctx context.Context, repoID int64, currentU return rels, nil } -func (c *CodeComponent) OrgCodes(ctx context.Context, req *types.OrgCodesReq) ([]types.Code, int, error) { +func (c *codeComponentImpl) OrgCodes(ctx context.Context, req *types.OrgCodesReq) ([]types.Code, int, error) { var resCodes []types.Code var err error r := membership.RoleUnknown diff --git a/component/collection.go b/component/collection.go index ae9a2049..47259288 100644 --- a/component/collection.go +++ b/component/collection.go @@ -16,8 +16,21 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewCollectionComponent(config *config.Config) (*CollectionComponent, error) { - cc := &CollectionComponent{} +type CollectionComponent interface { + GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int) ([]types.Collection, int, error) + CreateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) + GetCollection(ctx context.Context, currentUser string, id int64) (*types.Collection, error) + // get non private repositories of the collection + GetPublicRepos(collection types.Collection) []types.CollectionRepository + UpdateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) + DeleteCollection(ctx context.Context, id int64, userName string) error + AddReposToCollection(ctx context.Context, req types.UpdateCollectionReposReq) error + RemoveReposFromCollection(ctx context.Context, req types.UpdateCollectionReposReq) error + OrgCollections(ctx context.Context, req *types.OrgCollectionsReq) ([]types.Collection, int, error) +} + +func NewCollectionComponent(config *config.Config) (CollectionComponent, error) { + cc := &collectionComponentImpl{} cc.cs = database.NewCollectionStore() cc.rs = database.NewRepoStore() cc.us = database.NewUserStore() @@ -33,17 +46,17 @@ func NewCollectionComponent(config *config.Config) (*CollectionComponent, error) return cc, nil } -type CollectionComponent struct { - os *database.OrgStore - cs *database.CollectionStore - rs *database.RepoStore - us *database.UserStore - uls *database.UserLikesStore +type collectionComponentImpl struct { + os database.OrgStore + cs database.CollectionStore + rs database.RepoStore + us database.UserStore + uls database.UserLikesStore userSvcClient rpc.UserSvcClient - spaceComponent *SpaceComponent + spaceComponent SpaceComponent } -func (cc *CollectionComponent) GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int) ([]types.Collection, int, error) { +func (cc *collectionComponentImpl) GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int) ([]types.Collection, int, error) { collections, total, err := cc.cs.GetCollections(ctx, filter, per, page, true) if err != nil { return nil, 0, err @@ -58,7 +71,7 @@ func (cc *CollectionComponent) GetCollections(ctx context.Context, filter *types } -func (cc *CollectionComponent) CreateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { +func (cc *collectionComponentImpl) CreateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { // find by user name user, err := cc.us.FindByUsername(ctx, input.Username) if err != nil { @@ -82,7 +95,7 @@ func (cc *CollectionComponent) CreateCollection(ctx context.Context, input types return cc.cs.CreateCollection(ctx, collection) } -func (cc *CollectionComponent) GetCollection(ctx context.Context, currentUser string, id int64) (*types.Collection, error) { +func (cc *collectionComponentImpl) GetCollection(ctx context.Context, currentUser string, id int64) (*types.Collection, error) { collection, err := cc.cs.GetCollection(ctx, id) if err != nil { return nil, err @@ -141,7 +154,7 @@ func (cc *CollectionComponent) GetCollection(ctx context.Context, currentUser st } // get non private repositories of the collection -func (cc *CollectionComponent) GetPublicRepos(collection types.Collection) []types.CollectionRepository { +func (cc *collectionComponentImpl) GetPublicRepos(collection types.Collection) []types.CollectionRepository { var filtered []types.CollectionRepository for _, repo := range collection.Repositories { if !repo.Private { @@ -151,7 +164,7 @@ func (cc *CollectionComponent) GetPublicRepos(collection types.Collection) []typ return filtered } -func (cc *CollectionComponent) UpdateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { +func (cc *collectionComponentImpl) UpdateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { collection, err := cc.cs.GetCollection(ctx, input.ID) if err != nil { return nil, fmt.Errorf("cannot find collection to update, %w", err) @@ -165,7 +178,7 @@ func (cc *CollectionComponent) UpdateCollection(ctx context.Context, input types return cc.cs.UpdateCollection(ctx, *collection) } -func (cc *CollectionComponent) DeleteCollection(ctx context.Context, id int64, userName string) error { +func (cc *collectionComponentImpl) DeleteCollection(ctx context.Context, id int64, userName string) error { // find by user name user, err := cc.us.FindByUsername(ctx, userName) if err != nil { @@ -174,7 +187,7 @@ func (cc *CollectionComponent) DeleteCollection(ctx context.Context, id int64, u return cc.cs.DeleteCollection(ctx, id, user.ID) } -func (cc *CollectionComponent) AddReposToCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { +func (cc *collectionComponentImpl) AddReposToCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { // find by user name user, err := cc.us.FindByUsername(ctx, req.Username) if err != nil { @@ -197,7 +210,7 @@ func (cc *CollectionComponent) AddReposToCollection(ctx context.Context, req typ return cc.cs.AddCollectionRepos(ctx, collectionRepos) } -func (cc *CollectionComponent) RemoveReposFromCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { +func (cc *collectionComponentImpl) RemoveReposFromCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { // find by user name user, err := cc.us.FindByUsername(ctx, req.Username) if err != nil { @@ -220,7 +233,7 @@ func (cc *CollectionComponent) RemoveReposFromCollection(ctx context.Context, re return cc.cs.RemoveCollectionRepos(ctx, collectionRepos) } -func (cc *CollectionComponent) getUserCollectionPermission(ctx context.Context, userName string, collection *database.Collection) (*types.UserRepoPermission, error) { +func (cc *collectionComponentImpl) getUserCollectionPermission(ctx context.Context, userName string, collection *database.Collection) (*types.UserRepoPermission, error) { if userName == "" { //anonymous user only has read permission to public repo return &types.UserRepoPermission{CanRead: !collection.Private, CanWrite: false, CanAdmin: false}, nil @@ -264,7 +277,7 @@ func (cc *CollectionComponent) getUserCollectionPermission(ctx context.Context, } } -func (c *CollectionComponent) OrgCollections(ctx context.Context, req *types.OrgCollectionsReq) ([]types.Collection, int, error) { +func (c *collectionComponentImpl) OrgCollections(ctx context.Context, req *types.OrgCollectionsReq) ([]types.Collection, int, error) { var err error r := membership.RoleUnknown if req.CurrentUser != "" { diff --git a/component/dataset.go b/component/dataset.go index 1081632b..9f42a11b 100644 --- a/component/dataset.go +++ b/component/dataset.go @@ -79,13 +79,23 @@ const ( gitattributesFileName = ".gitattributes" ) -func NewDatasetComponent(config *config.Config) (*DatasetComponent, error) { - c := &DatasetComponent{} +type DatasetComponent interface { + Create(ctx context.Context, req *types.CreateDatasetReq) (*types.Dataset, error) + Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Dataset, int, error) + Update(ctx context.Context, req *types.UpdateDatasetReq) (*types.Dataset, error) + Delete(ctx context.Context, namespace, name, currentUser string) error + Show(ctx context.Context, namespace, name, currentUser string) (*types.Dataset, error) + Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) + OrgDatasets(ctx context.Context, req *types.OrgDatasetsReq) ([]types.Dataset, int, error) +} + +func NewDatasetComponent(config *config.Config) (DatasetComponent, error) { + c := &datasetComponentImpl{} c.ts = database.NewTagStore() c.ds = database.NewDatasetStore() c.rs = database.NewRepoStore() var err error - c.RepoComponent, err = NewRepoComponent(config) + c.repoComponentImpl, err = NewRepoComponentImpl(config) if err != nil { return nil, fmt.Errorf("failed to create repo component, error: %w", err) } @@ -96,15 +106,15 @@ func NewDatasetComponent(config *config.Config) (*DatasetComponent, error) { return c, nil } -type DatasetComponent struct { - *RepoComponent - ts *database.TagStore - ds *database.DatasetStore - rs *database.RepoStore - sc *SensitiveComponent +type datasetComponentImpl struct { + *repoComponentImpl + ts database.TagStore + ds database.DatasetStore + rs database.RepoStore + sc SensitiveComponent } -func (c *DatasetComponent) Create(ctx context.Context, req *types.CreateDatasetReq) (*types.Dataset, error) { +func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateDatasetReq) (*types.Dataset, error) { var ( nickname string tags []types.RepoTag @@ -239,11 +249,11 @@ license: ` + license + ` ` } -func (c *DatasetComponent) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Dataset, int, error) { +func (c *datasetComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Dataset, int, error) { return c.commonIndex(ctx, filter, per, page) } -func (c *DatasetComponent) commonIndex(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Dataset, int, error) { +func (c *datasetComponentImpl) commonIndex(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Dataset, int, error) { var ( err error resDatasets []types.Dataset @@ -316,7 +326,7 @@ func (c *DatasetComponent) commonIndex(ctx context.Context, filter *types.RepoFi return resDatasets, total, nil } -func (c *DatasetComponent) Update(ctx context.Context, req *types.UpdateDatasetReq) (*types.Dataset, error) { +func (c *datasetComponentImpl) Update(ctx context.Context, req *types.UpdateDatasetReq) (*types.Dataset, error) { req.RepoType = types.DatasetRepo dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { @@ -351,7 +361,7 @@ func (c *DatasetComponent) Update(ctx context.Context, req *types.UpdateDatasetR return resDataset, nil } -func (c *DatasetComponent) Delete(ctx context.Context, namespace, name, currentUser string) error { +func (c *datasetComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { dataset, err := c.ds.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find dataset, error: %w", err) @@ -375,7 +385,7 @@ func (c *DatasetComponent) Delete(ctx context.Context, namespace, name, currentU return nil } -func (c *DatasetComponent) Show(ctx context.Context, namespace, name, currentUser string) (*types.Dataset, error) { +func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Dataset, error) { var tags []types.RepoTag dataset, err := c.ds.FindByPath(ctx, namespace, name) if err != nil { @@ -447,7 +457,7 @@ func (c *DatasetComponent) Show(ctx context.Context, namespace, name, currentUse return resDataset, nil } -func (c *DatasetComponent) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { +func (c *datasetComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { dataset, err := c.ds.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find dataset repo, error: %w", err) @@ -461,7 +471,7 @@ func (c *DatasetComponent) Relations(ctx context.Context, namespace, name, curre return c.getRelations(ctx, dataset.RepositoryID, currentUser) } -func (c *DatasetComponent) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { +func (c *datasetComponentImpl) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { res, err := c.relatedRepos(ctx, repoID, currentUser) if err != nil { return nil, err @@ -483,7 +493,7 @@ func (c *DatasetComponent) getRelations(ctx context.Context, repoID int64, curre return rels, nil } -func (c *DatasetComponent) OrgDatasets(ctx context.Context, req *types.OrgDatasetsReq) ([]types.Dataset, int, error) { +func (c *datasetComponentImpl) OrgDatasets(ctx context.Context, req *types.OrgDatasetsReq) ([]types.Dataset, int, error) { var resDatasets []types.Dataset var err error r := membership.RoleUnknown diff --git a/component/dataset_viewer.go b/component/dataset_viewer.go index 990d1238..f3f031c3 100644 --- a/component/dataset_viewer.go +++ b/component/dataset_viewer.go @@ -28,26 +28,30 @@ type ViewParquetFileResp struct { Columns []string `json:"columns"` Rows [][]interface{} `json:"rows"` } -type DatasetViewerComponent struct { +type datasetViewerComponentImpl struct { gs gitserver.GitServer preader parquet.Reader once *sync.Once cfg *config.Config } -func NewDatasetViewerComponent(cfg *config.Config) (*DatasetViewerComponent, error) { +type DatasetViewerComponent interface { + ViewParquetFile(ctx context.Context, req *ViewParquetFileReq) (*ViewParquetFileResp, error) +} + +func NewDatasetViewerComponent(cfg *config.Config) (DatasetViewerComponent, error) { gs, err := git.NewGitServer(cfg) if err != nil { return nil, fmt.Errorf("failed to create git server,cause:%w", err) } - return &DatasetViewerComponent{ + return &datasetViewerComponentImpl{ gs: gs, once: new(sync.Once), cfg: cfg, }, nil } -func (c *DatasetViewerComponent) lazyInit() { +func (c *datasetViewerComponentImpl) lazyInit() { c.once.Do(func() { r, err := parquet.NewS3Reader(c.cfg) if err != nil { @@ -57,7 +61,7 @@ func (c *DatasetViewerComponent) lazyInit() { }) } -func (c *DatasetViewerComponent) ViewParquetFile(ctx context.Context, req *ViewParquetFileReq) (*ViewParquetFileResp, error) { +func (c *datasetViewerComponentImpl) ViewParquetFile(ctx context.Context, req *ViewParquetFileReq) (*ViewParquetFileResp, error) { c.lazyInit() objName, err := c.getParquetObject(req) @@ -83,7 +87,7 @@ func (c *DatasetViewerComponent) ViewParquetFile(ctx context.Context, req *ViewP return resp, nil } -func (c *DatasetViewerComponent) getParquetObject(req *ViewParquetFileReq) (string, error) { +func (c *datasetViewerComponentImpl) getParquetObject(req *ViewParquetFileReq) (string, error) { getFileContentReq := gitserver.GetRepoInfoByPathReq{ Namespace: req.Namespace, Name: req.RepoName, diff --git a/component/discussion.go b/component/discussion.go index da2149bd..4250d72a 100644 --- a/component/discussion.go +++ b/component/discussion.go @@ -11,20 +11,32 @@ import ( "opencsg.com/csghub-server/common/types" ) -type DiscussionComponent struct { - ds *database.DiscussionStore - rs *database.RepoStore - us *database.UserStore +type discussionComponentImpl struct { + ds database.DiscussionStore + rs database.RepoStore + us database.UserStore } -func NewDiscussionComponent() *DiscussionComponent { +type DiscussionComponent interface { + CreateRepoDiscussion(ctx context.Context, req CreateRepoDiscussionRequest) (*CreateDiscussionResponse, error) + GetDiscussion(ctx context.Context, id int64) (*ShowDiscussionResponse, error) + UpdateDiscussion(ctx context.Context, req UpdateDiscussionRequest) error + DeleteDiscussion(ctx context.Context, currentUser string, id int64) error + ListRepoDiscussions(ctx context.Context, req ListRepoDiscussionRequest) (*ListRepoDiscussionResponse, error) + CreateDiscussionComment(ctx context.Context, req CreateCommentRequest) (*CreateCommentResponse, error) + UpdateComment(ctx context.Context, currentUser string, id int64, content string) error + DeleteComment(ctx context.Context, currentUser string, id int64) error + ListDiscussionComments(ctx context.Context, discussionID int64) ([]*DiscussionResponse_Comment, error) +} + +func NewDiscussionComponent() DiscussionComponent { ds := database.NewDiscussionStore() rs := database.NewRepoStore() us := database.NewUserStore() - return &DiscussionComponent{ds: ds, rs: rs, us: us} + return &discussionComponentImpl{ds: ds, rs: rs, us: us} } -func (c *DiscussionComponent) CreateRepoDiscussion(ctx context.Context, req CreateRepoDiscussionRequest) (*CreateDiscussionResponse, error) { +func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req CreateRepoDiscussionRequest) (*CreateDiscussionResponse, error) { //TODO:check if the user can access the repo //get repo by namespace and name @@ -59,7 +71,7 @@ func (c *DiscussionComponent) CreateRepoDiscussion(ctx context.Context, req Crea return resp, nil } -func (c *DiscussionComponent) GetDiscussion(ctx context.Context, id int64) (*ShowDiscussionResponse, error) { +func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) (*ShowDiscussionResponse, error) { discussion, err := c.ds.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("failed to find discussion by id '%d': %w", id, err) @@ -91,7 +103,7 @@ func (c *DiscussionComponent) GetDiscussion(ctx context.Context, id int64) (*Sho return resp, nil } -func (c *DiscussionComponent) UpdateDiscussion(ctx context.Context, req UpdateDiscussionRequest) error { +func (c *discussionComponentImpl) UpdateDiscussion(ctx context.Context, req UpdateDiscussionRequest) error { //check if the user is the owner of the discussion user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { @@ -111,7 +123,7 @@ func (c *DiscussionComponent) UpdateDiscussion(ctx context.Context, req UpdateDi return nil } -func (c *DiscussionComponent) DeleteDiscussion(ctx context.Context, currentUser string, id int64) error { +func (c *discussionComponentImpl) DeleteDiscussion(ctx context.Context, currentUser string, id int64) error { discussion, err := c.ds.FindByID(ctx, id) if err != nil { return fmt.Errorf("failed to find discussion by id '%d': %w", id, err) @@ -126,7 +138,7 @@ func (c *DiscussionComponent) DeleteDiscussion(ctx context.Context, currentUser return nil } -func (c *DiscussionComponent) ListRepoDiscussions(ctx context.Context, req ListRepoDiscussionRequest) (*ListRepoDiscussionResponse, error) { +func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req ListRepoDiscussionRequest) (*ListRepoDiscussionResponse, error) { //TODO:check if the user can access the repo repo, err := c.rs.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { @@ -153,7 +165,7 @@ func (c *DiscussionComponent) ListRepoDiscussions(ctx context.Context, req ListR return resp, nil } -func (c *DiscussionComponent) CreateDiscussionComment(ctx context.Context, req CreateCommentRequest) (*CreateCommentResponse, error) { +func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, req CreateCommentRequest) (*CreateCommentResponse, error) { req.CommentableType = database.CommentableTypeDiscussion // get discussion by id _, err := c.ds.FindByID(ctx, req.CommentableID) @@ -189,7 +201,7 @@ func (c *DiscussionComponent) CreateDiscussionComment(ctx context.Context, req C }, nil } -func (c *DiscussionComponent) UpdateComment(ctx context.Context, currentUser string, id int64, content string) error { +func (c *discussionComponentImpl) UpdateComment(ctx context.Context, currentUser string, id int64, content string) error { user, err := c.us.FindByUsername(ctx, currentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", currentUser, err) @@ -210,7 +222,7 @@ func (c *DiscussionComponent) UpdateComment(ctx context.Context, currentUser str return nil } -func (c *DiscussionComponent) DeleteComment(ctx context.Context, currentUser string, id int64) error { +func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser string, id int64) error { user, err := c.us.FindByUsername(ctx, currentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", currentUser, err) @@ -231,7 +243,7 @@ func (c *DiscussionComponent) DeleteComment(ctx context.Context, currentUser str return nil } -func (c *DiscussionComponent) ListDiscussionComments(ctx context.Context, discussionID int64) ([]*DiscussionResponse_Comment, error) { +func (c *discussionComponentImpl) ListDiscussionComments(ctx context.Context, discussionID int64) ([]*DiscussionResponse_Comment, error) { comments, err := c.ds.FindDiscussionComments(ctx, discussionID) if err != nil { return nil, fmt.Errorf("failed to find discussion comments by discussion id '%d': %w", discussionID, err) diff --git a/component/event.go b/component/event.go index b0957a89..2ea4a1ec 100644 --- a/component/event.go +++ b/component/event.go @@ -7,18 +7,23 @@ import ( "opencsg.com/csghub-server/common/types" ) -type EventComponent struct { - es *database.EventStore +type eventComponentImpl struct { + es database.EventStore } // NewEventComponent creates a new EventComponent -func NewEventComponent() *EventComponent { - return &EventComponent{ + +type EventComponent interface { + NewEvents(ctx context.Context, events []types.Event) error +} + +func NewEventComponent() EventComponent { + return &eventComponentImpl{ es: database.NewEventStore(), } } -func (ec *EventComponent) NewEvents(ctx context.Context, events []types.Event) error { +func (ec *eventComponentImpl) NewEvents(ctx context.Context, events []types.Event) error { var dbevents []database.Event for _, e := range events { dbevents = append(dbevents, database.Event{ diff --git a/component/git_http.go b/component/git_http.go index 3d589103..62b2743a 100644 --- a/component/git_http.go +++ b/component/git_http.go @@ -25,18 +25,32 @@ import ( "opencsg.com/csghub-server/common/types" ) -type GitHTTPComponent struct { +type gitHTTPComponentImpl struct { git gitserver.GitServer config *config.Config s3Client *s3.Client - lfsMetaObjectStore *database.LfsMetaObjectStore - lfsLockStore *database.LfsLockStore - repo *database.RepoStore - *RepoComponent + lfsMetaObjectStore database.LfsMetaObjectStore + lfsLockStore database.LfsLockStore + repo database.RepoStore + *repoComponentImpl } -func NewGitHTTPComponent(config *config.Config) (*GitHTTPComponent, error) { - c := &GitHTTPComponent{} +type GitHTTPComponent interface { + InfoRefs(ctx context.Context, req types.InfoRefsReq) (io.Reader, error) + GitUploadPack(ctx context.Context, req types.GitUploadPackReq) error + GitReceivePack(ctx context.Context, req types.GitReceivePackReq) error + BuildObjectResponse(ctx context.Context, req types.BatchRequest, isUpload bool) (*types.BatchResponse, error) + LfsUpload(ctx context.Context, body io.ReadCloser, req types.UploadRequest) error + LfsVerify(ctx context.Context, req types.VerifyRequest, p types.Pointer) error + CreateLock(ctx context.Context, req types.LfsLockReq) (*database.LfsLock, error) + ListLocks(ctx context.Context, req types.ListLFSLockReq) (*types.LFSLockList, error) + UnLock(ctx context.Context, req types.UnlockLFSReq) (*database.LfsLock, error) + VerifyLock(ctx context.Context, req types.VerifyLFSLockReq) (*types.LFSLockListVerify, error) + LfsDownload(ctx context.Context, req types.DownloadRequest) (*url.URL, error) +} + +func NewGitHTTPComponent(config *config.Config) (GitHTTPComponent, error) { + c := &gitHTTPComponentImpl{} c.config = config var err error c.git, err = git.NewGitServer(config) @@ -54,14 +68,14 @@ func NewGitHTTPComponent(config *config.Config) (*GitHTTPComponent, error) { c.lfsMetaObjectStore = database.NewLfsMetaObjectStore() c.repo = database.NewRepoStore() c.lfsLockStore = database.NewLfsLockStore() - c.RepoComponent, err = NewRepoComponent(config) + c.repoComponentImpl, err = NewRepoComponentImpl(config) if err != nil { return nil, err } return c, nil } -func (c *GitHTTPComponent) InfoRefs(ctx context.Context, req types.InfoRefsReq) (io.Reader, error) { +func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsReq) (io.Reader, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -98,7 +112,7 @@ func (c *GitHTTPComponent) InfoRefs(ctx context.Context, req types.InfoRefsReq) return reader, err } -func (c *GitHTTPComponent) GitUploadPack(ctx context.Context, req types.GitUploadPackReq) error { +func (c *gitHTTPComponentImpl) GitUploadPack(ctx context.Context, req types.GitUploadPackReq) error { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) @@ -125,7 +139,7 @@ func (c *GitHTTPComponent) GitUploadPack(ctx context.Context, req types.GitUploa return err } -func (c *GitHTTPComponent) GitReceivePack(ctx context.Context, req types.GitReceivePackReq) error { +func (c *gitHTTPComponentImpl) GitReceivePack(ctx context.Context, req types.GitReceivePackReq) error { _, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) @@ -157,7 +171,7 @@ func (c *GitHTTPComponent) GitReceivePack(ctx context.Context, req types.GitRece return err } -func (c *GitHTTPComponent) BuildObjectResponse(ctx context.Context, req types.BatchRequest, isUpload bool) (*types.BatchResponse, error) { +func (c *gitHTTPComponentImpl) BuildObjectResponse(ctx context.Context, req types.BatchRequest, isUpload bool) (*types.BatchResponse, error) { var ( respObjects []*types.ObjectResponse exists bool @@ -251,7 +265,7 @@ func (c *GitHTTPComponent) BuildObjectResponse(ctx context.Context, req types.Ba return respobj, nil } -func (c *GitHTTPComponent) buildObjectResponse(ctx context.Context, req types.BatchRequest, pointer types.Pointer, download, upload bool, err *types.ObjectError) *types.ObjectResponse { +func (c *gitHTTPComponentImpl) buildObjectResponse(ctx context.Context, req types.BatchRequest, pointer types.Pointer, download, upload bool, err *types.ObjectError) *types.ObjectResponse { rep := &types.ObjectResponse{Pointer: pointer} if err != nil { rep.Error = err @@ -294,7 +308,7 @@ func (c *GitHTTPComponent) buildObjectResponse(ctx context.Context, req types.Ba return rep } -func (c *GitHTTPComponent) LfsUpload(ctx context.Context, body io.ReadCloser, req types.UploadRequest) error { +func (c *gitHTTPComponentImpl) LfsUpload(ctx context.Context, body io.ReadCloser, req types.UploadRequest) error { var exists bool repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { @@ -402,7 +416,7 @@ func (c *GitHTTPComponent) LfsUpload(ctx context.Context, body io.ReadCloser, re return nil } -func (c *GitHTTPComponent) LfsVerify(ctx context.Context, req types.VerifyRequest, p types.Pointer) error { +func (c *gitHTTPComponentImpl) LfsVerify(ctx context.Context, req types.VerifyRequest, p types.Pointer) error { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) @@ -436,7 +450,7 @@ func (c *GitHTTPComponent) LfsVerify(ctx context.Context, req types.VerifyReques return nil } -func (c *GitHTTPComponent) CreateLock(ctx context.Context, req types.LfsLockReq) (*database.LfsLock, error) { +func (c *gitHTTPComponentImpl) CreateLock(ctx context.Context, req types.LfsLockReq) (*database.LfsLock, error) { var ( lock *database.LfsLock ) @@ -480,7 +494,7 @@ func (c *GitHTTPComponent) CreateLock(ctx context.Context, req types.LfsLockReq) return lock, ErrAlreadyExists } -func (c *GitHTTPComponent) ListLocks(ctx context.Context, req types.ListLFSLockReq) (*types.LFSLockList, error) { +func (c *gitHTTPComponentImpl) ListLocks(ctx context.Context, req types.ListLFSLockReq) (*types.LFSLockList, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -541,7 +555,7 @@ func (c *GitHTTPComponent) ListLocks(ctx context.Context, req types.ListLFSLockR return res, nil } -func (c *GitHTTPComponent) UnLock(ctx context.Context, req types.UnlockLFSReq) (*database.LfsLock, error) { +func (c *gitHTTPComponentImpl) UnLock(ctx context.Context, req types.UnlockLFSReq) (*database.LfsLock, error) { var ( lock *database.LfsLock err error @@ -585,7 +599,7 @@ func (c *GitHTTPComponent) UnLock(ctx context.Context, req types.UnlockLFSReq) ( return lock, nil } -func (c *GitHTTPComponent) VerifyLock(ctx context.Context, req types.VerifyLFSLockReq) (*types.LFSLockListVerify, error) { +func (c *gitHTTPComponentImpl) VerifyLock(ctx context.Context, req types.VerifyLFSLockReq) (*types.LFSLockListVerify, error) { var ( ourLocks []*types.LFSLock theirLocks []*types.LFSLock @@ -643,7 +657,7 @@ func (c *GitHTTPComponent) VerifyLock(ctx context.Context, req types.VerifyLFSLo return &res, nil } -func (c *GitHTTPComponent) LfsDownload(ctx context.Context, req types.DownloadRequest) (*url.URL, error) { +func (c *gitHTTPComponentImpl) LfsDownload(ctx context.Context, req types.DownloadRequest) (*url.URL, error) { pointer := types.Pointer{Oid: req.Oid} repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { @@ -676,15 +690,15 @@ func (c *GitHTTPComponent) LfsDownload(ctx context.Context, req types.DownloadRe return signedUrl, nil } -func (c *GitHTTPComponent) buildDownloadLink(req types.BatchRequest, pointer types.Pointer) string { +func (c *gitHTTPComponentImpl) buildDownloadLink(req types.BatchRequest, pointer types.Pointer) string { return c.config.APIServer.PublicDomain + "/" + path.Join(fmt.Sprintf("%ss", req.RepoType), url.PathEscape(req.Namespace), url.PathEscape(req.Name+".git"), "info/lfs/objects", url.PathEscape(pointer.Oid)) } -// func (c *GitHTTPComponent) buildUploadLink(req types.BatchRequest, pointer types.Pointer) string { +// func (c *gitHTTPComponentImpl) buildUploadLink(req types.BatchRequest, pointer types.Pointer) string { // return c.config.APIServer.PublicDomain + "/" + path.Join(fmt.Sprintf("%ss", req.RepoType), url.PathEscape(req.Namespace), url.PathEscape(req.Name+".git"), "info/lfs/objects", url.PathEscape(pointer.Oid), strconv.FormatInt(pointer.Size, 10)) // } -func (c *GitHTTPComponent) buildUploadLink(req types.BatchRequest, pointer types.Pointer) string { +func (c *gitHTTPComponentImpl) buildUploadLink(req types.BatchRequest, pointer types.Pointer) string { objectKey := path.Join("lfs", pointer.RelativePath()) u, err := c.s3Client.PresignedPutObject(context.Background(), c.config.S3.Bucket, objectKey, time.Hour*24) if err != nil { @@ -693,7 +707,7 @@ func (c *GitHTTPComponent) buildUploadLink(req types.BatchRequest, pointer types return u.String() } -func (c *GitHTTPComponent) buildVerifyLink(req types.BatchRequest) string { +func (c *gitHTTPComponentImpl) buildVerifyLink(req types.BatchRequest) string { return c.config.APIServer.PublicDomain + "/" + path.Join(fmt.Sprintf("%ss", req.RepoType), url.PathEscape(req.Namespace), url.PathEscape(req.Name+".git"), "info/lfs/verify") } diff --git a/component/hf_dataset.go b/component/hf_dataset.go index ef14d15d..03c48e89 100644 --- a/component/hf_dataset.go +++ b/component/hf_dataset.go @@ -12,31 +12,36 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewHFDatasetComponent(config *config.Config) (*HFDatasetComponent, error) { - c := &HFDatasetComponent{} +type HFDatasetComponent interface { + GetPathsInfo(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) + GetDatasetTree(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) +} + +func NewHFDatasetComponent(config *config.Config) (HFDatasetComponent, error) { + c := &hFDatasetComponentImpl{} c.ts = database.NewTagStore() c.ds = database.NewDatasetStore() c.rs = database.NewRepoStore() var err error - c.RepoComponent, err = NewRepoComponent(config) + c.repoComponentImpl, err = NewRepoComponentImpl(config) if err != nil { return nil, err } return c, nil } -type HFDatasetComponent struct { - *RepoComponent - ts *database.TagStore - ds *database.DatasetStore - rs *database.RepoStore +type hFDatasetComponentImpl struct { + *repoComponentImpl + ts database.TagStore + ds database.DatasetStore + rs database.RepoStore } func convertFilePathFromRoute(path string) string { return strings.TrimLeft(path, "/") } -func (h *HFDatasetComponent) GetPathsInfo(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { +func (h *hFDatasetComponentImpl) GetPathsInfo(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { ds, err := h.ds.FindByPath(ctx, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) @@ -75,7 +80,7 @@ func (h *HFDatasetComponent) GetPathsInfo(ctx context.Context, req types.PathReq return paths, nil } -func (h *HFDatasetComponent) GetDatasetTree(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { +func (h *hFDatasetComponentImpl) GetDatasetTree(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { ds, err := h.ds.FindByPath(ctx, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find dataset tree, error: %w", err) diff --git a/component/internal.go b/component/internal.go index 0652445e..2ca87b97 100644 --- a/component/internal.go +++ b/component/internal.go @@ -16,20 +16,28 @@ import ( "opencsg.com/csghub-server/common/utils/common" ) -type InternalComponent struct { +type internalComponentImpl struct { config *config.Config - sshKeyStore *database.SSHKeyStore - repoStore *database.RepoStore - *RepoComponent + sshKeyStore database.SSHKeyStore + repoStore database.RepoStore + *repoComponentImpl } -func NewInternalComponent(config *config.Config) (*InternalComponent, error) { +type InternalComponent interface { + Allowed(ctx context.Context) (bool, error) + SSHAllowed(ctx context.Context, req types.SSHAllowedReq) (*types.SSHAllowedResp, error) + GetAuthorizedKeys(ctx context.Context, key string) (*database.SSHKey, error) + GetCommitDiff(ctx context.Context, req types.GetDiffBetweenTwoCommitsReq) (*types.GiteaCallbackPushReq, error) + LfsAuthenticate(ctx context.Context, req types.LfsAuthenticateReq) (*types.LfsAuthenticateResp, error) +} + +func NewInternalComponent(config *config.Config) (InternalComponent, error) { var err error - c := &InternalComponent{} + c := &internalComponentImpl{} c.config = config c.sshKeyStore = database.NewSSHKeyStore() c.repoStore = database.NewRepoStore() - c.RepoComponent, err = NewRepoComponent(config) + c.repoComponentImpl, err = NewRepoComponentImpl(config) c.tokenStore = database.NewAccessTokenStore() if err != nil { return nil, err @@ -37,11 +45,11 @@ func NewInternalComponent(config *config.Config) (*InternalComponent, error) { return c, nil } -func (c *InternalComponent) Allowed(ctx context.Context) (bool, error) { +func (c *internalComponentImpl) Allowed(ctx context.Context) (bool, error) { return true, nil } -func (c *InternalComponent) SSHAllowed(ctx context.Context, req types.SSHAllowedReq) (*types.SSHAllowedResp, error) { +func (c *internalComponentImpl) SSHAllowed(ctx context.Context, req types.SSHAllowedReq) (*types.SSHAllowedResp, error) { namespace, err := c.namespace.FindByPath(ctx, req.Namespace) if err != nil { return nil, fmt.Errorf("failed to find namespace %s: %v", req.Namespace, err) @@ -105,7 +113,7 @@ func (c *InternalComponent) SSHAllowed(ctx context.Context, req types.SSHAllowed }, nil } -func (c *InternalComponent) GetAuthorizedKeys(ctx context.Context, key string) (*database.SSHKey, error) { +func (c *internalComponentImpl) GetAuthorizedKeys(ctx context.Context, key string) (*database.SSHKey, error) { fingerprint, err := common.CalculateAuthorizedSSHKeyFingerprint(key) if err != nil { return nil, fmt.Errorf("failed to calculate authorized keys fingerprint, error: %v", err) @@ -117,7 +125,7 @@ func (c *InternalComponent) GetAuthorizedKeys(ctx context.Context, key string) ( return sshKey, nil } -func (c *InternalComponent) GetCommitDiff(ctx context.Context, req types.GetDiffBetweenTwoCommitsReq) (*types.GiteaCallbackPushReq, error) { +func (c *internalComponentImpl) GetCommitDiff(ctx context.Context, req types.GetDiffBetweenTwoCommitsReq) (*types.GiteaCallbackPushReq, error) { repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, err: %v", err) @@ -140,7 +148,7 @@ func (c *InternalComponent) GetCommitDiff(ctx context.Context, req types.GetDiff return diffs, nil } -func (c *InternalComponent) LfsAuthenticate(ctx context.Context, req types.LfsAuthenticateReq) (*types.LfsAuthenticateResp, error) { +func (c *internalComponentImpl) LfsAuthenticate(ctx context.Context, req types.LfsAuthenticateReq) (*types.LfsAuthenticateResp, error) { repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, err: %v", err) @@ -176,4 +184,3 @@ func (c *InternalComponent) LfsAuthenticate(ctx context.Context, req types.LfsAu RepoPath: c.config.APIServer.PublicDomain + "/" + filepath.Join(repoType, req.Namespace, req.Name+".git"), }, nil } - diff --git a/component/list.go b/component/list.go index 7f710382..46b4a6c0 100644 --- a/component/list.go +++ b/component/list.go @@ -9,21 +9,26 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewListComponent(config *config.Config) (*ListComponent, error) { - c := &ListComponent{} +type ListComponent interface { + ListModelsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.ModelResp, error) + ListDatasetsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.DatasetResp, error) +} + +func NewListComponent(config *config.Config) (ListComponent, error) { + c := &listComponentImpl{} c.ds = database.NewDatasetStore() c.ms = database.NewModelStore() c.ss = database.NewSpaceStore() return c, nil } -type ListComponent struct { - ms *database.ModelStore - ds *database.DatasetStore - ss *database.SpaceStore +type listComponentImpl struct { + ms database.ModelStore + ds database.DatasetStore + ss database.SpaceStore } -func (c *ListComponent) ListModelsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.ModelResp, error) { +func (c *listComponentImpl) ListModelsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.ModelResp, error) { var modelResp []*types.ModelResp models, err := c.ms.ListByPath(ctx, req.Paths) @@ -59,7 +64,7 @@ func (c *ListComponent) ListModelsByPath(ctx context.Context, req *types.ListByP return modelResp, nil } -func (c *ListComponent) ListDatasetsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.DatasetResp, error) { +func (c *listComponentImpl) ListDatasetsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.DatasetResp, error) { var datasetResp []*types.DatasetResp datasets, err := c.ds.ListByPath(ctx, req.Paths) diff --git a/component/mirror.go b/component/mirror.go index c21571e1..20f9e88f 100644 --- a/component/mirror.go +++ b/component/mirror.go @@ -21,30 +21,39 @@ import ( "opencsg.com/csghub-server/mirror/queue" ) -type MirrorComponent struct { - tokenStore *database.GitServerAccessTokenStore +type mirrorComponentImpl struct { + tokenStore database.GitServerAccessTokenStore mirrorServer mirrorserver.MirrorServer saas bool - repoComp *RepoComponent + repoComp RepoComponent git gitserver.GitServer s3Client *s3.Client lfsBucket string - modelStore *database.ModelStore - datasetStore *database.DatasetStore - codeStore *database.CodeStore - repoStore *database.RepoStore - mirrorStore *database.MirrorStore - mirrorSourceStore *database.MirrorSourceStore - namespaceStore *database.NamespaceStore - lfsMetaObjectStore *database.LfsMetaObjectStore - userStore *database.UserStore + modelStore database.ModelStore + datasetStore database.DatasetStore + codeStore database.CodeStore + repoStore database.RepoStore + mirrorStore database.MirrorStore + mirrorSourceStore database.MirrorSourceStore + namespaceStore database.NamespaceStore + lfsMetaObjectStore database.LfsMetaObjectStore + userStore database.UserStore config *config.Config mq *queue.PriorityQueue } -func NewMirrorComponent(config *config.Config) (*MirrorComponent, error) { +type MirrorComponent interface { + CreatePushMirrorForFinishedMirrorTask(ctx context.Context) error + // CreateMirrorRepo often called by the crawler server to create new repo which will then be mirrored from other sources + CreateMirrorRepo(ctx context.Context, req types.CreateMirrorRepoReq) (*database.Mirror, error) + CheckMirrorProgress(ctx context.Context) error + Repos(ctx context.Context, currentUser string, per, page int) ([]types.MirrorRepo, int, error) + Index(ctx context.Context, currentUser string, per, page int) ([]types.Mirror, int, error) +} + +func NewMirrorComponent(config *config.Config) (MirrorComponent, error) { var err error - c := &MirrorComponent{} + c := &mirrorComponentImpl{} if config.GitServer.Type == types.GitServerTypeGitea { c.mirrorServer, err = git.NewMirrorServer(config) if err != nil { @@ -89,7 +98,7 @@ func NewMirrorComponent(config *config.Config) (*MirrorComponent, error) { return c, nil } -func (c *MirrorComponent) CreatePushMirrorForFinishedMirrorTask(ctx context.Context) error { +func (c *mirrorComponentImpl) CreatePushMirrorForFinishedMirrorTask(ctx context.Context) error { mirrors, err := c.mirrorStore.NoPushMirror(ctx) if err != nil { return fmt.Errorf("fail to find all mirrors, %w", err) @@ -127,7 +136,7 @@ func (c *MirrorComponent) CreatePushMirrorForFinishedMirrorTask(ctx context.Cont } // CreateMirrorRepo often called by the crawler server to create new repo which will then be mirrored from other sources -func (c *MirrorComponent) CreateMirrorRepo(ctx context.Context, req types.CreateMirrorRepoReq) (*database.Mirror, error) { +func (c *mirrorComponentImpl) CreateMirrorRepo(ctx context.Context, req types.CreateMirrorRepoReq) (*database.Mirror, error) { var username string namespace := c.mapNamespaceAndName(req.SourceNamespace) name := req.SourceName @@ -276,7 +285,7 @@ func (c *MirrorComponent) CreateMirrorRepo(ctx context.Context, req types.Create return reqMirror, nil } -func (m *MirrorComponent) mapNamespaceAndName(sourceNamespace string) string { +func (m *mirrorComponentImpl) mapNamespaceAndName(sourceNamespace string) string { namespace := sourceNamespace if ns, found := mirrorOrganizationMap[sourceNamespace]; found { namespace = ns @@ -288,7 +297,7 @@ func (m *MirrorComponent) mapNamespaceAndName(sourceNamespace string) string { return namespace } -func (c *MirrorComponent) CheckMirrorProgress(ctx context.Context) error { +func (c *mirrorComponentImpl) CheckMirrorProgress(ctx context.Context) error { mirrors, err := c.mirrorStore.Unfinished(ctx) if err != nil { return fmt.Errorf("failed to get unfinished mirrors: %v", err) @@ -333,7 +342,7 @@ var mirrorStatusAndRepoSyncStatusMapping = map[types.MirrorTaskStatus]types.Repo types.MirrorIncomplete: types.SyncStatusFailed, } -func (c *MirrorComponent) checkAndUpdateMirrorStatus(ctx context.Context, mirror database.Mirror) error { +func (c *mirrorComponentImpl) checkAndUpdateMirrorStatus(ctx context.Context, mirror database.Mirror) error { var statusAndProgressFunc func(ctx context.Context, mirror database.Mirror) (types.MirrorResp, error) if mirror.Repository == nil { return nil @@ -409,7 +418,7 @@ func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryTy return files, nil } -func (c *MirrorComponent) getMirrorStatusAndProgressOnPremise(ctx context.Context, mirror database.Mirror) (types.MirrorResp, error) { +func (c *mirrorComponentImpl) getMirrorStatusAndProgressOnPremise(ctx context.Context, mirror database.Mirror) (types.MirrorResp, error) { task, err := c.git.GetMirrorTaskInfo(ctx, mirror.MirrorTaskID) if err != nil { slog.Error("fail to get mirror task info", slog.Int64("taskId", mirror.MirrorTaskID), slog.String("error", err.Error())) @@ -476,7 +485,7 @@ func (c *MirrorComponent) getMirrorStatusAndProgressOnPremise(ctx context.Contex } } -func (c *MirrorComponent) getMirrorStatusAndProgressSaas(ctx context.Context, mirror database.Mirror) (types.MirrorResp, error) { +func (c *mirrorComponentImpl) getMirrorStatusAndProgressSaas(ctx context.Context, mirror database.Mirror) (types.MirrorResp, error) { task, err := c.mirrorServer.GetMirrorTaskInfo(ctx, mirror.MirrorTaskID) if err != nil { slog.Error("fail to get mirror task info", slog.Int64("taskId", mirror.MirrorTaskID), slog.String("error", err.Error())) @@ -543,7 +552,7 @@ func (c *MirrorComponent) getMirrorStatusAndProgressSaas(ctx context.Context, mi } } -func (c *MirrorComponent) countMirrorProgress(ctx context.Context, mirror database.Mirror) (int8, error) { +func (c *mirrorComponentImpl) countMirrorProgress(ctx context.Context, mirror database.Mirror) (int8, error) { var ( lfsFiles []*types.File finishedFileCount int @@ -586,7 +595,7 @@ func (c *MirrorComponent) countMirrorProgress(ctx context.Context, mirror databa return int8(progress), nil } -func (c *MirrorComponent) Repos(ctx context.Context, currentUser string, per, page int) ([]types.MirrorRepo, int, error) { +func (c *mirrorComponentImpl) Repos(ctx context.Context, currentUser string, per, page int) ([]types.MirrorRepo, int, error) { var mirrorRepos []types.MirrorRepo user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { @@ -610,7 +619,7 @@ func (c *MirrorComponent) Repos(ctx context.Context, currentUser string, per, pa return mirrorRepos, total, nil } -func (c *MirrorComponent) Index(ctx context.Context, currentUser string, per, page int) ([]types.Mirror, int, error) { +func (c *mirrorComponentImpl) Index(ctx context.Context, currentUser string, per, page int) ([]types.Mirror, int, error) { var mirrorsResp []types.Mirror user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { diff --git a/component/mirror_source.go b/component/mirror_source.go index e1b04b59..136c16a0 100644 --- a/component/mirror_source.go +++ b/component/mirror_source.go @@ -10,19 +10,27 @@ import ( "opencsg.com/csghub-server/common/types" ) -type MirrorSourceComponent struct { - msStore *database.MirrorSourceStore - userStore *database.UserStore +type mirrorSourceComponentImpl struct { + msStore database.MirrorSourceStore + userStore database.UserStore } -func NewMirrorSourceComponent(config *config.Config) (*MirrorSourceComponent, error) { - return &MirrorSourceComponent{ +type MirrorSourceComponent interface { + Create(ctx context.Context, req types.CreateMirrorSourceReq) (*database.MirrorSource, error) + Get(ctx context.Context, id int64, currentUser string) (*database.MirrorSource, error) + Index(ctx context.Context, currentUser string) ([]database.MirrorSource, error) + Update(ctx context.Context, req types.UpdateMirrorSourceReq) (*database.MirrorSource, error) + Delete(ctx context.Context, id int64, currentUser string) error +} + +func NewMirrorSourceComponent(config *config.Config) (MirrorSourceComponent, error) { + return &mirrorSourceComponentImpl{ msStore: database.NewMirrorSourceStore(), userStore: database.NewUserStore(), }, nil } -func (c *MirrorSourceComponent) Create(ctx context.Context, req types.CreateMirrorSourceReq) (*database.MirrorSource, error) { +func (c *mirrorSourceComponentImpl) Create(ctx context.Context, req types.CreateMirrorSourceReq) (*database.MirrorSource, error) { var ms database.MirrorSource user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { @@ -40,7 +48,7 @@ func (c *MirrorSourceComponent) Create(ctx context.Context, req types.CreateMirr return res, nil } -func (c *MirrorSourceComponent) Get(ctx context.Context, id int64, currentUser string) (*database.MirrorSource, error) { +func (c *mirrorSourceComponentImpl) Get(ctx context.Context, id int64, currentUser string) (*database.MirrorSource, error) { user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { return nil, errors.New("user does not exist") @@ -55,7 +63,7 @@ func (c *MirrorSourceComponent) Get(ctx context.Context, id int64, currentUser s return ms, nil } -func (c *MirrorSourceComponent) Index(ctx context.Context, currentUser string) ([]database.MirrorSource, error) { +func (c *mirrorSourceComponentImpl) Index(ctx context.Context, currentUser string) ([]database.MirrorSource, error) { user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { return nil, errors.New("user does not exist") @@ -69,7 +77,7 @@ func (c *MirrorSourceComponent) Index(ctx context.Context, currentUser string) ( } return ms, nil } -func (c *MirrorSourceComponent) Update(ctx context.Context, req types.UpdateMirrorSourceReq) (*database.MirrorSource, error) { +func (c *mirrorSourceComponentImpl) Update(ctx context.Context, req types.UpdateMirrorSourceReq) (*database.MirrorSource, error) { var ms database.MirrorSource user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { @@ -88,7 +96,7 @@ func (c *MirrorSourceComponent) Update(ctx context.Context, req types.UpdateMirr return &ms, nil } -func (c *MirrorSourceComponent) Delete(ctx context.Context, id int64, currentUser string) error { +func (c *mirrorSourceComponentImpl) Delete(ctx context.Context, id int64, currentUser string) error { user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { return errors.New("user does not exist") diff --git a/component/model.go b/component/model.go index ffb4eef4..536e0306 100644 --- a/component/model.go +++ b/component/model.go @@ -61,10 +61,33 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text ` const LFSPrefix = "version https://git-lfs.github.com/spec/v1" -func NewModelComponent(config *config.Config) (*ModelComponent, error) { - c := &ModelComponent{} +type ModelComponent interface { + Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Model, int, error) + Create(ctx context.Context, req *types.CreateModelReq) (*types.Model, error) + Update(ctx context.Context, req *types.UpdateModelReq) (*types.Model, error) + Delete(ctx context.Context, namespace, name, currentUser string) error + Show(ctx context.Context, namespace, name, currentUser string) (*types.Model, error) + GetServerless(ctx context.Context, namespace, name, currentUser string) (*types.DeployRepo, error) + SDKModelInfo(ctx context.Context, namespace, name, ref, currentUser string) (*types.SDKModelInfo, error) + Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) + SetRelationDatasets(ctx context.Context, req types.RelationDatasets) error + AddRelationDataset(ctx context.Context, req types.RelationDataset) error + DelRelationDataset(ctx context.Context, req types.RelationDataset) error + Predict(ctx context.Context, req *types.ModelPredictReq) (*types.ModelPredictResp, error) + // create model deploy as inference/serverless + Deploy(ctx context.Context, deployReq types.DeployActReq, req types.ModelRunReq) (int64, error) + ListModelsByRuntimeFrameworkID(ctx context.Context, currentUser string, per, page int, id int64, deployType int) ([]types.Model, int, error) + ListAllByRuntimeFramework(ctx context.Context, currentUser string) ([]database.RuntimeFramework, error) + SetRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) + DeleteRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) + ListModelsOfRuntimeFrameworks(ctx context.Context, currentUser, search, sort string, per, page int, deployType int) ([]types.Model, int, error) + OrgModels(ctx context.Context, req *types.OrgModelsReq) ([]types.Model, int, error) +} + +func NewModelComponent(config *config.Config) (ModelComponent, error) { + c := &modelComponentImpl{} var err error - c.RepoComponent, err = NewRepoComponent(config) + c.repoComponentImpl, err = NewRepoComponentImpl(config) if err != nil { return nil, err } @@ -88,22 +111,22 @@ func NewModelComponent(config *config.Config) (*ModelComponent, error) { return c, nil } -type ModelComponent struct { - *RepoComponent - spaceComonent *SpaceComponent - ms *database.ModelStore - rs *database.RepoStore - SS *database.SpaceResourceStore +type modelComponentImpl struct { + *repoComponentImpl + spaceComonent SpaceComponent + ms database.ModelStore + rs database.RepoStore + SS database.SpaceResourceStore infer inference.Client - us *database.UserStore + us database.UserStore deployer deploy.Deployer - ac *AccountingComponent - ts *database.TagStore - rac *RuntimeArchitectureComponent - ds *database.DatasetStore + ac AccountingComponent + ts database.TagStore + rac RuntimeArchitectureComponent + ds database.DatasetStore } -func (c *ModelComponent) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Model, int, error) { +func (c *modelComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Model, int, error) { var ( err error resModels []types.Model @@ -169,7 +192,7 @@ func (c *ModelComponent) Index(ctx context.Context, filter *types.RepoFilter, pe return resModels, total, nil } -func (c *ModelComponent) Create(ctx context.Context, req *types.CreateModelReq) (*types.Model, error) { +func (c *modelComponentImpl) Create(ctx context.Context, req *types.CreateModelReq) (*types.Model, error) { var ( nickname string tags []types.RepoTag @@ -292,7 +315,7 @@ func buildCreateFileReq(p *types.CreateFileParams, repoType types.RepositoryType } } -func (c *ModelComponent) Update(ctx context.Context, req *types.UpdateModelReq) (*types.Model, error) { +func (c *modelComponentImpl) Update(ctx context.Context, req *types.UpdateModelReq) (*types.Model, error) { req.RepoType = types.ModelRepo dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { @@ -329,7 +352,7 @@ func (c *ModelComponent) Update(ctx context.Context, req *types.UpdateModelReq) return resModel, nil } -func (c *ModelComponent) Delete(ctx context.Context, namespace, name, currentUser string) error { +func (c *modelComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { model, err := c.ms.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find model, error: %w", err) @@ -353,7 +376,7 @@ func (c *ModelComponent) Delete(ctx context.Context, namespace, name, currentUse return nil } -func (c *ModelComponent) Show(ctx context.Context, namespace, name, currentUser string) (*types.Model, error) { +func (c *modelComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Model, error) { var tags []types.RepoTag model, err := c.ms.FindByPath(ctx, namespace, name) if err != nil { @@ -435,7 +458,7 @@ func (c *ModelComponent) Show(ctx context.Context, namespace, name, currentUser return resModel, nil } -func (c *ModelComponent) GetServerless(ctx context.Context, namespace, name, currentUser string) (*types.DeployRepo, error) { +func (c *modelComponentImpl) GetServerless(ctx context.Context, namespace, name, currentUser string) (*types.DeployRepo, error) { model, err := c.ms.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find model, error: %w", err) @@ -474,7 +497,7 @@ func (c *ModelComponent) GetServerless(ctx context.Context, namespace, name, cur return &resDeploy, nil } -func (c *ModelComponent) SDKModelInfo(ctx context.Context, namespace, name, ref, currentUser string) (*types.SDKModelInfo, error) { +func (c *modelComponentImpl) SDKModelInfo(ctx context.Context, namespace, name, ref, currentUser string) (*types.SDKModelInfo, error) { model, err := c.ms.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find model, error: %w", err) @@ -561,7 +584,7 @@ func (c *ModelComponent) SDKModelInfo(ctx context.Context, namespace, name, ref, return resModel, nil } -func (c *ModelComponent) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { +func (c *modelComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { model, err := c.ms.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find model, error: %w", err) @@ -575,7 +598,7 @@ func (c *ModelComponent) Relations(ctx context.Context, namespace, name, current return c.getRelations(ctx, model.RepositoryID, currentUser) } -func (c *ModelComponent) SetRelationDatasets(ctx context.Context, req types.RelationDatasets) error { +func (c *modelComponentImpl) SetRelationDatasets(ctx context.Context, req types.RelationDatasets) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("user does not exist, %w", err) @@ -627,7 +650,7 @@ func (c *ModelComponent) SetRelationDatasets(ctx context.Context, req types.Rela return nil } -func (c *ModelComponent) AddRelationDataset(ctx context.Context, req types.RelationDataset) error { +func (c *modelComponentImpl) AddRelationDataset(ctx context.Context, req types.RelationDataset) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("user does not exist, %w", err) @@ -684,7 +707,7 @@ func (c *ModelComponent) AddRelationDataset(ctx context.Context, req types.Relat return nil } -func (c *ModelComponent) DelRelationDataset(ctx context.Context, req types.RelationDataset) error { +func (c *modelComponentImpl) DelRelationDataset(ctx context.Context, req types.RelationDataset) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("user does not exist, %w", err) @@ -746,7 +769,7 @@ func (c *ModelComponent) DelRelationDataset(ctx context.Context, req types.Relat return nil } -func (c *ModelComponent) getRelations(ctx context.Context, fromRepoID int64, currentUser string) (*types.Relations, error) { +func (c *modelComponentImpl) getRelations(ctx context.Context, fromRepoID int64, currentUser string) (*types.Relations, error) { res, err := c.relatedRepos(ctx, fromRepoID, currentUser) if err != nil { return nil, err @@ -823,7 +846,7 @@ func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryT return filePaths, nil } -func (c *ModelComponent) Predict(ctx context.Context, req *types.ModelPredictReq) (*types.ModelPredictResp, error) { +func (c *modelComponentImpl) Predict(ctx context.Context, req *types.ModelPredictReq) (*types.ModelPredictResp, error) { mid := inference.ModelID{ Owner: req.Namespace, Name: req.Name, @@ -843,7 +866,7 @@ func (c *ModelComponent) Predict(ctx context.Context, req *types.ModelPredictReq } // create model deploy as inference/serverless -func (c *ModelComponent) Deploy(ctx context.Context, deployReq types.DeployActReq, req types.ModelRunReq) (int64, error) { +func (c *modelComponentImpl) Deploy(ctx context.Context, deployReq types.DeployActReq, req types.ModelRunReq) (int64, error) { m, err := c.ms.FindByPath(ctx, deployReq.Namespace, deployReq.Name) if err != nil { return -1, fmt.Errorf("cannot find model, %w", err) @@ -935,7 +958,7 @@ func (c *ModelComponent) Deploy(ctx context.Context, deployReq types.DeployActRe }) } -func (c *ModelComponent) ListModelsByRuntimeFrameworkID(ctx context.Context, currentUser string, per, page int, id int64, deployType int) ([]types.Model, int, error) { +func (c *modelComponentImpl) ListModelsByRuntimeFrameworkID(ctx context.Context, currentUser string, per, page int, id int64, deployType int) ([]types.Model, int, error) { var ( user database.User err error @@ -981,7 +1004,7 @@ func (c *ModelComponent) ListModelsByRuntimeFrameworkID(ctx context.Context, cur return resModels, total, nil } -func (c *ModelComponent) ListAllByRuntimeFramework(ctx context.Context, currentUser string) ([]database.RuntimeFramework, error) { +func (c *modelComponentImpl) ListAllByRuntimeFramework(ctx context.Context, currentUser string) ([]database.RuntimeFramework, error) { runtimes, err := c.runFrame.ListAll(ctx) if err != nil { newError := fmt.Errorf("failed to get public model repos,error:%w", err) @@ -991,7 +1014,7 @@ func (c *ModelComponent) ListAllByRuntimeFramework(ctx context.Context, currentU return runtimes, nil } -func (c *ModelComponent) SetRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) { +func (c *modelComponentImpl) SetRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) { runtimeRepos, err := c.rtfm.FindByID(ctx, id) if err != nil { return nil, err @@ -1034,7 +1057,7 @@ func (c *ModelComponent) SetRuntimeFrameworkModes(ctx context.Context, deployTyp return failedModels, nil } -func (c *ModelComponent) DeleteRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) { +func (c *modelComponentImpl) DeleteRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) { models, err := c.ms.ListByPath(ctx, paths) if err != nil { return nil, err @@ -1051,7 +1074,7 @@ func (c *ModelComponent) DeleteRuntimeFrameworkModes(ctx context.Context, deploy return failedModels, nil } -func (c *ModelComponent) ListModelsOfRuntimeFrameworks(ctx context.Context, currentUser, search, sort string, per, page int, deployType int) ([]types.Model, int, error) { +func (c *modelComponentImpl) ListModelsOfRuntimeFrameworks(ctx context.Context, currentUser, search, sort string, per, page int, deployType int) ([]types.Model, int, error) { var ( user database.User err error @@ -1101,7 +1124,7 @@ func (c *ModelComponent) ListModelsOfRuntimeFrameworks(ctx context.Context, curr return resModels, total, nil } -func (c *ModelComponent) OrgModels(ctx context.Context, req *types.OrgModelsReq) ([]types.Model, int, error) { +func (c *modelComponentImpl) OrgModels(ctx context.Context, req *types.OrgModelsReq) ([]types.Model, int, error) { var resModels []types.Model var err error r := membership.RoleUnknown diff --git a/component/multi_sync.go b/component/multi_sync.go index f6574392..22c287d0 100644 --- a/component/multi_sync.go +++ b/component/multi_sync.go @@ -19,25 +19,30 @@ import ( "opencsg.com/csghub-server/common/utils/common" ) -type MultiSyncComponent struct { - s *database.MultiSyncStore - repo *database.RepoStore - model *database.ModelStore - dataset *database.DatasetStore - namespace *database.NamespaceStore - user *database.UserStore - versionStore *database.SyncVersionStore - tag *database.TagStore - file *database.FileStore +type multiSyncComponentImpl struct { + s database.MultiSyncStore + repo database.RepoStore + model database.ModelStore + dataset database.DatasetStore + namespace database.NamespaceStore + user database.UserStore + versionStore database.SyncVersionStore + tag database.TagStore + file database.FileStore git gitserver.GitServer } -func NewMultiSyncComponent(config *config.Config) (*MultiSyncComponent, error) { +type MultiSyncComponent interface { + More(ctx context.Context, cur int64, limit int64) ([]types.SyncVersion, error) + SyncAsClient(ctx context.Context, sc multisync.Client) error +} + +func NewMultiSyncComponent(config *config.Config) (MultiSyncComponent, error) { git, err := git.NewGitServer(config) if err != nil { return nil, fmt.Errorf("failed to create git server: %w", err) } - return &MultiSyncComponent{ + return &multiSyncComponentImpl{ s: database.NewMultiSyncStore(), repo: database.NewRepoStore(), model: database.NewModelStore(), @@ -51,7 +56,7 @@ func NewMultiSyncComponent(config *config.Config) (*MultiSyncComponent, error) { }, nil } -func (c *MultiSyncComponent) More(ctx context.Context, cur int64, limit int64) ([]types.SyncVersion, error) { +func (c *multiSyncComponentImpl) More(ctx context.Context, cur int64, limit int64) ([]types.SyncVersion, error) { dbVersions, err := c.s.GetAfter(ctx, cur, limit) if err != nil { return nil, fmt.Errorf("failed to get sync versions after %d from db: %w", cur, err) @@ -70,7 +75,7 @@ func (c *MultiSyncComponent) More(ctx context.Context, cur int64, limit int64) ( return versions, nil } -func (c *MultiSyncComponent) SyncAsClient(ctx context.Context, sc multisync.Client) error { +func (c *multiSyncComponentImpl) SyncAsClient(ctx context.Context, sc multisync.Client) error { var currentVersion int64 v, err := c.s.GetLatest(ctx) if err != nil { @@ -166,7 +171,7 @@ func (c *MultiSyncComponent) SyncAsClient(ctx context.Context, sc multisync.Clie return nil } -func (c *MultiSyncComponent) createLocalDataset(ctx context.Context, m *types.Dataset, s types.SyncVersion, sc multisync.Client) error { +func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *types.Dataset, s types.SyncVersion, sc multisync.Client) error { namespace, name, _ := strings.Cut(m.Path, "/") //add prefix to avoid namespace conflict namespace = common.AddPrefixBySourceID(s.SourceID, namespace) @@ -292,7 +297,7 @@ func (c *MultiSyncComponent) createLocalDataset(ctx context.Context, m *types.Da return nil } -func (c *MultiSyncComponent) createLocalModel(ctx context.Context, m *types.Model, s types.SyncVersion, sc multisync.Client) error { +func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types.Model, s types.SyncVersion, sc multisync.Client) error { namespace, name, _ := strings.Cut(m.Path, "/") //add prefix to avoid namespace conflict namespace = common.AddPrefixBySourceID(s.SourceID, namespace) @@ -417,7 +422,7 @@ func (c *MultiSyncComponent) createLocalModel(ctx context.Context, m *types.Mode return nil } -func (c *MultiSyncComponent) createUser(ctx context.Context, req types.CreateUserRequest) (database.User, error) { +func (c *multiSyncComponentImpl) createUser(ctx context.Context, req types.CreateUserRequest) (database.User, error) { gsUserReq := gitserver.CreateUserRequest{ Nickname: req.Name, Username: req.Username, @@ -449,11 +454,11 @@ func (c *MultiSyncComponent) createUser(ctx context.Context, req types.CreateUse return *user, err } -func (c *MultiSyncComponent) getUser(ctx context.Context, userName string) (database.User, error) { +func (c *multiSyncComponentImpl) getUser(ctx context.Context, userName string) (database.User, error) { return c.user.FindByUsername(ctx, userName) } -func (c *MultiSyncComponent) createLocalSyncVersion(ctx context.Context, v types.SyncVersion) error { +func (c *multiSyncComponentImpl) createLocalSyncVersion(ctx context.Context, v types.SyncVersion) error { syncVersion := database.SyncVersion{ Version: v.Version, SourceID: v.SourceID, diff --git a/component/prompt.go b/component/prompt.go index 3d2ee262..7c77be5b 100644 --- a/component/prompt.go +++ b/component/prompt.go @@ -29,20 +29,48 @@ var ( AssistantRole string = "assistant" ) -type PromptComponent struct { +type promptComponentImpl struct { gs gitserver.GitServer - user *database.UserStore - pc *database.PromptConversationStore - pp *database.PromptPrefixStore - lc *database.LLMConfigStore - pt *database.PromptStore + user database.UserStore + pc database.PromptConversationStore + pp database.PromptPrefixStore + lc database.LLMConfigStore + pt database.PromptStore llm *llm.Client - *RepoComponent + *repoComponentImpl maxPromptFS int64 } -func NewPromptComponent(cfg *config.Config) (*PromptComponent, error) { - r, err := NewRepoComponent(cfg) +type PromptComponent interface { + ListPrompt(ctx context.Context, req types.PromptReq) ([]PromptOutput, error) + GetPrompt(ctx context.Context, req types.PromptReq) (*PromptOutput, error) + ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*PromptOutput, error) + CreatePrompt(ctx context.Context, req types.PromptReq, body *CreatePromptReq) (*Prompt, error) + UpdatePrompt(ctx context.Context, req types.PromptReq, body *UpdatePromptReq) (*Prompt, error) + DeletePrompt(ctx context.Context, req types.PromptReq) error + NewConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) + ListConversationsByUserID(ctx context.Context, currentUser string) ([]database.PromptConversation, error) + GetConversation(ctx context.Context, req types.ConversationReq) (*database.PromptConversation, error) + SubmitMessage(ctx context.Context, req types.ConversationReq) (<-chan string, error) + SaveGeneratedText(ctx context.Context, req types.Conversation) (*database.PromptConversationMessage, error) + RemoveConversation(ctx context.Context, req types.ConversationReq) error + UpdateConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) + LikeConversationMessage(ctx context.Context, req types.ConversationMessageReq) error + HateConversationMessage(ctx context.Context, req types.ConversationMessageReq) error + SetRelationModels(ctx context.Context, req types.RelationModels) error + AddRelationModel(ctx context.Context, req types.RelationModel) error + DelRelationModel(ctx context.Context, req types.RelationModel) error + CreatePromptRepo(ctx context.Context, req *types.CreatePromptRepoReq) (*types.PromptRes, error) + IndexPromptRepo(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.PromptRes, int, error) + UpdatePromptRepo(ctx context.Context, req *types.UpdatePromptRepoReq) (*types.PromptRes, error) + RemoveRepo(ctx context.Context, namespace, name, currentUser string) error + Show(ctx context.Context, namespace, name, currentUser string) (*types.PromptRes, error) + Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) + OrgPrompts(ctx context.Context, req *types.OrgPromptsReq) ([]types.PromptRes, int, error) +} + +func NewPromptComponent(cfg *config.Config) (PromptComponent, error) { + r, err := NewRepoComponentImpl(cfg) if err != nil { return nil, fmt.Errorf("failed to create repo component,cause:%w", err) } @@ -50,20 +78,20 @@ func NewPromptComponent(cfg *config.Config) (*PromptComponent, error) { if err != nil { return nil, fmt.Errorf("failed to create git server,cause:%w", err) } - return &PromptComponent{ - gs: gs, - user: database.NewUserStore(), - pc: database.NewPromptConversationStore(), - pp: database.NewPromptPrefixStore(), - lc: database.NewLLMConfigStore(), - pt: database.NewPromptStore(), - llm: llm.NewClient(), - RepoComponent: r, - maxPromptFS: cfg.Dataset.PromptMaxJsonlFileSize, + return &promptComponentImpl{ + gs: gs, + user: database.NewUserStore(), + pc: database.NewPromptConversationStore(), + pp: database.NewPromptPrefixStore(), + lc: database.NewLLMConfigStore(), + pt: database.NewPromptStore(), + llm: llm.NewClient(), + repoComponentImpl: r, + maxPromptFS: cfg.Dataset.PromptMaxJsonlFileSize, }, nil } -func (c *PromptComponent) ListPrompt(ctx context.Context, req types.PromptReq) ([]PromptOutput, error) { +func (c *promptComponentImpl) ListPrompt(ctx context.Context, req types.PromptReq) ([]PromptOutput, error) { r, err := c.repo.FindByPath(ctx, types.PromptRepo, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) @@ -135,7 +163,7 @@ func (c *PromptComponent) ListPrompt(ctx context.Context, req types.PromptReq) ( } -func (c *PromptComponent) GetPrompt(ctx context.Context, req types.PromptReq) (*PromptOutput, error) { +func (c *promptComponentImpl) GetPrompt(ctx context.Context, req types.PromptReq) (*PromptOutput, error) { r, err := c.repo.FindByPath(ctx, types.PromptRepo, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find prompt repo, error: %w", err) @@ -165,7 +193,7 @@ func (c *PromptComponent) GetPrompt(ctx context.Context, req types.PromptReq) (* return p, nil } -func (c *PromptComponent) ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*PromptOutput, error) { +func (c *promptComponentImpl) ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*PromptOutput, error) { f, err := c.gs.GetRepoFileContents(ctx, req) if err != nil { return nil, fmt.Errorf("failed to get %s contents, cause:%w", req.Path, err) @@ -189,7 +217,7 @@ func (c *PromptComponent) ParseJsonFile(ctx context.Context, req gitserver.GetRe return &po, nil } -func (c *PromptComponent) CreatePrompt(ctx context.Context, req types.PromptReq, body *CreatePromptReq) (*Prompt, error) { +func (c *promptComponentImpl) CreatePrompt(ctx context.Context, req types.PromptReq, body *CreatePromptReq) (*Prompt, error) { u, err := c.checkPromptRepoPermission(ctx, req) if err != nil { return nil, fmt.Errorf("user do not allowed create prompt") @@ -225,7 +253,7 @@ func (c *PromptComponent) CreatePrompt(ctx context.Context, req types.PromptReq, return &body.Prompt, nil } -func (c *PromptComponent) UpdatePrompt(ctx context.Context, req types.PromptReq, body *UpdatePromptReq) (*Prompt, error) { +func (c *promptComponentImpl) UpdatePrompt(ctx context.Context, req types.PromptReq, body *UpdatePromptReq) (*Prompt, error) { u, err := c.checkPromptRepoPermission(ctx, req) if err != nil { return nil, fmt.Errorf("user do not allowed update prompt") @@ -262,7 +290,7 @@ func (c *PromptComponent) UpdatePrompt(ctx context.Context, req types.PromptReq, return &body.Prompt, nil } -func (c *PromptComponent) DeletePrompt(ctx context.Context, req types.PromptReq) error { +func (c *promptComponentImpl) DeletePrompt(ctx context.Context, req types.PromptReq) error { u, err := c.checkPromptRepoPermission(ctx, req) if err != nil { return fmt.Errorf("user do not allowed delete prompt") @@ -292,7 +320,7 @@ func (c *PromptComponent) DeletePrompt(ctx context.Context, req types.PromptReq) return nil } -func (c *PromptComponent) checkFileExist(ctx context.Context, req types.PromptReq) (bool, error) { +func (c *promptComponentImpl) checkFileExist(ctx context.Context, req types.PromptReq) (bool, error) { getFileRawReq := gitserver.GetRepoInfoByPathReq{ Namespace: req.Namespace, Name: req.Name, @@ -307,7 +335,7 @@ func (c *PromptComponent) checkFileExist(ctx context.Context, req types.PromptRe return true, nil } -func (c *PromptComponent) checkPromptRepoPermission(ctx context.Context, req types.PromptReq) (*database.User, error) { +func (c *promptComponentImpl) checkPromptRepoPermission(ctx context.Context, req types.PromptReq) (*database.User, error) { namespace, err := c.namespace.FindByPath(ctx, req.Namespace) if err != nil { return nil, errors.New("namespace does not exist") @@ -336,7 +364,7 @@ func (c *PromptComponent) checkPromptRepoPermission(ctx context.Context, req typ return &user, nil } -func (c *PromptComponent) NewConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { +func (c *promptComponentImpl) NewConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, errors.New("user does not exist") @@ -355,7 +383,7 @@ func (c *PromptComponent) NewConversation(ctx context.Context, req types.Convers return &conversation, nil } -func (c *PromptComponent) ListConversationsByUserID(ctx context.Context, currentUser string) ([]database.PromptConversation, error) { +func (c *promptComponentImpl) ListConversationsByUserID(ctx context.Context, currentUser string) ([]database.PromptConversation, error) { user, err := c.user.FindByUsername(ctx, currentUser) if err != nil { return nil, errors.New("user does not exist") @@ -367,7 +395,7 @@ func (c *PromptComponent) ListConversationsByUserID(ctx context.Context, current return conversations, nil } -func (c *PromptComponent) GetConversation(ctx context.Context, req types.ConversationReq) (*database.PromptConversation, error) { +func (c *promptComponentImpl) GetConversation(ctx context.Context, req types.ConversationReq) (*database.PromptConversation, error) { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, errors.New("user does not exist") @@ -379,7 +407,7 @@ func (c *PromptComponent) GetConversation(ctx context.Context, req types.Convers return conversation, nil } -func (c *PromptComponent) SubmitMessage(ctx context.Context, req types.ConversationReq) (<-chan string, error) { +func (c *promptComponentImpl) SubmitMessage(ctx context.Context, req types.ConversationReq) (<-chan string, error) { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, errors.New("user does not exist") @@ -445,7 +473,7 @@ func (c *PromptComponent) SubmitMessage(ctx context.Context, req types.Conversat return ch, nil } -func (c *PromptComponent) SaveGeneratedText(ctx context.Context, req types.Conversation) (*database.PromptConversationMessage, error) { +func (c *promptComponentImpl) SaveGeneratedText(ctx context.Context, req types.Conversation) (*database.PromptConversationMessage, error) { respMsg := database.PromptConversationMessage{ ConversationID: req.Uuid, Role: AssistantRole, @@ -458,7 +486,7 @@ func (c *PromptComponent) SaveGeneratedText(ctx context.Context, req types.Conve return msg, nil } -func (c *PromptComponent) RemoveConversation(ctx context.Context, req types.ConversationReq) error { +func (c *promptComponentImpl) RemoveConversation(ctx context.Context, req types.ConversationReq) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return errors.New("user does not exist") @@ -471,7 +499,7 @@ func (c *PromptComponent) RemoveConversation(ctx context.Context, req types.Conv return nil } -func (c *PromptComponent) UpdateConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { +func (c *promptComponentImpl) UpdateConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, errors.New("user does not exist") @@ -493,7 +521,7 @@ func (c *PromptComponent) UpdateConversation(ctx context.Context, req types.Conv return resp, nil } -func (c *PromptComponent) LikeConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { +func (c *promptComponentImpl) LikeConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return errors.New("user does not exist") @@ -509,7 +537,7 @@ func (c *PromptComponent) LikeConversationMessage(ctx context.Context, req types return nil } -func (c *PromptComponent) HateConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { +func (c *promptComponentImpl) HateConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return errors.New("user does not exist") @@ -530,7 +558,7 @@ func isChinese(s string) bool { return re.MatchString(s) } -func (c *PromptComponent) SetRelationModels(ctx context.Context, req types.RelationModels) error { +func (c *promptComponentImpl) SetRelationModels(ctx context.Context, req types.RelationModels) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("user does not exist, %w", err) @@ -628,7 +656,7 @@ func GetOutputForReadme(metaMap map[string]any, splits []string) (string, error) return output, nil } -func (c *PromptComponent) AddRelationModel(ctx context.Context, req types.RelationModel) error { +func (c *promptComponentImpl) AddRelationModel(ctx context.Context, req types.RelationModel) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("user does not exist, %w", err) @@ -685,7 +713,7 @@ func (c *PromptComponent) AddRelationModel(ctx context.Context, req types.Relati return nil } -func (c *PromptComponent) DelRelationModel(ctx context.Context, req types.RelationModel) error { +func (c *promptComponentImpl) DelRelationModel(ctx context.Context, req types.RelationModel) error { user, err := c.user.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("user does not exist, %w", err) @@ -747,7 +775,7 @@ func (c *PromptComponent) DelRelationModel(ctx context.Context, req types.Relati return nil } -func (c *PromptComponent) CreatePromptRepo(ctx context.Context, req *types.CreatePromptRepoReq) (*types.PromptRes, error) { +func (c *promptComponentImpl) CreatePromptRepo(ctx context.Context, req *types.CreatePromptRepoReq) (*types.PromptRes, error) { var ( nickname string tags []types.RepoTag @@ -874,7 +902,7 @@ func (c *PromptComponent) CreatePromptRepo(ctx context.Context, req *types.Creat return resPrompt, nil } -func (c *PromptComponent) IndexPromptRepo(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.PromptRes, int, error) { +func (c *promptComponentImpl) IndexPromptRepo(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.PromptRes, int, error) { var ( err error resPrompts []types.PromptRes @@ -948,7 +976,7 @@ func (c *PromptComponent) IndexPromptRepo(ctx context.Context, filter *types.Rep return resPrompts, total, nil } -func (c *PromptComponent) UpdatePromptRepo(ctx context.Context, req *types.UpdatePromptRepoReq) (*types.PromptRes, error) { +func (c *promptComponentImpl) UpdatePromptRepo(ctx context.Context, req *types.UpdatePromptRepoReq) (*types.PromptRes, error) { req.RepoType = types.PromptRepo dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { @@ -983,7 +1011,7 @@ func (c *PromptComponent) UpdatePromptRepo(ctx context.Context, req *types.Updat return resPrompt, nil } -func (c *PromptComponent) RemoveRepo(ctx context.Context, namespace, name, currentUser string) error { +func (c *promptComponentImpl) RemoveRepo(ctx context.Context, namespace, name, currentUser string) error { prompt, err := c.pt.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find prompt, error: %w", err) @@ -1007,7 +1035,7 @@ func (c *PromptComponent) RemoveRepo(ctx context.Context, namespace, name, curre return nil } -func (c *PromptComponent) Show(ctx context.Context, namespace, name, currentUser string) (*types.PromptRes, error) { +func (c *promptComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.PromptRes, error) { var tags []types.RepoTag prompt, err := c.pt.FindByPath(ctx, namespace, name) if err != nil { @@ -1078,7 +1106,7 @@ func (c *PromptComponent) Show(ctx context.Context, namespace, name, currentUser return resPrompt, nil } -func (c *PromptComponent) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { +func (c *promptComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { prompt, err := c.pt.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find prompt repo, error: %w", err) @@ -1092,7 +1120,7 @@ func (c *PromptComponent) Relations(ctx context.Context, namespace, name, curren return c.getRelations(ctx, prompt.RepositoryID, currentUser) } -func (c *PromptComponent) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { +func (c *promptComponentImpl) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { res, err := c.relatedRepos(ctx, repoID, currentUser) if err != nil { return nil, err @@ -1172,7 +1200,7 @@ func (req *Prompt) GetSensitiveFields() []types.SensitiveField { return fields } -func (c *PromptComponent) OrgPrompts(ctx context.Context, req *types.OrgPromptsReq) ([]types.PromptRes, int, error) { +func (c *promptComponentImpl) OrgPrompts(ctx context.Context, req *types.OrgPromptsReq) ([]types.PromptRes, int, error) { var resPrompts []types.PromptRes var err error r := membership.RoleUnknown diff --git a/component/recom.go b/component/recom.go index 72b4f0fc..6c344abf 100644 --- a/component/recom.go +++ b/component/recom.go @@ -13,26 +13,33 @@ import ( "opencsg.com/csghub-server/common/config" ) -type RecomComponent struct { - rs *database.RecomStore - repos *database.RepoStore +type recomComponentImpl struct { + rs database.RecomStore + repos database.RepoStore gs gitserver.GitServer } -func NewRecomComponent(cfg *config.Config) (*RecomComponent, error) { +type RecomComponent interface { + SetOpWeight(ctx context.Context, repoID, weight int64) error + // loop through repositories and calculate the recom score of the repository + CalculateRecomScore(ctx context.Context) + CalcTotalScore(ctx context.Context, repo *database.Repository, weights map[string]string) float64 +} + +func NewRecomComponent(cfg *config.Config) (RecomComponent, error) { gs, err := git.NewGitServer(cfg) if err != nil { return nil, fmt.Errorf("failed to init git server,%w", err) } - return &RecomComponent{ + return &recomComponentImpl{ rs: database.NewRecomStore(), repos: database.NewRepoStore(), gs: gs, }, nil } -func (rc *RecomComponent) SetOpWeight(ctx context.Context, repoID, weight int64) error { +func (rc *recomComponentImpl) SetOpWeight(ctx context.Context, repoID, weight int64) error { _, err := rc.repos.FindById(ctx, repoID) if err != nil { return fmt.Errorf("failed to find repository with id %d, err:%w", repoID, err) @@ -41,7 +48,7 @@ func (rc *RecomComponent) SetOpWeight(ctx context.Context, repoID, weight int64) } // loop through repositories and calculate the recom score of the repository -func (rc *RecomComponent) CalculateRecomScore(ctx context.Context) { +func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context) { weights, err := rc.loadWeights() if err != nil { slog.Error("Error loading weights", err) @@ -54,7 +61,7 @@ func (rc *RecomComponent) CalculateRecomScore(ctx context.Context) { } for _, repo := range repos { repoID := repo.ID - score := rc.calcTotalScore(ctx, repo, weights) + score := rc.CalcTotalScore(ctx, repo, weights) err := rc.rs.UpsertScore(ctx, repoID, score) if err != nil { slog.Error("Error updating recom score", slog.Int64("repo_id", repoID), slog.Float64("score", score), @@ -63,7 +70,7 @@ func (rc *RecomComponent) CalculateRecomScore(ctx context.Context) { } } -func (rc *RecomComponent) calcTotalScore(ctx context.Context, repo *database.Repository, weights map[string]string) float64 { +func (rc *recomComponentImpl) CalcTotalScore(ctx context.Context, repo *database.Repository, weights map[string]string) float64 { score := float64(0) if freshness, ok := weights["freshness"]; ok { @@ -84,7 +91,7 @@ func (rc *RecomComponent) calcTotalScore(ctx context.Context, repo *database.Rep return score } -func (rc *RecomComponent) calcFreshnessScore(createdAt time.Time, weightExp string) float64 { +func (rc *recomComponentImpl) calcFreshnessScore(createdAt time.Time, weightExp string) float64 { // TODO:cache compiled script hours := time.Since(createdAt).Hours() scriptFreshness := tengo.NewScript([]byte(weightExp)) @@ -103,7 +110,7 @@ func (rc *RecomComponent) calcFreshnessScore(createdAt time.Time, weightExp stri return sc.Get("score").Float() } -func (rc *RecomComponent) calcDownloadsScore(downloads int64, weightExp string) float64 { +func (rc *recomComponentImpl) calcDownloadsScore(downloads int64, weightExp string) float64 { // TODO:cache compiled script scriptFreshness := tengo.NewScript([]byte(weightExp)) scriptFreshness.Add("score", 0.0) @@ -121,7 +128,7 @@ func (rc *RecomComponent) calcDownloadsScore(downloads int64, weightExp string) return sc.Get("score").Float() } -func (rc *RecomComponent) calcQualityScore(ctx context.Context, repo *database.Repository) (float64, error) { +func (rc *recomComponentImpl) calcQualityScore(ctx context.Context, repo *database.Repository) (float64, error) { score := 0.0 // get file counts from git server namespace, name := repo.NamespaceAndName() @@ -148,7 +155,7 @@ func (rc *RecomComponent) calcQualityScore(ctx context.Context, repo *database.R return score, nil } -func (rc *RecomComponent) loadWeights() (map[string]string, error) { +func (rc *recomComponentImpl) loadWeights() (map[string]string, error) { ctx := context.Background() items, err := rc.rs.LoadWeights(ctx) if err != nil { @@ -162,7 +169,7 @@ func (rc *RecomComponent) loadWeights() (map[string]string, error) { return weights, nil } -func (rc *RecomComponent) loadOpWeights() (map[int64]int, error) { +func (rc *recomComponentImpl) loadOpWeights() (map[int64]int, error) { ctx := context.Background() items, err := rc.rs.LoadOpWeights(ctx) if err != nil { diff --git a/component/recom_test.go b/component/recom_test.go index 38f148b5..dcd2cc0d 100644 --- a/component/recom_test.go +++ b/component/recom_test.go @@ -23,7 +23,7 @@ func TestCalculateRecomScore(t *testing.T) { weights1 := map[string]string{ "freshness": expFreshness, } - score1 := rc.calcTotalScore(ctx, repo1, weights1) + score1 := rc.CalcTotalScore(ctx, repo1, weights1) if score1 > 100 || score1 < 98 { t.Errorf("Expected score1 should in range [98,100], got: %f", score1) } @@ -34,7 +34,7 @@ func TestCalculateRecomScore(t *testing.T) { weights2 := map[string]string{ "freshness": expFreshness, } - score2 := rc.calcTotalScore(ctx, repo2, weights2) + score2 := rc.CalcTotalScore(ctx, repo2, weights2) if score2 > 98.0 || score2 < 60.0 { t.Errorf("Expected score1 should in range [60,98), got: %f", score2) } @@ -45,7 +45,7 @@ func TestCalculateRecomScore(t *testing.T) { weights3 := map[string]string{ "freshness": expFreshness, } - score3 := rc.calcTotalScore(ctx, repo3, weights3) + score3 := rc.CalcTotalScore(ctx, repo3, weights3) if score3 < 0 || score3 > 60 { t.Errorf("Expected score1 should in range [0,60), got: %f", score2) } diff --git a/component/repo.go b/component/repo.go index d07a99a3..fdf71fa0 100644 --- a/component/repo.go +++ b/component/repo.go @@ -41,44 +41,116 @@ const ( GitAttributesFileName = ".gitattributes" ) -type RepoComponent struct { - tc *TagComponent - user *database.UserStore - org *database.OrgStore - namespace *database.NamespaceStore - repo *database.RepoStore - repoFile *database.RepoFileStore - rel *database.RepoRelationsStore - mirror *database.MirrorStore +type repoComponentImpl struct { + tc TagComponent + user database.UserStore + org database.OrgStore + namespace database.NamespaceStore + repo database.RepoStore + repoFile database.RepoFileStore + rel database.RepoRelationsStore + mirror database.MirrorStore git gitserver.GitServer s3Client *s3.Client userSvcClient rpc.UserSvcClient lfsBucket string - uls *database.UserLikesStore + uls database.UserLikesStore mirrorServer mirrorserver.MirrorServer - runFrame *database.RuntimeFrameworksStore - deploy *database.DeployTaskStore + runFrame database.RuntimeFrameworksStore + deploy database.DeployTaskStore deployer deploy.Deployer publicRootDomain string serverBaseUrl string - cluster *database.ClusterInfoStore - mirrorSource *database.MirrorSourceStore - tokenStore *database.AccessTokenStore - rtfm *database.RuntimeFrameworksStore + cluster database.ClusterInfoStore + mirrorSource database.MirrorSourceStore + tokenStore database.AccessTokenStore + rtfm database.RuntimeFrameworksStore rrtfms *database.RepositoriesRuntimeFrameworkStore - syncVersion *database.SyncVersionStore - syncClientSetting *database.SyncClientSettingStore - file *database.FileStore + syncVersion database.SyncVersionStore + syncClientSetting database.SyncClientSettingStore + file database.FileStore config *config.Config - ac *AccountingComponent - srs *database.SpaceResourceStore - lfsMetaObjectStore *database.LfsMetaObjectStore - recom *database.RecomStore + ac AccountingComponent + srs database.SpaceResourceStore + lfsMetaObjectStore database.LfsMetaObjectStore + recom database.RecomStore mq *queue.PriorityQueue } -func NewRepoComponent(config *config.Config) (*RepoComponent, error) { - c := &RepoComponent{} +type RepoComponent interface { + CreateRepo(ctx context.Context, req types.CreateRepoReq) (*gitserver.CreateRepoResp, *database.Repository, error) + UpdateRepo(ctx context.Context, req types.UpdateRepoReq) (*database.Repository, error) + DeleteRepo(ctx context.Context, req types.DeleteRepoReq) (*database.Repository, error) + // PublicToUser gets visible repos of the given user and user's orgs + PublicToUser(ctx context.Context, repoType types.RepositoryType, userName string, filter *types.RepoFilter, per, page int) (repos []*database.Repository, count int, err error) + CreateFile(ctx context.Context, req *types.CreateFileReq) (*types.CreateFileResp, error) + UpdateFile(ctx context.Context, req *types.UpdateFileReq) (*types.UpdateFileResp, error) + DeleteFile(ctx context.Context, req *types.DeleteFileReq) (*types.DeleteFileResp, error) + Commits(ctx context.Context, req *types.GetCommitsReq) ([]types.Commit, *types.RepoPageOpts, error) + LastCommit(ctx context.Context, req *types.GetCommitsReq) (*types.Commit, error) + FileRaw(ctx context.Context, req *types.GetFileReq) (string, error) + DownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (io.ReadCloser, int64, string, error) + Branches(ctx context.Context, req *types.GetBranchesReq) ([]types.Branch, error) + Tags(ctx context.Context, req *types.GetTagsReq) ([]database.Tag, error) + UpdateTags(ctx context.Context, namespace, name string, repoType types.RepositoryType, category, currentUser string, tags []string) error + Tree(ctx context.Context, req *types.GetFileReq) ([]*types.File, error) + UploadFile(ctx context.Context, req *types.CreateFileReq) error + SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, ref, userName string) (*types.SDKFiles, error) + IsLfs(ctx context.Context, req *types.GetFileReq) (bool, int64, error) + HeadDownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (*types.File, error) + SDKDownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (io.ReadCloser, int64, string, error) + // UpdateDownloads increase clone download count for repo by given count + UpdateDownloads(ctx context.Context, req *types.UpdateDownloadsReq) error + // IncrDownloads increase the click download count for repo by 1 + IncrDownloads(ctx context.Context, repoType types.RepositoryType, namespace, name string) error + FileInfo(ctx context.Context, req *types.GetFileReq) (*types.File, error) + AllowReadAccessRepo(ctx context.Context, repo *database.Repository, username string) (bool, error) + AllowReadAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) + AllowWriteAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) + AllowAdminAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) + GetCommitWithDiff(ctx context.Context, req *types.GetCommitsReq) (*types.CommitResponse, error) + CreateMirror(ctx context.Context, req types.CreateMirrorReq) (*database.Mirror, error) + MirrorFromSaas(ctx context.Context, namespace, name, currentUser string, repoType types.RepositoryType) error + GetMirror(ctx context.Context, req types.GetMirrorReq) (*database.Mirror, error) + UpdateMirror(ctx context.Context, req types.UpdateMirrorReq) (*database.Mirror, error) + DeleteMirror(ctx context.Context, req types.DeleteMirrorReq) error + // get runtime framework list with type + ListRuntimeFrameworkWithType(ctx context.Context, deployType int) ([]types.RuntimeFramework, error) + // get runtime framework list + ListRuntimeFramework(ctx context.Context, repoType types.RepositoryType, namespace, name string, deployType int) ([]types.RuntimeFramework, error) + CreateRuntimeFramework(ctx context.Context, req *types.RuntimeFrameworkReq) (*types.RuntimeFramework, error) + UpdateRuntimeFramework(ctx context.Context, id int64, req *types.RuntimeFrameworkReq) (*types.RuntimeFramework, error) + DeleteRuntimeFramework(ctx context.Context, id int64) error + ListDeploy(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string) ([]types.DeployRepo, error) + DeleteDeploy(ctx context.Context, delReq types.DeployActReq) error + DeployDetail(ctx context.Context, detailReq types.DeployActReq) (*types.DeployRepo, error) + DeployInstanceLogs(ctx context.Context, logReq types.DeployActReq) (*deploy.MultiLogReader, error) + // check access repo permission by repo id + AllowAccessByRepoID(ctx context.Context, repoID int64, username string) (bool, error) + // check access endpoint for rproxy + AllowAccessEndpoint(ctx context.Context, currentUser string, deploy *database.Deploy) (bool, error) + // check access deploy permission + AllowAccessDeploy(ctx context.Context, req types.DeployActReq) (bool, error) + DeployStop(ctx context.Context, stopReq types.DeployActReq) error + AllowReadAccessByDeployID(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string, deployID int64) (bool, error) + DeployStatus(ctx context.Context, repoType types.RepositoryType, namespace, name string, deployID int64) (string, string, []types.Instance, error) + GetDeployBySvcName(ctx context.Context, svcName string) (*database.Deploy, error) + SyncMirror(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string) error + DeployUpdate(ctx context.Context, updateReq types.DeployActReq, req *types.DeployUpdateReq) error + DeployStart(ctx context.Context, startReq types.DeployActReq) error + AllFiles(ctx context.Context, req types.GetAllFilesReq) ([]*types.File, error) +} + +func NewRepoComponentImpl(config *config.Config) (*repoComponentImpl, error) { + r, err := NewRepoComponent(config) + if err != nil { + return nil, err + } + return r.(*repoComponentImpl), nil +} + +func NewRepoComponent(config *config.Config) (RepoComponent, error) { + c := &repoComponentImpl{} c.namespace = database.NewNamespaceStore() c.user = database.NewUserStore() c.org = database.NewOrgStore() @@ -144,7 +216,7 @@ func NewRepoComponent(config *config.Config) (*RepoComponent, error) { return c, nil } -func (c *RepoComponent) CreateRepo(ctx context.Context, req types.CreateRepoReq) (*gitserver.CreateRepoResp, *database.Repository, error) { +func (c *repoComponentImpl) CreateRepo(ctx context.Context, req types.CreateRepoReq) (*gitserver.CreateRepoResp, *database.Repository, error) { namespace, err := c.namespace.FindByPath(ctx, req.Namespace) if err != nil { return nil, nil, errors.New("namespace does not exist") @@ -219,7 +291,7 @@ func (c *RepoComponent) CreateRepo(ctx context.Context, req types.CreateRepoReq) return gitRepo, newDBRepo, nil } -func (c *RepoComponent) UpdateRepo(ctx context.Context, req types.UpdateRepoReq) (*database.Repository, error) { +func (c *repoComponentImpl) UpdateRepo(ctx context.Context, req types.UpdateRepoReq) (*database.Repository, error) { repo, err := c.repo.Find(ctx, req.Namespace, string(req.RepoType), req.Name) if err != nil { return nil, errors.New("repository does not exist") @@ -285,7 +357,7 @@ func (c *RepoComponent) UpdateRepo(ctx context.Context, req types.UpdateRepoReq) return resRepo, nil } -func (c *RepoComponent) DeleteRepo(ctx context.Context, req types.DeleteRepoReq) (*database.Repository, error) { +func (c *repoComponentImpl) DeleteRepo(ctx context.Context, req types.DeleteRepoReq) (*database.Repository, error) { repo, err := c.repo.Find(ctx, req.Namespace, string(req.RepoType), req.Name) if err != nil { return nil, errors.New("repository does not exist") @@ -341,7 +413,7 @@ func (c *RepoComponent) DeleteRepo(ctx context.Context, req types.DeleteRepoReq) } // PublicToUser gets visible repos of the given user and user's orgs -func (c *RepoComponent) PublicToUser(ctx context.Context, repoType types.RepositoryType, userName string, filter *types.RepoFilter, per, page int) (repos []*database.Repository, count int, err error) { +func (c *repoComponentImpl) PublicToUser(ctx context.Context, repoType types.RepositoryType, userName string, filter *types.RepoFilter, per, page int) (repos []*database.Repository, count int, err error) { var repoOwnerIDs []int64 if len(userName) > 0 { // get user orgs from user service @@ -356,7 +428,7 @@ func (c *RepoComponent) PublicToUser(ctx context.Context, repoType types.Reposit repoOwnerIDs = append(repoOwnerIDs, org.UserID) } } - repos, count, err = c.tc.rs.PublicToUser(ctx, repoType, repoOwnerIDs, filter, per, page) + repos, count, err = c.repo.PublicToUser(ctx, repoType, repoOwnerIDs, filter, per, page) if err != nil { return nil, 0, fmt.Errorf("failed to get user public repos, error: %w", err) } @@ -365,7 +437,7 @@ func (c *RepoComponent) PublicToUser(ctx context.Context, repoType types.Reposit } // relatedRepos gets all repos related to the given repo, and return them by repo type -func (c *RepoComponent) relatedRepos(ctx context.Context, repoID int64, currentUser string) (map[types.RepositoryType][]*database.Repository, error) { +func (c *repoComponentImpl) relatedRepos(ctx context.Context, repoID int64, currentUser string) (map[types.RepositoryType][]*database.Repository, error) { fromRelations, err := c.rel.From(ctx, repoID) if err != nil { return nil, fmt.Errorf("failed to get repo relation from, error: %w", err) @@ -411,7 +483,7 @@ func (c *RepoComponent) relatedRepos(ctx context.Context, repoID int64, currentU return res, nil } -func (c *RepoComponent) visiableToUser(ctx context.Context, repos []*database.Repository, currentUser string) ([]*database.Repository, error) { +func (c *repoComponentImpl) visiableToUser(ctx context.Context, repos []*database.Repository, currentUser string) ([]*database.Repository, error) { var res []*database.Repository for _, repo := range repos { if repo.Private { @@ -433,7 +505,7 @@ func (c *RepoComponent) visiableToUser(ctx context.Context, repos []*database.Re return res, nil } -func (c *RepoComponent) CreateFile(ctx context.Context, req *types.CreateFileReq) (*types.CreateFileResp, error) { +func (c *repoComponentImpl) CreateFile(ctx context.Context, req *types.CreateFileReq) (*types.CreateFileResp, error) { slog.Debug("creating file get request", slog.String("namespace", req.Namespace), slog.String("filepath", req.FilePath)) var ( err error @@ -512,7 +584,7 @@ func (c *RepoComponent) CreateFile(ctx context.Context, req *types.CreateFileReq return &resp, nil } -func (c *RepoComponent) createReadmeFile(ctx context.Context, req *types.CreateFileReq) error { +func (c *repoComponentImpl) createReadmeFile(ctx context.Context, req *types.CreateFileReq) error { var err error contentDecoded, _ := base64.RawStdEncoding.DecodeString(req.Content) _, err = c.tc.UpdateMetaTags(ctx, getTagScopeByRepoType(req.RepoType), req.Namespace, req.Name, string(contentDecoded)) @@ -528,7 +600,7 @@ func (c *RepoComponent) createReadmeFile(ctx context.Context, req *types.CreateF return err } -func (c *RepoComponent) createLibraryFile(ctx context.Context, req *types.CreateFileReq) error { +func (c *repoComponentImpl) createLibraryFile(ctx context.Context, req *types.CreateFileReq) error { var err error err = c.tc.UpdateLibraryTags(ctx, getTagScopeByRepoType(req.RepoType), req.Namespace, req.Name, "", req.FilePath) if err != nil { @@ -544,7 +616,7 @@ func (c *RepoComponent) createLibraryFile(ctx context.Context, req *types.Create return err } -func (c *RepoComponent) UpdateFile(ctx context.Context, req *types.UpdateFileReq) (*types.UpdateFileResp, error) { +func (c *repoComponentImpl) UpdateFile(ctx context.Context, req *types.UpdateFileReq) (*types.UpdateFileResp, error) { slog.Debug("update file get request", slog.String("namespace", req.Namespace), slog.String("filePath", req.FilePath), slog.String("origin_path", req.OriginPath)) @@ -631,7 +703,7 @@ func (c *RepoComponent) UpdateFile(ctx context.Context, req *types.UpdateFileReq return resp, nil } -func (c *RepoComponent) DeleteFile(ctx context.Context, req *types.DeleteFileReq) (*types.DeleteFileResp, error) { +func (c *repoComponentImpl) DeleteFile(ctx context.Context, req *types.DeleteFileReq) (*types.DeleteFileResp, error) { slog.Debug("delete file get request", slog.String("namespace", req.Namespace), slog.String("filePath", req.FilePath), slog.String("origin_path", req.OriginPath)) @@ -693,17 +765,17 @@ func (c *RepoComponent) DeleteFile(ctx context.Context, req *types.DeleteFileReq return resp, nil } -func (c *RepoComponent) updateLibraryFile(ctx context.Context, req *types.UpdateFileReq) error { +func (c *repoComponentImpl) updateLibraryFile(ctx context.Context, req *types.UpdateFileReq) error { err := c.changeLibraryFile(ctx, req.FilePath, req.OriginPath, req.Namespace, req.Name, req.RepoType) return err } -func (c *RepoComponent) deleteLibraryFile(ctx context.Context, req *types.DeleteFileReq) error { +func (c *repoComponentImpl) deleteLibraryFile(ctx context.Context, req *types.DeleteFileReq) error { err := c.changeLibraryFile(ctx, req.FilePath, req.OriginPath, req.Namespace, req.Name, req.RepoType) return err } -func (c *RepoComponent) changeLibraryFile(ctx context.Context, filePath, originPath, namespace, name string, repoType types.RepositoryType) error { +func (c *repoComponentImpl) changeLibraryFile(ctx context.Context, filePath, originPath, namespace, name string, repoType types.RepositoryType) error { var err error isFileRenamed := filePath != originPath @@ -720,7 +792,7 @@ func (c *RepoComponent) changeLibraryFile(ctx context.Context, filePath, originP return err } -func (c *RepoComponent) updateReadmeFile(ctx context.Context, req *types.UpdateFileReq) error { +func (c *repoComponentImpl) updateReadmeFile(ctx context.Context, req *types.UpdateFileReq) error { slog.Debug("file is readme", slog.String("content", req.Content)) err := c.changeReadmeFile(ctx, req.Content, req.Namespace, req.Name, req.RepoType) if err != nil { @@ -729,7 +801,7 @@ func (c *RepoComponent) updateReadmeFile(ctx context.Context, req *types.UpdateF return err } -func (c *RepoComponent) deleteReadmeFile(ctx context.Context, req *types.DeleteFileReq) error { +func (c *repoComponentImpl) deleteReadmeFile(ctx context.Context, req *types.DeleteFileReq) error { err := c.changeReadmeFile(ctx, req.Content, req.Namespace, req.Name, req.RepoType) if err != nil { return fmt.Errorf("failed to update meta tags for delete readme, cause: %w", err) @@ -737,7 +809,7 @@ func (c *RepoComponent) deleteReadmeFile(ctx context.Context, req *types.DeleteF return err } -func (c *RepoComponent) changeReadmeFile(ctx context.Context, content, namespace, name string, repoType types.RepositoryType) error { +func (c *repoComponentImpl) changeReadmeFile(ctx context.Context, content, namespace, name string, repoType types.RepositoryType) error { contentDecoded, _ := base64.RawStdEncoding.DecodeString(content) _, err := c.tc.UpdateMetaTags(ctx, getTagScopeByRepoType(repoType), namespace, name, string(contentDecoded)) if err != nil { @@ -746,7 +818,7 @@ func (c *RepoComponent) changeReadmeFile(ctx context.Context, content, namespace return err } -func (c *RepoComponent) Commits(ctx context.Context, req *types.GetCommitsReq) ([]types.Commit, *types.RepoPageOpts, error) { +func (c *repoComponentImpl) Commits(ctx context.Context, req *types.GetCommitsReq) ([]types.Commit, *types.RepoPageOpts, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -778,7 +850,7 @@ func (c *RepoComponent) Commits(ctx context.Context, req *types.GetCommitsReq) ( return commits, pageOpt, nil } -func (c *RepoComponent) LastCommit(ctx context.Context, req *types.GetCommitsReq) (*types.Commit, error) { +func (c *repoComponentImpl) LastCommit(ctx context.Context, req *types.GetCommitsReq) (*types.Commit, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -808,7 +880,7 @@ func (c *RepoComponent) LastCommit(ctx context.Context, req *types.GetCommitsReq return commit, nil } -func (c *RepoComponent) FileRaw(ctx context.Context, req *types.GetFileReq) (string, error) { +func (c *repoComponentImpl) FileRaw(ctx context.Context, req *types.GetFileReq) (string, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil || repo == nil { return "", fmt.Errorf("failed to find repo, error: %w", err) @@ -847,7 +919,7 @@ func (c *RepoComponent) FileRaw(ctx context.Context, req *types.GetFileReq) (str return raw, nil } -func (c *RepoComponent) DownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (io.ReadCloser, int64, string, error) { +func (c *repoComponentImpl) DownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (io.ReadCloser, int64, string, error) { var ( reader io.ReadCloser downloadUrl string @@ -902,7 +974,7 @@ func (c *RepoComponent) DownloadFile(ctx context.Context, req *types.GetFileReq, } } -func (c *RepoComponent) Branches(ctx context.Context, req *types.GetBranchesReq) ([]types.Branch, error) { +func (c *repoComponentImpl) Branches(ctx context.Context, req *types.GetBranchesReq) ([]types.Branch, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -933,7 +1005,7 @@ func (c *RepoComponent) Branches(ctx context.Context, req *types.GetBranchesReq) return bs, nil } -func (c *RepoComponent) Tags(ctx context.Context, req *types.GetTagsReq) ([]database.Tag, error) { +func (c *repoComponentImpl) Tags(ctx context.Context, req *types.GetTagsReq) ([]database.Tag, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find %s, error: %w", req.RepoType, err) @@ -954,7 +1026,7 @@ func (c *RepoComponent) Tags(ctx context.Context, req *types.GetTagsReq) ([]data return tags, nil } -func (c *RepoComponent) UpdateTags(ctx context.Context, namespace, name string, repoType types.RepositoryType, category, currentUser string, tags []string) error { +func (c *repoComponentImpl) UpdateTags(ctx context.Context, namespace, name string, repoType types.RepositoryType, category, currentUser string, tags []string) error { repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) @@ -974,7 +1046,7 @@ func (c *RepoComponent) UpdateTags(ctx context.Context, namespace, name string, return err } -func (c *RepoComponent) Tree(ctx context.Context, req *types.GetFileReq) ([]*types.File, error) { +func (c *repoComponentImpl) Tree(ctx context.Context, req *types.GetFileReq) ([]*types.File, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -1040,7 +1112,7 @@ func (c *RepoComponent) Tree(ctx context.Context, req *types.GetFileReq) ([]*typ return tree, nil } -func (c *RepoComponent) UploadFile(ctx context.Context, req *types.CreateFileReq) error { +func (c *repoComponentImpl) UploadFile(ctx context.Context, req *types.CreateFileReq) error { parentPath := filepath.Dir(req.FilePath) if parentPath == "." { parentPath = "/" @@ -1079,7 +1151,7 @@ func (c *RepoComponent) UploadFile(ctx context.Context, req *types.CreateFileReq return err } -func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, ref, userName string) (*types.SDKFiles, error) { +func (c *repoComponentImpl) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, ref, userName string) (*types.SDKFiles, error) { var sdkFiles []types.SDKFile repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil || repo == nil { @@ -1117,7 +1189,7 @@ func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.Reposit }, nil } -func (c *RepoComponent) IsLfs(ctx context.Context, req *types.GetFileReq) (bool, int64, error) { +func (c *repoComponentImpl) IsLfs(ctx context.Context, req *types.GetFileReq) (bool, int64, error) { getFileRawReq := gitserver.GetRepoInfoByPathReq{ Namespace: req.Namespace, Name: req.Name, @@ -1137,7 +1209,7 @@ func (c *RepoComponent) IsLfs(ctx context.Context, req *types.GetFileReq) (bool, return strings.HasPrefix(content, LFSPrefix), int64(len(content)), nil } -func (c *RepoComponent) HeadDownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (*types.File, error) { +func (c *repoComponentImpl) HeadDownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (*types.File, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -1170,7 +1242,7 @@ func (c *RepoComponent) HeadDownloadFile(ctx context.Context, req *types.GetFile return file, nil } -func (c *RepoComponent) SDKDownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (io.ReadCloser, int64, string, error) { +func (c *repoComponentImpl) SDKDownloadFile(ctx context.Context, req *types.GetFileReq, userName string) (io.ReadCloser, int64, string, error) { var downloadUrl string repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { @@ -1234,7 +1306,7 @@ func (c *RepoComponent) SDKDownloadFile(ctx context.Context, req *types.GetFileR } // UpdateDownloads increase clone download count for repo by given count -func (c *RepoComponent) UpdateDownloads(ctx context.Context, req *types.UpdateDownloadsReq) error { +func (c *repoComponentImpl) UpdateDownloads(ctx context.Context, req *types.UpdateDownloadsReq) error { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find %s, error: %w", req.RepoType, err) @@ -1248,7 +1320,7 @@ func (c *RepoComponent) UpdateDownloads(ctx context.Context, req *types.UpdateDo } // IncrDownloads increase the click download count for repo by 1 -func (c *RepoComponent) IncrDownloads(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { +func (c *repoComponentImpl) IncrDownloads(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find %s, error: %w", repoType, err) @@ -1261,7 +1333,7 @@ func (c *RepoComponent) IncrDownloads(ctx context.Context, repoType types.Reposi return err } -func (c *RepoComponent) FileInfo(ctx context.Context, req *types.GetFileReq) (*types.File, error) { +func (c *repoComponentImpl) FileInfo(ctx context.Context, req *types.GetFileReq) (*types.File, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -1324,7 +1396,7 @@ func getTagScopeByRepoType(repoType types.RepositoryType) database.TagScope { } } -func (c *RepoComponent) AllowReadAccessRepo(ctx context.Context, repo *database.Repository, username string) (bool, error) { +func (c *repoComponentImpl) AllowReadAccessRepo(ctx context.Context, repo *database.Repository, username string) (bool, error) { if !repo.Private { return true, nil } @@ -1337,7 +1409,7 @@ func (c *RepoComponent) AllowReadAccessRepo(ctx context.Context, repo *database. return c.checkCurrentUserPermission(ctx, username, namespace, membership.RoleRead) } -func (c *RepoComponent) AllowReadAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { +func (c *repoComponentImpl) AllowReadAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil { return false, fmt.Errorf("failed to find repo, error: %w", err) @@ -1345,7 +1417,7 @@ func (c *RepoComponent) AllowReadAccess(ctx context.Context, repoType types.Repo return c.AllowReadAccessRepo(ctx, repo, username) } -func (c *RepoComponent) AllowWriteAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { +func (c *repoComponentImpl) AllowWriteAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil { return false, fmt.Errorf("failed to find repo, error: %w", err) @@ -1361,7 +1433,7 @@ func (c *RepoComponent) AllowWriteAccess(ctx context.Context, repoType types.Rep return c.checkCurrentUserPermission(ctx, username, namespace, membership.RoleWrite) } -func (c *RepoComponent) AllowAdminAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { +func (c *repoComponentImpl) AllowAdminAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil { return false, fmt.Errorf("failed to find repo, error: %w", err) @@ -1377,7 +1449,7 @@ func (c *RepoComponent) AllowAdminAccess(ctx context.Context, repoType types.Rep return c.checkCurrentUserPermission(ctx, username, namespace, membership.RoleAdmin) } -func (c *RepoComponent) getUserRepoPermission(ctx context.Context, userName string, repo *database.Repository) (*types.UserRepoPermission, error) { +func (c *repoComponentImpl) getUserRepoPermission(ctx context.Context, userName string, repo *database.Repository) (*types.UserRepoPermission, error) { if userName == "" { //anonymous user only has read permission to public repo return &types.UserRepoPermission{CanRead: !repo.Private, CanWrite: false, CanAdmin: false}, nil @@ -1417,7 +1489,7 @@ func (c *RepoComponent) getUserRepoPermission(ctx context.Context, userName stri } } -func (c *RepoComponent) checkCurrentUserPermission(ctx context.Context, userName string, namespace string, role membership.Role) (bool, error) { +func (c *repoComponentImpl) checkCurrentUserPermission(ctx context.Context, userName string, namespace string, role membership.Role) (bool, error) { ns, err := c.namespace.FindByPath(ctx, namespace) if err != nil { return false, err @@ -1443,7 +1515,7 @@ func (c *RepoComponent) checkCurrentUserPermission(ctx context.Context, userName } } -func (c *RepoComponent) GetCommitWithDiff(ctx context.Context, req *types.GetCommitsReq) (*types.CommitResponse, error) { +func (c *repoComponentImpl) GetCommitWithDiff(ctx context.Context, req *types.GetCommitsReq) (*types.CommitResponse, error) { // get commit diff by commit id if req.Ref == "" { return nil, fmt.Errorf("failed to find request commit id") @@ -1474,7 +1546,7 @@ func (c *RepoComponent) GetCommitWithDiff(ctx context.Context, req *types.GetCom return resp, nil } -func (c *RepoComponent) CreateMirror(ctx context.Context, req types.CreateMirrorReq) (*database.Mirror, error) { +func (c *repoComponentImpl) CreateMirror(ctx context.Context, req types.CreateMirrorReq) (*database.Mirror, error) { var ( mirror database.Mirror taskId int64 @@ -1572,7 +1644,7 @@ func (c *RepoComponent) CreateMirror(ctx context.Context, req types.CreateMirror return reqMirror, nil } -func (c *RepoComponent) MirrorFromSaas(ctx context.Context, namespace, name, currentUser string, repoType types.RepositoryType) error { +func (c *repoComponentImpl) MirrorFromSaas(ctx context.Context, namespace, name, currentUser string, repoType types.RepositoryType) error { repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) @@ -1655,7 +1727,7 @@ func (c *RepoComponent) MirrorFromSaas(ctx context.Context, namespace, name, cur return nil } -func (c *RepoComponent) mirrorFromSaasSync(ctx context.Context, mirror *database.Mirror, namespace, name string, repoType types.RepositoryType) error { +func (c *repoComponentImpl) mirrorFromSaasSync(ctx context.Context, mirror *database.Mirror, namespace, name string, repoType types.RepositoryType) error { var err error syncClientSetting, err := c.syncClientSetting.First(ctx) if err != nil { @@ -1693,7 +1765,7 @@ func (c *RepoComponent) mirrorFromSaasSync(ctx context.Context, mirror *database return nil } -func (c *RepoComponent) GetMirror(ctx context.Context, req types.GetMirrorReq) (*database.Mirror, error) { +func (c *repoComponentImpl) GetMirror(ctx context.Context, req types.GetMirrorReq) (*database.Mirror, error) { admin, err := c.checkCurrentUserPermission(ctx, req.CurrentUser, req.Namespace, membership.RoleAdmin) if err != nil { return nil, fmt.Errorf("failed to check permission to create mirror, error: %w", err) @@ -1713,7 +1785,7 @@ func (c *RepoComponent) GetMirror(ctx context.Context, req types.GetMirrorReq) ( return mirror, nil } -func (c *RepoComponent) UpdateMirror(ctx context.Context, req types.UpdateMirrorReq) (*database.Mirror, error) { +func (c *repoComponentImpl) UpdateMirror(ctx context.Context, req types.UpdateMirrorReq) (*database.Mirror, error) { admin, err := c.checkCurrentUserPermission(ctx, req.CurrentUser, req.Namespace, membership.RoleAdmin) if err != nil { return nil, fmt.Errorf("failed to check permission to create mirror, error: %w", err) @@ -1757,7 +1829,7 @@ func (c *RepoComponent) UpdateMirror(ctx context.Context, req types.UpdateMirror return mirror, nil } -func (c *RepoComponent) DeleteMirror(ctx context.Context, req types.DeleteMirrorReq) error { +func (c *repoComponentImpl) DeleteMirror(ctx context.Context, req types.DeleteMirrorReq) error { admin, err := c.checkCurrentUserPermission(ctx, req.CurrentUser, req.Namespace, membership.RoleAdmin) if err != nil { return fmt.Errorf("failed to check permission to create mirror, error: %w", err) @@ -1782,7 +1854,7 @@ func (c *RepoComponent) DeleteMirror(ctx context.Context, req types.DeleteMirror } // get runtime framework list with type -func (c *RepoComponent) ListRuntimeFrameworkWithType(ctx context.Context, deployType int) ([]types.RuntimeFramework, error) { +func (c *repoComponentImpl) ListRuntimeFrameworkWithType(ctx context.Context, deployType int) ([]types.RuntimeFramework, error) { frames, err := c.runFrame.List(ctx, deployType) if err != nil { return nil, fmt.Errorf("failed to list runtime frameworks, error: %w", err) @@ -1804,7 +1876,7 @@ func (c *RepoComponent) ListRuntimeFrameworkWithType(ctx context.Context, deploy } // get runtime framework list -func (c *RepoComponent) ListRuntimeFramework(ctx context.Context, repoType types.RepositoryType, namespace, name string, deployType int) ([]types.RuntimeFramework, error) { +func (c *repoComponentImpl) ListRuntimeFramework(ctx context.Context, repoType types.RepositoryType, namespace, name string, deployType int) ([]types.RuntimeFramework, error) { repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -1830,7 +1902,7 @@ func (c *RepoComponent) ListRuntimeFramework(ctx context.Context, repoType types return frameList, nil } -func (c *RepoComponent) CreateRuntimeFramework(ctx context.Context, req *types.RuntimeFrameworkReq) (*types.RuntimeFramework, error) { +func (c *repoComponentImpl) CreateRuntimeFramework(ctx context.Context, req *types.RuntimeFrameworkReq) (*types.RuntimeFramework, error) { newFrame := database.RuntimeFramework{ FrameName: req.FrameName, FrameVersion: req.FrameVersion, @@ -1856,7 +1928,7 @@ func (c *RepoComponent) CreateRuntimeFramework(ctx context.Context, req *types.R return frame, nil } -func (c *RepoComponent) UpdateRuntimeFramework(ctx context.Context, id int64, req *types.RuntimeFrameworkReq) (*types.RuntimeFramework, error) { +func (c *repoComponentImpl) UpdateRuntimeFramework(ctx context.Context, id int64, req *types.RuntimeFrameworkReq) (*types.RuntimeFramework, error) { newFrame := database.RuntimeFramework{ ID: id, FrameName: req.FrameName, @@ -1883,7 +1955,7 @@ func (c *RepoComponent) UpdateRuntimeFramework(ctx context.Context, id int64, re }, nil } -func (c *RepoComponent) DeleteRuntimeFramework(ctx context.Context, id int64) error { +func (c *repoComponentImpl) DeleteRuntimeFramework(ctx context.Context, id int64) error { frame, err := c.runFrame.FindByID(ctx, id) if err != nil { return fmt.Errorf("failed to find runtime frameworks, error: %w", err) @@ -1892,7 +1964,7 @@ func (c *RepoComponent) DeleteRuntimeFramework(ctx context.Context, id int64) er return err } -func (c *RepoComponent) ListDeploy(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string) ([]types.DeployRepo, error) { +func (c *repoComponentImpl) ListDeploy(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string) ([]types.DeployRepo, error) { user, err := c.user.FindByUsername(ctx, currentUser) if err != nil { return nil, errors.New("user does not exist") @@ -1934,7 +2006,7 @@ func (c *RepoComponent) ListDeploy(ctx context.Context, repoType types.Repositor return resDeploys, nil } -func (c *RepoComponent) DeleteDeploy(ctx context.Context, delReq types.DeployActReq) error { +func (c *repoComponentImpl) DeleteDeploy(ctx context.Context, delReq types.DeployActReq) error { user, deploy, err := c.checkDeployPermissionForUser(ctx, delReq) if err != nil { return err @@ -1975,7 +2047,7 @@ func (c *RepoComponent) DeleteDeploy(ctx context.Context, delReq types.DeployAct return err } -func (c *RepoComponent) DeployDetail(ctx context.Context, detailReq types.DeployActReq) (*types.DeployRepo, error) { +func (c *repoComponentImpl) DeployDetail(ctx context.Context, detailReq types.DeployActReq) (*types.DeployRepo, error) { var ( deploy *database.Deploy = nil err error = nil @@ -2044,7 +2116,7 @@ func (c *RepoComponent) DeployDetail(ctx context.Context, detailReq types.Deploy } // generate endpoint -func (c *RepoComponent) generateEndpoint(ctx context.Context, deploy *database.Deploy) (string, string) { +func (c *repoComponentImpl) generateEndpoint(ctx context.Context, deploy *database.Deploy) (string, string) { var endpoint string provider := "" cls, err := c.cluster.ByClusterID(ctx, deploy.ClusterID) @@ -2121,7 +2193,7 @@ func deployStatusCodeToString(code int) string { return txt } -func (c *RepoComponent) DeployInstanceLogs(ctx context.Context, logReq types.DeployActReq) (*deploy.MultiLogReader, error) { +func (c *repoComponentImpl) DeployInstanceLogs(ctx context.Context, logReq types.DeployActReq) (*deploy.MultiLogReader, error) { var ( deploy *database.Deploy = nil err error = nil @@ -2148,7 +2220,7 @@ func (c *RepoComponent) DeployInstanceLogs(ctx context.Context, logReq types.Dep } // check access repo permission by repo id -func (c *RepoComponent) AllowAccessByRepoID(ctx context.Context, repoID int64, username string) (bool, error) { +func (c *repoComponentImpl) AllowAccessByRepoID(ctx context.Context, repoID int64, username string) (bool, error) { r, err := c.repo.FindById(ctx, repoID) if err != nil { return false, fmt.Errorf("failed to get repository by repo_id: %d, %w", repoID, err) @@ -2161,7 +2233,7 @@ func (c *RepoComponent) AllowAccessByRepoID(ctx context.Context, repoID int64, u } // check access endpoint for rproxy -func (c *RepoComponent) AllowAccessEndpoint(ctx context.Context, currentUser string, deploy *database.Deploy) (bool, error) { +func (c *repoComponentImpl) AllowAccessEndpoint(ctx context.Context, currentUser string, deploy *database.Deploy) (bool, error) { if deploy.SecureLevel == types.EndpointPublic { // public endpoint return true, nil @@ -2170,7 +2242,7 @@ func (c *RepoComponent) AllowAccessEndpoint(ctx context.Context, currentUser str } // check access deploy permission -func (c *RepoComponent) AllowAccessDeploy(ctx context.Context, req types.DeployActReq) (bool, error) { +func (c *repoComponentImpl) AllowAccessDeploy(ctx context.Context, req types.DeployActReq) (bool, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return false, fmt.Errorf("failed to find repo, error: %w", err) @@ -2193,7 +2265,7 @@ func (c *RepoComponent) AllowAccessDeploy(ctx context.Context, req types.DeployA } // common check function for apiserver and rproxy -func (c *RepoComponent) checkAccessDeployForUser(ctx context.Context, repoID int64, currentUser string, deploy *database.Deploy) (bool, error) { +func (c *repoComponentImpl) checkAccessDeployForUser(ctx context.Context, repoID int64, currentUser string, deploy *database.Deploy) (bool, error) { user, err := c.user.FindByUsername(ctx, currentUser) if err != nil { return false, errors.New("user does not exist") @@ -2209,7 +2281,7 @@ func (c *RepoComponent) checkAccessDeployForUser(ctx context.Context, repoID int return true, nil } -func (c *RepoComponent) checkAccessDeployForServerless(ctx context.Context, repoID int64, currentUser string, deploy *database.Deploy) (bool, error) { +func (c *repoComponentImpl) checkAccessDeployForServerless(ctx context.Context, repoID int64, currentUser string, deploy *database.Deploy) (bool, error) { user, err := c.user.FindByUsername(ctx, currentUser) if err != nil { return false, fmt.Errorf("user %s does not exist", currentUser) @@ -2225,7 +2297,7 @@ func (c *RepoComponent) checkAccessDeployForServerless(ctx context.Context, repo return true, nil } -func (c *RepoComponent) DeployStop(ctx context.Context, stopReq types.DeployActReq) error { +func (c *repoComponentImpl) DeployStop(ctx context.Context, stopReq types.DeployActReq) error { var ( user *database.User = nil deploy *database.Deploy = nil @@ -2275,7 +2347,7 @@ func (c *RepoComponent) DeployStop(ctx context.Context, stopReq types.DeployActR return err } -func (c *RepoComponent) AllowReadAccessByDeployID(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string, deployID int64) (bool, error) { +func (c *repoComponentImpl) AllowReadAccessByDeployID(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string, deployID int64) (bool, error) { user, err := c.user.FindByUsername(ctx, currentUser) if err != nil { return false, errors.New("user does not exist") @@ -2300,7 +2372,7 @@ func (c *RepoComponent) AllowReadAccessByDeployID(ctx context.Context, repoType return c.AllowReadAccessRepo(ctx, repo, currentUser) } -func (c *RepoComponent) DeployStatus(ctx context.Context, repoType types.RepositoryType, namespace, name string, deployID int64) (string, string, []types.Instance, error) { +func (c *repoComponentImpl) DeployStatus(ctx context.Context, repoType types.RepositoryType, namespace, name string, deployID int64) (string, string, []types.Instance, error) { deploy, err := c.deploy.GetDeployByID(ctx, deployID) if err != nil { return "", SpaceStatusStopped, nil, err @@ -2322,7 +2394,7 @@ func (c *RepoComponent) DeployStatus(ctx context.Context, repoType types.Reposit return srvName, deployStatusCodeToString(code), instances, nil } -func (c *RepoComponent) GetDeployBySvcName(ctx context.Context, svcName string) (*database.Deploy, error) { +func (c *repoComponentImpl) GetDeployBySvcName(ctx context.Context, svcName string) (*database.Deploy, error) { d, err := c.deploy.GetDeployBySvcName(ctx, svcName) if err != nil { return nil, fmt.Errorf("failed to get deploy by svc name:%s, %w", svcName, err) @@ -2333,7 +2405,7 @@ func (c *RepoComponent) GetDeployBySvcName(ctx context.Context, svcName string) return d, nil } -func (c *RepoComponent) SyncMirror(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string) error { +func (c *repoComponentImpl) SyncMirror(ctx context.Context, repoType types.RepositoryType, namespace, name, currentUser string) error { admin, err := c.checkCurrentUserPermission(ctx, currentUser, namespace, membership.RoleAdmin) if err != nil { return fmt.Errorf("failed to check permission to create mirror, error: %w", err) @@ -2375,7 +2447,7 @@ func (c *RepoComponent) SyncMirror(ctx context.Context, repoType types.Repositor return nil } -func (c *RepoComponent) checkDeployPermissionForUser(ctx context.Context, deployReq types.DeployActReq) (*database.User, *database.Deploy, error) { +func (c *repoComponentImpl) checkDeployPermissionForUser(ctx context.Context, deployReq types.DeployActReq) (*database.User, *database.Deploy, error) { user, err := c.user.FindByUsername(ctx, deployReq.CurrentUser) if err != nil { return nil, nil, &types.PermissionError{Message: "user does not exist"} @@ -2393,7 +2465,7 @@ func (c *RepoComponent) checkDeployPermissionForUser(ctx context.Context, deploy return &user, deploy, nil } -func (c *RepoComponent) checkDeployPermissionForServerless(ctx context.Context, deployReq types.DeployActReq) (*database.User, *database.Deploy, error) { +func (c *repoComponentImpl) checkDeployPermissionForServerless(ctx context.Context, deployReq types.DeployActReq) (*database.User, *database.Deploy, error) { user, err := c.user.FindByUsername(ctx, deployReq.CurrentUser) if err != nil { return nil, nil, fmt.Errorf("user does not exist, %w", err) @@ -2412,7 +2484,7 @@ func (c *RepoComponent) checkDeployPermissionForServerless(ctx context.Context, return &user, deploy, nil } -func (c *RepoComponent) DeployUpdate(ctx context.Context, updateReq types.DeployActReq, req *types.DeployUpdateReq) error { +func (c *repoComponentImpl) DeployUpdate(ctx context.Context, updateReq types.DeployActReq, req *types.DeployUpdateReq) error { var ( deploy *database.Deploy = nil err error = nil @@ -2467,7 +2539,7 @@ func (c *RepoComponent) DeployUpdate(ctx context.Context, updateReq types.Deploy return err } -func (c *RepoComponent) DeployStart(ctx context.Context, startReq types.DeployActReq) error { +func (c *repoComponentImpl) DeployStart(ctx context.Context, startReq types.DeployActReq) error { var ( deploy *database.Deploy = nil err error = nil @@ -2510,7 +2582,7 @@ func (c *RepoComponent) DeployStart(ctx context.Context, startReq types.DeployAc return err } -func (c *RepoComponent) AllFiles(ctx context.Context, req types.GetAllFilesReq) ([]*types.File, error) { +func (c *repoComponentImpl) AllFiles(ctx context.Context, req types.GetAllFilesReq) ([]*types.File, error) { repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) @@ -2536,12 +2608,12 @@ func (c *RepoComponent) AllFiles(ctx context.Context, req types.GetAllFilesReq) return allFiles, nil } -func (c *RepoComponent) isAdminRole(user database.User) bool { +func (c *repoComponentImpl) isAdminRole(user database.User) bool { slog.Debug("Check if user is admin", slog.Any("user", user)) return user.CanAdmin() } -func (c *RepoComponent) getNameSpaceInfo(ctx context.Context, path string) (*types.Namespace, error) { +func (c *repoComponentImpl) getNameSpaceInfo(ctx context.Context, path string) (*types.Namespace, error) { nsResp, err := c.userSvcClient.GetNameSpaceInfo(ctx, path) if err != nil { return nil, fmt.Errorf("failed to get namespace infor from user service, path: %s, error: %w", path, err) @@ -2554,7 +2626,7 @@ func (c *RepoComponent) getNameSpaceInfo(ctx context.Context, path string) (*typ return ns, nil } -func (c *RepoComponent) checkIfShouldUseLfs(ctx context.Context, req *types.CreateFileReq) (bool, *types.CreateFileReq) { +func (c *repoComponentImpl) checkIfShouldUseLfs(ctx context.Context, req *types.CreateFileReq) (bool, *types.CreateFileReq) { gFile, err := c.git.GetRepoFileContents(ctx, gitserver.GetRepoInfoByPathReq{ RepoType: req.RepoType, Namespace: req.Namespace, @@ -2578,7 +2650,7 @@ func (c *RepoComponent) checkIfShouldUseLfs(ctx context.Context, req *types.Crea return true, req } -func (c *RepoComponent) checkIfShouldUseLfsUpdate(ctx context.Context, req *types.UpdateFileReq) (bool, *types.UpdateFileReq) { +func (c *repoComponentImpl) checkIfShouldUseLfsUpdate(ctx context.Context, req *types.UpdateFileReq) (bool, *types.UpdateFileReq) { gFile, err := c.git.GetRepoFileContents(ctx, gitserver.GetRepoInfoByPathReq{ RepoType: req.RepoType, Namespace: req.Namespace, diff --git a/component/repo_file.go b/component/repo_file.go index 3cd15e42..22481ca6 100644 --- a/component/repo_file.go +++ b/component/repo_file.go @@ -13,14 +13,19 @@ import ( "opencsg.com/csghub-server/common/types" ) -type RepoFileComponent struct { - rfs *database.RepoFileStore - rs *database.RepoStore +type repoFileComponentImpl struct { + rfs database.RepoFileStore + rs database.RepoStore gs gitserver.GitServer } -func NewRepoFileComponent(conf *config.Config) (*RepoFileComponent, error) { - c := &RepoFileComponent{ +type RepoFileComponent interface { + GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error + GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error +} + +func NewRepoFileComponent(conf *config.Config) (RepoFileComponent, error) { + c := &repoFileComponentImpl{ rfs: database.NewRepoFileStore(), rs: database.NewRepoStore(), } @@ -32,7 +37,7 @@ func NewRepoFileComponent(conf *config.Config) (*RepoFileComponent, error) { c.gs = gs return c, nil } -func (c *RepoFileComponent) GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { +func (c *repoFileComponentImpl) GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) @@ -40,7 +45,7 @@ func (c *RepoFileComponent) GenRepoFileRecords(ctx context.Context, repoType typ return c.createRepoFileRecords(ctx, *repo, "", c.gs.GetRepoFileTree) } -func (c *RepoFileComponent) GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error { +func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error { tokens := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { tokens <- struct{}{} @@ -82,7 +87,7 @@ func (c *RepoFileComponent) GenRepoFileRecordsBatch(ctx context.Context, repoTyp return nil } -func (c *RepoFileComponent) createRepoFileRecords(ctx context.Context, repo database.Repository, folder string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) error { +func (c *repoFileComponentImpl) createRepoFileRecords(ctx context.Context, repo database.Repository, folder string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) error { namespace, name := repo.NamespaceAndName() var files []*types.File diff --git a/component/runtime_architecture.go b/component/runtime_architecture.go index 8ff6044f..50c65dd3 100644 --- a/component/runtime_architecture.go +++ b/component/runtime_architecture.go @@ -20,21 +20,37 @@ var ( ScanLock sync.Mutex ) -type RuntimeArchitectureComponent struct { - r *RepoComponent - ras *database.RuntimeArchitecturesStore - rfs *database.RuntimeFrameworksStore - ts *database.TagStore - rms *database.ResourceModelStore +type runtimeArchitectureComponentImpl struct { + r *repoComponentImpl + ras database.RuntimeArchitecturesStore + rfs database.RuntimeFrameworksStore + ts database.TagStore + rms database.ResourceModelStore } -func NewRuntimeArchitectureComponent(config *config.Config) (*RuntimeArchitectureComponent, error) { - c := &RuntimeArchitectureComponent{} +type RuntimeArchitectureComponent interface { + ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]database.RuntimeArchitecture, error) + SetArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) + DeleteArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) + ScanArchitecture(ctx context.Context, id int64, scanType int, models []string) error + // check if it's supported model resource by name + IsSupportedModelResource(ctx context.Context, modelName string, rf *database.RuntimeFramework, id int64) (bool, error) + GetArchitectureFromConfig(ctx context.Context, namespace, name string) (string, error) + // remove runtime_framework tag from model + RemoveRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) + // add runtime_framework tag to model + AddRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) error + // add resource tag to model + AddResourceTag(ctx context.Context, rstags []*database.Tag, modelname string, repoId int64) error +} + +func NewRuntimeArchitectureComponent(config *config.Config) (RuntimeArchitectureComponent, error) { + c := &runtimeArchitectureComponentImpl{} c.rfs = database.NewRuntimeFrameworksStore() c.ras = database.NewRuntimeArchitecturesStore() c.ts = database.NewTagStore() c.rms = database.NewResourceModelStore() - repo, err := NewRepoComponent(config) + repo, err := NewRepoComponentImpl(config) if err != nil { return nil, fmt.Errorf("fail to create repo component, %w", err) } @@ -42,7 +58,7 @@ func NewRuntimeArchitectureComponent(config *config.Config) (*RuntimeArchitectur return c, nil } -func (c *RuntimeArchitectureComponent) ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]database.RuntimeArchitecture, error) { +func (c *runtimeArchitectureComponentImpl) ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]database.RuntimeArchitecture, error) { archs, err := c.ras.ListByRuntimeFrameworkID(ctx, id) if err != nil { return nil, fmt.Errorf("list runtime arch failed, %w", err) @@ -50,7 +66,7 @@ func (c *RuntimeArchitectureComponent) ListByRuntimeFrameworkID(ctx context.Cont return archs, nil } -func (c *RuntimeArchitectureComponent) SetArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { +func (c *runtimeArchitectureComponentImpl) SetArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { _, err := c.r.rtfm.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("invalid runtime framework id, %w", err) @@ -71,7 +87,7 @@ func (c *RuntimeArchitectureComponent) SetArchitectures(ctx context.Context, id return failedArchs, nil } -func (c *RuntimeArchitectureComponent) DeleteArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { +func (c *runtimeArchitectureComponentImpl) DeleteArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { _, err := c.r.rtfm.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("invalid runtime framework id, %w", err) @@ -89,7 +105,7 @@ func (c *RuntimeArchitectureComponent) DeleteArchitectures(ctx context.Context, return failedDeletes, nil } -func (c *RuntimeArchitectureComponent) ScanArchitecture(ctx context.Context, id int64, scanType int, models []string) error { +func (c *runtimeArchitectureComponentImpl) ScanArchitecture(ctx context.Context, id int64, scanType int, models []string) error { frame, err := c.r.rtfm.FindByID(ctx, id) if err != nil { return fmt.Errorf("invalid runtime framework id, %w", err) @@ -138,7 +154,7 @@ func (c *RuntimeArchitectureComponent) ScanArchitecture(ctx context.Context, id return nil } -func (c *RuntimeArchitectureComponent) scanNewModels(ctx context.Context, req types.ScanReq) error { +func (c *runtimeArchitectureComponentImpl) scanNewModels(ctx context.Context, req types.ScanReq) error { repos, err := c.r.repo.GetRepoWithoutRuntimeByID(ctx, req.FrameID, req.Models) if err != nil { return fmt.Errorf("failed to get repos without runtime by ID, %w", err) @@ -188,7 +204,7 @@ func (c *RuntimeArchitectureComponent) scanNewModels(ctx context.Context, req ty } // check if it's supported model resource by name -func (c *RuntimeArchitectureComponent) IsSupportedModelResource(ctx context.Context, modelName string, rf *database.RuntimeFramework, id int64) (bool, error) { +func (c *runtimeArchitectureComponentImpl) IsSupportedModelResource(ctx context.Context, modelName string, rf *database.RuntimeFramework, id int64) (bool, error) { trimModel := strings.Replace(strings.ToLower(modelName), "meta-", "", 1) rm, err := c.rms.CheckModelNameNotInRFRepo(ctx, trimModel, id) if err != nil || rm == nil { @@ -212,7 +228,7 @@ func (c *RuntimeArchitectureComponent) IsSupportedModelResource(ctx context.Cont return false, nil } -func (c *RuntimeArchitectureComponent) scanExistModels(ctx context.Context, req types.ScanReq) error { +func (c *runtimeArchitectureComponentImpl) scanExistModels(ctx context.Context, req types.ScanReq) error { repos, err := c.r.repo.GetRepoWithRuntimeByID(ctx, req.FrameID, req.Models) if err != nil { return fmt.Errorf("fail to get repos with runtime by ID, %w", err) @@ -242,7 +258,7 @@ func (c *RuntimeArchitectureComponent) scanExistModels(ctx context.Context, req return nil } -func (c *RuntimeArchitectureComponent) GetArchitectureFromConfig(ctx context.Context, namespace, name string) (string, error) { +func (c *runtimeArchitectureComponentImpl) GetArchitectureFromConfig(ctx context.Context, namespace, name string) (string, error) { content, err := c.getConfigContent(ctx, namespace, name) if err != nil { return "", fmt.Errorf("fail to read config.json for relation, %w", err) @@ -264,7 +280,7 @@ func (c *RuntimeArchitectureComponent) GetArchitectureFromConfig(ctx context.Con return config.Architectures[0], nil } -func (c *RuntimeArchitectureComponent) getConfigContent(ctx context.Context, namespace, name string) (string, error) { +func (c *runtimeArchitectureComponentImpl) getConfigContent(ctx context.Context, namespace, name string) (string, error) { content, err := c.r.git.GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ Namespace: namespace, Name: name, @@ -279,7 +295,7 @@ func (c *RuntimeArchitectureComponent) getConfigContent(ctx context.Context, nam } // remove runtime_framework tag from model -func (c *RuntimeArchitectureComponent) RemoveRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) { +func (c *runtimeArchitectureComponentImpl) RemoveRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) { rfw, _ := c.rfs.FindByID(ctx, rfId) for _, tag := range rftags { if strings.Contains(rfw.FrameImage, tag.Name) { @@ -292,7 +308,7 @@ func (c *RuntimeArchitectureComponent) RemoveRuntimeFrameworkTag(ctx context.Con } // add runtime_framework tag to model -func (c *RuntimeArchitectureComponent) AddRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) error { +func (c *runtimeArchitectureComponentImpl) AddRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) error { rfw, err := c.rfs.FindByID(ctx, rfId) if err != nil { return err @@ -309,7 +325,7 @@ func (c *RuntimeArchitectureComponent) AddRuntimeFrameworkTag(ctx context.Contex } // add resource tag to model -func (c *RuntimeArchitectureComponent) AddResourceTag(ctx context.Context, rstags []*database.Tag, modelname string, repoId int64) error { +func (c *runtimeArchitectureComponentImpl) AddResourceTag(ctx context.Context, rstags []*database.Tag, modelname string, repoId int64) error { rms, err := c.rms.FindByModelName(ctx, modelname) if err != nil { return err diff --git a/component/sensitive.go b/component/sensitive.go index 1df9a49f..939481cc 100644 --- a/component/sensitive.go +++ b/component/sensitive.go @@ -11,13 +11,19 @@ import ( "opencsg.com/csghub-server/common/types" ) -type SensitiveComponent struct { +type sensitiveComponentImpl struct { checker rpc.ModerationSvcClient enable bool } -func NewSensitiveComponent(cfg *config.Config) (*SensitiveComponent, error) { - c := &SensitiveComponent{} +type SensitiveComponent interface { + CheckText(ctx context.Context, scenario, text string) (bool, error) + CheckImage(ctx context.Context, scenario, ossBucketName, ossObjectName string) (bool, error) + CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) +} + +func NewSensitiveComponent(cfg *config.Config) (SensitiveComponent, error) { + c := &sensitiveComponentImpl{} c.enable = cfg.SensitiveCheck.Enable if c.enable { @@ -26,7 +32,7 @@ func NewSensitiveComponent(cfg *config.Config) (*SensitiveComponent, error) { return c, nil } -func (c SensitiveComponent) CheckText(ctx context.Context, scenario, text string) (bool, error) { +func (c *sensitiveComponentImpl) CheckText(ctx context.Context, scenario, text string) (bool, error) { if !c.enable { return true, nil } @@ -39,7 +45,7 @@ func (c SensitiveComponent) CheckText(ctx context.Context, scenario, text string return !result.IsSensitive, nil } -func (c SensitiveComponent) CheckImage(ctx context.Context, scenario, ossBucketName, ossObjectName string) (bool, error) { +func (c *sensitiveComponentImpl) CheckImage(ctx context.Context, scenario, ossBucketName, ossObjectName string) (bool, error) { if !c.enable { return true, nil } @@ -51,7 +57,7 @@ func (c SensitiveComponent) CheckImage(ctx context.Context, scenario, ossBucketN return !result.IsSensitive, nil } -func (c SensitiveComponent) CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) { +func (c *sensitiveComponentImpl) CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) { if !c.enable { return true, nil } diff --git a/component/space.go b/component/space.go index 1e895773..0b447d42 100644 --- a/component/space.go +++ b/component/space.go @@ -29,14 +29,37 @@ enableXsrfProtection = false streamlitConfig = ".streamlit/config.toml" ) -func NewSpaceComponent(config *config.Config) (*SpaceComponent, error) { - c := &SpaceComponent{} +type SpaceComponent interface { + Create(ctx context.Context, req types.CreateSpaceReq) (*types.Space, error) + Show(ctx context.Context, namespace, name, currentUser string) (*types.Space, error) + Update(ctx context.Context, req *types.UpdateSpaceReq) (*types.Space, error) + Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Space, int, error) + OrgSpaces(ctx context.Context, req *types.OrgSpacesReq) ([]types.Space, int, error) + // UserSpaces get spaces of owner and visible to current user + UserSpaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) + UserLikesSpaces(ctx context.Context, req *types.UserSpacesReq, userID int64) ([]types.Space, int, error) + ListByPath(ctx context.Context, paths []string) ([]*types.Space, error) + AllowCallApi(ctx context.Context, spaceID int64, username string) (bool, error) + Delete(ctx context.Context, namespace, name, currentUser string) error + Deploy(ctx context.Context, namespace, name, currentUser string) (int64, error) + Wakeup(ctx context.Context, namespace, name string) error + Stop(ctx context.Context, namespace, name string) error + // FixHasEntryFile checks whether git repo has entry point file and update space's HasAppFile property in db + FixHasEntryFile(ctx context.Context, s *database.Space) *database.Space + Status(ctx context.Context, namespace, name string) (string, string, error) + Logs(ctx context.Context, namespace, name string) (*deploy.MultiLogReader, error) + // HasEntryFile checks whether space repo has entry point file to run with + HasEntryFile(ctx context.Context, space *database.Space) bool +} + +func NewSpaceComponent(config *config.Config) (SpaceComponent, error) { + c := &spaceComponentImpl{} c.ss = database.NewSpaceStore() var err error c.sss = database.NewSpaceSdkStore() c.srs = database.NewSpaceResourceStore() c.rs = database.NewRepoStore() - c.RepoComponent, err = NewRepoComponent(config) + c.repoComponentImpl, err = NewRepoComponentImpl(config) if err != nil { return nil, err } @@ -51,19 +74,19 @@ func NewSpaceComponent(config *config.Config) (*SpaceComponent, error) { return c, nil } -type SpaceComponent struct { - *RepoComponent - ss *database.SpaceStore - sss *database.SpaceSdkStore - srs *database.SpaceResourceStore - rs *database.RepoStore - us *database.UserStore +type spaceComponentImpl struct { + *repoComponentImpl + ss database.SpaceStore + sss database.SpaceSdkStore + srs database.SpaceResourceStore + rs database.RepoStore + us database.UserStore deployer deploy.Deployer publicRootDomain string - ac *AccountingComponent + ac AccountingComponent } -func (c *SpaceComponent) Create(ctx context.Context, req types.CreateSpaceReq) (*types.Space, error) { +func (c *spaceComponentImpl) Create(ctx context.Context, req types.CreateSpaceReq) (*types.Space, error) { var nickname string if req.Nickname != "" { nickname = req.Nickname @@ -182,7 +205,7 @@ func (c *SpaceComponent) Create(ctx context.Context, req types.CreateSpaceReq) ( return space, nil } -func (c *SpaceComponent) Show(ctx context.Context, namespace, name, currentUser string) (*types.Space, error) { +func (c *spaceComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Space, error) { var tags []types.RepoTag space, err := c.ss.FindByPath(ctx, namespace, name) if err != nil { @@ -266,7 +289,7 @@ func (c *SpaceComponent) Show(ctx context.Context, namespace, name, currentUser return resModel, nil } -func (c *SpaceComponent) Update(ctx context.Context, req *types.UpdateSpaceReq) (*types.Space, error) { +func (c *spaceComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceReq) (*types.Space, error) { req.RepoType = types.SpaceRepo dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { @@ -307,7 +330,7 @@ func (c *SpaceComponent) Update(ctx context.Context, req *types.UpdateSpaceReq) return resDataset, nil } -func (c *SpaceComponent) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Space, int, error) { +func (c *spaceComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Space, int, error) { var ( resSpaces []types.Space err error @@ -379,7 +402,7 @@ func (c *SpaceComponent) Index(ctx context.Context, filter *types.RepoFilter, pe return resSpaces, total, nil } -func (c *SpaceComponent) OrgSpaces(ctx context.Context, req *types.OrgSpacesReq) ([]types.Space, int, error) { +func (c *spaceComponentImpl) OrgSpaces(ctx context.Context, req *types.OrgSpacesReq) ([]types.Space, int, error) { var resSpaces []types.Space var err error r := membership.RoleUnknown @@ -422,7 +445,7 @@ func (c *SpaceComponent) OrgSpaces(ctx context.Context, req *types.OrgSpacesReq) } // UserSpaces get spaces of owner and visible to current user -func (c *SpaceComponent) UserSpaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { +func (c *spaceComponentImpl) UserSpaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { onlyPublic := req.Owner != req.CurrentUser ms, total, err := c.ss.ByUsername(ctx, req.Owner, req.PageSize, req.Page, onlyPublic) if err != nil { @@ -453,7 +476,7 @@ func (c *SpaceComponent) UserSpaces(ctx context.Context, req *types.UserSpacesRe return resSpaces, total, nil } -func (c *SpaceComponent) UserLikesSpaces(ctx context.Context, req *types.UserSpacesReq, userID int64) ([]types.Space, int, error) { +func (c *spaceComponentImpl) UserLikesSpaces(ctx context.Context, req *types.UserSpacesReq, userID int64) ([]types.Space, int, error) { ms, total, err := c.ss.ByUserLikes(ctx, userID, req.PageSize, req.Page) if err != nil { newError := fmt.Errorf("failed to get spaces by username,%w", err) @@ -482,7 +505,7 @@ func (c *SpaceComponent) UserLikesSpaces(ctx context.Context, req *types.UserSpa return resSpaces, total, nil } -func (c *SpaceComponent) ListByPath(ctx context.Context, paths []string) ([]*types.Space, error) { +func (c *spaceComponentImpl) ListByPath(ctx context.Context, paths []string) ([]*types.Space, error) { var spaces []*types.Space spacesData, err := c.ss.ListByPath(ctx, paths) @@ -527,7 +550,7 @@ func (c *SpaceComponent) ListByPath(ctx context.Context, paths []string) ([]*typ return spaces, nil } -func (c *SpaceComponent) AllowCallApi(ctx context.Context, spaceID int64, username string) (bool, error) { +func (c *spaceComponentImpl) AllowCallApi(ctx context.Context, spaceID int64, username string) (bool, error) { if username == "" { return false, ErrUserNotFound } @@ -539,7 +562,7 @@ func (c *SpaceComponent) AllowCallApi(ctx context.Context, spaceID int64, userna return c.AllowReadAccess(ctx, s.Repository.RepositoryType, fields[0], fields[1], username) } -func (c *SpaceComponent) Delete(ctx context.Context, namespace, name, currentUser string) error { +func (c *spaceComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { space, err := c.ss.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find space, error: %w", err) @@ -567,7 +590,7 @@ func (c *SpaceComponent) Delete(ctx context.Context, namespace, name, currentUse return nil } -func (c *SpaceComponent) Deploy(ctx context.Context, namespace, name, currentUser string) (int64, error) { +func (c *spaceComponentImpl) Deploy(ctx context.Context, namespace, name, currentUser string) (int64, error) { s, err := c.ss.FindByPath(ctx, namespace, name) if err != nil { slog.Error("can't deploy space", slog.Any("error", err), slog.String("namespace", namespace), slog.String("name", name)) @@ -623,7 +646,7 @@ func (c *SpaceComponent) Deploy(ctx context.Context, namespace, name, currentUse }) } -func (c *SpaceComponent) Wakeup(ctx context.Context, namespace, name string) error { +func (c *spaceComponentImpl) Wakeup(ctx context.Context, namespace, name string) error { s, err := c.ss.FindByPath(ctx, namespace, name) if err != nil { slog.Error("can't wakeup space", slog.Any("error", err), slog.String("namespace", namespace), slog.String("name", name)) @@ -642,7 +665,7 @@ func (c *SpaceComponent) Wakeup(ctx context.Context, namespace, name string) err }) } -func (c *SpaceComponent) Stop(ctx context.Context, namespace, name string) error { +func (c *spaceComponentImpl) Stop(ctx context.Context, namespace, name string) error { s, err := c.ss.FindByPath(ctx, namespace, name) if err != nil { slog.Error("can't stop space", slog.Any("error", err), slog.String("namespace", namespace), slog.String("name", name)) @@ -676,7 +699,7 @@ func (c *SpaceComponent) Stop(ctx context.Context, namespace, name string) error } // FixHasEntryFile checks whether git repo has entry point file and update space's HasAppFile property in db -func (c *SpaceComponent) FixHasEntryFile(ctx context.Context, s *database.Space) *database.Space { +func (c *spaceComponentImpl) FixHasEntryFile(ctx context.Context, s *database.Space) *database.Space { hasAppFile := c.HasEntryFile(ctx, s) if s.HasAppFile != hasAppFile { s.HasAppFile = hasAppFile @@ -686,7 +709,7 @@ func (c *SpaceComponent) FixHasEntryFile(ctx context.Context, s *database.Space) return s } -func (c *SpaceComponent) status(ctx context.Context, s *database.Space) (string, string, error) { +func (c *spaceComponentImpl) status(ctx context.Context, s *database.Space) (string, string, error) { if !s.HasAppFile { if s.Sdk == scheduler.NGINX.Name { return "", SpaceStatusNoNGINXConf, nil @@ -717,7 +740,7 @@ func (c *SpaceComponent) status(ctx context.Context, s *database.Space) (string, return srvName, deployStatusCodeToString(code), nil } -func (c *SpaceComponent) Status(ctx context.Context, namespace, name string) (string, string, error) { +func (c *spaceComponentImpl) Status(ctx context.Context, namespace, name string) (string, string, error) { s, err := c.ss.FindByPath(ctx, namespace, name) if err != nil { return "", SpaceStatusStopped, fmt.Errorf("can't find space by path:%w", err) @@ -725,7 +748,7 @@ func (c *SpaceComponent) Status(ctx context.Context, namespace, name string) (st return c.status(ctx, s) } -func (c *SpaceComponent) Logs(ctx context.Context, namespace, name string) (*deploy.MultiLogReader, error) { +func (c *spaceComponentImpl) Logs(ctx context.Context, namespace, name string) (*deploy.MultiLogReader, error) { s, err := c.ss.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("can't find space by path:%w", err) @@ -738,7 +761,7 @@ func (c *SpaceComponent) Logs(ctx context.Context, namespace, name string) (*dep } // HasEntryFile checks whether space repo has entry point file to run with -func (c *SpaceComponent) HasEntryFile(ctx context.Context, space *database.Space) bool { +func (c *spaceComponentImpl) HasEntryFile(ctx context.Context, space *database.Space) bool { namespace, name := space.Repository.NamespaceAndName() entryFile := "app.py" if space.Sdk == scheduler.NGINX.Name { @@ -748,7 +771,7 @@ func (c *SpaceComponent) HasEntryFile(ctx context.Context, space *database.Space return c.hasEntryFile(ctx, namespace, name, entryFile) } -func (c *SpaceComponent) hasEntryFile(ctx context.Context, namespace, name, entryFile string) bool { +func (c *spaceComponentImpl) hasEntryFile(ctx context.Context, namespace, name, entryFile string) bool { var req gitserver.GetRepoInfoByPathReq req.Namespace = namespace req.Name = name @@ -770,7 +793,7 @@ func (c *SpaceComponent) hasEntryFile(ctx context.Context, namespace, name, entr return false } -func (c *SpaceComponent) mergeUpdateSpaceRequest(ctx context.Context, space *database.Space, req *types.UpdateSpaceReq) error { +func (c *spaceComponentImpl) mergeUpdateSpaceRequest(ctx context.Context, space *database.Space, req *types.UpdateSpaceReq) error { // Do not update column value if request body do not have it if req.Sdk != nil { space.Sdk = *req.Sdk diff --git a/component/space_resource.go b/component/space_resource.go index ed4a503c..b40939e0 100644 --- a/component/space_resource.go +++ b/component/space_resource.go @@ -12,19 +12,26 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewSpaceResourceComponent(config *config.Config) (*SpaceResourceComponent, error) { - c := &SpaceResourceComponent{} +type SpaceResourceComponent interface { + Index(ctx context.Context, clusterId string, deployType int) ([]types.SpaceResource, error) + Update(ctx context.Context, req *types.UpdateSpaceResourceReq) (*types.SpaceResource, error) + Create(ctx context.Context, req *types.CreateSpaceResourceReq) (*types.SpaceResource, error) + Delete(ctx context.Context, id int64) error +} + +func NewSpaceResourceComponent(config *config.Config) (SpaceResourceComponent, error) { + c := &spaceResourceComponentImpl{} c.srs = database.NewSpaceResourceStore() c.deployer = deploy.NewDeployer() return c, nil } -type SpaceResourceComponent struct { - srs *database.SpaceResourceStore +type spaceResourceComponentImpl struct { + srs database.SpaceResourceStore deployer deploy.Deployer } -func (c *SpaceResourceComponent) Index(ctx context.Context, clusterId string, deployType int) ([]types.SpaceResource, error) { +func (c *spaceResourceComponentImpl) Index(ctx context.Context, clusterId string, deployType int) ([]types.SpaceResource, error) { // backward compatibility for old api if clusterId == "" { clusters, err := c.deployer.ListCluster(ctx) @@ -76,7 +83,7 @@ func (c *SpaceResourceComponent) Index(ctx context.Context, clusterId string, de return result, nil } -func (c *SpaceResourceComponent) Update(ctx context.Context, req *types.UpdateSpaceResourceReq) (*types.SpaceResource, error) { +func (c *spaceResourceComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceResourceReq) (*types.SpaceResource, error) { sr, err := c.srs.FindByID(ctx, req.ID) if err != nil { slog.Error("error getting space resource", slog.Any("error", err)) @@ -100,7 +107,7 @@ func (c *SpaceResourceComponent) Update(ctx context.Context, req *types.UpdateSp return result, nil } -func (c *SpaceResourceComponent) Create(ctx context.Context, req *types.CreateSpaceResourceReq) (*types.SpaceResource, error) { +func (c *spaceResourceComponentImpl) Create(ctx context.Context, req *types.CreateSpaceResourceReq) (*types.SpaceResource, error) { sr := database.SpaceResource{ Name: req.Name, Resources: req.Resources, @@ -121,7 +128,7 @@ func (c *SpaceResourceComponent) Create(ctx context.Context, req *types.CreateSp return result, nil } -func (c *SpaceResourceComponent) Delete(ctx context.Context, id int64) error { +func (c *spaceResourceComponentImpl) Delete(ctx context.Context, id int64) error { sr, err := c.srs.FindByID(ctx, id) if err != nil { slog.Error("error finding space resource", slog.Any("error", err)) diff --git a/component/space_sdk.go b/component/space_sdk.go index 7ba9d27a..1c0eca17 100644 --- a/component/space_sdk.go +++ b/component/space_sdk.go @@ -9,18 +9,25 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewSpaceSdkComponent(config *config.Config) (*SpaceSdkComponent, error) { - c := &SpaceSdkComponent{} +type SpaceSdkComponent interface { + Index(ctx context.Context) ([]types.SpaceSdk, error) + Update(ctx context.Context, req *types.UpdateSpaceSdkReq) (*types.SpaceSdk, error) + Create(ctx context.Context, req *types.CreateSpaceSdkReq) (*types.SpaceSdk, error) + Delete(ctx context.Context, id int64) error +} + +func NewSpaceSdkComponent(config *config.Config) (SpaceSdkComponent, error) { + c := &spaceSdkComponentImpl{} c.sss = database.NewSpaceSdkStore() return c, nil } -type SpaceSdkComponent struct { - sss *database.SpaceSdkStore +type spaceSdkComponentImpl struct { + sss database.SpaceSdkStore } -func (c *SpaceSdkComponent) Index(ctx context.Context) ([]types.SpaceSdk, error) { +func (c *spaceSdkComponentImpl) Index(ctx context.Context) ([]types.SpaceSdk, error) { var result []types.SpaceSdk databaseSpaceSdks, err := c.sss.Index(ctx) if err != nil { @@ -37,7 +44,7 @@ func (c *SpaceSdkComponent) Index(ctx context.Context) ([]types.SpaceSdk, error) return result, nil } -func (c *SpaceSdkComponent) Update(ctx context.Context, req *types.UpdateSpaceSdkReq) (*types.SpaceSdk, error) { +func (c *spaceSdkComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceSdkReq) (*types.SpaceSdk, error) { ss, err := c.sss.FindByID(ctx, req.ID) if err != nil { slog.Error("error getting space sdk", slog.Any("error", err)) @@ -61,7 +68,7 @@ func (c *SpaceSdkComponent) Update(ctx context.Context, req *types.UpdateSpaceSd return result, nil } -func (c *SpaceSdkComponent) Create(ctx context.Context, req *types.CreateSpaceSdkReq) (*types.SpaceSdk, error) { +func (c *spaceSdkComponentImpl) Create(ctx context.Context, req *types.CreateSpaceSdkReq) (*types.SpaceSdk, error) { ss := database.SpaceSdk{ Name: req.Name, Version: req.Version, @@ -81,7 +88,7 @@ func (c *SpaceSdkComponent) Create(ctx context.Context, req *types.CreateSpaceSd return result, nil } -func (c *SpaceSdkComponent) Delete(ctx context.Context, id int64) error { +func (c *spaceSdkComponentImpl) Delete(ctx context.Context, id int64) error { ss, err := c.sss.FindByID(ctx, id) if err != nil { slog.Error("error finding space sdk", slog.Any("error", err)) diff --git a/component/sshkey.go b/component/sshkey.go index c49985cf..a0f4cd10 100644 --- a/component/sshkey.go +++ b/component/sshkey.go @@ -15,8 +15,14 @@ import ( "opencsg.com/csghub-server/common/utils/common" ) -func NewSSHKeyComponent(config *config.Config) (*SSHKeyComponent, error) { - c := &SSHKeyComponent{} +type SSHKeyComponent interface { + Create(ctx context.Context, req *types.CreateSSHKeyRequest) (*database.SSHKey, error) + Index(ctx context.Context, username string, per, page int) ([]database.SSHKey, error) + Delete(ctx context.Context, username, name string) error +} + +func NewSSHKeyComponent(config *config.Config) (SSHKeyComponent, error) { + c := &sSHKeyComponentImpl{} c.ss = database.NewSSHKeyStore() c.us = database.NewUserStore() var err error @@ -29,13 +35,13 @@ func NewSSHKeyComponent(config *config.Config) (*SSHKeyComponent, error) { return c, nil } -type SSHKeyComponent struct { - ss *database.SSHKeyStore - us *database.UserStore +type sSHKeyComponentImpl struct { + ss database.SSHKeyStore + us database.UserStore gs gitserver.GitServer } -func (c *SSHKeyComponent) Create(ctx context.Context, req *types.CreateSSHKeyRequest) (*database.SSHKey, error) { +func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKeyRequest) (*database.SSHKey, error) { user, err := c.us.FindByUsername(ctx, req.Username) if err != nil { return nil, fmt.Errorf("failed to find user,error:%w", err) @@ -81,7 +87,7 @@ func (c *SSHKeyComponent) Create(ctx context.Context, req *types.CreateSSHKeyReq return resSk, nil } -func (c *SSHKeyComponent) Index(ctx context.Context, username string, per, page int) ([]database.SSHKey, error) { +func (c *sSHKeyComponentImpl) Index(ctx context.Context, username string, per, page int) ([]database.SSHKey, error) { sks, err := c.ss.Index(ctx, username, per, page) if err != nil { return nil, fmt.Errorf("failed to get database SSH keys,error:%w", err) @@ -89,7 +95,7 @@ func (c *SSHKeyComponent) Index(ctx context.Context, username string, per, page return sks, nil } -func (c *SSHKeyComponent) Delete(ctx context.Context, username, name string) error { +func (c *sSHKeyComponentImpl) Delete(ctx context.Context, username, name string) error { sshKey, err := c.ss.FindByUsernameAndName(ctx, username, name) if err != nil { return fmt.Errorf("failed to get database SSH keys,error:%w", err) diff --git a/component/sync_client_setting.go b/component/sync_client_setting.go index 7cd49cb7..ba1cc9cc 100644 --- a/component/sync_client_setting.go +++ b/component/sync_client_setting.go @@ -11,17 +11,22 @@ import ( "opencsg.com/csghub-server/common/types" ) -type SyncClientSettingComponent struct { - settingStore *database.SyncClientSettingStore +type syncClientSettingComponentImpl struct { + settingStore database.SyncClientSettingStore } -func NewSyncClientSettingComponent(config *config.Config) (*SyncClientSettingComponent, error) { - return &SyncClientSettingComponent{ +type SyncClientSettingComponent interface { + Create(ctx context.Context, req types.CreateSyncClientSettingReq) (*database.SyncClientSetting, error) + Show(ctx context.Context) (*database.SyncClientSetting, error) +} + +func NewSyncClientSettingComponent(config *config.Config) (SyncClientSettingComponent, error) { + return &syncClientSettingComponentImpl{ settingStore: database.NewSyncClientSettingStore(), }, nil } -func (c *SyncClientSettingComponent) Create(ctx context.Context, req types.CreateSyncClientSettingReq) (*database.SyncClientSetting, error) { +func (c *syncClientSettingComponentImpl) Create(ctx context.Context, req types.CreateSyncClientSettingReq) (*database.SyncClientSetting, error) { exists, err := c.settingStore.SyncClientSettingExists(ctx) if err != nil { return nil, fmt.Errorf("failed to check sync client setting if exists, error: %w", err) @@ -43,7 +48,7 @@ func (c *SyncClientSettingComponent) Create(ctx context.Context, req types.Creat return res, nil } -func (c *SyncClientSettingComponent) Show(ctx context.Context) (*database.SyncClientSetting, error) { +func (c *syncClientSettingComponentImpl) Show(ctx context.Context) (*database.SyncClientSetting, error) { res, err := c.settingStore.First(ctx) if err != nil { if errors.Is(err, sql.ErrNoRows) { diff --git a/component/tag.go b/component/tag.go index 636a085a..495a4483 100644 --- a/component/tag.go +++ b/component/tag.go @@ -14,8 +14,16 @@ import ( "opencsg.com/csghub-server/component/tagparser" ) -func NewTagComponent(config *config.Config) (*TagComponent, error) { - tc := &TagComponent{} +type TagComponent interface { + AllTags(ctx context.Context) ([]database.Tag, error) + ClearMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string) error + UpdateMetaTags(ctx context.Context, tagScope database.TagScope, namespace, name, content string) ([]*database.RepositoryTag, error) + UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace, name, oldFilePath, newFilePath string) error + UpdateRepoTagsByCategory(ctx context.Context, tagScope database.TagScope, repoID int64, category string, tagNames []string) error +} + +func NewTagComponent(config *config.Config) (TagComponent, error) { + tc := &tagComponentImpl{} tc.ts = database.NewTagStore() tc.rs = database.NewRepoStore() if config.SensitiveCheck.Enable { @@ -24,23 +32,23 @@ func NewTagComponent(config *config.Config) (*TagComponent, error) { return tc, nil } -type TagComponent struct { - ts *database.TagStore - rs *database.RepoStore +type tagComponentImpl struct { + ts database.TagStore + rs database.RepoStore sensitiveChecker rpc.ModerationSvcClient } -func (tc *TagComponent) AllTags(ctx context.Context) ([]database.Tag, error) { +func (tc *tagComponentImpl) AllTags(ctx context.Context) ([]database.Tag, error) { // TODO: query cache for tags at first return tc.ts.AllTags(ctx) } -func (c *TagComponent) ClearMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { +func (c *tagComponentImpl) ClearMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { _, err := c.ts.SetMetaTags(ctx, repoType, namespace, name, nil) return err } -func (c *TagComponent) UpdateMetaTags(ctx context.Context, tagScope database.TagScope, namespace, name, content string) ([]*database.RepositoryTag, error) { +func (c *tagComponentImpl) UpdateMetaTags(ctx context.Context, tagScope database.TagScope, namespace, name, content string) ([]*database.RepositoryTag, error) { var ( tp tagparser.TagProcessor repoType types.RepositoryType @@ -108,7 +116,7 @@ func (c *TagComponent) UpdateMetaTags(ctx context.Context, tagScope database.Tag return repoTags, nil } -func (c *TagComponent) UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace, name, oldFilePath, newFilePath string) error { +func (c *tagComponentImpl) UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace, name, oldFilePath, newFilePath string) error { oldLibTagName := tagparser.LibraryTag(oldFilePath) newLibTagName := tagparser.LibraryTag(newFilePath) // TODO:load from cache @@ -153,7 +161,7 @@ func (c *TagComponent) UpdateLibraryTags(ctx context.Context, tagScope database. return nil } -func (c *TagComponent) UpdateRepoTagsByCategory(ctx context.Context, tagScope database.TagScope, repoID int64, category string, tagNames []string) error { +func (c *tagComponentImpl) UpdateRepoTagsByCategory(ctx context.Context, tagScope database.TagScope, repoID int64, category string, tagNames []string) error { allTags, err := c.ts.AllTagsByScopeAndCategory(ctx, tagScope, category) if err != nil { return fmt.Errorf("failed to get all tags of scope `%s`, error: %w", tagScope, err) diff --git a/component/telemetry.go b/component/telemetry.go index 349f6ae5..d1aa5512 100644 --- a/component/telemetry.go +++ b/component/telemetry.go @@ -11,21 +11,26 @@ import ( "opencsg.com/csghub-server/common/types/telemetry" ) -type TelemetryComponent struct { +type telemetryComponentImpl struct { // Add telemetry related fields and methods here - ts *database.TelemetryStore - us *database.UserStore - rs *database.RepoStore + ts database.TelemetryStore + us database.UserStore + rs database.RepoStore } -func NewTelemetryComponent() (*TelemetryComponent, error) { +type TelemetryComponent interface { + SaveUsageData(ctx context.Context, usage telemetry.Usage) error + GenUsageData(ctx context.Context) (telemetry.Usage, error) +} + +func NewTelemetryComponent() (TelemetryComponent, error) { ts := database.NewTelemetryStore() us := database.NewUserStore() rs := database.NewRepoStore() - return &TelemetryComponent{ts: ts, us: us, rs: rs}, nil + return &telemetryComponentImpl{ts: ts, us: us, rs: rs}, nil } -func (tc *TelemetryComponent) SaveUsageData(ctx context.Context, usage telemetry.Usage) error { +func (tc *telemetryComponentImpl) SaveUsageData(ctx context.Context, usage telemetry.Usage) error { t := database.Telemetry{ UUID: usage.UUID, RecordedAt: usage.RecordedAt, @@ -55,7 +60,7 @@ func (tc *TelemetryComponent) SaveUsageData(ctx context.Context, usage telemetry return nil } -func (tc *TelemetryComponent) GenUsageData(ctx context.Context) (telemetry.Usage, error) { +func (tc *telemetryComponentImpl) GenUsageData(ctx context.Context) (telemetry.Usage, error) { var usage telemetry.Usage uuid, err := uuid.NewV7() @@ -99,11 +104,11 @@ func (tc *TelemetryComponent) GenUsageData(ctx context.Context) (telemetry.Usage return usage, nil } -func (tc *TelemetryComponent) getUserCnt(ctx context.Context) (int, error) { +func (tc *telemetryComponentImpl) getUserCnt(ctx context.Context) (int, error) { return tc.us.GetActiveUserCount(ctx) } -func (tc *TelemetryComponent) getCounts(ctx context.Context) (telemetry.Counts, error) { +func (tc *telemetryComponentImpl) getCounts(ctx context.Context) (telemetry.Counts, error) { var counts telemetry.Counts modelCnt, err := tc.rs.CountByRepoType(ctx, types.ModelRepo) if err != nil { diff --git a/component/user.go b/component/user.go index 4ffb7ac4..c3146a8a 100644 --- a/component/user.go +++ b/component/user.go @@ -16,8 +16,32 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewUserComponent(config *config.Config) (*UserComponent, error) { - c := &UserComponent{} +type UserComponent interface { + Datasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) + Models(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) + Codes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) + Spaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) + AddLikes(ctx context.Context, req *types.UserLikesRequest) error + // user likes collection + LikesCollection(ctx context.Context, req *types.UserSpacesReq) ([]types.Collection, int, error) + // UserCollections get collections of owner or visible to current user + Collections(ctx context.Context, req *types.UserCollectionReq) ([]types.Collection, int, error) + LikeCollection(ctx context.Context, req *types.UserLikesRequest) error + UnLikeCollection(ctx context.Context, req *types.UserLikesRequest) error + DeleteLikes(ctx context.Context, req *types.UserLikesRequest) error + LikesSpaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) + LikesCodes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) + LikesModels(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) + LikesDatasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) + ListDeploys(ctx context.Context, repoType types.RepositoryType, req *types.DeployReq) ([]types.DeployRepo, int, error) + ListInstances(ctx context.Context, req *types.UserRepoReq) ([]types.DeployRepo, int, error) + ListServerless(ctx context.Context, req types.DeployReq) ([]types.DeployRepo, int, error) + GetUserByName(ctx context.Context, userName string) (*database.User, error) + Prompts(ctx context.Context, req *types.UserPromptsReq) ([]types.PromptRes, int, error) +} + +func NewUserComponent(config *config.Config) (UserComponent, error) { + c := &userComponentImpl{} c.ms = database.NewModelStore() c.us = database.NewUserStore() c.ds = database.NewDatasetStore() @@ -36,7 +60,7 @@ func NewUserComponent(config *config.Config) (*UserComponent, error) { newError := fmt.Errorf("failed to create git server,error:%w", err) return nil, newError } - c.repoComponent, err = NewRepoComponent(config) + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { newError := fmt.Errorf("failed to create repo component,error:%w", err) return nil, newError @@ -53,28 +77,28 @@ func NewUserComponent(config *config.Config) (*UserComponent, error) { return c, nil } -type UserComponent struct { - us *database.UserStore - ms *database.ModelStore - ds *database.DatasetStore - cs *database.CodeStore - ss *database.SpaceStore - ns *database.NamespaceStore +type userComponentImpl struct { + us database.UserStore + ms database.ModelStore + ds database.DatasetStore + cs database.CodeStore + ss database.SpaceStore + ns database.NamespaceStore gs gitserver.GitServer - spaceComponent *SpaceComponent - repoComponent *RepoComponent + spaceComponent SpaceComponent + repoComponent *repoComponentImpl deployer deploy.Deployer - uls *database.UserLikesStore - repo *database.RepoStore - deploy *database.DeployTaskStore - cos *database.CollectionStore - ac *AccountingComponent - srs *database.SpaceResourceStore + uls database.UserLikesStore + repo database.RepoStore + deploy database.DeployTaskStore + cos database.CollectionStore + ac AccountingComponent + srs database.SpaceResourceStore // urs *database.UserResourcesStore - pt *database.PromptStore + pt database.PromptStore } -func (c *UserComponent) Datasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) { +func (c *userComponentImpl) Datasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) { var resDatasets []types.Dataset userExists, err := c.us.IsExist(ctx, req.Owner) if err != nil { @@ -127,7 +151,7 @@ func (c *UserComponent) Datasets(ctx context.Context, req *types.UserDatasetsReq return resDatasets, total, nil } -func (c *UserComponent) Models(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) { +func (c *userComponentImpl) Models(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) { var resModels []types.Model userExists, err := c.us.IsExist(ctx, req.Owner) if err != nil { @@ -180,7 +204,7 @@ func (c *UserComponent) Models(ctx context.Context, req *types.UserModelsReq) ([ return resModels, total, nil } -func (c *UserComponent) Codes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) { +func (c *userComponentImpl) Codes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) { var resCodes []types.Code userExists, err := c.us.IsExist(ctx, req.Owner) if err != nil { @@ -233,7 +257,7 @@ func (c *UserComponent) Codes(ctx context.Context, req *types.UserModelsReq) ([] return resCodes, total, nil } -func (c *UserComponent) Spaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { +func (c *userComponentImpl) Spaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { userExists, err := c.us.IsExist(ctx, req.Owner) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user,error:%w", err) @@ -261,7 +285,7 @@ func (c *UserComponent) Spaces(ctx context.Context, req *types.UserSpacesReq) ([ return c.spaceComponent.UserSpaces(ctx, req) } -func (c *UserComponent) AddLikes(ctx context.Context, req *types.UserLikesRequest) error { +func (c *userComponentImpl) AddLikes(ctx context.Context, req *types.UserLikesRequest) error { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user, error:%w", err) @@ -291,7 +315,7 @@ func (c *UserComponent) AddLikes(ctx context.Context, req *types.UserLikesReques } // user likes collection -func (c *UserComponent) LikesCollection(ctx context.Context, req *types.UserSpacesReq) ([]types.Collection, int, error) { +func (c *userComponentImpl) LikesCollection(ctx context.Context, req *types.UserSpacesReq) ([]types.Collection, int, error) { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user, error:%w", err) @@ -314,7 +338,7 @@ func (c *UserComponent) LikesCollection(ctx context.Context, req *types.UserSpac } // UserCollections get collections of owner or visible to current user -func (c *UserComponent) Collections(ctx context.Context, req *types.UserCollectionReq) ([]types.Collection, int, error) { +func (c *userComponentImpl) Collections(ctx context.Context, req *types.UserCollectionReq) ([]types.Collection, int, error) { userExists, err := c.us.IsExist(ctx, req.Owner) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user,error:%w", err) @@ -356,7 +380,7 @@ func (c *UserComponent) Collections(ctx context.Context, req *types.UserCollecti return newCollection, total, nil } -func (c *UserComponent) LikeCollection(ctx context.Context, req *types.UserLikesRequest) error { +func (c *userComponentImpl) LikeCollection(ctx context.Context, req *types.UserLikesRequest) error { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user, error:%w", err) @@ -376,7 +400,7 @@ func (c *UserComponent) LikeCollection(ctx context.Context, req *types.UserLikes return err } -func (c *UserComponent) UnLikeCollection(ctx context.Context, req *types.UserLikesRequest) error { +func (c *userComponentImpl) UnLikeCollection(ctx context.Context, req *types.UserLikesRequest) error { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user,error:%w", err) @@ -386,7 +410,7 @@ func (c *UserComponent) UnLikeCollection(ctx context.Context, req *types.UserLik return err } -func (c *UserComponent) DeleteLikes(ctx context.Context, req *types.UserLikesRequest) error { +func (c *userComponentImpl) DeleteLikes(ctx context.Context, req *types.UserLikesRequest) error { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user,error:%w", err) @@ -396,7 +420,7 @@ func (c *UserComponent) DeleteLikes(ctx context.Context, req *types.UserLikesReq return err } -func (c *UserComponent) LikesSpaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { +func (c *userComponentImpl) LikesSpaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user, error:%w", err) @@ -405,7 +429,7 @@ func (c *UserComponent) LikesSpaces(ctx context.Context, req *types.UserSpacesRe return c.spaceComponent.UserLikesSpaces(ctx, req, user.ID) } -func (c *UserComponent) LikesCodes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) { +func (c *userComponentImpl) LikesCodes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) { var resCodes []types.Code user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { @@ -438,7 +462,7 @@ func (c *UserComponent) LikesCodes(ctx context.Context, req *types.UserModelsReq return resCodes, total, nil } -func (c *UserComponent) LikesModels(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) { +func (c *userComponentImpl) LikesModels(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) { var resModels []types.Model user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { @@ -471,7 +495,7 @@ func (c *UserComponent) LikesModels(ctx context.Context, req *types.UserModelsRe return resModels, total, nil } -func (c *UserComponent) LikesDatasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) { +func (c *userComponentImpl) LikesDatasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) { var resDatasets []types.Dataset user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { @@ -504,7 +528,7 @@ func (c *UserComponent) LikesDatasets(ctx context.Context, req *types.UserDatase return resDatasets, total, nil } -func (c *UserComponent) ListDeploys(ctx context.Context, repoType types.RepositoryType, req *types.DeployReq) ([]types.DeployRepo, int, error) { +func (c *userComponentImpl) ListDeploys(ctx context.Context, repoType types.RepositoryType, req *types.DeployReq) ([]types.DeployRepo, int, error) { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user:%s, error:%w", req.CurrentUser, err) @@ -566,7 +590,7 @@ func (c *UserComponent) ListDeploys(ctx context.Context, repoType types.Reposito return resDeploys, total, nil } -func (c *UserComponent) ListInstances(ctx context.Context, req *types.UserRepoReq) ([]types.DeployRepo, int, error) { +func (c *userComponentImpl) ListInstances(ctx context.Context, req *types.UserRepoReq) ([]types.DeployRepo, int, error) { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user:%s, error:%w", req.CurrentUser, err) @@ -606,7 +630,7 @@ func (c *UserComponent) ListInstances(ctx context.Context, req *types.UserRepoRe return resDeploys, total, nil } -func (c *UserComponent) ListServerless(ctx context.Context, req types.DeployReq) ([]types.DeployRepo, int, error) { +func (c *userComponentImpl) ListServerless(ctx context.Context, req types.DeployReq) ([]types.DeployRepo, int, error) { user, err := c.us.FindByUsername(ctx, req.CurrentUser) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user:%s, error:%w", req.CurrentUser, err) @@ -651,7 +675,7 @@ func (c *UserComponent) ListServerless(ctx context.Context, req types.DeployReq) return resDeploys, total, nil } -func (c *UserComponent) GetUserByName(ctx context.Context, userName string) (*database.User, error) { +func (c *userComponentImpl) GetUserByName(ctx context.Context, userName string) (*database.User, error) { user, err := c.us.FindByUsername(ctx, userName) if err != nil { return nil, fmt.Errorf("failed to check for the presence of the user %s,error:%w", userName, err) @@ -659,7 +683,7 @@ func (c *UserComponent) GetUserByName(ctx context.Context, userName string) (*da return &user, nil } -func (c *UserComponent) Prompts(ctx context.Context, req *types.UserPromptsReq) ([]types.PromptRes, int, error) { +func (c *userComponentImpl) Prompts(ctx context.Context, req *types.UserPromptsReq) ([]types.PromptRes, int, error) { var resPrompts []types.PromptRes userExists, err := c.us.IsExist(ctx, req.Owner) if err != nil { diff --git a/mirror/lfssyncer/minio.go b/mirror/lfssyncer/minio.go index c168f31d..8894e113 100644 --- a/mirror/lfssyncer/minio.go +++ b/mirror/lfssyncer/minio.go @@ -25,11 +25,11 @@ type MinioLFSSyncWorker struct { mq *queue.PriorityQueue tasks chan queue.MirrorTask wg sync.WaitGroup - mirrorStore *database.MirrorStore - lfsMetaObjectStore *database.LfsMetaObjectStore + mirrorStore database.MirrorStore + lfsMetaObjectStore database.LfsMetaObjectStore s3Client *s3.Client config *config.Config - repoStore *database.RepoStore + repoStore database.RepoStore numWorkers int } diff --git a/mirror/reposyncer/local_woker.go b/mirror/reposyncer/local_woker.go index 8c21549c..3e7e54d4 100644 --- a/mirror/reposyncer/local_woker.go +++ b/mirror/reposyncer/local_woker.go @@ -26,9 +26,9 @@ type LocalMirrorWoker struct { numWorkers int wg sync.WaitGroup saas bool - mirrorStore *database.MirrorStore - lfsMetaObjectStore *database.LfsMetaObjectStore - repoStore *database.RepoStore + mirrorStore database.MirrorStore + lfsMetaObjectStore database.LfsMetaObjectStore + repoStore database.RepoStore git gitserver.GitServer config *config.Config } diff --git a/moderation/component/repo.go b/moderation/component/repo.go index 30fa6dde..4a2b74f3 100644 --- a/moderation/component/repo.go +++ b/moderation/component/repo.go @@ -16,16 +16,22 @@ import ( "opencsg.com/csghub-server/moderation/checker" ) -type RepoComponent struct { +type repoComponentImpl struct { checker sensitive.SensitiveChecker - rs *database.RepoStore - rfs *database.RepoFileStore - rfcs *database.RepoFileCheckStore + rs database.RepoStore + rfs database.RepoFileStore + rfcs database.RepoFileCheckStore git gitserver.GitServer } -func NewRepoComponent(cfg *config.Config) (*RepoComponent, error) { - c := &RepoComponent{checker: sensitive.NewAliyunGreenChecker(cfg)} +type RepoComponent interface { + UpdateRepoSensitiveCheckStatus(ctx context.Context, repoType types.RepositoryType, namespace string, name string, status types.SensitiveCheckStatus) error + CheckRepoFiles(ctx context.Context, repoType types.RepositoryType, namespace string, name string, options CheckOption) error + CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) +} + +func NewRepoComponent(cfg *config.Config) (RepoComponent, error) { + c := &repoComponentImpl{checker: sensitive.NewAliyunGreenChecker(cfg)} gs, err := git.NewGitServer(cfg) if err != nil { return nil, fmt.Errorf("failed to create git server for sensitive component: %w", err) @@ -47,7 +53,7 @@ type CheckOption struct { // MaxConcurrent int } -func (c *RepoComponent) UpdateRepoSensitiveCheckStatus(ctx context.Context, repoType types.RepositoryType, namespace string, name string, status types.SensitiveCheckStatus) error { +func (c *repoComponentImpl) UpdateRepoSensitiveCheckStatus(ctx context.Context, repoType types.RepositoryType, namespace string, name string, status types.SensitiveCheckStatus) error { repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to get repo, error: %w", err) @@ -58,7 +64,7 @@ func (c *RepoComponent) UpdateRepoSensitiveCheckStatus(ctx context.Context, repo return err } -func (c *RepoComponent) CheckRepoFiles(ctx context.Context, repoType types.RepositoryType, namespace string, name string, options CheckOption) error { +func (c *repoComponentImpl) CheckRepoFiles(ctx context.Context, repoType types.RepositoryType, namespace string, name string, options CheckOption) error { if options.BatchSize == 0 { options.BatchSize = 10 } @@ -96,7 +102,7 @@ func (c *RepoComponent) CheckRepoFiles(ctx context.Context, repoType types.Repos return nil } -func (c *RepoComponent) processFile(ctx context.Context, file *database.RepositoryFile) { +func (c *repoComponentImpl) processFile(ctx context.Context, file *database.RepositoryFile) { reader := NewRepoFileContentReader(file, c.git) checker := checker.GetFileChecker(file.FileType, file.Path, file.LfsRelativePath) status, msg := checker.Run(reader) @@ -108,7 +114,7 @@ func (c *RepoComponent) processFile(ctx context.Context, file *database.Reposito c.saveCheckResult(ctx, file, status, msg) } -func (c *RepoComponent) saveCheckResult(ctx context.Context, file *database.RepositoryFile, status types.SensitiveCheckStatus, msg string) error { +func (c *repoComponentImpl) saveCheckResult(ctx context.Context, file *database.RepositoryFile, status types.SensitiveCheckStatus, msg string) error { fcr := &database.RepositoryFileCheck{ RepoFileID: file.ID, Status: status, @@ -141,7 +147,7 @@ func (c *RepoComponent) saveCheckResult(ctx context.Context, file *database.Repo return err } -func (cc *RepoComponent) CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) { +func (cc *repoComponentImpl) CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) { fields := req.GetSensitiveFields() for _, field := range fields { pass, err := cc.checker.PassTextCheck(ctx, sensitive.Scenario(field.Scenario), field.Value()) diff --git a/moderation/component/repo_file.go b/moderation/component/repo_file.go index 0278f453..2a515d7c 100644 --- a/moderation/component/repo_file.go +++ b/moderation/component/repo_file.go @@ -13,14 +13,20 @@ import ( "opencsg.com/csghub-server/common/types" ) -type RepoFileComponent struct { - rfs *database.RepoFileStore - rs *database.RepoStore +type repoFileComponentImpl struct { + rfs database.RepoFileStore + rs database.RepoStore gs gitserver.GitServer } -func NewRepoFileComponent(conf *config.Config) (*RepoFileComponent, error) { - c := &RepoFileComponent{ +type RepoFileComponent interface { + GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error + GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error + DetectRepoSensitiveCheckStatus(ctx context.Context, repoType types.RepositoryType, namespace, name string) error +} + +func NewRepoFileComponent(conf *config.Config) (RepoFileComponent, error) { + c := &repoFileComponentImpl{ rfs: database.NewRepoFileStore(), rs: database.NewRepoStore(), } @@ -32,7 +38,7 @@ func NewRepoFileComponent(conf *config.Config) (*RepoFileComponent, error) { c.gs = gs return c, nil } -func (c *RepoFileComponent) GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { +func (c *repoFileComponentImpl) GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) @@ -40,7 +46,7 @@ func (c *RepoFileComponent) GenRepoFileRecords(ctx context.Context, repoType typ return c.createRepoFileRecords(ctx, *repo, "", c.gs.GetRepoFileTree) } -func (c *RepoFileComponent) GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error { +func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error { tokens := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { tokens <- struct{}{} @@ -82,7 +88,7 @@ func (c *RepoFileComponent) GenRepoFileRecordsBatch(ctx context.Context, repoTyp return nil } -func (c *RepoFileComponent) createRepoFileRecords(ctx context.Context, repo database.Repository, folder string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) error { +func (c *repoFileComponentImpl) createRepoFileRecords(ctx context.Context, repo database.Repository, folder string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) error { namespace, name := repo.NamespaceAndName() var files []*types.File @@ -141,7 +147,7 @@ func (c *RepoFileComponent) createRepoFileRecords(ctx context.Context, repo data return nil } -func (c *RepoFileComponent) DetectRepoSensitiveCheckStatus(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { +func (c *repoFileComponentImpl) DetectRepoSensitiveCheckStatus(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) diff --git a/moderation/handler/repo.go b/moderation/handler/repo.go index c4fb5f78..cb6b72b2 100644 --- a/moderation/handler/repo.go +++ b/moderation/handler/repo.go @@ -16,7 +16,7 @@ import ( ) type RepoHandler struct { - rc *component.RepoComponent + rc component.RepoComponent config *config.Config } diff --git a/user/component/access_token.go b/user/component/access_token.go index 2c4b4744..a191213a 100644 --- a/user/component/access_token.go +++ b/user/component/access_token.go @@ -18,8 +18,16 @@ import ( var ErrUserNotFound = errors.New("user not found, please login first") -func NewAccessTokenComponent(config *config.Config) (*AccessTokenComponent, error) { - c := &AccessTokenComponent{} +type AccessTokenComponent interface { + Create(ctx context.Context, req *types.CreateUserTokenRequest) (*database.AccessToken, error) + Delete(ctx context.Context, req *types.DeleteUserTokenRequest) error + Check(ctx context.Context, req *types.CheckAccessTokenReq) (types.CheckAccessTokenResp, error) + GetTokens(ctx context.Context, username, app string) ([]types.CheckAccessTokenResp, error) + RefreshToken(ctx context.Context, userName, tokenName, app string, newExpiredAt time.Time) (types.CheckAccessTokenResp, error) +} + +func NewAccessTokenComponent(config *config.Config) (AccessTokenComponent, error) { + c := &accessTokenComponentImpl{} c.ts = database.NewAccessTokenStore() c.us = database.NewUserStore() var err error @@ -32,13 +40,13 @@ func NewAccessTokenComponent(config *config.Config) (*AccessTokenComponent, erro return c, nil } -type AccessTokenComponent struct { - ts *database.AccessTokenStore - us *database.UserStore +type accessTokenComponentImpl struct { + ts database.AccessTokenStore + us database.UserStore gs gitserver.GitServer } -func (c *AccessTokenComponent) Create(ctx context.Context, req *types.CreateUserTokenRequest) (*database.AccessToken, error) { +func (c *accessTokenComponentImpl) Create(ctx context.Context, req *types.CreateUserTokenRequest) (*database.AccessToken, error) { user, err := c.us.FindByUsername(ctx, req.Username) if err != nil { return nil, fmt.Errorf("fail to find user,error:%w", err) @@ -96,12 +104,12 @@ func (c *AccessTokenComponent) Create(ctx context.Context, req *types.CreateUser return token, nil } -func (c *AccessTokenComponent) genUnique() string { +func (c *accessTokenComponentImpl) genUnique() string { // TODO:change return strings.ReplaceAll(uuid.NewString(), "-", "") } -func (c *AccessTokenComponent) Delete(ctx context.Context, req *types.DeleteUserTokenRequest) error { +func (c *accessTokenComponentImpl) Delete(ctx context.Context, req *types.DeleteUserTokenRequest) error { ue, err := c.us.IsExist(ctx, req.Username) if !ue { return fmt.Errorf("user does not exists,error:%w", err) @@ -125,7 +133,7 @@ func (c *AccessTokenComponent) Delete(ctx context.Context, req *types.DeleteUser return nil } -func (c *AccessTokenComponent) Check(ctx context.Context, req *types.CheckAccessTokenReq) (types.CheckAccessTokenResp, error) { +func (c *accessTokenComponentImpl) Check(ctx context.Context, req *types.CheckAccessTokenReq) (types.CheckAccessTokenResp, error) { var resp types.CheckAccessTokenResp t, err := c.ts.FindByToken(ctx, req.Token, req.Application) if err != nil { @@ -142,7 +150,7 @@ func (c *AccessTokenComponent) Check(ctx context.Context, req *types.CheckAccess return resp, nil } -func (c *AccessTokenComponent) GetTokens(ctx context.Context, username, app string) ([]types.CheckAccessTokenResp, error) { +func (c *accessTokenComponentImpl) GetTokens(ctx context.Context, username, app string) ([]types.CheckAccessTokenResp, error) { var resps []types.CheckAccessTokenResp tokens, err := c.ts.FindByUser(ctx, username, app) if err != nil { @@ -164,7 +172,7 @@ func (c *AccessTokenComponent) GetTokens(ctx context.Context, username, app stri return resps, nil } -func (c *AccessTokenComponent) RefreshToken(ctx context.Context, userName, tokenName, app string, newExpiredAt time.Time) (types.CheckAccessTokenResp, error) { +func (c *accessTokenComponentImpl) RefreshToken(ctx context.Context, userName, tokenName, app string, newExpiredAt time.Time) (types.CheckAccessTokenResp, error) { var resp types.CheckAccessTokenResp t, err := c.ts.FindByTokenName(ctx, userName, tokenName, app) if err != nil { diff --git a/user/component/jwt.go b/user/component/jwt.go index 72df74f7..bfedd004 100644 --- a/user/component/jwt.go +++ b/user/component/jwt.go @@ -10,14 +10,20 @@ import ( "opencsg.com/csghub-server/common/types" ) -type JwtComponent struct { +type jwtComponentImpl struct { SigningKey []byte ValidTime time.Duration - us *database.UserStore + us database.UserStore } -func NewJwtComponent(signKey string, validHour int) *JwtComponent { - return &JwtComponent{ +type JwtComponent interface { + // GenerateToken generate a jwt token, and return the token and signed string + GenerateToken(ctx context.Context, req types.CreateJWTReq) (claims *types.JWTClaims, signed string, err error) + ParseToken(ctx context.Context, token string) (user *types.User, err error) +} + +func NewJwtComponent(signKey string, validHour int) JwtComponent { + return &jwtComponentImpl{ SigningKey: []byte(signKey), ValidTime: time.Duration(validHour) * time.Hour, us: database.NewUserStore(), @@ -25,7 +31,7 @@ func NewJwtComponent(signKey string, validHour int) *JwtComponent { } // GenerateToken generate a jwt token, and return the token and signed string -func (c *JwtComponent) GenerateToken(ctx context.Context, req types.CreateJWTReq) (claims *types.JWTClaims, signed string, err error) { +func (c *jwtComponentImpl) GenerateToken(ctx context.Context, req types.CreateJWTReq) (claims *types.JWTClaims, signed string, err error) { u, err := c.us.FindByUUID(ctx, req.UUID) if err != nil { return nil, "", fmt.Errorf("failed to find user by uuid '%s',error: %w", req.UUID, err) @@ -49,7 +55,7 @@ func (c *JwtComponent) GenerateToken(ctx context.Context, req types.CreateJWTReq return claims, signed, nil } -func (c *JwtComponent) ParseToken(ctx context.Context, token string) (user *types.User, err error) { +func (c *jwtComponentImpl) ParseToken(ctx context.Context, token string) (user *types.User, err error) { claims := &types.JWTClaims{} _, err = jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { diff --git a/user/component/member.go b/user/component/member.go index 3d2f3b7f..94ab410a 100644 --- a/user/component/member.go +++ b/user/component/member.go @@ -15,16 +15,28 @@ import ( "opencsg.com/csghub-server/common/types" ) -type MemberComponent struct { - memberStore *database.MemberStore - orgStore *database.OrgStore - userStore *database.UserStore +type memberComponentImpl struct { + memberStore database.MemberStore + orgStore database.OrgStore + userStore database.UserStore gitServer gitserver.GitServer gitMemberShip membership.GitMemerShip config *config.Config } -func NewMemberComponent(config *config.Config) (*MemberComponent, error) { +type MemberComponent interface { + OrgMembers(ctx context.Context, orgName, currentUser string, pageSize, page int) ([]types.Member, int, error) + InitRoles(ctx context.Context, org *database.Organization) error + SetAdmin(ctx context.Context, org *database.Organization, user *database.User) error + ChangeMemberRole(ctx context.Context, orgName, userName, operatorName, oldRole, newRole string) error + GetMemberRole(ctx context.Context, orgName, userName string) (membership.Role, error) + AddMembers(ctx context.Context, orgName string, users []string, operatorName string, role string) error + AddMember(ctx context.Context, orgName, userName, operatorName string, role string) error + Update(ctx context.Context) (org *database.Member, err error) + Delete(ctx context.Context, orgName, userName, operatorName string, role string) error +} + +func NewMemberComponent(config *config.Config) (MemberComponent, error) { var gms membership.GitMemerShip gs, err := git.NewGitServer(config) if err != nil { @@ -36,7 +48,7 @@ func NewMemberComponent(config *config.Config) (*MemberComponent, error) { return nil, fmt.Errorf("failed to create git membership:%w", err) } } - return &MemberComponent{ + return &memberComponentImpl{ memberStore: database.NewMemberStore(), orgStore: database.NewOrgStore(), userStore: database.NewUserStore(), @@ -46,7 +58,7 @@ func NewMemberComponent(config *config.Config) (*MemberComponent, error) { }, nil } -func (c *MemberComponent) OrgMembers(ctx context.Context, orgName, currentUser string, pageSize, page int) ([]types.Member, int, error) { +func (c *memberComponentImpl) OrgMembers(ctx context.Context, orgName, currentUser string, pageSize, page int) ([]types.Member, int, error) { var ( org database.Organization user database.User @@ -90,7 +102,7 @@ func (c *MemberComponent) OrgMembers(ctx context.Context, orgName, currentUser s return members, total, nil } -func (c *MemberComponent) InitRoles(ctx context.Context, org *database.Organization) error { +func (c *memberComponentImpl) InitRoles(ctx context.Context, org *database.Organization) error { if c.config.GitServer.Type == types.GitServerTypeGitea { return c.gitMemberShip.AddRoles(ctx, org.Name, []membership.Role{membership.RoleAdmin, membership.RoleRead, membership.RoleWrite}) @@ -99,7 +111,7 @@ func (c *MemberComponent) InitRoles(ctx context.Context, org *database.Organizat } } -func (c *MemberComponent) SetAdmin(ctx context.Context, org *database.Organization, user *database.User) error { +func (c *memberComponentImpl) SetAdmin(ctx context.Context, org *database.Organization, user *database.User) error { var ( err error ) @@ -115,7 +127,7 @@ func (c *MemberComponent) SetAdmin(ctx context.Context, org *database.Organizati } } -func (c *MemberComponent) ChangeMemberRole(ctx context.Context, orgName, userName, operatorName, oldRole, newRole string) error { +func (c *memberComponentImpl) ChangeMemberRole(ctx context.Context, orgName, userName, operatorName, oldRole, newRole string) error { err := c.Delete(ctx, orgName, userName, operatorName, oldRole) if err != nil { return fmt.Errorf("failed to delete old role,error:%w", err) @@ -128,7 +140,7 @@ func (c *MemberComponent) ChangeMemberRole(ctx context.Context, orgName, userNam return nil } -func (c *MemberComponent) GetMemberRole(ctx context.Context, orgName, userName string) (membership.Role, error) { +func (c *memberComponentImpl) GetMemberRole(ctx context.Context, orgName, userName string) (membership.Role, error) { var ( org database.Organization user database.User @@ -152,7 +164,7 @@ func (c *MemberComponent) GetMemberRole(ctx context.Context, orgName, userName s return c.toGitRole(m.Role), nil } -func (c *MemberComponent) AddMembers(ctx context.Context, orgName string, users []string, operatorName string, role string) error { +func (c *memberComponentImpl) AddMembers(ctx context.Context, orgName string, users []string, operatorName string, role string) error { var ( org database.Organization op database.User @@ -204,15 +216,15 @@ func (c *MemberComponent) AddMembers(ctx context.Context, orgName string, users return nil } -func (c *MemberComponent) AddMember(ctx context.Context, orgName, userName, operatorName string, role string) error { +func (c *memberComponentImpl) AddMember(ctx context.Context, orgName, userName, operatorName string, role string) error { return c.AddMembers(ctx, orgName, []string{userName}, operatorName, role) } -func (c *MemberComponent) Update(ctx context.Context) (org *database.Member, err error) { +func (c *memberComponentImpl) Update(ctx context.Context) (org *database.Member, err error) { return } -func (c *MemberComponent) Delete(ctx context.Context, orgName, userName, operatorName string, role string) error { +func (c *memberComponentImpl) Delete(ctx context.Context, orgName, userName, operatorName string, role string) error { var ( org database.Organization op database.User @@ -258,12 +270,12 @@ func (c *MemberComponent) Delete(ctx context.Context, orgName, userName, operato } } -func (c *MemberComponent) allowAddMember(u *database.Member) bool { +func (c *memberComponentImpl) allowAddMember(u *database.Member) bool { //TODO: check more roles return u != nil && u.Role == string(membership.RoleAdmin) } -func (c *MemberComponent) toGitRole(role string) membership.Role { +func (c *memberComponentImpl) toGitRole(role string) membership.Role { switch role { case "admin": return membership.RoleAdmin diff --git a/user/component/namespace.go b/user/component/namespace.go index ac09ac44..0e250337 100644 --- a/user/component/namespace.go +++ b/user/component/namespace.go @@ -9,19 +9,23 @@ import ( "opencsg.com/csghub-server/common/types" ) -type NamespaceComponent struct { - ns *database.NamespaceStore - os *database.OrgStore +type namespaceComponentImpl struct { + ns database.NamespaceStore + os database.OrgStore } -func NewNamespaceComponent(config *config.Config) (*NamespaceComponent, error) { - return &NamespaceComponent{ +type NamespaceComponent interface { + GetInfo(ctx context.Context, path string) (*types.Namespace, error) +} + +func NewNamespaceComponent(config *config.Config) (NamespaceComponent, error) { + return &namespaceComponentImpl{ ns: database.NewNamespaceStore(), os: database.NewOrgStore(), }, nil } -func (c *NamespaceComponent) GetInfo(ctx context.Context, path string) (*types.Namespace, error) { +func (c *namespaceComponentImpl) GetInfo(ctx context.Context, path string) (*types.Namespace, error) { dbns, err := c.ns.FindByPath(ctx, path) ns := &types.Namespace{ Path: dbns.Path, diff --git a/user/component/organization.go b/user/component/organization.go index d8c440d9..292e8f47 100644 --- a/user/component/organization.go +++ b/user/component/organization.go @@ -13,8 +13,17 @@ import ( "opencsg.com/csghub-server/common/types" ) -func NewOrganizationComponent(config *config.Config) (*OrganizationComponent, error) { - c := &OrganizationComponent{} +type OrganizationComponent interface { + FixOrgData(ctx context.Context, org *database.Organization) (*database.Organization, error) + Create(ctx context.Context, req *types.CreateOrgReq) (*types.Organization, error) + Index(ctx context.Context, username string) ([]types.Organization, error) + Get(ctx context.Context, orgName string) (*types.Organization, error) + Delete(ctx context.Context, req *types.DeleteOrgReq) error + Update(ctx context.Context, req *types.EditOrgReq) (*database.Organization, error) +} + +func NewOrganizationComponent(config *config.Config) (OrganizationComponent, error) { + c := &organizationComponentImpl{} c.os = database.NewOrgStore() c.ns = database.NewNamespaceStore() c.us = database.NewUserStore() @@ -34,16 +43,16 @@ func NewOrganizationComponent(config *config.Config) (*OrganizationComponent, er return c, nil } -type OrganizationComponent struct { - os *database.OrgStore - ns *database.NamespaceStore - us *database.UserStore +type organizationComponentImpl struct { + os database.OrgStore + ns database.NamespaceStore + us database.UserStore gs gitserver.GitServer - msc *MemberComponent + msc MemberComponent } -func (c *OrganizationComponent) FixOrgData(ctx context.Context, org *database.Organization) (*database.Organization, error) { +func (c *organizationComponentImpl) FixOrgData(ctx context.Context, org *database.Organization) (*database.Organization, error) { user := org.User req := new(types.CreateOrgReq) req.Name = org.Name @@ -64,7 +73,7 @@ func (c *OrganizationComponent) FixOrgData(ctx context.Context, org *database.Or return org, err } -func (c *OrganizationComponent) Create(ctx context.Context, req *types.CreateOrgReq) (*types.Organization, error) { +func (c *organizationComponentImpl) Create(ctx context.Context, req *types.CreateOrgReq) (*types.Organization, error) { user, err := c.us.FindByUsername(ctx, req.Username) if err != nil { return nil, fmt.Errorf("failed to find user, error: %w", err) @@ -116,7 +125,7 @@ func (c *OrganizationComponent) Create(ctx context.Context, req *types.CreateOrg return org, err } -func (c *OrganizationComponent) Index(ctx context.Context, username string) ([]types.Organization, error) { +func (c *organizationComponentImpl) Index(ctx context.Context, username string) ([]types.Organization, error) { dborgs, err := c.os.GetUserOwnOrgs(ctx, username) if err != nil { return nil, fmt.Errorf("failed to get organizations, error: %w", err) @@ -136,7 +145,7 @@ func (c *OrganizationComponent) Index(ctx context.Context, username string) ([]t return orgs, nil } -func (c *OrganizationComponent) Get(ctx context.Context, orgName string) (*types.Organization, error) { +func (c *organizationComponentImpl) Get(ctx context.Context, orgName string) (*types.Organization, error) { dborg, err := c.os.FindByPath(ctx, orgName) if err != nil { return nil, fmt.Errorf("failed to get organizations by name, error: %w", err) @@ -152,7 +161,7 @@ func (c *OrganizationComponent) Get(ctx context.Context, orgName string) (*types return org, nil } -func (c *OrganizationComponent) Delete(ctx context.Context, req *types.DeleteOrgReq) error { +func (c *organizationComponentImpl) Delete(ctx context.Context, req *types.DeleteOrgReq) error { r, err := c.msc.GetMemberRole(ctx, req.Name, req.CurrentUser) if err != nil { slog.Error("faild to get member role", @@ -173,7 +182,7 @@ func (c *OrganizationComponent) Delete(ctx context.Context, req *types.DeleteOrg return nil } -func (c *OrganizationComponent) Update(ctx context.Context, req *types.EditOrgReq) (*database.Organization, error) { +func (c *organizationComponentImpl) Update(ctx context.Context, req *types.EditOrgReq) (*database.Organization, error) { r, err := c.msc.GetMemberRole(ctx, req.Name, req.CurrentUser) if err != nil { slog.Error("faild to get member role", diff --git a/user/component/user.go b/user/component/user.go index 1edf3d31..a541c1d7 100644 --- a/user/component/user.go +++ b/user/component/user.go @@ -23,17 +23,17 @@ import ( const GitalyRepoNotFoundErr = "rpc error: code = NotFound desc = repository does not exist" -type UserComponent struct { - us *database.UserStore - os *database.OrgStore - ns *database.NamespaceStore - repo *database.RepoStore - ds *database.DeployTaskStore - ams *database.AccountMeteringStore +type userComponentImpl struct { + us database.UserStore + os database.OrgStore + ns database.NamespaceStore + repo database.RepoStore + ds database.DeployTaskStore + ams database.AccountMeteringStore gs gitserver.GitServer - jwtc *JwtComponent - tokenc *AccessTokenComponent + jwtc JwtComponent + tokenc AccessTokenComponent casc *casdoorsdk.Client casConfig *casdoorsdk.AuthConfig @@ -42,9 +42,37 @@ type UserComponent struct { config *config.Config } -func NewUserComponent(config *config.Config) (*UserComponent, error) { +type UserComponent interface { + ChangeUserName(ctx context.Context, oldUserName, newUserName, opUser string) error + Update(ctx context.Context, req *types.UpdateUserRequest, opUser string) error + Delete(ctx context.Context, operator, username string) error + // CanAdmin checks if a user has admin privileges. + // + // Parameters: + // - ctx: The context.Context object for the function. + // - username: The username of the user to check. + // + // Returns: + // - bool: True if the user has admin privileges, false otherwise. + // - error: An error if the user cannot be found in the database. + CanAdmin(ctx context.Context, username string) (bool, error) + // GetInternal get *full* user info by username or uuid + // + // should only be called by other *internal* services + GetInternal(ctx context.Context, userNameOrUUID string, useUUID bool) (*types.User, error) + Get(ctx context.Context, userNameOrUUID, visitorName string, useUUID bool) (*types.User, error) + CheckOperatorAndUser(ctx context.Context, operator, username string) (bool, error) + CheckIfUserHasOrgs(ctx context.Context, userName string) (bool, error) + CheckIffUserHasRunningOrBuildingDeployments(ctx context.Context, userName string) (bool, error) + CheckIfUserHasBills(ctx context.Context, userName string) (bool, error) + Index(ctx context.Context, visitorName, search string, per, page int) ([]*types.User, int, error) + Signin(ctx context.Context, code, state string) (*types.JWTClaims, string, error) + FixUserData(ctx context.Context, userName string) error +} + +func NewUserComponent(config *config.Config) (UserComponent, error) { var err error - c := &UserComponent{} + c := &userComponentImpl{} c.us = database.NewUserStore() c.os = database.NewOrgStore() c.ns = database.NewNamespaceStore() @@ -80,12 +108,12 @@ func NewUserComponent(config *config.Config) (*UserComponent, error) { } // This function creates a user when user register from portal, without casdoor -func (c *UserComponent) createFromPortalRegistry(ctx context.Context, req types.CreateUserRequest) (*database.User, error) { +func (c *userComponentImpl) createFromPortalRegistry(ctx context.Context, req types.CreateUserRequest) (*database.User, error) { // Panic if the function has not been implemented panic("implement me later") } -func (c *UserComponent) createFromCasdoorUser(ctx context.Context, cu casdoorsdk.User) (*database.User, error) { +func (c *userComponentImpl) createFromCasdoorUser(ctx context.Context, cu casdoorsdk.User) (*database.User, error) { var ( gsUserResp *gitserver.CreateUserResponse err error @@ -156,7 +184,7 @@ func (c *UserComponent) createFromCasdoorUser(ctx context.Context, cu casdoorsdk return user, nil } -func (c *UserComponent) ChangeUserName(ctx context.Context, oldUserName, newUserName, opUser string) error { +func (c *userComponentImpl) ChangeUserName(ctx context.Context, oldUserName, newUserName, opUser string) error { if oldUserName != opUser { return fmt.Errorf("user name can only be changed by user self, user: '%s', op user: '%s'", oldUserName, opUser) } @@ -201,7 +229,7 @@ func (c *UserComponent) ChangeUserName(ctx context.Context, oldUserName, newUser return nil } -func (c *UserComponent) Update(ctx context.Context, req *types.UpdateUserRequest, opUser string) error { +func (c *userComponentImpl) Update(ctx context.Context, req *types.UpdateUserRequest, opUser string) error { c.lazyInit() user, err := c.us.FindByUsername(ctx, req.Username) @@ -254,7 +282,7 @@ func (c *UserComponent) Update(ctx context.Context, req *types.UpdateUserRequest // user registery with wechat does not have email, so git user is not created after signin // when user set email, a git user needs to be created -func (c *UserComponent) upsertGitUser(username string, nickname *string, oldEmail, newEmail string) error { +func (c *userComponentImpl) upsertGitUser(username string, nickname *string, oldEmail, newEmail string) error { var err error if nickname == nil { nickname = &username @@ -287,7 +315,7 @@ func (c *UserComponent) upsertGitUser(username string, nickname *string, oldEmai return nil } -func (c *UserComponent) setChangedProps(user *database.User, req *types.UpdateUserRequest) { +func (c *userComponentImpl) setChangedProps(user *database.User, req *types.UpdateUserRequest) { if req.Email != nil { user.Email = *req.Email if user.CanChangeUserName { @@ -319,7 +347,7 @@ func (c *UserComponent) setChangedProps(user *database.User, req *types.UpdateUs } } -func (c *UserComponent) Delete(ctx context.Context, operator, username string) error { +func (c *userComponentImpl) Delete(ctx context.Context, operator, username string) error { user, err := c.us.FindByUsername(ctx, username) if err != nil { newError := fmt.Errorf("failed to find user by name in db,error:%w", err) @@ -375,7 +403,7 @@ func (c *UserComponent) Delete(ctx context.Context, operator, username string) e // Returns: // - bool: True if the user has admin privileges, false otherwise. // - error: An error if the user cannot be found in the database. -func (c *UserComponent) CanAdmin(ctx context.Context, username string) (bool, error) { +func (c *userComponentImpl) CanAdmin(ctx context.Context, username string) (bool, error) { user, err := c.us.FindByUsername(ctx, username) if err != nil { newError := fmt.Errorf("failed to find user by name '%s' in db,error:%w", username, err) @@ -387,7 +415,7 @@ func (c *UserComponent) CanAdmin(ctx context.Context, username string) (bool, er // GetInternal get *full* user info by username or uuid // // should only be called by other *internal* services -func (c *UserComponent) GetInternal(ctx context.Context, userNameOrUUID string, useUUID bool) (*types.User, error) { +func (c *userComponentImpl) GetInternal(ctx context.Context, userNameOrUUID string, useUUID bool) (*types.User, error) { var dbuser = new(database.User) var err error if useUUID { @@ -401,7 +429,7 @@ func (c *UserComponent) GetInternal(ctx context.Context, userNameOrUUID string, return c.buildUserInfo(ctx, dbuser, false) } -func (c *UserComponent) Get(ctx context.Context, userNameOrUUID, visitorName string, useUUID bool) (*types.User, error) { +func (c *userComponentImpl) Get(ctx context.Context, userNameOrUUID, visitorName string, useUUID bool) (*types.User, error) { var dbuser = new(database.User) var err error if useUUID { @@ -431,7 +459,7 @@ func (c *UserComponent) Get(ctx context.Context, userNameOrUUID, visitorName str return c.buildUserInfo(ctx, dbuser, onlyBasicInfo) } -func (c *UserComponent) CheckOperatorAndUser(ctx context.Context, operator, username string) (bool, error) { +func (c *userComponentImpl) CheckOperatorAndUser(ctx context.Context, operator, username string) (bool, error) { opUser, err := c.us.FindByUsername(ctx, operator) if err != nil { newError := fmt.Errorf("failed to find operator by name in db,error:%w", err) @@ -453,7 +481,7 @@ func (c *UserComponent) CheckOperatorAndUser(ctx context.Context, operator, user return false, nil } -func (c *UserComponent) CheckIfUserHasOrgs(ctx context.Context, userName string) (bool, error) { +func (c *userComponentImpl) CheckIfUserHasOrgs(ctx context.Context, userName string) (bool, error) { var ( err error orgs []database.Organization @@ -467,7 +495,7 @@ func (c *UserComponent) CheckIfUserHasOrgs(ctx context.Context, userName string) return true, nil } -func (c *UserComponent) CheckIffUserHasRunningOrBuildingDeployments(ctx context.Context, userName string) (bool, error) { +func (c *userComponentImpl) CheckIffUserHasRunningOrBuildingDeployments(ctx context.Context, userName string) (bool, error) { user, err := c.us.FindByUsername(ctx, userName) if err != nil { return false, fmt.Errorf("failed to find user by username in db, error: %v", err) @@ -482,7 +510,7 @@ func (c *UserComponent) CheckIffUserHasRunningOrBuildingDeployments(ctx context. return false, nil } -func (c *UserComponent) CheckIfUserHasBills(ctx context.Context, userName string) (bool, error) { +func (c *userComponentImpl) CheckIfUserHasBills(ctx context.Context, userName string) (bool, error) { user, err := c.us.FindByUsername(ctx, userName) if err != nil { return false, fmt.Errorf("failed to find user by username in db, error: %v", err) @@ -498,7 +526,7 @@ func (c *UserComponent) CheckIfUserHasBills(ctx context.Context, userName string return false, nil } -func (c *UserComponent) buildUserInfo(ctx context.Context, dbuser *database.User, onlyBasicInfo bool) (*types.User, error) { +func (c *userComponentImpl) buildUserInfo(ctx context.Context, dbuser *database.User, onlyBasicInfo bool) (*types.User, error) { u := types.User{ Username: dbuser.Username, Nickname: dbuser.NickName, @@ -537,7 +565,7 @@ func (c *UserComponent) buildUserInfo(ctx context.Context, dbuser *database.User return &u, nil } -func (c *UserComponent) Index(ctx context.Context, visitorName, search string, per, page int) ([]*types.User, int, error) { +func (c *userComponentImpl) Index(ctx context.Context, visitorName, search string, per, page int) ([]*types.User, int, error) { var ( respUsers []*types.User onlyBasicInfo bool @@ -578,7 +606,7 @@ func (c *UserComponent) Index(ctx context.Context, visitorName, search string, p return respUsers, count, nil } -func (c *UserComponent) Signin(ctx context.Context, code, state string) (*types.JWTClaims, string, error) { +func (c *userComponentImpl) Signin(ctx context.Context, code, state string) (*types.JWTClaims, string, error) { c.lazyInit() casToken, err := c.casc.GetOAuthToken(code, state) @@ -640,7 +668,7 @@ func (c *UserComponent) Signin(ctx context.Context, code, state string) (*types. return hubToken, signed, nil } -func (c *UserComponent) genUniqueName() (string, error) { +func (c *userComponentImpl) genUniqueName() (string, error) { c.lazyInit() if c.sfnode == nil { @@ -650,7 +678,7 @@ func (c *UserComponent) genUniqueName() (string, error) { return "user_" + id, nil } -func (c *UserComponent) updateCasdoorUser(req *types.UpdateUserRequest) error { +func (c *userComponentImpl) updateCasdoorUser(req *types.UpdateUserRequest) error { c.lazyInit() casu, err := c.casc.GetUserByUserId(*req.UUID) @@ -683,7 +711,7 @@ func (c *UserComponent) updateCasdoorUser(req *types.UpdateUserRequest) error { return err } -func (c *UserComponent) lazyInit() { +func (c *userComponentImpl) lazyInit() { c.once.Do(func() { var err error c.casc = casdoorsdk.NewClientWithConf(c.casConfig) @@ -694,7 +722,7 @@ func (c *UserComponent) lazyInit() { }) } -func (c *UserComponent) FixUserData(ctx context.Context, userName string) error { +func (c *userComponentImpl) FixUserData(ctx context.Context, userName string) error { err := c.gs.FixUserData(ctx, userName) if err != nil { return err diff --git a/user/handler/access_token.go b/user/handler/access_token.go index d42fd654..e79e85d2 100644 --- a/user/handler/access_token.go +++ b/user/handler/access_token.go @@ -30,8 +30,8 @@ func NewAccessTokenHandler(config *config.Config) (*AccessTokenHandler, error) { } type AccessTokenHandler struct { - c *component.AccessTokenComponent - sc *apicomponent.SensitiveComponent + c component.AccessTokenComponent + sc apicomponent.SensitiveComponent } // CreateAccessToken godoc diff --git a/user/handler/jwt.go b/user/handler/jwt.go index 034f9fc9..98c5e7bc 100644 --- a/user/handler/jwt.go +++ b/user/handler/jwt.go @@ -18,7 +18,7 @@ func NewJWTHandler(config *config.Config) (*JWTHandler, error) { } type JWTHandler struct { - c *component.JwtComponent + c component.JwtComponent } // CreateJWTToken godoc diff --git a/user/handler/member.go b/user/handler/member.go index 605a9ab0..bb796d77 100644 --- a/user/handler/member.go +++ b/user/handler/member.go @@ -13,7 +13,7 @@ import ( ) type MemberHandler struct { - c *component.MemberComponent + c component.MemberComponent } func NewMemberHandler(config *config.Config) (*MemberHandler, error) { diff --git a/user/handler/namespace.go b/user/handler/namespace.go index 592cceed..d7941e17 100644 --- a/user/handler/namespace.go +++ b/user/handler/namespace.go @@ -8,7 +8,7 @@ import ( ) type NamespaceHandler struct { - c *component.NamespaceComponent + c component.NamespaceComponent } func NewNamespaceHandler(config *config.Config) (*NamespaceHandler, error) { diff --git a/user/handler/organization.go b/user/handler/organization.go index 1a8b34e5..46806816 100644 --- a/user/handler/organization.go +++ b/user/handler/organization.go @@ -29,8 +29,8 @@ func NewOrganizationHandler(config *config.Config) (*OrganizationHandler, error) } type OrganizationHandler struct { - c *component.OrganizationComponent - sc *apicomponent.SensitiveComponent + c component.OrganizationComponent + sc apicomponent.SensitiveComponent } // CreateOrganization godoc diff --git a/user/handler/user.go b/user/handler/user.go index 2b915931..3ab508c3 100644 --- a/user/handler/user.go +++ b/user/handler/user.go @@ -20,8 +20,8 @@ import ( ) type UserHandler struct { - c *component.UserComponent - sc *apicomponent.SensitiveComponent + c component.UserComponent + sc apicomponent.SensitiveComponent publicDomain string EnableHTTPS bool signinSuccessRedirectURL string