diff --git a/controllers/aga/globalaccelerator_controller.go b/controllers/aga/globalaccelerator_controller.go index 0f4ed7f1c..07277cf84 100644 --- a/controllers/aga/globalaccelerator_controller.go +++ b/controllers/aga/globalaccelerator_controller.go @@ -64,9 +64,15 @@ const ( agaResourcesGroupVersion = "aga.k8s.aws/v1beta1" globalAcceleratorKind = "GlobalAccelerator" - // Requeue constants for provisioning state monitoring - requeueMessage = "Monitoring provisioning state" - statusUpdateRequeueTime = 1 * time.Minute + // Requeue constants for state monitoring + // requeueReasonAcceleratorInProgress indicates that the reconciliation is being requeued because + // the Global Accelerator is still in progress state + requeueReasonAcceleratorInProgress = "Waiting for Global Accelerator %s with status 'IN_PROGRESS' to complete" + + // requeueReasonEndpointsInWarningState indicates that the reconciliation is being requeued because + // there are endpoints in warning state that need to be periodically rechecked + requeueReasonEndpointsInWarningState = "Retrying endpoints for Global Accelerator %s which did load successfully - will check availability again soon" + statusUpdateRequeueTime = 1 * time.Minute // Metric stage constants MetricStageFetchGlobalAccelerator = "fetch_globalAccelerator" @@ -251,8 +257,8 @@ func (r *globalAcceleratorReconciler) cleanupGlobalAccelerator(ctx context.Conte return nil } -func (r *globalAcceleratorReconciler) buildModel(ctx context.Context, ga *agaapi.GlobalAccelerator) (core.Stack, *agamodel.Accelerator, error) { - stack, accelerator, err := r.modelBuilder.Build(ctx, ga) +func (r *globalAcceleratorReconciler) buildModel(ctx context.Context, ga *agaapi.GlobalAccelerator, loadedEndpoints []*aga.LoadedEndpoint) (core.Stack, *agamodel.Accelerator, error) { + stack, accelerator, err := r.modelBuilder.Build(ctx, ga, loadedEndpoints) if err != nil { r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedBuildModel, fmt.Sprintf("Failed build model due to %v", err)) return nil, nil, err @@ -279,7 +285,7 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co r.endpointResourcesManager.MonitorEndpointResources(ga, endpoints) // Validate and load endpoint status using the endpoint loader - _, fatalErrors := r.endpointLoader.LoadEndpoints(ctx, ga, endpoints) + loadedEndpoints, fatalErrors := r.endpointLoader.LoadEndpoints(ctx, ga, endpoints) if len(fatalErrors) > 0 { err := fmt.Errorf("failed to load endpoints: %v", fatalErrors[0]) r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedEndpointLoad, fmt.Sprintf("Failed to reconcile due to %v", err)) @@ -295,7 +301,7 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co var accelerator *agamodel.Accelerator var err error buildModelFn := func() { - stack, accelerator, err = r.buildModel(ctx, ga) + stack, accelerator, err = r.buildModel(ctx, ga, loadedEndpoints) } r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageBuildModel, buildModelFn) if err != nil { @@ -326,14 +332,37 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co r.logger.Info("Successfully deployed GlobalAccelerator stack", "stackID", stack.StackID()) - // Update GlobalAccelerator status after successful deployment + // Check if any endpoints have warning status and collect them + hasWarningEndpoints := false + for _, ep := range loadedEndpoints { + if ep.Status == aga.EndpointStatusWarning { + hasWarningEndpoints = true + } + } + + // Update GlobalAccelerator status after successful deployment, including warning endpoints requeueNeeded, err := r.statusUpdater.UpdateStatusSuccess(ctx, ga, accelerator) if err != nil { r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedUpdateStatus, fmt.Sprintf("Failed update status due to %v", err)) return err } - if requeueNeeded { - return ctrlerrors.NewRequeueNeededAfter(requeueMessage, statusUpdateRequeueTime) + + // If we have warning endpoints, add a separate condition for them and requeue + if hasWarningEndpoints { + r.logger.V(1).Info("Detected endpoints in warning state, will requeue", + "Global Accelerator", k8s.NamespacedName(ga)) + + // Add event to notify about warning endpoints + warningMessage := fmt.Sprintf("Detected endpoints which did not load successfully. These endpoints will be rechecked shortly.") + r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonWarningEndpoints, warningMessage) + } + + if requeueNeeded || hasWarningEndpoints { + message := fmt.Sprintf(requeueReasonAcceleratorInProgress, k8s.NamespacedName(ga)) + if hasWarningEndpoints { + message = fmt.Sprintf(requeueReasonEndpointsInWarningState, k8s.NamespacedName(ga)) + } + return ctrlerrors.NewRequeueNeededAfter(message, statusUpdateRequeueTime) } r.eventRecorder.Event(ga, corev1.EventTypeNormal, k8s.GlobalAcceleratorEventReasonSuccessfullyReconciled, "Successfully reconciled") @@ -379,7 +408,7 @@ func (r *globalAcceleratorReconciler) cleanupGlobalAcceleratorResources(ctx cont if updateErr := r.statusUpdater.UpdateStatusDeletion(ctx, ga); updateErr != nil { r.logger.Error(updateErr, "Failed to update status during accelerator deletion") } - return ctrlerrors.NewRequeueNeeded("Waiting for accelerator to be disabled") + return ctrlerrors.NewRequeueNeeded(fmt.Sprintf(requeueReasonAcceleratorInProgress, k8s.NamespacedName(ga))) } // Any other error diff --git a/pkg/aga/model_build_endpoint_group.go b/pkg/aga/model_build_endpoint_group.go index 90d987b86..63ee9f03e 100644 --- a/pkg/aga/model_build_endpoint_group.go +++ b/pkg/aga/model_build_endpoint_group.go @@ -4,6 +4,7 @@ import ( "context" "fmt" awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/go-logr/logr" agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" @@ -12,16 +13,21 @@ import ( // endpointGroupBuilder builds EndpointGroup model resources type endpointGroupBuilder interface { // Build builds all endpoint groups for all listeners - Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) + Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, + listenerConfigs []agaapi.GlobalAcceleratorListener, loadedEndpoints []*LoadedEndpoint) ([]*agamodel.EndpointGroup, error) // buildEndpointGroupsForListener builds endpoint groups for a specific listener - buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) + buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, + endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int, + loadedEndpoints []*LoadedEndpoint) ([]*agamodel.EndpointGroup, error) } // NewEndpointGroupBuilder constructs new endpointGroupBuilder -func NewEndpointGroupBuilder(clusterRegion string) endpointGroupBuilder { +func NewEndpointGroupBuilder(clusterRegion string, gaNamespace string, logger logr.Logger) endpointGroupBuilder { return &defaultEndpointGroupBuilder{ clusterRegion: clusterRegion, + gaNamespace: gaNamespace, + logger: logger, } } @@ -29,10 +35,13 @@ var _ endpointGroupBuilder = &defaultEndpointGroupBuilder{} type defaultEndpointGroupBuilder struct { clusterRegion string + gaNamespace string + logger logr.Logger } // Build builds EndpointGroup model resources -func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) { +func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, + listenerConfigs []agaapi.GlobalAcceleratorListener, loadedEndpoints []*LoadedEndpoint) ([]*agamodel.EndpointGroup, error) { if listeners == nil || len(listeners) == 0 { return nil, nil } @@ -51,7 +60,7 @@ func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stac continue } - listenerEndpointGroups, err := b.buildEndpointGroupsForListener(ctx, stack, listener, *listenerConfig.EndpointGroups, i) + listenerEndpointGroups, err := b.buildEndpointGroupsForListener(ctx, stack, listener, *listenerConfig.EndpointGroups, i, loadedEndpoints) if err != nil { return nil, err } @@ -114,11 +123,13 @@ func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesWithinListene } // buildEndpointGroupsForListener builds EndpointGroup models for a specific listener -func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) { +func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, + listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, + listenerIndex int, loadedEndpoints []*LoadedEndpoint) ([]*agamodel.EndpointGroup, error) { var result []*agamodel.EndpointGroup for i, endpointGroup := range endpointGroups { - spec, err := b.buildEndpointGroupSpec(ctx, listener, endpointGroup) + spec, err := b.buildEndpointGroupSpec(ctx, listener, endpointGroup, loadedEndpoints) if err != nil { return nil, err } @@ -132,7 +143,9 @@ func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context } // buildEndpointGroupSpec builds the EndpointGroupSpec for a single EndpointGroup model resource -func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (agamodel.EndpointGroupSpec, error) { +func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context, + listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup, + loadedEndpoints []*LoadedEndpoint) (agamodel.EndpointGroupSpec, error) { region, err := b.determineRegion(endpointGroup) if err != nil { return agamodel.EndpointGroupSpec{}, err @@ -146,14 +159,90 @@ func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context return agamodel.EndpointGroupSpec{}, err } + // Build endpoint configurations from both static configurations and loaded endpoints + endpointConfigurations, err := b.buildEndpointConfigurations(ctx, endpointGroup, loadedEndpoints) + if err != nil { + return agamodel.EndpointGroupSpec{}, err + } + return agamodel.EndpointGroupSpec{ - ListenerARN: listener.ListenerARN(), - Region: region, - TrafficDialPercentage: trafficDialPercentage, - PortOverrides: portOverrides, + ListenerARN: listener.ListenerARN(), + Region: region, + TrafficDialPercentage: trafficDialPercentage, + PortOverrides: portOverrides, + EndpointConfigurations: endpointConfigurations, }, nil } +// generateEndpointKey creates a consistent string key for endpoint lookup +func generateEndpointKey(ep agaapi.GlobalAcceleratorEndpoint, gaNamespace string) string { + namespace := gaNamespace + if ep.Namespace != nil { + namespace = awssdk.ToString(ep.Namespace) + } + name := awssdk.ToString(ep.Name) + + if ep.Type == agaapi.GlobalAcceleratorEndpointTypeEndpointID { + return fmt.Sprintf("%s/%s", ep.Type, awssdk.ToString(ep.EndpointID)) + } + return fmt.Sprintf("%s/%s/%s", ep.Type, namespace, name) +} + +// buildEndpointConfigurations builds endpoint configurations from both static configurations in the API struct +// and from successfully loaded endpoints +func (b *defaultEndpointGroupBuilder) buildEndpointConfigurations(_ context.Context, + endpointGroup agaapi.GlobalAcceleratorEndpointGroup, loadedEndpoints []*LoadedEndpoint) ([]agamodel.EndpointConfiguration, error) { + + var endpointConfigurations []agamodel.EndpointConfiguration + + // Skip if no endpoints defined in the endpoint group + if endpointGroup.Endpoints == nil { + return nil, nil + } + + // Build a map of loaded endpoints with for quick lookup + loadedEndpointsMap := make(map[string]*LoadedEndpoint) + for _, le := range loadedEndpoints { + key := le.GetKey() + loadedEndpointsMap[key] = le + + } + + // Process the endpoints defined in the CRD and match with loaded endpoints + for _, ep := range *endpointGroup.Endpoints { + // Create key for lookup using the helper function + lookupKey := generateEndpointKey(ep, b.gaNamespace) + + // Find the loaded endpoint + if loadedEndpoint, found := loadedEndpointsMap[lookupKey]; found { + // Add endpoint to model stack only if its in Loaded status and has valid ARN + if loadedEndpoint.Status == EndpointStatusLoaded { + // Create a base configuration with the loaded endpoint's ARN + endpointConfig := agamodel.EndpointConfiguration{ + EndpointID: loadedEndpoint.ARN, + } + endpointConfig.Weight = awssdk.Int32(loadedEndpoint.Weight) + endpointConfig.ClientIPPreservationEnabled = ep.ClientIPPreservationEnabled + endpointConfigurations = append(endpointConfigurations, endpointConfig) + } else { + // Log warning for endpoints which are not loaded successfully during loading and has Warning status + b.logger.Info("Endpoint not added to endpoint group as no valid ARN was found during loading", + "endpoint", lookupKey, + "message", loadedEndpoint.Message, + "error", loadedEndpoint.Error) + } + } else { + b.logger.Info("Endpoint not found in loaded endpoints", + "endpoint", lookupKey) + } + } + + return endpointConfigurations, nil +} + +// Note: The TargetsEndpointGroup method is no longer needed since we match endpoints based on +// the explicit references in the GlobalAcceleratorEndpoint resources under each endpoint group + // validateListenerPortOverrideWithinListenerPortRanges ensures all listener ports used in port overrides are // contained within the listener's port ranges func (b *defaultEndpointGroupBuilder) validateListenerPortOverrideWithinListenerPortRanges(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error { @@ -248,3 +337,23 @@ func (b *defaultEndpointGroupBuilder) validatePortOverrides(listener *agamodel.L return nil } + +// buildEndpointConfiguration creates an EndpointConfiguration from a GlobalAcceleratorEndpoint +// This helper function consolidates the repeated code for creating endpoint configurations +func buildEndpointConfigurationFromEndpoint(endpoint *agaapi.GlobalAcceleratorEndpoint) agamodel.EndpointConfiguration { + endpointConfig := agamodel.EndpointConfiguration{ + EndpointID: awssdk.ToString(endpoint.EndpointID), + } + + // Add weight if specified + if endpoint.Weight != nil { + endpointConfig.Weight = endpoint.Weight + } + + // Add client IP preservation setting if specified + if endpoint.ClientIPPreservationEnabled != nil { + endpointConfig.ClientIPPreservationEnabled = endpoint.ClientIPPreservationEnabled + } + + return endpointConfig +} diff --git a/pkg/aga/model_build_endpoint_group_test.go b/pkg/aga/model_build_endpoint_group_test.go index 5e51e389b..4905f7854 100644 --- a/pkg/aga/model_build_endpoint_group_test.go +++ b/pkg/aga/model_build_endpoint_group_test.go @@ -2,15 +2,337 @@ package aga import ( "context" + "fmt" awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" "testing" - - "github.com/stretchr/testify/assert" - agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) +func Test_generateEndpointKey(t *testing.T) { + tests := []struct { + name string + endpoint agaapi.GlobalAcceleratorEndpoint + gaNamespace string + want string + }{ + { + name: "endpoint with EndpointID type", + endpoint: agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-alb/1234567890"), + }, + gaNamespace: "default", + want: "EndpointID/arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-alb/1234567890", + }, + { + name: "endpoint with Service type and explicit namespace", + endpoint: agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Namespace: awssdk.String("test-namespace"), + Name: awssdk.String("test-service"), + }, + gaNamespace: "default", + want: "Service/test-namespace/test-service", + }, + { + name: "endpoint with Service type and default namespace", + endpoint: agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("test-service"), + }, + gaNamespace: "default", + want: "Service/default/test-service", + }, + { + name: "endpoint with Ingress type", + endpoint: agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Namespace: awssdk.String("ingress-ns"), + Name: awssdk.String("test-ingress"), + }, + gaNamespace: "default", + want: "Ingress/ingress-ns/test-ingress", + }, + { + name: "endpoint with Gateway type", + endpoint: agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Namespace: awssdk.String("gateway-ns"), + Name: awssdk.String("test-gateway"), + }, + gaNamespace: "default", + want: "Gateway/gateway-ns/test-gateway", + }, + { + name: "endpoint with nil name (should still work)", + endpoint: agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Namespace: awssdk.String("test-namespace"), + Name: nil, + }, + gaNamespace: "default", + want: "Service/test-namespace/", + }, + { + name: "endpoint with both nil namespace and name", + endpoint: agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Namespace: nil, + Name: nil, + }, + gaNamespace: "default", + want: "Service/default/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := generateEndpointKey(tt.endpoint, tt.gaNamespace) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultEndpointGroupBuilder_buildEndpointConfigurations(t *testing.T) { + testLogger := logr.Discard() + + // Create test LoadedEndpoints + createTestEndpoints := func() []*LoadedEndpoint { + return []*LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-service/1234567890", + DNSName: "test-service.default.svc.cluster.local", + Status: EndpointStatusLoaded, + EndpointRef: &agaapi.GlobalAcceleratorEndpoint{Type: agaapi.GlobalAcceleratorEndpointTypeService, Name: awssdk.String("test-service")}, + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: "test-ingress", + Namespace: "ingress-ns", + Weight: 200, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-ingress/0987654321", + DNSName: "test-ingress.example.com", + Status: EndpointStatusLoaded, + EndpointRef: &agaapi.GlobalAcceleratorEndpoint{Type: agaapi.GlobalAcceleratorEndpointTypeIngress, Name: awssdk.String("test-ingress"), Namespace: awssdk.String("ingress-ns")}, + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Name: "test-gateway", + Namespace: "gateway-ns", + Weight: 150, + ARN: "", + DNSName: "", + Status: EndpointStatusWarning, + Error: fmt.Errorf("gateway not found"), + Message: "Gateway resource not found", + EndpointRef: &agaapi.GlobalAcceleratorEndpoint{Type: agaapi.GlobalAcceleratorEndpointTypeGateway, Name: awssdk.String("test-gateway"), Namespace: awssdk.String("gateway-ns")}, + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + Name: "", + Namespace: "", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-lb/abcdef1234", + Status: EndpointStatusLoaded, + EndpointRef: &agaapi.GlobalAcceleratorEndpoint{Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, EndpointID: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-lb/abcdef1234")}, + }, + } + } + + tests := []struct { + name string + endpointGroup agaapi.GlobalAcceleratorEndpointGroup + loadedEndpoints []*LoadedEndpoint + want []agamodel.EndpointConfiguration + wantErr bool + }{ + { + name: "nil endpoints in endpoint group", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: nil, + }, + loadedEndpoints: createTestEndpoints(), + want: nil, + wantErr: false, + }, + { + name: "empty endpoints array in endpoint group", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{}, + }, + loadedEndpoints: createTestEndpoints(), + want: []agamodel.EndpointConfiguration{}, + wantErr: false, + }, + { + name: "endpoint with EndpointID reference", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-lb/abcdef1234"), + ClientIPPreservationEnabled: awssdk.Bool(true), + }, + }, + }, + loadedEndpoints: createTestEndpoints(), + want: []agamodel.EndpointConfiguration{ + { + EndpointID: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-lb/abcdef1234", + Weight: awssdk.Int32(100), + ClientIPPreservationEnabled: awssdk.Bool(true), + }, + }, + wantErr: false, + }, + { + name: "endpoint with Service reference", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("test-service"), + ClientIPPreservationEnabled: awssdk.Bool(false), + }, + }, + }, + loadedEndpoints: createTestEndpoints(), + want: []agamodel.EndpointConfiguration{ + { + EndpointID: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-service/1234567890", + Weight: awssdk.Int32(100), + ClientIPPreservationEnabled: awssdk.Bool(false), + }, + }, + wantErr: false, + }, + { + name: "endpoint with Ingress reference, no override weight", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Namespace: awssdk.String("ingress-ns"), + Name: awssdk.String("test-ingress"), + ClientIPPreservationEnabled: awssdk.Bool(true), + }, + }, + }, + loadedEndpoints: createTestEndpoints(), + want: []agamodel.EndpointConfiguration{ + { + EndpointID: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-ingress/0987654321", + Weight: awssdk.Int32(200), // From the loaded endpoint, no override + ClientIPPreservationEnabled: awssdk.Bool(true), // From the endpoint definition + }, + }, + wantErr: false, + }, + { + name: "endpoint with warning status - not included", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Namespace: awssdk.String("gateway-ns"), + Name: awssdk.String("test-gateway"), + }, + }, + }, + loadedEndpoints: createTestEndpoints(), + want: []agamodel.EndpointConfiguration{}, // No endpoints should be added + wantErr: false, + }, + { + name: "endpoint not found in loaded endpoints", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Namespace: awssdk.String("non-existent"), + Name: awssdk.String("non-existent-service"), + }, + }, + }, + loadedEndpoints: createTestEndpoints(), + want: []agamodel.EndpointConfiguration{}, // No endpoints should be added + wantErr: false, + }, + { + name: "multiple endpoints with mixed types", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("test-service"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Namespace: awssdk.String("ingress-ns"), + Name: awssdk.String("test-ingress"), + ClientIPPreservationEnabled: awssdk.Bool(true), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Namespace: awssdk.String("gateway-ns"), + Name: awssdk.String("test-gateway"), // Has warning status, should be skipped + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("non-existent-service"), // Not in loaded endpoints, should be skipped + }, + }, + }, + loadedEndpoints: createTestEndpoints(), + want: []agamodel.EndpointConfiguration{ + { + EndpointID: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-service/1234567890", + Weight: awssdk.Int32(100), + ClientIPPreservationEnabled: nil, + }, + { + EndpointID: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-ingress/0987654321", + Weight: awssdk.Int32(200), // From the loaded endpoint + ClientIPPreservationEnabled: awssdk.Bool(true), + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + // Create endpointGroupBuilder + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + gaNamespace: "default", + logger: testLogger, + } + + // Call buildEndpointConfigurations + got, err := builder.buildEndpointConfigurations(ctx, tt.endpointGroup, tt.loadedEndpoints) + + // Check for expected error + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.ElementsMatch(t, tt.want, got) // Use ElementsMatch to ignore order + } + }) + } +} + func Test_defaultEndpointGroupBuilder_determineRegion(t *testing.T) { tests := []struct { name string diff --git a/pkg/aga/model_builder.go b/pkg/aga/model_builder.go index c3e1caa5b..e9b492929 100644 --- a/pkg/aga/model_builder.go +++ b/pkg/aga/model_builder.go @@ -17,7 +17,7 @@ import ( // ModelBuilder is responsible for building model stack for a GlobalAccelerator. type ModelBuilder interface { // Build model stack for a GlobalAccelerator. - Build(ctx context.Context, ga *agaapi.GlobalAccelerator) (core.Stack, *agamodel.Accelerator, error) + Build(ctx context.Context, ga *agaapi.GlobalAccelerator, loadedEndpoints []*LoadedEndpoint) (core.Stack, *agamodel.Accelerator, error) } // NewDefaultModelBuilder constructs new defaultModelBuilder. @@ -56,15 +56,14 @@ type defaultModelBuilder struct { } // Build model stack for a GlobalAccelerator. -func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccelerator) (core.Stack, *agamodel.Accelerator, error) { +func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccelerator, loadedEndpoints []*LoadedEndpoint) (core.Stack, *agamodel.Accelerator, error) { stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(ga))) // Create fresh builder instances for each reconciliation acceleratorBuilder := NewAcceleratorBuilder(b.trackingProvider, b.clusterName, b.clusterRegion, b.defaultTags, b.externalManagedTags, b.featureGates.Enabled(config.EnableDefaultTagsLowPriority)) listenerBuilder := NewListenerBuilder() - endpointGroupBuilder := NewEndpointGroupBuilder(b.clusterRegion) - // TODO - // endpointBuilder := NewEndpointBuilder() + endpointGroupBuilder := NewEndpointGroupBuilder(b.clusterRegion, ga.Namespace, b.logger) + // Build Accelerator accelerator, err := acceleratorBuilder.Build(ctx, stack, ga) if err != nil { @@ -78,15 +77,13 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele if err != nil { return nil, nil, err } - endpointGroups, err := endpointGroupBuilder.Build(ctx, stack, listeners, *ga.Spec.Listeners) + + // Build endpoint groups with loaded endpoints + _, err := endpointGroupBuilder.Build(ctx, stack, listeners, *ga.Spec.Listeners, loadedEndpoints) if err != nil { return nil, nil, err } - b.logger.V(1).Info("Listener and endpoint groups built", "listeners", listeners, "endpointGroups", endpointGroups) } - // TODO: Add endpoint builder - // endpoints, err := endpointBuilder.Build(ctx, stack, endpointGroups, ga.Spec.Listeners) - return stack, accelerator, nil } diff --git a/pkg/aws/services/globalaccelerator.go b/pkg/aws/services/globalaccelerator.go index 30e627307..be4736e85 100644 --- a/pkg/aws/services/globalaccelerator.go +++ b/pkg/aws/services/globalaccelerator.go @@ -64,6 +64,12 @@ type GlobalAccelerator interface { // ListTagsForResource lists tags for a resource. ListTagsForResourceWithContext(ctx context.Context, input *globalaccelerator.ListTagsForResourceInput) (*globalaccelerator.ListTagsForResourceOutput, error) + + // AddEndpoints adds endpoints to an endpoint group. + AddEndpointsWithContext(ctx context.Context, input *globalaccelerator.AddEndpointsInput) (*globalaccelerator.AddEndpointsOutput, error) + + // RemoveEndpoints removes endpoints from an endpoint group. + RemoveEndpointsWithContext(ctx context.Context, input *globalaccelerator.RemoveEndpointsInput) (*globalaccelerator.RemoveEndpointsOutput, error) } // NewGlobalAccelerator constructs new GlobalAccelerator implementation. @@ -256,3 +262,19 @@ func (c *defaultGlobalAccelerator) ListEndpointGroupsAsList(ctx context.Context, } return result, nil } + +func (c *defaultGlobalAccelerator) AddEndpointsWithContext(ctx context.Context, input *globalaccelerator.AddEndpointsInput) (*globalaccelerator.AddEndpointsOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "AddEndpoints") + if err != nil { + return nil, err + } + return client.AddEndpoints(ctx, input) +} + +func (c *defaultGlobalAccelerator) RemoveEndpointsWithContext(ctx context.Context, input *globalaccelerator.RemoveEndpointsInput) (*globalaccelerator.RemoveEndpointsOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "RemoveEndpoints") + if err != nil { + return nil, err + } + return client.RemoveEndpoints(ctx, input) +} diff --git a/pkg/aws/services/globalaccelerator_mocks.go b/pkg/aws/services/globalaccelerator_mocks.go index 0bad79b37..f267f5100 100644 --- a/pkg/aws/services/globalaccelerator_mocks.go +++ b/pkg/aws/services/globalaccelerator_mocks.go @@ -36,6 +36,21 @@ func (m *MockGlobalAccelerator) EXPECT() *MockGlobalAcceleratorMockRecorder { return m.recorder } +// AddEndpointsWithContext mocks base method. +func (m *MockGlobalAccelerator) AddEndpointsWithContext(arg0 context.Context, arg1 *globalaccelerator.AddEndpointsInput) (*globalaccelerator.AddEndpointsOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddEndpointsWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.AddEndpointsOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddEndpointsWithContext indicates an expected call of AddEndpointsWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) AddEndpointsWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEndpointsWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).AddEndpointsWithContext), arg0, arg1) +} + // CreateAcceleratorWithContext mocks base method. func (m *MockGlobalAccelerator) CreateAcceleratorWithContext(arg0 context.Context, arg1 *globalaccelerator.CreateAcceleratorInput) (*globalaccelerator.CreateAcceleratorOutput, error) { m.ctrl.T.Helper() @@ -246,6 +261,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) ListTagsForResourceWithContext(arg0 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTagsForResourceWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListTagsForResourceWithContext), arg0, arg1) } +// RemoveEndpointsWithContext mocks base method. +func (m *MockGlobalAccelerator) RemoveEndpointsWithContext(arg0 context.Context, arg1 *globalaccelerator.RemoveEndpointsInput) (*globalaccelerator.RemoveEndpointsOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveEndpointsWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.RemoveEndpointsOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RemoveEndpointsWithContext indicates an expected call of RemoveEndpointsWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) RemoveEndpointsWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveEndpointsWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).RemoveEndpointsWithContext), arg0, arg1) +} + // TagResourceWithContext mocks base method. func (m *MockGlobalAccelerator) TagResourceWithContext(arg0 context.Context, arg1 *globalaccelerator.TagResourceInput) (*globalaccelerator.TagResourceOutput, error) { m.ctrl.T.Helper() diff --git a/pkg/deploy/aga/endpoint_group_manager.go b/pkg/deploy/aga/endpoint_group_manager.go index 263056e48..c6d2529ab 100644 --- a/pkg/deploy/aga/endpoint_group_manager.go +++ b/pkg/deploy/aga/endpoint_group_manager.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/util/sets" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) @@ -22,6 +23,9 @@ type EndpointGroupManager interface { // Delete deletes an endpoint group. Delete(ctx context.Context, endpointGroupARN string) error + + // ManageEndpoints manages endpoints in an endpoint group based on the desired state. + ManageEndpoints(ctx context.Context, endpointGroupARN string, resEndpointConfigs []agamodel.EndpointConfiguration, sdkEndpoints []agatypes.EndpointDescription) error } // NewDefaultEndpointGroupManager constructs new defaultEndpointGroupManager. @@ -105,6 +109,16 @@ func (m *defaultEndpointGroupManager) Create(ctx context.Context, resEndpointGro "resourceID", resEndpointGroup.ID(), "endpointGroupARN", *endpointGroup.EndpointGroupArn) + // Manage endpoints for newly created endpoint group + // For new endpoint groups, there are no existing endpoints + var noEndpoints []agatypes.EndpointDescription + if err := m.ManageEndpoints(ctx, *endpointGroup.EndpointGroupArn, resEndpointGroup.Spec.EndpointConfigurations, noEndpoints); err != nil { + m.logger.Error(err, "Failed to manage endpoints for newly created endpoint group", + "endpointGroupARN", *endpointGroup.EndpointGroupArn, + "endpointCount", len(resEndpointGroup.Spec.EndpointConfigurations)) + return agamodel.EndpointGroupStatus{}, fmt.Errorf("failed to manage endpoints for endpoint group %s: %w", *endpointGroup.EndpointGroupArn, err) + } + return agamodel.EndpointGroupStatus{ EndpointGroupARN: *endpointGroup.EndpointGroupArn, }, nil @@ -137,6 +151,15 @@ func (m *defaultEndpointGroupManager) Update(ctx context.Context, resEndpointGro "resourceID", resEndpointGroup.ID(), "endpointGroupARN", *sdkEndpointGroup.EndpointGroupArn) + // Even if the endpoint group itself doesn't need an update, we still need to check endpoints + if err := m.ManageEndpoints(ctx, *sdkEndpointGroup.EndpointGroupArn, resEndpointGroup.Spec.EndpointConfigurations, sdkEndpointGroup.EndpointDescriptions); err != nil { + m.logger.Error(err, "Failed to manage endpoints for endpoint group", + "endpointGroupARN", *sdkEndpointGroup.EndpointGroupArn, + "desiredEndpointCount", len(resEndpointGroup.Spec.EndpointConfigurations), + "currentEndpointCount", len(sdkEndpointGroup.EndpointDescriptions)) + return agamodel.EndpointGroupStatus{}, fmt.Errorf("failed to manage endpoints for endpoint group %s: %w", *sdkEndpointGroup.EndpointGroupArn, err) + } + return agamodel.EndpointGroupStatus{ EndpointGroupARN: *sdkEndpointGroup.EndpointGroupArn, }, nil @@ -165,6 +188,15 @@ func (m *defaultEndpointGroupManager) Update(ctx context.Context, resEndpointGro "resourceID", resEndpointGroup.ID(), "endpointGroupARN", *updatedEndpointGroup.EndpointGroupArn) + // After updating the endpoint group, manage endpoints + if err := m.ManageEndpoints(ctx, *updatedEndpointGroup.EndpointGroupArn, resEndpointGroup.Spec.EndpointConfigurations, updatedEndpointGroup.EndpointDescriptions); err != nil { + m.logger.Error(err, "Failed to manage endpoints for updated endpoint group", + "endpointGroupARN", *updatedEndpointGroup.EndpointGroupArn, + "desiredEndpointCount", len(resEndpointGroup.Spec.EndpointConfigurations), + "currentEndpointCount", len(updatedEndpointGroup.EndpointDescriptions)) + return agamodel.EndpointGroupStatus{}, fmt.Errorf("failed to manage endpoints for updated endpoint group %s: %w", *updatedEndpointGroup.EndpointGroupArn, err) + } + return agamodel.EndpointGroupStatus{ EndpointGroupARN: *updatedEndpointGroup.EndpointGroupArn, }, nil @@ -238,3 +270,299 @@ func (m *defaultEndpointGroupManager) arePortOverridesEqual(modelPortOverrides [ return true } + +// isEndpointConfigurationDrifted checks if the endpoint settings have drifted between desired and existing configuration +func (m *defaultEndpointGroupManager) isEndpointConfigurationDrifted( + desiredConfig agamodel.EndpointConfiguration, + existingEndpoint agatypes.EndpointDescription) bool { + + // Check weight drift + if (desiredConfig.Weight == nil) != (existingEndpoint.Weight == nil) { + return true + } else if desiredConfig.Weight != nil && awssdk.ToInt32(desiredConfig.Weight) != awssdk.ToInt32(existingEndpoint.Weight) { + return true + } + + // Check client IP preservation drift + if (desiredConfig.ClientIPPreservationEnabled == nil) != (existingEndpoint.ClientIPPreservationEnabled == nil) { + return true + } else if desiredConfig.ClientIPPreservationEnabled != nil && + awssdk.ToBool(desiredConfig.ClientIPPreservationEnabled) != awssdk.ToBool(existingEndpoint.ClientIPPreservationEnabled) { + return true + } + + return false +} + +// buildSDKEndpointConfiguration converts a model endpoint configuration to an AWS SDK endpoint configuration +func (m *defaultEndpointGroupManager) buildSDKEndpointConfiguration(config agamodel.EndpointConfiguration) agatypes.EndpointConfiguration { + endpointConfig := agatypes.EndpointConfiguration{ + EndpointId: awssdk.String(config.EndpointID), + } + + // Add weight if specified + if config.Weight != nil { + endpointConfig.Weight = config.Weight + } + + // Add client IP preservation if specified + if config.ClientIPPreservationEnabled != nil { + endpointConfig.ClientIPPreservationEnabled = config.ClientIPPreservationEnabled + } + + return endpointConfig +} + +// detectEndpointDrift compares existing endpoints with desired endpoint configurations +// It efficiently determines which endpoints need to be added, updated or removed using set operations. +// Returns: +// - configsToAdd: Endpoint configurations that need to be added (present in desired but not in existing) +// - configsToUpdate: Endpoint configurations present in both desired and existing +// - endpointsToRemove: Endpoint IDs that need to be removed (present in existing but not in desired) +// - isUpdateRequired: Returns true if any existing endpoint needs property updates (weight or clientIPPreservation) +// This flag is used to determine the optimal API call strategy +func (m *defaultEndpointGroupManager) detectEndpointDrift( + existingEndpoints []agatypes.EndpointDescription, + desiredConfigs []agamodel.EndpointConfiguration) (configsToAdd []agamodel.EndpointConfiguration, configsToUpdate []agamodel.EndpointConfiguration, endpointsToRemove []string, isUpdateRequired bool) { + + // Extract all endpoint IDs from existing endpoints + existingEndpointIDs := sets.NewString() + existingIDToEndpoint := make(map[string]agatypes.EndpointDescription) + for _, endpoint := range existingEndpoints { + if endpoint.EndpointId != nil { + id := awssdk.ToString(endpoint.EndpointId) + existingEndpointIDs.Insert(id) + existingIDToEndpoint[id] = endpoint + } + } + + // Extract all endpoint IDs from desired configs and create a lookup map + desiredEndpointIDs := sets.NewString() + idToConfig := make(map[string]agamodel.EndpointConfiguration) + for _, config := range desiredConfigs { + desiredEndpointIDs.Insert(config.EndpointID) + idToConfig[config.EndpointID] = config + } + + // Find endpoints to update (present in both desired and existing) + endpointsToUpdateIDs := desiredEndpointIDs.Intersection(existingEndpointIDs) + isUpdateRequired = false + for id := range endpointsToUpdateIDs { + resConfig, _ := idToConfig[id] + sdkConfig, _ := existingIDToEndpoint[id] + if m.isEndpointConfigurationDrifted(resConfig, sdkConfig) { + isUpdateRequired = true + } + configsToUpdate = append(configsToUpdate, resConfig) + + } + + // Find endpoints to add (in desired but not in existing) + endpointsToAddIDs := desiredEndpointIDs.Difference(existingEndpointIDs) + for id := range endpointsToAddIDs { + config, _ := idToConfig[id] + configsToAdd = append(configsToAdd, config) + } + + // Find endpoints to remove (in existing but not in desired) + endpointsToRemove = existingEndpointIDs.Difference(desiredEndpointIDs).List() + + return configsToAdd, configsToUpdate, endpointsToRemove, isUpdateRequired +} + +// ManageEndpoints manages endpoints in an endpoint group based on the desired state. +// It implements drift detection by comparing existing endpoints with desired ones, +// then performs necessary additions, updates, and removals to reconcile the state. +// +// This implementation optimizes API usage based on the type of changes needed: +// 1. For updates to existing endpoints: Uses UpdateEndpointGroup API which can handle both +// new and updated endpoints in a single call (since AddEndpoints API doesn't support updates) +// 2. For simple additions/removals: Uses more efficient AddEndpoints and RemoveEndpoints APIs +// +// Following AWS Global Accelerator best practices, this implementation: +// 1. Adds endpoints first, then removes later to minimize connection disruption +// 2. Handles LimitExceededException by implementing a flip-flop Delete-Create pattern +// where some existing endpoints are removed first to make room for new additions +func (m *defaultEndpointGroupManager) ManageEndpoints( + ctx context.Context, + endpointGroupARN string, + resEndpointConfigs []agamodel.EndpointConfiguration, + sdkEndpoints []agatypes.EndpointDescription) error { + + // Early return if there are no endpoints to manage + if len(resEndpointConfigs) == 0 && len(sdkEndpoints) == 0 { + m.logger.V(1).Info("No endpoint configurations found for endpoint group", "endpointGroupARN", endpointGroupARN) + return nil + } + + // Determine drift (endpoints to add/update/remove) + configsToAdd, configsToUpdate, endpointsToRemove, isUpdateRequired := m.detectEndpointDrift(sdkEndpoints, resEndpointConfigs) + + if len(configsToAdd) == 0 && len(endpointsToRemove) == 0 && !isUpdateRequired { + m.logger.V(1).Info("No drift found for endpoint group", "endpointGroupARN", endpointGroupARN) + return nil + } + + m.logger.V(1).Info("Managing endpoints for endpoint group", + "endpointGroupARN", endpointGroupARN, + "addCount", len(configsToAdd), + "updateCount", len(configsToUpdate), + "removeCount", len(endpointsToRemove), + "updateRequired", isUpdateRequired) + + // add-endpoints API doesn't support updating existing endpoints so we need to use update-endpoint-groups API for updates and add + if isUpdateRequired { + // Use UpdateEndpointGroup API to handle both adds and updates + updatedConfigs := append(configsToAdd, configsToUpdate...) + + endpointConfigs := make([]agatypes.EndpointConfiguration, 0, len(updatedConfigs)) + for _, config := range updatedConfigs { + endpointConfigs = append(endpointConfigs, m.buildSDKEndpointConfiguration(config)) + } + + // Call UpdateEndpointGroup with all configs + updateInput := &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(endpointGroupARN), + EndpointConfigurations: endpointConfigs, + } + + if _, err := m.gaService.UpdateEndpointGroupWithContext(ctx, updateInput); err != nil { + return fmt.Errorf("failed to update endpoint group %s: %w", endpointGroupARN, err) + } + return nil + } + + // This is pure add and remove case. So we can use faster and efficient APIs + // Try adding endpoints first - this follows AWS best practice to minimize connection disruption + if len(configsToAdd) > 0 { + err := m.addEndpoints(ctx, endpointGroupARN, configsToAdd) + // If we hit a limit exception, we need to use flip-flop Delete-Create pattern + var apiErr *agatypes.LimitExceededException + if errors.As(err, &apiErr) { + m.logger.V(1).Info("Hit endpoint limit, will remove some endpoints first and retry additions", + "endpointGroupARN", endpointGroupARN) + // Only proceed with flip-flop if we have endpoints to remove + if len(endpointsToRemove) > 0 { + if err := m.flipFlopEndpoints(ctx, endpointGroupARN, configsToAdd, endpointsToRemove); err != nil { + return err + } + // All endpoints processed with flip-flop, so we're done + return nil + } + // If no endpoints to remove but hit limit, just return the original error + return fmt.Errorf("failed to add endpoints due to limit and no endpoints available to remove for endpoint group %s: %w", endpointGroupARN, err) + } else if err != nil { + // For any other error, return it directly + return fmt.Errorf("failed to add endpoints for endpoint group %s: %w", endpointGroupARN, err) + } + } + + // Now remove endpoints that are no longer needed + // (We do this after successfully adding to minimize connection disruption) + if len(endpointsToRemove) > 0 { + if err := m.removeEndpoints(ctx, endpointGroupARN, endpointsToRemove); err != nil { + return err + } + } + + return nil +} + +// addEndpoints adds endpoints to the endpoint group +func (m *defaultEndpointGroupManager) addEndpoints( + ctx context.Context, + endpointGroupARN string, + configsToAdd []agamodel.EndpointConfiguration) error { + + // Skip if no endpoints to add + if len(configsToAdd) == 0 { + return nil + } + + // Convert endpoint configurations to SDK format + endpointConfigs := make([]agatypes.EndpointConfiguration, 0, len(configsToAdd)) + for _, config := range configsToAdd { + endpointConfigs = append(endpointConfigs, m.buildSDKEndpointConfiguration(config)) + } + + // Prepare and execute the request + addInput := &globalaccelerator.AddEndpointsInput{ + EndpointGroupArn: awssdk.String(endpointGroupARN), + EndpointConfigurations: endpointConfigs, + } + + if _, err := m.gaService.AddEndpointsWithContext(ctx, addInput); err != nil { + return fmt.Errorf("failed to add endpoints to endpoint group %s: %w", endpointGroupARN, err) + } + + m.logger.V(1).Info("Successfully added endpoints", + "endpointGroupARN", endpointGroupARN, + "count", len(endpointConfigs)) + return nil +} + +// removeEndpoints removes endpoints from the endpoint group +func (m *defaultEndpointGroupManager) removeEndpoints( + ctx context.Context, + endpointGroupARN string, + endpointsToRemove []string) error { + + // Skip if no endpoints to remove + if len(endpointsToRemove) == 0 { + return nil + } + + // Convert string endpoint IDs to EndpointIdentifier objects + endpointIdentifiers := make([]agatypes.EndpointIdentifier, len(endpointsToRemove)) + for i, endpointID := range endpointsToRemove { + endpointIdentifiers[i] = agatypes.EndpointIdentifier{ + EndpointId: awssdk.String(endpointID), + } + } + + // Create and execute the request + removeInput := &globalaccelerator.RemoveEndpointsInput{ + EndpointGroupArn: awssdk.String(endpointGroupARN), + EndpointIdentifiers: endpointIdentifiers, + } + + if _, err := m.gaService.RemoveEndpointsWithContext(ctx, removeInput); err != nil { + return fmt.Errorf("failed to remove endpoints from endpoint group %s: %w", endpointGroupARN, err) + } + + m.logger.V(1).Info("Successfully removed endpoints", + "endpointGroupARN", endpointGroupARN, + "count", len(endpointIdentifiers)) + return nil +} + +// flipFlopEndpoints implements a simplified flip-flop Delete-Create pattern: +// 1. Remove all existing endpoints that need to be removed +// 2. Add all new endpoints at once +// This simple approach ensures we have room to add all new endpoints by removing old ones first. +func (m *defaultEndpointGroupManager) flipFlopEndpoints( + ctx context.Context, + endpointGroupARN string, + configsToAdd []agamodel.EndpointConfiguration, + endpointsToRemove []string) error { + + // First, remove all endpoints that need to be removed + m.logger.V(1).Info("Flip-flop: Removing all endpoints to make room", + "endpointGroupARN", endpointGroupARN, + "removingCount", len(endpointsToRemove)) + + if err := m.removeEndpoints(ctx, endpointGroupARN, endpointsToRemove); err != nil { + return fmt.Errorf("flip-flop: failed to remove endpoints for endpoint group %s: %w", endpointGroupARN, err) + } + + // Then, add all new endpoints at once + m.logger.V(1).Info("Flip-flop: Adding all new endpoints", + "endpointGroupARN", endpointGroupARN, + "addingCount", len(configsToAdd)) + + if err := m.addEndpoints(ctx, endpointGroupARN, configsToAdd); err != nil { + return fmt.Errorf("flip-flop: failed to add endpoints after removing old ones for endpoint group %s: %w", endpointGroupARN, err) + } + + return nil +} diff --git a/pkg/deploy/aga/endpoint_group_manager_mocks.go b/pkg/deploy/aga/endpoint_group_manager_mocks.go index d3109662c..0020771b2 100644 --- a/pkg/deploy/aga/endpoint_group_manager_mocks.go +++ b/pkg/deploy/aga/endpoint_group_manager_mocks.go @@ -65,6 +65,20 @@ func (mr *MockEndpointGroupManagerMockRecorder) Delete(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockEndpointGroupManager)(nil).Delete), arg0, arg1) } +// ManageEndpoints mocks base method. +func (m *MockEndpointGroupManager) ManageEndpoints(arg0 context.Context, arg1 string, arg2 []aga.EndpointConfiguration, arg3 []types.EndpointDescription) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ManageEndpoints", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// ManageEndpoints indicates an expected call of ManageEndpoints. +func (mr *MockEndpointGroupManagerMockRecorder) ManageEndpoints(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ManageEndpoints", reflect.TypeOf((*MockEndpointGroupManager)(nil).ManageEndpoints), arg0, arg1, arg2, arg3) +} + // Update mocks base method. func (m *MockEndpointGroupManager) Update(arg0 context.Context, arg1 *aga.EndpointGroup, arg2 *types.EndpointGroup) (aga.EndpointGroupStatus, error) { m.ctrl.T.Helper() diff --git a/pkg/deploy/aga/endpoint_group_manager_test.go b/pkg/deploy/aga/endpoint_group_manager_test.go index 53959afa8..7392a14be 100644 --- a/pkg/deploy/aga/endpoint_group_manager_test.go +++ b/pkg/deploy/aga/endpoint_group_manager_test.go @@ -2,7 +2,13 @@ package aga import ( "context" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "sort" "testing" "github.com/aws/aws-sdk-go-v2/aws" @@ -96,6 +102,209 @@ func Test_defaultEndpointGroupManager_buildSDKPortOverrides(t *testing.T) { } } +func Test_defaultEndpointGroupManager_isEndpointConfigurationDrifted(t *testing.T) { + tests := []struct { + name string + desiredConfig agamodel.EndpointConfiguration + existingEndpoint agatypes.EndpointDescription + want bool + }{ + { + name: "no drift - both weight and client IP preservation nil", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: nil, + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: nil, + }, + want: false, + }, + { + name: "weight drift - desired nil, existing not nil", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: nil, + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: nil, + }, + want: true, + }, + { + name: "weight drift - desired not nil, existing nil", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: nil, + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: nil, + }, + want: true, + }, + { + name: "weight drift - both not nil but different values", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: aws.Int32(80), + ClientIPPreservationEnabled: nil, + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: nil, + }, + want: true, + }, + { + name: "no weight drift - both not nil with same values", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: nil, + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: nil, + }, + want: false, + }, + { + name: "client IP preservation drift - desired nil, existing not nil", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: nil, + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(true), + }, + want: true, + }, + { + name: "client IP preservation drift - desired not nil, existing nil", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(true), + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: nil, + }, + want: true, + }, + { + name: "client IP preservation drift - both not nil but different values (true vs false)", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(true), + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(false), + }, + want: true, + }, + { + name: "client IP preservation drift - both not nil but different values (false vs true)", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(false), + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(true), + }, + want: true, + }, + { + name: "no client IP preservation drift - both not nil with same values (both true)", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(true), + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(true), + }, + want: false, + }, + { + name: "no client IP preservation drift - both not nil with same values (both false)", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(false), + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: nil, + ClientIPPreservationEnabled: aws.Bool(false), + }, + want: false, + }, + { + name: "drift in both weight and client IP preservation", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: aws.Int32(80), + ClientIPPreservationEnabled: aws.Bool(true), + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(false), + }, + want: true, + }, + { + name: "no drift - both weight and client IP preservation have same non-nil values", + desiredConfig: agamodel.EndpointConfiguration{ + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + existingEndpoint: agatypes.EndpointDescription{ + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + got := m.isEndpointConfigurationDrifted(tt.desiredConfig, tt.existingEndpoint) + assert.Equal(t, tt.want, got) + }) + } +} + func Test_defaultEndpointGroupManager_arePortOverridesEqual(t *testing.T) { tests := []struct { name string @@ -391,6 +600,297 @@ func Test_defaultEndpointGroupManager_isSDKEndpointGroupSettingsDrifted(t *testi } } +func Test_defaultEndpointGroupManager_detectEndpointDrift(t *testing.T) { + tests := []struct { + name string + existingEndpoints []agatypes.EndpointDescription + desiredConfigs []agamodel.EndpointConfiguration + wantConfigsToAdd []agamodel.EndpointConfiguration + wantConfigsToUpdate []agamodel.EndpointConfiguration + wantEndpointsToRemove []string + wantIsUpdateRequired bool + }{ + { + name: "no endpoints - empty lists", + existingEndpoints: []agatypes.EndpointDescription{}, + desiredConfigs: []agamodel.EndpointConfiguration{}, + wantConfigsToAdd: []agamodel.EndpointConfiguration{}, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{}, + wantEndpointsToRemove: []string{}, + wantIsUpdateRequired: false, + }, + { + name: "add new endpoint - no existing endpoints", + existingEndpoints: []agatypes.EndpointDescription{}, + desiredConfigs: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + }, + }, + wantConfigsToAdd: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + }, + }, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{}, + wantEndpointsToRemove: []string{}, + wantIsUpdateRequired: false, + }, + { + name: "remove endpoint - no desired endpoints", + existingEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: aws.String("endpoint-1"), + }, + }, + desiredConfigs: []agamodel.EndpointConfiguration{}, + wantConfigsToAdd: []agamodel.EndpointConfiguration{}, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{}, + wantEndpointsToRemove: []string{"endpoint-1"}, + wantIsUpdateRequired: false, + }, + { + name: "no change - same endpoints with same configuration", + existingEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + desiredConfigs: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + wantConfigsToAdd: []agamodel.EndpointConfiguration{}, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + wantEndpointsToRemove: []string{}, + wantIsUpdateRequired: false, + }, + { + name: "update endpoint - weight drift", + existingEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(80), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + desiredConfigs: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + wantConfigsToAdd: []agamodel.EndpointConfiguration{}, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + wantEndpointsToRemove: []string{}, + wantIsUpdateRequired: true, + }, + { + name: "update endpoint - client IP preservation drift", + existingEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(false), + }, + }, + desiredConfigs: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + wantConfigsToAdd: []agamodel.EndpointConfiguration{}, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(true), + }, + }, + wantEndpointsToRemove: []string{}, + wantIsUpdateRequired: true, + }, + { + name: "multiple actions - add, update, remove endpoints", + existingEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: aws.String("endpoint-1"), + Weight: aws.Int32(80), + }, + { + EndpointId: aws.String("endpoint-2"), + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(false), + }, + { + EndpointId: aws.String("endpoint-to-remove"), + Weight: aws.Int32(50), + }, + }, + desiredConfigs: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), // Changed weight + }, + { + EndpointID: "endpoint-2", + Weight: aws.Int32(100), // No change + ClientIPPreservationEnabled: aws.Bool(false), + }, + { + EndpointID: "endpoint-new", // New endpoint + Weight: aws.Int32(100), + }, + }, + wantConfigsToAdd: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-new", + Weight: aws.Int32(100), + }, + }, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-1", + Weight: aws.Int32(100), + }, + { + EndpointID: "endpoint-2", + Weight: aws.Int32(100), + ClientIPPreservationEnabled: aws.Bool(false), + }, + }, + wantEndpointsToRemove: []string{"endpoint-to-remove"}, + wantIsUpdateRequired: true, // Because endpoint-1 weight changed + }, + { + name: "nil endpoint ID in existing endpoints", + existingEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: nil, // Should be skipped + Weight: aws.Int32(80), + }, + { + EndpointId: aws.String("endpoint-2"), + Weight: aws.Int32(100), + }, + }, + desiredConfigs: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-2", + Weight: aws.Int32(100), + }, + }, + wantConfigsToAdd: []agamodel.EndpointConfiguration{}, + wantConfigsToUpdate: []agamodel.EndpointConfiguration{ + { + EndpointID: "endpoint-2", + Weight: aws.Int32(100), + }, + }, + wantEndpointsToRemove: []string{}, + wantIsUpdateRequired: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + gotConfigsToAdd, gotConfigsToUpdate, gotEndpointsToRemove, gotIsUpdateRequired := m.detectEndpointDrift(tt.existingEndpoints, tt.desiredConfigs) + + // Sort slices for deterministic comparison + sort.Slice(gotConfigsToAdd, func(i, j int) bool { + return gotConfigsToAdd[i].EndpointID < gotConfigsToAdd[j].EndpointID + }) + sort.Slice(tt.wantConfigsToAdd, func(i, j int) bool { + return tt.wantConfigsToAdd[i].EndpointID < tt.wantConfigsToAdd[j].EndpointID + }) + sort.Slice(gotConfigsToUpdate, func(i, j int) bool { + return gotConfigsToUpdate[i].EndpointID < gotConfigsToUpdate[j].EndpointID + }) + sort.Slice(tt.wantConfigsToUpdate, func(i, j int) bool { + return tt.wantConfigsToUpdate[i].EndpointID < tt.wantConfigsToUpdate[j].EndpointID + }) + sort.Strings(gotEndpointsToRemove) + sort.Strings(tt.wantEndpointsToRemove) + + // Check if configsToAdd matches expected + assert.Equal(t, len(tt.wantConfigsToAdd), len(gotConfigsToAdd), "configsToAdd length mismatch") + for i, config := range tt.wantConfigsToAdd { + if i < len(gotConfigsToAdd) { + assert.Equal(t, config.EndpointID, gotConfigsToAdd[i].EndpointID, "EndpointID mismatch") + + // Check Weight + if config.Weight == nil { + assert.Nil(t, gotConfigsToAdd[i].Weight, "Weight should be nil") + } else if gotConfigsToAdd[i].Weight != nil { + assert.Equal(t, *config.Weight, *gotConfigsToAdd[i].Weight, "Weight value mismatch") + } + + // Check ClientIPPreservationEnabled + if config.ClientIPPreservationEnabled == nil { + assert.Nil(t, gotConfigsToAdd[i].ClientIPPreservationEnabled, "ClientIPPreservationEnabled should be nil") + } else if gotConfigsToAdd[i].ClientIPPreservationEnabled != nil { + assert.Equal(t, *config.ClientIPPreservationEnabled, *gotConfigsToAdd[i].ClientIPPreservationEnabled, "ClientIPPreservationEnabled value mismatch") + } + } + } + + // Check if configsToUpdate matches expected + assert.Equal(t, len(tt.wantConfigsToUpdate), len(gotConfigsToUpdate), "configsToUpdate length mismatch") + for i, config := range tt.wantConfigsToUpdate { + if i < len(gotConfigsToUpdate) { + assert.Equal(t, config.EndpointID, gotConfigsToUpdate[i].EndpointID, "EndpointID mismatch") + + // Check Weight + if config.Weight == nil { + assert.Nil(t, gotConfigsToUpdate[i].Weight, "Weight should be nil") + } else if gotConfigsToUpdate[i].Weight != nil { + assert.Equal(t, *config.Weight, *gotConfigsToUpdate[i].Weight, "Weight value mismatch") + } + + // Check ClientIPPreservationEnabled + if config.ClientIPPreservationEnabled == nil { + assert.Nil(t, gotConfigsToUpdate[i].ClientIPPreservationEnabled, "ClientIPPreservationEnabled should be nil") + } else if gotConfigsToUpdate[i].ClientIPPreservationEnabled != nil { + assert.Equal(t, *config.ClientIPPreservationEnabled, *gotConfigsToUpdate[i].ClientIPPreservationEnabled, "ClientIPPreservationEnabled value mismatch") + } + } + } + + // Check if endpointsToRemove matches expected + assert.Equal(t, tt.wantEndpointsToRemove, gotEndpointsToRemove, "endpointsToRemove mismatch") + + // Check if isUpdateRequired matches expected + assert.Equal(t, tt.wantIsUpdateRequired, gotIsUpdateRequired, "isUpdateRequired mismatch") + }) + } +} + func Test_defaultEndpointGroupManager_buildSDKCreateEndpointGroupInput(t *testing.T) { testListenerARN := "arn:aws:globalaccelerator::123456789012:listener/1234abcd-abcd-1234-abcd-1234abcdefgh/abcdefghi" mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) @@ -478,3 +978,343 @@ func Test_defaultEndpointGroupManager_buildSDKCreateEndpointGroupInput(t *testin }) } } + +func Test_ManageEndpoints(t *testing.T) { + testCases := []struct { + name string + endpointGroupARN string + currentEndpoints []agatypes.EndpointDescription + loadedEndpoints []*aga.LoadedEndpoint + expectError bool + describeEndpointErr error + addEndpointsErr error + addEndpointsErrOnFirstTry bool // For testing limit exceeded with flip-flop pattern + removeEndpointsErr error + expectAddCall bool + expectRemoveCall bool + expectUpdateCall bool // Whether to expect update-endpoint-group API call due to property drift + expectFlipFlopPattern bool // Whether to expect flip-flop delete-create pattern + }{ + { + name: "limit exceeded - flip-flop pattern", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/existing-lb1/1111111111"), + }, + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/existing-lb2/2222222222"), + }, + }, + loadedEndpoints: []*aga.LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "new-service-1", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/new-lb1/3333333333", + Status: aga.EndpointStatusLoaded, + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "new-service-2", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/new-lb2/4444444444", + Status: aga.EndpointStatusLoaded, + }, + }, + addEndpointsErrOnFirstTry: true, // First AddEndpoints call fails with LimitExceededException + expectError: false, + expectAddCall: true, + expectRemoveCall: true, + expectUpdateCall: false, + expectFlipFlopPattern: true, + }, + { + name: "endpoint property drift - update endpoint-group API call", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/existing-lb/1111111111"), + Weight: awssdk.Int32(80), // Different weight + ClientIPPreservationEnabled: awssdk.Bool(false), + }, + }, + loadedEndpoints: []*aga.LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "existing-service", + Namespace: "default", + Weight: 100, // Changed weight + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/existing-lb/1111111111", + Status: aga.EndpointStatusLoaded, + }, + }, + expectError: false, + expectAddCall: false, // Should not call AddEndpoints + expectRemoveCall: false, // Should not call RemoveEndpoints + expectUpdateCall: true, // Should call UpdateEndpointGroup + }, + { + name: "endpoints to remove only", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb/1234567890"), + }, + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb2/0987654321"), + }, + }, + loadedEndpoints: []*aga.LoadedEndpoint{}, + expectError: false, + expectAddCall: false, + expectRemoveCall: true, + }, + { + name: "both add and remove endpoints", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb/1234567890"), + }, + }, + loadedEndpoints: []*aga.LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service-2", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb2/0987654321", + Status: aga.EndpointStatusLoaded, + }, + }, + expectError: false, + expectAddCall: true, + expectRemoveCall: true, + }, + { + name: "add endpoints error", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{}, + loadedEndpoints: []*aga.LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service-1", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb/1234567890", + Status: aga.EndpointStatusLoaded, + }, + }, + addEndpointsErr: errors.New("add error"), + expectError: true, + expectAddCall: true, + expectRemoveCall: false, + }, + { + name: "remove endpoints error", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb/1234567890"), + }, + }, + loadedEndpoints: []*aga.LoadedEndpoint{}, + removeEndpointsErr: errors.New("remove error"), + expectError: true, + expectAddCall: false, + expectRemoveCall: true, + }, + { + name: "add and remove with remove error", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb/1234567890"), + }, + }, + loadedEndpoints: []*aga.LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service-2", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb2/0987654321", + Status: aga.EndpointStatusLoaded, + }, + }, + removeEndpointsErr: errors.New("remove error"), + expectError: true, + expectAddCall: true, + expectRemoveCall: true, + }, + { + name: "endpoint with failed status", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{}, + loadedEndpoints: []*aga.LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service-1", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb/1234567890", + Status: aga.EndpointStatusFatal, + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service-2", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/my-lb2/0987654321", + Status: aga.EndpointStatusLoaded, + }, + }, + expectError: false, + expectAddCall: true, // Should call add for the loaded endpoint + expectRemoveCall: false, // No endpoints to remove + }, + { + name: "limit exceeded - flip-flop pattern", + endpointGroupARN: "arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234", + currentEndpoints: []agatypes.EndpointDescription{ + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/existing-lb1/1111111111"), + }, + { + EndpointId: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/existing-lb2/2222222222"), + }, + }, + loadedEndpoints: []*aga.LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "new-service-1", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/new-lb1/3333333333", + Status: aga.EndpointStatusLoaded, + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "new-service-2", + Namespace: "default", + Weight: 100, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/new-lb2/4444444444", + Status: aga.EndpointStatusLoaded, + }, + }, + addEndpointsErrOnFirstTry: true, // First AddEndpoints call fails with LimitExceededException + expectError: false, + expectAddCall: true, + expectRemoveCall: true, + expectFlipFlopPattern: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockGaService := services.NewMockGlobalAccelerator(ctrl) + + // Setup expectations for DescribeEndpointGroup + describeOutput := &globalaccelerator.DescribeEndpointGroupOutput{ + EndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupArn: awssdk.String(tc.endpointGroupARN), + EndpointDescriptions: tc.currentEndpoints, + }, + } + mockGaService.EXPECT(). + DescribeEndpointGroupWithContext(gomock.Any(), gomock.Any()). + Return(describeOutput, tc.describeEndpointErr). + AnyTimes() + + // Setup expectations for UpdateEndpointGroup if applicable (for drift in properties) + if tc.expectUpdateCall { + mockGaService.EXPECT(). + UpdateEndpointGroupWithContext(gomock.Any(), gomock.Any()). + Return(&globalaccelerator.UpdateEndpointGroupOutput{}, nil). + Times(1) + } else if tc.expectAddCall { // Don't expect any other API calls if using UpdateEndpointGroup + if tc.expectFlipFlopPattern { + // For flip-flop pattern, we first expect one AddEndpoints call that returns LimitExceededException + limitExceededErr := &agatypes.LimitExceededException{ + Message: awssdk.String("Endpoint limit exceeded"), + } + firstAddCall := mockGaService.EXPECT(). + AddEndpointsWithContext(gomock.Any(), gomock.Any()). + Return(nil, limitExceededErr). + Times(1) + + // Then we expect individual AddEndpoints calls for each endpoint (after removal) + // These calls should succeed + mockGaService.EXPECT(). + AddEndpointsWithContext(gomock.Any(), gomock.Any()). + Return(&globalaccelerator.AddEndpointsOutput{}, nil). + After(firstAddCall). + AnyTimes() + } else if tc.addEndpointsErrOnFirstTry { + // Set up a sequence for error on first try only + limitExceededErr := &agatypes.LimitExceededException{ + Message: awssdk.String("Endpoint limit exceeded"), + } + mockGaService.EXPECT(). + AddEndpointsWithContext(gomock.Any(), gomock.Any()). + Return(nil, limitExceededErr). + Times(1) + + mockGaService.EXPECT(). + AddEndpointsWithContext(gomock.Any(), gomock.Any()). + Return(&globalaccelerator.AddEndpointsOutput{}, nil). + AnyTimes() + } else { + // Standard expectation for normal cases + mockGaService.EXPECT(). + AddEndpointsWithContext(gomock.Any(), gomock.Any()). + Return(&globalaccelerator.AddEndpointsOutput{}, tc.addEndpointsErr). + AnyTimes() + } + } + + // Setup expectations for RemoveEndpoints if applicable + if tc.expectRemoveCall { + mockGaService.EXPECT(). + RemoveEndpointsWithContext(gomock.Any(), gomock.Any()). + Return(&globalaccelerator.RemoveEndpointsOutput{}, tc.removeEndpointsErr). + AnyTimes() + } + + manager := &defaultEndpointGroupManager{ + gaService: mockGaService, + logger: logr.Discard(), + } + + // Convert LoadedEndpoints to EndpointConfigurations + endpointConfigs := []agamodel.EndpointConfiguration{} + for _, loadedEndpoint := range tc.loadedEndpoints { + if loadedEndpoint.Status == aga.EndpointStatusLoaded { + endpointConfig := agamodel.EndpointConfiguration{ + EndpointID: loadedEndpoint.ARN, + } + if loadedEndpoint.Weight > 0 { + weight := int32(loadedEndpoint.Weight) + endpointConfig.Weight = &weight + } + endpointConfigs = append(endpointConfigs, endpointConfig) + } + } + + // Add more test-specific handling if needed here + err := manager.ManageEndpoints(context.Background(), tc.endpointGroupARN, endpointConfigs, tc.currentEndpoints) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/k8s/events.go b/pkg/k8s/events.go index 42c966f2f..5d6c5e178 100644 --- a/pkg/k8s/events.go +++ b/pkg/k8s/events.go @@ -54,5 +54,6 @@ const ( GlobalAcceleratorEventReasonFailedBuildModel = "FailedBuildModel" GlobalAcceleratorEventReasonFailedEndpointLoad = "FailedEndpointLoad" GlobalAcceleratorEventReasonFailedDeploy = "FailedDeploy" + GlobalAcceleratorEventReasonWarningEndpoints = "WarningEndpoints" GlobalAcceleratorEventReasonSuccessfullyReconciled = "SuccessfullyReconciled" ) diff --git a/pkg/model/aga/endpoint_group.go b/pkg/model/aga/endpoint_group.go index dfa5e1ba0..d6384b954 100644 --- a/pkg/model/aga/endpoint_group.go +++ b/pkg/model/aga/endpoint_group.go @@ -74,6 +74,20 @@ type PortOverride struct { EndpointPort int32 `json:"endpointPort"` } +// EndpointConfiguration defines an endpoint configuration for Global Accelerator endpoint groups. +type EndpointConfiguration struct { + // EndpointID is the ID of the endpoint. + EndpointID string `json:"endpointId"` + + // Weight determines the proportion of traffic that is directed to the endpoint. + // +optional + Weight *int32 `json:"weight,omitempty"` + + // ClientIPPreservationEnabled indicates whether client IP preservation is enabled for this endpoint. + // +optional + ClientIPPreservationEnabled *bool `json:"clientIPPreservationEnabled,omitempty"` +} + // EndpointGroupSpec defines the desired state of EndpointGroup type EndpointGroupSpec struct { // ListenerARN is the ARN of the listener for the endpoint group @@ -92,7 +106,7 @@ type EndpointGroupSpec struct { // EndpointConfigurations is a list of endpoint configurations for the endpoint group. // +optional - // This field is not implemented in the initial version as it will be part of a separate endpoint builder. + EndpointConfigurations []EndpointConfiguration `json:"endpointConfigurations,omitempty"` } // EndpointGroupStatus defines the observed state of EndpointGroup