diff --git a/providers/azure/connection/azureinstancesnapshot/provider.go b/providers/azure/connection/azureinstancesnapshot/provider.go index c4ab8df816..6c75ebb03a 100644 --- a/providers/azure/connection/azureinstancesnapshot/provider.go +++ b/providers/azure/connection/azureinstancesnapshot/provider.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/cockroachdb/errors" "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v9/mrn" @@ -23,8 +24,9 @@ import ( ) type scanTarget struct { - TargetType string - TargetName string + TargetType string + Target string + ResourceGroup string } const ( @@ -85,16 +87,28 @@ func determineScannerInstanceInfo(localConn *connection.LocalConnection, token a }, nil } -func ParseTarget(conf *inventory.Config) scanTarget { - return scanTarget{ - TargetType: conf.Options["type"], - TargetName: conf.Options["target-name"], +func ParseTarget(conf *inventory.Config, scanner *azureScannerInstance) (scanTarget, error) { + target := conf.Options["target"] + if target == "" { + return scanTarget{}, errors.New("target is required") + } + id, err := arm.ParseResourceID(conf.Options["target"]) + if err != nil { + log.Debug().Msg("could not parse target as resource id, assuming it's only the resource name") + return scanTarget{ + TargetType: conf.Options["type"], + Target: conf.Options["target"], + ResourceGroup: scanner.ResourceGroup, + }, nil } + return scanTarget{ + TargetType: conf.Options["type"], + Target: id.Name, + ResourceGroup: id.ResourceGroupName, + }, nil } func NewAzureSnapshotConnection(id uint32, conf *inventory.Config, asset *inventory.Asset) (*AzureSnapshotConnection, error) { - target := ParseTarget(conf) - var cred *vault.Credential if len(conf.Credentials) > 0 { cred = conf.Credentials[0] @@ -111,6 +125,11 @@ func NewAzureSnapshotConnection(id uint32, conf *inventory.Config, asset *invent return nil, err } + target, err := ParseTarget(conf, scanner) + if err != nil { + return nil, err + } + // determine the target sc, err := NewSnapshotCreator(token, scanner.SubscriptionId) if err != nil { @@ -127,19 +146,19 @@ func NewAzureSnapshotConnection(id uint32, conf *inventory.Config, asset *invent // setup disk image so and attach it to the instance mi := mountInfo{} - diskName := "cnspec-" + target.TargetName + "-snapshot-" + time.Now().Format("2006-01-02t15-04-05z00-00") + diskName := "cnspec-" + target.Target + "-snapshot-" + time.Now().Format("2006-01-02t15-04-05z00-00") switch target.TargetType { case "instance": - instanceInfo, err := sc.InstanceInfo(scanner.ResourceGroup, target.TargetName) + instanceInfo, err := sc.InstanceInfo(target.ResourceGroup, target.Target) if err != nil { return nil, err } if instanceInfo.BootDiskId == "" { - return nil, fmt.Errorf("could not find boot disk for instance %s", target.TargetName) + return nil, fmt.Errorf("could not find boot disk for instance %s", target.Target) } log.Debug().Str("boot disk", instanceInfo.BootDiskId).Msg("found boot disk for instance, cloning") - disk, err := sc.cloneDisk(instanceInfo.BootDiskId, scanner.ResourceGroup, diskName, instanceInfo.Location, scanner.Vm.Zones) + disk, err := sc.cloneDisk(instanceInfo.BootDiskId, scanner.ResourceGroup, diskName, scanner.Location, scanner.Vm.Zones) if err != nil { log.Error().Err(err).Msg("could not complete disk cloning") return nil, errors.Wrap(err, "could not complete disk cloning") @@ -150,12 +169,12 @@ func NewAzureSnapshotConnection(id uint32, conf *inventory.Config, asset *invent asset.Name = instanceInfo.InstanceName conf.PlatformId = azcompute.MondooAzureInstanceID(*instanceInfo.Vm.ID) case "snapshot": - snapshotInfo, err := sc.SnapshotInfo(scanner.ResourceGroup, target.TargetName) + snapshotInfo, err := sc.SnapshotInfo(target.ResourceGroup, target.Target) if err != nil { return nil, err } - disk, err := sc.createSnapshotDisk(snapshotInfo.SnapshotId, scanner.ResourceGroup, diskName, snapshotInfo.Location, scanner.Vm.Zones) + disk, err := sc.createSnapshotDisk(snapshotInfo.SnapshotId, scanner.ResourceGroup, diskName, scanner.Location, scanner.Vm.Zones) if err != nil { log.Error().Err(err).Msg("could not complete snapshot disk creation") return nil, errors.Wrap(err, "could not create disk from snapshot") @@ -163,7 +182,7 @@ func NewAzureSnapshotConnection(id uint32, conf *inventory.Config, asset *invent log.Debug().Str("disk", *disk.ID).Msg("created disk from snapshot") mi.diskId = *disk.ID mi.diskName = *disk.Name - asset.Name = target.TargetName + asset.Name = target.Target conf.PlatformId = SnapshotPlatformMrn(snapshotInfo.SnapshotId) default: return nil, errors.New("invalid target type") @@ -262,26 +281,34 @@ func (c *AzureSnapshotConnection) Close() { } } - err := c.volumeMounter.UnmountVolumeFromInstance() - if err != nil { - log.Error().Err(err).Msg("unable to unmount volume") + if c.volumeMounter != nil { + err := c.volumeMounter.UnmountVolumeFromInstance() + if err != nil { + log.Error().Err(err).Msg("unable to unmount volume") + } } if c.snapshotCreator != nil { - err = c.snapshotCreator.detachDisk(c.mountInfo.diskName, c.scanner.instanceInfo) - if err != nil { - log.Error().Err(err).Msg("unable to detach volume") + if c.mountInfo.diskName != "" { + err := c.snapshotCreator.detachDisk(c.mountInfo.diskName, c.scanner.instanceInfo) + if err != nil { + log.Error().Err(err).Msg("unable to detach volume") + } } - err = c.snapshotCreator.deleteCreatedDisk(c.scanner.ResourceGroup, c.mountInfo.diskName) - if err != nil { - log.Error().Err(err).Msg("could not delete created disk") + if c.mountInfo.diskName != "" { + err := c.snapshotCreator.deleteCreatedDisk(c.scanner.ResourceGroup, c.mountInfo.diskName) + if err != nil { + log.Error().Err(err).Msg("could not delete created disk") + } } } - err = c.volumeMounter.RemoveTempScanDir() - if err != nil { - log.Error().Err(err).Msg("unable to remove dir") + if c.volumeMounter != nil { + err := c.volumeMounter.RemoveTempScanDir() + if err != nil { + log.Error().Err(err).Msg("unable to remove dir") + } } } diff --git a/providers/azure/connection/azureinstancesnapshot/provider_test.go b/providers/azure/connection/azureinstancesnapshot/provider_test.go new file mode 100644 index 0000000000..9f739e3cf7 --- /dev/null +++ b/providers/azure/connection/azureinstancesnapshot/provider_test.go @@ -0,0 +1,98 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package azureinstancesnapshot + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.mondoo.com/cnquery/v9/providers-sdk/v1/inventory" +) + +func TestParseTarget(t *testing.T) { + t.Run("parse snapshot target with just a resource name", func(t *testing.T) { + scanner := &azureScannerInstance{ + instanceInfo: instanceInfo{ + ResourceGroup: "my-rg", + InstanceName: "my-instance", + }, + } + target := "my-other-snapshot" + + conf := &inventory.Config{ + Options: map[string]string{ + "target": target, + "type": "snapshot", + }, + } + scanTarget, err := ParseTarget(conf, scanner) + assert.NoError(t, err) + assert.Equal(t, "my-rg", scanTarget.ResourceGroup) + assert.Equal(t, target, scanTarget.Target) + assert.Equal(t, "snapshot", scanTarget.TargetType) + }) + t.Run("parse instance target with just a resource name", func(t *testing.T) { + scanner := &azureScannerInstance{ + instanceInfo: instanceInfo{ + ResourceGroup: "my-rg", + InstanceName: "my-instance", + }, + } + target := "my-other-instance" + + conf := &inventory.Config{ + Options: map[string]string{ + "target": target, + "type": "instance", + }, + } + scanTarget, err := ParseTarget(conf, scanner) + assert.NoError(t, err) + assert.Equal(t, "my-rg", scanTarget.ResourceGroup) + assert.Equal(t, target, scanTarget.Target) + assert.Equal(t, "instance", scanTarget.TargetType) + }) + t.Run("parse snapshot target with a fully qualifed Azure resource id", func(t *testing.T) { + scanner := &azureScannerInstance{ + instanceInfo: instanceInfo{ + ResourceGroup: "my-rg", + InstanceName: "my-instance", + }, + } + target := "/subscriptions/f1a2873a-6c27-4097-aa7c-3df51f103e91/resourceGroups/my-other-rg/providers/Microsoft.Compute/snapshots/test-snp" + + conf := &inventory.Config{ + Options: map[string]string{ + "target": target, + "type": "snapshot", + }, + } + scanTarget, err := ParseTarget(conf, scanner) + assert.NoError(t, err) + assert.Equal(t, "my-other-rg", scanTarget.ResourceGroup) + assert.Equal(t, "test-snp", scanTarget.Target) + assert.Equal(t, "snapshot", scanTarget.TargetType) + }) + t.Run("parse instance target with a fully qualifed Azure resource id", func(t *testing.T) { + scanner := &azureScannerInstance{ + instanceInfo: instanceInfo{ + ResourceGroup: "my-rg", + InstanceName: "my-instance", + }, + } + target := "/subscriptions/f1a2873a-6b27-4097-aa7c-3df51f103e96/resourceGroups/debian_group/providers/Microsoft.Compute/virtualMachines/debian" + + conf := &inventory.Config{ + Options: map[string]string{ + "target": target, + "type": "instance", + }, + } + scanTarget, err := ParseTarget(conf, scanner) + assert.NoError(t, err) + assert.Equal(t, "debian_group", scanTarget.ResourceGroup) + assert.Equal(t, "debian", scanTarget.Target) + assert.Equal(t, "instance", scanTarget.TargetType) + }) +} diff --git a/providers/azure/provider/provider.go b/providers/azure/provider/provider.go index ed0659534d..87e82345a8 100644 --- a/providers/azure/provider/provider.go +++ b/providers/azure/provider/provider.go @@ -117,12 +117,12 @@ func handleAzureComputeSubcommands(args []string, config *inventory.Config) erro config.Type = string(azureinstancesnapshot.SnapshotConnectionType) config.Discover = nil config.Options["type"] = "instance" - config.Options["target-name"] = args[2] + config.Options["target"] = args[2] return nil case "snapshot": config.Type = string(azureinstancesnapshot.SnapshotConnectionType) config.Options["type"] = "snapshot" - config.Options["target-name"] = args[2] + config.Options["target"] = args[2] config.Discover = nil return nil default: