From 453845debf80c28d13765ac68319901047dde721 Mon Sep 17 00:00:00 2001 From: Fardin Khanjani Date: Wed, 23 Feb 2022 16:25:52 +0100 Subject: [PATCH] feat: This commit adds pull request support to SCM generator so the generator can create ArgoCD apps for PRs as well. Fixes #466 Signed-off-by: Fardin Khanjani --- api/v1alpha1/applicationset_types.go | 8 ++ .../crds/argoproj.io_applicationsets.yaml | 18 +++++ pkg/generators/scm_provider.go | 4 +- pkg/services/scm_provider/github.go | 64 ++++++++++++++-- pkg/services/scm_provider/github_test.go | 18 +++-- pkg/services/scm_provider/gitlab.go | 50 ++++++++++++- pkg/services/scm_provider/gitlab_test.go | 12 +-- pkg/services/scm_provider/mock.go | 20 ++++- pkg/services/scm_provider/types.go | 14 ++-- pkg/services/scm_provider/utils.go | 74 ++++++++++++++++++- 10 files changed, 251 insertions(+), 31 deletions(-) diff --git a/api/v1alpha1/applicationset_types.go b/api/v1alpha1/applicationset_types.go index b6dc6bc1..21a280b7 100644 --- a/api/v1alpha1/applicationset_types.go +++ b/api/v1alpha1/applicationset_types.go @@ -314,6 +314,8 @@ type SCMProviderGeneratorGithub struct { TokenRef *SecretRef `json:"tokenRef,omitempty"` // Scan all branches instead of just the default branch. AllBranches bool `json:"allBranches,omitempty"` + // Scan all pull requests + AllPullRequests bool `json:"allPullRequests,omitempty"` } // SCMProviderGeneratorGitlab defines a connection info specific to Gitlab. @@ -328,6 +330,8 @@ type SCMProviderGeneratorGitlab struct { TokenRef *SecretRef `json:"tokenRef,omitempty"` // Scan all branches instead of just the default branch. AllBranches bool `json:"allBranches,omitempty"` + // Scan all pull requests + AllPullRequests bool `json:"allPullRequests,omitempty"` } // SCMProviderGeneratorFilter is a single repository filter. @@ -342,6 +346,10 @@ type SCMProviderGeneratorFilter struct { LabelMatch *string `json:"labelMatch,omitempty"` // A regex which must match the branch name. BranchMatch *string `json:"branchMatch,omitempty"` + // A regex which must match the pull request tile. + PullRequestTitleMatch *string `json:"pullRequestTitleMatch,omitempty"` + // A regex which must match at least one pull request label. + PullRequestLabelMatch *string `json:"pullRequestLabelMatch,omitempty"` } // PullRequestGenerator defines a generator that scrapes a PullRequest API to find candidate pull requests. diff --git a/manifests/crds/argoproj.io_applicationsets.yaml b/manifests/crds/argoproj.io_applicationsets.yaml index 18033389..d25e90e3 100644 --- a/manifests/crds/argoproj.io_applicationsets.yaml +++ b/manifests/crds/argoproj.io_applicationsets.yaml @@ -2706,6 +2706,10 @@ spec: items: type: string type: array + pullRequestLabelMatch: + type: string + pullRequestTitleMatch: + type: string repositoryMatch: type: string type: object @@ -2714,6 +2718,8 @@ spec: properties: allBranches: type: boolean + allPullRequests: + type: boolean api: type: string organization: @@ -4792,6 +4798,10 @@ spec: items: type: string type: array + pullRequestLabelMatch: + type: string + pullRequestTitleMatch: + type: string repositoryMatch: type: string type: object @@ -4800,6 +4810,8 @@ spec: properties: allBranches: type: boolean + allPullRequests: + type: boolean api: type: string organization: @@ -5699,6 +5711,10 @@ spec: items: type: string type: array + pullRequestLabelMatch: + type: string + pullRequestTitleMatch: + type: string repositoryMatch: type: string type: object @@ -5707,6 +5723,8 @@ spec: properties: allBranches: type: boolean + allPullRequests: + type: boolean api: type: string organization: diff --git a/pkg/generators/scm_provider.go b/pkg/generators/scm_provider.go index b4c145b9..542efa19 100644 --- a/pkg/generators/scm_provider.go +++ b/pkg/generators/scm_provider.go @@ -64,7 +64,7 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha if err != nil { return nil, fmt.Errorf("error fetching Github token: %v", err) } - provider, err = scm_provider.NewGithubProvider(ctx, providerConfig.Github.Organization, token, providerConfig.Github.API, providerConfig.Github.AllBranches) + provider, err = scm_provider.NewGithubProvider(ctx, providerConfig.Github.Organization, token, providerConfig.Github.API, providerConfig.Github.AllBranches, providerConfig.Github.AllPullRequests) if err != nil { return nil, fmt.Errorf("error initializing Github service: %v", err) } @@ -73,7 +73,7 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha if err != nil { return nil, fmt.Errorf("error fetching Gitlab token: %v", err) } - provider, err = scm_provider.NewGitlabProvider(ctx, providerConfig.Gitlab.Group, token, providerConfig.Gitlab.API, providerConfig.Gitlab.AllBranches, providerConfig.Gitlab.IncludeSubgroups) + provider, err = scm_provider.NewGitlabProvider(ctx, providerConfig.Gitlab.Group, token, providerConfig.Gitlab.API, providerConfig.Gitlab.AllBranches, providerConfig.Gitlab.IncludeSubgroups, providerConfig.Gitlab.AllPullRequests) if err != nil { return nil, fmt.Errorf("error initializing Gitlab service: %v", err) } diff --git a/pkg/services/scm_provider/github.go b/pkg/services/scm_provider/github.go index 91f9d508..d8a3f132 100644 --- a/pkg/services/scm_provider/github.go +++ b/pkg/services/scm_provider/github.go @@ -11,14 +11,15 @@ import ( ) type GithubProvider struct { - client *github.Client - organization string - allBranches bool + client *github.Client + organization string + allBranches bool + allPullRequests bool } var _ SCMProviderService = &GithubProvider{} -func NewGithubProvider(ctx context.Context, organization string, token string, url string, allBranches bool) (*GithubProvider, error) { +func NewGithubProvider(ctx context.Context, organization string, token string, url string, allBranches bool, allPullRequests bool) (*GithubProvider, error) { var ts oauth2.TokenSource // Undocumented environment variable to set a default token, to be used in testing to dodge anonymous rate limits. if token == "" { @@ -40,7 +41,7 @@ func NewGithubProvider(ctx context.Context, organization string, token string, u return nil, err } } - return &GithubProvider{client: client, organization: organization, allBranches: allBranches}, nil + return &GithubProvider{client: client, organization: organization, allBranches: allBranches, allPullRequests: allPullRequests}, nil } func (g *GithubProvider) GetBranches(ctx context.Context, repo *Repository) ([]*Repository, error) { @@ -64,6 +65,32 @@ func (g *GithubProvider) GetBranches(ctx context.Context, repo *Repository) ([]* return repos, nil } +func (g *GithubProvider) GetPullRequests(ctx context.Context, repo *Repository) ([]*Repository, error) { + repos := []*Repository{} + pullRequests, err := g.listPullRequests(ctx, repo) + if err != nil { + return nil, fmt.Errorf("error listing pull requests for %s/%s: %v", repo.Organization, repo.Repository, err) + } + + // go-github's PullRequest type does not have a GetLabel() function. + var labels []string + for _, pullRequest := range pullRequests { + for _, label := range pullRequest.Labels { + labels = append(labels, label.GetName()) + } + repos = append(repos, &Repository{ + Organization: repo.Organization, + Repository: repo.Repository, + URL: repo.URL, + Branch: pullRequest.GetTitle(), + SHA: pullRequest.GetHead().GetSHA(), + Labels: labels, + RepositoryId: repo.RepositoryId, + }) + } + return repos, nil +} + func (g *GithubProvider) ListRepos(ctx context.Context, cloneProtocol string) ([]*Repository, error) { opt := &github.RepositoryListByOrgOptions{ ListOptions: github.ListOptions{PerPage: 100}, @@ -104,7 +131,7 @@ func (g *GithubProvider) ListRepos(ctx context.Context, cloneProtocol string) ([ func (g *GithubProvider) RepoHasPath(ctx context.Context, repo *Repository, path string) (bool, error) { _, _, resp, err := g.client.Repositories.GetContents(ctx, repo.Organization, repo.Repository, path, &github.RepositoryContentGetOptions{ - Ref: repo.Branch, + Ref: repo.SHA, }) // 404s are not an error here, just a normal false. if resp != nil && resp.StatusCode == 404 { @@ -153,3 +180,28 @@ func (g *GithubProvider) listBranches(ctx context.Context, repo *Repository) ([] } return branches, nil } + +func (g *GithubProvider) listPullRequests(ctx context.Context, repo *Repository) ([]github.PullRequest, error) { + opt := &github.PullRequestListOptions{ + ListOptions: github.ListOptions{PerPage: 100}, + } + + githubPullRequests := []github.PullRequest{} + + for { + allPullRequests, resp, err := g.client.PullRequests.List(ctx, repo.Organization, repo.Repository, opt) + if err != nil { + return nil, err + } + + for _, pr := range allPullRequests { + githubPullRequests = append(githubPullRequests, *pr) + } + + if resp.NextPage == 0 { + break + } + opt.Page = resp.NextPage + } + return githubPullRequests, nil +} diff --git a/pkg/services/scm_provider/github_test.go b/pkg/services/scm_provider/github_test.go index 9a692af4..295ef661 100644 --- a/pkg/services/scm_provider/github_test.go +++ b/pkg/services/scm_provider/github_test.go @@ -36,10 +36,10 @@ func checkRateLimit(t *testing.T, err error) { func TestGithubListRepos(t *testing.T) { cases := []struct { - name, proto, url string - hasError, allBranches bool - branches []string - filters []v1alpha1.SCMProviderGeneratorFilter + name, proto, url string + hasError, allBranches, allPullRequests bool + branches []string + filters []v1alpha1.SCMProviderGeneratorFilter }{ { name: "blank protocol", @@ -67,11 +67,17 @@ func TestGithubListRepos(t *testing.T) { url: "git@github.com:argoproj/applicationset.git", branches: []string{"master", "release-0.1.0"}, }, + { + name: "all pull requests", + allPullRequests: true, + url: "git@github.com:argoproj/applicationset.git", + branches: []string{"pr-1", "pr-2"}, + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - provider, _ := NewGithubProvider(context.Background(), "argoproj", "", "", c.allBranches) + provider, _ := NewGithubProvider(context.Background(), "argoproj", "", "", c.allBranches, c.allPullRequests) rawRepos, err := ListRepos(context.Background(), provider, c.filters, c.proto) if c.hasError { assert.Error(t, err) @@ -98,7 +104,7 @@ func TestGithubListRepos(t *testing.T) { } func TestGithubHasPath(t *testing.T) { - host, _ := NewGithubProvider(context.Background(), "argoproj", "", "", false) + host, _ := NewGithubProvider(context.Background(), "argoproj", "", "", false, false) repo := &Repository{ Organization: "argoproj", Repository: "applicationset", diff --git a/pkg/services/scm_provider/gitlab.go b/pkg/services/scm_provider/gitlab.go index d9b372a3..74ce67e7 100644 --- a/pkg/services/scm_provider/gitlab.go +++ b/pkg/services/scm_provider/gitlab.go @@ -13,11 +13,12 @@ type GitlabProvider struct { organization string allBranches bool includeSubgroups bool + allPullRequests bool } var _ SCMProviderService = &GitlabProvider{} -func NewGitlabProvider(ctx context.Context, organization string, token string, url string, allBranches, includeSubgroups bool) (*GitlabProvider, error) { +func NewGitlabProvider(ctx context.Context, organization string, token string, url string, allBranches, includeSubgroups, allPullRequests bool) (*GitlabProvider, error) { // Undocumented environment variable to set a default token, to be used in testing to dodge anonymous rate limits. if token == "" { token = os.Getenv("GITLAB_TOKEN") @@ -36,7 +37,7 @@ func NewGitlabProvider(ctx context.Context, organization string, token string, u return nil, err } } - return &GitlabProvider{client: client, organization: organization, allBranches: allBranches, includeSubgroups: includeSubgroups}, nil + return &GitlabProvider{client: client, organization: organization, allBranches: allBranches, includeSubgroups: includeSubgroups, allPullRequests: allPullRequests}, nil } func (g *GitlabProvider) GetBranches(ctx context.Context, repo *Repository) ([]*Repository, error) { @@ -60,6 +61,28 @@ func (g *GitlabProvider) GetBranches(ctx context.Context, repo *Repository) ([]* return repos, nil } +func (g *GitlabProvider) GetPullRequests(ctx context.Context, repo *Repository) ([]*Repository, error) { + repos := []*Repository{} + + pullRequests, err := g.listPullRequests(ctx, repo) + if err != nil { + return nil, err + } + + for _, pullRequest := range pullRequests { + repos = append(repos, &Repository{ + Organization: repo.Organization, + Repository: repo.Repository, + URL: repo.URL, + Branch: pullRequest.Title, + SHA: pullRequest.SHA, + Labels: pullRequest.Labels, + RepositoryId: repo.RepositoryId, + }) + } + return repos, nil +} + func (g *GitlabProvider) ListRepos(ctx context.Context, cloneProtocol string) ([]*Repository, error) { opt := &gitlab.ListGroupProjectsOptions{ ListOptions: gitlab.ListOptions{PerPage: 100}, @@ -149,3 +172,26 @@ func (g *GitlabProvider) listBranches(_ context.Context, repo *Repository) ([]gi } return branches, nil } + +func (g *GitlabProvider) listPullRequests(_ context.Context, repo *Repository) ([]gitlab.MergeRequest, error) { + opt := &gitlab.ListProjectMergeRequestsOptions{ + ListOptions: gitlab.ListOptions{PerPage: 100}, + } + + pullRequests := []gitlab.MergeRequest{} + for { + gitlabPullRequests, resp, err := g.client.MergeRequests.ListProjectMergeRequests(repo.RepositoryId, opt) + if err != nil { + return nil, err + } + for _, gitlabPullRequest := range gitlabPullRequests { + pullRequests = append(pullRequests, *gitlabPullRequest) + } + + if resp.NextPage == 0 { + break + } + opt.Page = resp.NextPage + } + return pullRequests, nil +} diff --git a/pkg/services/scm_provider/gitlab_test.go b/pkg/services/scm_provider/gitlab_test.go index d53ad1bd..8bcb657d 100644 --- a/pkg/services/scm_provider/gitlab_test.go +++ b/pkg/services/scm_provider/gitlab_test.go @@ -10,10 +10,10 @@ import ( func TestGitlabListRepos(t *testing.T) { cases := []struct { - name, proto, url string - hasError, allBranches, includeSubgroups bool - branches []string - filters []v1alpha1.SCMProviderGeneratorFilter + name, proto, url string + hasError, allBranches, includeSubgroups, allPullRequests bool + branches []string + filters []v1alpha1.SCMProviderGeneratorFilter }{ { name: "blank protocol", @@ -45,7 +45,7 @@ func TestGitlabListRepos(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - provider, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", c.allBranches, c.includeSubgroups) + provider, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", c.allBranches, c.includeSubgroups, c.allPullRequests) rawRepos, err := ListRepos(context.Background(), provider, c.filters, c.proto) if c.hasError { assert.NotNil(t, err) @@ -72,7 +72,7 @@ func TestGitlabListRepos(t *testing.T) { } func TestGitlabHasPath(t *testing.T) { - host, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", false, true) + host, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", false, true, false) repo := &Repository{ Organization: "test-argocd-proton", Repository: "argocd", diff --git a/pkg/services/scm_provider/mock.go b/pkg/services/scm_provider/mock.go index bf7e452c..5249a7af 100644 --- a/pkg/services/scm_provider/mock.go +++ b/pkg/services/scm_provider/mock.go @@ -44,7 +44,25 @@ func (m *MockProvider) GetBranches(_ context.Context, repo *Repository) ([]*Repo branchRepos = append(branchRepos, candidateRepo) } } - } return branchRepos, nil } + +func (m *MockProvider) GetPullRequests(_ context.Context, repo *Repository) ([]*Repository, error) { + pullRequestRepos := []*Repository{} + for _, candidateRepo := range m.Repos { + if candidateRepo.Repository == repo.Repository { + found := false + for _, alreadySetRepo := range pullRequestRepos { + if alreadySetRepo.Branch == candidateRepo.Branch { + found = true + break + } + } + if !found { + pullRequestRepos = append(pullRequestRepos, candidateRepo) + } + } + } + return pullRequestRepos, nil +} diff --git a/pkg/services/scm_provider/types.go b/pkg/services/scm_provider/types.go index b7f90bd8..ebb48b34 100644 --- a/pkg/services/scm_provider/types.go +++ b/pkg/services/scm_provider/types.go @@ -20,15 +20,18 @@ type SCMProviderService interface { ListRepos(context.Context, string) ([]*Repository, error) RepoHasPath(context.Context, *Repository, string) (bool, error) GetBranches(context.Context, *Repository) ([]*Repository, error) + GetPullRequests(context.Context, *Repository) ([]*Repository, error) } // A compiled version of SCMProviderGeneratorFilter for performance. type Filter struct { - RepositoryMatch *regexp.Regexp - PathsExist []string - LabelMatch *regexp.Regexp - BranchMatch *regexp.Regexp - FilterType FilterType + RepositoryMatch *regexp.Regexp + PathsExist []string + LabelMatch *regexp.Regexp + BranchMatch *regexp.Regexp + PullRequestTitleMatch *regexp.Regexp + PullRequestLabelMatch *regexp.Regexp + FilterType FilterType } // A convenience type for indicating where to apply a filter @@ -39,4 +42,5 @@ const ( FilterTypeUndefined FilterType = iota FilterTypeBranch FilterTypeRepo + FilterTypePullRequest ) diff --git a/pkg/services/scm_provider/utils.go b/pkg/services/scm_provider/utils.go index 07f29a59..54fca3e1 100644 --- a/pkg/services/scm_provider/utils.go +++ b/pkg/services/scm_provider/utils.go @@ -38,6 +38,20 @@ func compileFilters(filters []argoprojiov1alpha1.SCMProviderGeneratorFilter) ([] } outFilter.FilterType = FilterTypeBranch } + if filter.PullRequestTitleMatch != nil { + outFilter.PullRequestTitleMatch, err = regexp.Compile(*filter.PullRequestTitleMatch) + if err != nil { + return nil, fmt.Errorf("error compiling PullRequestTitleMatch regexp %q: %v", *filter.PullRequestTitleMatch, err) + } + outFilter.FilterType = FilterTypePullRequest + } + if filter.PullRequestLabelMatch != nil { + outFilter.PullRequestLabelMatch, err = regexp.Compile(*filter.PullRequestLabelMatch) + if err != nil { + return nil, fmt.Errorf("error compiling PullRequestLabelMatch regexp %q: %v", *filter.PullRequestLabelMatch, err) + } + outFilter.FilterType = FilterTypePullRequest + } outFilters = append(outFilters, outFilter) } return outFilters, nil @@ -52,6 +66,23 @@ func matchFilter(ctx context.Context, provider SCMProviderService, repo *Reposit return false, nil } + if filter.PullRequestTitleMatch != nil && !filter.PullRequestTitleMatch.MatchString(repo.Branch) { + return false, nil + } + + if filter.PullRequestLabelMatch != nil { + found := false + for _, label := range repo.Labels { + if filter.PullRequestLabelMatch.MatchString(label) { + found = true + break + } + } + if !found { + return false, nil + } + } + if filter.LabelMatch != nil { found := false for _, label := range repo.Labels { @@ -114,7 +145,12 @@ func ListRepos(ctx context.Context, provider SCMProviderService, filters []argop } } - repos, err = getBranches(ctx, provider, filteredRepos, compiledFilters) + repos1, err := getPullRequests(ctx, provider, filteredRepos, compiledFilters) + + repos2, err := getBranches(ctx, provider, filteredRepos, compiledFilters) + + repos = append(repos1, repos2...) + if err != nil { return nil, err } @@ -150,17 +186,49 @@ func getBranches(ctx context.Context, provider SCMProviderService, repos []*Repo return filteredRepos, nil } +func getPullRequests(ctx context.Context, provider SCMProviderService, repos []*Repository, compiledFilters []*Filter) ([]*Repository, error) { + reposWithPullRequests := []*Repository{} + for _, repo := range repos { + reposFilled, err := provider.GetPullRequests(ctx, repo) + if err != nil { + return nil, err + } + reposWithPullRequests = append(reposWithPullRequests, reposFilled...) + } + pullRequestFilters := getApplicableFilters(compiledFilters)[FilterTypePullRequest] + if len(pullRequestFilters) == 0 { + return reposWithPullRequests, nil + } + filteredRepos := make([]*Repository, 0, len(reposWithPullRequests)) + for _, repo := range reposWithPullRequests { + for _, filter := range pullRequestFilters { + matches, err := matchFilter(ctx, provider, repo, filter) + if err != nil { + return nil, err + } + if matches { + filteredRepos = append(filteredRepos, repo) + break + } + } + } + return filteredRepos, nil +} + // getApplicableFilters returns a map of filters separated by type. func getApplicableFilters(filters []*Filter) map[FilterType][]*Filter { filterMap := map[FilterType][]*Filter{ - FilterTypeBranch: {}, - FilterTypeRepo: {}, + FilterTypeBranch: {}, + FilterTypeRepo: {}, + FilterTypePullRequest: {}, } for _, filter := range filters { if filter.FilterType == FilterTypeBranch { filterMap[FilterTypeBranch] = append(filterMap[FilterTypeBranch], filter) } else if filter.FilterType == FilterTypeRepo { filterMap[FilterTypeRepo] = append(filterMap[FilterTypeRepo], filter) + } else if filter.FilterType == FilterTypePullRequest { + filterMap[FilterTypePullRequest] = append(filterMap[FilterTypePullRequest], filter) } } return filterMap