Skip to content

Commit d9b6428

Browse files
committed
fix: optimize interface
1 parent 1e242d5 commit d9b6428

File tree

10 files changed

+59
-80
lines changed

10 files changed

+59
-80
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"apimachinery",
1313
"apimachineryruntime",
1414
"apiruntime",
15+
"apiserver",
1516
"apiutil",
1617
"automount",
1718
"AWSGPU",

internal/hypervisor/api/worker_types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ type WorkerInfo struct {
2121
TemplateID string
2222
Annotations map[string]string
2323
PodIndex string
24+
25+
// Tombstone field to indicate if the worker is deleted
26+
Deleted bool
2427
}
2528

2629
type WorkerAllocation struct {

internal/hypervisor/backend/kubernetes/kubernetes_backend.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77

88
"github.com/NexusGPU/tensor-fusion/internal/constants"
9+
"github.com/NexusGPU/tensor-fusion/internal/hypervisor/api"
910
"github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/kubernetes/external_dp"
1011
"github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework"
1112
"k8s.io/client-go/rest"
@@ -22,7 +23,7 @@ type KubeletBackend struct {
2223
deviceDetector *external_dp.DevicePluginDetector
2324

2425
workerChanged chan struct{}
25-
workerCh chan []string
26+
workerCh chan []*api.WorkerInfo
2627
workerStopCh chan struct{}
2728
}
2829

@@ -140,10 +141,10 @@ func (b *KubeletBackend) watchWorkerChanges() {
140141
}
141142
}
142143

143-
func (b *KubeletBackend) ListAndWatchWorkers() (<-chan []string, <-chan struct{}, error) {
144+
func (b *KubeletBackend) ListAndWatchWorkers() (<-chan []*api.WorkerInfo, <-chan struct{}, error) {
144145
// Initialize channels if not already created
145146
if b.workerCh == nil {
146-
b.workerCh = make(chan []string, 1)
147+
b.workerCh = make(chan []*api.WorkerInfo, 1)
147148
b.workerStopCh = make(chan struct{})
148149
}
149150

@@ -154,9 +155,11 @@ func (b *KubeletBackend) ListAndWatchWorkers() (<-chan []string, <-chan struct{}
154155
// Send initial list
155156
if b.kubeletClient != nil {
156157
b.kubeletClient.mu.RLock()
157-
workers := make([]string, 0, len(b.kubeletClient.podCache))
158+
workers := make([]*api.WorkerInfo, 0, len(b.kubeletClient.podCache))
158159
for podUID := range b.kubeletClient.podCache {
159-
workers = append(workers, podUID)
160+
workers = append(workers, &api.WorkerInfo{
161+
WorkerUID: podUID,
162+
})
160163
}
161164
b.kubeletClient.mu.RUnlock()
162165

@@ -180,9 +183,11 @@ func (b *KubeletBackend) ListAndWatchWorkers() (<-chan []string, <-chan struct{}
180183
case <-workerChangedCh:
181184
if b.kubeletClient != nil {
182185
b.kubeletClient.mu.RLock()
183-
workers := make([]string, 0, len(b.kubeletClient.podCache))
186+
workers := make([]*api.WorkerInfo, 0, len(b.kubeletClient.podCache))
184187
for podUID := range b.kubeletClient.podCache {
185-
workers = append(workers, podUID)
188+
workers = append(workers, &api.WorkerInfo{
189+
WorkerUID: podUID,
190+
})
186191
}
187192
b.kubeletClient.mu.RUnlock()
188193

internal/hypervisor/backend/single_node/single_node_backend.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"sync"
66
"time"
77

8+
"github.com/NexusGPU/tensor-fusion/internal/hypervisor/api"
89
"github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework"
910
"k8s.io/klog/v2"
1011
)
@@ -16,7 +17,7 @@ type SingleNodeBackend struct {
1617
workers map[string]*WorkerState // worker UID -> state
1718
stopCh chan struct{}
1819
stopOnce sync.Once
19-
workerCh chan []string
20+
workerCh chan []*api.WorkerInfo
2021
workerChCloseOnce sync.Once
2122
workerStopCh chan struct{}
2223
workerStopOnce sync.Once
@@ -130,10 +131,10 @@ func (b *SingleNodeBackend) periodicWorkerDiscovery() {
130131
}
131132
}
132133

133-
func (b *SingleNodeBackend) ListAndWatchWorkers() (<-chan []string, <-chan struct{}, error) {
134+
func (b *SingleNodeBackend) ListAndWatchWorkers() (<-chan []*api.WorkerInfo, <-chan struct{}, error) {
134135
// Initialize channels if not already created
135136
if b.workerCh == nil {
136-
b.workerCh = make(chan []string, 1)
137+
b.workerCh = make(chan []*api.WorkerInfo, 1)
137138
b.workerStopCh = make(chan struct{})
138139
}
139140

@@ -148,9 +149,11 @@ func (b *SingleNodeBackend) ListAndWatchWorkers() (<-chan []string, <-chan struc
148149

149150
// Send initial list
150151
b.mu.RLock()
151-
workers := make([]string, 0, len(b.workers))
152+
workers := make([]*api.WorkerInfo, 0, len(b.workers))
152153
for workerUID := range b.workers {
153-
workers = append(workers, workerUID)
154+
workers = append(workers, &api.WorkerInfo{
155+
WorkerUID: workerUID,
156+
})
154157
}
155158
b.mu.RUnlock()
156159

@@ -179,9 +182,12 @@ func (b *SingleNodeBackend) ListAndWatchWorkers() (<-chan []string, <-chan struc
179182
b.discoverWorkers()
180183

181184
b.mu.RLock()
182-
workers := make([]string, 0, len(b.workers))
185+
workers := make([]*api.WorkerInfo, 0, len(b.workers))
183186
for workerUID := range b.workers {
184-
workers = append(workers, workerUID)
187+
workers = append(workers, &api.WorkerInfo{
188+
WorkerUID: workerUID,
189+
AllocatedDevices: []string{"dummy"},
190+
})
185191
}
186192
b.mu.RUnlock()
187193

internal/hypervisor/framework/framework.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ type WorkerController interface {
5656
// Returns map keyed by device UUID, then by worker UID, then by process ID
5757
GetWorkerMetrics() (map[string]map[string]map[string]*api.WorkerMetrics, error)
5858

59-
// ListWorkers returns list of all worker UIDs
60-
ListWorkers() ([]string, error)
59+
// ListWorkers returns list of all worker infos
60+
ListWorkers() ([]*api.WorkerInfo, error)
6161
}
6262

6363
type QuotaController interface {
@@ -79,9 +79,9 @@ type Backend interface {
7979
Stop() error
8080

8181
// ListAndWatchWorkers gets GPU workers from the workload orchestration platform
82-
// Returns a channel that receives worker UID lists and a stop channel
82+
// Returns a channel that receives worker info lists and a stop channel
8383
// The channel should be closed when Stop() is called
84-
ListAndWatchWorkers() (<-chan []string, <-chan struct{}, error)
84+
ListAndWatchWorkers() (<-chan []*api.WorkerInfo, <-chan struct{}, error)
8585

8686
// GetWorkerToProcessMap links workers to actual running process list on OS
8787
GetWorkerToProcessMap() (map[string][]string, error)

internal/hypervisor/metrics/metrics.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,17 @@ func (h *HypervisorMetricsRecorder) RecordWorkerMetrics(writer io.Writer) {
151151
return
152152
}
153153

154-
workerUIDs, err := h.workerController.ListWorkers()
154+
workerInfos, err := h.workerController.ListWorkers()
155155
if err != nil {
156156
return
157157
}
158158

159159
// Get worker allocations for metadata
160160
workerAllocations := make(map[string]*api.WorkerAllocation)
161-
for _, workerUID := range workerUIDs {
162-
allocation, err := h.workerController.GetWorkerAllocation(workerUID)
161+
for _, worker := range workerInfos {
162+
allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID)
163163
if err == nil && allocation != nil {
164-
workerAllocations[workerUID] = allocation
164+
workerAllocations[worker.WorkerUID] = allocation
165165
}
166166
}
167167

internal/hypervisor/server/handlers/legacy.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ func (h *LegacyHandler) HandleGetLimiter(c *gin.Context) {
4949
}
5050

5151
limiterInfos := make([]api.LimiterInfo, 0, len(workers))
52-
for _, workerUID := range workers {
53-
allocation, err := h.workerController.GetWorkerAllocation(workerUID)
52+
for _, worker := range workers {
53+
allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID)
5454
if err != nil || allocation == nil {
5555
continue
5656
}
@@ -72,7 +72,7 @@ func (h *LegacyHandler) HandleGetLimiter(c *gin.Context) {
7272
}
7373

7474
limiterInfos = append(limiterInfos, api.LimiterInfo{
75-
WorkerUID: workerUID,
75+
WorkerUID: worker.WorkerUID,
7676
Requests: requests,
7777
Limits: limits,
7878
})
@@ -91,8 +91,8 @@ func (h *LegacyHandler) HandleTrap(c *gin.Context) {
9191
}
9292

9393
snapshotCount := 0
94-
for _, workerUID := range workers {
95-
allocation, err := h.workerController.GetWorkerAllocation(workerUID)
94+
for _, worker := range workers {
95+
allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID)
9696
if err != nil || allocation == nil {
9797
continue
9898
}
@@ -123,8 +123,8 @@ func (h *LegacyHandler) HandleGetPods(c *gin.Context) {
123123
}
124124

125125
pods := make([]api.PodInfo, 0)
126-
for _, workerUID := range workers {
127-
allocation, err := h.workerController.GetWorkerAllocation(workerUID)
126+
for _, worker := range workers {
127+
allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID)
128128
if err != nil || allocation == nil {
129129
continue
130130
}

internal/hypervisor/server/handlers/worker.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ func (h *WorkerHandler) HandleGetWorkers(c *gin.Context) {
4646

4747
// Get worker details
4848
workerDetails := make([]*api.WorkerAllocation, 0, len(workers))
49-
for _, workerUID := range workers {
50-
allocation, err := h.workerController.GetWorkerAllocation(workerUID)
49+
for _, worker := range workers {
50+
allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID)
5151
if err != nil {
5252
continue
5353
}

internal/hypervisor/worker/controller.go

Lines changed: 12 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ type WorkerController struct {
1818
quotaController framework.QuotaController
1919

2020
mu sync.RWMutex
21-
workers map[string]bool // worker UID -> exists
21+
workers map[string]*api.WorkerInfo
2222
workerWatchStop chan struct{}
2323
workerWatchStopOnce sync.Once
2424
}
@@ -31,7 +31,7 @@ func NewWorkerController(
3131
mode: mode,
3232
backend: backend,
3333
quotaController: quotaController,
34-
workers: make(map[string]bool),
34+
workers: make(map[string]*api.WorkerInfo, 16),
3535
workerWatchStop: make(chan struct{}),
3636
}
3737
}
@@ -63,9 +63,8 @@ func (w *WorkerController) Start() error {
6363
}
6464
// Update worker cache
6565
w.mu.Lock()
66-
w.workers = make(map[string]bool)
67-
for _, workerUID := range workers {
68-
w.workers[workerUID] = true
66+
for _, worker := range workers {
67+
w.workers[worker.WorkerUID] = worker
6968
}
7069
w.mu.Unlock()
7170
klog.V(4).Infof("Updated worker list: %d workers", len(workers))
@@ -271,50 +270,15 @@ func (w *WorkerController) GetWorkerMetrics() (map[string]map[string]map[string]
271270
return result, nil
272271
}
273272

274-
func (w *WorkerController) ListWorkers() ([]string, error) {
275-
// First check cache (updated by ListAndWatchWorkers)
273+
func (w *WorkerController) ListWorkers() ([]*api.WorkerInfo, error) {
276274
w.mu.RLock()
277-
cachedWorkers := make([]string, 0, len(w.workers))
278-
for workerUID := range w.workers {
279-
cachedWorkers = append(cachedWorkers, workerUID)
280-
}
281-
w.mu.RUnlock()
282-
283-
// If cache has workers, return them
284-
if len(cachedWorkers) > 0 {
285-
return cachedWorkers, nil
286-
}
287-
288-
// If cache is empty, directly query device allocations to get immediate results
289-
// This ensures we hit the key logic path and return accurate results
290-
allocations, err := w.deviceController.GetDeviceAllocations("")
291-
if err != nil {
292-
return cachedWorkers, err
293-
}
294-
295-
// Extract unique worker UIDs from allocations
296-
workerSet := make(map[string]bool)
297-
for _, allocation := range allocations {
298-
workerUID := allocation.WorkerInfo.WorkerUID
299-
if workerUID == "" {
300-
workerUID = allocation.WorkerInfo.PodUID
301-
}
302-
if workerUID != "" {
303-
workerSet[workerUID] = true
275+
defer w.mu.RUnlock()
276+
workerSnapshot := make([]*api.WorkerInfo, 0, len(w.workers))
277+
for _, worker := range w.workers {
278+
if worker.Deleted {
279+
continue
304280
}
281+
workerSnapshot = append(workerSnapshot, worker)
305282
}
306-
307-
// Update cache with discovered workers
308-
w.mu.Lock()
309-
for workerUID := range workerSet {
310-
w.workers[workerUID] = true
311-
}
312-
w.mu.Unlock()
313-
314-
// Return list of workers
315-
workers := make([]string, 0, len(workerSet))
316-
for workerUID := range workerSet {
317-
workers = append(workers, workerUID)
318-
}
319-
return workers, nil
283+
return workerSnapshot, nil
320284
}

internal/indexallocator/indexallocator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func (s *IndexAllocator) AssignIndex(podName string) (int, error) {
8686
}
8787
// Atomic increment and wrap around
8888
next := atomic.AddInt64(&s.currentIndex, 1)
89-
index := int((next-1)%IndexRangeEnd) + IndexRangeStart
89+
index := int((next-1)%constants.IndexRangeEnd) + constants.IndexRangeStart
9090
log.FromContext(s.ctx).Info("assigned index successfully", "podName", podName, "index", index)
9191
return index, nil
9292
}

0 commit comments

Comments
 (0)