Skip to content

Commit cbf253f

Browse files
committed
Add compare and matching for OS features
Signed-off-by: Derek McGowan <[email protected]>
1 parent 1dc1164 commit cbf253f

File tree

3 files changed

+156
-3
lines changed

3 files changed

+156
-3
lines changed

compare.go

+23-1
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,20 @@ func (c orderedPlatformComparer) Less(p1 specs.Platform, p2 specs.Platform) bool
218218
return true
219219
}
220220
if p1m || p2m {
221+
if p1m && p2m {
222+
// Prefer one with most matching features
223+
if len(p1.OSFeatures) != len(p2.OSFeatures) {
224+
return len(p1.OSFeatures) > len(p2.OSFeatures)
225+
}
226+
}
221227
return false
222228
}
223229
}
230+
if len(p1.OSFeatures) > 0 || len(p2.OSFeatures) > 0 {
231+
p1.OSFeatures = nil
232+
p2.OSFeatures = nil
233+
return c.Less(p1, p2)
234+
}
224235
return false
225236
}
226237

@@ -247,9 +258,20 @@ func (c anyPlatformComparer) Less(p1, p2 specs.Platform) bool {
247258
p2m = true
248259
}
249260
if p1m && p2m {
250-
return false
261+
if len(p1.OSFeatures) != len(p2.OSFeatures) {
262+
return len(p1.OSFeatures) > len(p2.OSFeatures)
263+
}
264+
break
251265
}
252266
}
267+
268+
// If neither match and has features, strip features and compare
269+
if !p1m && !p2m && (len(p1.OSFeatures) > 0 || len(p2.OSFeatures) > 0) {
270+
p1.OSFeatures = nil
271+
p2.OSFeatures = nil
272+
return c.Less(p1, p2)
273+
}
274+
253275
// If one matches, and the other does, sort match first
254276
return p1m && !p2m
255277
}

compare_test.go

+92
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
package platforms
1818

1919
import (
20+
"sort"
2021
"testing"
22+
23+
"github.com/stretchr/testify/assert"
2124
)
2225

2326
func TestOnly(t *testing.T) {
@@ -592,3 +595,92 @@ func TestOnlyStrict(t *testing.T) {
592595
})
593596
}
594597
}
598+
599+
func TestCompareOSFeatures(t *testing.T) {
600+
for _, tc := range []struct {
601+
platform string
602+
platforms []string
603+
expected []string
604+
}{
605+
{
606+
"linux/amd64",
607+
[]string{"windows/amd64", "linux/amd64", "linux(+other)/amd64", "linux/arm64"},
608+
[]string{"linux/amd64", "linux(+other)/amd64", "windows/amd64", "linux/arm64"},
609+
},
610+
{
611+
"linux(+none)/amd64",
612+
[]string{"windows/amd64", "linux/amd64", "linux/arm64", "linux(+other)/amd64"},
613+
[]string{"linux/amd64", "linux(+other)/amd64", "windows/amd64", "linux/arm64"},
614+
},
615+
{
616+
"linux(+other)/amd64",
617+
[]string{"windows/amd64", "linux/amd64", "linux/arm64", "linux(+other)/amd64"},
618+
[]string{"linux(+other)/amd64", "linux/amd64", "windows/amd64", "linux/arm64"},
619+
},
620+
{
621+
"linux(+af+other+zf)/amd64",
622+
[]string{"windows/amd64", "linux/amd64", "linux/arm64", "linux(+other)/amd64"},
623+
[]string{"linux(+other)/amd64", "linux/amd64", "windows/amd64", "linux/arm64"},
624+
},
625+
{
626+
"linux(+f1+f2)/amd64",
627+
[]string{"linux/amd64", "linux(+f2)/amd64", "linux(+f1)/amd64", "linux(+f1+f2)/amd64"},
628+
[]string{"linux(+f1+f2)/amd64", "linux(+f2)/amd64", "linux(+f1)/amd64", "linux/amd64"},
629+
},
630+
{
631+
// This test should likely fail and be updated when os version is considered for linux
632+
"linux(7.2+other)/amd64",
633+
[]string{"linux/amd64", "linux(+other)/amd64", "linux(7.1)/amd64", "linux(7.2+other)/amd64"},
634+
[]string{"linux(+other)/amd64", "linux(7.2+other)/amd64", "linux/amd64", "linux(7.1)/amd64"},
635+
},
636+
} {
637+
testcase := tc
638+
t.Run(testcase.platform, func(t *testing.T) {
639+
t.Parallel()
640+
p, err := Parse(testcase.platform)
641+
if err != nil {
642+
t.Fatal(err)
643+
}
644+
645+
for _, stc := range []struct {
646+
name string
647+
mc MatchComparer
648+
}{
649+
{
650+
name: "only",
651+
mc: Only(p),
652+
},
653+
{
654+
name: "only strict",
655+
mc: OnlyStrict(p),
656+
},
657+
{
658+
name: "ordered",
659+
mc: Ordered(p),
660+
},
661+
{
662+
name: "any",
663+
mc: Any(p),
664+
},
665+
} {
666+
mc := stc.mc
667+
testcase := testcase
668+
t.Run(stc.name, func(t *testing.T) {
669+
p, err := ParseAll(testcase.platforms)
670+
if err != nil {
671+
t.Fatal(err)
672+
}
673+
sort.Slice(p, func(i, j int) bool {
674+
return mc.Less(p[i], p[j])
675+
})
676+
actual := make([]string, len(p))
677+
for i, ps := range p {
678+
actual[i] = FormatAll(ps)
679+
}
680+
681+
assert.Equal(t, testcase.expected, actual)
682+
})
683+
}
684+
})
685+
}
686+
}

platforms.go

+41-2
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ import (
114114
"path"
115115
"regexp"
116116
"runtime"
117+
"sort"
117118
"strconv"
118119
"strings"
119120

@@ -141,6 +142,10 @@ type Matcher interface {
141142
// functionality.
142143
//
143144
// Applications should opt to use `Match` over directly parsing specifiers.
145+
//
146+
// For OSFeatures, this matcher will match if the platform to match has
147+
// OSFeatures which are a subset of the OSFeatures of the platform
148+
// provided to NewMatcher.
144149
func NewMatcher(platform specs.Platform) Matcher {
145150
return newDefaultMatcher(platform)
146151
}
@@ -151,9 +156,40 @@ type matcher struct {
151156

152157
func (m *matcher) Match(platform specs.Platform) bool {
153158
normalized := Normalize(platform)
154-
return m.OS == normalized.OS &&
159+
if m.OS == normalized.OS &&
155160
m.Architecture == normalized.Architecture &&
156-
m.Variant == normalized.Variant
161+
m.Variant == normalized.Variant {
162+
if len(normalized.OSFeatures) == 0 {
163+
return true
164+
}
165+
if len(m.OSFeatures) >= len(normalized.OSFeatures) {
166+
// Ensure that normalized.OSFeatures is a subet of
167+
// m.OSFeatures
168+
j := 0
169+
for _, feature := range normalized.OSFeatures {
170+
for ; j < len(m.OSFeatures); j++ {
171+
if feature == m.OSFeatures[j] {
172+
// Don't increment j since the list is sorted
173+
// but may contain duplicates
174+
// TODO: Deduplicate list during normalize so
175+
// that j can be incremented here
176+
break
177+
}
178+
// Since both lists are ordered, if the feature is less
179+
// than what is seen, it is not in the list
180+
if feature < m.OSFeatures[j] {
181+
return false
182+
}
183+
}
184+
// if we hit the end, then feature was not found
185+
if j == len(m.OSFeatures) {
186+
return false
187+
}
188+
}
189+
return true
190+
}
191+
}
192+
return false
157193
}
158194

159195
func (m *matcher) String() string {
@@ -311,6 +347,9 @@ func FormatAll(platform specs.Platform) string {
311347
func Normalize(platform specs.Platform) specs.Platform {
312348
platform.OS = normalizeOS(platform.OS)
313349
platform.Architecture, platform.Variant = normalizeArch(platform.Architecture, platform.Variant)
350+
if len(platform.OSFeatures) > 0 {
351+
sort.Strings(platform.OSFeatures)
352+
}
314353

315354
return platform
316355
}

0 commit comments

Comments
 (0)