diff --git a/internal/capacity/drainer.go b/internal/capacity/drainer.go index 6ca56a1..ba8c0af 100644 --- a/internal/capacity/drainer.go +++ b/internal/capacity/drainer.go @@ -55,7 +55,7 @@ func (d *drainer) Drain(ctx context.Context, instanceIDs []string) error { processedCount += len(instances) arns := make([]*string, len(instances)) - fmt.Printf("Drain the following container instances in the cluster \"%s\":\n", d.cluster) + fmt.Printf("Drain the following %d container instances in the cluster \"%s\":\n", len(instances), d.cluster) for i, instance := range instances { arns[i] = instance.ContainerInstanceArn fmt.Printf("\t%s (%s)\n", getContainerInstanceID(*instance.ContainerInstanceArn), *instance.Ec2InstanceId) @@ -72,7 +72,7 @@ func (d *drainer) Drain(ctx context.Context, instanceIDs []string) error { return xerrors.Errorf("no target instances exist in the cluster \"%s\"", d.cluster) } if processedCount != len(instanceIDs) { - return xerrors.Errorf("%d instances should be drained but only %d instances was drained", len(instanceIDs), processedCount) + return xerrors.Errorf("%d instances should be drained but only %d instances were drained", len(instanceIDs), processedCount) } return nil @@ -226,11 +226,14 @@ func (d *drainer) drainContainerInstances(ctx context.Context, arns []*string, w func (d *drainer) processContainerInstances(ctx context.Context, instanceIDs []string, callback func([]ecstypes.ContainerInstance) error) error { params := &ecs.ListContainerInstancesInput{ - Cluster: aws.String(d.cluster), - Filter: aws.String(fmt.Sprintf("ec2InstanceId in [%s]", strings.Join(instanceIDs, ","))), - MaxResults: aws.Int32(d.batchSize), + Cluster: aws.String(d.cluster), + Filter: aws.String(fmt.Sprintf("ec2InstanceId in [%s]", strings.Join(instanceIDs, ","))), + } + if err := d.waitUntilContainerInstancesRegistered(ctx, len(instanceIDs), params); err != nil { + return xerrors.Errorf("failed to wait until container instances are registered: %w", err) } + params.MaxResults = aws.Int32(d.batchSize) paginator := ecs.NewListContainerInstancesPaginator(d.ecsSvc, params) for paginator.HasMorePages() { page, err := paginator.NextPage(ctx) @@ -250,13 +253,44 @@ func (d *drainer) processContainerInstances(ctx context.Context, instanceIDs []s } if err := callback(resp.ContainerInstances); err != nil { - return xerrors.Errorf("failed to list container instances: %w", err) + return xerrors.Errorf("failed to execute the callback: %w", err) } } return nil } +func (d *drainer) waitUntilContainerInstancesRegistered(ctx context.Context, count int, params *ecs.ListContainerInstancesInput) error { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + timeout := 5 * time.Minute + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + instanceCount := 0 + paginator := ecs.NewListContainerInstancesPaginator(d.ecsSvc, params) + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return xerrors.Errorf("failed to list container instances: %w", err) + } + instanceCount += len(page.ContainerInstanceArns) + } + if instanceCount == count { + return nil + } + + select { + case <-ticker.C: + continue + case <-timer.C: + return xerrors.Errorf("%d container instances expect to be registered but only %d instances were registered within %v", count, instanceCount, timeout) + } + } +} + func getContainerInstanceID(arn string) string { return arn[strings.LastIndex(arn, "/")+1:] }