Skip to content

Commit

Permalink
Merge pull request #7286 from mook-as/win32/wsl-version-without-appx
Browse files Browse the repository at this point in the history
WSL-helper: Try to get WSL version without appx
  • Loading branch information
mook-as authored Aug 2, 2024
2 parents c4b5321 + f6a6073 commit 00ca4ea
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 240 deletions.
183 changes: 16 additions & 167 deletions src/go/wsl-helper/pkg/wsl-utils/version_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ func (i WSLInfo) String() string {
}

const (
// kPackageFamily is the package family for the WSL app (MSIX).
kPackageFamily = "MicrosoftCorporationII.WindowsSubsystemForLinux_8wekyb3d8bbwe" // spellcheck-ignore-line
// kMsiUpgradeCode is the upgrade code for the WSL kernel (for in-box WSL2)
kMsiUpgradeCode = "{1C3DB5B6-65A5-4EBC-A5B9-2F2D6F665F48}"
// Number of characters in a GUID string, including spaces
Expand All @@ -82,20 +80,9 @@ const (
//nolint:stylecheck // Win32 constants
const (
INSTALLPROPERTY_VERSIONSTRING = "VersionString"
PACKAGE_INFORMATION_BASIC = 0x00000000
PACKAGE_INFORMATION_FULL = 0x00000100
PACKAGE_FILTER_STATIC = 0x00080000
PACKAGE_FILTER_DYNAMIC = 0x00100000
PackagePathType_Effective = 2
)

var (
dllKernel32 = windows.NewLazySystemDLL("kernel32.dll")
getPackagesByPackageFamily = dllKernel32.NewProc("GetPackagesByPackageFamily")
openPackageInfoByFullName = dllKernel32.NewProc("OpenPackageInfoByFullName")
closePackageInfo = dllKernel32.NewProc("ClosePackageInfo")
getPackageInfo = dllKernel32.NewProc("GetPackageInfo")

dllMsi = windows.NewLazySystemDLL("msi.dll")
msiEnumRelatedProducts = dllMsi.NewProc("MsiEnumRelatedProductsW")
msiGetProductInfo = dllMsi.NewProc("MsiGetProductInfoW")
Expand All @@ -119,53 +106,6 @@ func errorFromWin32(msg string, rv uintptr, err error) error {
return fmt.Errorf("%s: %w", msg, windows.Errno(rv))
}

// getPackageNames returns the package names for the given package family.
func getPackageNames(packageFamily string) ([]string, error) {
var count, bufferLength uint32
packageFamilyBytes, err := windows.UTF16PtrFromString(packageFamily)
if err != nil {
return nil, fmt.Errorf("error allocating package family name: %w", err)
}
rv, _, err := getPackagesByPackageFamily.Call(
uintptr(unsafe.Pointer(packageFamilyBytes)),
uintptr(unsafe.Pointer(&count)),
uintptr(unsafe.Pointer(nil)),
uintptr(unsafe.Pointer(&bufferLength)),
uintptr(unsafe.Pointer(nil)),
)
switch rv {
case uintptr(windows.ERROR_SUCCESS):
break
case uintptr(windows.ERROR_INSUFFICIENT_BUFFER):
// This is expected: we didn't provide any buffer
break
default:
return nil, errorFromWin32("error getting buffer size", rv, err)
}

packageNames := make([]uintptr, count)
packageNameBuffer := make([]uint16, bufferLength)

rv, _, err = getPackagesByPackageFamily.Call(
uintptr(unsafe.Pointer(packageFamilyBytes)),
uintptr(unsafe.Pointer(&count)),
uintptr(unsafe.Pointer(unsafe.SliceData(packageNames))),
uintptr(unsafe.Pointer(&bufferLength)),
uintptr(unsafe.Pointer(unsafe.SliceData(packageNameBuffer))),
)
if rv != uintptr(windows.ERROR_SUCCESS) {
return nil, errorFromWin32("error getting package names", rv, err)
}

result := make([]string, count)
slice := unsafe.Slice((**uint16)(unsafe.Pointer(unsafe.SliceData(packageNames))), count)
for i, ptr := range slice {
result[i] = windows.UTF16PtrToString(ptr)
}

return result, nil
}

// PackageVersion corresponds to the PACKAGE_VERSION structure.
type PackageVersion struct {
Revision uint16 `json:"revision"`
Expand Down Expand Up @@ -232,85 +172,9 @@ func (v PackageVersion) Less(other PackageVersion) bool {
return false
}

// packageInfo corresponds to the PACKAGE_INFO structure.
type packageInfo struct {
reserved uint32
flags uint32
path *uint16
packageFullName *uint16
packageFamilyName *uint16
packageID struct {
reserved uint32
processorArchitecture uint32
version PackageVersion
name *uint16
publisher *uint16
resourceID *uint16
publisherID *uint16
}
}

// getPackageVersion gets the package version of the package with the given
// full name.
func getPackageVersion(fullName string) (*PackageVersion, error) {
nameBuffer, err := windows.UTF16PtrFromString(fullName)
if err != nil {
return nil, err
}
var packageInfoReference uintptr
rv, _, err := openPackageInfoByFullName.Call(
uintptr(unsafe.Pointer(nameBuffer)),
0, // reserved
uintptr(unsafe.Pointer(&packageInfoReference)),
)
if rv != uintptr(windows.ERROR_SUCCESS) {
return nil, errorFromWin32("error opening package info", rv, err)
}
defer func() { _, _, _ = closePackageInfo.Call(packageInfoReference) }()

var bufferLength, count uint32
rv, _, err = getPackageInfo.Call(
packageInfoReference,
uintptr(PACKAGE_INFORMATION_BASIC|PACKAGE_FILTER_STATIC|PACKAGE_FILTER_DYNAMIC),
uintptr(unsafe.Pointer(&bufferLength)),
uintptr(unsafe.Pointer(nil)),
uintptr(unsafe.Pointer(nil)),
)
switch rv {
case uintptr(windows.ERROR_SUCCESS):
break
case uintptr(windows.ERROR_INSUFFICIENT_BUFFER):
// This is expected: we didn't provide any buffer
break
default:
return nil, errorFromWin32("error getting buffer size", rv, err)
}

buf := make([]byte, bufferLength)
rv, _, err = getPackageInfo.Call(
packageInfoReference,
uintptr(PACKAGE_INFORMATION_BASIC|PACKAGE_FILTER_STATIC|PACKAGE_FILTER_DYNAMIC),
uintptr(unsafe.Pointer(&bufferLength)),
uintptr(unsafe.Pointer(unsafe.SliceData(buf))),
uintptr(unsafe.Pointer(&count)),
)
if rv != uintptr(windows.ERROR_SUCCESS) {
return nil, errorFromWin32("error getting package info", rv, err)
}
infos := unsafe.Slice((*packageInfo)(unsafe.Pointer(unsafe.SliceData(buf))), count)
for _, info := range infos {
// `info` is a pointer to an unsafe slice; make a copy of the version
// on the stack and then return that so the GC knows about it.
versionCopy := info.packageID.version
return &versionCopy, nil
}

return nil, fmt.Errorf("no info found for %s", fullName)
}

// Get the component versions for an AppX-based installation. Returns the WSL
// Get the component versions by asking the CLI. Returns the WSL
// version, followed by the kernel version.
func getAppxVersion(ctx context.Context, log *logrus.Entry) (*PackageVersion, *PackageVersion, error) {
func getVersionFromCLI(ctx context.Context, log *logrus.Entry) (*PackageVersion, *PackageVersion, error) {
newRunnerFunc := NewWSLRunner
if f := ctx.Value(&kWSLExeOverride); f != nil {
newRunnerFunc = f.(func() WSLRunner)
Expand Down Expand Up @@ -352,9 +216,9 @@ func getAppxVersion(ctx context.Context, log *logrus.Entry) (*PackageVersion, *P
}
}
if len(errorList) > 0 {
return nil, nil, fmt.Errorf("error getting AppX version: %w", errors.Join(errorList...))
return nil, nil, fmt.Errorf("error getting WSL version from CLI: %w", errors.Join(errorList...))
}
log.WithFields(logrus.Fields{"wsl": wslVersion, "kernel": kernelVersion}).Trace("got AppX version")
log.WithFields(logrus.Fields{"wsl": wslVersion, "kernel": kernelVersion}).Trace("got version from CLI")
return &wslVersion, &kernelVersion, nil
}

Expand Down Expand Up @@ -480,34 +344,19 @@ func getMSIVersion(productCode []uint16, log *logrus.Entry) (*PackageVersion, er
}

func GetWSLInfo(ctx context.Context, log *logrus.Entry) (*WSLInfo, error) {
names, err := getPackageNames(kPackageFamily)
if err != nil {
log.WithError(err).Trace("Error getting appx packages")
return nil, err
}

log.Tracef("Got %d appx packages", len(names))
for _, name := range names {
if version, err := getPackageVersion(name); err == nil {
log.Tracef("Got appx package %s with version %s", name, version)
wslVersion, kernelVersion, err := getAppxVersion(ctx, log)
if err != nil {
return nil, err
}
return &WSLInfo{
Installed: true,
Inbox: false,
Version: *wslVersion,
KernelVersion: *kernelVersion,
HasKernel: PackageVersion{}.Less(*kernelVersion),
OutdatedKernel: kernelVersion.Less(MinimumKernelVersion),
}, nil
} else {
log.WithError(err).Trace("Failed to get package version")
}
}
wslVersion, kernelVersion, err := getVersionFromCLI(ctx, log)
if err == nil {
return &WSLInfo{
Installed: true,
Inbox: false,
Version: *wslVersion,
KernelVersion: *kernelVersion,
HasKernel: PackageVersion{}.Less(*kernelVersion),
OutdatedKernel: kernelVersion.Less(MinimumKernelVersion),
}, nil
}
log.WithError(err).Trace("Could not get version from `wsl --version`, trying inbox versions...")

log.Trace("Failed to get WSL appx package, trying inbox versions...")
hasWSL, kernelVersion, err := getInboxWSLInfo(ctx, log)
if err != nil {
return nil, err
Expand Down
75 changes: 2 additions & 73 deletions src/go/wsl-helper/pkg/wsl-utils/version_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,6 @@ import (
"golang.org/x/sys/windows"
)

func TestGetPackageNames(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

t.Run("valid package family", func(t *testing.T) {
// Get a random package family
command := `Get-AppxPackage | Select-Object -First 1 -ExpandProperty PackageFamilyName`
output, err := runPowerShell(ctx, command)
if err != nil {
t.Skipf("Failed to get sample family, skipping test: %s", err)
}
family := strings.TrimSpace(output.String())

// Get all packages in that family; not that we can't pipe single element
// arrays (they get flattened to just the element), so we need to pass the
// list as an argument to ConvertTo-JSON
command = fmt.Sprintf(strings.NewReplacer("\r", " ", "\n", " ", "\t", " ").Replace(`
ConvertTo-JSON -InputObject @(
Get-AppxPackage
| Where-Object { $_.PackageFamilyName -eq "%s" }
| Select-Object -ExpandProperty PackageFullName
)
`), family)
output, err = runPowerShell(ctx, command)
if err != nil {
t.Skipf("Failed to get packages in family %s, skipping test: %s", family, err)
}
var expected []string
assert.NoErrorf(t, json.Unmarshal(output.Bytes(), &expected), "failed to read package names: %s", output.String())

names, err := getPackageNames(family)
require.NoError(t, err, "Error getting package names")
assert.ElementsMatch(t, expected, names, "Failed to get packages")
})

t.Run("invalid package family", func(t *testing.T) {
_, err := getPackageNames("invalid package family")
assert.Error(t, err, "should not get packages for invalid family")
})
}

func TestPackageVersion(t *testing.T) {
t.Run("UnmarshalText", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -155,36 +114,6 @@ func TestPackageVersion(t *testing.T) {
})
}

func TestGetPackageVersion(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

t.Run("valid package", func(t *testing.T) {
// Get a random package and version
info := struct {
PackageFullName string
Version *PackageVersion
}{}
command := `Get-AppxPackage | Select-Object -First 1 -Property PackageFullName, Version | ConvertTo-JSON`
output, err := runPowerShell(ctx, command)
if err != nil {
t.Skipf("Failed to get sample package, skipping test: %s", err)
}
require.NoError(t, json.Unmarshal(output.Bytes(), &info), "failed to get package")
require.NotNil(t, info.Version, "failed to get package version")

version, err := getPackageVersion(info.PackageFullName)
require.NoError(t, err)
require.NotNil(t, version)
assert.Equal(t, info.Version, version, "unexpected version")
})

t.Run("invalid package", func(t *testing.T) {
_, err := getPackageVersion("not a valid package name")
assert.Error(t, err, "should error with invalid package name")
})
}

// TestWithExitCode is a dummy test function to let us exit with a given exit
// code. See TestIsInboxWSLInstalled/not_installed.
func TestWithExitCode(t *testing.T) {
Expand Down Expand Up @@ -220,7 +149,7 @@ func runPowerShell(ctx context.Context, command string) (*bytes.Buffer, error) {
return stdout, nil
}

func TestGetAppxVersion(t *testing.T) {
func TestGetVersionFromCLI(t *testing.T) {
outputs := map[string]struct {
lines []string
wsl string
Expand Down Expand Up @@ -261,7 +190,7 @@ func TestGetAppxVersion(t *testing.T) {
return nil
})
var expectedWSL, expectedKernel PackageVersion
wsl, kernel, err := getAppxVersion(ctx, logrus.NewEntry(logger))
wsl, kernel, err := getVersionFromCLI(ctx, logrus.NewEntry(logger))
assert.NoError(t, err)
assert.NoError(t, expectedWSL.UnmarshalText([]byte(input.wsl)))
assert.NoError(t, expectedKernel.UnmarshalText([]byte(input.kernel)))
Expand Down

0 comments on commit 00ca4ea

Please sign in to comment.