Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"math"
"math/rand"
"slices"
"sort"
"time"

Expand All @@ -43,12 +44,12 @@ type weightedScoredPod struct {
key float64
}

// compile-time type validation
var _ framework.Picker = &WeightedRandomPicker{}

// WeightedRandomPickerFactory defines the factory function for WeightedRandomPicker.
func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := pickerParameters{
MaxNumOfEndpoints: DefaultMaxNumOfEndpoints,
}
parameters := pickerParameters{MaxNumOfEndpoints: DefaultMaxNumOfEndpoints}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", WeightedRandomPickerType, err)
Expand All @@ -58,9 +59,10 @@ func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ p
return NewWeightedRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil
}

// NewWeightedRandomPicker initializes a new WeightedRandomPicker and returns its pointer.
func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker {
if maxNumOfEndpoints <= 0 {
maxNumOfEndpoints = DefaultMaxNumOfEndpoints
maxNumOfEndpoints = DefaultMaxNumOfEndpoints // on invalid configuration value, fallback to default value
}

return &WeightedRandomPicker{
Expand All @@ -70,81 +72,68 @@ func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker {
}
}

// WeightedRandomPicker picks pod(s) from the list of candidates based on weighted random sampling using A-Res algorithm.
// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf.
//
// The picker at its core is picking pods randomly, where the probability of the pod to get picked is derived
// from its weighted score.
// Algorithm:
// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ)
// - Selects k items with largest keys for mathematically correct weighted sampling
// - More efficient than traditional cumulative probability approach
//
// Key characteristics:
// - Mathematically correct weighted random sampling
// - Single pass algorithm with O(n + k log k) complexity
type WeightedRandomPicker struct {
typedName plugins.TypedName
maxNumOfEndpoints int
randomPicker *RandomPicker // fallback for zero weights
}

// WithName sets the name of the picker.
func (p *WeightedRandomPicker) WithName(name string) *WeightedRandomPicker {
p.typedName.Name = name
return p
}

// TypedName returns the type and name tuple of this plugin instance.
func (p *WeightedRandomPicker) TypedName() plugins.TypedName {
return p.typedName
}

// WeightedRandomPicker performs weighted random sampling using A-Res algorithm.
// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf
// Algorithm:
// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ)
// - Selects k items with largest keys for mathematically correct weighted sampling
// - More efficient than traditional cumulative probability approach
//
// Key characteristics:
// - Mathematically correct weighted random sampling
// - Single pass algorithm with O(n + k log k) complexity
// Pick selects the pod(s) randomly from the list of candidates, where the probability of the pod to get picked is derived
// from its weighted score.
func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates using weighted random sampling: %+v",
p.maxNumOfEndpoints, len(scoredPods), scoredPods))

// Check if all weights are zero or negative
allZeroWeights := true
for _, scoredPod := range scoredPods {
if scoredPod.Score > 0 {
allZeroWeights = false
break
}
}

// Delegate to RandomPicker for uniform selection when all weights are zero
if allZeroWeights {
log.FromContext(ctx).V(logutil.DEBUG).Info("All weights are zero, delegating to RandomPicker for uniform selection")
// Check if there is at least one pod with Score > 0, if not let random picker run
if slices.IndexFunc(scoredPods, func(scoredPod *types.ScoredPod) bool { return scoredPod.Score > 0 }) == -1 {
log.FromContext(ctx).V(logutil.DEBUG).Info("All scores are zero, delegating to RandomPicker for uniform selection")
return p.randomPicker.Pick(ctx, cycleState, scoredPods)
}

log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates by random weighted picker", "max-num-of-endpoints", p.maxNumOfEndpoints,
"num-of-candidates", len(scoredPods), "scored-pods", scoredPods)

randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))

// A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ)
weightedPods := make([]weightedScoredPod, 0, len(scoredPods))

for _, scoredPod := range scoredPods {
weight := float64(scoredPod.Score)

// Handle zero or negative weights
if weight <= 0 {
// Assign very small key for zero-weight pods (effectively excludes them)
weightedPods = append(weightedPods, weightedScoredPod{
ScoredPod: scoredPod,
key: 0,
})
weightedPods := make([]weightedScoredPod, len(scoredPods))

for i, scoredPod := range scoredPods {
// Handle zero score
if scoredPod.Score <= 0 {
// Assign key=0 for zero-score pods (effectively excludes them from selection)
weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: 0}
continue
}

// Generate random number U in (0,1)
// If we're here the scoredPod.Score > 0. Generate a random number U in (0,1)
u := randomGenerator.Float64()
if u == 0 {
u = 1e-10 // Avoid log(0)
}

// Calculate key = U^(1/weight)
key := math.Pow(u, 1.0/weight)

weightedPods = append(weightedPods, weightedScoredPod{
ScoredPod: scoredPod,
key: key,
})
weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: math.Pow(u, 1.0/scoredPod.Score)} // key = U^(1/weight)
}

// Sort by key in descending order (largest keys first)
Expand All @@ -155,14 +144,9 @@ func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.Cycle
// Select top k pods
selectedCount := min(p.maxNumOfEndpoints, len(weightedPods))

scoredPods = make([]*types.ScoredPod, selectedCount)
targetPods := make([]types.Pod, selectedCount)
for i := range selectedCount {
scoredPods[i] = weightedPods[i].ScoredPod
}

targetPods := make([]types.Pod, len(scoredPods))
for i, scoredPod := range scoredPods {
targetPods[i] = scoredPod
targetPods[i] = weightedPods[i].ScoredPod
}

return &types.ProfileRunResult{TargetPods: targetPods}
Expand Down