|
| 1 | +package awscommons |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "math" |
| 6 | + "strings" |
| 7 | + "time" |
| 8 | + |
| 9 | + "github.com/aws/aws-sdk-go-v2/aws" |
| 10 | + "github.com/aws/aws-sdk-go-v2/service/ecs" |
| 11 | + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" |
| 12 | + "github.com/gruntwork-io/go-commons/collections" |
| 13 | + "github.com/gruntwork-io/go-commons/errors" |
| 14 | + "github.com/gruntwork-io/go-commons/logging" |
| 15 | + "github.com/gruntwork-io/go-commons/retry" |
| 16 | +) |
| 17 | + |
| 18 | +// GetContainerInstanceArns gets the container instance ARNs of all the EC2 instances in an ECS Cluster. |
| 19 | +// ECS container instance ARNs are different from EC2 instance IDs! |
| 20 | +// An ECS container instance is an EC2 instance that runs the ECS container agent and has been registered into |
| 21 | +// an ECS cluster. |
| 22 | +// Example identifiers: |
| 23 | +// - EC2 instance ID: i-08e8cfc073db135a9 |
| 24 | +// - container instance ID: 2db66342-5f69-4782-89a3-f9b707f979ab |
| 25 | +// - container instance ARN: arn:aws:ecs:us-east-1:012345678910:container-instance/2db66342-5f69-4782-89a3-f9b707f979ab |
| 26 | +func GetContainerInstanceArns(opts *Options, clusterName string) ([]string, error) { |
| 27 | + client, err := NewECSClient(opts) |
| 28 | + if err != nil { |
| 29 | + return nil, err |
| 30 | + } |
| 31 | + |
| 32 | + logger := logging.GetProjectLogger() |
| 33 | + logger.Infof("Looking up Container Instance ARNs for ECS cluster %s", clusterName) |
| 34 | + |
| 35 | + input := &ecs.ListContainerInstancesInput{Cluster: aws.String(clusterName)} |
| 36 | + arns := []string{} |
| 37 | + // Handle pagination by repeatedly making the API call while there is a next token set. |
| 38 | + for { |
| 39 | + result, err := client.ListContainerInstances(opts.Context, input) |
| 40 | + if err != nil { |
| 41 | + return nil, errors.WithStackTrace(err) |
| 42 | + } |
| 43 | + arns = append(arns, result.ContainerInstanceArns...) |
| 44 | + if result.NextToken == nil { |
| 45 | + break |
| 46 | + } |
| 47 | + input.NextToken = result.NextToken |
| 48 | + } |
| 49 | + |
| 50 | + return arns, nil |
| 51 | +} |
| 52 | + |
| 53 | +// StartDrainingContainerInstances puts ECS container instances in DRAINING state so that all ECS Tasks running on |
| 54 | +// them are migrated to other container instances. Batches into chunks of 10 because of AWS API limitations. |
| 55 | +// (An error occurred InvalidParameterException when calling the UpdateContainerInstancesState |
| 56 | +// operation: instanceIds can have at most 10 items.) |
| 57 | +func StartDrainingContainerInstances(opts *Options, clusterName string, containerInstanceArns []string) error { |
| 58 | + client, err := NewECSClient(opts) |
| 59 | + if err != nil { |
| 60 | + return err |
| 61 | + } |
| 62 | + |
| 63 | + logger := logging.GetProjectLogger() |
| 64 | + batchSize := 10 |
| 65 | + numBatches := int(math.Ceil(float64(len(containerInstanceArns) / batchSize))) |
| 66 | + |
| 67 | + errList := NewMultipleDrainContainerInstanceErrors() |
| 68 | + for batchIdx, batchedArnList := range collections.BatchListIntoGroupsOf(containerInstanceArns, batchSize) { |
| 69 | + batchedArns := aws.StringSlice(batchedArnList) |
| 70 | + |
| 71 | + logger.Infof("Putting batch %d/%d of container instances in cluster %s into DRAINING state", batchIdx, numBatches, clusterName) |
| 72 | + input := &ecs.UpdateContainerInstancesStateInput{ |
| 73 | + Cluster: aws.String(clusterName), |
| 74 | + ContainerInstances: aws.ToStringSlice(batchedArns), |
| 75 | + Status: "DRAINING", |
| 76 | + } |
| 77 | + _, err := client.UpdateContainerInstancesState(opts.Context, input) |
| 78 | + if err != nil { |
| 79 | + errList.AddError(err) |
| 80 | + logger.Errorf("Encountered error starting to drain container instances in batch %d: %s", batchIdx, err) |
| 81 | + logger.Errorf("Container Instance ARNs: %s", strings.Join(batchedArnList, ",")) |
| 82 | + continue |
| 83 | + } |
| 84 | + |
| 85 | + logger.Infof("Started draining %d container instances from batch %d", len(batchedArnList), batchIdx) |
| 86 | + } |
| 87 | + |
| 88 | + if !errList.IsEmpty() { |
| 89 | + return errors.WithStackTrace(errList) |
| 90 | + } |
| 91 | + logger.Infof("Successfully started draining all %d container instances", len(containerInstanceArns)) |
| 92 | + return nil |
| 93 | +} |
| 94 | + |
| 95 | +// WaitForContainerInstancesToDrain waits until there are no more ECS Tasks running on any of the ECS container |
| 96 | +// instances. Batches container instances in groups of 100 because of AWS API limitations. |
| 97 | +func WaitForContainerInstancesToDrain(opts *Options, clusterName string, containerInstanceArns []string, start time.Time, timeout time.Duration, maxRetries int, sleepBetweenRetries time.Duration) error { |
| 98 | + client, err := NewECSClient(opts) |
| 99 | + if err != nil { |
| 100 | + return err |
| 101 | + } |
| 102 | + |
| 103 | + logger := logging.GetProjectLogger() |
| 104 | + logger.Infof("Checking if all ECS Tasks have been drained from the ECS Container Instances in Cluster %s.", clusterName) |
| 105 | + |
| 106 | + batchSize := 100 |
| 107 | + numBatches := int(math.Ceil(float64(len(containerInstanceArns) / batchSize))) |
| 108 | + |
| 109 | + err = retry.DoWithRetry( |
| 110 | + logger.Logger, |
| 111 | + "Wait for Container Instances to be Drained", |
| 112 | + maxRetries, sleepBetweenRetries, |
| 113 | + func() error { |
| 114 | + responses := []*ecs.DescribeContainerInstancesOutput{} |
| 115 | + for batchIdx, batchedArnList := range collections.BatchListIntoGroupsOf(containerInstanceArns, batchSize) { |
| 116 | + batchedArns := aws.StringSlice(batchedArnList) |
| 117 | + |
| 118 | + logger.Infof("Fetching description of batch %d/%d of ECS Instances in Cluster %s.", batchIdx, numBatches, clusterName) |
| 119 | + input := &ecs.DescribeContainerInstancesInput{ |
| 120 | + Cluster: aws.String(clusterName), |
| 121 | + ContainerInstances: aws.ToStringSlice(batchedArns), |
| 122 | + } |
| 123 | + result, err := client.DescribeContainerInstances(opts.Context, input) |
| 124 | + if err != nil { |
| 125 | + return errors.WithStackTrace(err) |
| 126 | + } |
| 127 | + responses = append(responses, result) |
| 128 | + } |
| 129 | + |
| 130 | + // If we exceeded the timeout, halt with error. |
| 131 | + if timeoutExceeded(start, timeout) { |
| 132 | + return retry.FatalError{Underlying: fmt.Errorf("maximum drain timeout of %s seconds has elapsed and instances are still draining", timeout)} |
| 133 | + } |
| 134 | + |
| 135 | + // Yay, all done. |
| 136 | + if drained, _ := allInstancesFullyDrained(responses); drained == true { |
| 137 | + logger.Infof("All container instances have been drained in Cluster %s!", clusterName) |
| 138 | + return nil |
| 139 | + } |
| 140 | + |
| 141 | + // If there's no error, retry. |
| 142 | + if err == nil { |
| 143 | + return errors.WithStackTrace(fmt.Errorf("container instances still draining")) |
| 144 | + } |
| 145 | + |
| 146 | + // Else, there's an error, halt and fail. |
| 147 | + return retry.FatalError{Underlying: err} |
| 148 | + }) |
| 149 | + return errors.WithStackTrace(err) |
| 150 | +} |
| 151 | + |
| 152 | +// timeoutExceeded returns true if the amount of time since start has exceeded the timeout. |
| 153 | +func timeoutExceeded(start time.Time, timeout time.Duration) bool { |
| 154 | + timeElapsed := time.Now().Sub(start) |
| 155 | + return timeElapsed > timeout |
| 156 | +} |
| 157 | + |
| 158 | +// NewECSClient returns a new AWS SDK client for interacting with AWS ECS. |
| 159 | +func NewECSClient(opts *Options) (*ecs.Client, error) { |
| 160 | + cfg, err := NewDefaultConfig(opts) |
| 161 | + if err != nil { |
| 162 | + return nil, errors.WithStackTrace(err) |
| 163 | + } |
| 164 | + return ecs.NewFromConfig(cfg), nil |
| 165 | +} |
| 166 | + |
| 167 | +func allInstancesFullyDrained(responses []*ecs.DescribeContainerInstancesOutput) (bool, error) { |
| 168 | + for _, response := range responses { |
| 169 | + instances := response.ContainerInstances |
| 170 | + if len(instances) == 0 { |
| 171 | + return false, errors.WithStackTrace(fmt.Errorf("querying DescribeContainerInstances returned no instances")) |
| 172 | + } |
| 173 | + |
| 174 | + for _, instance := range instances { |
| 175 | + if !instanceFullyDrained(instance) { |
| 176 | + return false, nil |
| 177 | + } |
| 178 | + } |
| 179 | + } |
| 180 | + return true, nil |
| 181 | +} |
| 182 | + |
| 183 | +func instanceFullyDrained(instance ecsTypes.ContainerInstance) bool { |
| 184 | + logger := logging.GetProjectLogger() |
| 185 | + instanceArn := instance.ContainerInstanceArn |
| 186 | + |
| 187 | + if *instance.Status == "ACTIVE" { |
| 188 | + logger.Infof("The ECS Container Instance %s is still in ACTIVE status", *instanceArn) |
| 189 | + return false |
| 190 | + } |
| 191 | + if instance.PendingTasksCount > 0 { |
| 192 | + logger.Infof("The ECS Container Instance %s still has pending tasks", *instanceArn) |
| 193 | + return false |
| 194 | + } |
| 195 | + if instance.RunningTasksCount > 0 { |
| 196 | + logger.Infof("The ECS Container Instance %s still has running tasks", *instanceArn) |
| 197 | + return false |
| 198 | + } |
| 199 | + |
| 200 | + return true |
| 201 | +} |
0 commit comments