diff --git a/castai/resource_autoscaler.go b/castai/resource_autoscaler.go index 5cca3a88..c13f66d3 100644 --- a/castai/resource_autoscaler.go +++ b/castai/resource_autoscaler.go @@ -5,12 +5,12 @@ import ( "context" "encoding/json" "fmt" + jsonpatch "github.com/evanphx/json-patch" "io" "log" "net/http" "time" - jsonpatch "github.com/evanphx/json-patch" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" @@ -29,6 +29,7 @@ func resourceAutoscaler() *schema.Resource { CreateContext: resourceCastaiAutoscalerCreate, UpdateContext: resourceCastaiAutoscalerUpdate, DeleteContext: resourceCastaiAutoscalerDelete, + CustomizeDiff: resourceCastaiAutoscalerDiff, Description: "CAST AI autoscaler resource to manage autoscaler settings", Timeouts: &schema.ResourceTimeout{ @@ -74,6 +75,23 @@ func resourceCastaiAutoscalerDelete(ctx context.Context, data *schema.ResourceDa return nil } +func resourceCastaiAutoscalerDiff(ctx context.Context, d *schema.ResourceDiff, meta interface{}) error { + clusterId := getClusterId(d) + if clusterId == "" { + return nil + } + + policies, err := getChangedPolicies(ctx, d, meta, clusterId) + if err != nil { + return err + } + if policies == nil { + return nil + } + + return d.SetNew(FieldAutoscalerPolicies, string(policies)) +} + func resourceCastaiAutoscalerRead(ctx context.Context, data *schema.ResourceData, meta interface{}) diag.Diagnostics { err := readAutoscalerPolicies(ctx, data, meta) if err != nil { @@ -119,15 +137,15 @@ func getCurrentPolicies(ctx context.Context, client *sdk.ClientWithResponses, cl return nil, fmt.Errorf("cluster %s policies do not exist at CAST AI", clusterId) } - bytes, err := io.ReadAll(resp.Body) + responseBytes, err := io.ReadAll(resp.Body) defer resp.Body.Close() if err != nil { return nil, fmt.Errorf("reading response body: %w", err) } - log.Printf("[DEBUG] Read autoscaler policies for cluster %s:\n%v\n", clusterId, string(bytes)) + log.Printf("[DEBUG] Read autoscaler policies for cluster %s:\n%v\n", clusterId, string(responseBytes)) - return bytes, nil + return normalizeJSON(responseBytes) } func updateAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, meta interface{}) error { @@ -137,18 +155,17 @@ func updateAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, me return nil } - err := readAutoscalerPolicies(ctx, data, meta) + policies, err := getChangedPolicies(ctx, data, meta, clusterId) if err != nil { return err } - changedPolicies, found := data.GetOk(FieldAutoscalerPolicies) - if !found { - log.Printf("[DEBUG] changed policies json not found. Skipping autoscaler policies changes") + if policies == nil { + log.Printf("[DEBUG] changed policies json not calculated. Skipping autoscaler policies changes") return nil } - changedPoliciesJSON := changedPolicies.(string) + changedPoliciesJSON := string(policies) if changedPoliciesJSON == "" { log.Printf("[DEBUG] changed policies json not found. Skipping autoscaler policies changes") return nil @@ -178,12 +195,13 @@ func readAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, meta return nil } - policies, err := getChangedPolicies(ctx, data, meta, clusterId) + client := meta.(*ProviderConfig).api + currentPolicies, err := getCurrentPolicies(ctx, client, clusterId) if err != nil { return err } - err = data.Set(FieldAutoscalerPolicies, string(policies)) + err = data.Set(FieldAutoscalerPolicies, string(currentPolicies)) if err != nil { log.Printf("[ERROR] Failed to set field: %v", err) return err @@ -192,7 +210,16 @@ func readAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, meta return nil } -func getChangedPolicies(ctx context.Context, data *schema.ResourceData, meta interface{}, clusterId string) ([]byte, error) { +func getClusterId(data resourceProvider) string { + value, found := data.GetOk(FieldClusterId) + if !found { + return "" + } + + return value.(string) +} + +func getChangedPolicies(ctx context.Context, data resourceProvider, meta interface{}, clusterId string) ([]byte, error) { policyChangesJSON, found := data.GetOk(FieldAutoscalerPoliciesJSON) if !found { log.Printf("[DEBUG] policies json not provided. Skipping autoscaler policies changes") @@ -219,16 +246,7 @@ func getChangedPolicies(ctx context.Context, data *schema.ResourceData, meta int return nil, fmt.Errorf("failed to merge policies: %v", err) } - return policies, nil -} - -func getClusterId(data *schema.ResourceData) string { - value, found := data.GetOk(FieldClusterId) - if !found { - return "" - } - - return value.(string) + return normalizeJSON(policies) } func validateAutoscalerPolicyJSON() schema.SchemaValidateDiagFunc { diff --git a/castai/resource_autoscaler_test.go b/castai/resource_autoscaler_test.go index 24324966..7efca54a 100644 --- a/castai/resource_autoscaler_test.go +++ b/castai/resource_autoscaler_test.go @@ -256,17 +256,6 @@ func TestAutoscalerResource_PoliciesUpdateAction_Fail(t *testing.T) { r.Equal(`expected status code 200, received: status=400 body={"message":"policies config: Evictor policy management is not allowed: Evictor installed externally. Uninstall Evictor first and try again.","fieldViolations":[]`, result[0].Summary) } -func JSONBytesEqual(a, b []byte) (bool, error) { - var j, j2 interface{} - if err := json.Unmarshal(a, &j); err != nil { - return false, err - } - if err := json.Unmarshal(b, &j2); err != nil { - return false, err - } - return reflect.DeepEqual(j2, j), nil -} - func Test_validateAutoscalerPolicyJSON(t *testing.T) { type testData struct { json string @@ -392,3 +381,216 @@ func Test_validateAutoscalerPolicyJSON(t *testing.T) { }) } } + +func TestAutoscalerResource_ReadPoliciesAction(t *testing.T) { + r := require.New(t) + mockctrl := gomock.NewController(t) + mockClient := mock_sdk.NewMockClientInterface(mockctrl) + ctx := context.Background() + provider := &ProviderConfig{ + api: &sdk.ClientWithResponses{ + ClientInterface: mockClient, + }, + } + + currentPoliciesBytes, err := normalizeJSON([]byte(` + { + "enabled": true, + "isScopedMode": false, + "unschedulablePods": { + "enabled": true, + "headroom": { + "cpuPercentage": 10, + "memoryPercentage": 10, + "enabled": true + }, + "headroomSpot": { + "cpuPercentage": 10, + "memoryPercentage": 10, + "enabled": true + }, + "nodeConstraints": { + "minCpuCores": 2, + "maxCpuCores": 32, + "minRamMib": 4096, + "maxRamMib": 262144, + "enabled": false + }, + "diskGibToCpuRatio": 25 + }, + "clusterLimits": { + "enabled": false, + "cpu": { + "minCores": 1, + "maxCores": 20 + } + }, + "nodeDownscaler": { + "emptyNodes": { + "enabled": false, + "delaySeconds": 0 + } + } + }`)) + r.NoError(err) + + currentPolicies := string(currentPoliciesBytes) + resource := resourceAutoscaler() + + clusterId := "cluster_id" + val := cty.ObjectVal(map[string]cty.Value{ + FieldClusterId: cty.StringVal(clusterId), + }) + state := terraform.NewInstanceStateShimmedFromValue(val, 0) + data := resource.Data(state) + + body := io.NopCloser(bytes.NewReader([]byte(currentPolicies))) + response := &http.Response{StatusCode: 200, Body: body} + + mockClient.EXPECT().PoliciesAPIGetClusterPolicies(gomock.Any(), clusterId, gomock.Any()).Return(response, nil).Times(1) + mockClient.EXPECT().PoliciesAPIUpsertClusterPoliciesWithBody(gomock.Any(), clusterId, "application/json", gomock.Any()). + Times(0) + + result := resource.ReadContext(ctx, data, provider) + r.Nil(result) + r.Equal(currentPolicies, data.Get(FieldAutoscalerPolicies)) +} + +func TestAutoscalerResource_CustomizeDiff(t *testing.T) { + r := require.New(t) + mockctrl := gomock.NewController(t) + mockClient := mock_sdk.NewMockClientInterface(mockctrl) + ctx := context.Background() + provider := &ProviderConfig{ + api: &sdk.ClientWithResponses{ + ClientInterface: mockClient, + }, + } + + currentPoliciesBytes, err := normalizeJSON([]byte(` + { + "enabled": true, + "isScopedMode": false, + "unschedulablePods": { + "enabled": true, + "headroom": { + "cpuPercentage": 10, + "memoryPercentage": 10, + "enabled": true + }, + "headroomSpot": { + "cpuPercentage": 10, + "memoryPercentage": 10, + "enabled": true + }, + "nodeConstraints": { + "minCpuCores": 2, + "maxCpuCores": 32, + "minRamMib": 4096, + "maxRamMib": 262144, + "enabled": false + }, + "diskGibToCpuRatio": 25 + }, + "clusterLimits": { + "enabled": false, + "cpu": { + "minCores": 1, + "maxCores": 20 + } + }, + "nodeDownscaler": { + "emptyNodes": { + "enabled": false, + "delaySeconds": 0 + } + } + }`)) + r.NoError(err) + + policyChangeBytes, err := normalizeJSON([]byte(` + { + "enabled": false, + "unschedulablePods": { + "enabled": false + } + }`)) + r.NoError(err) + + expectedPoliciesBytes, err := normalizeJSON([]byte(` + { + "enabled": false, + "isScopedMode": false, + "unschedulablePods": { + "enabled": false, + "headroom": { + "cpuPercentage": 10, + "memoryPercentage": 10, + "enabled": true + }, + "headroomSpot": { + "cpuPercentage": 10, + "memoryPercentage": 10, + "enabled": true + }, + "nodeConstraints": { + "minCpuCores": 2, + "maxCpuCores": 32, + "minRamMib": 4096, + "maxRamMib": 262144, + "enabled": false + }, + "diskGibToCpuRatio": 25 + }, + "clusterLimits": { + "enabled": false, + "cpu": { + "minCores": 1, + "maxCores": 20 + } + }, + "nodeDownscaler": { + "emptyNodes": { + "enabled": false, + "delaySeconds": 0 + } + } + }`)) + r.NoError(err) + + currentPolicies := string(currentPoliciesBytes) + policyChanges := string(policyChangeBytes) + expectedPolicies := string(expectedPoliciesBytes) + resource := resourceAutoscaler() + + clusterId := "cluster_id" + val := cty.ObjectVal(map[string]cty.Value{ + FieldAutoscalerPoliciesJSON: cty.StringVal(policyChanges), + FieldClusterId: cty.StringVal(clusterId), + }) + state := terraform.NewInstanceStateShimmedFromValue(val, 0) + data := resource.Data(state) + r.NoError(err) + + body := io.NopCloser(bytes.NewReader([]byte(currentPolicies))) + response := &http.Response{StatusCode: 200, Body: body} + + mockClient.EXPECT().PoliciesAPIGetClusterPolicies(gomock.Any(), clusterId, gomock.Any()).Return(response, nil).Times(1) + mockClient.EXPECT().PoliciesAPIUpsertClusterPoliciesWithBody(gomock.Any(), clusterId, "application/json", gomock.Any()). + Times(0) + + result, err := getChangedPolicies(ctx, data, provider, clusterId) + r.NoError(err) + r.Equal(expectedPolicies, string(result)) +} + +func JSONBytesEqual(a, b []byte) (bool, error) { + var j, j2 interface{} + if err := json.Unmarshal(a, &j); err != nil { + return false, err + } + if err := json.Unmarshal(b, &j2); err != nil { + return false, err + } + return reflect.DeepEqual(j2, j), nil +} diff --git a/castai/sdk/api.gen.go b/castai/sdk/api.gen.go index c4fd6afa..86331bfe 100644 --- a/castai/sdk/api.gen.go +++ b/castai/sdk/api.gen.go @@ -7,6 +7,8 @@ import ( "encoding/json" "fmt" "time" + + openapi_types "github.com/deepmap/oapi-codegen/pkg/types" ) const ( @@ -2325,6 +2327,9 @@ type ScheduledrebalancingV1TriggerConditions struct { SavingsPercentage *float32 `json:"savingsPercentage,omitempty"` } +// HeaderOrganizationId defines model for headerOrganizationId. +type HeaderOrganizationId = openapi_types.UUID + // AuthTokenAPIListAuthTokensParams defines parameters for AuthTokenAPIListAuthTokens. type AuthTokenAPIListAuthTokensParams struct { UserId *string `form:"userId,omitempty" json:"userId,omitempty"` diff --git a/castai/utils.go b/castai/utils.go index 1f566e32..bab94e5b 100644 --- a/castai/utils.go +++ b/castai/utils.go @@ -7,6 +7,10 @@ import ( "golang.org/x/exp/constraints" ) +type resourceProvider interface { + GetOk(key string) (interface{}, bool) +} + func toPtr[S any](src S) *S { return &src } @@ -107,3 +111,12 @@ func toNilList[T any](l *[]T) *[]T { } return l } + +func normalizeJSON(bytes []byte) ([]byte, error) { + var output interface{} + err := json.Unmarshal(bytes, &output) + if err != nil { + return nil, err + } + return json.Marshal(output) +}