Skip to content

Commit 13041fe

Browse files
committed
minor updates and godoc to weighted random picker (#1514)
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 6023213 commit 13041fe

File tree

1 file changed

+40
-56
lines changed

1 file changed

+40
-56
lines changed

pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go

Lines changed: 40 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"math"
2424
"math/rand"
25+
"slices"
2526
"sort"
2627
"time"
2728

@@ -43,12 +44,12 @@ type weightedScoredPod struct {
4344
key float64
4445
}
4546

47+
// compile-time type validation
4648
var _ framework.Picker = &WeightedRandomPicker{}
4749

50+
// WeightedRandomPickerFactory defines the factory function for WeightedRandomPicker.
4851
func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
49-
parameters := pickerParameters{
50-
MaxNumOfEndpoints: DefaultMaxNumOfEndpoints,
51-
}
52+
parameters := pickerParameters{MaxNumOfEndpoints: DefaultMaxNumOfEndpoints}
5253
if rawParameters != nil {
5354
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
5455
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", WeightedRandomPickerType, err)
@@ -58,9 +59,10 @@ func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ p
5859
return NewWeightedRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil
5960
}
6061

62+
// NewWeightedRandomPicker initializes a new WeightedRandomPicker and returns its pointer.
6163
func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker {
6264
if maxNumOfEndpoints <= 0 {
63-
maxNumOfEndpoints = DefaultMaxNumOfEndpoints
65+
maxNumOfEndpoints = DefaultMaxNumOfEndpoints // on invalid configuration value, fallback to default value
6466
}
6567

6668
return &WeightedRandomPicker{
@@ -70,81 +72,68 @@ func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker {
7072
}
7173
}
7274

75+
// WeightedRandomPicker picks pod(s) from the list of candidates based on weighted random sampling using A-Res algorithm.
76+
// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf.
77+
//
78+
// The picker at its core is picking pods randomly, where the probability of the pod to get picked is derived
79+
// from its weighted score.
80+
// Algorithm:
81+
// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ)
82+
// - Selects k items with largest keys for mathematically correct weighted sampling
83+
// - More efficient than traditional cumulative probability approach
84+
//
85+
// Key characteristics:
86+
// - Mathematically correct weighted random sampling
87+
// - Single pass algorithm with O(n + k log k) complexity
7388
type WeightedRandomPicker struct {
7489
typedName plugins.TypedName
7590
maxNumOfEndpoints int
7691
randomPicker *RandomPicker // fallback for zero weights
7792
}
7893

94+
// WithName sets the name of the picker.
7995
func (p *WeightedRandomPicker) WithName(name string) *WeightedRandomPicker {
8096
p.typedName.Name = name
8197
return p
8298
}
8399

100+
// TypedName returns the type and name tuple of this plugin instance.
84101
func (p *WeightedRandomPicker) TypedName() plugins.TypedName {
85102
return p.typedName
86103
}
87104

88-
// WeightedRandomPicker performs weighted random sampling using A-Res algorithm.
89-
// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf
90-
// Algorithm:
91-
// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ)
92-
// - Selects k items with largest keys for mathematically correct weighted sampling
93-
// - More efficient than traditional cumulative probability approach
94-
//
95-
// Key characteristics:
96-
// - Mathematically correct weighted random sampling
97-
// - Single pass algorithm with O(n + k log k) complexity
105+
// Pick selects the pod(s) randomly from the list of candidates, where the probability of the pod to get picked is derived
106+
// from its weighted score.
98107
func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
99-
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates using weighted random sampling: %+v",
100-
p.maxNumOfEndpoints, len(scoredPods), scoredPods))
101-
102-
// Check if all weights are zero or negative
103-
allZeroWeights := true
104-
for _, scoredPod := range scoredPods {
105-
if scoredPod.Score > 0 {
106-
allZeroWeights = false
107-
break
108-
}
109-
}
110-
111-
// Delegate to RandomPicker for uniform selection when all weights are zero
112-
if allZeroWeights {
113-
log.FromContext(ctx).V(logutil.DEBUG).Info("All weights are zero, delegating to RandomPicker for uniform selection")
108+
// Check if there is at least one pod with Score > 0, if not let random picker run
109+
if slices.IndexFunc(scoredPods, func(scoredPod *types.ScoredPod) bool { return scoredPod.Score > 0 }) == -1 {
110+
log.FromContext(ctx).V(logutil.DEBUG).Info("All scores are zero, delegating to RandomPicker for uniform selection")
114111
return p.randomPicker.Pick(ctx, cycleState, scoredPods)
115112
}
116113

114+
log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates by random weighted picker", "max-num-of-endpoints", p.maxNumOfEndpoints,
115+
"num-of-candidates", len(scoredPods), "scored-pods", scoredPods)
116+
117117
randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))
118118

119119
// A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ)
120-
weightedPods := make([]weightedScoredPod, 0, len(scoredPods))
121-
122-
for _, scoredPod := range scoredPods {
123-
weight := float64(scoredPod.Score)
124-
125-
// Handle zero or negative weights
126-
if weight <= 0 {
127-
// Assign very small key for zero-weight pods (effectively excludes them)
128-
weightedPods = append(weightedPods, weightedScoredPod{
129-
ScoredPod: scoredPod,
130-
key: 0,
131-
})
120+
weightedPods := make([]weightedScoredPod, len(scoredPods))
121+
122+
for i, scoredPod := range scoredPods {
123+
// Handle zero score
124+
if scoredPod.Score <= 0 {
125+
// Assign key=0 for zero-score pods (effectively excludes them from selection)
126+
weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: 0}
132127
continue
133128
}
134129

135-
// Generate random number U in (0,1)
130+
// If we're here the scoredPod.Score > 0. Generate a random number U in (0,1)
136131
u := randomGenerator.Float64()
137132
if u == 0 {
138133
u = 1e-10 // Avoid log(0)
139134
}
140135

141-
// Calculate key = U^(1/weight)
142-
key := math.Pow(u, 1.0/weight)
143-
144-
weightedPods = append(weightedPods, weightedScoredPod{
145-
ScoredPod: scoredPod,
146-
key: key,
147-
})
136+
weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: math.Pow(u, 1.0/scoredPod.Score)} // key = U^(1/weight)
148137
}
149138

150139
// Sort by key in descending order (largest keys first)
@@ -155,14 +144,9 @@ func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.Cycle
155144
// Select top k pods
156145
selectedCount := min(p.maxNumOfEndpoints, len(weightedPods))
157146

158-
scoredPods = make([]*types.ScoredPod, selectedCount)
147+
targetPods := make([]types.Pod, selectedCount)
159148
for i := range selectedCount {
160-
scoredPods[i] = weightedPods[i].ScoredPod
161-
}
162-
163-
targetPods := make([]types.Pod, len(scoredPods))
164-
for i, scoredPod := range scoredPods {
165-
targetPods[i] = scoredPod
149+
targetPods[i] = weightedPods[i].ScoredPod
166150
}
167151

168152
return &types.ProfileRunResult{TargetPods: targetPods}

0 commit comments

Comments
 (0)