|
| 1 | +package aga |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + awssdk "github.com/aws/aws-sdk-go-v2/aws" |
| 7 | + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" |
| 8 | + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" |
| 9 | + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" |
| 10 | +) |
| 11 | + |
| 12 | +// endpointGroupBuilder builds EndpointGroup model resources |
| 13 | +type endpointGroupBuilder interface { |
| 14 | + // Build builds all endpoint groups for all listeners |
| 15 | + Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) |
| 16 | + |
| 17 | + // buildEndpointGroupsForListener builds endpoint groups for a specific listener |
| 18 | + buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) |
| 19 | +} |
| 20 | + |
| 21 | +// NewEndpointGroupBuilder constructs new endpointGroupBuilder |
| 22 | +func NewEndpointGroupBuilder(clusterRegion string) endpointGroupBuilder { |
| 23 | + return &defaultEndpointGroupBuilder{ |
| 24 | + clusterRegion: clusterRegion, |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +var _ endpointGroupBuilder = &defaultEndpointGroupBuilder{} |
| 29 | + |
| 30 | +type defaultEndpointGroupBuilder struct { |
| 31 | + clusterRegion string |
| 32 | +} |
| 33 | + |
| 34 | +// Build builds EndpointGroup model resources |
| 35 | +func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) { |
| 36 | + if listeners == nil || len(listeners) == 0 { |
| 37 | + return nil, nil |
| 38 | + } |
| 39 | + |
| 40 | + var result []*agamodel.EndpointGroup |
| 41 | + |
| 42 | + // Create a map of all listener port ranges |
| 43 | + listenerPortRanges := make(map[string][]agamodel.PortRange) // Maps listener ID to its port ranges |
| 44 | + for _, listener := range listeners { |
| 45 | + listenerPortRanges[listener.ID()] = listener.Spec.PortRanges |
| 46 | + } |
| 47 | + |
| 48 | + for i, listener := range listeners { |
| 49 | + listenerConfig := listenerConfigs[i] |
| 50 | + if listenerConfig.EndpointGroups == nil { |
| 51 | + continue |
| 52 | + } |
| 53 | + |
| 54 | + listenerEndpointGroups, err := b.buildEndpointGroupsForListener(ctx, stack, listener, *listenerConfig.EndpointGroups, i) |
| 55 | + if err != nil { |
| 56 | + return nil, err |
| 57 | + } |
| 58 | + result = append(result, listenerEndpointGroups...) |
| 59 | + } |
| 60 | + |
| 61 | + // Validate endpoint ports in all port overrides across all listeners |
| 62 | + if err := b.validateEndpointPortOverridesCrossListeners(result, listenerPortRanges); err != nil { |
| 63 | + return nil, err |
| 64 | + } |
| 65 | + |
| 66 | + return result, nil |
| 67 | +} |
| 68 | + |
| 69 | +// validateEndpointPortOverridesCrossListeners performs validations for endpoint port overrides across all listeners |
| 70 | +func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesCrossListeners(endpointGroups []*agamodel.EndpointGroup, listenerPortRanges map[string][]agamodel.PortRange) error { |
| 71 | + // Track endpoint port usage across all endpoint groups |
| 72 | + endpointPortUsage := make(map[int32]string) // Maps endpoint port to listener ID |
| 73 | + |
| 74 | + // Check all endpoint groups for port overrides |
| 75 | + for _, endpointGroup := range endpointGroups { |
| 76 | + listenerID := endpointGroup.Listener.ID() |
| 77 | + |
| 78 | + for _, portOverride := range endpointGroup.Spec.PortOverrides { |
| 79 | + endpointPort := portOverride.EndpointPort |
| 80 | + |
| 81 | + // Rule 1: Check if endpoint port is within any listener's port range |
| 82 | + if err := b.validateEndpointPortOverridesWithinListener(endpointPort, listenerPortRanges); err != nil { |
| 83 | + return err |
| 84 | + } |
| 85 | + |
| 86 | + // Rule 2: Check for duplicate endpoint port usage across listeners |
| 87 | + if existingListenerID, exists := endpointPortUsage[endpointPort]; exists && existingListenerID != listenerID { |
| 88 | + return fmt.Errorf("duplicate endpoint port %d: the same endpoint port cannot be used in port overrides from different listeners (used in %s and %s)", |
| 89 | + endpointPort, existingListenerID, listenerID) |
| 90 | + } |
| 91 | + |
| 92 | + // Register this endpoint port usage |
| 93 | + endpointPortUsage[endpointPort] = listenerID |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + return nil |
| 98 | +} |
| 99 | + |
| 100 | +// validateEndpointPortOverridesWithinListener checks if an endpoint port is within any listener's port range |
| 101 | +func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesWithinListener(endpointPort int32, listenerPortRanges map[string][]agamodel.PortRange) error { |
| 102 | + for listenerID, portRanges := range listenerPortRanges { |
| 103 | + if IsPortInRanges(endpointPort, portRanges) { |
| 104 | + // Find the specific port range for the error message |
| 105 | + for _, portRange := range portRanges { |
| 106 | + if endpointPort >= portRange.FromPort && endpointPort <= portRange.ToPort { |
| 107 | + return fmt.Errorf("endpoint port %d conflicts with listener %s port range %d-%d: endpoint port cannot be included in any listener port range", |
| 108 | + endpointPort, listenerID, portRange.FromPort, portRange.ToPort) |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | + } |
| 113 | + return nil |
| 114 | +} |
| 115 | + |
| 116 | +// buildEndpointGroupsForListener builds EndpointGroup models for a specific listener |
| 117 | +func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) { |
| 118 | + var result []*agamodel.EndpointGroup |
| 119 | + |
| 120 | + for i, endpointGroup := range endpointGroups { |
| 121 | + spec, err := b.buildEndpointGroupSpec(ctx, listener, endpointGroup) |
| 122 | + if err != nil { |
| 123 | + return nil, err |
| 124 | + } |
| 125 | + |
| 126 | + resourceID := fmt.Sprintf("EndpointGroup-%d-%d", listenerIndex, i) |
| 127 | + endpointGroupModel := agamodel.NewEndpointGroup(stack, resourceID, spec, listener) |
| 128 | + result = append(result, endpointGroupModel) |
| 129 | + } |
| 130 | + |
| 131 | + return result, nil |
| 132 | +} |
| 133 | + |
| 134 | +// buildEndpointGroupSpec builds the EndpointGroupSpec for a single EndpointGroup model resource |
| 135 | +func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (agamodel.EndpointGroupSpec, error) { |
| 136 | + region, err := b.determineRegion(endpointGroup) |
| 137 | + if err != nil { |
| 138 | + return agamodel.EndpointGroupSpec{}, err |
| 139 | + } |
| 140 | + |
| 141 | + // Handle trafficDialPercentage |
| 142 | + trafficDialPercentage := endpointGroup.TrafficDialPercentage |
| 143 | + |
| 144 | + portOverrides, err := b.buildPortOverrides(ctx, listener, endpointGroup) |
| 145 | + if err != nil { |
| 146 | + return agamodel.EndpointGroupSpec{}, err |
| 147 | + } |
| 148 | + |
| 149 | + return agamodel.EndpointGroupSpec{ |
| 150 | + ListenerARN: listener.ListenerARN(), |
| 151 | + Region: region, |
| 152 | + TrafficDialPercentage: trafficDialPercentage, |
| 153 | + PortOverrides: portOverrides, |
| 154 | + }, nil |
| 155 | +} |
| 156 | + |
| 157 | +// validateListenerPortOverrideWithinListenerPortRanges ensures all listener ports used in port overrides are |
| 158 | +// contained within the listener's port ranges |
| 159 | +func (b *defaultEndpointGroupBuilder) validateListenerPortOverrideWithinListenerPortRanges(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error { |
| 160 | + if len(portOverrides) == 0 { |
| 161 | + return nil |
| 162 | + } |
| 163 | + |
| 164 | + for _, portOverride := range portOverrides { |
| 165 | + listenerPort := portOverride.ListenerPort |
| 166 | + if !IsPortInRanges(listenerPort, listener.Spec.PortRanges) { |
| 167 | + return fmt.Errorf("port override listener port %d is not within any listener port ranges - this will cause AWS Global Accelerator to reject the configuration", listenerPort) |
| 168 | + } |
| 169 | + } |
| 170 | + return nil |
| 171 | +} |
| 172 | + |
| 173 | +// determineRegion determines the region for the endpoint group |
| 174 | +func (b *defaultEndpointGroupBuilder) determineRegion(endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (string, error) { |
| 175 | + // Use explicit region from endpoint group if specified |
| 176 | + if endpointGroup.Region != nil && awssdk.ToString(endpointGroup.Region) != "" { |
| 177 | + return awssdk.ToString(endpointGroup.Region), nil |
| 178 | + } |
| 179 | + |
| 180 | + // Default to cluster region if available |
| 181 | + if b.clusterRegion != "" { |
| 182 | + return b.clusterRegion, nil |
| 183 | + } |
| 184 | + return "", fmt.Errorf("region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration") |
| 185 | +} |
| 186 | + |
| 187 | +// buildPortOverrides builds the port overrides for the endpoint group |
| 188 | +func (b *defaultEndpointGroupBuilder) buildPortOverrides(_ context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) ([]agamodel.PortOverride, error) { |
| 189 | + if endpointGroup.PortOverrides == nil { |
| 190 | + return nil, nil |
| 191 | + } |
| 192 | + |
| 193 | + var portOverrides []agamodel.PortOverride |
| 194 | + for _, po := range *endpointGroup.PortOverrides { |
| 195 | + portOverrides = append(portOverrides, agamodel.PortOverride{ |
| 196 | + ListenerPort: po.ListenerPort, |
| 197 | + EndpointPort: po.EndpointPort, |
| 198 | + }) |
| 199 | + } |
| 200 | + |
| 201 | + // Validate all port override rules |
| 202 | + if err := b.validatePortOverrides(listener, portOverrides); err != nil { |
| 203 | + return []agamodel.PortOverride{}, err |
| 204 | + } |
| 205 | + |
| 206 | + return portOverrides, nil |
| 207 | +} |
| 208 | + |
| 209 | +// validateNoDuplicatePorts checks both listener and endpoint ports for duplicates in a single pass |
| 210 | +func (b *defaultEndpointGroupBuilder) validateNoDuplicatePorts(portOverrides []agamodel.PortOverride) error { |
| 211 | + if len(portOverrides) <= 1 { |
| 212 | + return nil |
| 213 | + } |
| 214 | + |
| 215 | + listenerPorts := make(map[int32]bool) |
| 216 | + endpointPorts := make(map[int32]bool) |
| 217 | + |
| 218 | + for _, portOverride := range portOverrides { |
| 219 | + // Check for duplicate listener ports |
| 220 | + listenerPort := portOverride.ListenerPort |
| 221 | + if listenerPorts[listenerPort] { |
| 222 | + return fmt.Errorf("duplicate listener port %d in port overrides: each listener port can only be used once in port overrides for an endpoint group", listenerPort) |
| 223 | + } |
| 224 | + listenerPorts[listenerPort] = true |
| 225 | + |
| 226 | + // Check for duplicate endpoint ports |
| 227 | + endpointPort := portOverride.EndpointPort |
| 228 | + if endpointPorts[endpointPort] { |
| 229 | + return fmt.Errorf("duplicate endpoint port %d in port overrides: each endpoint port can only be used once in port overrides for an endpoint group", endpointPort) |
| 230 | + } |
| 231 | + endpointPorts[endpointPort] = true |
| 232 | + } |
| 233 | + |
| 234 | + return nil |
| 235 | +} |
| 236 | + |
| 237 | +// validatePortOverrides is a wrapper function that runs all port override validation rules |
| 238 | +func (b *defaultEndpointGroupBuilder) validatePortOverrides(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error { |
| 239 | + // Validate listener port overrides against listener port ranges |
| 240 | + if err := b.validateListenerPortOverrideWithinListenerPortRanges(listener, portOverrides); err != nil { |
| 241 | + return err |
| 242 | + } |
| 243 | + |
| 244 | + // Check for duplicate listener and endpoint ports within this endpoint group's port overrides |
| 245 | + if err := b.validateNoDuplicatePorts(portOverrides); err != nil { |
| 246 | + return err |
| 247 | + } |
| 248 | + |
| 249 | + return nil |
| 250 | +} |
0 commit comments