Skip to content

Commit 4829302

Browse files
committed
changes to rollback goroutine wait pattern
1 parent 9da7f20 commit 4829302

File tree

1 file changed

+68
-75
lines changed

1 file changed

+68
-75
lines changed

internal/service/ecs/service.go

Lines changed: 68 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"slices"
1313
"strconv"
1414
"strings"
15+
"sync"
1516
"time"
1617

1718
"github.com/YakDriver/regexache"
@@ -2113,7 +2114,11 @@ func statusService(ctx context.Context, conn *ecs.Client, serviceName, clusterNa
21132114
}
21142115
}
21152116

2116-
func statusServiceWaitForStable(ctx context.Context, conn *ecs.Client, serviceName, clusterNameOrARN string, primaryDeploymentArn **string, operationTime time.Time, primaryTaskSet **awstypes.Deployment, isNewECSDeployment *bool) retry.StateRefreshFunc {
2117+
func statusServiceWaitForStable(ctx context.Context, conn *ecs.Client, serviceName, clusterNameOrARN string, sigintConfig *rollbackState, operationTime time.Time) retry.StateRefreshFunc {
2118+
var primaryTaskSet *awstypes.Deployment
2119+
var primaryDeploymentArn *string
2120+
var isNewPrimaryDeployment bool
2121+
21172122
return func() (any, string, error) {
21182123
outputRaw, serviceStatus, err := statusService(ctx, conn, serviceName, clusterNameOrARN)()
21192124
if err != nil {
@@ -2126,36 +2131,41 @@ func statusServiceWaitForStable(ctx context.Context, conn *ecs.Client, serviceNa
21262131

21272132
output := outputRaw.(*awstypes.Service)
21282133

2129-
if *primaryTaskSet == nil {
2130-
*primaryTaskSet = findPrimaryTaskSet(output.Deployments)
2131-
2132-
var isNewPrimaryDeployment bool
2134+
if primaryTaskSet == nil {
2135+
primaryTaskSet = findPrimaryTaskSet(output.Deployments)
21332136

2134-
if *primaryTaskSet != nil && (*primaryTaskSet).CreatedAt != nil {
2137+
if primaryTaskSet != nil && (*primaryTaskSet).CreatedAt != nil {
21352138
createdAtUTC := (*primaryTaskSet).CreatedAt.UTC()
21362139
isNewPrimaryDeployment = createdAtUTC.After(operationTime)
21372140
}
2138-
*isNewECSDeployment = output.DeploymentController != nil &&
2139-
output.DeploymentController.Type == awstypes.DeploymentControllerTypeEcs &&
2140-
isNewPrimaryDeployment
21412141
}
21422142

2143+
isNewECSDeployment := output.DeploymentController != nil &&
2144+
output.DeploymentController.Type == awstypes.DeploymentControllerTypeEcs &&
2145+
isNewPrimaryDeployment
2146+
21432147
// For new deployments with ECS deployment controller, check the deployment status
2144-
if *isNewECSDeployment {
2145-
if *primaryDeploymentArn == nil {
2148+
if isNewECSDeployment {
2149+
if primaryDeploymentArn == nil {
21462150
serviceArn := aws.ToString(output.ServiceArn)
21472151

21482152
var err error
2149-
*primaryDeploymentArn, err = findPrimaryDeploymentARN(ctx, conn, *primaryTaskSet, serviceArn, clusterNameOrARN, operationTime)
2153+
primaryDeploymentArn, err = findPrimaryDeploymentARN(ctx, conn, primaryTaskSet, serviceArn, clusterNameOrARN, operationTime)
21502154
if err != nil {
21512155
return nil, "", err
21522156
}
2153-
if *primaryDeploymentArn == nil {
2157+
if primaryDeploymentArn == nil {
21542158
return output, serviceStatusPending, nil
21552159
}
21562160
}
21572161

2158-
deploymentStatus, err := findDeploymentStatus(ctx, conn, **primaryDeploymentArn)
2162+
if sigintConfig.rollbackRequested && !sigintConfig.rollbackRoutineStarted {
2163+
sigintConfig.waitGroup.Add(1)
2164+
go rollbackRoutine(ctx, conn, sigintConfig, primaryDeploymentArn)
2165+
sigintConfig.rollbackRoutineStarted = true
2166+
}
2167+
2168+
deploymentStatus, err := findDeploymentStatus(ctx, conn, *primaryDeploymentArn)
21592169
if err != nil {
21602170
return nil, "", err
21612171
}
@@ -2246,19 +2256,34 @@ func findDeploymentStatus(ctx context.Context, conn *ecs.Client, deploymentArn s
22462256
}
22472257
}
22482258

2249-
func waitForCancellation(ctx context.Context, conn *ecs.Client, clusterName, serviceName string, primaryDeploymentArn **string, operationTime time.Time) {
2250-
<-ctx.Done()
2251-
log.Printf("[INFO] Detected cancellation. Initiating rollback for deployment.")
2252-
newCtx := context.Background()
2253-
err := rollbackBlueGreenDeployment(newCtx, conn, clusterName, serviceName, *primaryDeploymentArn, operationTime)
2254-
if err != nil {
2255-
log.Printf("[ERROR] Failed to rollback deployment: %s", err)
2256-
} else {
2257-
log.Printf("[INFO] Blue/green deployment cancelled and rolled back successfully.")
2259+
type rollbackState struct {
2260+
rollbackRequested bool
2261+
rollbackRoutineStarted bool
2262+
rollbackRoutineStopped chan struct{}
2263+
waitGroup sync.WaitGroup
2264+
}
2265+
2266+
func rollbackRoutine(ctx context.Context, conn *ecs.Client, rollbackState *rollbackState, primaryDeploymentArn *string) {
2267+
defer rollbackState.waitGroup.Done()
2268+
2269+
select {
2270+
case <-ctx.Done():
2271+
log.Printf("[INFO] SIGINT detected. Initiating rollback for deployment: %s", *primaryDeploymentArn)
2272+
cancelContext, cancelFunc := context.WithTimeout(context.Background(), (1 * time.Hour)) // Maximum time before SIGKILL
2273+
defer cancelFunc()
2274+
2275+
if err := rollbackBlueGreenDeployment(cancelContext, conn, primaryDeploymentArn); err != nil {
2276+
log.Printf("[ERROR] Failed to rollback deployment: %s. Err: %s", *primaryDeploymentArn, err)
2277+
} else {
2278+
log.Printf("[INFO] Blue/green deployment: %s rolled back successfully.", *primaryDeploymentArn)
2279+
}
2280+
2281+
case <-rollbackState.rollbackRoutineStopped:
2282+
return
22582283
}
22592284
}
22602285

2261-
func rollbackBlueGreenDeployment(ctx context.Context, conn *ecs.Client, clusterName, serviceName string, primaryDeploymentArn *string, operationTime time.Time) error {
2286+
func rollbackBlueGreenDeployment(ctx context.Context, conn *ecs.Client, primaryDeploymentArn *string) error {
22622287
// Check if deployment is already in terminal state, meaning rollback is not needed
22632288
deploymentStatus, err := findDeploymentStatus(ctx, conn, *primaryDeploymentArn)
22642289
if err != nil {
@@ -2280,21 +2305,20 @@ func rollbackBlueGreenDeployment(ctx context.Context, conn *ecs.Client, clusterN
22802305
return err
22812306
}
22822307

2283-
err = waitForDeploymentTerminalStatus(ctx, conn, *primaryDeploymentArn)
2284-
if err != nil {
2285-
return err
2286-
}
2287-
2288-
return nil
2308+
return waitForDeploymentTerminalStatus(ctx, conn, *primaryDeploymentArn)
22892309
}
22902310

22912311
func waitForDeploymentTerminalStatus(ctx context.Context, conn *ecs.Client, primaryDeploymentArn string) error {
22922312
stateConf := &retry.StateChangeConf{
2293-
Pending: []string{string(awstypes.ServiceDeploymentStatusInProgress), string(awstypes.ServiceDeploymentStatusPending)},
2294-
Target: deploymentTerminalStates,
2295-
Refresh: func() (interface{}, string, error) {
2313+
Pending: []string{
2314+
string(awstypes.ServiceDeploymentStatusPending),
2315+
string(awstypes.ServiceDeploymentStatusInProgress),
2316+
string(awstypes.ServiceDeploymentStatusRollbackRequested),
2317+
string(awstypes.ServiceDeploymentStatusRollbackInProgress),
2318+
},
2319+
Target: deploymentTerminalStates,
2320+
Refresh: func() (any, string, error) {
22962321
status, err := findDeploymentStatus(ctx, conn, primaryDeploymentArn)
2297-
22982322
return nil, status, err
22992323
},
23002324
Timeout: 1 * time.Hour, // Maximum time before SIGKILL
@@ -2306,27 +2330,26 @@ func waitForDeploymentTerminalStatus(ctx context.Context, conn *ecs.Client, prim
23062330

23072331
// waitServiceStable waits for an ECS Service to reach the status "ACTIVE" and have all desired tasks running.
23082332
// Does not return tags.
2309-
func waitServiceStable(ctx context.Context, conn *ecs.Client, serviceName, clusterNameOrARN string, operationTime time.Time, sigintCancellation bool, timeout time.Duration) (*awstypes.Service, error) {
2310-
deployment := &deploymentState{
2311-
primaryDeploymentArn: new(*string),
2333+
func waitServiceStable(ctx context.Context, conn *ecs.Client, serviceName, clusterNameOrARN string, operationTime time.Time, sigintCancellation bool, timeout time.Duration) (*awstypes.Service, error) { //nolint:unparam
2334+
sigintConfig := &rollbackState{
2335+
rollbackRequested: sigintCancellation,
2336+
rollbackRoutineStarted: false,
2337+
rollbackRoutineStopped: make(chan struct{}),
2338+
waitGroup: sync.WaitGroup{},
23122339
}
23132340

2314-
cancellation := &cancellationState{}
2315-
23162341
stateConf := &retry.StateChangeConf{
23172342
Pending: []string{serviceStatusInactive, serviceStatusDraining, serviceStatusPending},
23182343
Target: []string{serviceStatusStable},
2319-
Refresh: func() (any, string, error) {
2320-
return refreshStatusAndHandleCancellation(ctx, conn, serviceName, clusterNameOrARN, operationTime, sigintCancellation, deployment, cancellation)
2321-
},
2344+
Refresh: statusServiceWaitForStable(ctx, conn, serviceName, clusterNameOrARN, sigintConfig, operationTime),
23222345
Timeout: timeout,
23232346
}
23242347

23252348
outputRaw, err := stateConf.WaitForStateContext(ctx)
23262349

2327-
if cancellation.goroutineStarted {
2328-
cancellation.cancelFunc()
2329-
<-cancellation.done
2350+
if sigintConfig.rollbackRoutineStarted {
2351+
close(sigintConfig.rollbackRoutineStopped)
2352+
sigintConfig.waitGroup.Wait()
23302353
}
23312354

23322355
if output, ok := outputRaw.(*awstypes.Service); ok {
@@ -2336,36 +2359,6 @@ func waitServiceStable(ctx context.Context, conn *ecs.Client, serviceName, clust
23362359
return nil, err
23372360
}
23382361

2339-
type deploymentState struct {
2340-
primaryDeploymentArn **string
2341-
primaryTaskSet *awstypes.Deployment
2342-
isNewECSDeployment bool
2343-
}
2344-
2345-
type cancellationState struct {
2346-
cancelFunc context.CancelFunc
2347-
cancelCtx context.Context
2348-
done chan struct{}
2349-
goroutineStarted bool
2350-
}
2351-
2352-
func refreshStatusAndHandleCancellation(ctx context.Context, conn *ecs.Client, serviceName, clusterNameOrARN string, operationTime time.Time, sigintCancellation bool, deployment *deploymentState, cancellation *cancellationState) (any, string, error) {
2353-
result, status, err := statusServiceWaitForStable(ctx, conn, serviceName, clusterNameOrARN, deployment.primaryDeploymentArn, operationTime, &deployment.primaryTaskSet, &deployment.isNewECSDeployment)()
2354-
2355-
// Create context and start goroutine only when needed
2356-
if sigintCancellation && !cancellation.goroutineStarted && *deployment.primaryDeploymentArn != nil {
2357-
cancellation.cancelCtx, cancellation.cancelFunc = context.WithCancel(ctx)
2358-
cancellation.done = make(chan struct{})
2359-
go func() {
2360-
defer close(cancellation.done)
2361-
waitForCancellation(cancellation.cancelCtx, conn, clusterNameOrARN, serviceName, deployment.primaryDeploymentArn, operationTime)
2362-
}()
2363-
cancellation.goroutineStarted = true
2364-
}
2365-
2366-
return result, status, err
2367-
}
2368-
23692362
// Does not return tags.
23702363
func waitServiceActive(ctx context.Context, conn *ecs.Client, serviceName, clusterNameOrARN string, timeout time.Duration) (*awstypes.Service, error) { //nolint:unparam
23712364
stateConf := &retry.StateChangeConf{

0 commit comments

Comments
 (0)