From d22949acb8717711a63288e12687fb582abf569b Mon Sep 17 00:00:00 2001 From: Malte Poll Date: Fri, 6 Oct 2023 17:41:08 +0200 Subject: [PATCH] config validation with rego --- azure/uploader.go | 10 +- config/config.go | 82 +++++--- config/config_test.go | 35 ++-- config/validation.go | 1 - config/validation.rego | 397 +++++++++++++++++++++++++++++++++++++- config/validation_test.go | 361 +++++++++++++++++++++++++++++++++- flake.nix | 2 +- 7 files changed, 825 insertions(+), 63 deletions(-) diff --git a/azure/uploader.go b/azure/uploader.go index 426abe3..eeeb95d 100644 --- a/azure/uploader.go +++ b/azure/uploader.go @@ -335,7 +335,7 @@ func (u *Uploader) ensureManagedImageDeleted(ctx context.Context) error { // ensureSIG creates a SIG if it does not exist yet. func (u *Uploader) ensureSIG(ctx context.Context) error { rg := u.config.Azure.ResourceGroup - sigName := u.config.Azure.SharedImageGalleryName + sigName := u.config.Azure.SharedImageGallery pubNamePrefix := u.config.Azure.SharingNamePrefix sharingProf := sharingProfilePermissionFromString(u.config.Azure.SharingProfile) @@ -406,7 +406,7 @@ func sharingProfilePermissionFromString(s string) *armcomputev5.GallerySharingPe // ensureImageDefinition creates an image definition (component of a SIG) if it does not exist yet. func (u *Uploader) ensureImageDefinition(ctx context.Context) error { rg := u.config.Azure.ResourceGroup - sigName := u.config.Azure.SharedImageGalleryName + sigName := u.config.Azure.SharedImageGallery attestVariant := u.config.Azure.AttestationVariant defName := u.config.Azure.ImageDefinitionName @@ -459,7 +459,7 @@ func (u *Uploader) ensureImageDefinition(ctx context.Context) error { func (u *Uploader) createImageVersion(ctx context.Context, imageID string) (string, error) { rg := u.config.Azure.ResourceGroup - sigName := u.config.Azure.SharedImageGalleryName + sigName := u.config.Azure.SharedImageGallery verName := u.config.ImageVersion defName := u.config.Azure.ImageDefinitionName @@ -500,7 +500,7 @@ func (u *Uploader) createImageVersion(ctx context.Context, imageID string) (stri func (u *Uploader) ensureImageVersionDeleted(ctx context.Context) error { rg := u.config.Azure.ResourceGroup - sigName := u.config.Azure.SharedImageGalleryName + sigName := u.config.Azure.SharedImageGallery verName := u.config.ImageVersion defName := u.config.Azure.ImageDefinitionName @@ -529,7 +529,7 @@ func (u *Uploader) ensureImageVersionDeleted(ctx context.Context) error { func (u *Uploader) getImageReference(ctx context.Context, unsharedID string) (string, error) { rg := u.config.Azure.ResourceGroup location := u.config.Azure.Location - sigName := u.config.Azure.SharedImageGalleryName + sigName := u.config.Azure.SharedImageGallery verName := u.config.ImageVersion defName := u.config.Azure.ImageDefinitionName diff --git a/config/config.go b/config/config.go index b3dc44b..e82ffa3 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,7 @@ SPDX-License-Identifier: Apache-2.0 package config import ( + "context" "errors" "fmt" "html/template" @@ -16,7 +17,6 @@ import ( uplositemplate "github.com/edgelesssys/uplosi/template" "dario.cat/mergo" - "golang.org/x/mod/semver" ) var defaultConfig = Config{ @@ -63,19 +63,6 @@ func (c *Config) SetDefaults() error { return mergo.Merge(c, defaultConfig, mergo.WithTransformers(&OptionTransformer{})) } -func (c *Config) Validate() error { - if len(c.Provider) == 0 { - return errors.New("provider must be set") - } - if !semver.IsValid("v" + c.ImageVersion) { - return errors.New("imageVersion must be of the form MAJOR.MINOR.PATCH") - } - if len(c.Name) == 0 { - return errors.New("name must be set") - } - return nil -} - // Render renders the config by evaluating the version file and all template strings. func (c *Config) Render(fileLookup func(name string) ([]byte, error)) error { if err := c.renderVersion(fileLookup); err != nil { @@ -95,6 +82,12 @@ func (c *Config) Render(fileLookup func(name string) ([]byte, error)) error { return err } + v := Validator{} + + if err := v.Validate(context.TODO(), *c); err != nil { + return err + } + return nil } @@ -180,18 +173,18 @@ type AWSConfig struct { } type AzureConfig struct { - SubscriptionID string `toml:"subscriptionID,omitempty"` - Location string `toml:"location,omitempty"` - ResourceGroup string `toml:"resourceGroup,omitempty" template:"true"` - AttestationVariant string `toml:"attestationVariant,omitempty" template:"true"` - SharedImageGalleryName string `toml:"sharedImageGallery,omitempty" template:"true"` - SharingProfile string `toml:"sharingProfile,omitempty" template:"true"` - SharingNamePrefix string `toml:"sharingNamePrefix,omitempty" template:"true"` - ImageDefinitionName string `toml:"imageDefinitionName,omitempty" template:"true"` - Offer string `toml:"offer,omitempty" template:"true"` - SKU string `toml:"sku,omitempty" template:"true"` - Publisher string `toml:"publisher,omitempty" template:"true"` - DiskName string `toml:"diskName,omitempty" template:"true"` + SubscriptionID string `toml:"subscriptionID,omitempty"` + Location string `toml:"location,omitempty"` + ResourceGroup string `toml:"resourceGroup,omitempty" template:"true"` + AttestationVariant string `toml:"attestationVariant,omitempty" template:"true"` + SharedImageGallery string `toml:"sharedImageGallery,omitempty" template:"true"` + SharingProfile string `toml:"sharingProfile,omitempty" template:"true"` + SharingNamePrefix string `toml:"sharingNamePrefix,omitempty" template:"true"` + ImageDefinitionName string `toml:"imageDefinitionName,omitempty" template:"true"` + Offer string `toml:"offer,omitempty" template:"true"` + SKU string `toml:"sku,omitempty" template:"true"` + Publisher string `toml:"publisher,omitempty" template:"true"` + DiskName string `toml:"diskName,omitempty" template:"true"` } type GCPConfig struct { @@ -254,7 +247,44 @@ func (c *ConfigFile) RenderedVariant(fileLookup fileLookupFn, name string) (Conf return out, nil } +func (c *ConfigFile) validateAll(fileLookup fileLookupFn, filters ...variantFilter) error { + var errs error + if len(c.Variants) == 0 { + _, err := c.RenderedVariant(fileLookup, "") + if err != nil { + return fmt.Errorf("validating config: %w", err) + } + } + + variantNames := make([]string, 0, len(c.Variants)) + for name := range c.Variants { + var filtered bool + for _, filter := range filters { + if !filter(name) { + filtered = true + break + } + } + if filtered { + continue + } + variantNames = append(variantNames, name) + } + slices.Sort(variantNames) + for _, name := range variantNames { + _, err := c.RenderedVariant(fileLookup, name) + if err != nil { + errs = errors.Join(errs, fmt.Errorf("config for variant %s: %w", name, err)) + } + } + return errs +} + func (c *ConfigFile) ForEach(fn func(name string, cfg Config) error, fileLookup fileLookupFn, filters ...variantFilter) error { + if err := c.validateAll(fileLookup, filters...); err != nil { + return err + } + if len(c.Variants) == 0 { cfg, err := c.RenderedVariant(fileLookup, "") if err != nil { diff --git a/config/config_test.go b/config/config_test.go index 98a51e3..df94191 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -17,11 +17,12 @@ func TestConfigRenderVersionFromFile(t *testing.T) { lookup := stubFileLookup{ "image-version.txt": []byte("0.0.2"), } - config := Config{ + config := fullConfig() + config.Merge(Config{ Name: "test", ImageVersion: "0.0.1", // this will be overwritten by the file ImageVersionFile: "image-version.txt", - } + }) assert.NoError(config.Render(lookup.Lookup)) assert.Equal("0.0.2", config.ImageVersion) } @@ -29,13 +30,14 @@ func TestConfigRenderVersionFromFile(t *testing.T) { func TestConfigRenderTemplate(t *testing.T) { assert := assert.New(t) lookup := stubFileLookup{} - config := Config{ + config := fullConfig() + config.Merge(Config{ Name: "name", ImageVersion: "0.0.1", GCP: GCPConfig{ ImageName: "prefix-{{.Name}}-{{replaceAll .Version \".\" \"-\"}}-suffix", }, - } + }) assert.NoError(config.Render(lookup.Lookup)) assert.Equal("prefix-name-0-0-1-suffix", config.GCP.ImageName) } @@ -144,6 +146,7 @@ func (s stubFileLookup) Lookup(name string) ([]byte, error) { func fullConfig() Config { return Config{ + Provider: "aws", ImageVersion: "0.0.1", Name: "test", AWS: AWSConfig{ @@ -157,18 +160,18 @@ func fullConfig() Config { Publish: Some[bool](true), }, Azure: AzureConfig{ - SubscriptionID: "subscription-id", - Location: "location", - ResourceGroup: "resource-group", - AttestationVariant: "attestation-variant", - SharedImageGalleryName: "shared-image-gallery", - SharingProfile: "sharing-profile", - SharingNamePrefix: "sharing-name-prefix", - ImageDefinitionName: "image-definition-name-template", - Offer: "offer", - SKU: "sku", - Publisher: "publisher", - DiskName: "disk-name", + SubscriptionID: "subscription-id", + Location: "location", + ResourceGroup: "resource-group", + AttestationVariant: "attestation-variant", + SharedImageGallery: "shared-image-gallery", + SharingProfile: "sharing-profile", + SharingNamePrefix: "sharing-name-prefix", + ImageDefinitionName: "image-definition-name-template", + Offer: "offer", + SKU: "sku", + Publisher: "publisher", + DiskName: "disk-name", }, GCP: GCPConfig{ Project: "project", diff --git a/config/validation.go b/config/validation.go index 2ce5d79..83f6f0a 100644 --- a/config/validation.go +++ b/config/validation.go @@ -25,7 +25,6 @@ func (v *Validator) Validate(ctx context.Context, config Config) error { if err != nil { return fmt.Errorf("evaluating policy: %w", err) } - fmt.Println(res) var resErr error for _, result := range res { diff --git a/config/validation.rego b/config/validation.rego index 9ae5f52..2ef81e4 100644 --- a/config/validation.rego +++ b/config/validation.rego @@ -3,14 +3,403 @@ package config import future.keywords.in deny[msg] { - not input.Provider in [ "aws", "azure", "gcp" ] + not input.Provider in valid_csps - msg = sprintf("cloud provider %s unknown", [input.Provider]) + msg = sprintf("cloud provider %q unknown", [input.Provider]) +} + +deny[msg] { + not regex.match(`^\d+\.\d+\.\d+$`, input.ImageVersion) + + msg = sprintf("image version %q must be in format ..", [input.ImageVersion]) +} + +deny[msg] { + input.Name == "" + + msg = "required field name empty" +} + +deny[msg] { + input.Provider == "aws" + some "" in input.AWS.ReplicationRegions + + msg = "member of list replicationRegions empty for provider aws" +} + +deny[msg] { + input.Provider == "aws" + input.AWS.AMIName != "" + not length_in_range(input.AWS.AMIName, 3, 128) + + msg = sprintf("field amiName must be between 3 and 128 characters for provider aws, got %d", [count(input.AWS.AMIName)]) +} + +deny[msg] { + input.Provider == "aws" + input.AWS.AMIName != "" + not regex.match(`^[a-zA-Z0-9().\-/_]+$`, input.AWS.AMIName) + + msg = sprintf("ami name %q should only contain letters, numbers, '(', ')', '.', '-', '/' and '_'", [input.AWS.AMIName]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 1 +deny[msg] { + input.Provider == "aws" + input.AWS.Bucket != "" + not length_in_range(input.AWS.Bucket, 3, 63) + + msg = sprintf("field bucket must be between 3 and 63 characters for provider aws, got %d", [count(input.AWS.Bucket)]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 2 +deny[msg] { + input.Provider == "aws" + input.AWS.Bucket != "" + not regex.match(`^[a-z0-9.\-]+$`, input.AWS.Bucket) + + msg = sprintf("bucket name %q should only contain lowercase letters, numbers, dots (.) and hyphens", [input.AWS.Bucket]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 3 +deny[msg] { + input.Provider == "aws" + input.AWS.Bucket != "" + not begin_and_end_with(input.AWS.Bucket, lowercase_letters | digits) + + msg = sprintf("bucket name %q must begin and end with a letter or number", [input.AWS.Bucket]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 4 +deny[msg] { + input.Provider == "aws" + regex.match(`[.]{2}`, input.AWS.Bucket) + + msg = sprintf("bucket name %q must not contain two adjacent periods", [input.AWS.Bucket]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 6 +deny[msg] { + input.Provider == "aws" + regex.match(`^xn--`, input.AWS.Bucket) + + msg = sprintf("bucket name %q must not start with the prefix xn--", [input.AWS.Bucket]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 7 +deny[msg] { + input.Provider == "aws" + regex.match(`^sthree-`, input.AWS.Bucket) + + msg = sprintf("bucket name %q must not start with the prefix sthree- and the prefix sthree-configurator", [input.AWS.Bucket]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 8 +deny[msg] { + input.Provider == "aws" + regex.match(`-s3alias$`, input.AWS.Bucket) + + msg = sprintf("bucket name %q must not end with the suffix -s3alias", [input.AWS.Bucket]) +} + +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - 9 +deny[msg] { + input.Provider == "aws" + regex.match(`--ol-s3$`, input.AWS.Bucket) + + msg = sprintf("bucket name %q must not end with the suffix --ol-s3", [input.AWS.Bucket]) +} + +deny[msg] { + input.Provider == "aws" + not is_boolean(input.AWS.Publish) + + msg = "required field Publish uninitialized for provider aws" +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SubscriptionID != "" + not regex.match(`^(?:\{{0,1}(?:[0-9a-fA-F]){8}-(?:[0-9a-fA-F]){4}-(?:[0-9a-fA-F]){4}-(?:[0-9a-fA-F]){4}-(?:[0-9a-fA-F]){12}\}{0,1})$$`, input.Azure.SubscriptionID) + + msg = sprintf("subscription id %q must be a valid guid for provider azure", [input.Azure.SubscriptionID]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.AttestationVariant != "" + not input.Azure.AttestationVariant in ["azure-sev-snp", "azure-trustedlaunch"] + + msg = sprintf("attestation variant %q must be one of %s for provider azure", [input.Azure.AttestationVariant, ["azure-sev-snp", "azure-trustedlaunch"]]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SharedImageGallery != "" + not regex.match(`^[a-zA-Z0-9_.]*$`, input.Azure.SharedImageGallery) + + msg = sprintf("shared image gallery %q must contain only alphanumerics, underscores and periods for provider azure", [input.Azure.SharedImageGallery]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SharedImageGallery != "" + not begin_and_end_with(input.Azure.SharedImageGallery, lowercase_letters | uppercase_letters | digits) + + msg = sprintf("shared image gallery %q must begin and end with a letter or number", [input.Azure.SharedImageGallery]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SharedImageGallery != "" + not length_in_range(input.Azure.SharedImageGallery, 1, 80) + + msg = sprintf("field sharedImageGallery must be between 1 and 80 characters for provider azure, got %d", [count(input.Azure.SharedImageGallery)]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SharingProfile != "" + allowed := ["community", "groups", "private"] + not input.Azure.SharingProfile in allowed + + msg = sprintf("sharing profile %q must be one of %s for provider azure", [input.Azure.SharingProfile, allowed]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SharingProfile == "community" + input.Azure.SharingNamePrefix == "" + + msg = "field sharingNamePrefix is required for sharing profile community and provider azure" +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SharingNamePrefix != "" + not length_in_range(input.Azure.SharingNamePrefix, 5, 16) + + msg = sprintf("field sharingNamePrefix must be between 5 and 16 characters for provider azure, got %d", [count(input.Azure.SharingNamePrefix)]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.SharingNamePrefix != "" + not regex.match(`^[a-zA-Z0-9]*$`, input.Azure.SharingNamePrefix) + + msg = sprintf("sharing name prefix %q must be alphanumeric for provider azure", [input.Azure.SharingNamePrefix]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.ImageDefinitionName != "" + not regex.match(`^[a-zA-Z0-9_\-.]*$`, input.Azure.ImageDefinitionName) + + msg = sprintf("image definition name %q must contain only alphanumerics, underscores, hyphens, and periods for provider azure", [input.Azure.ImageDefinitionName]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.ImageDefinitionName != "" + not begin_and_end_with(input.Azure.ImageDefinitionName, lowercase_letters | uppercase_letters | digits) + + msg = sprintf("image definition name %q must begin and end with a letter or number", [input.Azure.ImageDefinitionName]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.ImageDefinitionName != "" + not length_in_range(input.Azure.ImageDefinitionName, 1, 80) + + msg = sprintf("field imageDefinitionName must be between 1 and 80 characters for provider azure, got %d", [count(input.Azure.ImageDefinitionName)]) +} + +deny[msg] { + input.Provider == "azure" + input.Azure.DiskName != "" + not regex.match(`^[a-zA-Z0-9_\-.]*$`, input.Azure.DiskName) + + msg = sprintf("disk name %q must contain only alphanumerics, underscores, hyphens, and periods for provider azure", [input.Azure.DiskName]) } deny[msg] { input.Provider == "azure" - input.Azure.SubscriptionID == "" + input.Azure.DiskName != "" + not length_in_range(input.Azure.DiskName, 1, 80) + + msg = sprintf("field diskName must be between 1 and 80 characters for provider azure, got %d", [count(input.Azure.DiskName)]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.Project != "" + not regex.match(`^[a-z0-9\-]*$`, input.GCP.Project) + + msg = sprintf("project name %q must contain only lowercase letters, digits and hyphens for provider gcp", [input.GCP.Project]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.Project != "" + not begins_with(input.GCP.Project, lowercase_letters) + not ends_with(input.GCP.Project, lowercase_letters | digits) + + msg = sprintf("project name %q must begin with a letter and end with a letter or number", [input.GCP.Project]) +} - msg = "required field subscriptionID empty for provider azure" +deny[msg] { + input.Provider == "gcp" + input.GCP.Project != "" + not length_in_range(input.GCP.Project, 6, 30) + + msg = sprintf("field project must be between 6 and 30 characters for provider gcp, got %d", [count(input.GCP.Project)]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.ImageName != "" + not regex.match(`^[a-z0-9\-]*$`, input.GCP.ImageName) + + msg = sprintf("image name %q must contain only alphanumerics, underscores, hyphens, and periods for provider gcp", [input.GCP.ImageName]) } + +deny[msg] { + input.Provider == "gcp" + input.GCP.ImageName != "" + not begins_with(input.GCP.ImageName, lowercase_letters) + not ends_with(input.GCP.ImageName, lowercase_letters | digits) + + msg = sprintf("image name %q must begin with a letter and end with a letter or number", [input.GCP.ImageName]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.ImageName != "" + not length_in_range(input.GCP.ImageName, 1, 63) + + msg = sprintf("field imageName must be between 1 and 63 characters for provider gcp, got %d", [count(input.GCP.ImageName)]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.ImageFamily != "" + not regex.match(`^[a-z0-9\-]*$`, input.GCP.ImageFamily) + + msg = sprintf("image family %q must contain only alphanumerics, underscores, hyphens, and periods for provider gcp", [input.GCP.ImageFamily]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.ImageFamily != "" + not begins_with(input.GCP.ImageFamily, lowercase_letters) + not ends_with(input.GCP.ImageFamily, lowercase_letters | digits) + + msg = sprintf("image family %q must begin with a letter and end with a letter or number", [input.GCP.ImageFamily]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.ImageFamily != "" + not length_in_range(input.GCP.ImageFamily, 1, 63) + + msg = sprintf("field imageFamily must be between 1 and 63 characters for provider gcp, got %d", [count(input.GCP.ImageFamily)]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.Bucket != "" + not regex.match(`^[a-z0-9\-_.]*$`, input.GCP.Bucket) + + msg = sprintf("bucket %q must contain only alphanumerics, underscores, hyphens, and periods for provider gcp", [input.GCP.Bucket]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.Bucket != "" + not begin_and_end_with(input.GCP.Bucket, lowercase_letters | digits) + + msg = sprintf("bucket %q must begin with a letter and end with a letter or number", [input.GCP.Bucket]) +} + +deny[msg] { + input.Provider == "gcp" + input.GCP.Bucket != "" + not length_in_range(input.GCP.Bucket, 3, 63) + + msg = sprintf("field bucket must be between 1 and 63 characters for provider gcp, got %d", [count(input.GCP.Bucket)]) +} + + +deny[msg] { + some provider in valid_csps + input.Provider == provider + some fieldName, fieldValue in required_fields[provider] + fieldValue == "" + + msg = sprintf("required field %q empty for provider %s", [fieldName, input.Provider]) +} + +length_in_range(s, min_len, max_len) = in_range { + length := count(s) + in_range := all([min_len <= length, length <= max_len]) +} + +begins_with(s, charset) = begin { + begin := substring(s, 0, 1) in charset +} + +ends_with(s, charset) = end { + end := substring(s, count(s)-1, 1) in charset +} + +begin_and_end_with(s, charset) = begin_and_end { + begin_and_end := all([ + begins_with(s, charset), + ends_with(s, charset), + ]) +} + +valid_csps := [ "aws", "azure", "gcp" ] + +required_fields := { + "aws": { + "region": input.AWS.Region, + "replicationRegions": input.AWS.ReplicationRegions, + "amiName": input.AWS.AMIName, + "bucket": input.AWS.Bucket, + "blobName": input.AWS.BlobName, + "snapshotName": input.AWS.SnapshotName, + }, + "azure": { + "subscriptionID": input.Azure.SubscriptionID, + "location": input.Azure.Location, + "resourceGroup": input.Azure.ResourceGroup, + "attestationVariant": input.Azure.AttestationVariant, + "sharedImageGallery": input.Azure.SharedImageGallery, + "sharingProfile": input.Azure.SharingProfile, + "imageDefinitionName": input.Azure.ImageDefinitionName, + "diskName": input.Azure.DiskName, + "offer": input.Azure.Offer, + "sku": input.Azure.SKU, + "publisher": input.Azure.Publisher, + "diskName": input.Azure.DiskName, + }, + "gcp": { + "project": input.GCP.Project, + "location": input.GCP.Location, + "imageName": input.GCP.ImageName, + "imageFamily": input.GCP.ImageFamily, + "bucket": input.GCP.Bucket, + "blobName": input.GCP.BlobName, + }, +} + +lowercase_letters := { + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z" +} + +uppercase_letters := { + "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z" +} + +digits := { "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" } diff --git a/config/validation_test.go b/config/validation_test.go index b9ed9dc..a10a6c0 100644 --- a/config/validation_test.go +++ b/config/validation_test.go @@ -2,6 +2,7 @@ package config import ( "context" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -9,19 +10,312 @@ import ( func TestValidate(t *testing.T) { testCases := map[string]struct { - config Config - wantErr bool + base Config + overrides Config + mutation func(*Config) + wantErr bool }{ - "with subsc": { - config: Config{Provider: "azure", Azure: AzureConfig{SubscriptionID: "test"}}, - wantErr: false, + "empty config": { + wantErr: true, + }, + "full config": { + base: validConfig(), + }, + "valid AWS config": { + base: validConfig(), + overrides: Config{Provider: "aws"}, }, - "without subsc": { - config: Config{Provider: "azure", Azure: AzureConfig{SubscriptionID: "test"}}, - wantErr: false, + "valid Azure config": { + base: validConfig(), + overrides: Config{Provider: "azure"}, + }, + "valid GCP config": { + base: validConfig(), + overrides: Config{Provider: "gcp"}, }, "unknown provider": { - config: Config{Provider: "foo"}, + base: validConfig(), + overrides: Config{Provider: "foo"}, + wantErr: true, + }, + "missing version": { + base: validConfig(), + mutation: func(c *Config) { c.ImageVersion = "" }, + wantErr: true, + }, + "invalid version": { + base: validConfig(), + overrides: Config{ImageVersion: "v1.2.3-dev"}, + wantErr: true, + }, + "missing name": { + base: validConfig(), + mutation: func(c *Config) { c.Name = "" }, + wantErr: true, + }, + "missing AWS region": { + base: validConfig(), + overrides: Config{Provider: "aws"}, + mutation: func(c *Config) { + c.AWS.Region = "" + }, + wantErr: true, + }, + "missing AWS amiName": { + base: validConfig(), + overrides: Config{Provider: "aws"}, + mutation: func(c *Config) { + c.AWS.AMIName = "" + }, + wantErr: true, + }, + "wrong length AWS amiName": { + base: validConfig(), + overrides: Config{ + Provider: "aws", + AWS: AWSConfig{ + AMIName: strings.Repeat("a", 129), + }, + }, + wantErr: true, + }, + "invalid AWS amiName": { + base: validConfig(), + overrides: Config{ + Provider: "aws", + AWS: AWSConfig{ + AMIName: "invalid