-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcvekiller.go
140 lines (114 loc) · 3.85 KB
/
cvekiller.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package woodpecker
import (
"context"
"fmt"
"github.com/JackKCWong/go-woodpecker/api"
"github.com/JackKCWong/go-woodpecker/internal/util"
"os"
"strings"
"time"
)
type KillOpts struct {
Opts
SendPR bool
}
func (w Woodpecker) Kill(args []string, opts KillOpts) error {
if multiModules, err := w.DepMgr.IsMultiModules(); err != nil {
return fmt.Errorf("failed to check if multi-modules: %w", err)
} else if multiModules {
return fmt.Errorf("kill does not work on the parent project. Please run it on the child project")
}
tree, err := w.DepMgr.DependencyTree()
if err != nil {
return err
}
cveID := args[0]
subtreeToUpdate, found := tree.FirstChildWithCVE(cveID)
if !found {
return fmt.Errorf("CVE %s not found in the dependency tree", cveID)
}
originalPackageID := subtreeToUpdate.Root().ID
newPackageID := ""
newBrachName := strings.ReplaceAll(fmt.Sprintf("%s/%s/%s", opts.BranchNamePrefix, originalPackageID, cveID), ":", "/")
err = w.GitClient.Branch(newBrachName)
if err != nil {
return err
}
for subtreeToUpdate, found = subtreeToUpdate.FirstChildWithCVE(cveID); found; subtreeToUpdate, found = subtreeToUpdate.FirstChildWithCVE(cveID) {
util.Printfln(os.Stdout, "%s found in %s, upgrading...", cveID, subtreeToUpdate.Root().ID)
lastPackageID := subtreeToUpdate.Root().ID
newPackageID, err = w.DepMgr.UpdateDependency(subtreeToUpdate.Root())
if err != nil {
return fmt.Errorf("failed to update dependency %s: %w", subtreeToUpdate.Root().ID, err)
}
if lastPackageID == newPackageID {
util.Printfln(os.Stdout, "already the latest version: %s, exiting...", newPackageID)
return fmt.Errorf("no version available without %s", cveID)
}
util.Printfln(os.Stdout, "upgraded to %s", newPackageID)
subtreeToUpdate, err = w.DepMgr.DependencyTree()
if err != nil {
return fmt.Errorf("failed to get dependency tree: %w", err)
}
}
util.Printfln(os.Stdout, "%s is killed.", cveID)
util.Printfln(os.Stdout, "start verifying...")
result, err := w.DepMgr.Verify()
if !result.Passed {
if err == nil {
err = fmt.Errorf("unknown error")
}
return fmt.Errorf("verification failed: %w\n%s", err, result.Summary)
}
var verificationResult string
if result.Summary == "" {
verificationResult = "build passed but you don't seem to have any test! good luck!"
util.Printfln(os.Stdout, verificationResult)
} else {
verificationResult = fmt.Sprintf("verification passed: \n%s", result.Summary)
util.Printfln(os.Stdout, verificationResult)
}
err = w.DepMgr.StageUpdate()
if err != nil {
return fmt.Errorf("failed to apply change: %w", err)
}
commitMessage := "removing " + cveID + " in " + originalPackageID
hash, err := w.GitClient.Commit(commitMessage)
if err != nil {
return err
}
util.Printfln(os.Stdout, "commited %s", hash)
if opts.SendPR {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
err := w.GitClient.Push(ctx)
if err != nil {
return err
}
origin, err := w.GitClient.Origin()
if err != nil {
return err
}
prComment := fmt.Sprintf("### updated to %s, verification result:\n> %s", newPackageID, verificationResult)
subtreeToUpdate, found := tree.FirstChildWithCVE(cveID)
if found {
cve, _ := subtreeToUpdate.FindCVE(cveID)
prComment = fmt.Sprintf("%s\n\n### CVE details\n%s\n>%s",
prComment, formatCVESummary(cve), cve.Description)
} else {
util.Printfln(os.Stderr, "weired...CVE %s not found in the original dependency tree...", cveID)
}
pullRequestURL, err := w.GitServer.CreatePullRequest(ctx,
origin, newBrachName, "master",
commitMessage,
prComment)
if err != nil {
return err
}
util.Printfln(os.Stdout, "Pull request created: %s", pullRequestURL)
}
return nil
}
func formatCVESummary(cve api.Vulnerability) string {
return fmt.Sprintf("[%s](%s) - %s %.1f", cve.Cve, cve.NVDUrl(), cve.Severity, cve.CvssScore)
}