Skip to content

Commit 9c6a527

Browse files
committed
Update lora affinity to be a scorer.
1 parent 0e1e964 commit 9c6a527

File tree

3 files changed

+267
-1
lines changed

3 files changed

+267
-1
lines changed

cmd/epp/runner/runner.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,11 @@ func (r *Runner) initializeScheduler() (*scheduling.Scheduler, error) {
289289
if schedulerV2 {
290290
queueScorerWeight := envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog)
291291
kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog)
292+
loraAffinityScorerWeight := envutil.GetEnvInt("LORA_AFFINITY_SCORE_WEIGHT", scorer.DefaultLoraAffinityScorerWeight, setupLog)
292293

293294
schedulerProfile := framework.NewSchedulerProfile().
294295
WithScorers(framework.NewWeightedScorer(scorer.NewQueueScorer(), queueScorerWeight),
295-
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight)).
296+
framework.NewWeightedScorer(scorer.NewKVCacheScorer(), kvCacheScorerWeight), framework.NewWeightedScorer(scorer.NewLoraAffinityScorer(), loraAffinityScorerWeight)).
296297
WithPicker(picker.NewMaxScorePicker())
297298

298299
if prefixCacheScheduling {
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scorer
18+
19+
import (
20+
"context"
21+
"encoding/json"
22+
23+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
26+
)
27+
28+
const (
29+
DefaultLoraAffinityScorerWeight = 1
30+
LoraAffinityScorerType = "lora-affinity"
31+
)
32+
33+
// compile-time type assertion
34+
var _ framework.Scorer = &LoraAffinityScorer{}
35+
36+
// LoraAffinityScorerFactory defines the factory function for LoraAffinityScorer.
37+
func LoraAffinityScorerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
38+
return NewLoraAffinityScorer().WithName(name), nil
39+
}
40+
41+
// NewLoraAffinityScorer initializes a new LoraAffinityScorer and returns its pointer.
42+
func NewLoraAffinityScorer() *LoraAffinityScorer {
43+
return &LoraAffinityScorer{
44+
name: plugins.TypedName{Type: LoraAffinityScorerType, Name: LoraAffinityScorerType},
45+
}
46+
}
47+
48+
// LoraAffinityScorer scores list of candidate pods based on KV cache utilization.
49+
type LoraAffinityScorer struct {
50+
name plugins.TypedName
51+
}
52+
53+
// TypedName returns the type and name tuple of this plugin instance.
54+
func (s *LoraAffinityScorer) TypedName() plugins.TypedName {
55+
return s.name
56+
}
57+
58+
// Type returns the type of the scorer.
59+
func (s *LoraAffinityScorer) Type() string {
60+
return LoraAffinityScorerType
61+
}
62+
63+
// WithName sets the name of the scorer.
64+
func (s *LoraAffinityScorer) WithName(name string) *LoraAffinityScorer {
65+
s.name.Name = name
66+
return s
67+
}
68+
69+
func (s *LoraAffinityScorer) Score(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
70+
scores := make(map[types.Pod]float64, len(pods))
71+
72+
// Categorize pods based on affinity and availability
73+
for _, pod := range pods {
74+
_, active := pod.GetMetrics().ActiveModels[request.TargetModel]
75+
_, waiting := pod.GetMetrics().WaitingModels[request.TargetModel]
76+
77+
if active {
78+
scores[pod] = 1
79+
} else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels {
80+
scores[pod] = 0.8
81+
} else if waiting {
82+
scores[pod] = 0.6
83+
} else {
84+
scores[pod] = 0.0
85+
}
86+
}
87+
88+
return scores
89+
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scorer
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/stretchr/testify/assert"
24+
k8stypes "k8s.io/apimachinery/pkg/types"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
26+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
27+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
28+
)
29+
30+
func TestLoraAffinityScorer(t *testing.T) {
31+
tests := []struct {
32+
name string
33+
request *types.LLMRequest
34+
pods []types.Pod
35+
expectedScoresPod map[int]float64 // Map of pod index to expected score
36+
}{
37+
{
38+
name: "Target model is active",
39+
request: &types.LLMRequest{TargetModel: "active-model-1"},
40+
pods: []types.Pod{
41+
&types.PodMetrics{
42+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}},
43+
MetricsState: &backendmetrics.MetricsState{
44+
ActiveModels: map[string]int{"active-model-1": 1},
45+
WaitingModels: map[string]int{},
46+
MaxActiveModels: 5,
47+
},
48+
},
49+
},
50+
expectedScoresPod: map[int]float64{
51+
0: 1.0,
52+
},
53+
},
54+
{
55+
name: "Target model is waiting",
56+
request: &types.LLMRequest{TargetModel: "active-model-1"},
57+
pods: []types.Pod{
58+
&types.PodMetrics{
59+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}},
60+
MetricsState: &backendmetrics.MetricsState{
61+
ActiveModels: map[string]int{"active-model-2": 2},
62+
WaitingModels: map[string]int{"active-model-1": 1},
63+
MaxActiveModels: 2,
64+
},
65+
},
66+
},
67+
expectedScoresPod: map[int]float64{
68+
0: 0.6,
69+
},
70+
},
71+
{
72+
name: "Pods have no space for new model",
73+
request: &types.LLMRequest{TargetModel: "active-model-1"},
74+
pods: []types.Pod{
75+
&types.PodMetrics{
76+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}},
77+
MetricsState: &backendmetrics.MetricsState{
78+
ActiveModels: map[string]int{"active-model-2": 2},
79+
WaitingModels: map[string]int{"active-model-3": 1},
80+
MaxActiveModels: 2,
81+
},
82+
},
83+
&types.PodMetrics{
84+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}},
85+
MetricsState: &backendmetrics.MetricsState{
86+
ActiveModels: map[string]int{},
87+
WaitingModels: map[string]int{},
88+
MaxActiveModels: 0,
89+
},
90+
},
91+
},
92+
expectedScoresPod: map[int]float64{
93+
0: 0.0,
94+
1: 0.0,
95+
},
96+
},
97+
{
98+
name: "Multiple pods with mixed active and waiting models",
99+
request: &types.LLMRequest{TargetModel: "active-model-1"},
100+
pods: []types.Pod{
101+
&types.PodMetrics{
102+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}},
103+
MetricsState: &backendmetrics.MetricsState{
104+
ActiveModels: map[string]int{"active-model-1": 1},
105+
WaitingModels: map[string]int{},
106+
MaxActiveModels: 5,
107+
},
108+
},
109+
&types.PodMetrics{
110+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}},
111+
MetricsState: &backendmetrics.MetricsState{
112+
ActiveModels: map[string]int{"active-model-2": 4},
113+
WaitingModels: map[string]int{"active-model-1": 1},
114+
MaxActiveModels: 5,
115+
},
116+
},
117+
&types.PodMetrics{
118+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}},
119+
MetricsState: &backendmetrics.MetricsState{
120+
ActiveModels: map[string]int{"active-model-2": 1},
121+
WaitingModels: map[string]int{},
122+
MaxActiveModels: 2,
123+
},
124+
},
125+
&types.PodMetrics{
126+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}},
127+
MetricsState: &backendmetrics.MetricsState{
128+
ActiveModels: map[string]int{"active-model-3": 1},
129+
WaitingModels: map[string]int{"active-model-1": 1},
130+
MaxActiveModels: 2,
131+
},
132+
},
133+
&types.PodMetrics{
134+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}},
135+
MetricsState: &backendmetrics.MetricsState{
136+
ActiveModels: map[string]int{"active-model-4": 1, "active-model-5": 1},
137+
WaitingModels: map[string]int{},
138+
MaxActiveModels: 2,
139+
},
140+
},
141+
},
142+
expectedScoresPod: map[int]float64{
143+
0: 1.0,
144+
1: 0.8,
145+
2: 0.8,
146+
3: 0.6,
147+
4: 0.0,
148+
},
149+
},
150+
{
151+
name: "Empty pods slice",
152+
request: &types.LLMRequest{TargetModel: "modelA"},
153+
pods: []types.Pod{},
154+
expectedScoresPod: map[int]float64{}, // No pods, no scores
155+
},
156+
}
157+
158+
for _, test := range tests {
159+
t.Run(test.name, func(t *testing.T) {
160+
scorer := &LoraAffinityScorer{}
161+
scores := scorer.Score(context.Background(), types.NewCycleState(), test.request, test.pods)
162+
163+
for i, pod := range test.pods {
164+
expectedScore, ok := test.expectedScoresPod[i]
165+
if !ok {
166+
t.Fatalf("Expected score not found for pod index %d in test %s", i, test.name)
167+
}
168+
// Use pod.GetPod().NamespacedName.Name for better identification in error messages
169+
assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %s (index %d) should have score %f", pod.GetPod().NamespacedName.Name, i, expectedScore)
170+
}
171+
172+
// Also, ensure no unexpected pods are scored
173+
assert.Len(t, scores, len(test.expectedScoresPod), "Number of scored pods should match expected")
174+
})
175+
}
176+
}

0 commit comments

Comments
 (0)