Skip to content

Commit af04aa5

Browse files
authored
feat: non-locking Kubernetes device extension (#431)
* fix: index allocator, pod webhook priorityClass issue * fix: simplify index allocation * fix: lint issue * fix: optimize func length * fix: add toleration for tensor-fusion managed nodes * fix: add/remove node taint to isolate scheduler in progressive migration mode * fix: add ignore resource group for scheduler * fix: rename taint key var name * fix: priority conflict with priorityClass issue
1 parent f520cfc commit af04aa5

File tree

20 files changed

+489
-45
lines changed

20 files changed

+489
-45
lines changed

charts/tensor-fusion/Chart.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type: application
1515
# This is the chart version. This version number should be incremented each time you make changes
1616
# to the chart and its templates, including the app version.
1717
# Versions are expected to follow Semantic Versioning (https://semver.org/)
18-
version: 1.7.4
18+
version: 1.7.5
1919

2020
# This is the version number of the application being deployed. This version number should be
2121
# incremented each time you make changes to the application. Versions are not expected to

charts/tensor-fusion/values.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ schedulerConfig:
208208
totalIntranetBandWidthGBps: 100
209209
- name: NodeResourcesFit
210210
args:
211+
ignoredResourceGroups:
212+
- "tensor-fusion.ai"
211213
scoringStrategy:
212214
resources:
213215
- name: cpu

cmd/main.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"github.com/NexusGPU/tensor-fusion/internal/constants"
3737
"github.com/NexusGPU/tensor-fusion/internal/controller"
3838
"github.com/NexusGPU/tensor-fusion/internal/gpuallocator"
39+
"github.com/NexusGPU/tensor-fusion/internal/indexallocator"
3940
"github.com/NexusGPU/tensor-fusion/internal/metrics"
4041
"github.com/NexusGPU/tensor-fusion/internal/portallocator"
4142
"github.com/NexusGPU/tensor-fusion/internal/scheduler/expander"
@@ -232,17 +233,25 @@ func main() {
232233
// Initialize GPU allocator and set up watches
233234
allocator, portAllocator := startTensorFusionAllocators(ctx, mgr)
234235

236+
// Initialize Index allocator for Device Plugin communication
237+
indexAllocator, err := indexallocator.NewIndexAllocator(ctx, mgr.GetClient())
238+
if err != nil {
239+
setupLog.Error(err, "unable to set up index allocator")
240+
os.Exit(1)
241+
}
242+
_ = indexAllocator.SetupWithManager(ctx, mgr)
243+
235244
startAutoScaler(mgr, allocator)
236245

237246
// Create pricing provider for webhook
238247
pricingProvider := pricing.NewStaticPricingProvider()
239-
startWebhook(mgr, portAllocator, pricingProvider)
248+
startWebhook(mgr, portAllocator, indexAllocator, pricingProvider)
240249

241250
scheduler, nodeExpander := startScheduler(ctx, allocator, mgr, k8sVersion)
242251

243252
startCustomResourceController(ctx, mgr, metricsRecorder, allocator, portAllocator, nodeExpander)
244253

245-
startHttpServerForTFClient(ctx, kc, portAllocator, allocator, scheduler, mgr.Elected())
254+
startHttpServerForTFClient(ctx, kc, portAllocator, indexAllocator, allocator, scheduler, mgr.Elected())
246255

247256
// +kubebuilder:scaffold:builder
248257
addHealthCheckAPI(mgr)
@@ -291,6 +300,7 @@ func startHttpServerForTFClient(
291300
ctx context.Context,
292301
kc *rest.Config,
293302
portAllocator *portallocator.PortAllocator,
303+
indexAllocator *indexallocator.IndexAllocator,
294304
allocator *gpuallocator.GpuAllocator,
295305
scheduler *scheduler.Scheduler,
296306
leaderChan <-chan struct{},
@@ -310,12 +320,19 @@ func startHttpServerForTFClient(
310320
setupLog.Error(err, "failed to create assign host port router")
311321
os.Exit(1)
312322
}
323+
assignIndexRouter, err := router.NewAssignIndexRouter(ctx, indexAllocator)
324+
if err != nil {
325+
setupLog.Error(err, "failed to create assign index router")
326+
os.Exit(1)
327+
}
313328
allocatorInfoRouter, err := router.NewAllocatorInfoRouter(ctx, allocator, scheduler)
314329
if err != nil {
315330
setupLog.Error(err, "failed to create allocator info router")
316331
os.Exit(1)
317332
}
318-
httpServer := server.NewHTTPServer(connectionRouter, assignHostPortRouter, allocatorInfoRouter, leaderChan)
333+
httpServer := server.NewHTTPServer(
334+
connectionRouter, assignHostPortRouter, assignIndexRouter, allocatorInfoRouter, leaderChan,
335+
)
319336
go func() {
320337
err := httpServer.Run()
321338
if err != nil {
@@ -468,12 +485,13 @@ func startCustomResourceController(
468485
func startWebhook(
469486
mgr manager.Manager,
470487
portAllocator *portallocator.PortAllocator,
488+
indexAllocator *indexallocator.IndexAllocator,
471489
pricingProvider pricing.PricingProvider,
472490
) {
473491
if os.Getenv(constants.EnableWebhookEnv) == constants.FalseStringValue {
474492
return
475493
}
476-
if err := webhookcorev1.SetupPodWebhookWithManager(mgr, portAllocator, pricingProvider); err != nil {
494+
if err := webhookcorev1.SetupPodWebhookWithManager(mgr, portAllocator, indexAllocator, pricingProvider); err != nil {
477495
setupLog.Error(err, "unable to create webhook", "webhook", "Pod")
478496
os.Exit(1)
479497
}

config/samples/scheduler-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ profiles:
4040
totalIntranetBandWidthGBps: 100
4141
- name: NodeResourcesFit
4242
args:
43+
ignoredResourceGroups:
44+
- "tensor-fusion.ai"
4345
scoringStrategy:
4446
resources:
4547
- name: cpu

internal/constants/constants.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ const (
9292
// Additional worker pod template is set by user with /worker-pod-template annotation
9393
WorkerPodTemplateAnnotation = Domain + "/worker-pod-template"
9494

95+
// Pod index annotation for Device Plugin communication (1-512)
96+
PodIndexAnnotation = Domain + "/index"
97+
9598
WorkloadModeAnnotation = Domain + "/workload-mode"
9699
WorkloadModeDynamic = "dynamic"
97100
WorkloadModeFixed = "fixed"
@@ -119,6 +122,7 @@ const (
119122
TensorFusionPodCounterKeyAnnotation = Domain + "/pod-counter-key"
120123
TensorFusionPodCountAnnotation = Domain + "/tf-pod-count"
121124
TensorFusionWorkerSuffix = "-tf"
125+
NodeUsedByTaintKey = Domain + "/used-by"
122126

123127
// For grey release
124128
TensorFusionEnabledReplicasAnnotation = Domain + "/enabled-replicas"

internal/constants/vendors.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,22 @@ const (
88

99
// DSA vendors - Global
1010
AcceleratorVendorQualcomm = "Qualcomm"
11-
AcceleratorVendorAWSNeuron = "AWS-Neuron"
12-
AcceleratorVendorGoogleTPU = "Google-TPU"
11+
AcceleratorVendorAWSNeuron = "AWSNeuron"
12+
AcceleratorVendorGoogleTPU = "Google"
1313
AcceleratorVendorCerebras = "Cerebras"
1414

1515
// GPGPU vendors - CN
16-
AcceleratorVendorHygon = "Hygon-DCU"
17-
AcceleratorVendorMetaX = "Meta-X"
16+
AcceleratorVendorHygon = "Hygon"
17+
AcceleratorVendorMetaX = "MetaX"
1818
AcceleratorVendorMThreads = "MThreads"
19-
AcceleratorVendorBiren = "BirenGPU"
20-
AcceleratorVendorAlibabaTHead = "THead-PPU"
19+
AcceleratorVendorBiren = "Biren"
20+
AcceleratorVendorAlibabaTHead = "THead"
2121

2222
// DSA vendors - CN
23-
AcceleratorVendorHuaweiAscendNPU = "Ascend-NPU"
24-
AcceleratorVendorCambricon = "Cambricon-MLU"
25-
AcceleratorVendorEnflame = "Enflame-XPU"
26-
AcceleratorVendorKunlunX = "KunlunXin-XPU"
23+
AcceleratorVendorHuaweiAscendNPU = "Ascend"
24+
AcceleratorVendorCambricon = "Cambricon"
25+
AcceleratorVendorEnflame = "Enflame"
26+
AcceleratorVendorKunlunX = "KunlunXin"
2727

2828
AcceleratorVendorUnknown = "Unknown"
2929
)

internal/controller/gpunode_controller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func (r *GPUNodeReconciler) checkStatusAndUpdateVirtualCapacity(
188188

189189
return nil
190190
} else {
191-
gpuModels, err := gpuallocator.RefreshGPUNodeCapacity(ctx, r.Client, node, poolObj, r.Allocator)
191+
gpuModels, err := gpuallocator.RefreshGPUNodeCapacity(ctx, r.Client, node, poolObj, r.Allocator, coreNode)
192192
if err != nil {
193193
return err
194194
}

internal/gpuallocator/gpuallocator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ var _ = Describe("GPU Allocator", func() {
9797
if err := k8sClient.Get(ctx, types.NamespacedName{Name: "test-pool"}, pool); err != nil {
9898
Expect(err).NotTo(HaveOccurred())
9999
}
100-
_, _ = RefreshGPUNodeCapacity(ctx, k8sClient, gpuNode, pool, allocator)
100+
_, _ = RefreshGPUNodeCapacity(ctx, k8sClient, gpuNode, pool, allocator, nil)
101101

102102
// Verify resources were reduced on the allocated GPU
103103
gpu := getGPU(gpus[0].Name)

internal/gpuallocator/node_capacity.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@ import (
66

77
tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
88
"github.com/NexusGPU/tensor-fusion/internal/constants"
9+
"github.com/NexusGPU/tensor-fusion/internal/utils"
10+
corev1 "k8s.io/api/core/v1"
911
"k8s.io/apimachinery/pkg/api/equality"
1012
"k8s.io/apimachinery/pkg/api/resource"
13+
"k8s.io/kubernetes/pkg/util/taints"
1114
"sigs.k8s.io/controller-runtime/pkg/client"
15+
"sigs.k8s.io/controller-runtime/pkg/log"
1216
)
1317

1418
func RefreshGPUNodeCapacity(
1519
ctx context.Context, k8sClient client.Client,
1620
node *tfv1.GPUNode, pool *tfv1.GPUPool,
1721
allocator *GpuAllocator,
22+
coreNode *corev1.Node,
1823
) ([]string, error) {
1924
gpuList := &tfv1.GPUList{}
2025
if err := k8sClient.List(ctx, gpuList, client.MatchingLabels{constants.LabelKeyOwner: node.Name}); err != nil {
@@ -76,6 +81,31 @@ func RefreshGPUNodeCapacity(
7681
if err != nil {
7782
return nil, fmt.Errorf("failed to update GPU node status: %w", err)
7883
}
84+
85+
// check if need to update K8S node label
86+
if utils.IsProgressiveMigration() && coreNode != nil {
87+
taint := &corev1.Taint{
88+
Key: constants.NodeUsedByTaintKey,
89+
Effect: corev1.TaintEffectNoSchedule,
90+
Value: constants.TensorFusionSystemName,
91+
}
92+
needUpdateNode := false
93+
if node.Status.AvailableVRAM.Equal(node.Status.TotalVRAM) && node.Status.AvailableTFlops.Equal(node.Status.TotalTFlops) {
94+
// check if need to remove the taint
95+
coreNode, needUpdateNode, _ = taints.RemoveTaint(coreNode, taint)
96+
} else if !taints.TaintExists(coreNode.Spec.Taints, taint) {
97+
// check if need to add the taint
98+
coreNode, needUpdateNode, _ = taints.AddOrUpdateTaint(coreNode, taint)
99+
}
100+
if needUpdateNode {
101+
log.FromContext(ctx).Info("Updating K8S node taints for isolation of tensor-fusion and non-tensor-fusion used nodes",
102+
"node", coreNode.Name, "taint", taint.Key)
103+
err := k8sClient.Update(ctx, coreNode)
104+
if err != nil {
105+
return nil, fmt.Errorf("failed to update K8S node: %w", err)
106+
}
107+
}
108+
}
79109
}
80110
return gpuModels, nil
81111
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package indexallocator
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync/atomic"
7+
8+
"github.com/NexusGPU/tensor-fusion/internal/constants"
9+
"github.com/NexusGPU/tensor-fusion/internal/utils"
10+
v1 "k8s.io/api/core/v1"
11+
"k8s.io/client-go/util/retry"
12+
13+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
14+
"sigs.k8s.io/controller-runtime/pkg/client"
15+
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
16+
"sigs.k8s.io/controller-runtime/pkg/log"
17+
"sigs.k8s.io/controller-runtime/pkg/manager"
18+
)
19+
20+
const (
21+
IndexRangeStart = 1
22+
IndexRangeEnd = 512
23+
)
24+
25+
// IndexAllocator manages allocation of 1-512 temporary indices for Pod-to-DevicePlugin communication
26+
// Uses a simple atomic counter that increments from 1 to 512, then wraps around to 1
27+
// No bitmap tracking needed - index reuse is acceptable after 512 cycles
28+
type IndexAllocator struct {
29+
IsLeader bool
30+
31+
// Atomic counter for index allocation (1-512, wraps around)
32+
currentIndex int64
33+
34+
Client client.Client
35+
36+
ctx context.Context
37+
}
38+
39+
func NewIndexAllocator(ctx context.Context, client client.Client) (*IndexAllocator, error) {
40+
if client == nil {
41+
return nil, fmt.Errorf("client cannot be nil")
42+
}
43+
44+
allocator := &IndexAllocator{
45+
Client: client,
46+
IsLeader: false,
47+
currentIndex: 0, // Will start from 1 on first assignment
48+
ctx: ctx,
49+
}
50+
51+
return allocator, nil
52+
}
53+
54+
func (s *IndexAllocator) SetupWithManager(ctx context.Context, mgr manager.Manager) <-chan struct{} {
55+
readyCh := make(chan struct{}, 1)
56+
_ = mgr.Add(manager.RunnableFunc(func(ctx context.Context) error {
57+
<-mgr.Elected()
58+
s.IsLeader = true
59+
leaderInfo := &v1.ConfigMap{
60+
ObjectMeta: metav1.ObjectMeta{
61+
Name: constants.LeaderInfoConfigMapName,
62+
Namespace: utils.CurrentNamespace(),
63+
},
64+
}
65+
err := retry.RetryOnConflict(retry.DefaultBackoff, func() error {
66+
_, err := controllerutil.CreateOrUpdate(ctx, s.Client, leaderInfo, func() error {
67+
leaderInfo.Data = map[string]string{
68+
constants.LeaderInfoConfigMapLeaderIPKey: utils.CurrentIP(),
69+
}
70+
return nil
71+
})
72+
return err
73+
})
74+
if err != nil {
75+
log.FromContext(ctx).Error(err, "Failed to update leader IP info in ConfigMap")
76+
}
77+
78+
readyCh <- struct{}{}
79+
return nil
80+
}))
81+
return readyCh
82+
}
83+
84+
// AssignIndex assigns a temporary index (1-512) for Pod-to-DevicePlugin communication
85+
// Uses atomic increment to ensure thread-safe assignment
86+
// Index wraps around from 512 to 1 (simple modulo operation)
87+
func (s *IndexAllocator) AssignIndex(podName string) (int, error) {
88+
if !s.IsLeader {
89+
return 0, fmt.Errorf("only leader can assign index")
90+
}
91+
92+
// Atomic increment and wrap around
93+
next := atomic.AddInt64(&s.currentIndex, 1)
94+
index := int((next-1)%IndexRangeEnd) + IndexRangeStart
95+
96+
return index, nil
97+
}

0 commit comments

Comments
 (0)