diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 76d9f0b5..55312752 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -77,7 +77,7 @@ jobs: build-args: | GO_LDFLAGS=-X 'github.com/NexusGPU/tensor-fusion/internal/version.BuildVersion=${{ needs.release.outputs.version }}' - publish_node_discovery_image: + publish_hypervisor_image: needs: - release if: needs.release.outputs.published == 'true' || github.event_name == 'workflow_dispatch' @@ -95,7 +95,7 @@ jobs: - id: meta uses: docker/metadata-action@v5 with: - images: tensorfusion/tensor-fusion-node-discovery + images: tensorfusion/tensor-fusion-hypervisor tags: ${{ github.event_name == 'workflow_dispatch' && steps.set_tag.outputs.tag || format('type=semver,pattern={{{{version}}}},value={0}', needs.release.outputs.version) }} - name: Login to DockerHub @@ -104,12 +104,12 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - name: Build and push node discovery + - name: Build and push hypervisor uses: docker/build-push-action@v6 with: context: . push: true - file: dockerfile/node-discovery.Dockerfile + file: dockerfile/hypervisor.Dockerfile tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} no-cache: true diff --git a/.gitignore b/.gitignore index fc148c71..54b8a74d 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,13 @@ __debug* vendor logs -*.prof \ No newline at end of file +*.prof + +provider/build + +cmd/hypervisor/hypervisor +*.o + +_obj + +metrics.log \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 954d1d19..0c9c7fa9 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -21,15 +21,18 @@ ] }, { - "name": "Debug Discovery", + "name": "Debug Hypervisor", "type": "go", "request": "launch", "mode": "auto", + "console": "integratedTerminal", "env": { - "HOSTNAME": "mocknode", - "KUBECONFIG": "~/.kube/config", + "KUBECONFIG": "~/.kube/config-local-studio", + "HYPERVISOR_PORT": "8042", + "GPU_NODE_NAME": "ubuntu", }, - "program": "${workspaceFolder}/cmd/nodediscovery/main.go", + "cwd": "${workspaceFolder}", + "program": "${workspaceFolder}/cmd/hypervisor/main.go", }, { "name": "Debug Dev Env Operator", @@ -62,7 +65,8 @@ "ENABLE_WEBHOOKS": "false", "ENABLE_SCHEDULER": "true", "ENABLE_CR_CONTROLLER": "true", - "NVIDIA_OPERATOR_PROGRESSIVE_MIGRATION": "true" + "NVIDIA_OPERATOR_PROGRESSIVE_MIGRATION": "true", + "IMPERSONATE_SERVICE_ACCOUNT": "system:serviceaccount:tensor-fusion-sys:tensor-fusion-sys" }, "args": [ "--metrics-path", "${workspaceFolder}/logs/metrics.log", diff --git a/.vscode/settings.json b/.vscode/settings.json index 5be70139..012463ae 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,11 +12,15 @@ "apimachinery", "apimachineryruntime", "apiruntime", + "apiserver", "apiutil", "automount", "AWSGPU", "batchv", "Biren", + "bubbletea", + "BUILDPLATFORM", + "buildx", "burstable", "Cambricon", "CDNA", @@ -24,6 +28,8 @@ "certgen", "certificaterequests", "certmanager", + "CFLAGS", + "charmbracelet", "clientcmd", "clientcmdapi", "clientgoscheme", @@ -45,27 +51,35 @@ "datanode", "deepcopy", "defaultbinder", + "deviceplugin", "dylib", "eastus", "envtest", "essd", "Eventf", + "eventhandlers", "evictable", "featuregate", "finalizer", "Finalizers", "frameworkruntime", + "fsnotify", "FULLTEXT", + "GOBIN", "goconst", "gocyclo", "goerrors", + "golangci", "golint", "Gomega", "gonic", + "GOPATH", "gopsutil", "gorm", "gosec", + "GPGPU", "gpuallocator", + "GPUIDs", "gpunode", "gpunodeclaim", "gpunodeclaims", @@ -86,8 +100,11 @@ "imageutils", "indexallocator", "influxdata", + "Infof", "internalcache", "internalqueue", + "intstr", + "IVSHMEM", "jsonpatch", "karpenter", "karpv", @@ -99,9 +116,12 @@ "kubescheduler", "kubeschedulerconfig", "kustomization", + "libaccelerator", "libcuda", "libnvidia", "lineprotocol", + "lipgloss", + "LOCALBIN", "mapstructure", "metav", "metricsserver", @@ -113,14 +133,19 @@ "nindent", "nodeclaim", "nodeclassref", + "nodelist", "noderesources", "nolint", "NUMA", + "nvdp", "Nvlink", "NVML", "objs", "omitempty", "onsi", + "pids", + "pluginapi", + "podname", "portallocator", "Postable", "printcolumn", @@ -128,11 +153,13 @@ "prometheuses", "prometheusrules", "queuesort", + "Radeon", "RDNA", "readyz", "replicaset", "replicasets", "rolebinding", + "RTXA", "runbook", "runpod", "samber", @@ -145,12 +172,18 @@ "schedv", "serviceaccount", "shirou", + "shmem", "shortuuid", "statefulset", "statefulsets", + "stdbool", + "stddef", + "stdint", + "stdlib", "strategicpatch", "strategicpatches", "stretchr", + "strncpy", "subresource", "Tabler", "tensorfusion", @@ -165,6 +198,8 @@ "testutil", "tflops", "timberio", + "Timeslicing", + "tmpfs", "Tmpl", "tokenreviews", "Tolerations", @@ -173,9 +208,15 @@ "utilerrors", "utilruntime", "vgpu", + "Warningf", "webhookcorev", + "workerstate", "workloadprofiles", "workqueue", "Xlarge" - ] + ], + "files.associations": { + "__locale": "cpp", + "bitset": "cpp" + } } \ No newline at end of file diff --git a/Makefile b/Makefile index 87317a95..db0b7056 100644 --- a/Makefile +++ b/Makefile @@ -110,6 +110,26 @@ build: manifests generate fmt vet ## Build manager binary. run: manifests generate fmt vet ## Run a controller from your host. go run ./cmd/main.go +.PHONY: build-provider +build-provider: ## Build accelerator stub library. + $(MAKE) -C provider stub + +.PHONY: build-hypervisor +build-hypervisor: build-provider ## Build hypervisor binary with CGO enabled. + @PROVIDER_DIR=$$(pwd)/provider; \ + CGO_ENABLED=1 \ + CGO_CFLAGS="-I$$PROVIDER_DIR" \ + go build -o bin/hypervisor ./cmd/hypervisor + +.PHONY: build-hypervisor-tui +build-hypervisor-tui: + go build -o bin/hypervisor-tui ./cmd/hypervisor-tui + + +.PHONY: clean-cache +clean-cache: ## Clean Go build cache. + go clean -cache -testcache + # If you wish to build the manager image targeting other platforms you can use the --platform flag. # (i.e. docker build --platform linux/arm64). However, you must enable docker buildKit for it. # More info: https://docs.docker.com/develop/develop-images/build_enhancements/ diff --git a/README.md b/README.md index b327fa0e..346eea2a 100644 --- a/README.md +++ b/README.md @@ -57,30 +57,34 @@ Tensor Fusion is a state-of-the-art **GPU virtualization and pooling solution** - [x] Fractional GPU and flexible oversubscription - [x] Remote GPU sharing with SOTA GPU-over-IP technology, less than 4% performance loss -- [x] GPU VRAM expansion and hot/warm/cold tiering -- [ ] None NVIDIA GPU/NPU vendor support +- [x] GPU VRAM expansion and hot/cold tiering +- [x] None NVIDIA GPU/NPU vendor support ### Pooling & Scheduling & Management - [x] GPU/NPU pool management in Kubernetes -- [x] GPU-first scheduling and allocation, with single TFlops/MB precision -- [x] GPU node auto provisioning/termination +- [x] GPU-first scheduling and allocation, with 1 TFLOPs, 1% Computing, 1 MB precision +- [x] GPU node auto provisioning/termination, Karpenter integration - [x] GPU compaction/bin-packing +- [x] Take full control of GPU allocation with precision targeting by vendor, model, device index, and more - [x] Seamless onboarding experience for Pytorch, TensorFlow, llama.cpp, vLLM, Tensor-RT, SGlang and all popular AI training/serving frameworks +- [x] Seamless migration from existing NVIDIA operator and device-plugin stack - [x] Centralized Dashboard & Control Plane - [x] GPU-first autoscaling policies, auto set requests/limits/replicas - [x] Request multiple vGPUs with group scheduling for large models - [x] Support different QoS levels +- [x] Hardware partitioned mode isolation like NVIDIA Dynamic MIG +- [x] Support Kubernetes dynamic resource allocation (DRA) API ### Enterprise Features - [x] GPU live-migration, snapshot and restore GPU context cross cluster - [ ] AI model registry and preloading, build your own private MaaS(Model-as-a-Service) -- [ ] Advanced auto-scaling policies, scale to zero, rebalance of hot GPUs +- [x] Advanced auto-scaling policies, scale to zero, rebalance of hot GPUs - [ ] Advanced observability features, detailed metrics & tracing/profiling of CUDA calls -- [ ] Monetize your GPU cluster by multi-tenancy usage measurement & billing report -- [ ] Enterprise level high availability and resilience, support topology aware scheduling, GPU node auto failover etc. -- [ ] Enterprise level security, complete on-premise deployment support +- [x] Monetize your GPU cluster by multi-tenancy usage measurement & billing report +- [x] Enterprise level high availability and resilience, support topology aware scheduling, GPU node auto failover etc. +- [x] Enterprise level security, complete on-premise deployment support - [ ] Enterprise level compliance, SSO/SAML support, advanced audit, ReBAC control, SOC2 and other compliance reports available ### 🗳️ Platform Support diff --git a/api/v1/gpu_types.go b/api/v1/gpu_types.go index d59b747c..6606a4b5 100644 --- a/api/v1/gpu_types.go +++ b/api/v1/gpu_types.go @@ -38,6 +38,10 @@ type GPUStatus struct { UUID string `json:"uuid"` + // +optional + // +kubebuilder:default=soft + IsolationMode IsolationModeType `json:"isolationMode,omitempty"` + // +optional Index *int32 `json:"index,omitempty"` @@ -61,6 +65,16 @@ type GPUStatus struct { // +optional RunningApps []*RunningAppDetail `json:"runningApps,omitempty"` + + // +optional + // PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + // Reported from discovery, each template has fixed resource allocation + PartitionTemplates []PartitionTemplate `json:"partitionTemplates,omitempty"` + + // +optional + // AllocatedPartitions tracks allocated partitions on this GPU + // Key is partitionUUID, value contains template info and allocated resources + AllocatedPartitions map[string]AllocatedPartition `json:"allocatedPartitions,omitempty"` } // +kubebuilder:validation:Enum=tensor-fusion;nvidia-device-plugin @@ -94,6 +108,44 @@ type PodGPUInfo struct { QoS QoSLevel `json:"qos,omitempty"` } +// PartitionTemplate represents a hardware partition template (e.g., MIG profile) +// Only stores template ID and name in GPU status. Detailed resource information +// is stored in public GPU info config. +type PartitionTemplate struct { + // TemplateID is the unique identifier for this partition template (e.g., "1g.24gb", "4g.94gb") + TemplateID string `json:"templateId"` + + // Name is a human-readable name for this template + Name string `json:"name"` +} + +// AllocatedPartition represents an allocated partition on a GPU +// Key in AllocatedPartitions map is podUID +type AllocatedPartition struct { + // TemplateID is the template used to create this partition + TemplateID string `json:"templateId"` + + // PodUID is the UID of the pod using this partition (used as map key) + PodUID string `json:"podUid"` + + // PodName is the name of the pod using this partition + PodName string `json:"podName"` + + // Namespace is the namespace of the pod using this partition + Namespace string `json:"namespace"` + + // AllocatedAt is when this partition was allocated + AllocatedAt metav1.Time `json:"allocatedAt"` + + // AllocatedSlotStart is the starting slot position where this partition is allocated + // This is the actual hardware slot position (0-based index) + AllocatedSlotStart *uint32 `json:"allocatedSlotStart,omitempty"` + + // AllocatedSlotEnd is the ending slot position (exclusive) where this partition is allocated + // The partition occupies slots [AllocatedSlotStart, AllocatedSlotEnd) + AllocatedSlotEnd *uint32 `json:"allocatedSlotEnd,omitempty"` +} + // +kubebuilder:validation:Enum=Pending;Provisioning;Running;Unknown;Destroying;Migrating type TensorFusionGPUPhase string diff --git a/api/v1/gpupool_types.go b/api/v1/gpupool_types.go index 78fe7e84..5d3cf8a2 100644 --- a/api/v1/gpupool_types.go +++ b/api/v1/gpupool_types.go @@ -33,6 +33,10 @@ type GPUPoolSpec struct { // +optional DefaultUsingLocalGPU *bool `json:"defaultUsingLocalGPU,omitempty"` + // +optional + // +kubebuilder:default=NVIDIA + Vendor string `json:"vendor,omitempty"` + CapacityConfig *CapacityConfig `json:"capacityConfig,omitempty"` NodeManagerConfig *NodeManagerConfig `json:"nodeManagerConfig,omitempty"` @@ -88,12 +92,23 @@ type NodeManagerConfig struct { // +kubebuilder:default="AutoSelect" ProvisioningMode ProvisioningMode `json:"provisioningMode,omitempty"` + // +optional + // +kubebuilder:default=NVIDIA + // In single AI accelerator hardware vendor mode, when default vendor set + // All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + DefaultVendor string `json:"defaultVendor,omitempty"` + // +optional NodeProvisioner *NodeProvisioner `json:"nodeProvisioner,omitempty"` // +optional NodeSelector *corev1.NodeSelector `json:"nodeSelector,omitempty"` + // +optional + // When this field set, the GPU pool will be in multi AI accelerator vendor mode + // each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + MultiVendorNodeSelector map[string]*corev1.NodeSelector `json:"multiVendorNodeSelector,omitempty"` + // +optional NodeCompaction *NodeCompaction `json:"nodeCompaction,omitempty"` diff --git a/api/v1/gpuresourcequota_types.go b/api/v1/gpuresourcequota_types.go index e5ba09b8..322bc9c5 100644 --- a/api/v1/gpuresourcequota_types.go +++ b/api/v1/gpuresourcequota_types.go @@ -194,6 +194,12 @@ type AllocRequest struct { PodMeta metav1.ObjectMeta QoS QoSLevel + + Isolation IsolationModeType + + // PartitionTemplateID is the template ID used for partitioned mode allocation + // This is set by the scheduler when a partition is matched, or read from pod annotation + PartitionTemplateID string } func (p *AllocRequest) Clone() fwk.StateData { diff --git a/api/v1/workloadprofile_types.go b/api/v1/workloadprofile_types.go index 5bd70f0c..57b7dec7 100644 --- a/api/v1/workloadprofile_types.go +++ b/api/v1/workloadprofile_types.go @@ -63,6 +63,11 @@ type WorkloadProfileSpec struct { // How to isolate resources, could be `shared` or `soft` or `hard` or `partitioned` Isolation IsolationModeType `json:"isolation,omitempty"` + // +optional + // PartitionTemplateID specifies the partition template ID for partitioned isolation mode + // This is read from pod annotation tensor-fusion.ai/partition if specified + PartitionTemplateID string `json:"partitionTemplateId,omitempty"` + // +optional // GPUModel specifies the required GPU model (e.g., "A100", "H100") GPUModel string `json:"gpuModel,omitempty"` diff --git a/api/v1/zz_generated.deepcopy.go b/api/v1/zz_generated.deepcopy.go index 110155a2..44089a1e 100644 --- a/api/v1/zz_generated.deepcopy.go +++ b/api/v1/zz_generated.deepcopy.go @@ -77,6 +77,32 @@ func (in *AllocRequest) DeepCopy() *AllocRequest { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *AllocatedPartition) DeepCopyInto(out *AllocatedPartition) { + *out = *in + in.AllocatedAt.DeepCopyInto(&out.AllocatedAt) + if in.AllocatedSlotStart != nil { + in, out := &in.AllocatedSlotStart, &out.AllocatedSlotStart + *out = new(uint32) + **out = **in + } + if in.AllocatedSlotEnd != nil { + in, out := &in.AllocatedSlotEnd, &out.AllocatedSlotEnd + *out = new(uint32) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AllocatedPartition. +func (in *AllocatedPartition) DeepCopy() *AllocatedPartition { + if in == nil { + return nil + } + out := new(AllocatedPartition) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *AutoFreeze) DeepCopyInto(out *AutoFreeze) { *out = *in @@ -1350,6 +1376,18 @@ func (in *GPUStatus) DeepCopyInto(out *GPUStatus) { } } } + if in.PartitionTemplates != nil { + in, out := &in.PartitionTemplates, &out.PartitionTemplates + *out = make([]PartitionTemplate, len(*in)) + copy(*out, *in) + } + if in.AllocatedPartitions != nil { + in, out := &in.AllocatedPartitions, &out.AllocatedPartitions + *out = make(map[string]AllocatedPartition, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new GPUStatus. @@ -1602,6 +1640,22 @@ func (in *NodeManagerConfig) DeepCopyInto(out *NodeManagerConfig) { *out = new(corev1.NodeSelector) (*in).DeepCopyInto(*out) } + if in.MultiVendorNodeSelector != nil { + in, out := &in.MultiVendorNodeSelector, &out.MultiVendorNodeSelector + *out = make(map[string]*corev1.NodeSelector, len(*in)) + for key, val := range *in { + var outVal *corev1.NodeSelector + if val == nil { + (*out)[key] = nil + } else { + inVal := (*in)[key] + in, out := &inVal, &outVal + *out = new(corev1.NodeSelector) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } if in.NodeCompaction != nil { in, out := &in.NodeCompaction, &out.NodeCompaction *out = new(NodeCompaction) @@ -1725,6 +1779,21 @@ func (in *Oversubscription) DeepCopy() *Oversubscription { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PartitionTemplate) DeepCopyInto(out *PartitionTemplate) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PartitionTemplate. +func (in *PartitionTemplate) DeepCopy() *PartitionTemplate { + if in == nil { + return nil + } + out := new(PartitionTemplate) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *PeriodicalBudget) DeepCopyInto(out *PeriodicalBudget) { *out = *in diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml index a8c2b5a0..afe2df8b 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml @@ -249,6 +249,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The + terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -608,6 +710,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object status: description: GPUPoolStatus defines the observed state of GPUPool. diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml index 50c76bce..84e3ee86 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml @@ -69,6 +69,54 @@ spec: GPUStatus defines the observed state of GPU. NOTE: When new fields added, remember to update syncGPUMetadataAndStatusFromCluster properties: + allocatedPartitions: + additionalProperties: + description: |- + AllocatedPartition represents an allocated partition on a GPU + Key in AllocatedPartitions map is podUID + properties: + allocatedAt: + description: AllocatedAt is when this partition was allocated + format: date-time + type: string + allocatedSlotEnd: + description: |- + AllocatedSlotEnd is the ending slot position (exclusive) where this partition is allocated + The partition occupies slots [AllocatedSlotStart, AllocatedSlotEnd) + format: int32 + type: integer + allocatedSlotStart: + description: |- + AllocatedSlotStart is the starting slot position where this partition is allocated + This is the actual hardware slot position (0-based index) + format: int32 + type: integer + namespace: + description: Namespace is the namespace of the pod using this + partition + type: string + podName: + description: PodName is the name of the pod using this partition + type: string + podUid: + description: PodUID is the UID of the pod using this partition + (used as map key) + type: string + templateId: + description: TemplateID is the template used to create this + partition + type: string + required: + - allocatedAt + - namespace + - podName + - podUid + - templateId + type: object + description: |- + AllocatedPartitions tracks allocated partitions on this GPU + Key is partitionUUID, value contains template info and allocated resources + type: object available: properties: compute: @@ -124,6 +172,14 @@ spec: index: format: int32 type: integer + isolationMode: + default: soft + enum: + - shared + - soft + - hard + - partitioned + type: string message: type: string model: @@ -138,6 +194,28 @@ spec: NUMA node format: int32 type: integer + partitionTemplates: + description: |- + PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + Reported from discovery, each template has fixed resource allocation + items: + description: |- + PartitionTemplate represents a hardware partition template (e.g., MIG profile) + Only stores template ID and name in GPU status. Detailed resource information + is stored in public GPU info config. + properties: + name: + description: Name is a human-readable name for this template + type: string + templateId: + description: TemplateID is the unique identifier for this partition + template (e.g., "1g.24gb", "4g.94gb") + type: string + required: + - name + - templateId + type: object + type: array phase: default: Pending enum: diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml index d80f589b..c43bb82b 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml @@ -315,6 +315,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector + terms. The terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -675,6 +777,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object required: - specTemplate diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml index 6fe04c9a..f432f499 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml @@ -466,6 +466,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml index f7fd3820..d22286b2 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml @@ -453,6 +453,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/charts/tensor-fusion/templates/controller-deployment.yaml b/charts/tensor-fusion/templates/controller-deployment.yaml index c16c4aab..ef409a1d 100644 --- a/charts/tensor-fusion/templates/controller-deployment.yaml +++ b/charts/tensor-fusion/templates/controller-deployment.yaml @@ -57,7 +57,7 @@ spec: fieldPath: metadata.namespace # when deploy with AutoSelect mode, GPU node is managed by Kubernetes rather than TensorFusion, thus, need to specify the label selector to generate the GPUNode custom resource - name: INITIAL_GPU_NODE_LABEL_SELECTOR - value: "{{ default "nvidia.com/gpu.present=true" .Values.initialGpuNodeLabelSelector }}" + value: "{{ .Values.initialGpuNodeLabelSelector }}" - name: TSDB_MYSQL_HOST value: "{{ .Values.greptime.host }}" - name: TSDB_MYSQL_PORT diff --git a/charts/tensor-fusion/values-multi-vendor.yaml b/charts/tensor-fusion/values-multi-vendor.yaml new file mode 100644 index 00000000..66233244 --- /dev/null +++ b/charts/tensor-fusion/values-multi-vendor.yaml @@ -0,0 +1 @@ +initialGpuNodeLabelSelector: "" diff --git a/cmd/hypervisor-tui/main.go b/cmd/hypervisor-tui/main.go new file mode 100644 index 00000000..e0e1294a --- /dev/null +++ b/cmd/hypervisor-tui/main.go @@ -0,0 +1,54 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + "flag" + "os" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/tui" + tea "github.com/charmbracelet/bubbletea" + "k8s.io/klog/v2" +) + +var ( + host = flag.String("host", "localhost", "Hypervisor server host") + port = flag.Int("port", 8001, "Hypervisor server port") +) + +func main() { + flag.Parse() + klog.InitFlags(nil) + defer klog.Flush() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create HTTP client + client := tui.NewClient(*host, *port) + + // Create TUI model + model := tui.NewModel(ctx, client) + + // Start TUI + p := tea.NewProgram(model, tea.WithAltScreen()) + if _, err := p.Run(); err != nil { + klog.Fatalf("Error running TUI: %v", err) + os.Exit(1) + } +} diff --git a/cmd/hypervisor/main.go b/cmd/hypervisor/main.go new file mode 100644 index 00000000..041f2b5b --- /dev/null +++ b/cmd/hypervisor/main.go @@ -0,0 +1,164 @@ +package main + +import ( + "context" + "flag" + "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/cmd/hypervisor/shm_init" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/kubernetes" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/single_node" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/device" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/metrics" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/server" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker" + "github.com/NexusGPU/tensor-fusion/internal/utils" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/klog/v2" +) + +var ( + acceleratorLibPath = flag.String("accelerator-lib", + "./provider/build/libaccelerator_stub.so", "Path to accelerator library") + isolationMode = flag.String("isolation-mode", "shared", + "Isolation mode: shared, soft, hard, partitioned") + backendType = flag.String("backend-type", "kubernetes", "Backend type: kubernetes, simple") + discoveryInterval = flag.Duration("discovery-interval", + 12*time.Hour, "Device discovery interval") + metricsPath = flag.String("metrics-output-path", "metrics.log", "Path to metrics output file") + + httpPort = flag.Int("port", int(constants.HypervisorDefaultPortNumber), "HTTP port for hypervisor API") +) + +const ( + TFHardwareVendorEnv = "TF_HARDWARE_VENDOR" + TFAcceleratorLibPathEnv = "TF_ACCELERATOR_LIB_PATH" +) + +const ( + MountShmSubcommand = "mount-shm" +) + +func main() { + // Check for subcommands (used inside init container for initializing shared memory of limiter of soft isolation) + if len(os.Args) > 1 && os.Args[1] == MountShmSubcommand { + shm_init.RunMountShm() + return + } + + flag.Parse() + klog.InitFlags(nil) + defer klog.Flush() + + ctx, cancel := context.WithCancel(context.Background()) + + utils.NormalizeKubeConfigEnv() + + // Determine accelerator library path from env var or flag + libPath := *acceleratorLibPath + if envLibPath := os.Getenv(TFAcceleratorLibPathEnv); envLibPath != "" { + libPath = envLibPath + klog.Infof("Using accelerator library path from env: %s", libPath) + } + if vendor := os.Getenv(TFHardwareVendorEnv); vendor != "" { + klog.Infof("Hardware vendor from env: %s", vendor) + } + + // Create and start device controller + deviceController, err := device.NewController(ctx, libPath, *discoveryInterval) + if err != nil { + klog.Fatalf("Failed to create device controller: %v", err) + } + if err := deviceController.Start(); err != nil { + klog.Fatalf("Failed to start device manager: %v", err) + } + klog.Info("Device manager started") + + mode := tfv1.IsolationModeType(*isolationMode) + + // initialize data backend and worker controller + var backend framework.Backend + var workerController framework.WorkerController + + switch *backendType { + case "kubernetes": + // Get Kubernetes rest config + var restConfig *rest.Config + kubeconfig := os.Getenv("KUBECONFIG") + if kubeconfig != "" { + restConfig, err = clientcmd.BuildConfigFromFlags("", kubeconfig) + } else { + restConfig, err = rest.InClusterConfig() + } + if err != nil { + klog.Fatalf("Failed to get Kubernetes config: %v", err) + } + + backend, err = kubernetes.NewKubeletBackend(ctx, deviceController, workerController, restConfig) + if err != nil { + klog.Fatalf("Failed to create Kubernetes backend: %v", err) + } + workerController = worker.NewWorkerController(deviceController, mode, backend) + case "simple": + backend = single_node.NewSingleNodeBackend(ctx, deviceController) + workerController = worker.NewWorkerController(deviceController, mode, backend) + default: + klog.Fatalf("Invalid backend type: %s", *backendType) + } + err = workerController.Start() + if err != nil { + klog.Fatalf("Failed to start worker controller: %v", err) + } + defer func() { + _ = workerController.Stop() + }() + klog.Info("Worker controller started") + + // initialize metrics recorder + metricsRecorder := metrics.NewHypervisorMetricsRecorder(ctx, *metricsPath, deviceController, workerController) + metricsRecorder.Start() + klog.Info("Metrics recorder started") + + // initialize and start HTTP server + httpPortNum := *httpPort + if httpPortEnv := os.Getenv(constants.HypervisorPortEnv); httpPortEnv != "" { + httpPortNum, err = strconv.Atoi(httpPortEnv) + if err != nil { + klog.Fatalf("Failed to convert HTTP port from env: %v", err) + } + } + httpServer := server.NewServer(ctx, deviceController, workerController, metricsRecorder, backend, httpPortNum) + go func() { + if err := httpServer.Start(); err != nil && err != http.ErrServerClosed { + klog.Fatalf("Failed to start HTTP server: %v", err) + } + }() + klog.Info("HTTP server started") + + // Wait for interrupt signal + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + klog.Info("Hypervisor running") + <-sigCh + klog.Info("Stopping hypervisor...") + + // Shutdown HTTP server + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := httpServer.Stop(shutdownCtx); err != nil { + klog.Errorf("Error shutting down HTTP server: %v", err) + } + + cancel() + klog.Info("Hypervisor stopped") +} diff --git a/cmd/hypervisor/shm_init/mount_shm.go b/cmd/hypervisor/shm_init/mount_shm.go new file mode 100644 index 00000000..cd6eea08 --- /dev/null +++ b/cmd/hypervisor/shm_init/mount_shm.go @@ -0,0 +1,91 @@ +package shm_init + +import ( + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + + "k8s.io/klog/v2" +) + +// runMountShm handles the "mount-shm" subcommand +func RunMountShm() { + // Create a new flag set for mount-shm subcommand + mountShmFlags := flag.NewFlagSet("mount-shm", flag.ExitOnError) + mountPoint := mountShmFlags.String("mount-point", "", "Mount point directory path (required)") + sizeMB := mountShmFlags.Int("size", 0, "Size in MB (required)") + + klog.InitFlags(nil) + if err := mountShmFlags.Parse(os.Args[2:]); err != nil { + klog.Fatalf("Failed to parse flags: %v", err) + } + + if *mountPoint == "" { + klog.Fatalf("mount-point is required") + } + if *sizeMB <= 0 { + klog.Fatalf("size must be greater than 0") + } + + klog.Infof("mount point: %s", *mountPoint) + klog.Infof("size: %d MB", *sizeMB) + + // Create mount point directory if it doesn't exist + if _, err := os.Stat(*mountPoint); os.IsNotExist(err) { + klog.Infof("create mount point directory: %s", *mountPoint) + if err := os.MkdirAll(*mountPoint, 0755); err != nil { + klog.Fatalf("create mount point directory failed: %v", err) + } + } + + // Check if tmpfs is already mounted + mountCmd := exec.Command("mount") + mountOutput, err := mountCmd.Output() + if err != nil { + klog.Fatalf("execute mount command failed: %v", err) + } + + mountInfo := string(mountOutput) + mountPointAbs, err := filepath.Abs(*mountPoint) + if err != nil { + klog.Fatalf("get absolute path failed: %v", err) + } + + expectedMountStr := fmt.Sprintf("on %s type tmpfs", mountPointAbs) + if strings.Contains(mountInfo, expectedMountStr) { + klog.Infof("tmpfs is already mounted on %s", *mountPoint) + } else { + // Mount tmpfs + klog.Infof("mount tmpfs on %s", *mountPoint) + sizeArg := fmt.Sprintf("size=%dM", *sizeMB) + + mountTmpfsCmd := exec.Command("mount", + "-t", "tmpfs", + "-o", fmt.Sprintf("rw,nosuid,nodev,%s", sizeArg), + "tmpfs", + mountPointAbs, + ) + + if err := mountTmpfsCmd.Run(); err != nil { + klog.Fatalf("mount tmpfs failed: %v", err) + } + + klog.Info("mount tmpfs successfully") + } + + // Set directory permissions to 0777 + // Save old umask + oldUmask := syscall.Umask(0) + defer syscall.Umask(oldUmask) + + // Set permissions + if err := os.Chmod(*mountPoint, 0777); err != nil { + klog.Fatalf("set permissions failed: %v", err) + } + + klog.Info("mount-shm completed successfully") +} diff --git a/cmd/main.go b/cmd/main.go index c55a219c..22642b6b 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -20,9 +20,7 @@ import ( "context" "crypto/tls" "flag" - "fmt" "os" - "strings" "time" // Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.) @@ -189,7 +187,7 @@ func main() { metricsServerOptions.FilterProvider = filters.WithAuthenticationAndAuthorization } - normalizeKubeConfigEnv() + utils.NormalizeKubeConfigEnv() kc := ctrl.GetConfigOrDie() mgr, err := ctrl.NewManager(kc, ctrl.Options{ Scheme: scheme, @@ -688,19 +686,6 @@ func startWatchGPUInfoChanges(ctx context.Context, gpuInfos *[]config.GpuInfo, g }() } -// only for local development, won't set KUBECONFIG env var in none local environments -func normalizeKubeConfigEnv() { - cfgPath := os.Getenv("KUBECONFIG") - if cfgPath != "" && strings.HasPrefix(cfgPath, "~") { - home, err := os.UserHomeDir() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - _ = os.Setenv("KUBECONFIG", strings.Replace(cfgPath, "~", home, 1)) - } -} - // Setup GreptimeDB connection func setupTimeSeriesDB() *metrics.TimeSeriesDB { timeSeriesDB := &metrics.TimeSeriesDB{} diff --git a/config/crd/bases/tensor-fusion.ai_gpupools.yaml b/config/crd/bases/tensor-fusion.ai_gpupools.yaml index a8c2b5a0..afe2df8b 100644 --- a/config/crd/bases/tensor-fusion.ai_gpupools.yaml +++ b/config/crd/bases/tensor-fusion.ai_gpupools.yaml @@ -249,6 +249,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The + terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -608,6 +710,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object status: description: GPUPoolStatus defines the observed state of GPUPool. diff --git a/config/crd/bases/tensor-fusion.ai_gpus.yaml b/config/crd/bases/tensor-fusion.ai_gpus.yaml index 50c76bce..84e3ee86 100644 --- a/config/crd/bases/tensor-fusion.ai_gpus.yaml +++ b/config/crd/bases/tensor-fusion.ai_gpus.yaml @@ -69,6 +69,54 @@ spec: GPUStatus defines the observed state of GPU. NOTE: When new fields added, remember to update syncGPUMetadataAndStatusFromCluster properties: + allocatedPartitions: + additionalProperties: + description: |- + AllocatedPartition represents an allocated partition on a GPU + Key in AllocatedPartitions map is podUID + properties: + allocatedAt: + description: AllocatedAt is when this partition was allocated + format: date-time + type: string + allocatedSlotEnd: + description: |- + AllocatedSlotEnd is the ending slot position (exclusive) where this partition is allocated + The partition occupies slots [AllocatedSlotStart, AllocatedSlotEnd) + format: int32 + type: integer + allocatedSlotStart: + description: |- + AllocatedSlotStart is the starting slot position where this partition is allocated + This is the actual hardware slot position (0-based index) + format: int32 + type: integer + namespace: + description: Namespace is the namespace of the pod using this + partition + type: string + podName: + description: PodName is the name of the pod using this partition + type: string + podUid: + description: PodUID is the UID of the pod using this partition + (used as map key) + type: string + templateId: + description: TemplateID is the template used to create this + partition + type: string + required: + - allocatedAt + - namespace + - podName + - podUid + - templateId + type: object + description: |- + AllocatedPartitions tracks allocated partitions on this GPU + Key is partitionUUID, value contains template info and allocated resources + type: object available: properties: compute: @@ -124,6 +172,14 @@ spec: index: format: int32 type: integer + isolationMode: + default: soft + enum: + - shared + - soft + - hard + - partitioned + type: string message: type: string model: @@ -138,6 +194,28 @@ spec: NUMA node format: int32 type: integer + partitionTemplates: + description: |- + PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + Reported from discovery, each template has fixed resource allocation + items: + description: |- + PartitionTemplate represents a hardware partition template (e.g., MIG profile) + Only stores template ID and name in GPU status. Detailed resource information + is stored in public GPU info config. + properties: + name: + description: Name is a human-readable name for this template + type: string + templateId: + description: TemplateID is the unique identifier for this partition + template (e.g., "1g.24gb", "4g.94gb") + type: string + required: + - name + - templateId + type: object + type: array phase: default: Pending enum: diff --git a/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml b/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml index d80f589b..c43bb82b 100644 --- a/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml +++ b/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml @@ -315,6 +315,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector + terms. The terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -675,6 +777,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object required: - specTemplate diff --git a/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml b/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml index 6fe04c9a..f432f499 100644 --- a/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml +++ b/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml @@ -466,6 +466,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml b/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml index f7fd3820..d22286b2 100644 --- a/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml +++ b/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml @@ -453,6 +453,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/dockerfile/node-discovery.Dockerfile b/dockerfile/hypervisor.Dockerfile similarity index 83% rename from dockerfile/node-discovery.Dockerfile rename to dockerfile/hypervisor.Dockerfile index 09ac6741..e2eae468 100644 --- a/dockerfile/node-discovery.Dockerfile +++ b/dockerfile/hypervisor.Dockerfile @@ -15,6 +15,7 @@ RUN go mod download COPY cmd/ cmd/ COPY api/ api/ COPY internal/ internal/ +COPY provider/ provider/ # Build @@ -22,13 +23,13 @@ COPY internal/ internal/ # was called. For example, if we call make docker-build in a local env which has the Apple Silicon M1 SO # the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore, # by leaving it empty we can ensure that the container and binary shipped on it will have the same platform. -RUN CGO_ENABLED=1 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o nodediscovery cmd/nodediscovery/main.go +RUN CGO_ENABLED=1 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o hypervisor cmd/hypervisor/main.go -# Use distroless as minimal base image to package the nodediscovery binary +# Use distroless as minimal base image to package the hypervisor binary # Refer to https://github.com/GoogleContainerTools/distroless for more details FROM ubuntu:24.04 WORKDIR / -COPY --from=builder /workspace/nodediscovery . +COPY --from=builder /workspace/hypervisor . USER 65532:65532 -ENTRYPOINT ["/nodediscovery"] +ENTRYPOINT ["/hypervisor"] diff --git a/go.mod b/go.mod index 7a806600..27a14399 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,13 @@ require ( github.com/NVIDIA/go-nvml v0.13.0-1 github.com/aliyun/alibaba-cloud-sdk-go v1.63.107 github.com/aws/aws-sdk-go-v2 v1.40.0 - github.com/aws/aws-sdk-go-v2/service/ec2 v1.274.0 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.275.0 github.com/aws/smithy-go v1.23.2 github.com/awslabs/operatorpkg v0.0.0-20251024191238-14554b75b88a + github.com/charmbracelet/bubbles v0.21.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/fsnotify/fsnotify v1.9.0 github.com/gin-contrib/gzip v1.2.5 github.com/gin-gonic/gin v1.11.0 github.com/go-sql-driver/mysql v1.9.3 @@ -27,6 +31,7 @@ require ( go.uber.org/zap v1.27.1 golang.org/x/time v0.14.0 gomodules.xyz/jsonpatch/v2 v2.5.0 + google.golang.org/grpc v1.77.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.1 @@ -39,6 +44,7 @@ require ( k8s.io/component-helpers v0.34.2 k8s.io/klog/v2 v2.130.1 k8s.io/kube-scheduler v0.34.2 + k8s.io/kubelet v0.34.2 k8s.io/kubernetes v1.34.2 k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 sigs.k8s.io/controller-runtime v0.22.4 @@ -53,10 +59,12 @@ require ( github.com/Masterminds/semver/v3 v3.4.0 // indirect github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect + github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.14 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.14 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.14 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect @@ -64,15 +72,19 @@ require ( github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/x/ansi v0.10.1 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/emicklei/go-restful/v3 v3.13.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/gin-contrib/sse v1.1.0 // indirect @@ -119,12 +131,18 @@ require ( github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/moby/term v0.5.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect @@ -136,6 +154,8 @@ require ( github.com/prometheus/procfs v0.17.0 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.55.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/sahilm/fuzzy v0.1.1 // indirect github.com/spf13/cobra v1.10.1 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/stoewer/go-strcase v1.3.1 // indirect @@ -143,11 +163,12 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.0 // indirect github.com/x448/float16 v0.8.4 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.etcd.io/etcd/api/v3 v3.6.4 // indirect go.etcd.io/etcd/client/pkg/v3 v3.6.4 // indirect go.etcd.io/etcd/client/v3 v3.6.4 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect @@ -164,15 +185,14 @@ require ( golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b // indirect golang.org/x/mod v0.29.0 // indirect golang.org/x/net v0.47.0 // indirect - golang.org/x/oauth2 v0.31.0 // indirect + golang.org/x/oauth2 v0.32.0 // indirect golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/tools v0.38.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1 // indirect - google.golang.org/grpc v1.75.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect @@ -186,7 +206,6 @@ require ( k8s.io/dynamic-resource-allocation v0.34.0 // indirect k8s.io/kms v0.34.2 // indirect k8s.io/kube-openapi v0.0.0-20250905212525-66792eed8611 // indirect - k8s.io/kubelet v0.34.0 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.33.0 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect diff --git a/go.sum b/go.sum index 4ce09d77..c4dbf19d 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/aliyun/alibaba-cloud-sdk-go v1.63.107 h1:qagvUyrgOnBIlVRQWOyCZGVKUIYb github.com/aliyun/alibaba-cloud-sdk-go v1.63.107/go.mod h1:SOSDHfe1kX91v3W5QiBsWSLqeLxImobbMX1mxrFHsVQ= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= github.com/aws/aws-sdk-go-v2 v1.40.0 h1:/WMUA0kjhZExjOQN2z3oLALDREea1A7TobfuiBrKlwc= @@ -30,8 +32,8 @@ github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.14 h1:PZHqQACxYb8mYgms4 github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.14/go.mod h1:VymhrMJUWs69D8u0/lZ7jSB6WgaG/NqHi3gX0aYf6U0= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.14 h1:bOS19y6zlJwagBfHxs0ESzr1XCOU2KXJCWcq3E2vfjY= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.14/go.mod h1:1ipeGBMAxZ0xcTm6y6paC2C/J6f6OO7LBODV9afuAyM= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.274.0 h1:Q2+WD4KSVRkd27QxD9I30nM3O7B4WYwE+ua5dm2NJY0= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.274.0/go.mod h1:QrV+/GjhSrJh6MRRuTO6ZEg4M2I0nwPakf0lZHSrE1o= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.275.0 h1:ymusjrsOjrcVBQNQXYFIQEHJIJ17/m+VoDSmWIMjGe0= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.275.0/go.mod h1:QrV+/GjhSrJh6MRRuTO6ZEg4M2I0nwPakf0lZHSrE1o= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.14 h1:FIouAnCE46kyYqyhs0XEBDFFSREtdnr8HQuLPQPLCrY= @@ -40,6 +42,10 @@ github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/awslabs/operatorpkg v0.0.0-20251024191238-14554b75b88a h1:qstXCawuAwrgFLoaU1IIYGGFeVKVBkJMVSSSKJXBD14= github.com/awslabs/operatorpkg v0.0.0-20251024191238-14554b75b88a/go.mod h1:D4OLvXkR+2pp9RKo8Ovjc1Mqnd0qPRW0gz3cjxGSCkA= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= +github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= @@ -54,6 +60,22 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= +github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= +github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= @@ -74,6 +96,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes= github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/evanphx/json-patch v5.6.0+incompatible h1:jBYDEEiFBPxA0v50tFdvOzQQTCvpL6mnFh5mB2/l16U= github.com/evanphx/json-patch v5.6.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= @@ -242,12 +266,18 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c= github.com/lithammer/shortuuid/v4 v4.2.0/go.mod h1:D5noHZ2oFw/YaKCfGy0YxyE7M0wMbezmMjPdhyEFe6Y= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mfridman/tparse v0.18.0 h1:wh6dzOKaIwkUGyKgOntDW4liXSo37qg5AXbIhkMV3vE= github.com/mfridman/tparse v0.18.0/go.mod h1:gEvqZTuCgEhPbYk/2lS3Kcxg1GmTxxU7kTC8DvP0i/A= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= @@ -262,6 +292,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -294,11 +330,16 @@ github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk= github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= +github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= @@ -348,6 +389,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510 h1:S2dVYn90KE98chqDkyE9Z4N61UnQd+KOfgp5Iu53llk= github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= @@ -366,8 +409,8 @@ go.etcd.io/etcd/server/v3 v3.6.4 h1:LsCA7CzjVt+8WGrdsnh6RhC0XqCsLkBly3ve5rTxMAU= go.etcd.io/etcd/server/v3 v3.6.4/go.mod h1:aYCL/h43yiONOv0QIR82kH/2xZ7m+IWYjzRmyQfnCAg= go.etcd.io/raft/v3 v3.6.0 h1:5NtvbDVYpnfZWcIHgGRk9DyzkBIXOi8j+DDp1IcnUWQ= go.etcd.io/raft/v3 v3.6.0/go.mod h1:nLvLevg6+xrVtHUmVaTcTz603gQPHfh7kUAwV6YpfGo= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= @@ -432,8 +475,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= -golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= +golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -445,6 +488,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= @@ -478,12 +522,12 @@ gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= -google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1 h1:APHvLLYBhtZvsbnpkfknDZ7NyH4z5+ub/I0u8L3Oz6g= -google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1/go.mod h1:xUjFWUnWDpZ/C0Gu0qloASKFb6f8/QXiiXhSPFsD668= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1 h1:pmJpJEvT846VzausCQ5d7KreSROcDqmO388w5YbnltA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1/go.mod h1:GmFNa4BdJZ2a8G+wCe9Bg3wwThLrJun751XstdJt5Og= -google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= -google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -542,8 +586,8 @@ k8s.io/kube-openapi v0.0.0-20250905212525-66792eed8611 h1:o4oKOsvSymDkZRsMAPZU7b k8s.io/kube-openapi v0.0.0-20250905212525-66792eed8611/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ= k8s.io/kube-scheduler v0.34.2 h1:TtLcaXeIpkqgzMr2ch7Ap8Cluq4M182XUDRlnOPDdoc= k8s.io/kube-scheduler v0.34.2/go.mod h1:PTn4QYiSet8/00VQ2qGO/HWdo5iNJlVRCXz/7R3Ut5I= -k8s.io/kubelet v0.34.0 h1:1nZt1Q6Kfx7xCaTS9vnqR9sjZDxf3cRSQkAFCczULmc= -k8s.io/kubelet v0.34.0/go.mod h1:NqbF8ViVettlZbf9hw9DJhubaWn7rGvDDTcLMDm6tQ0= +k8s.io/kubelet v0.34.2 h1:Dl+1uh7xwJr70r+SHKyIpvu6XvzuoPu0uDIC4cqgJUs= +k8s.io/kubelet v0.34.2/go.mod h1:RfwR03iuKeVV7Z1qD9XKH98c3tlPImJpQ3qHIW40htM= k8s.io/kubernetes v1.34.2 h1:WQdDvYJazkmkwSncgNwGvVtaCt4TYXIU3wSMRgvp3MI= k8s.io/kubernetes v1.34.2/go.mod h1:m6pZk6a179pRo2wsTiCPORJ86iOEQmfIzUvtyEF8BwA= k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck= diff --git a/internal/autoscaler/autoscaler_suite_test.go b/internal/autoscaler/autoscaler_suite_test.go index 0595acce..6e9f69fe 100644 --- a/internal/autoscaler/autoscaler_suite_test.go +++ b/internal/autoscaler/autoscaler_suite_test.go @@ -273,7 +273,9 @@ var _ = BeforeSuite(func() { var _ = AfterSuite(func() { By("tearing down the test environment") - allocator.Stop() + if allocator != nil { + allocator.Stop() + } cancel() err := testEnv.Stop() Expect(err).NotTo(HaveOccurred()) diff --git a/internal/autoscaler/autoscaler_test.go b/internal/autoscaler/autoscaler_test.go index 2eba22fb..1055f98e 100644 --- a/internal/autoscaler/autoscaler_test.go +++ b/internal/autoscaler/autoscaler_test.go @@ -91,11 +91,11 @@ var _ = Describe("Autoscaler", func() { // create two workloads pool := tfEnv.GetGPUPool(0) - // with two replias + // with two replicas workload0 := createWorkload(pool, 0, 2) workload0Workers := getWorkers(workload0) key0 := WorkloadID{workload0.Namespace, workload0.Name} - // with one replia + // with one replica workload1 := createWorkload(pool, 1, 1) workload1Workers := getWorkers(workload1) key1 := WorkloadID{workload1.Namespace, workload1.Name} @@ -539,8 +539,8 @@ func (f *FakeRecommender) Name() string { return "fake" } -func (f *FakeRecommender) Recommend(ctx context.Context, workoad *workload.State) (*recommender.RecResult, error) { - meta.SetStatusCondition(&workoad.Status.Conditions, metav1.Condition{ +func (f *FakeRecommender) Recommend(ctx context.Context, workload *workload.State) (*recommender.RecResult, error) { + meta.SetStatusCondition(&workload.Status.Conditions, metav1.Condition{ Type: constants.ConditionStatusTypeRecommendationProvided, Status: metav1.ConditionTrue, LastTransitionTime: metav1.Now(), @@ -667,7 +667,7 @@ func mockSchedulerLoop(ctx context.Context, cfg *rest.Config) { func scheduleAndStartPod(pod *corev1.Pod, clientset *kubernetes.Clientset) { // simulate scheduling cycle Filter and Reserve - allocRequest, _, err := allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(ctx, pod) Expect(err).To(Succeed()) gpus, err := allocator.Alloc(allocRequest) if err != nil { diff --git a/internal/cloudprovider/pricing/pricing.go b/internal/cloudprovider/pricing/pricing.go index 45dd09bb..65dfccbd 100644 --- a/internal/cloudprovider/pricing/pricing.go +++ b/internal/cloudprovider/pricing/pricing.go @@ -31,6 +31,7 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/cloudprovider/types" "github.com/NexusGPU/tensor-fusion/internal/config" "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" "k8s.io/apimachinery/pkg/api/resource" "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -104,6 +105,9 @@ func SetTflopsMapAndInitGPUPricingInfo(ctx context.Context, gpuInfos *[]config.G tflopsMap[gpuInfo.Model] = completeInfo } + // Load partition templates from config + gpuallocator.LoadPartitionTemplatesFromConfig(*gpuInfos) + initOnce.Do(func() { globalAWSGPUInstanceData = make(map[string]GPUNodeInstanceInfoAndPrice) globalAzureGPUInstanceData = make(map[string]GPUNodeInstanceInfoAndPrice) diff --git a/internal/component/component.go b/internal/component/component.go index e3940a15..13446456 100644 --- a/internal/component/component.go +++ b/internal/component/component.go @@ -170,7 +170,7 @@ func calculateDesiredUpdatedDelta(total int, updatedSize int, batchPercentage in currentBatchIndex = newUpdateProgress / batchPercentage desiredSize = min((currentBatchIndex+1)*int32(batchSize), int32(total)) delta = desiredSize - int32(updatedSize) - // if rolling udpate policy changed or new nodes were added during update, we need to update progress + // if rolling update policy changed or new nodes were added during update, we need to update progress if delta < 0 { newUpdateProgress = min(newUpdateProgress+batchPercentage, 100) } else { diff --git a/internal/component/hypervisor.go b/internal/component/hypervisor.go index b33d03c8..55f9bba2 100644 --- a/internal/component/hypervisor.go +++ b/internal/component/hypervisor.go @@ -88,7 +88,7 @@ func (h *Hypervisor) GetResourcesInfo(r client.Client, ctx context.Context, pool } key := client.ObjectKey{ Namespace: utils.CurrentNamespace(), - Name: fmt.Sprintf("hypervisor-%s", node.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", node.Name), } pod := &corev1.Pod{} err := r.Get(ctx, key, pod) diff --git a/internal/config/gpu_info.go b/internal/config/gpu_info.go index f05bace1..830548b8 100644 --- a/internal/config/gpu_info.go +++ b/internal/config/gpu_info.go @@ -10,6 +10,49 @@ type GpuInfo struct { CostPerHour float64 `json:"costPerHour"` Fp16TFlops resource.Quantity `json:"fp16TFlops"` FullModelName string `json:"fullModelName"` + + // PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + // Only applicable for GPUs that support hardware partitioning + PartitionTemplates []PartitionTemplateInfo `json:"partitionTemplates,omitempty"` + + // MaxPartitions is the maximum number of partitions this GPU can support (e.g., 7 for MIG) + MaxPartitions uint32 `json:"maxPartitions,omitempty"` + + // MaxPlacementSlots is the maximum number of placement slots this GPU can support (e.g., 8 for NVIDIA MIG) + MaxPlacementSlots uint32 `json:"maxPlacementSlots,omitempty"` +} + +// PartitionTemplateInfo contains detailed resource information for a partition template +type PartitionTemplateInfo struct { + // TemplateID is the unique identifier for this partition template Profile `19` for 1g.10gb in A100 + TemplateID string `json:"templateId"` + + // TemplateID is the unique identifier (e.g., "1g.24gb", "4g.94gb") + Name string `json:"name"` + + // MemoryGigabytes is the memory allocated to this partition in gigabytes + MemoryGigabytes uint64 `json:"memoryGigabytes"` + + // ComputePercent is the percent of sliced GPU (0-100) + ComputePercent float64 `json:"computePercent"` + + // Description provides additional information about this template + Description string `json:"description,omitempty"` + + // MaxPartition for this single template, eg. 1g.10gb+me can only be allocate once + MaxPartition uint32 `json:"maxPartition"` + + // The placement limit for this template, use a bitmask to represent the placement limit + // e.g. sudo nvidia-smi mig -i 0 -lgipp + // GPU 0 Profile ID 19 Placements: {0,1,2,3,4,5,6}:1 + // GPU 0 Profile ID 20 Placements: {0,1,2,3,4,5,6}:1 + // GPU 0 Profile ID 15 Placements: {0,2,4,6}:2 + // GPU 0 Profile ID 14 Placements: {0,2,4}:2 + // GPU 0 Profile ID 9 Placements: {0,4}:4 + // GPU 0 Profile ID 5 Placement : {0}:4 + // GPU 0 Profile ID 0 Placement : {0}:8 + PlacementLimit []uint32 `json:"placementLimit"` + PlacementOffSet uint32 `json:"placementOffSet"` } func MockGpuInfo() *[]GpuInfo { diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 557fdabd..3a2dd406 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -83,7 +83,13 @@ const ( // GPUModelAnnotation specifies the required GPU model (e.g., "A100", "H100") GPUModelAnnotation = Domain + "/gpu-model" // GPU ID list is assigned by scheduler, should not specified by user - GPUDeviceIDsAnnotation = Domain + "/gpu-ids" + GPUDeviceIDsAnnotation = Domain + "/gpu-ids" + // User can specify the partition name to designate the partition template to use, e.g. 1g.20gb+me + // TODO: parse and pre-set in scheduler plugin to avoid find matched partition. + PartitionNameAnnotation = Domain + "/partition" + // PartitionTemplateIDAnnotation is the partition UUID assigned to a pod in partitioned mode + // This is read by accelerator.c to mock slice GPU like MIG does + PartitionTemplateIDAnnotation = Domain + "/partition-id" DedicatedGPUAnnotation = Domain + "/dedicated-gpu" SetPendingOwnedWorkloadAnnotation = Domain + "/pending-owned-workload" PricingAnnotation = Domain + "/hourly-pricing" @@ -233,3 +239,11 @@ const DefaultEvictionProtectionPriceRatio = 1.2 const NodeCriticalPriorityClassName = "system-node-critical" const KarpenterNodeClaimKind = "NodeClaim" const KarpenterNodePoolKind = "NodePool" + +// Vendor label key for multi-vendor support +const AcceleratorLabelVendor = Domain + "/hardware-vendor" + +const ( + IndexRangeStart = 1 + IndexRangeEnd = 512 +) diff --git a/internal/constants/env.go b/internal/constants/env.go index 52801324..c5521e68 100644 --- a/internal/constants/env.go +++ b/internal/constants/env.go @@ -136,21 +136,22 @@ const ( // TensorFusion hypervisor related envs const ( - HypervisorPoolNameEnv = "TENSOR_FUSION_POOL_NAME" - PodNameEnv = "POD_NAME" - VectorPodNodeNameEnv = "NODE_NAME" - HypervisorGPUNodeNameEnv = "GPU_NODE_NAME" - HypervisorSchedulingConfigEnv = "TF_HYPERVISOR_SCHEDULING_CONFIG" - HypervisorListenAddrEnv = "API_LISTEN_ADDR" - HypervisorMetricsFormatEnv = "TF_HYPERVISOR_METRICS_FORMAT" - HypervisorMetricsExtraLabelsEnv = "TF_HYPERVISOR_METRICS_EXTRA_LABELS" - HypervisorDetectUsedGPUEnv = "DETECT_IN_USED_GPUS" - HypervisorDevicePluginPathEnv = "DEVICE_PLUGIN_PATH" + HypervisorPoolNameEnv = "TENSOR_FUSION_POOL_NAME" + PodNameEnv = "POD_NAME" + VectorPodNodeNameEnv = "NODE_NAME" + HypervisorGPUNodeNameEnv = "GPU_NODE_NAME" + HypervisorSchedulingConfigEnv = "TF_HYPERVISOR_SCHEDULING_CONFIG" + HypervisorListenAddrEnv = "API_LISTEN_ADDR" + HypervisorMetricsFormatEnv = "TF_HYPERVISOR_METRICS_FORMAT" + HypervisorMetricsExtraLabelsEnv = "TF_HYPERVISOR_METRICS_EXTRA_LABELS" + HypervisorDetectUsedGPUEnv = "DETECT_IN_USED_GPUS" + HypervisorDevicePluginPathEnv = "DEVICE_PLUGIN_PATH" + HypervisorKubeletCheckpointPathEnv = "KUBELET_CHECKPOINT_PATH" // Add ptrace capability to hypervisor container, to trace all host PID using GPU SystemPtraceCapability = "SYS_PTRACE" - HypervisorDefaultPortNumber int32 = 8000 + HypervisorDefaultPortNumber int32 = 8001 HypervisorPortName string = "http" // For security enhancement, there are 2 types of endpoints to protect @@ -161,6 +162,10 @@ const ( // but k3s and some K8S distribution may not support, need to find some way to get SA token JWT pub key HypervisorVerifyServiceAccountEnabledEnvVar = "SA_TOKEN_VERIFY_ENABLED" HypervisorVerifyServiceAccountPublicKeyEnvVar = "SA_TOKEN_VERIFY_PUBLIC_KEY" + + // Hardware vendor and accelerator library path for multi-vendor support + TFHardwareVendorEnv = "TF_HARDWARE_VENDOR" + TFAcceleratorLibPathEnv = "TF_ACCELERATOR_LIB_PATH" ) // Node discovery related envs diff --git a/internal/constants/vendors.go b/internal/constants/vendors.go index f72c4636..ba3fc16e 100644 --- a/internal/constants/vendors.go +++ b/internal/constants/vendors.go @@ -70,3 +70,19 @@ var L3VirtualizationSupportedVendors = []map[string]bool{ AcceleratorVendorHuaweiAscendNPU: false, }, } + +// GetAcceleratorLibPath returns the accelerator library path based on vendor +// Vendor string should match constants from internal/constants/vendors.go +func GetAcceleratorLibPath(vendor string) string { + switch vendor { + case AcceleratorVendorNvidia: + return "libaccelerator_nvidia.so" + case AcceleratorVendorAMD: + return "libaccelerator_amd.so" + case AcceleratorVendorHuaweiAscendNPU: + return "libaccelerator_ascend.so" + default: + // Default to stub library for unknown vendors + return "libaccelerator_stub.so" + } +} diff --git a/internal/controller/gpunode_controller.go b/internal/controller/gpunode_controller.go index 4a6c235f..6661faba 100644 --- a/internal/controller/gpunode_controller.go +++ b/internal/controller/gpunode_controller.go @@ -30,7 +30,6 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/metrics" "github.com/NexusGPU/tensor-fusion/internal/scheduler/expander" "github.com/NexusGPU/tensor-fusion/internal/utils" - batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -103,7 +102,7 @@ func (r *GPUNodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct poolObj := &tfv1.GPUPool{} err = r.Get(ctx, client.ObjectKey{Name: poolName}, poolObj) if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to get tensor-fusion pool, can not create node discovery job, pool: %s", poolName) + return ctrl.Result{}, fmt.Errorf("failed to get tensor-fusion pool, pool: %s", poolName) } // Check if the Kubernetes node exists; if not, the GPUNode should delete itself. @@ -135,15 +134,6 @@ func (r *GPUNodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct } } - if err := r.reconcileNodeDiscoveryJob(ctx, node, poolObj); err != nil { - return ctrl.Result{}, err - } - - if node.Status.TotalGPUs == 0 { - log.Info("GPU on this node has not been discovered, wait next loop", "node", node.Name) - return ctrl.Result{}, nil - } - hypervisorName, err := r.reconcileHypervisorPod(ctx, node, poolObj, coreNode) if err != nil { return ctrl.Result{}, err @@ -259,77 +249,6 @@ func (r *GPUNodeReconciler) fetchAllOwnedGPUDevices(ctx context.Context, node *t return gpuList.Items, nil } -func (r *GPUNodeReconciler) reconcileNodeDiscoveryJob( - ctx context.Context, - gpunode *tfv1.GPUNode, - pool *tfv1.GPUPool, -) error { - log := log.FromContext(ctx) - log.Info("starting node discovery job") - - if pool.Spec.ComponentConfig == nil || pool.Spec.ComponentConfig.NodeDiscovery.PodTemplate == nil { - return fmt.Errorf(`missing node discovery pod template in pool spec`) - } - podTmpl := &corev1.PodTemplate{} - err := json.Unmarshal(pool.Spec.ComponentConfig.NodeDiscovery.PodTemplate.Raw, podTmpl) - if err != nil { - return fmt.Errorf("unmarshal pod template: %w", err) - } - tmpl := podTmpl.Template - if tmpl.Labels == nil { - tmpl.Labels = map[string]string{} - } - tmpl.Labels[constants.LabelComponent] = constants.ComponentNodeDiscovery - tmpl.Spec.NodeName = gpunode.Name - // allow job to run at any taint Nodes that marked as NoSchedule - tmpl.Spec.Tolerations = append(tmpl.Spec.Tolerations, corev1.Toleration{ - Key: string(corev1.TaintEffectNoSchedule), - Operator: corev1.TolerationOpExists, - }) - tmpl.Spec.EnableServiceLinks = ptr.To(false) - - utils.AddTFNodeDiscoveryConfAfterTemplate(ctx, &tmpl, pool, gpunode.Name, r.CompatibleWithNvidiaContainerToolkit) - - // create node-discovery job - job := &batchv1.Job{ - ObjectMeta: metav1.ObjectMeta{ - Name: getDiscoveryJobName(gpunode.Name), - Namespace: utils.CurrentNamespace(), - Labels: tmpl.Labels, - Annotations: tmpl.Annotations, - }, - Spec: batchv1.JobSpec{ - TTLSecondsAfterFinished: ptr.To[int32](3600 * 10), - Template: tmpl, - }, - } - - if err := r.Get(ctx, client.ObjectKeyFromObject(job), job); err != nil { - if errors.IsNotFound(err) { - if err := ctrl.SetControllerReference(gpunode, job, r.Scheme); err != nil { - return fmt.Errorf("set owner reference %w", err) - } - if err := r.Create(ctx, job); err != nil { - return fmt.Errorf("create node discovery job %w", err) - } - } else { - return fmt.Errorf("create node discovery job %w", err) - } - } - - if job.Status.Failed > 0 { - log.Info("node discovery job failed, update GPU node status to failed", "node", gpunode.Name) - // Update phase to failed, require manual address why it failed and restart of node discovery job - gpunode.Status.Phase = tfv1.TensorFusionGPUNodePhaseFailed - if err := r.Status().Update(ctx, gpunode); err != nil { - return fmt.Errorf("failed to update GPU node status to failed: %w", err) - } - metrics.SetNodeMetrics(gpunode, pool, nil) - } - - return nil -} - func (r *GPUNodeReconciler) reconcileHypervisorPod( ctx context.Context, node *tfv1.GPUNode, @@ -344,7 +263,7 @@ func (r *GPUNodeReconciler) reconcileHypervisorPod( key := client.ObjectKey{ Namespace: utils.CurrentNamespace(), - Name: fmt.Sprintf("hypervisor-%s", node.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", node.Name), } currentPod := &corev1.Pod{} @@ -414,7 +333,21 @@ func (r *GPUNodeReconciler) createHypervisorPod( // add must-have tensor-fusion hypervisor manifest log.Info("adding must-have tensor-fusion hypervisor manifest", "node", node.Name) - utils.AddTFHypervisorConfAfterTemplate(ctx, &spec, pool) + utils.AddTFHypervisorConfAfterTemplate(ctx, &spec, pool, r.CompatibleWithNvidiaContainerToolkit) + + // add vendor-specific env vars for multi-vendor support + if node.Labels != nil && node.Labels[constants.AcceleratorLabelVendor] != "" { + vendor := node.Labels[constants.AcceleratorLabelVendor] + acceleratorLibPath := constants.GetAcceleratorLibPath(vendor) + spec.Containers[0].Env = utils.AppendEnvVarsIfNotExists(spec.Containers[0].Env, corev1.EnvVar{ + Name: constants.TFHardwareVendorEnv, + Value: vendor, + }, corev1.EnvVar{ + Name: constants.TFAcceleratorLibPathEnv, + Value: acceleratorLibPath, + }) + log.Info("added vendor env vars to hypervisor pod", "node", node.Name, "vendor", vendor, "libPath", acceleratorLibPath) + } // add scheduling config for hypervisor if pool.Spec.SchedulingConfigTemplate != nil { @@ -495,12 +428,7 @@ func (r *GPUNodeReconciler) SetupWithManager(mgr ctrl.Manager) error { {NamespacedName: client.ObjectKey{Name: obj.GetName()}}, } })). - Owns(&batchv1.Job{}). Owns(&corev1.Pod{}). Owns(&tfv1.GPU{}). Complete(r) } - -func getDiscoveryJobName(gpunodeName string) string { - return fmt.Sprintf("node-discovery-%s", gpunodeName) -} diff --git a/internal/controller/gpunode_controller_test.go b/internal/controller/gpunode_controller_test.go index 42ea9d7b..29ea919c 100644 --- a/internal/controller/gpunode_controller_test.go +++ b/internal/controller/gpunode_controller_test.go @@ -23,37 +23,24 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/utils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" - "k8s.io/utils/ptr" ) var _ = Describe("GPUNode Controller", func() { Context("When reconciling gpunodes", func() { - It("should create the node discovery job and the hypervisor pod", func() { + It("should create the hypervisor pod", func() { tfEnv := NewTensorFusionEnvBuilder(). AddPoolWithNodeCount(1). SetGpuCountPerNode(1). Build() gpuNode := tfEnv.GetGPUNode(0, 0) - By("checking that the node discovery job is created") - Eventually(func(g Gomega) { - job := &batchv1.Job{} - g.Expect(k8sClient.Get(ctx, types.NamespacedName{ - Name: fmt.Sprintf("node-discovery-%s", gpuNode.Name), - Namespace: utils.CurrentNamespace(), - }, job)).Should(Succeed()) - - g.Expect(job.Spec.TTLSecondsAfterFinished).Should(Equal(ptr.To[int32](3600 * 10))) - }).Should(Succeed()) - By("checking that the hypervisor pod is created") pod := &corev1.Pod{} Eventually(func(g Gomega) { err := k8sClient.Get(ctx, types.NamespacedName{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod) g.Expect(err).ShouldNot(HaveOccurred()) @@ -72,7 +59,7 @@ var _ = Describe("GPUNode Controller", func() { Eventually(func(g Gomega) { newPod := &corev1.Pod{} err := k8sClient.Get(ctx, types.NamespacedName{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, newPod) g.Expect(err).ShouldNot(HaveOccurred()) diff --git a/internal/controller/gpupool_controller.go b/internal/controller/gpupool_controller.go index a823ba9f..2d0c2ed7 100644 --- a/internal/controller/gpupool_controller.go +++ b/internal/controller/gpupool_controller.go @@ -408,16 +408,73 @@ func (r *GPUPoolReconciler) reconcilePoolComponents(ctx context.Context, pool *t } func (r *GPUPoolReconciler) reconcilePoolSelectorChange(ctx context.Context, pool *tfv1.GPUPool) error { - if pool.Spec.NodeManagerConfig != nil && pool.Spec.NodeManagerConfig.NodeSelector != nil { - hash := utils.GetObjectHash(pool.Spec.NodeManagerConfig.NodeSelector) + nodeManagerConfig := pool.Spec.NodeManagerConfig + if nodeManagerConfig == nil { + return nil + } + + // Handle MultiVendorNodeSelector mode + if len(nodeManagerConfig.MultiVendorNodeSelector) > 0 { + hash := utils.GetObjectHash(nodeManagerConfig.MultiVendorNodeSelector) + if poolSelectorChangeMap[pool.Name] == hash { + return nil + } + + // hash has changed, or first reconcile, should check all k8s nodes + nodes := &corev1.NodeList{} + if err := r.List(ctx, nodes); err != nil { + return err + } + for _, node := range nodes.Items { + // skip no label or deleting nodes + if node.Labels == nil || !node.DeletionTimestamp.IsZero() { + continue + } + // Loop through vendor keys, when any key matched, set vendor label and break + vendorMatched := false + for vendor, nodeSelector := range nodeManagerConfig.MultiVendorNodeSelector { + if nodeSelector == nil { + continue + } + matches, err := schedulingcorev1.MatchNodeSelectorTerms(&node, nodeSelector) + if err != nil { + return err + } + if matches { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, &node, hash, vendor); err != nil { + return err + } + vendorMatched = true + break + } + } + // If no vendor matched but node was previously matched, remove vendor label + if !vendorMatched && node.Labels[constants.AcceleratorLabelVendor] != "" { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, &node, hash, ""); err != nil { + return err + } + } + } + poolSelectorChangeMap[pool.Name] = hash + return nil + } + + // Handle default NodeSelector mode + if nodeManagerConfig.NodeSelector != nil { + hash := utils.GetObjectHash(nodeManagerConfig.NodeSelector) if poolSelectorChangeMap[pool.Name] == hash { return nil } + // Determine default vendor: use defaultVendor if set, otherwise NVIDIA + defaultVendor := constants.AcceleratorVendorNvidia + if nodeManagerConfig.DefaultVendor != "" { + defaultVendor = nodeManagerConfig.DefaultVendor + } + // hash has changed, or first reconcile, should check all k8s nodes nodes := &corev1.NodeList{} - selectors := utils.GetInitialGPUNodeSelector() - if err := r.List(ctx, nodes, client.MatchingLabels{selectors[0]: selectors[1]}); err != nil { + if err := r.List(ctx, nodes); err != nil { return err } for _, node := range nodes.Items { @@ -425,12 +482,12 @@ func (r *GPUPoolReconciler) reconcilePoolSelectorChange(ctx context.Context, poo if node.Labels == nil || !node.DeletionTimestamp.IsZero() { continue } - matches, err := schedulingcorev1.MatchNodeSelectorTerms(&node, pool.Spec.NodeManagerConfig.NodeSelector) + matches, err := schedulingcorev1.MatchNodeSelectorTerms(&node, nodeManagerConfig.NodeSelector) if err != nil { return err } if matches { - if err := UpdateK8SNodeSelectorHash(ctx, r.Client, &node, hash); err != nil { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, &node, hash, defaultVendor); err != nil { return err } } @@ -441,9 +498,9 @@ func (r *GPUPoolReconciler) reconcilePoolSelectorChange(ctx context.Context, poo return nil } -func UpdateK8SNodeSelectorHash(ctx context.Context, k8sClient client.Client, node *corev1.Node, hash string) error { - // skip nodes that already injected the hash - if node.Labels[constants.LabelNodeSelectorHash] == hash { +func UpdateK8SNodeSelectorHashAndVendor(ctx context.Context, k8sClient client.Client, node *corev1.Node, hash string, vendor string) error { + // skip nodes that already have the same hash and vendor + if node.Labels[constants.LabelNodeSelectorHash] == hash && node.Labels[constants.AcceleratorLabelVendor] == vendor { return nil } // update label to trigger the GPUNode reconcile @@ -452,7 +509,15 @@ func UpdateK8SNodeSelectorHash(ctx context.Context, k8sClient client.Client, nod if err := k8sClient.Get(ctx, client.ObjectKey{Name: node.Name}, latest); err != nil { return err } + if latest.Labels == nil { + latest.Labels = make(map[string]string) + } latest.Labels[constants.LabelNodeSelectorHash] = hash + if vendor != "" { + latest.Labels[constants.AcceleratorLabelVendor] = vendor + } else { + delete(latest.Labels, constants.AcceleratorLabelVendor) + } return k8sClient.Update(ctx, latest) }); err != nil { return err diff --git a/internal/controller/gpupool_controller_test.go b/internal/controller/gpupool_controller_test.go index 422a140c..caf85f6f 100644 --- a/internal/controller/gpupool_controller_test.go +++ b/internal/controller/gpupool_controller_test.go @@ -429,7 +429,7 @@ func verifyHypervisorPodHash(gpuNode *tfv1.GPUNode, hash string) { Eventually(func(g Gomega) { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) @@ -463,7 +463,7 @@ func verifyHypervisorPodHashConsistently(gpuNode *tfv1.GPUNode, hash string) { Consistently(func(g Gomega) { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) @@ -486,7 +486,7 @@ func verifyAllHypervisorPodHash(tfEnv *TensorFusionEnv, hash string) { for _, gpuNode := range nodeList.Items { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) @@ -552,7 +552,7 @@ func verifyAllHypervisorPodHashConsistently(tfEnv *TensorFusionEnv, hash string) for _, gpuNode := range nodeList.Items { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) diff --git a/internal/controller/node_controller.go b/internal/controller/node_controller.go index d8908847..67723625 100644 --- a/internal/controller/node_controller.go +++ b/internal/controller/node_controller.go @@ -32,11 +32,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - "sigs.k8s.io/controller-runtime/pkg/event" - "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" - "sigs.k8s.io/controller-runtime/pkg/reconcile" schedulingcorev1 "k8s.io/component-helpers/scheduling/corev1" ) @@ -115,6 +112,14 @@ func (r *NodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl. } } + // If node changed to other AI accelerator hardware vendor, update gpuNode label vendor and trigger hypervisor update + if gpuNode.Labels[constants.AcceleratorLabelVendor] != node.Labels[constants.AcceleratorLabelVendor] { + gpuNode.Labels[constants.AcceleratorLabelVendor] = node.Labels[constants.AcceleratorLabelVendor] + if err := r.Update(ctx, gpuNode); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update GPU node vendor: %w", err) + } + } + if !node.DeletionTimestamp.IsZero() { log.Info("GPU node is being deleted, mark related GPUNode resource as destroying", "node", node.Name) gpuNode.Status.Phase = tfv1.TensorFusionGPUNodePhaseDestroying @@ -125,9 +130,14 @@ func (r *NodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl. } // update k8s node hash - hash := utils.GetObjectHash(pool.Spec.NodeManagerConfig.NodeSelector) + hash := "" + if len(pool.Spec.NodeManagerConfig.MultiVendorNodeSelector) > 0 { + hash = utils.GetObjectHash(pool.Spec.NodeManagerConfig.MultiVendorNodeSelector) + } else { + hash = utils.GetObjectHash(pool.Spec.NodeManagerConfig.NodeSelector) + } if node.Labels[constants.LabelNodeSelectorHash] != hash { - if err := UpdateK8SNodeSelectorHash(ctx, r.Client, node, hash); err != nil { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, node, hash, node.Labels[constants.AcceleratorLabelVendor]); err != nil { return ctrl.Result{}, fmt.Errorf("failed to update k8s node hash: %w", err) } } @@ -203,51 +213,35 @@ func (r *NodeReconciler) generateGPUNode(node *corev1.Node, pool *tfv1.GPUPool, if provisioner != "" { gpuNode.Labels[constants.ProvisionerLabelKey] = provisioner } + // Copy vendor label from k8s node to GPUNode + if node.Labels != nil && node.Labels[constants.AcceleratorLabelVendor] != "" { + gpuNode.Labels[constants.AcceleratorLabelVendor] = node.Labels[constants.AcceleratorLabelVendor] + } _ = controllerutil.SetControllerReference(pool, gpuNode, r.Scheme) return gpuNode } // SetupWithManager sets up the controller with the Manager. func (r *NodeReconciler) SetupWithManager(mgr ctrl.Manager) error { - // must choose an initial label selector to avoid performance impact in large Kubernetes clusters + ctr := ctrl.NewControllerManagedBy(mgr) + // Prefer to choose an initial label selector to avoid performance impact in large Kubernetes clusters that has lots of CPU nodes selectors := utils.GetInitialGPUNodeSelector() - p, err := predicate.LabelSelectorPredicate(metav1.LabelSelector{ - MatchLabels: map[string]string{ - selectors[0]: selectors[1], - }, - }) - if err != nil { - return fmt.Errorf("unable to create predicate: %w", err) + if len(selectors) == 2 { + p, err := predicate.LabelSelectorPredicate(metav1.LabelSelector{ + MatchLabels: map[string]string{ + selectors[0]: selectors[1], + }, + }) + if err != nil { + return fmt.Errorf("unable to create predicate: %w", err) + } + ctr.For(&corev1.Node{}, builder.WithPredicates(p)) + } else { + ctr.For(&corev1.Node{}) } - return ctrl.NewControllerManagedBy(mgr). - For(&corev1.Node{}, builder.WithPredicates(p)). + return ctr. Named("node"). - Watches(&tfv1.GPUPool{}, handler.EnqueueRequestsFromMapFunc(func(ctx context.Context, obj client.Object) []reconcile.Request { - nodelist := &tfv1.GPUNodeList{} - if err := mgr.GetClient().List(ctx, nodelist, client.MatchingLabels{ - selectors[0]: selectors[1], - }); err != nil { - log.FromContext(ctx).Error(err, "failed to list GPUNode") - return []reconcile.Request{} - } - var requests []reconcile.Request - for _, n := range nodelist.Items { - requests = append(requests, reconcile.Request{NamespacedName: client.ObjectKey{Name: n.Name}}) - } - return requests - }), builder.WithPredicates(predicate.Funcs{ - UpdateFunc: func(e event.UpdateEvent) bool { - oldObj, ok1 := e.ObjectOld.(*tfv1.GPUPool) - newObj, ok2 := e.ObjectNew.(*tfv1.GPUPool) - if !ok1 || !ok2 { - return false - } - oldNodeSelector := oldObj.Spec.NodeManagerConfig.NodeSelector - newNodeSelector := newObj.Spec.NodeManagerConfig.NodeSelector - return utils.GetObjectHash(oldNodeSelector) != utils.GetObjectHash(newNodeSelector) - }, - })). Complete(r) } diff --git a/internal/controller/tensorfusionworkload_controller_test.go b/internal/controller/tensorfusionworkload_controller_test.go index 9c2a9cd3..f11fe3d5 100644 --- a/internal/controller/tensorfusionworkload_controller_test.go +++ b/internal/controller/tensorfusionworkload_controller_test.go @@ -37,6 +37,7 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" ) var _ = Describe("TensorFusionWorkload Controller", func() { @@ -402,7 +403,7 @@ func mockSchedulerLoop(ctx context.Context, cfg *rest.Config) { func scheduleAndStartPod(pod *corev1.Pod, clientset *kubernetes.Clientset) { // simulate scheduling cycle Filter and Reserve - allocRequest, _, err := allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(ctx, pod) Expect(err).To(Succeed()) gpus, err := allocator.Alloc(allocRequest) if err != nil { diff --git a/internal/gpuallocator/filter/filter_test.go b/internal/gpuallocator/filter/filter_test.go index c47ab594..5c6e2e5a 100644 --- a/internal/gpuallocator/filter/filter_test.go +++ b/internal/gpuallocator/filter/filter_test.go @@ -111,7 +111,7 @@ func TestFilters(t *testing.T) { filter := NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("8"), Vram: resource.MustParse("30Gi"), - }, nil) + }) result, err := filter.Filter(ctx, testPodKey, gpus) assert.NoError(t, err) assert.Len(t, result, 2) @@ -126,7 +126,7 @@ func TestFilters(t *testing.T) { With(NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("8"), Vram: resource.MustParse("30Gi"), - }, nil)) + })) // Apply filters result, _, err := registry.Apply(ctx, testPodKey, gpus, false) @@ -137,10 +137,11 @@ func TestFilters(t *testing.T) { t.Run("FilterRegistry with gpu indices filtering", func(t *testing.T) { registry := NewFilterRegistry(). + With(NewGPUIndexFilter([]int32{2, 3})). With(NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("1"), Vram: resource.MustParse("1Gi"), - }, []int32{2, 3})) + })) // Apply filters result, _, err := registry.Apply(ctx, testPodKey, gpus, false) @@ -160,7 +161,7 @@ func TestFilters(t *testing.T) { With(NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("8"), Vram: resource.MustParse("30Gi"), - }, nil)) + })) // Apply base registry filters baseResult, _, err := baseRegistry.Apply(ctx, testPodKey, gpus, false) diff --git a/internal/gpuallocator/filter/gpu_index_filter.go b/internal/gpuallocator/filter/gpu_index_filter.go new file mode 100644 index 00000000..285f59bf --- /dev/null +++ b/internal/gpuallocator/filter/gpu_index_filter.go @@ -0,0 +1,57 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "context" + "slices" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/samber/lo" +) + +// GPUIndexFilter filters GPUs based on required GPU indices +type GPUIndexFilter struct { + requiredIndices []int32 +} + +// NewGPUIndexFilter creates a new GPUIndexFilter with the specified indices +func NewGPUIndexFilter(requiredIndices []int32) *GPUIndexFilter { + return &GPUIndexFilter{ + requiredIndices: requiredIndices, + } +} + +// Filter implements GPUFilter.Filter +func (f *GPUIndexFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + // If no indices specified, pass all GPUs + if len(f.requiredIndices) == 0 { + return gpus, nil + } + + return lo.Filter(gpus, func(gpu *tfv1.GPU, _ int) bool { + // Check GPU index + if gpu.Status.Index != nil && slices.Contains(f.requiredIndices, *gpu.Status.Index) { + return true + } + return false + }), nil +} + +func (f *GPUIndexFilter) Name() string { + return "GPUIndexFilter" +} diff --git a/internal/gpuallocator/filter/gpu_isolation_mode_filter.go b/internal/gpuallocator/filter/gpu_isolation_mode_filter.go new file mode 100644 index 00000000..4d094e04 --- /dev/null +++ b/internal/gpuallocator/filter/gpu_isolation_mode_filter.go @@ -0,0 +1,38 @@ +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// GPUIsolationModeFilter filters GPUs based on their isolation mode +type GPUIsolationModeFilter struct { + requiredIsolationMode tfv1.IsolationModeType +} + +// NewGPUIsolationModeFilter creates a new filter that matches GPUs with the specified isolation mode +func NewGPUIsolationModeFilter(isolationMode tfv1.IsolationModeType) *GPUIsolationModeFilter { + return &GPUIsolationModeFilter{ + requiredIsolationMode: isolationMode, + } +} + +// Filter implements GPUFilter interface +func (f *GPUIsolationModeFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + if f.requiredIsolationMode == "" { + return gpus, nil + } + + filtered := make([]*tfv1.GPU, 0, len(gpus)) + for _, gpu := range gpus { + if gpu.Status.IsolationMode == "" || gpu.Status.IsolationMode == f.requiredIsolationMode { + filtered = append(filtered, gpu) + } + } + return filtered, nil +} + +func (f *GPUIsolationModeFilter) Name() string { + return "GPUIsolationModeFilter" +} diff --git a/internal/gpuallocator/filter/gpu_model_filter.go b/internal/gpuallocator/filter/gpu_model_filter.go new file mode 100644 index 00000000..f3d927e3 --- /dev/null +++ b/internal/gpuallocator/filter/gpu_model_filter.go @@ -0,0 +1,38 @@ +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// GPUModelFilter filters GPUs based on their model (e.g., A100, H100) +type GPUModelFilter struct { + requiredModel string +} + +// NewGPUModelFilter creates a new filter that matches GPUs with the specified model +func NewGPUModelFilter(model string) *GPUModelFilter { + return &GPUModelFilter{ + requiredModel: model, + } +} + +// Filter implements GPUFilter interface +func (f *GPUModelFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + if f.requiredModel == "" { + return gpus, nil + } + + filtered := make([]*tfv1.GPU, 0, len(gpus)) + for _, gpu := range gpus { + if gpu.Status.GPUModel == f.requiredModel { + filtered = append(filtered, gpu) + } + } + return filtered, nil +} + +func (f *GPUModelFilter) Name() string { + return "GPUModelFilter" +} diff --git a/internal/gpuallocator/filter/gpu_model_vendor_filter.go b/internal/gpuallocator/filter/gpu_model_vendor_filter.go deleted file mode 100644 index f095d76a..00000000 --- a/internal/gpuallocator/filter/gpu_model_vendor_filter.go +++ /dev/null @@ -1,50 +0,0 @@ -package filter - -import ( - "context" - - tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" -) - -// GPUModelAndVendorFilter filters GPUs based on their model (e.g., A100, H100) -type GPUModelAndVendorFilter struct { - requiredModel string - requiredVendor string -} - -// NewGPUModelAndVendorFilter creates a new filter that matches GPUs with the specified model -func NewGPUModelAndVendorFilter(model string, vendor string) *GPUModelAndVendorFilter { - return &GPUModelAndVendorFilter{ - requiredModel: model, - requiredVendor: vendor, - } -} - -// Filter implements GPUFilter interface -func (f *GPUModelAndVendorFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { - if f.requiredModel == "" && f.requiredVendor == "" { - return gpus, nil - } - - filtered := make([]*tfv1.GPU, 0, len(gpus)) - - if f.requiredModel != "" { - for _, gpu := range gpus { - if gpu.Status.GPUModel == f.requiredModel { - filtered = append(filtered, gpu) - } - } - } - if f.requiredVendor != "" { - for _, gpu := range gpus { - if gpu.Status.Vendor == f.requiredVendor { - filtered = append(filtered, gpu) - } - } - } - return filtered, nil -} - -func (f *GPUModelAndVendorFilter) Name() string { - return "GPUModelAndVendorFilter" -} diff --git a/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go b/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go index 0f57173b..e25de11e 100644 --- a/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go +++ b/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go @@ -85,7 +85,7 @@ func TestGPUModelFilter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - filter := NewGPUModelAndVendorFilter(tt.requiredModel, "") + filter := NewGPUModelFilter(tt.requiredModel) got, err := filter.Filter(context.Background(), testPodKey, tt.gpus) if tt.wantErr { assert.Error(t, err) diff --git a/internal/gpuallocator/filter/gpu_vendor_filter.go b/internal/gpuallocator/filter/gpu_vendor_filter.go new file mode 100644 index 00000000..0f3ef5cf --- /dev/null +++ b/internal/gpuallocator/filter/gpu_vendor_filter.go @@ -0,0 +1,38 @@ +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// GPUVendorFilter filters GPUs based on their vendor +type GPUVendorFilter struct { + requiredVendor string +} + +// NewGPUVendorFilter creates a new filter that matches GPUs with the specified vendor +func NewGPUVendorFilter(vendor string) *GPUVendorFilter { + return &GPUVendorFilter{ + requiredVendor: vendor, + } +} + +// Filter implements GPUFilter interface +func (f *GPUVendorFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + if f.requiredVendor == "" { + return gpus, nil + } + + filtered := make([]*tfv1.GPU, 0, len(gpus)) + for _, gpu := range gpus { + if gpu.Status.Vendor == f.requiredVendor { + filtered = append(filtered, gpu) + } + } + return filtered, nil +} + +func (f *GPUVendorFilter) Name() string { + return "GPUVendorFilter" +} diff --git a/internal/gpuallocator/filter/partition_template_filter.go b/internal/gpuallocator/filter/partition_template_filter.go new file mode 100644 index 00000000..e6991764 --- /dev/null +++ b/internal/gpuallocator/filter/partition_template_filter.go @@ -0,0 +1,101 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/samber/lo" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// PartitionTemplateFilter filters GPUs based on partition template availability +// Only applies when isolation mode is partitioned +type PartitionTemplateFilter struct { + isolationMode tfv1.IsolationModeType + requiredTemplateID string + maxPartitionsMap map[string]uint32 // GPU model -> max partitions +} + +// NewPartitionTemplateFilter creates a new PartitionTemplateFilter +func NewPartitionTemplateFilter(isolationMode tfv1.IsolationModeType, requiredTemplateID string, maxPartitionsMap map[string]uint32) *PartitionTemplateFilter { + return &PartitionTemplateFilter{ + isolationMode: isolationMode, + requiredTemplateID: requiredTemplateID, + maxPartitionsMap: maxPartitionsMap, + } +} + +// Filter implements GPUFilter.Filter +func (f *PartitionTemplateFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + // Only apply filter for partitioned isolation mode + if f.isolationMode != tfv1.IsolationModePartitioned { + return gpus, nil + } + + logger := log.FromContext(ctx) + + return lo.Filter(gpus, func(gpu *tfv1.GPU, _ int) bool { + // Check if GPU has partition templates + if len(gpu.Status.PartitionTemplates) == 0 { + logger.V(5).Info("GPU has no partition templates", "gpu", gpu.Name) + return false + } + + // If a specific template ID is required, check if GPU has it + if f.requiredTemplateID != "" { + hasTemplate := false + for _, template := range gpu.Status.PartitionTemplates { + if template.TemplateID == f.requiredTemplateID { + hasTemplate = true + break + } + } + if !hasTemplate { + logger.V(5).Info("GPU does not have required partition template", + "gpu", gpu.Name, "template", f.requiredTemplateID) + return false + } + } + + // Check partition count limit + allocatedCount := 0 + if gpu.Status.AllocatedPartitions != nil { + allocatedCount = len(gpu.Status.AllocatedPartitions) + } + + // Get max partitions from config + maxPartitions := f.maxPartitionsMap[gpu.Status.GPUModel] + if maxPartitions == 0 { + // Default to 7 for MIG if not configured + maxPartitions = 7 + } + + if maxPartitions > 0 && uint32(allocatedCount) >= maxPartitions { + logger.V(5).Info("GPU has reached maximum partition count", + "gpu", gpu.Name, "allocated", allocatedCount, "max", maxPartitions) + return false + } + + return true + }), nil +} + +func (f *PartitionTemplateFilter) Name() string { + return "PartitionTemplateFilter" +} diff --git a/internal/gpuallocator/filter/partition_template_filter_test.go b/internal/gpuallocator/filter/partition_template_filter_test.go new file mode 100644 index 00000000..a6eaf1e2 --- /dev/null +++ b/internal/gpuallocator/filter/partition_template_filter_test.go @@ -0,0 +1,175 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "context" + "testing" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/stretchr/testify/assert" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestPartitionTemplateFilter(t *testing.T) { + testPodKey := tfv1.NameNamespace{ + Name: "test-pod", + Namespace: "test-namespace", + } + + tests := []struct { + name string + isolationMode tfv1.IsolationModeType + requiredTemplate string + maxPartitionsMap map[string]uint32 + gpus []*tfv1.GPU + expectedCount int + expectedGPUNames []string + }{ + { + name: "non-partitioned mode should pass all GPUs", + isolationMode: tfv1.IsolationModeSoft, + requiredTemplate: "", + maxPartitionsMap: map[string]uint32{}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-1"}, + }, + { + name: "partitioned mode - GPU without templates filtered out", + isolationMode: tfv1.IsolationModePartitioned, + requiredTemplate: "", + maxPartitionsMap: map[string]uint32{"A100": 7}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{}, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-2"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-2"}, + }, + { + name: "partitioned mode - specific template required", + isolationMode: tfv1.IsolationModePartitioned, + requiredTemplate: "1g.24gb", + maxPartitionsMap: map[string]uint32{"A100": 7}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-2"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-2"}, + }, + { + name: "partitioned mode - max partitions reached", + isolationMode: tfv1.IsolationModePartitioned, + requiredTemplate: "", + maxPartitionsMap: map[string]uint32{"A100": 7}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: "1g.24gb", PodUID: "pod-1"}, + "pod-2": {TemplateID: "1g.24gb", PodUID: "pod-2"}, + "pod-3": {TemplateID: "1g.24gb", PodUID: "pod-3"}, + "pod-4": {TemplateID: "1g.24gb", PodUID: "pod-4"}, + "pod-5": {TemplateID: "1g.24gb", PodUID: "pod-5"}, + "pod-6": {TemplateID: "1g.24gb", PodUID: "pod-6"}, + "pod-7": {TemplateID: "1g.24gb", PodUID: "pod-7"}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-2"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: "1g.24gb", PodUID: "pod-1"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-2"}, + }, + } + + ctx := context.Background() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filter := NewPartitionTemplateFilter(tt.isolationMode, tt.requiredTemplate, tt.maxPartitionsMap) + result, err := filter.Filter(ctx, testPodKey, tt.gpus) + + assert.NoError(t, err) + assert.Len(t, result, tt.expectedCount) + if len(tt.expectedGPUNames) > 0 { + resultNames := make([]string, len(result)) + for i, gpu := range result { + resultNames[i] = gpu.Name + } + assert.ElementsMatch(t, tt.expectedGPUNames, resultNames) + } + }) + } +} diff --git a/internal/gpuallocator/filter/resource_filter.go b/internal/gpuallocator/filter/resource_filter.go index fa8ca805..9f0a76ef 100644 --- a/internal/gpuallocator/filter/resource_filter.go +++ b/internal/gpuallocator/filter/resource_filter.go @@ -2,7 +2,6 @@ package filter import ( "context" - "slices" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/utils" @@ -12,14 +11,12 @@ import ( // ResourceFilter filters GPUs based on available resources type ResourceFilter struct { requiredResource tfv1.Resource - requiredIndices []int32 } // NewResourceFilter creates a new ResourceFilter with the specified resource requirements -func NewResourceFilter(required tfv1.Resource, requiredIndices []int32) *ResourceFilter { +func NewResourceFilter(required tfv1.Resource) *ResourceFilter { return &ResourceFilter{ requiredResource: required, - requiredIndices: requiredIndices, } } @@ -31,13 +28,6 @@ func (f *ResourceFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNames return false } - // Check GPU indices range - if len(f.requiredIndices) > 0 { - if gpu.Status.Index != nil && !slices.Contains(f.requiredIndices, *gpu.Status.Index) { - return false - } - } - // Check TFlops availability hasTflops := gpu.Status.Available.Tflops.Cmp(f.requiredResource.Tflops) >= 0 diff --git a/internal/gpuallocator/gpuallocator.go b/internal/gpuallocator/gpuallocator.go index a32156da..0ee33431 100644 --- a/internal/gpuallocator/gpuallocator.go +++ b/internal/gpuallocator/gpuallocator.go @@ -5,9 +5,7 @@ import ( "context" "fmt" "math" - "slices" "sort" - "strconv" "strings" "sync" "time" @@ -38,12 +36,54 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" ) -const MaxGPUCounterPerAllocation = 128 const CleanUpCheckInterval = 3 * time.Minute var mu sync.Mutex var GPUCapacityMap = map[string]tfv1.Resource{} +// PartitionTemplateMap stores partition template info by GPU model +// Key: GPU model (e.g., "A100_SXM_80G"), Value: map of templateID -> template info +var PartitionTemplateMap = map[string]map[string]config.PartitionTemplateInfo{} + +// MaxPartitionsMap stores max partitions by GPU model +// Key: GPU model, Value: max partitions (e.g., 7 for MIG) +var MaxPartitionsMap = map[string]uint32{} + +// MaxPlacementSlotsMap stores max placement slots by GPU model +// Key: GPU model, Value: max placement slots (e.g., 8 for MIG) +var MaxPlacementSlotsMap = map[string]uint32{} + +// LoadPartitionTemplatesFromConfig loads partition templates and max partitions from GPU info config +// This should be called when GPU info config is loaded/updated +func LoadPartitionTemplatesFromConfig(gpuInfos []config.GpuInfo) { + mu.Lock() + defer mu.Unlock() + + for _, gpuInfo := range gpuInfos { + // Store max partitions + if gpuInfo.MaxPartitions > 0 { + MaxPartitionsMap[gpuInfo.Model] = gpuInfo.MaxPartitions + MaxPartitionsMap[gpuInfo.FullModelName] = gpuInfo.MaxPartitions + } + + // Store max placement slots + if gpuInfo.MaxPlacementSlots > 0 { + MaxPlacementSlotsMap[gpuInfo.Model] = gpuInfo.MaxPlacementSlots + MaxPlacementSlotsMap[gpuInfo.FullModelName] = gpuInfo.MaxPlacementSlots + } + + // Store partition templates + if len(gpuInfo.PartitionTemplates) > 0 { + templateMap := make(map[string]config.PartitionTemplateInfo, len(gpuInfo.PartitionTemplates)) + for _, template := range gpuInfo.PartitionTemplates { + templateMap[template.TemplateID] = template + } + PartitionTemplateMap[gpuInfo.Model] = templateMap + PartitionTemplateMap[gpuInfo.FullModelName] = templateMap + } + } +} + type Strategy interface { // When isForNode = true, indicates each GPU's node level score // otherwise it's single GPU score inside one node @@ -178,20 +218,43 @@ func (s *GpuAllocator) Filter( toFilterGPUs []*tfv1.GPU, isSimulateSchedule bool, ) ([]*tfv1.GPU, []filter.FilterDetail, error) { - // Add SameNodeFilter if count > 1 to ensure GPUs are from the same node - filterRegistry := s.filterRegistry.With(filter.NewResourceFilter(req.Request, req.GPUIndices)) + // Filter order: index -> isolation -> partition -> resource -> (model, vendor, nodeAffinity) -> sameNode + filterRegistry := s.filterRegistry + + // 1. GPU index filter (extracted from resource filter) + if len(req.GPUIndices) > 0 { + filterRegistry = filterRegistry.With(filter.NewGPUIndexFilter(req.GPUIndices)) + } + + // 2. GPU isolation mode filter + if req.Isolation != "" { + filterRegistry = filterRegistry.With(filter.NewGPUIsolationModeFilter(req.Isolation)) + } - // Add GPU model filter if specified + // 3. Partition template filter (only for partitioned mode) + if req.Isolation == tfv1.IsolationModePartitioned { + filterRegistry = filterRegistry.With(filter.NewPartitionTemplateFilter(req.Isolation, req.PartitionTemplateID, MaxPartitionsMap)) + } + + // 4. Resource filter (moved after isolation/partition filters) + filterRegistry = filterRegistry.With(filter.NewResourceFilter(req.Request)) + + // 5. GPU model filter if specified if req.GPUModel != "" { - filterRegistry = filterRegistry.With(filter.NewGPUModelAndVendorFilter(req.GPUModel, req.GPUVendor)) + filterRegistry = filterRegistry.With(filter.NewGPUModelFilter(req.GPUModel)) } - // NOTE: deprecated, use Kubernetes native spec template affinity way + // 6. GPU vendor filter if specified + if req.GPUVendor != "" { + filterRegistry = filterRegistry.With(filter.NewGPUVendorFilter(req.GPUVendor)) + } + + // 7. NOTE: deprecated, use Kubernetes native spec template affinity way if req.NodeAffinity != nil { filterRegistry = filterRegistry.With(filter.NewNodeAffinityFilter(s.Client, req.NodeAffinity)) } - // Same node filter must be applied at final step + // 8. Same node filter must be applied at final step if req.Count > 1 { filterRegistry = filterRegistry.With(filter.NewSameNodeFilter(req.Count)) } @@ -217,17 +280,59 @@ func (s *GpuAllocator) FilterWithPreempt( return nil, nil, fmt.Errorf("gpu %s not found", gpuName) } gpuCopy := gpu.DeepCopy() - gpuCopy.Status.Available.Tflops.Add(preemptAllocRequest.Request.Tflops) - gpuCopy.Status.Available.Vram.Add(preemptAllocRequest.Request.Vram) + + // Handle partitioned mode: add back partition resources from config + if preemptAllocRequest.Isolation == tfv1.IsolationModePartitioned && preemptAllocRequest.PartitionTemplateID != "" { + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage( + gpuCopy.Status.Capacity.Tflops, gpuCopy.Status.GPUModel, preemptAllocRequest.PartitionTemplateID) + if err == nil { + gpuCopy.Status.Available.Tflops.Add(partitionTflops) + gpuCopy.Status.Available.Vram.Add(partitionVram) + } else { + // Fallback to request resources + gpuCopy.Status.Available.Tflops.Add(preemptAllocRequest.Request.Tflops) + gpuCopy.Status.Available.Vram.Add(preemptAllocRequest.Request.Vram) + } + } else { + // Non-partitioned mode + gpuCopy.Status.Available.Tflops.Add(preemptAllocRequest.Request.Tflops) + gpuCopy.Status.Available.Vram.Add(preemptAllocRequest.Request.Vram) + } toFilterGPUs = append(toFilterGPUs, gpuCopy) } } - filterRegistry := s.filterRegistry.With(filter.NewResourceFilter(req.Request, req.GPUIndices)) - // Add GPU model filter if specified + // Use same filter order as regular Filter + filterRegistry := s.filterRegistry + + // 1. GPU index filter + if len(req.GPUIndices) > 0 { + filterRegistry = filterRegistry.With(filter.NewGPUIndexFilter(req.GPUIndices)) + } + + // 2. GPU isolation mode filter + if req.Isolation != "" { + filterRegistry = filterRegistry.With(filter.NewGPUIsolationModeFilter(req.Isolation)) + } + + // 3. Partition template filter (only for partitioned mode) + if req.Isolation == tfv1.IsolationModePartitioned { + filterRegistry = filterRegistry.With(filter.NewPartitionTemplateFilter(req.Isolation, req.PartitionTemplateID, MaxPartitionsMap)) + } + + // 4. Resource filter + filterRegistry = filterRegistry.With(filter.NewResourceFilter(req.Request)) + + // 5. GPU model filter if specified if req.GPUModel != "" { - filterRegistry = filterRegistry.With(filter.NewGPUModelAndVendorFilter(req.GPUModel, req.GPUVendor)) + filterRegistry = filterRegistry.With(filter.NewGPUModelFilter(req.GPUModel)) } + + // 6. GPU vendor filter if specified + if req.GPUVendor != "" { + filterRegistry = filterRegistry.With(filter.NewGPUVendorFilter(req.GPUVendor)) + } + // No need to check count and other filters since it's always in the same node during each preempt trial filteredGPUs, filterDetails, err := filterRegistry.Apply(s.ctx, req.WorkloadNameNamespace, toFilterGPUs, false) if err != nil { @@ -266,6 +371,71 @@ func (s *GpuAllocator) Select(req *tfv1.AllocRequest, filteredGPUs []*tfv1.GPU) return result, nil } +// GetMatchedPartition finds the best matching partition template for a request in partitioned mode. +// Returns the GPU, matched partition template, and partition UUID if a match is found. +// In partitioned mode, GPUs must have partition templates available, and we select the smallest +// template that can satisfy the request to minimize resource waste. +func (s *GpuAllocator) GetMatchedPartition( + req *tfv1.AllocRequest, + filteredGPUs []*tfv1.GPU, +) (*tfv1.GPU, *PartitionMatchResult, error) { + // Only process partitioned mode requests + if req.Isolation != tfv1.IsolationModePartitioned { + return nil, nil, fmt.Errorf("GetMatchedPartition only supports partitioned isolation mode") + } + + if len(filteredGPUs) == 0 { + return nil, nil, fmt.Errorf("no GPUs available for partition matching") + } + + var bestGPU *tfv1.GPU + var bestMatch *PartitionMatchResult + bestScore := math.MaxFloat64 + + s.storeMutex.RLock() + defer s.storeMutex.RUnlock() + + // Find the best GPU with the best matching partition template + for _, gpu := range filteredGPUs { + // Get partition templates from GPU status + if len(gpu.Status.PartitionTemplates) == 0 { + continue // Skip GPUs without partition templates + } + // Match partition template (gets template info from config) + match, err := MatchPartitionTemplate(gpu.Status, req) + if err != nil { + log.FromContext(s.ctx).V(5).Info("Failed to match partition template for GPU", + "gpu", gpu.Name, "error", err) + continue + } + + if !match.CanAllocate { + continue + } + + // Check if GPU has enough resources (gets template info from config) + if err := CheckPartitionAvailability(gpu, match.TemplateID); err != nil { + log.FromContext(s.ctx).V(5).Info("GPU does not have available resources for partition", + "gpu", gpu.Name, "error", err) + continue + } + + // Update best match if this is better (lower score = less waste) + if match.Score < bestScore { + bestGPU = gpu + bestMatch = match + bestScore = match.Score + } + } + + if bestGPU == nil || bestMatch == nil { + return nil, nil, fmt.Errorf("no suitable partition template found for request: TFLOPs=%s, VRAM=%s", + req.Request.Tflops.String(), req.Request.Vram.String()) + } + + return bestGPU, bestMatch, nil +} + // Bind allocates resources on the provided GPUs for the given request. // It updates the in-memory store and marks the GPUs as dirty for syncing. func (s *GpuAllocator) Bind( @@ -302,24 +472,32 @@ func (s *GpuAllocator) Bind( if gpu.Status.Available == nil { return nil, fmt.Errorf("GPU %s has nil available resources", selectedGPU) } - if gpu.Status.Available.Tflops.Cmp(req.Request.Tflops) < 0 { - return nil, fmt.Errorf("GPU %s insufficient TFLOPs: available %s, requested %s", - selectedGPU, gpu.Status.Available.Tflops.String(), req.Request.Tflops.String()) - } - if gpu.Status.Available.Vram.Cmp(req.Request.Vram) < 0 { - return nil, fmt.Errorf("GPU %s insufficient VRAM: available %s, requested %s", - selectedGPU, gpu.Status.Available.Vram.String(), req.Request.Vram.String()) - } - - // reduce available resource on the GPU status - if !req.Request.ComputePercent.IsZero() { - requiredTflops := utils.ComputePercentToTflops(gpu.Status.Capacity.Tflops, req.Request) - gpu.Status.Available.Tflops.Sub(*requiredTflops) + // Handle partitioned mode differently + if req.Isolation == tfv1.IsolationModePartitioned && req.PartitionTemplateID != "" { + if err := s.bindPartition(gpu, req, selectedGPU); err != nil { + return nil, err + } } else { - gpu.Status.Available.Tflops.Sub(req.Request.Tflops) + // Non-partitioned mode: subtract request resources + if gpu.Status.Available.Tflops.Cmp(req.Request.Tflops) < 0 { + return nil, fmt.Errorf("GPU %s insufficient TFLOPs: available %s, requested %s", + selectedGPU, gpu.Status.Available.Tflops.String(), req.Request.Tflops.String()) + } + if gpu.Status.Available.Vram.Cmp(req.Request.Vram) < 0 { + return nil, fmt.Errorf("GPU %s insufficient VRAM: available %s, requested %s", + selectedGPU, gpu.Status.Available.Vram.String(), req.Request.Vram.String()) + } + + // reduce available resource on the GPU status + if !req.Request.ComputePercent.IsZero() { + requiredTflops := utils.ComputePercentToTflops(gpu.Status.Capacity.Tflops, req.Request) + gpu.Status.Available.Tflops.Sub(*requiredTflops) + } else { + gpu.Status.Available.Tflops.Sub(req.Request.Tflops) + } + gpu.Status.Available.Vram.Sub(req.Request.Vram) } - gpu.Status.Available.Vram.Sub(req.Request.Vram) addRunningApp(s.ctx, gpu, req) @@ -460,18 +638,18 @@ func (s *GpuAllocator) Dealloc( ) { <-s.initializedCh podUID := string(podMeta.UID) - log := log.FromContext(s.ctx) + logger := log.FromContext(s.ctx) request, exists := s.uniqueAllocation[podUID] if !exists || request == nil { // should not block finalizer - log.Error(fmt.Errorf("pod has not allocated GPUs"), "pod", podUID) + logger.Error(fmt.Errorf("pod has not allocated GPUs"), "pod", podUID) return } if _, exists := s.uniqueDeallocation[podUID]; exists { // should not block finalizer - log.Error(fmt.Errorf("pod has already deallocated GPUs"), "pod", podUID) + logger.Error(fmt.Errorf("pod has already deallocated GPUs"), "pod", podUID) return } @@ -484,18 +662,23 @@ func (s *GpuAllocator) Dealloc( gpuNameNs := types.NamespacedName{Name: gpu} storeGPU, exists := s.gpuStore[gpuNameNs] if !exists { - log.Error(fmt.Errorf("GPU not found in store"), "Failed to deallocate GPU", "name", gpu) + logger.Error(fmt.Errorf("GPU not found in store"), "Failed to deallocate GPU", "name", gpu) continue } - // Add resources back to the GPU - if !request.Request.ComputePercent.IsZero() { - requiredTflops := utils.ComputePercentToTflops(storeGPU.Status.Capacity.Tflops, request.Request) - storeGPU.Status.Available.Tflops.Add(*requiredTflops) + // Handle partitioned mode deallocation + if request.Isolation == tfv1.IsolationModePartitioned && request.PartitionTemplateID != "" { + s.deallocPartition(storeGPU, request, gpu) } else { - storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + // Non-partitioned mode: add back request resources + if !request.Request.ComputePercent.IsZero() { + requiredTflops := utils.ComputePercentToTflops(storeGPU.Status.Capacity.Tflops, request.Request) + storeGPU.Status.Available.Tflops.Add(*requiredTflops) + } else { + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + } + storeGPU.Status.Available.Vram.Add(request.Request.Vram) } - storeGPU.Status.Available.Vram.Add(request.Request.Vram) if nodeName == "" { nodeName = storeGPU.Status.NodeSelector[constants.KubernetesHostNameLabel] @@ -515,7 +698,7 @@ func (s *GpuAllocator) Dealloc( // Deallocate quota resources in memory (atomic operation) s.quotaStore.DeallocateQuota(workloadNameNamespace.Namespace, request) - log.Info("GPU deallocation successful", + logger.Info("GPU deallocation successful", "namespace", workloadNameNamespace.Namespace, "workload", workloadNameNamespace.Name, "gpu_count", len(gpus), @@ -1071,6 +1254,9 @@ func syncGPUMetadataAndStatusFromCluster(old *tfv1.GPU, gpu *tfv1.GPU) { old.Status.Vendor = gpu.Status.Vendor old.Status.NUMANode = gpu.Status.NUMANode old.Status.Index = gpu.Status.Index + // Sync partition templates from cluster (discovered by node discovery) + // Don't overwrite AllocatedPartitions as that's managed by the allocator + old.Status.PartitionTemplates = gpu.Status.PartitionTemplates } func (s *GpuAllocator) handleGPUUpdateCapacityDiff(old, gpu *tfv1.GPU) { @@ -1151,6 +1337,7 @@ func (s *GpuAllocator) SyncGPUsToK8s() { // Apply our status updates to the latest version latest.Status.Available = gpu.Status.Available latest.Status.RunningApps = gpu.Status.RunningApps + latest.Status.AllocatedPartitions = gpu.Status.AllocatedPartitions // Attempt to update with the latest version return s.Status().Update(s.ctx, latest) @@ -1316,7 +1503,7 @@ func (s *GpuAllocator) reconcileAllocationState() { !controllerutil.ContainsFinalizer(&worker, constants.Finalizer) if scheduled { - allocRequest, msg, err := s.ComposeAllocationRequest(&worker) + allocRequest, msg, err := utils.ComposeAllocationRequest(ctx, &worker) if err != nil { logger.Error(err, "Failed to compose allocation request for existing worker Pod, annotation may not be valid", "pod", worker.Name, "msg", msg) return false @@ -1340,6 +1527,8 @@ func (s *GpuAllocator) reconcileAllocationState() { actualRunningAppsMap[gpuKey] = gpu.Status.RunningApps gpu.Status.RunningApps = []*tfv1.RunningAppDetail{} + // Clear AllocatedPartitions - will be rebuilt from workers + gpu.Status.AllocatedPartitions = make(map[string]tfv1.AllocatedPartition) } // This is important for progressive migration mode @@ -1357,12 +1546,55 @@ func (s *GpuAllocator) reconcileAllocationState() { for gpuId := range gpuIdsList { gpuKey := types.NamespacedName{Name: gpuId} + gpu := s.gpuStore[gpuKey] + if gpu == nil { + continue + } + gpuAvailableRes, ok := actualAvailableMap[gpuKey] if ok { - gpuAvailableRes.Tflops.Sub(allocRequest.Request.Tflops) - gpuAvailableRes.Vram.Sub(allocRequest.Request.Vram) + // Handle partitioned mode differently + if allocRequest.Isolation == tfv1.IsolationModePartitioned && allocRequest.PartitionTemplateID != "" { + // Calculate partition resource usage from config + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage(gpu.Status.Capacity.Tflops, gpu.Status.GPUModel, allocRequest.PartitionTemplateID) + if err == nil { + gpuAvailableRes.Tflops.Sub(partitionTflops) + gpuAvailableRes.Vram.Sub(partitionVram) + + // Rebuild AllocatedPartitions using podUID as key + if gpu.Status.AllocatedPartitions == nil { + gpu.Status.AllocatedPartitions = make(map[string]tfv1.AllocatedPartition) + } + podUID := string(worker.UID) + // During reconciliation, preserve existing slot assignments if available + existingPartition, exists := gpu.Status.AllocatedPartitions[podUID] + allocatedPartition := tfv1.AllocatedPartition{ + TemplateID: allocRequest.PartitionTemplateID, + PodUID: podUID, + PodName: worker.Name, + Namespace: worker.Namespace, + AllocatedAt: metav1.Now(), // Use current time for reconciliation + } + // Preserve existing slot assignments if they exist + if exists { + allocatedPartition.AllocatedSlotStart = existingPartition.AllocatedSlotStart + allocatedPartition.AllocatedSlotEnd = existingPartition.AllocatedSlotEnd + } + gpu.Status.AllocatedPartitions[podUID] = allocatedPartition + } else { + // Fallback to request resources if template not found + logger.Info("Partition template not found in config during reconciliation, using request resources", + "gpu", gpuId, "template", allocRequest.PartitionTemplateID, "error", err) + gpuAvailableRes.Tflops.Sub(allocRequest.Request.Tflops) + gpuAvailableRes.Vram.Sub(allocRequest.Request.Vram) + } + } else { + // Non-partitioned mode + gpuAvailableRes.Tflops.Sub(allocRequest.Request.Tflops) + gpuAvailableRes.Vram.Sub(allocRequest.Request.Vram) + } } - addRunningApp(ctx, s.gpuStore[gpuKey], allocRequest) + addRunningApp(ctx, gpu, allocRequest) } } @@ -1384,6 +1616,12 @@ func (s *GpuAllocator) reconcileAllocationState() { s.markGPUDirtyLocked(gpuKey) log.FromContext(ctx).Info("Correcting gpu running apps", "gpu", gpuKey.Name, "runningApps", len(gpu.Status.RunningApps)) } + + // Mark GPU dirty if AllocatedPartitions need to be synced + // (they are already updated in the loop above, just need to sync to K8s) + if len(gpu.Status.AllocatedPartitions) > 0 { + s.markGPUDirtyLocked(gpuKey) + } } // reconcile quota store state @@ -1482,65 +1720,124 @@ func removeRunningApp(ctx context.Context, gpu *tfv1.GPU, allocRequest *tfv1.All } } -func (s *GpuAllocator) ComposeAllocationRequest(pod *v1.Pod) (*tfv1.AllocRequest, string, error) { - // allow Pods with no requests/limits to use TensorFusion, Pod webhook will ensure at least one request/limit is set - gpuRequestResource, err := utils.GetGPUResource(pod, true) - if err != nil { - log.FromContext(s.ctx).Error(err, "Invalid gpu request annotation", "pod", pod.Name, "namespace", pod.Namespace) +// bindPartition handles partition allocation for a single GPU in partitioned mode +func (s *GpuAllocator) bindPartition(gpu *tfv1.GPU, req *tfv1.AllocRequest, selectedGPU string) error { + // Verify template exists in GPU status + templateExists := false + for _, template := range gpu.Status.PartitionTemplates { + if template.TemplateID == req.PartitionTemplateID { + templateExists = true + break + } } - gpuLimitResource, err := utils.GetGPUResource(pod, false) + if !templateExists { + return fmt.Errorf("partition template %s not found on GPU %s", req.PartitionTemplateID, selectedGPU) + } + + // Calculate partition resource usage from config (no overhead) + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage(gpu.Status.Capacity.Tflops, gpu.Status.GPUModel, req.PartitionTemplateID) if err != nil { - log.FromContext(s.ctx).Error(err, "Invalid gpu limit annotation", "pod", pod.Name, "namespace", pod.Namespace) + return fmt.Errorf("failed to get partition template info for GPU %s template %s: %w", selectedGPU, req.PartitionTemplateID, err) } - count := 1 - if gpuCountStr, exists := pod.Annotations[constants.GpuCountAnnotation]; exists { - count, err = strconv.Atoi(gpuCountStr) - if err != nil { - return &tfv1.AllocRequest{}, "invalid gpu count annotation", err - } + // Check availability for partition resources + if gpu.Status.Available.Tflops.Cmp(partitionTflops) < 0 { + return fmt.Errorf("GPU %s insufficient TFLOPs for partition: available %s, required %s", + selectedGPU, gpu.Status.Available.Tflops.String(), partitionTflops.String()) } - if count > MaxGPUCounterPerAllocation { - return &tfv1.AllocRequest{}, "gpu count annotation is too large", nil + if gpu.Status.Available.Vram.Cmp(partitionVram) < 0 { + return fmt.Errorf("GPU %s insufficient VRAM for partition: available %s, required %s", + selectedGPU, gpu.Status.Available.Vram.String(), partitionVram.String()) } - qosLevel := tfv1.QoSLevel(pod.Annotations[constants.QoSLevelAnnotation]) - if qosLevel == "" { - qosLevel = tfv1.QoSMedium - } + // Subtract partition resources (no overhead) + gpu.Status.Available.Tflops.Sub(partitionTflops) + gpu.Status.Available.Vram.Sub(partitionVram) - gpuVendor := pod.Annotations[constants.GpuVendorAnnotation] + // Initialize AllocatedPartitions map if needed + if gpu.Status.AllocatedPartitions == nil { + gpu.Status.AllocatedPartitions = make(map[string]tfv1.AllocatedPartition) + } - gpuIndices, hasError := utils.ParseIndicesAnnotation(pod.Annotations[constants.GpuIndicesAnnotation]) - if hasError { - return &tfv1.AllocRequest{}, "invalid gpu-indices annotation", - fmt.Errorf("can not parse gpu indices annotation") + // Find and assign slot position + var slotStart, slotEnd *uint32 + templateConfigs, exists := PartitionTemplateMap[gpu.Status.GPUModel] + if exists { + if templateInfo, found := templateConfigs[req.PartitionTemplateID]; found { + if len(templateInfo.PlacementLimit) > 0 && templateInfo.PlacementOffSet > 0 { + // Build slot occupancy map from existing partitions + occupiedSlots := buildSlotOccupancyMap(gpu, templateConfigs) + // Find available slot position + if startPos, found := findAvailableSlotPosition(templateInfo, occupiedSlots); found { + slotStart = &startPos + endPos := startPos + templateInfo.PlacementOffSet + slotEnd = &endPos + } + } + } } - allocRequest := tfv1.AllocRequest{ - PoolName: pod.Annotations[constants.GpuPoolKey], - Request: gpuRequestResource, - Limit: gpuLimitResource, + // Store partition allocation info using podUID as key + podUID := string(req.PodMeta.UID) + gpu.Status.AllocatedPartitions[podUID] = tfv1.AllocatedPartition{ + TemplateID: req.PartitionTemplateID, + PodUID: podUID, + PodName: req.PodMeta.Name, + Namespace: req.PodMeta.Namespace, + AllocatedAt: metav1.Now(), + AllocatedSlotStart: slotStart, + AllocatedSlotEnd: slotEnd, + } + + log.FromContext(s.ctx).Info("Allocated partition on GPU", + "gpu", selectedGPU, + "template", req.PartitionTemplateID, + "podUID", podUID, + "slotStart", slotStart, + "slotEnd", slotEnd) + return nil +} - Count: uint(count), - GPUModel: pod.Annotations[constants.GPUModelAnnotation], - GPUIndices: gpuIndices, - GPUVendor: gpuVendor, - WorkloadNameNamespace: tfv1.NameNamespace{ - Name: pod.Labels[constants.WorkloadKey], - Namespace: pod.Namespace, - }, - PodMeta: pod.ObjectMeta, - QoS: qosLevel, - } +// deallocPartition handles partition deallocation for a single GPU in partitioned mode +func (s *GpuAllocator) deallocPartition(storeGPU *tfv1.GPU, request *tfv1.AllocRequest, gpu string) { + logger := log.FromContext(s.ctx) + // Find and remove the allocated partition using podUID as key + podUID := string(request.PodMeta.UID) + if storeGPU.Status.AllocatedPartitions != nil { + allocatedPartition, exists := storeGPU.Status.AllocatedPartitions[podUID] + if exists { + // Calculate partition resource usage from config (no overhead) + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage(storeGPU.Status.Capacity.Tflops, storeGPU.Status.GPUModel, allocatedPartition.TemplateID) + if err != nil { + // Fallback: add back request resources if template not found in config + logger.Info("Partition template not found in config during deallocation, using request resources", + "gpu", gpu, "template", allocatedPartition.TemplateID, "error", err) + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + storeGPU.Status.Available.Vram.Add(request.Request.Vram) + } else { + // Add back partition resources (no overhead) + storeGPU.Status.Available.Tflops.Add(partitionTflops) + storeGPU.Status.Available.Vram.Add(partitionVram) + } - // for already allocated workers, set the GPU device IDs for further scaling and retrieval - if gpuIdStr, exists := pod.Annotations[constants.GPUDeviceIDsAnnotation]; exists { - gpuIds := strings.SplitSeq(gpuIdStr, ",") - allocRequest.GPUNames = slices.Collect(gpuIds) + // Remove partition from allocated partitions map using podUID + delete(storeGPU.Status.AllocatedPartitions, podUID) + logger.Info("Removed partition allocation", + "gpu", gpu, + "podUID", podUID, + "template", allocatedPartition.TemplateID) + } else { + logger.Info("Partition not found in allocated partitions during deallocation", + "gpu", gpu, "podUID", podUID) + // Fallback: add back request resources + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + storeGPU.Status.Available.Vram.Add(request.Request.Vram) + } + } else { + // No allocated partitions map, fallback to request resources + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + storeGPU.Status.Available.Vram.Add(request.Request.Vram) } - - return &allocRequest, "", nil } func (s *GpuAllocator) addAllocationMap(gpuNodeName string, podMeta metav1.ObjectMeta) { diff --git a/internal/gpuallocator/partitioned_scheduling.go b/internal/gpuallocator/partitioned_scheduling.go new file mode 100644 index 00000000..09bf650a --- /dev/null +++ b/internal/gpuallocator/partitioned_scheduling.go @@ -0,0 +1,342 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gpuallocator + +import ( + "fmt" + "math" + "sort" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/config" + "github.com/NexusGPU/tensor-fusion/internal/utils" + "k8s.io/apimachinery/pkg/api/resource" +) + +const DefaultMaxPartitionNum = 32 +const PartitionMatchingComputingWeight = 0.6 +const PartitionMatchingVRAMWeight = 0.4 + +// PartitionMatchResult represents the result of matching a partition template to a request +type PartitionMatchResult struct { + Template *config.PartitionTemplateInfo // Template info from config + TemplateID string // Template ID + Score float64 // Lower score means better match (less waste) + CanAllocate bool + Reason string +} + +// MatchPartitionTemplate matches a partition template to an allocation request. +// Gets template info from config (PartitionTemplateMap) based on GPU model. +// In partitioned mode, we find the smallest template that can satisfy the request. +func MatchPartitionTemplate(gpuStatus tfv1.GPUStatus, req *tfv1.AllocRequest) (*PartitionMatchResult, error) { + gpuModel := gpuStatus.GPUModel + gpuTemplates := gpuStatus.PartitionTemplates + if len(gpuTemplates) == 0 { + return nil, fmt.Errorf("no partition templates available for GPU model %s", gpuModel) + } + + // Get template configs from global map + templateConfigs, exists := PartitionTemplateMap[gpuModel] + if !exists || len(templateConfigs) == 0 { + return nil, fmt.Errorf("no partition template configs found for GPU model %s", gpuModel) + } + + // Convert request to comparable values + // Handle ComputePercent: convert to TFLOPs if specified + var requestTflops float64 + if !req.Request.ComputePercent.IsZero() { + // Get GPU capacity from global map to convert ComputePercent to TFLOPs + mu.Lock() + gpuCapacity, exists := GPUCapacityMap[gpuModel] + mu.Unlock() + if !exists { + return nil, fmt.Errorf("GPU capacity not found for model %s, cannot convert ComputePercent to TFLOPs", gpuModel) + } + requiredTflops := utils.ComputePercentToTflops(gpuCapacity.Tflops, req.Request) + requestTflops = requiredTflops.AsApproximateFloat64() + } else { + requestTflops = req.Request.Tflops.AsApproximateFloat64() + } + requestVramBytes := req.Request.Vram.Value() + + // Get max partitions from config + maxPartitions := MaxPartitionsMap[gpuModel] + if maxPartitions <= 0 { + maxPartitions = DefaultMaxPartitionNum + } + + // Find the best matching template + var bestMatch *PartitionMatchResult + bestScore := math.MaxFloat64 // Lower is better (we want smallest that fits) + + for _, gpuTemplate := range gpuTemplates { + // Get detailed template info from config + templateInfo, exists := templateConfigs[gpuTemplate.TemplateID] + if !exists { + continue // Skip if template not found in config + } + + // If a specific template is required, only consider that one + if req.PartitionTemplateID != "" && gpuTemplate.TemplateID != req.PartitionTemplateID { + continue + } + + result := &PartitionMatchResult{ + Template: &templateInfo, + TemplateID: gpuTemplate.TemplateID, + CanAllocate: false, + } + + // Check if template resources can satisfy the request + templateTflops := templateInfo.ComputePercent * gpuStatus.Capacity.Tflops.AsApproximateFloat64() + templateVramBytes := int64(templateInfo.MemoryGigabytes * 1024 * 1024 * 1024) + + // Check if template has enough resources + if templateTflops < requestTflops { + result.Reason = fmt.Sprintf("template %s has insufficient TFLOPs: %.2f < %.2f", + gpuTemplate.TemplateID, templateTflops, requestTflops) + continue + } + + if templateVramBytes < requestVramBytes { + result.Reason = fmt.Sprintf("template %s has insufficient VRAM: %d < %d", + gpuTemplate.TemplateID, templateVramBytes, requestVramBytes) + continue + } + + // Check if we can allocate more partitions (MIG constraint) + currentPartitionCount := len(gpuStatus.AllocatedPartitions) + if maxPartitions > 0 && uint32(currentPartitionCount) >= maxPartitions { + result.Reason = fmt.Sprintf("GPU has reached maximum partition count: %d/%d", + currentPartitionCount, maxPartitions) + continue + } + + // Calculate score: prefer templates that are just large enough (minimize waste) + tflopsWaste := (templateTflops - requestTflops) / math.Max(requestTflops, 1.0) + vramWaste := float64(templateVramBytes-requestVramBytes) / math.Max(float64(requestVramBytes), 1.0) + score := tflopsWaste*PartitionMatchingComputingWeight + vramWaste*PartitionMatchingVRAMWeight + + result.Score = score + result.CanAllocate = true + result.Reason = "template can satisfy request" + + // Update best match if this is better + if bestMatch == nil || score < bestScore { + bestMatch = result + bestScore = score + } + } + + if bestMatch == nil { + return nil, fmt.Errorf("no partition template can satisfy request: TFLOPs=%.2f, VRAM=%d", + requestTflops, requestVramBytes) + } + + return bestMatch, nil +} + +// CalculatePartitionResourceUsage calculates the resource usage for a partition template. +// Gets template info from config. +func CalculatePartitionResourceUsage(capacityTflops resource.Quantity, gpuModel, templateID string) (tflops resource.Quantity, vram resource.Quantity, err error) { + templateConfigs, exists := PartitionTemplateMap[gpuModel] + if !exists { + return resource.Quantity{}, resource.Quantity{}, fmt.Errorf("no partition template configs for GPU model %s", gpuModel) + } + + templateInfo, exists := templateConfigs[templateID] + if !exists { + return resource.Quantity{}, resource.Quantity{}, fmt.Errorf("partition template %s not found for GPU model %s", templateID, gpuModel) + } + + tflops = resource.MustParse(fmt.Sprintf("%.2f", templateInfo.ComputePercent*capacityTflops.AsApproximateFloat64()/100.0)) + vram = resource.MustParse(fmt.Sprintf("%dGi", templateInfo.MemoryGigabytes)) + + return tflops, vram, nil +} + +// areSlotsFree checks if slots starting from startPos for offset slots are all free. +func areSlotsFree(occupiedSlots map[uint32]bool, startPos, offset uint32) bool { + for i := range offset { + if occupiedSlots[startPos+i] { + return false + } + } + return true +} + +// buildSlotOccupancyMap builds a map of occupied slots from existing partitions. +// Uses AllocatedSlotStart/End if available, otherwise falls back to greedy assignment. +func buildSlotOccupancyMap( + gpu *tfv1.GPU, + templateConfigs map[string]config.PartitionTemplateInfo, +) map[uint32]bool { + occupiedSlots := make(map[uint32]bool) + + // First, use explicit slot assignments if available + for _, partition := range gpu.Status.AllocatedPartitions { + if partition.AllocatedSlotStart != nil && partition.AllocatedSlotEnd != nil { + start := *partition.AllocatedSlotStart + end := *partition.AllocatedSlotEnd + for slot := start; slot < end; slot++ { + occupiedSlots[slot] = true + } + } + } + + // For partitions without explicit slot assignments, use greedy approach + // Convert map to slice and sort by AllocatedAt timestamp (ASC) + partitions := make([]tfv1.AllocatedPartition, 0, len(gpu.Status.AllocatedPartitions)) + for _, partition := range gpu.Status.AllocatedPartitions { + // Skip if already has explicit slot assignment + if partition.AllocatedSlotStart != nil && partition.AllocatedSlotEnd != nil { + continue + } + partitions = append(partitions, partition) + } + + if len(partitions) > 0 { + sort.Slice(partitions, func(i, j int) bool { + // If both have valid timestamps, compare by time + if !partitions[i].AllocatedAt.IsZero() && !partitions[j].AllocatedAt.IsZero() { + if !partitions[i].AllocatedAt.Equal(&partitions[j].AllocatedAt) { + return partitions[i].AllocatedAt.Before(&partitions[j].AllocatedAt) + } + } + // Fallback to PodUID for stable ordering when timestamps are zero or equal + return partitions[i].PodUID < partitions[j].PodUID + }) + + // Process each partition without explicit slots in allocation order + for _, partition := range partitions { + templateInfo, exists := templateConfigs[partition.TemplateID] + if !exists || len(templateInfo.PlacementLimit) == 0 || templateInfo.PlacementOffSet == 0 { + continue + } + + // Find first available starting position for this partition + for _, startPos := range templateInfo.PlacementLimit { + if areSlotsFree(occupiedSlots, startPos, templateInfo.PlacementOffSet) { + // Assign this partition to this position + for i := uint32(0); i < templateInfo.PlacementOffSet; i++ { + occupiedSlots[startPos+i] = true + } + break + } + } + } + } + + return occupiedSlots +} + +// findAvailableSlotPosition finds the first available slot position for a template. +// Returns the starting position and true if found, 0 and false otherwise. +func findAvailableSlotPosition( + templateInfo config.PartitionTemplateInfo, + occupiedSlots map[uint32]bool, +) (uint32, bool) { + if len(templateInfo.PlacementLimit) == 0 || templateInfo.PlacementOffSet == 0 { + return 0, false + } + + for _, startPos := range templateInfo.PlacementLimit { + if areSlotsFree(occupiedSlots, startPos, templateInfo.PlacementOffSet) { + return startPos, true + } + } + + return 0, false +} + +// CheckPartitionAvailability checks if a GPU has enough resources to allocate a partition. +// Gets template info from config. +func CheckPartitionAvailability( + gpu *tfv1.GPU, + templateID string, +) error { + // Get template info from config first to check template-specific constraints + templateConfigs, exists := PartitionTemplateMap[gpu.Status.GPUModel] + if !exists { + return fmt.Errorf("no partition template configs for GPU model %s", gpu.Status.GPUModel) + } + + templateInfo, exists := templateConfigs[templateID] + if !exists { + return fmt.Errorf("partition template %s not found for GPU model %s", templateID, gpu.Status.GPUModel) + } + + currentCount := len(gpu.Status.AllocatedPartitions) + + // Check general partition count limit first (cheaper check) + maxPartitions := MaxPartitionsMap[gpu.Status.GPUModel] + if maxPartitions == 0 { + maxPartitions = 7 // Default MIG limit + } + if maxPartitions > 0 && uint32(currentCount) >= maxPartitions { + return fmt.Errorf("GPU %s has reached maximum partition count: %d/%d", + gpu.Name, currentCount, maxPartitions) + } + + // Count how many partitions of this template are already allocated + templateCount := uint32(0) + for _, partition := range gpu.Status.AllocatedPartitions { + if partition.TemplateID == templateID { + templateCount++ + } + } + + // Check MaxPartition limit for this specific template + if templateInfo.MaxPartition > 0 && templateCount >= templateInfo.MaxPartition { + return fmt.Errorf("GPU %s has reached maximum partition count for template %s: %d/%d", + gpu.Name, templateID, templateCount, templateInfo.MaxPartition) + } + + // Check placement slots using bitmask-based tracking + if len(templateInfo.PlacementLimit) > 0 && templateInfo.PlacementOffSet > 0 { + // Build slot occupancy map from existing partitions + occupiedSlots := buildSlotOccupancyMap(gpu, templateConfigs) + + // Check if the new template can find a valid placement + _, found := findAvailableSlotPosition(templateInfo, occupiedSlots) + if !found { + return fmt.Errorf("GPU %s has no available placement slots for template %s: required %d slots starting from positions %v", + gpu.Name, templateID, templateInfo.PlacementOffSet, templateInfo.PlacementLimit) + } + } + + // Calculate required resources from config + requiredTflops, requiredVram, err := CalculatePartitionResourceUsage(gpu.Status.Capacity.Tflops, gpu.Status.GPUModel, templateID) + if err != nil { + return err + } + + // Check TFLOPs availability + if gpu.Status.Available.Tflops.Cmp(requiredTflops) < 0 { + return fmt.Errorf("GPU %s insufficient TFLOPs for partition: available %s, required %s", + gpu.Name, gpu.Status.Available.Tflops.String(), requiredTflops.String()) + } + + // Check VRAM availability + if gpu.Status.Available.Vram.Cmp(requiredVram) < 0 { + return fmt.Errorf("GPU %s insufficient VRAM for partition: available %s, required %s", + gpu.Name, gpu.Status.Available.Vram.String(), requiredVram.String()) + } + + return nil +} diff --git a/internal/gpuallocator/partitioned_scheduling_test.go b/internal/gpuallocator/partitioned_scheduling_test.go new file mode 100644 index 00000000..5d020cf2 --- /dev/null +++ b/internal/gpuallocator/partitioned_scheduling_test.go @@ -0,0 +1,527 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gpuallocator + +import ( + "testing" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/config" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const testGPUModel = "A100_SXM_80G" + +func TestMatchPartitionTemplate(t *testing.T) { + // Setup: Initialize partition template map + gpuModel := testGPUModel + PartitionTemplateMap[gpuModel] = map[string]config.PartitionTemplateInfo{ + "1g.24gb": { + TemplateID: "19", + Name: "1g.24gb", + MemoryGigabytes: 24, // 24GB (function converts to bytes) + ComputePercent: 1.0 / 7.0 * 100, + }, + "4g.94gb": { + TemplateID: "9", + Name: "4g.94gb", + MemoryGigabytes: 94, // 94GB (function converts to bytes) + ComputePercent: 4.0 / 7.0 * 100, + }, + } + // Setup: Initialize GPU capacity map for ComputePercent conversion + // A100_SXM_80G has ~312 TFLOPs capacity + mu.Lock() + GPUCapacityMap[gpuModel] = tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + } + mu.Unlock() + + tests := []struct { + name string + gpuTemplates []tfv1.PartitionTemplate + req *tfv1.AllocRequest + allocatedPartitions map[string]tfv1.AllocatedPartition + expectError bool + expectedTemplateID string + }{ + { + name: "match smallest template that fits", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("30"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "1g.24gb", // Should match smallest that fits + }, + { + name: "match specific template when required", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("30"), + Vram: resource.MustParse("20Gi"), + }, + PartitionTemplateID: "4g.94gb", + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "4g.94gb", + }, + { + name: "no template matches request", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("300"), // Too large + Vram: resource.MustParse("100Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + { + name: "no templates available", + gpuTemplates: []tfv1.PartitionTemplate{}, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("30"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + { + name: "match with ComputePercent - smallest template that fits", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + // 10% of 312 TFLOPs = 31.2 TFLOPs, should match 1g.24gb (50 TFLOPs) + ComputePercent: resource.MustParse("10"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "1g.24gb", + }, + { + name: "match with ComputePercent - requires larger template", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + // 50% of 312 TFLOPs = 156 TFLOPs, should match 4g.94gb (200 TFLOPs) + ComputePercent: resource.MustParse("50"), + Vram: resource.MustParse("50Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "4g.94gb", + }, + { + name: "match with ComputePercent - no template matches", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + // 80% of 312 TFLOPs = 249.6 TFLOPs, too large for 1g.24gb (50 TFLOPs) + ComputePercent: resource.MustParse("80"), + Vram: resource.MustParse("100Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + { + name: "match with ComputePercent - missing GPU capacity", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + ComputePercent: resource.MustParse("10"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use different GPU model for missing capacity test + testGPUModel := gpuModel + if tt.name == "match with ComputePercent - missing GPU capacity" { + testGPUModel = "UNKNOWN_GPU_MODEL" + } + + result, err := MatchPartitionTemplate( + tfv1.GPUStatus{ + GPUModel: testGPUModel, + PartitionTemplates: tt.gpuTemplates, + AllocatedPartitions: tt.allocatedPartitions, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + }, + tt.req, + ) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.CanAllocate) + assert.Equal(t, tt.expectedTemplateID, result.TemplateID) + } + }) + } +} + +func TestCalculatePartitionResourceUsage(t *testing.T) { + // Setup + gpuModel := testGPUModel + templateID := "1g.24gb" + PartitionTemplateMap[gpuModel] = map[string]config.PartitionTemplateInfo{ + templateID: { + TemplateID: templateID, + Name: "1g.24gb", + MemoryGigabytes: 24, // 24GB (function converts to bytes) + ComputePercent: 1.0 / 7.0 * 100, + }, + } + + tflops, vram, err := CalculatePartitionResourceUsage(resource.MustParse("312"), gpuModel, templateID) + + assert.NoError(t, err) + // Compare using Cmp to handle different formatting + // 1/7 of 312 TFLOPs = 44.57 TFLOPs + expectedTflops := resource.MustParse("44.57") + assert.Equal(t, 0, tflops.Cmp(expectedTflops), "TFLOPs: got %s, expected %s", tflops.String(), expectedTflops.String()) + // Compare VRAM using Cmp to handle quantity representation differences + assert.Equal(t, 0, vram.Cmp(resource.MustParse("24Gi")), "VRAM: got %s, expected 24Gi", vram.String()) +} + +func TestCheckPartitionAvailability(t *testing.T) { + // Setup: A100 MIG constraints based on nvidia-smi mig -lgipp output + // Profile 19 (1g.24gb): Placements {0,1,2,3,4,5,6}:1 - can start at any of 7 positions, occupies 1 slot each + // Profile 9 (4g.94gb): Placements {0,4}:4 - can start at position 0 or 4, occupies 4 slots each + gpuModel := testGPUModel + template1g := "1g.24gb" // Profile 19 + template4g := "4g.94gb" // Profile 9 + + // Clear and setup maps for this test + mu.Lock() + PartitionTemplateMap[gpuModel] = map[string]config.PartitionTemplateInfo{ + template1g: { + TemplateID: template1g, + Name: "1g.24gb", + MemoryGigabytes: 24, // 24GB + ComputePercent: 1.0 / 7.0 * 100, + MaxPartition: 7, // Can allocate up to 7 instances + PlacementLimit: []uint32{0, 1, 2, 3, 4, 5, 6}, // Can start at any of these positions + PlacementOffSet: 1, // Occupies 1 slot + }, + template4g: { + TemplateID: template4g, + Name: "4g.94gb", + MemoryGigabytes: 94, // 94GB + ComputePercent: 4.0 / 7.0 * 100, + MaxPartition: 2, // Can only allocate 2 instances + PlacementLimit: []uint32{0, 4}, // Can start at position 0 or 4 + PlacementOffSet: 4, // Occupies 4 slots (0-3 or 4-7) + }, + } + MaxPartitionsMap[gpuModel] = 7 + MaxPlacementSlotsMap[gpuModel] = 8 // A100 has 8 placement slots (0-7) + mu.Unlock() + + tests := []struct { + name string + gpu *tfv1.GPU + templateID string + expectError bool + errorContains string + }{ + { + name: "happy path - 1g.24gb allocation succeeds", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("100"), + Vram: resource.MustParse("50Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{}, + }, + }, + templateID: template1g, + expectError: false, + }, + { + name: "Profile 19 * 4 should fail - all valid positions occupied", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("96Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Profile 19 at position 0 (slot 0) + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Profile 19 at position 1 (slot 1) + "pod-3": {TemplateID: template1g, PodUID: "pod-3"}, // Profile 19 at position 2 (slot 2) + "pod-4": {TemplateID: template1g, PodUID: "pod-4"}, // Profile 19 at position 3 (slot 3) + // Positions 4,5,6 are still free, but trying to allocate 5th instance + // Actually wait, if we have 4 instances, we need to check if 5th can fit + // Let me change this to have Profile 9 at position 0, then Profile 19 * 3, then try 4th + }, + }, + }, + templateID: template1g, + expectError: false, // Actually 4 instances can fit at positions 0,1,2,3, leaving 4,5,6 free + }, + { + name: "Profile 9 at 0 + Profile 19 * 4 should fail", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("96Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-p9": {TemplateID: template4g, PodUID: "pod-p9", AllocatedAt: metav1.NewTime(metav1.Now().Add(-3 * time.Hour))}, // Profile 9 allocated first at position 0, occupies slots 0,1,2,3 + "pod-1": {TemplateID: template1g, PodUID: "pod-1", AllocatedAt: metav1.NewTime(metav1.Now().Add(-2 * time.Hour))}, // Profile 19 at position 4 (slot 4) + "pod-2": {TemplateID: template1g, PodUID: "pod-2", AllocatedAt: metav1.NewTime(metav1.Now().Add(-1 * time.Hour))}, // Profile 19 at position 5 (slot 5) + "pod-3": {TemplateID: template1g, PodUID: "pod-3", AllocatedAt: metav1.Now()}, // Profile 19 at position 6 (slot 6) + // Trying to allocate 4th Profile 19 instance - should fail + // All valid positions {0,1,2,3,4,5,6} are either occupied or conflict + }, + }, + }, + templateID: template1g, + expectError: true, + errorContains: "placement slots", + }, + { + name: "Profile 9 * 1 + Profile 19 * 3 should work", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("150"), + Vram: resource.MustParse("118Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-p9": {TemplateID: template4g, PodUID: "pod-p9"}, // Profile 9 at position 0, occupies slots 0,1,2,3 + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Profile 19 at slot 4 + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Profile 19 at slot 5 + // Trying to allocate 3rd Profile 19 instance - should succeed at slot 6 + }, + }, + }, + templateID: template1g, + expectError: false, // 3rd Profile 19 instance should succeed + }, + { + name: "Profile 9 * 1 + Profile 19 * 3 should work (happy case)", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("150"), + Vram: resource.MustParse("118Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-p9": {TemplateID: template4g, PodUID: "pod-p9"}, // Profile 9 at position 0, occupies slots 0,1,2,3 + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Profile 19 at slot 4 + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Profile 19 at slot 5 + // Trying to allocate 3rd Profile 19 instance - should succeed at slot 6 + }, + }, + }, + templateID: template1g, + expectError: false, + }, + { + name: "Profile 9 - all placement positions occupied", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("94Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: template4g, PodUID: "pod-1"}, // Profile 9 at position 0, occupies slots 0,1,2,3 + "pod-2": {TemplateID: template4g, PodUID: "pod-2"}, // Profile 9 at position 4, occupies slots 4,5,6,7 + // Both positions {0,4} are now occupied + }, + }, + }, + templateID: template4g, + expectError: true, + errorContains: "maximum partition count", // MaxPartition check happens first (2/2) + }, + { + name: "insufficient TFLOPs", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("10"), // Too low + Vram: resource.MustParse("50Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{}, + }, + }, + templateID: template1g, + expectError: true, + errorContains: "insufficient TFLOPs", + }, + { + name: "insufficient VRAM", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("100"), + Vram: resource.MustParse("10Gi"), // Too low for 24Gi required + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{}, + }, + }, + templateID: template1g, + expectError: true, + errorContains: "insufficient VRAM", + }, + { + name: "Profile 9 can allocate at position 4 when Profile 19 uses slots 0-2", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("94Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Slot 0 + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Slot 1 + "pod-3": {TemplateID: template1g, PodUID: "pod-3"}, // Slot 2 + // Slots 3,4,5,6,7 are free + // Profile 9 can use position 4 (slots 4,5,6,7) or position 0 (slots 0,1,2,3) + // Position 0 conflicts, but position 4 is free + }, + }, + }, + templateID: template4g, + expectError: false, // Profile 9 can use position 4 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckPartitionAvailability(tt.gpu, tt.templateID) + + if tt.expectError { + if !assert.Error(t, err) { + return // Stop if no error when one is expected + } + if tt.errorContains != "" && err != nil { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/hypervisor/api/device_types.go b/internal/hypervisor/api/device_types.go new file mode 100644 index 00000000..8b03888b --- /dev/null +++ b/internal/hypervisor/api/device_types.go @@ -0,0 +1,87 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package api + +// DeviceInfo represents discovered GPU device information +type DeviceInfo struct { + UUID string + Vendor string + Model string + Index int32 + NUMANode int32 + TotalMemoryBytes uint64 + MaxTflops float64 + Capabilities DeviceCapabilities + Properties map[string]string + Healthy bool +} + +// DeviceCapabilities represents device capabilities +type DeviceCapabilities struct { + SupportsPartitioning bool + SupportsSoftIsolation bool + SupportsHardIsolation bool + SupportsSnapshot bool + SupportsMetrics bool + MaxPartitions uint32 + MaxWorkersPerDevice uint32 +} + +// ComputeUtilization represents compute utilization for a process on a device +type ComputeUtilization struct { + ProcessID string + DeviceUUID string + UtilizationPercent float64 +} + +// MemoryUtilization represents memory utilization for a process on a device +type MemoryUtilization struct { + ProcessID string + DeviceUUID string + UsedBytes uint64 + ReservedBytes uint64 +} + +// GPUUsageMetrics represents GPU device metrics +type GPUUsageMetrics struct { + DeviceUUID string + MemoryBytes uint64 + MemoryPercentage float64 + ComputePercentage float64 + ComputeTflops float64 + Rx float64 // PCIe RX in KB + Tx float64 // PCIe TX in KB + Temperature float64 + GraphicsClockMHz float64 + SMClockMHz float64 + MemoryClockMHz float64 + VideoClockMHz float64 + PowerUsage int64 // in watts + ExtraMetrics map[string]float64 +} + +// WorkerMetrics represents worker process metrics on a device +type WorkerMetrics struct { + DeviceUUID string + WorkerUID string + ProcessID string + MemoryBytes uint64 + MemoryPercentage float64 + ComputeTflops float64 + ComputePercentage float64 + ExtraMetrics map[string]float64 +} diff --git a/internal/hypervisor/api/http_types.go b/internal/hypervisor/api/http_types.go new file mode 100644 index 00000000..16eecef5 --- /dev/null +++ b/internal/hypervisor/api/http_types.go @@ -0,0 +1,90 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package api + +import ( + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// HTTP API Response Types + +// ErrorResponse represents an error response +type ErrorResponse struct { + Error string `json:"error"` +} + +// DataResponse is a generic response wrapper for data-only responses +type DataResponse[T any] struct { + Data T `json:"data"` +} + +// MessageAndDataResponse is a generic response wrapper for responses with message and data +type MessageAndDataResponse[T any] struct { + Message string `json:"message"` + Data T `json:"data"` +} + +// StatusResponse represents a simple status response +type StatusResponse struct { + Status string `json:"status"` +} + +// Types to be compatible with legacy APIs + +// LimiterInfo represents worker limiter information (used in legacy.go) +type LimiterInfo struct { + WorkerUID string `json:"worker_uid"` + Requests *tfv1.Resource `json:"requests,omitempty"` + Limits *tfv1.Resource `json:"limits,omitempty"` +} + +// ListLimitersResponse represents the response from GET /api/v1/limiter (used in legacy.go) +type ListLimitersResponse struct { + Limiters []LimiterInfo `json:"limiters"` +} + +// TrapResponse represents the response from POST /api/v1/trap (used in legacy.go) +type TrapResponse struct { + Message string `json:"message"` + SnapshotCount int `json:"snapshot_count"` +} + +// PodInfo represents pod information for the /api/v1/pod endpoint (used in legacy.go) +type PodInfo struct { + PodName string `json:"pod_name"` + Namespace string `json:"namespace"` + GPUIDs []string `json:"gpu_uuids"` + TflopsLimit *float64 `json:"tflops_limit,omitempty"` + VramLimit *uint64 `json:"vram_limit,omitempty"` + QoSLevel *string `json:"qos_level,omitempty"` +} + +// ListPodsResponse represents the response from GET /api/v1/pod (used in legacy.go) +type ListPodsResponse struct { + Pods []PodInfo `json:"pods"` +} + +// ProcessInfo represents process mapping information (used in legacy.go) +type ProcessInfo struct { + WorkerUID string `json:"worker_uid"` + ProcessMapping map[string]string `json:"process_mapping"` // container PID -> host PID +} + +// ListProcessesResponse represents the response from GET /api/v1/process (used in legacy.go) +type ListProcessesResponse struct { + Processes []ProcessInfo `json:"processes"` +} diff --git a/internal/hypervisor/api/worker_types.go b/internal/hypervisor/api/worker_types.go new file mode 100644 index 00000000..ccaabce5 --- /dev/null +++ b/internal/hypervisor/api/worker_types.go @@ -0,0 +1,34 @@ +package api + +import ( + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// IsolationMode represents the isolation mode for worker processes +type IsolationMode = tfv1.IsolationModeType + +type WorkerInfo struct { + WorkerUID string + AllocatedDevices []string + Status string + PodUID string + PodName string + Namespace string + PartitionUUID string + IsolationMode IsolationMode + MemoryLimitBytes uint64 + ComputeLimitUnits uint32 + TemplateID string + Annotations map[string]string + PodIndex string + + // Tombstone field to indicate if the worker is deleted + Deleted bool +} + +type WorkerAllocation struct { + WorkerInfo *WorkerInfo + + // the complete or partitioned device info + DeviceInfos []*DeviceInfo +} diff --git a/internal/hypervisor/backend/kubernetes/apiserver.go b/internal/hypervisor/backend/kubernetes/apiserver.go new file mode 100644 index 00000000..8cc7a5b7 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/apiserver.go @@ -0,0 +1,290 @@ +package kubernetes + +import ( + "context" + "fmt" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" + "k8s.io/client-go/util/retry" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/apiutil" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" +) + +const ( + // bytesPerMiB is the number of bytes in a MiB + bytesPerMiB = 1024 * 1024 +) + +var ( + scheme = runtime.NewScheme() +) + +func init() { + utilruntime.Must(tfv1.AddToScheme(scheme)) +} + +// APIServer provides CRUD operations for GPU resources +type APIServer struct { + client client.Client + ctx context.Context +} + +// NewAPIServer creates a new API server instance with an existing client +func NewAPIServer(ctx context.Context, k8sClient client.Client) *APIServer { + return &APIServer{ + client: k8sClient, + ctx: ctx, + } +} + +// NewAPIServerFromConfig creates a new API server instance from a rest.Config +func NewAPIServerFromConfig(ctx context.Context, restConfig *rest.Config) (*APIServer, error) { + k8sClient, err := client.New(restConfig, client.Options{ + Scheme: scheme, + }) + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) + } + + return &APIServer{ + client: k8sClient, + ctx: ctx, + }, nil +} + +// GPUInfo contains information needed to create or update a GPU +type GPUInfo struct { + UUID string + DeviceName string + VRAMBytes uint64 + TFlops resource.Quantity + Index int32 + NUMANodeID int32 + NodeName string + Vendor string + IsolationMode tfv1.IsolationModeType +} + +// CreateOrUpdateGPU creates or updates a GPU resource with metadata and status +func (a *APIServer) CreateOrUpdateGPU(gpuNode *tfv1.GPUNode, info GPUInfo) (*tfv1.GPU, error) { + if len(gpuNode.OwnerReferences) == 0 { + return nil, fmt.Errorf("GPUNode %s has no owner references", gpuNode.Name) + } + + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: info.UUID, + }, + } + + // Create or update GPU metadata + if err := retry.OnError(wait.Backoff{ + Steps: 10, + Duration: time.Second, + Factor: 1.0, + Jitter: 0.1, + }, func(err error) bool { + return true // Retry on all errors + }, func() error { + _, err := controllerutil.CreateOrUpdate(a.ctx, a.client, gpu, func() error { + gpu.Labels = map[string]string{ + constants.LabelKeyOwner: gpuNode.Name, + constants.GpuPoolKey: gpuNode.OwnerReferences[0].Name, + } + gpu.Annotations = map[string]string{ + constants.LastSyncTimeAnnotationKey: time.Now().Format(time.RFC3339), + } + + if !metav1.IsControlledBy(gpu, gpuNode) { + gvk, err := apiutil.GVKForObject(gpuNode, scheme) + if err != nil { + return err + } + ref := metav1.OwnerReference{ + APIVersion: gvk.GroupVersion().String(), + Kind: gvk.Kind, + Name: gpuNode.GetName(), + UID: gpuNode.GetUID(), + BlockOwnerDeletion: ptr.To(true), + Controller: ptr.To(true), + } + gpu.OwnerReferences = []metav1.OwnerReference{ref} + } + return nil + }) + return err + }); err != nil { + return nil, fmt.Errorf("failed to create or update GPU %s: %w", info.UUID, err) + } + + // Update GPU status with retry on conflict + if err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { + if err := a.client.Get(a.ctx, client.ObjectKey{Name: info.UUID}, gpu); err != nil { + return err + } + + patch := client.MergeFrom(gpu.DeepCopy()) + a.setGPUStatus(gpu, info) + return a.client.Status().Patch(a.ctx, gpu, patch) + }); err != nil { + return nil, fmt.Errorf("failed to update GPU %s status: %w", info.UUID, err) + } + + return gpu, nil +} + +// setGPUStatus sets the GPU status fields from GPUInfo +func (a *APIServer) setGPUStatus(gpu *tfv1.GPU, info GPUInfo) { + gpu.Status.Capacity = &tfv1.Resource{ + Vram: resource.MustParse(fmt.Sprintf("%dMi", info.VRAMBytes/bytesPerMiB)), + Tflops: info.TFlops, + } + gpu.Status.UUID = info.UUID + gpu.Status.GPUModel = info.DeviceName + gpu.Status.Index = ptr.To(info.Index) + gpu.Status.Vendor = info.Vendor + gpu.Status.IsolationMode = info.IsolationMode + gpu.Status.NUMANode = ptr.To(info.NUMANodeID) + gpu.Status.NodeSelector = map[string]string{ + constants.KubernetesHostNameLabel: info.NodeName, + } + + if gpu.Status.Available == nil { + gpu.Status.Available = gpu.Status.Capacity.DeepCopy() + } + if gpu.Status.UsedBy == "" { + gpu.Status.UsedBy = tfv1.UsedByTensorFusion + } + if gpu.Status.Phase == "" { + gpu.Status.Phase = tfv1.TensorFusionGPUPhasePending + } +} + +// GetGPU retrieves a GPU resource by UUID +func (a *APIServer) GetGPU(uuid string) (*tfv1.GPU, error) { + gpu := &tfv1.GPU{} + if err := a.client.Get(a.ctx, client.ObjectKey{Name: uuid}, gpu); err != nil { + return nil, fmt.Errorf("failed to get GPU %s: %w", uuid, err) + } + return gpu, nil +} + +// ListGPUs lists all GPU resources +func (a *APIServer) ListGPUs() (*tfv1.GPUList, error) { + gpuList := &tfv1.GPUList{} + if err := a.client.List(a.ctx, gpuList); err != nil { + return nil, fmt.Errorf("failed to list GPUs: %w", err) + } + return gpuList, nil +} + +// UpdateGPUStatus updates the status of a GPU resource using merge patch +func (a *APIServer) UpdateGPUStatus(gpu *tfv1.GPU) error { + return retry.RetryOnConflict(retry.DefaultBackoff, func() error { + current := &tfv1.GPU{} + if err := a.client.Get(a.ctx, client.ObjectKeyFromObject(gpu), current); err != nil { + return err + } + + patch := client.MergeFrom(current.DeepCopy()) + current.Status = gpu.Status + return a.client.Status().Patch(a.ctx, current, patch) + }) +} + +// patchGPUStatus patches a specific GPU status field using a function +func (a *APIServer) patchGPUStatus(uuid string, updateFn func(*tfv1.GPU)) error { + return retry.RetryOnConflict(retry.DefaultBackoff, func() error { + gpu, err := a.GetGPU(uuid) + if err != nil { + return err + } + + patch := client.MergeFrom(gpu.DeepCopy()) + updateFn(gpu) + return a.client.Status().Patch(a.ctx, gpu, patch) + }) +} + +// UpdateGPUAvailableResources updates the available resources of a GPU +func (a *APIServer) UpdateGPUAvailableResources(uuid string, available *tfv1.Resource) error { + return a.patchGPUStatus(uuid, func(gpu *tfv1.GPU) { + gpu.Status.Available = available + }) +} + +// UpdateGPUPhase updates the phase of a GPU +func (a *APIServer) UpdateGPUPhase(uuid string, phase tfv1.TensorFusionGPUPhase) error { + return a.patchGPUStatus(uuid, func(gpu *tfv1.GPU) { + gpu.Status.Phase = phase + }) +} + +// GetGPUNode retrieves a GPUNode resource by name +func (a *APIServer) GetGPUNode(name string) (*tfv1.GPUNode, error) { + gpuNode := &tfv1.GPUNode{} + if err := a.client.Get(a.ctx, client.ObjectKey{Name: name}, gpuNode); err != nil { + return nil, fmt.Errorf("failed to get GPUNode %s: %w", name, err) + } + return gpuNode, nil +} + +// UpdateGPUNodeStatus updates the status of a GPUNode resource +func (a *APIServer) UpdateGPUNodeStatus( + gpuNode *tfv1.GPUNode, + totalTFlops, totalVRAM resource.Quantity, + totalGPUs int32, + deviceIDs []string, +) error { + return retry.RetryOnConflict(retry.DefaultBackoff, func() error { + current := &tfv1.GPUNode{} + if err := a.client.Get(a.ctx, client.ObjectKeyFromObject(gpuNode), current); err != nil { + return err + } + + patch := client.MergeFrom(current.DeepCopy()) + a.updateGPUNodeStatus(¤t.Status, totalTFlops, totalVRAM, totalGPUs, deviceIDs) + return a.client.Status().Patch(a.ctx, current, patch) + }) +} + +// updateGPUNodeStatus updates GPUNode status fields +func (a *APIServer) updateGPUNodeStatus( + status *tfv1.GPUNodeStatus, + totalTFlops, totalVRAM resource.Quantity, + totalGPUs int32, + deviceIDs []string, +) { + status.TotalTFlops = totalTFlops + status.TotalVRAM = totalVRAM + status.TotalGPUs = totalGPUs + status.ManagedGPUs = totalGPUs + status.ManagedGPUDeviceIDs = deviceIDs + + if status.Phase == "" { + status.Phase = tfv1.TensorFusionGPUNodePhasePending + } +} + +// DeleteGPU deletes a GPU resource +func (a *APIServer) DeleteGPU(uuid string) error { + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: uuid, + }, + } + if err := a.client.Delete(a.ctx, gpu); err != nil { + return fmt.Errorf("failed to delete GPU %s: %w", uuid, err) + } + return nil +} diff --git a/internal/hypervisor/backend/kubernetes/deviceplugin.go b/internal/hypervisor/backend/kubernetes/deviceplugin.go new file mode 100644 index 00000000..5a25cb73 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/deviceplugin.go @@ -0,0 +1,426 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubernetes + +import ( + "context" + "fmt" + "net" + "os" + "path/filepath" + "sync" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "k8s.io/klog/v2" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +const ( + // DevicePluginPath is the path where device plugins should register + DevicePluginPath = "/var/lib/kubelet/device-plugins" + // KubeletSocket is the kubelet registration socket + KubeletSocket = "kubelet.sock" + // ResourceName is the resource name advertised to kubelet + ResourceName = "tensor-fusion.ai/index" + // DevicePluginEndpoint is the endpoint name for this device plugin + DevicePluginEndpoint = "tensor-fusion-index.sock" +) + +// DevicePlugin implements the Kubernetes device plugin interface +type DevicePlugin struct { + pluginapi.UnimplementedDevicePluginServer + + ctx context.Context + deviceController framework.DeviceController + workerController framework.WorkerController + kubeletClient *PodCacheManager + + server *grpc.Server + socketPath string + resourceName string + + mu sync.RWMutex + devices []*pluginapi.Device + stopCh chan struct{} + updateCh chan []*pluginapi.Device +} + +// NewDevicePlugin creates a new device plugin instance +func NewDevicePlugin(ctx context.Context, deviceController framework.DeviceController, workerController framework.WorkerController, kubeletClient *PodCacheManager) *DevicePlugin { + return &DevicePlugin{ + ctx: ctx, + deviceController: deviceController, + workerController: workerController, + kubeletClient: kubeletClient, + socketPath: filepath.Join(DevicePluginPath, DevicePluginEndpoint), + resourceName: ResourceName, + stopCh: make(chan struct{}), + updateCh: make(chan []*pluginapi.Device, 1), + } +} + +// Start starts the device plugin gRPC server and registers with kubelet +func (dp *DevicePlugin) Start() error { + // Clean up any existing socket + // Check if file exists first to avoid permission errors on non-existent files + if _, err := os.Stat(dp.socketPath); err == nil { + // File exists, try to remove it + if err := os.Remove(dp.socketPath); err != nil { + return fmt.Errorf("failed to remove existing socket: %w", err) + } + } else if !os.IsNotExist(err) { + // Some other error checking file existence (e.g., permission denied on parent directory) + // Log warning but continue - net.Listen will handle it + klog.Warningf("Could not check socket file existence: %v", err) + } + + // Create directory if it doesn't exist + if err := os.MkdirAll(DevicePluginPath, 0750); err != nil { + return fmt.Errorf("failed to create device plugin directory: %w", err) + } + + // Create Unix socket listener + listener, err := net.Listen("unix", dp.socketPath) + if err != nil { + return fmt.Errorf("failed to create listener: %w", err) + } + + // Create gRPC server + dp.server = grpc.NewServer() + pluginapi.RegisterDevicePluginServer(dp.server, dp) + + // Start gRPC server + go func() { + klog.Infof("Starting device plugin gRPC server on %s", dp.socketPath) + if err := dp.server.Serve(listener); err != nil { + klog.Errorf("Device plugin gRPC server error: %v", err) + } + }() + + // Wait for server to be ready + conn, err := dp.dial(dp.socketPath, 5*time.Second) + if err != nil { + return fmt.Errorf("failed to dial device plugin socket: %w", err) + } + _ = conn.Close() + + // Register with kubelet + if err := dp.register(); err != nil { + return fmt.Errorf("failed to register with kubelet: %w", err) + } + + // Initialize device list with dummy index devices (1-512) + dp.updateDeviceList() + + // Start device monitoring + go dp.monitorDevices() + + return nil +} + +// Stop stops the device plugin +func (dp *DevicePlugin) Stop() error { + close(dp.stopCh) + if dp.server != nil { + dp.server.Stop() + } + return os.Remove(dp.socketPath) +} + +// register registers the device plugin with kubelet +func (dp *DevicePlugin) register() error { + kubeletSocketPath := filepath.Join(DevicePluginPath, KubeletSocket) + + // Check if kubelet socket exists + if _, err := os.Stat(kubeletSocketPath); os.IsNotExist(err) { + return fmt.Errorf("kubelet socket does not exist at %s (kubelet may not be running or device plugin support not enabled)", kubeletSocketPath) + } else if err != nil { + return fmt.Errorf("failed to check kubelet socket: %w", err) + } + + conn, err := dp.dial(kubeletSocketPath, 5*time.Second) + if err != nil { + return fmt.Errorf("failed to dial kubelet: %w", err) + } + defer func() { + _ = conn.Close() + }() + + client := pluginapi.NewRegistrationClient(conn) + req := &pluginapi.RegisterRequest{ + Version: pluginapi.Version, + Endpoint: DevicePluginEndpoint, + ResourceName: dp.resourceName, + Options: &pluginapi.DevicePluginOptions{ + PreStartRequired: false, + GetPreferredAllocationAvailable: false, + }, + } + + _, err = client.Register(context.Background(), req) + if err != nil { + return fmt.Errorf("failed to register: %w", err) + } + + klog.Infof("Successfully registered device plugin with kubelet: %s", dp.resourceName) + return nil +} + +// dial establishes a connection to a Unix socket +func (dp *DevicePlugin) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) { + // Use unix:// prefix for gRPC to recognize it as a Unix socket + // The dialer will receive the full address, so we need to strip the prefix + target := "unix://" + unixSocketPath + conn, err := grpc.NewClient(target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + // Strip unix:// prefix to get the actual socket path + socketPath := addr + if len(addr) > 7 && addr[:7] == "unix://" { + socketPath = addr[7:] + } + return net.DialTimeout("unix", socketPath, timeout) + }), + ) + return conn, err +} + +// monitorDevices periodically updates the device list +func (dp *DevicePlugin) monitorDevices() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-dp.ctx.Done(): + return + case <-dp.stopCh: + return + case <-ticker.C: + dp.updateDeviceList() + case devices := <-dp.updateCh: + dp.mu.Lock() + dp.devices = devices + dp.mu.Unlock() + } + } +} + +// updateDeviceList updates the list of available dummy index devices +// This device plugin registers tensor-fusion.ai/index resource, not real GPU devices. +// We advertise 512 dummy devices (indices 1-512) for pod identification. +// Real GPU devices are allocated by scheduler and set in pod annotations. +func (dp *DevicePlugin) updateDeviceList() { + dp.mu.Lock() + defer dp.mu.Unlock() + + // Advertise 512 dummy index devices (1-512) for pod identification + // These are NOT real GPU devices - they're just used to match pods by index + pluginDevices := make([]*pluginapi.Device, 0, 512) + for i := 1; i <= 512; i++ { + pluginDevices = append(pluginDevices, &pluginapi.Device{ + ID: fmt.Sprintf("%d", i), // Index as device ID + Health: pluginapi.Healthy, + }) + } + + dp.devices = pluginDevices + select { + case dp.updateCh <- pluginDevices: + default: + } +} + +// GetDevicePluginOptions returns options for the device plugin +func (dp *DevicePlugin) GetDevicePluginOptions(ctx context.Context, req *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { + return &pluginapi.DevicePluginOptions{ + PreStartRequired: false, + GetPreferredAllocationAvailable: false, + }, nil +} + +// ListAndWatch streams device list and health updates +func (dp *DevicePlugin) ListAndWatch(req *pluginapi.Empty, stream pluginapi.DevicePlugin_ListAndWatchServer) error { + klog.Info("ListAndWatch called") + + // Send initial device list + dp.updateDeviceList() + dp.mu.RLock() + devices := make([]*pluginapi.Device, len(dp.devices)) + copy(devices, dp.devices) + dp.mu.RUnlock() + + if err := stream.Send(&pluginapi.ListAndWatchResponse{Devices: devices}); err != nil { + return fmt.Errorf("failed to send device list: %w", err) + } + + // Watch for updates + for { + select { + case <-dp.ctx.Done(): + return nil + case <-dp.stopCh: + return nil + case devices := <-dp.updateCh: + if err := stream.Send(&pluginapi.ListAndWatchResponse{Devices: devices}); err != nil { + return fmt.Errorf("failed to send device update: %w", err) + } + } + } +} + +// Allocate handles device allocation requests from kubelet +// IMPORTANT: This device plugin registers tensor-fusion.ai/index as a dummy resource. +// The pod index (1-512) is used to identify which pod is requesting allocation. +// The actual GPU device UUIDs are already set by the centralized scheduler in pod annotations: +// - tensor-fusion.ai/gpu-ids: comma-separated GPU UUIDs (for all isolation modes) +// - tensor-fusion.ai/partition: partition template ID (only for partitioned isolation mode) +// +// The len(req.ContainerRequests) is just the number of containers in the pod requesting +// tensor-fusion.ai/index resource - it's NOT the pod index. The pod index comes from +// DevicesIds[0] which contains the index value from resource limits. +// +// We do NOT allocate the fake tensor-fusion.ai/index device - it's only used for pod identification. +// CDIDevices in the response is kept empty to prevent kubelet from allocating the dummy device. +func (dp *DevicePlugin) Allocate(ctx context.Context, req *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { + // len(req.ContainerRequests) identifies how many containers in the pod are requesting + // tensor-fusion.ai/index resource - this is for logging/identification only + klog.Infof("Allocate called with %d container requests (pod may have multiple containers)", len(req.ContainerRequests)) + + responses := make([]*pluginapi.ContainerAllocateResponse, 0, len(req.ContainerRequests)) + + for containerIdx, containerReq := range req.ContainerRequests { + // Extract pod index from DevicesIds - this contains the index value (1-512) from resource limits + // Resource limit: tensor-fusion.ai/index: 3 -> DevicesIds: ["3"] + // This is the actual pod index used to match the pod in the pod cache + podIndex := len(containerReq.DevicesIds) + if podIndex == 0 { + return nil, fmt.Errorf("container request %d has no DevicesIds (expected pod index value 1-512)", containerIdx) + } + + if podIndex < constants.IndexRangeStart || podIndex > constants.IndexRangeEnd { + return nil, fmt.Errorf("container request %d has index out of range: %d (expected 1-512)", containerIdx, podIndex) + } + + klog.V(4).Infof("Processing allocation for container index %d, pod index %d (from DevicesIds)", containerIdx, podIndex) + + // Get worker info from kubelet client using pod index + // This will automatically check for duplicate indices and fail fast if found + workerInfo, err := dp.kubeletClient.GetWorkerInfoForAllocationByIndex(ctx, podIndex) + if err != nil { + klog.Errorf("Failed to get worker info for pod index %d: %v", podIndex, err) + return nil, fmt.Errorf("failed to get worker info for pod index %d: %w", podIndex, err) + } + + if workerInfo == nil { + return nil, fmt.Errorf("worker info not found for pod index %d", podIndex) + } + + // Device UUIDs are already set by scheduler in annotations, not from DevicesIds + deviceUUIDs := workerInfo.AllocatedDevices + if len(deviceUUIDs) == 0 { + return nil, fmt.Errorf("no device UUIDs found in pod annotations for pod %s/%s", workerInfo.Namespace, workerInfo.PodName) + } + + // Call worker controller to allocate + allocResp, err := dp.workerController.AllocateWorker(workerInfo) + if err != nil { + return nil, fmt.Errorf("failed to allocate device: %w", err) + } + + // WorkerAllocation doesn't need Success/ErrMsg check - if no error, allocation succeeded + + // Build container response - create minimal response since allocation details are tracked separately + // IMPORTANT: CdiDevices must be empty to prevent dummy tensor-fusion.ai/index device + // from being allocated by kubelet + containerResp := &pluginapi.ContainerAllocateResponse{ + Envs: make(map[string]string), + Mounts: []*pluginapi.Mount{}, + Devices: []*pluginapi.DeviceSpec{}, + CdiDevices: []*pluginapi.CDIDevice{}, // Empty to prevent dummy device allocation + } + + // Add basic environment variables for worker info + if allocResp.WorkerInfo != nil { + containerResp.Envs["TF_WORKER_UID"] = allocResp.WorkerInfo.WorkerUID + containerResp.Envs["TF_POD_UID"] = allocResp.WorkerInfo.PodUID + + // Add device UUIDs as environment variable + if len(allocResp.DeviceInfos) > 0 { + deviceUUIDs := make([]string, 0, len(allocResp.DeviceInfos)) + for _, device := range allocResp.DeviceInfos { + deviceUUIDs = append(deviceUUIDs, device.UUID) + } + containerResp.Envs["TF_DEVICE_UUIDS"] = fmt.Sprintf("%v", deviceUUIDs) + } + } + + // Get pod to extract labels and annotations + pod := dp.kubeletClient.GetPodByUID(workerInfo.PodUID) + labels := make(map[string]string) + annotations := make(map[string]string) + if pod != nil { + if pod.Labels != nil { + labels = pod.Labels + } + if pod.Annotations != nil { + annotations = pod.Annotations + } + } + + // Update allocation in device controller with labels and annotations + // Use type assertion to access the concrete implementation + if deviceCtrl, ok := dp.deviceController.(interface { + UpdateAllocationLabelsAndAnnotations(workerUID string, labels, annotations map[string]string) + }); ok { + deviceCtrl.UpdateAllocationLabelsAndAnnotations(workerInfo.PodUID, labels, annotations) + } + + if err := dp.kubeletClient.StoreAllocation(workerInfo.PodUID, allocResp); err != nil { + klog.Warningf("Failed to store allocation: %v", err) + } + + // Remove PodIndexAnnotation after successful allocation to release the index + // This prevents the index from being matched to this pod in future allocation cycles + if err := dp.kubeletClient.RemovePodIndexAnnotation(ctx, workerInfo.PodUID, workerInfo.Namespace, workerInfo.PodName); err != nil { + klog.Warningf("Failed to remove pod index annotation for pod %s/%s: %v", workerInfo.Namespace, workerInfo.PodName, err) + // Don't fail allocation if annotation removal fails + } + + responses = append(responses, containerResp) + } + + return &pluginapi.AllocateResponse{ + ContainerResponses: responses, + }, nil +} + +// PreStartContainer is called before container start (optional) +func (dp *DevicePlugin) PreStartContainer(ctx context.Context, req *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { + return &pluginapi.PreStartContainerResponse{}, nil +} + +// GetPreferredAllocation returns preferred device allocation (optional) +func (dp *DevicePlugin) GetPreferredAllocation(ctx context.Context, req *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { + return &pluginapi.PreferredAllocationResponse{ + ContainerResponses: []*pluginapi.ContainerPreferredAllocationResponse{}, + }, nil +} diff --git a/internal/hypervisor/backend/kubernetes/deviceplugin_test.go b/internal/hypervisor/backend/kubernetes/deviceplugin_test.go new file mode 100644 index 00000000..3724d120 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/deviceplugin_test.go @@ -0,0 +1,81 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +// TestDevicePluginAllocate_ExtractsIndexFromDevicesIds tests that the device plugin +// correctly extracts the pod index from DevicesIds[0], not from len(req.ContainerRequests) +// This is a key test to verify the device plugin implementation matches the design: +// - DevicesIds[0] contains the index value (1-512) from resource limits +// - len(req.ContainerRequests) is just the number of containers, NOT the pod index +// - CdiDevices must be empty to prevent dummy device allocation +func TestDevicePluginAllocate_ExtractsIndexFromDevicesIds(t *testing.T) { + // This test verifies the key design principle: + // The pod index comes from DevicesIds[0], which contains the value from + // tensor-fusion.ai/index resource limit, NOT from len(req.ContainerRequests) + + req := &pluginapi.AllocateRequest{ + ContainerRequests: []*pluginapi.ContainerAllocateRequest{ + { + DevicesIds: []string{"3"}, // Index "3" from resource limit + }, + }, + } + + // Verify the structure: len(ContainerRequests) = 1, but index is "3" from DevicesIds[0] + assert.Len(t, req.ContainerRequests, 1, "Should have 1 container request") + assert.Equal(t, "3", req.ContainerRequests[0].DevicesIds[0], "Index should come from DevicesIds[0], not from len(ContainerRequests)") + + // This demonstrates that len(req.ContainerRequests) is NOT the pod index + // The pod index is extracted from DevicesIds[0] + assert.NotEqual(t, len(req.ContainerRequests), 3, "len(ContainerRequests) should NOT equal the pod index") +} + +// TestDevicePluginAllocate_MultipleContainers tests that len(req.ContainerRequests) +// is used for iteration, not for pod index identification +func TestDevicePluginAllocate_MultipleContainers(t *testing.T) { + // Create request with 2 containers, both with index "5" + // len(ContainerRequests) = 2, but pod index is still "5" from DevicesIds + req := &pluginapi.AllocateRequest{ + ContainerRequests: []*pluginapi.ContainerAllocateRequest{ + { + DevicesIds: []string{"5"}, // First container: index 5 + }, + { + DevicesIds: []string{"5"}, // Second container: same pod, same index + }, + }, + } + + // Verify: len(ContainerRequests) = 2, but index is "5" from DevicesIds + assert.Len(t, req.ContainerRequests, 2, "Should have 2 container requests") + assert.Equal(t, "5", req.ContainerRequests[0].DevicesIds[0], "First container index from DevicesIds") + assert.Equal(t, "5", req.ContainerRequests[1].DevicesIds[0], "Second container index from DevicesIds") + + // Key verification: len(ContainerRequests) is NOT the pod index + assert.NotEqual(t, len(req.ContainerRequests), 5, "len(ContainerRequests) should NOT equal the pod index") + + // Both containers have the same index because they're in the same pod + assert.Equal(t, req.ContainerRequests[0].DevicesIds[0], req.ContainerRequests[1].DevicesIds[0], + "Both containers should have the same index (same pod)") +} diff --git a/internal/hypervisor/backend/kubernetes/dra.go b/internal/hypervisor/backend/kubernetes/dra.go new file mode 100644 index 00000000..276009a4 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/dra.go @@ -0,0 +1 @@ +package kubernetes diff --git a/internal/hypervisor/backend/kubernetes/external_dp/detector_test.go b/internal/hypervisor/backend/kubernetes/external_dp/detector_test.go new file mode 100644 index 00000000..65a90192 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/external_dp/detector_test.go @@ -0,0 +1,263 @@ +package external_dp + +import ( + "context" + "os" + "testing" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// MockAPIServer is a mock implementation of APIServerInterface +type MockAPIServer struct { + mock.Mock +} + +func (m *MockAPIServer) GetGPU(uuid string) (*tfv1.GPU, error) { + args := m.Called(uuid) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*tfv1.GPU), args.Error(1) +} + +func (m *MockAPIServer) UpdateGPUStatus(gpu *tfv1.GPU) error { + args := m.Called(gpu) + return args.Error(0) +} + +// MockKubeletClient is a mock implementation of KubeletClientInterface +type MockKubeletClient struct { + mock.Mock + pods map[string]interface{} +} + +func (m *MockKubeletClient) GetAllPods() map[string]any { + return m.pods +} + +func TestReadCheckpointFile(t *testing.T) { + // Create a temporary checkpoint file with test data + testData := `{ + "Data": { + "PodDeviceEntries": [ + { + "PodUID": "a7461dc1-023a-4bd5-a403-c738bb1d7db4", + "ContainerName": "web", + "ResourceName": "nvidia.com/gpu", + "DeviceIDs": { + "-1": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + }, + "AllocResp": "CkIKFk5WSURJQV9WSVNJQkxFX0RFVklDRVMSKEdQVS03ZDg0MjlkNS01MzFkLWQ2YTYtNjUxMC0zYjY2MjA4MWE3NWEaJAoOL2Rldi9udmlkaWFjdGwSDi9kZXYvbnZpZGlhY3RsGgJydxomCg8vZGV2L252aWRpYS11dm0SDy9kZXYvbnZpZGlhLXV2bRoCcncaMgoVL2Rldi9udmlkaWEtdXZtLXRvb2xzEhUvZGV2L252aWRpYS11dm0tdG9vbHMaAnJ3Gi4KEy9kZXYvbnZpZGlhLW1vZGVzZXQSEy9kZXYvbnZpZGlhLW1vZGVzZXQaAnJ3GiAKDC9kZXYvbnZpZGlhMBIML2Rldi9udmlkaWEwGgJydw==" + } + ], + "RegisteredDevices": { + "nvidia.com/gpu": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + }, + "Checksum": 2262205670 +}` + + tmpFile, err := os.CreateTemp("", "checkpoint-*.json") + assert.NoError(t, err) + defer func() { + _ = os.Remove(tmpFile.Name()) + }() + + _, err = tmpFile.WriteString(testData) + assert.NoError(t, err) + _ = tmpFile.Close() + + detector := &DevicePluginDetector{ + checkpointPath: tmpFile.Name(), + } + + checkpoint, err := detector.readCheckpointFile() + assert.NoError(t, err) + assert.NotNil(t, checkpoint) + assert.Len(t, checkpoint.Data.PodDeviceEntries, 1) + assert.Equal(t, "a7461dc1-023a-4bd5-a403-c738bb1d7db4", checkpoint.Data.PodDeviceEntries[0].PodUID) + assert.Equal(t, "nvidia.com/gpu", checkpoint.Data.PodDeviceEntries[0].ResourceName) + assert.Contains(t, checkpoint.Data.RegisteredDevices, "nvidia.com/gpu") +} + +func TestExtractDeviceIDs(t *testing.T) { + checkpoint := &KubeletCheckpoint{ + Data: CheckpointData{ + PodDeviceEntries: []PodDeviceEntry{ + { + ResourceName: "nvidia.com/gpu", + DeviceIDs: map[string][]string{ + "-1": {"GPU-7d8429d5-531d-d6a6-6510-3b662081a75a"}, + }, + }, + }, + RegisteredDevices: map[string][]string{ + "nvidia.com/gpu": {"GPU-7d8429d5-531d-d6a6-6510-3b662081a75a"}, + }, + }, + } + + detector := &DevicePluginDetector{ + vendorDetectors: map[string]VendorDetector{ + "nvidia.com/gpu": NewNvidiaDevicePluginDetector(), + }, + } + + allocated, registered := detector.extractDeviceIDs(checkpoint) + assert.Contains(t, allocated, "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a") + assert.Contains(t, registered, "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a") +} + +func TestNvidiaDevicePluginDetector(t *testing.T) { + detector := NewNvidiaDevicePluginDetector() + assert.Equal(t, "nvidia.com/gpu", detector.GetResourceName()) + assert.Equal(t, string(tfv1.UsedByNvidiaDevicePlugin), detector.GetUsedBySystem()) +} + +func TestProcessDeviceState_DeviceAdded(t *testing.T) { + mockAPI := new(MockAPIServer) + mockKubelet := &MockKubeletClient{ + pods: map[string]interface{}{ + "a7461dc1-023a-4bd5-a403-c738bb1d7db4": struct{}{}, // Pod exists + }, + } + + checkpointData := `{ + "Data": { + "PodDeviceEntries": [ + { + "PodUID": "a7461dc1-023a-4bd5-a403-c738bb1d7db4", + "ContainerName": "web", + "ResourceName": "nvidia.com/gpu", + "DeviceIDs": { + "-1": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + } + ], + "RegisteredDevices": { + "nvidia.com/gpu": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + } +}` + + tmpFile, err := os.CreateTemp("", "checkpoint-*.json") + assert.NoError(t, err) + defer func() { + _ = os.Remove(tmpFile.Name()) + }() + + _, err = tmpFile.WriteString(checkpointData) + assert.NoError(t, err) + _ = tmpFile.Close() + + // Mock GPU resource + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a", + }, + Status: tfv1.GPUStatus{ + UsedBy: tfv1.UsedByTensorFusion, + }, + } + + mockAPI.On("GetGPU", "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a").Return(gpu, nil) + mockAPI.On("UpdateGPUStatus", mock.AnythingOfType("*v1.GPU")).Return(nil) + + detector := &DevicePluginDetector{ + ctx: context.Background(), + checkpointPath: tmpFile.Name(), + apiServer: mockAPI, + kubeletClient: mockKubelet, + vendorDetectors: map[string]VendorDetector{"nvidia.com/gpu": NewNvidiaDevicePluginDetector()}, + previousDeviceIDs: make(map[string]bool), + } + + err = detector.processDeviceState(false) + assert.NoError(t, err) + mockAPI.AssertExpectations(t) +} + +func TestProcessDeviceState_DeviceRemoved(t *testing.T) { + mockAPI := new(MockAPIServer) + mockKubelet := &MockKubeletClient{ + pods: map[string]interface{}{}, // No pods - device should be removed + } + + checkpointData := `{ + "Data": { + "PodDeviceEntries": [], + "RegisteredDevices": { + "nvidia.com/gpu": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + } +}` + + tmpFile, err := os.CreateTemp("", "checkpoint-*.json") + assert.NoError(t, err) + defer func() { + _ = os.Remove(tmpFile.Name()) + }() + + _, err = tmpFile.WriteString(checkpointData) + assert.NoError(t, err) + _ = tmpFile.Close() + + // Mock GPU resource that was previously allocated + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a", + }, + Status: tfv1.GPUStatus{ + UsedBy: tfv1.UsedByNvidiaDevicePlugin, + }, + } + + mockAPI.On("GetGPU", "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a").Return(gpu, nil) + mockAPI.On("UpdateGPUStatus", mock.AnythingOfType("*v1.GPU")).Return(nil) + + detector := &DevicePluginDetector{ + ctx: context.Background(), + checkpointPath: tmpFile.Name(), + apiServer: mockAPI, + kubeletClient: mockKubelet, + vendorDetectors: map[string]VendorDetector{"nvidia.com/gpu": NewNvidiaDevicePluginDetector()}, + previousDeviceIDs: map[string]bool{"gpu-7d8429d5-531d-d6a6-6510-3b662081a75a": true}, + } + + err = detector.processDeviceState(false) + assert.NoError(t, err) + mockAPI.AssertExpectations(t) +} + +func TestFindEntryForDevice(t *testing.T) { + checkpoint := &KubeletCheckpoint{ + Data: CheckpointData{ + PodDeviceEntries: []PodDeviceEntry{ + { + ResourceName: "nvidia.com/gpu", + DeviceIDs: map[string][]string{ + "-1": {"GPU-7d8429d5-531d-d6a6-6510-3b662081a75a"}, + }, + }, + }, + }, + } + + detector := &DevicePluginDetector{} + entry := detector.findEntryForDevice(checkpoint, "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a") + assert.Equal(t, "nvidia.com/gpu", entry.ResourceName) +} diff --git a/internal/hypervisor/backend/kubernetes/external_dp/kubelet_checkpoint.go b/internal/hypervisor/backend/kubernetes/external_dp/kubelet_checkpoint.go new file mode 100644 index 00000000..074ece2f --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/external_dp/kubelet_checkpoint.go @@ -0,0 +1,482 @@ +package external_dp + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "os" + "path/filepath" + "strings" + "sync" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/fsnotify/fsnotify" + "k8s.io/klog/v2" +) + +const ( + // Default kubelet checkpoint file path + defaultKubeletCheckpointPath = "/var/lib/kubelet/device-plugins/kubelet_internal_checkpoint" + + // Polling intervals + defaultPollInterval = 30 * time.Second + defaultPatchAllInterval = 120 * time.Second + patchAllIntervalJitter = 0.15 // ±15% jitter +) + +// KubeletCheckpoint represents the structure of kubelet device checkpoint file +type KubeletCheckpoint struct { + Data CheckpointData `json:"Data"` +} + +type CheckpointData struct { + PodDeviceEntries []PodDeviceEntry `json:"PodDeviceEntries,omitempty"` + RegisteredDevices map[string][]string `json:"RegisteredDevices,omitempty"` +} + +type PodDeviceEntry struct { + PodUID string `json:"PodUID"` + ContainerName string `json:"ContainerName"` + ResourceName string `json:"ResourceName"` + DeviceIDs map[string][]string `json:"DeviceIDs"` +} + +// VendorDetector interface for vendor-specific device plugin detectors +type VendorDetector interface { + // GetResourceName returns the resource name this detector handles (e.g., "nvidia.com/gpu") + GetResourceName() string + // GetUsedBySystem returns the UsedBy system name for this vendor + GetUsedBySystem() string +} + +// APIServerInterface defines the interface for GPU API operations +type APIServerInterface interface { + GetGPU(uuid string) (*tfv1.GPU, error) + UpdateGPUStatus(gpu *tfv1.GPU) error +} + +// KubeletClientInterface defines the interface for pod listing +type KubeletClientInterface interface { + GetAllPods() map[string]interface{} // Returns map of pod UID to pod (can be *corev1.Pod) +} + +// DevicePluginDetector watches kubelet device checkpoint and manages GPU resource patching +type DevicePluginDetector struct { + ctx context.Context + checkpointPath string + apiServer APIServerInterface + kubeletClient KubeletClientInterface + vendorDetectors map[string]VendorDetector // key: resource name + previousDeviceIDs map[string]bool + mu sync.RWMutex + watcher *fsnotify.Watcher + stopCh chan struct{} +} + +// NewDevicePluginDetector creates a new device plugin detector +func NewDevicePluginDetector( + ctx context.Context, + checkpointPath string, + apiServer APIServerInterface, + kubeletClient KubeletClientInterface, +) (*DevicePluginDetector, error) { + if checkpointPath == "" { + checkpointPath = defaultKubeletCheckpointPath + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("failed to create filesystem watcher: %w", err) + } + + detector := &DevicePluginDetector{ + ctx: ctx, + checkpointPath: checkpointPath, + apiServer: apiServer, + kubeletClient: kubeletClient, + vendorDetectors: make(map[string]VendorDetector), + previousDeviceIDs: make(map[string]bool), + watcher: watcher, + stopCh: make(chan struct{}), + } + + // Register vendor-specific detectors + detector.registerVendorDetectors() + + return detector, nil +} + +// registerVendorDetectors registers all vendor-specific detectors +func (d *DevicePluginDetector) registerVendorDetectors() { + // Register NVIDIA detector + nvdpDetector := NewNvidiaDevicePluginDetector() + d.vendorDetectors[nvdpDetector.GetResourceName()] = nvdpDetector + + // Add more vendor detectors here as needed + // amdDetector := NewAMDDevicePluginDetector() + // d.vendorDetectors[amdDetector.GetResourceName()] = amdDetector +} + +// Start starts watching the checkpoint file and processing device allocations +func (d *DevicePluginDetector) Start() error { + klog.Info("Starting device plugin detector", "checkpointPath", d.checkpointPath) + + // Setup filesystem watcher + if err := d.setupFilesystemWatcher(); err != nil { + klog.Warningf("Failed to setup filesystem watcher, falling back to polling only: %v", err) + } + + // Start processing loop + go d.run() + + return nil +} + +// Stop stops the detector +func (d *DevicePluginDetector) Stop() { + close(d.stopCh) + if d.watcher != nil { + _ = d.watcher.Close() + } +} + +// setupFilesystemWatcher sets up filesystem watcher for the checkpoint file +func (d *DevicePluginDetector) setupFilesystemWatcher() error { + // Watch the directory containing the checkpoint file + dir := filepath.Dir(d.checkpointPath) + if err := d.watcher.Add(dir); err != nil { + return fmt.Errorf("failed to watch directory %s: %w", dir, err) + } + + // Also watch the file itself if it exists + if _, err := os.Stat(d.checkpointPath); err == nil { + if err := d.watcher.Add(d.checkpointPath); err != nil { + klog.Warningf("Failed to watch checkpoint file directly: %v", err) + } + } + + klog.Infof("Filesystem watcher enabled for checkpoint file: %s", d.checkpointPath) + return nil +} + +// run is the main processing loop +func (d *DevicePluginDetector) run() { + // Create tickers for periodic polling + pollTicker := time.NewTicker(defaultPollInterval) + defer pollTicker.Stop() + + patchAllInterval := d.durationWithJitter(defaultPatchAllInterval, patchAllIntervalJitter) + patchAllTicker := time.NewTicker(patchAllInterval) + defer patchAllTicker.Stop() + + // Process initial state + if err := d.processDeviceState(false); err != nil { + klog.Errorf("Failed to process initial device state: %v", err) + } + + for { + select { + case <-d.ctx.Done(): + klog.Info("Device plugin detector shutdown requested") + return + + case <-d.stopCh: + klog.Info("Device plugin detector stopped") + return + + case event, ok := <-d.watcher.Events: + if !ok { + klog.Warning("Filesystem watcher channel closed, restarting watcher") + // Try to restart watcher + if err := d.setupFilesystemWatcher(); err != nil { + klog.Errorf("Failed to restart filesystem watcher: %v", err) + } + continue + } + + // Process checkpoint file changes + if event.Op&(fsnotify.Write|fsnotify.Create) != 0 && + (event.Name == d.checkpointPath || strings.HasSuffix(event.Name, filepath.Base(d.checkpointPath))) { + klog.V(4).Infof("Checkpoint file changed: %s", event.Name) + if err := d.processDeviceState(false); err != nil { + klog.Errorf("Failed to process device state after filesystem event: %v", err) + } + } + + case err := <-d.watcher.Errors: + if err != nil { + klog.Errorf("Filesystem watcher error: %v", err) + } + + case <-pollTicker.C: + // Periodic polling fallback + klog.V(4).Info("Periodic polling check") + if err := d.processDeviceState(false); err != nil { + klog.Errorf("Failed to process device state during periodic check: %v", err) + } + + case <-patchAllTicker.C: + // Periodic full patch check to handle deleted pods + klog.V(4).Info("Checking all devices for deleted pods") + if err := d.processDeviceState(true); err != nil { + klog.Errorf("Failed to process device state during patch all check: %v", err) + } + // Reset ticker with new jitter + patchAllTicker.Reset(d.durationWithJitter(defaultPatchAllInterval, patchAllIntervalJitter)) + } + } +} + +// processDeviceState reads and processes the device checkpoint state +func (d *DevicePluginDetector) processDeviceState(patchAllDevices bool) error { + // Read checkpoint file + checkpoint, err := d.readCheckpointFile() + if err != nil { + return fmt.Errorf("failed to read checkpoint file: %w", err) + } + + // Extract registered device IDs (for comparison) + _, registeredDeviceIDs := d.extractDeviceIDs(checkpoint) + + // Get current pods to check for deleted pods + currentPods := d.kubeletClient.GetAllPods() + currentPodUIDs := make(map[string]bool, len(currentPods)) + for uid := range currentPods { + currentPodUIDs[uid] = true + } + + // Build device ID to entry mapping for vendor-specific processing + deviceToEntry := make(map[string]PodDeviceEntry) + + // Filter allocated devices by checking if pods still exist + // This handles the case where pods are deleted but checkpoint isn't updated + validAllocatedDeviceIDs := make(map[string]bool) + + if checkpoint.Data.PodDeviceEntries != nil { + for _, entry := range checkpoint.Data.PodDeviceEntries { + // Check if we have a detector for this resource + if _, hasDetector := d.vendorDetectors[entry.ResourceName]; !hasDetector { + continue + } + + // Check if pod still exists + if !currentPodUIDs[entry.PodUID] { + // Pod was deleted, but checkpoint may still have it + // We'll handle this in the removed devices logic + continue + } + + // Extract device IDs from this entry + for _, deviceList := range entry.DeviceIDs { + for _, deviceID := range deviceList { + deviceIDLower := strings.ToLower(deviceID) + validAllocatedDeviceIDs[deviceIDLower] = true + deviceToEntry[deviceIDLower] = entry + } + } + } + } + + // Determine added and removed devices + d.mu.Lock() + previousDeviceIDs := make(map[string]bool, len(d.previousDeviceIDs)) + for k, v := range d.previousDeviceIDs { + previousDeviceIDs[k] = v + } + d.mu.Unlock() + + var addedDevices, removedDevices map[string]bool + + if patchAllDevices { + // Patch all devices: treat all allocated as added, and all registered but not allocated as removed + addedDevices = validAllocatedDeviceIDs + removedDevices = make(map[string]bool) + for deviceID := range registeredDeviceIDs { + if !validAllocatedDeviceIDs[deviceID] { + removedDevices[deviceID] = true + } + } + } else { + // Only process changes + addedDevices = make(map[string]bool) + removedDevices = make(map[string]bool) + + for deviceID := range validAllocatedDeviceIDs { + if !previousDeviceIDs[deviceID] { + addedDevices[deviceID] = true + } + } + + for deviceID := range previousDeviceIDs { + if !validAllocatedDeviceIDs[deviceID] { + removedDevices[deviceID] = true + } + } + } + + // Process added devices using vendor-specific detectors + hasError := false + for deviceID := range addedDevices { + entry, exists := deviceToEntry[deviceID] + if !exists { + // Try to find entry from checkpoint + entry = d.findEntryForDevice(checkpoint, deviceID) + } + + detector, hasDetector := d.vendorDetectors[entry.ResourceName] + if !hasDetector { + klog.Warningf("No detector found for resource %s, device %s", entry.ResourceName, deviceID) + continue + } + + usedBySystem := detector.GetUsedBySystem() + klog.Infof("Device added: %s, resource: %s, patching with usedBy: %s", deviceID, entry.ResourceName, usedBySystem) + if err := d.patchGPUResource(deviceID, usedBySystem); err != nil { + klog.Errorf("Failed to patch GPU resource for added device %s: %v", deviceID, err) + hasError = true + } + } + + // Process removed devices + for deviceID := range removedDevices { + // Find which resource this device belongs to + entry := d.findEntryForDevice(checkpoint, deviceID) + if entry.ResourceName == "" { + // Try to find from previous state - use NVIDIA as default + entry.ResourceName = "nvidia.com/gpu" + } + + usedBySystem := string(tfv1.UsedByTensorFusion) + klog.Infof("Device removed: %s, patching with usedBy: %s", deviceID, usedBySystem) + if err := d.patchGPUResource(deviceID, usedBySystem); err != nil { + klog.Errorf("Failed to patch GPU resource for removed device %s: %v", deviceID, err) + hasError = true + } + } + + // Update previous state only if no errors occurred + if !hasError { + d.mu.Lock() + d.previousDeviceIDs = validAllocatedDeviceIDs + d.mu.Unlock() + } + + return nil +} + +// patchGPUResource patches a GPU resource with the specified usedBy value +func (d *DevicePluginDetector) patchGPUResource(deviceID, usedBySystem string) error { + const maxRetries = 3 + + for i := 0; i < maxRetries; i++ { + // Get current GPU resource + gpu, err := d.apiServer.GetGPU(deviceID) + if err != nil { + if i < maxRetries-1 { + backoff := time.Duration(200*(1<= 2 { + // The last field is typically the container PID + // If there are multiple PIDs, the last one is in the innermost namespace + pidStr := fields[len(fields)-1] + pid, err := strconv.ParseUint(pidStr, 10, 32) + if err != nil { + return 0, fmt.Errorf("failed to parse container PID: %w", err) + } + return uint32(pid), nil + } + } + } + + if err := scanner.Err(); err != nil { + return 0, fmt.Errorf("failed to read status file: %w", err) + } + + return 0, fmt.Errorf("NSpid not found in status file") +} diff --git a/internal/hypervisor/backend/kubernetes/pod_cache.go b/internal/hypervisor/backend/kubernetes/pod_cache.go new file mode 100644 index 00000000..9c46fc68 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/pod_cache.go @@ -0,0 +1,438 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubernetes + +import ( + "context" + "fmt" + "slices" + "strconv" + "sync" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/utils" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/retry" + "k8s.io/klog/v2" +) + +// PodCacheManager manages pod watching and worker information extraction +type PodCacheManager struct { + ctx context.Context + clientset *kubernetes.Clientset + restConfig *rest.Config + nodeName string + + mu sync.RWMutex + podCache map[string]*corev1.Pod // key: pod UID + allocations map[string]*api.WorkerAllocation // key: pod UID + indexToWorkerInfo map[int]*api.WorkerInfo // key: pod index annotation + indexToPodList map[int][]string // key: pod index annotation, value: list of pod UIDs + stopCh chan struct{} + workerChangedCh chan struct{} +} + +// NewPodCacheManager creates a new pod cache manager +func NewPodCacheManager(ctx context.Context, restConfig *rest.Config, nodeName string) (*PodCacheManager, error) { + clientset, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return nil, fmt.Errorf("failed to create kubernetes clientset: %w", err) + } + + return &PodCacheManager{ + ctx: ctx, + clientset: clientset, + restConfig: restConfig, + nodeName: nodeName, + podCache: make(map[string]*corev1.Pod), + allocations: make(map[string]*api.WorkerAllocation), + indexToWorkerInfo: make(map[int]*api.WorkerInfo), + indexToPodList: make(map[int][]string), + stopCh: make(chan struct{}), + workerChangedCh: make(chan struct{}, 1), + }, nil +} + +// Start starts watching pods on this node +func (kc *PodCacheManager) Start() error { + // Create a field selector to watch only pods on this node + fieldSelector := fields.OneTermEqualSelector("spec.nodeName", kc.nodeName).String() + + // Create a label selector for pods with tensor-fusion.ai/enabled=true + labelSelector := labels.Set{ + constants.TensorFusionEnabledLabelKey: constants.TrueStringValue, + }.AsSelector().String() + + // Create list watcher + lw := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + options.FieldSelector = fieldSelector + options.LabelSelector = labelSelector + return kc.clientset.CoreV1().Pods(metav1.NamespaceAll).List(kc.ctx, options) + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + options.FieldSelector = fieldSelector + options.LabelSelector = labelSelector + return kc.clientset.CoreV1().Pods(metav1.NamespaceAll).Watch(kc.ctx, options) + }, + } + + // Create informer + _, controller := cache.NewInformerWithOptions(cache.InformerOptions{ + ListerWatcher: lw, + ObjectType: &corev1.Pod{}, + ResyncPeriod: 0, + Handler: cache.ResourceEventHandlerFuncs{ + AddFunc: kc.onPodAdd, + UpdateFunc: kc.onPodUpdate, + DeleteFunc: kc.onPodDelete, + }, + }) + + // Start the informer + go controller.Run(kc.stopCh) + + klog.Infof("Started watching pods on node %s with label %s=%s", kc.nodeName, constants.TensorFusionEnabledLabelKey, constants.TrueStringValue) + return nil +} + +// Stop stops the pod cache manager +func (kc *PodCacheManager) Stop() { + close(kc.stopCh) +} + +// onPodAdd handles pod addition events +func (kc *PodCacheManager) onPodAdd(obj interface{}) { + pod := obj.(*corev1.Pod) + kc.mu.Lock() + kc.podCache[string(pod.UID)] = pod + if podIndexAnno, exists := pod.Annotations[constants.PodIndexAnnotation]; exists { + if podIndex, err := strconv.Atoi(podIndexAnno); err == nil { + // Parse and store WorkerInfo + workerInfo := kc.extractWorkerInfo(pod, podIndexAnno) + kc.indexToWorkerInfo[podIndex] = workerInfo + // Add pod UID to indexToPodList + kc.indexToPodList[podIndex] = append(kc.indexToPodList[podIndex], string(pod.UID)) + } + } else { + klog.Errorf("Pod %s/%s has no index annotation", pod.Namespace, pod.Name) + } + kc.mu.Unlock() + + klog.V(4).Infof("Pod added: %s/%s (UID: %s)", pod.Namespace, pod.Name, pod.UID) + kc.notifyWorkerChanged() +} + +// onPodUpdate handles pod update events +func (kc *PodCacheManager) onPodUpdate(oldObj, newObj interface{}) { + oldPod := oldObj.(*corev1.Pod) + newPod := newObj.(*corev1.Pod) + + kc.mu.Lock() + kc.podCache[string(newPod.UID)] = newPod + + // Handle old index if it changed + oldPodIndexAnno, oldExists := oldPod.Annotations[constants.PodIndexAnnotation] + newPodIndexAnno, newExists := newPod.Annotations[constants.PodIndexAnnotation] + + if oldExists { + if oldPodIndex, err := strconv.Atoi(oldPodIndexAnno); err == nil { + // Remove pod UID from old index + kc.removePodFromIndex(oldPodIndex, string(newPod.UID)) + } + } + + // Update WorkerInfo cache if pod has index annotation + if newExists { + if podIndex, err := strconv.Atoi(newPodIndexAnno); err == nil { + // Parse and store WorkerInfo + workerInfo := kc.extractWorkerInfo(newPod, newPodIndexAnno) + kc.indexToWorkerInfo[podIndex] = workerInfo + // Add pod UID to indexToPodList if not already present + podUID := string(newPod.UID) + found := slices.Contains(kc.indexToPodList[podIndex], podUID) + if !found { + kc.indexToPodList[podIndex] = append(kc.indexToPodList[podIndex], podUID) + } + } + } + kc.mu.Unlock() + + klog.V(4).Infof("Pod updated: %s/%s (UID: %s)", newPod.Namespace, newPod.Name, newPod.UID) + + // Check if annotations changed (which might affect allocation) + if !podAnnotationsEqual(oldPod.Annotations, newPod.Annotations) { + kc.notifyWorkerChanged() + } +} + +// onPodDelete handles pod deletion events +func (kc *PodCacheManager) onPodDelete(obj interface{}) { + pod, ok := obj.(*corev1.Pod) + if !ok { + // Handle deleted final state unknown + tombstone, ok := obj.(cache.DeletedFinalStateUnknown) + if !ok { + klog.Errorf("Unexpected object type: %T", obj) + return + } + pod, ok = tombstone.Obj.(*corev1.Pod) + if !ok { + klog.Errorf("Tombstone contained object that is not a pod: %T", tombstone.Obj) + return + } + } + + kc.mu.Lock() + podUID := string(pod.UID) + delete(kc.podCache, podUID) + delete(kc.allocations, podUID) + // Clean up WorkerInfo cache and indexToPodList if pod had index annotation + if podIndexAnno, exists := pod.Annotations[constants.PodIndexAnnotation]; exists { + if podIndex, err := strconv.Atoi(podIndexAnno); err == nil { + delete(kc.indexToWorkerInfo, podIndex) + kc.removePodFromIndex(podIndex, podUID) + } + } + kc.mu.Unlock() + + klog.V(4).Infof("Pod deleted: %s/%s (UID: %s)", pod.Namespace, pod.Name, pod.UID) + kc.notifyWorkerChanged() +} + +// removePodFromIndex removes a pod UID from the indexToPodList for a given index +func (kc *PodCacheManager) removePodFromIndex(podIndex int, podUID string) { + podList := kc.indexToPodList[podIndex] + newList := make([]string, 0, len(podList)) + for _, uid := range podList { + if uid != podUID { + newList = append(newList, uid) + } + } + if len(newList) == 0 { + delete(kc.indexToPodList, podIndex) + } else { + kc.indexToPodList[podIndex] = newList + } +} + +// notifyWorkerChanged notifies that worker information has changed +func (kc *PodCacheManager) notifyWorkerChanged() { + select { + case kc.workerChangedCh <- struct{}{}: + default: + } +} + +// GetWorkerInfoForAllocationByIndex finds a pod by its index annotation and extracts worker info +func (kc *PodCacheManager) GetWorkerInfoForAllocationByIndex(ctx context.Context, podIndex int) (*api.WorkerInfo, error) { + var workerInfo *api.WorkerInfo + var lastErr error + + // Retry for at most 5 seconds using k8s retry utility with 10ms backoff + startTime := time.Now() + err := retry.OnError(wait.Backoff{ + Duration: 10 * time.Millisecond, + Factor: 1.4, + Jitter: 0.1, + Cap: 5 * time.Second, + }, func(err error) bool { + // Check if we've exceeded 5 seconds + if time.Since(startTime) >= 5*time.Second { + return false + } + // Retry if worker info not found + return true + }, func() error { + kc.mu.RLock() + defer kc.mu.RUnlock() + + // Check for duplicate index - fast fail if multiple pods have same index + if podList, exists := kc.indexToPodList[podIndex]; exists { + if len(podList) > 1 { + // Build error message with pod details + var matchingPods []string + for _, podUID := range podList { + if pod := kc.podCache[podUID]; pod != nil { + matchingPods = append(matchingPods, fmt.Sprintf("%s/%s (UID: %s)", pod.Namespace, pod.Name, podUID)) + } + } + lastErr = fmt.Errorf("duplicate index %d found in pods: %v", podIndex, matchingPods) + return lastErr + } + } + + // Find worker info with matching index annotation + if info, exists := kc.indexToWorkerInfo[podIndex]; exists { + workerInfo = info + return nil // Success, stop retrying + } + + lastErr = fmt.Errorf("worker info not found for pod index %d", podIndex) + return lastErr // Return error to trigger retry + }) + + if err != nil { + return nil, fmt.Errorf("worker info not found for pod index %d after retrying for 5 seconds: %w", podIndex, err) + } + + return workerInfo, nil +} + +// GetPodByUID retrieves a pod from the cache by its UID +func (kc *PodCacheManager) GetPodByUID(podUID string) *corev1.Pod { + kc.mu.RLock() + defer kc.mu.RUnlock() + return kc.podCache[podUID] +} + +// RemovePodIndexAnnotation removes the PodIndexAnnotation from a pod after successful allocation +func (kc *PodCacheManager) RemovePodIndexAnnotation(ctx context.Context, podUID string, namespace string, podName string) error { + kc.mu.RLock() + pod, exists := kc.podCache[podUID] + kc.mu.RUnlock() + + // TODO: too complex, just a raw patch should work! and delete pod_cache before calling apiserver API + + if !exists { + return fmt.Errorf("pod %s/%s not found in cache", namespace, podName) + } + + // Check if annotation exists + if pod.Annotations == nil { + return nil // Nothing to remove + } + + if _, exists := pod.Annotations[constants.PodIndexAnnotation]; !exists { + return nil // Annotation already removed + } + + // Use API client to patch pod and remove annotation + // Get fresh pod from API server + currentPod, err := kc.clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get pod %s/%s: %w", namespace, podName, err) + } + + // Create patch to remove annotation + if currentPod.Annotations == nil { + return nil // No annotations to remove + } + + if _, exists := currentPod.Annotations[constants.PodIndexAnnotation]; !exists { + return nil // Annotation already removed + } + + // Remove annotation + delete(currentPod.Annotations, constants.PodIndexAnnotation) + + // Update pod + _, err = kc.clientset.CoreV1().Pods(namespace).Update(ctx, currentPod, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("failed to update pod %s/%s: %w", namespace, podName, err) + } + + klog.Infof("Successfully removed PodIndexAnnotation from pod %s/%s", namespace, podName) + return nil +} + +// extractWorkerInfo extracts worker information from pod annotations using the common utility function +func (kc *PodCacheManager) extractWorkerInfo(pod *corev1.Pod, podIndex string) *api.WorkerInfo { + // Use common utility function to extract pod worker info + allocRequest, msg, err := utils.ComposeAllocationRequest(kc.ctx, pod) + if err != nil { + klog.Error(err, "Failed to compose allocation request for existing worker Pod, annotation may not be valid", "pod", pod.Name, "msg", msg) + return nil + } + info := &api.WorkerInfo{ + PodUID: string(pod.UID), + PodName: pod.Name, + Namespace: pod.Namespace, + Annotations: pod.Annotations, + PodIndex: podIndex, + AllocatedDevices: allocRequest.GPUNames, + IsolationMode: allocRequest.Isolation, + MemoryLimitBytes: uint64(allocRequest.Limit.Vram.Value()), + ComputeLimitUnits: uint32(allocRequest.Limit.ComputePercent.Value()), + TemplateID: allocRequest.PartitionTemplateID, + } + + return info +} + +// StoreAllocation stores allocation information +func (kc *PodCacheManager) StoreAllocation(podUID string, allocation *api.WorkerAllocation) error { + kc.mu.Lock() + defer kc.mu.Unlock() + kc.allocations[podUID] = allocation + return nil +} + +// GetWorkerChangedChan returns the channel for worker change notifications +func (kc *PodCacheManager) GetWorkerChangedChan() <-chan struct{} { + return kc.workerChangedCh +} + +// GetAllPods returns all pods currently in the cache +func (kc *PodCacheManager) GetAllPods() map[string]*corev1.Pod { + kc.mu.RLock() + defer kc.mu.RUnlock() + + result := make(map[string]*corev1.Pod, len(kc.podCache)) + for k, v := range kc.podCache { + result[k] = v + } + return result +} + +// podAnnotationsEqual checks if two annotation maps are equal (for relevant keys) +func podAnnotationsEqual(old, new map[string]string) bool { + if old == nil && new == nil { + return true + } + if old == nil || new == nil { + return false + } + + // Check relevant annotation keys + relevantKeys := []string{ + constants.GPUDeviceIDsAnnotation, + constants.IsolationModeAnnotation, + constants.VRAMLimitAnnotation, + constants.ComputeLimitAnnotation, + constants.WorkloadProfileAnnotation, + } + + for _, key := range relevantKeys { + if old[key] != new[key] { + return false + } + } + + return true +} diff --git a/internal/hypervisor/backend/single_node/filestate.go b/internal/hypervisor/backend/single_node/filestate.go new file mode 100644 index 00000000..d33a7996 --- /dev/null +++ b/internal/hypervisor/backend/single_node/filestate.go @@ -0,0 +1 @@ +package single_node diff --git a/internal/hypervisor/backend/single_node/single_node_backend.go b/internal/hypervisor/backend/single_node/single_node_backend.go new file mode 100644 index 00000000..84fabfce --- /dev/null +++ b/internal/hypervisor/backend/single_node/single_node_backend.go @@ -0,0 +1,246 @@ +package single_node + +import ( + "context" + "sync" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "k8s.io/klog/v2" +) + +type SingleNodeBackend struct { + ctx context.Context + deviceController framework.DeviceController + mu sync.RWMutex + workers map[string]*WorkerState // worker UID -> state + stopCh chan struct{} + stopOnce sync.Once + workerCh chan []*api.WorkerInfo + workerChCloseOnce sync.Once + workerStopCh chan struct{} + workerStopOnce sync.Once +} + +type WorkerState struct { + UID string + ProcessIDs []string + CreatedAt time.Time + LastUpdated time.Time +} + +func NewSingleNodeBackend(ctx context.Context, deviceController framework.DeviceController) *SingleNodeBackend { + return &SingleNodeBackend{ + ctx: ctx, + deviceController: deviceController, + workers: make(map[string]*WorkerState), + stopCh: make(chan struct{}), + } +} + +func (b *SingleNodeBackend) Start() error { + // Start periodic worker discovery + go b.periodicWorkerDiscovery() + return nil +} + +func (b *SingleNodeBackend) Stop() error { + // Use sync.Once to ensure stopCh is only closed once + b.stopOnce.Do(func() { + close(b.stopCh) + }) + // Close worker watch stop channel (safe to close even if nil) + if b.workerStopCh != nil { + b.workerStopOnce.Do(func() { + close(b.workerStopCh) + }) + } + return nil +} + +// discoverWorkers discovers workers from device allocations and updates the internal state +func (b *SingleNodeBackend) discoverWorkers() { + // Discover workers from device allocations + allocations, err := b.deviceController.GetDeviceAllocations("") + if err != nil { + klog.Errorf("Failed to get device allocations: %v", err) + return + } + + b.mu.Lock() + defer b.mu.Unlock() + + // Update worker states from allocations + for _, allocation := range allocations { + workerUID := allocation.WorkerInfo.WorkerUID + if workerUID == "" { + workerUID = allocation.WorkerInfo.PodUID + } + if workerUID == "" { + continue + } + + if _, exists := b.workers[workerUID]; !exists { + b.workers[workerUID] = &WorkerState{ + UID: workerUID, + ProcessIDs: []string{}, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + } else { + b.workers[workerUID].LastUpdated = time.Now() + } + } + + // Remove workers that no longer have allocations + activeWorkers := make(map[string]bool) + for _, allocation := range allocations { + workerUID := allocation.WorkerInfo.WorkerUID + if workerUID == "" { + workerUID = allocation.WorkerInfo.PodUID + } + if workerUID != "" { + activeWorkers[workerUID] = true + } + } + + for workerUID := range b.workers { + if !activeWorkers[workerUID] { + delete(b.workers, workerUID) + } + } +} + +func (b *SingleNodeBackend) periodicWorkerDiscovery() { + // Run initial discovery immediately + b.discoverWorkers() + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-b.stopCh: + return + case <-b.ctx.Done(): + return + case <-ticker.C: + b.discoverWorkers() + } + } +} + +func (b *SingleNodeBackend) ListAndWatchWorkers() (<-chan []*api.WorkerInfo, <-chan struct{}, error) { + // Initialize channels if not already created + if b.workerCh == nil { + b.workerCh = make(chan []*api.WorkerInfo, 1) + b.workerStopCh = make(chan struct{}) + } + + // Send initial worker list and watch for changes + go func() { + defer b.workerChCloseOnce.Do(func() { + close(b.workerCh) + }) + + // Trigger immediate discovery before sending initial list + b.discoverWorkers() + + // Send initial list + b.mu.RLock() + workers := make([]*api.WorkerInfo, 0, len(b.workers)) + for workerUID := range b.workers { + workers = append(workers, &api.WorkerInfo{ + WorkerUID: workerUID, + }) + } + b.mu.RUnlock() + + select { + case b.workerCh <- workers: + case <-b.ctx.Done(): + return + case <-b.workerStopCh: + return + } + + // Watch for changes via periodic discovery (already running in background) + // The periodic discovery will update b.workers, but we don't have a direct + // notification mechanism, so we'll poll periodically + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-b.ctx.Done(): + return + case <-b.workerStopCh: + return + case <-ticker.C: + // Trigger discovery before sending update + b.discoverWorkers() + + b.mu.RLock() + workers := make([]*api.WorkerInfo, 0, len(b.workers)) + for workerUID := range b.workers { + workers = append(workers, &api.WorkerInfo{ + WorkerUID: workerUID, + AllocatedDevices: []string{"dummy"}, + }) + } + b.mu.RUnlock() + + select { + case b.workerCh <- workers: + case <-b.ctx.Done(): + return + case <-b.workerStopCh: + return + } + } + } + }() + + return b.workerCh, b.workerStopCh, nil +} + +func (b *SingleNodeBackend) GetWorkerToProcessMap() (map[string][]string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + result := make(map[string][]string) + for workerUID, state := range b.workers { + result[workerUID] = append([]string{}, state.ProcessIDs...) + } + return result, nil +} + +func (b *SingleNodeBackend) StartWorker(workerUID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + if _, exists := b.workers[workerUID]; !exists { + b.workers[workerUID] = &WorkerState{ + UID: workerUID, + ProcessIDs: []string{}, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + } + return nil +} + +func (b *SingleNodeBackend) StopWorker(workerUID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + delete(b.workers, workerUID) + return nil +} + +func (b *SingleNodeBackend) ReconcileDevices(devices []string) error { + // In single node mode, we don't need to reconcile with external systems + // Devices are managed locally + return nil +} diff --git a/internal/hypervisor/device/accelerator.go b/internal/hypervisor/device/accelerator.go new file mode 100644 index 00000000..1b407b2b --- /dev/null +++ b/internal/hypervisor/device/accelerator.go @@ -0,0 +1,331 @@ +package device + +/* +#cgo CFLAGS: -I../../../provider +#cgo LDFLAGS: -ldl +#include "../../../provider/accelerator.h" +#include +#include +#include +#include +#include +#include + +// Forward declarations from wrapper.c +extern int loadAcceleratorLibrary(const char* libPath); +extern void unloadAcceleratorLibrary(void); +extern Result GetDeviceCountWrapper(size_t* deviceCount); +extern Result GetAllDevicesWrapper(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount); +extern Result GetPartitionTemplatesWrapper(int32_t deviceIndex, PartitionTemplate* templates, size_t maxCount, size_t* templateCount); +extern bool AssignPartitionWrapper(PartitionAssignment* assignment); +extern bool RemovePartitionWrapper(const char* templateId, const char* deviceUUID); +extern Result SetMemHardLimitWrapper(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes); +extern Result SetComputeUnitHardLimitWrapper(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit); +extern Result GetProcessComputeUtilizationWrapper(ComputeUtilization* utilizations, size_t maxCount, size_t* utilizationCount); +extern Result GetProcessMemoryUtilizationWrapper(MemoryUtilization* utilizations, size_t maxCount, size_t* utilizationCount); +extern const char* getDlError(void); +*/ +import "C" +import ( + "fmt" + "sync" + "unsafe" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" +) + +// AcceleratorInterface provides Go bindings for the C accelerator library +type AcceleratorInterface struct { + libPath string + // deviceProcesses maps device UUID to list of process IDs + deviceProcesses map[string][]string + mu sync.RWMutex + loaded bool +} + +// NewAcceleratorInterface creates a new accelerator interface and loads the library +func NewAcceleratorInterface(libPath string) (*AcceleratorInterface, error) { + accel := &AcceleratorInterface{ + libPath: libPath, + deviceProcesses: make(map[string][]string), + loaded: false, + } + + // Load the library + if err := accel.Load(); err != nil { + return nil, fmt.Errorf("failed to load accelerator library from %s: %w", libPath, err) + } + + return accel, nil +} + +// Load loads the accelerator library dynamically +func (a *AcceleratorInterface) Load() error { + if a.libPath == "" { + return fmt.Errorf("library path is empty") + } + + cLibPath := C.CString(a.libPath) + defer C.free(unsafe.Pointer(cLibPath)) + + result := C.loadAcceleratorLibrary(cLibPath) + if result != 0 { + var errMsg string + if dlErr := C.getDlError(); dlErr != nil { + errMsg = C.GoString(dlErr) + } else { + errMsg = "unknown error" + } + + switch result { + case -1: + return fmt.Errorf("failed to load library: %s", errMsg) + case -2: + return fmt.Errorf("missing required symbols in library: %s", errMsg) + } + return fmt.Errorf("failed to load library (code %d): %s", result, errMsg) + } + + a.loaded = true + return nil +} + +// Close unloads the accelerator library +func (a *AcceleratorInterface) Close() error { + if a.loaded { + C.unloadAcceleratorLibrary() + a.loaded = false + } + return nil +} + +// GetTotalProcessCount returns the total number of processes across all devices +func (a *AcceleratorInterface) GetTotalProcessCount() int { + a.mu.RLock() + defer a.mu.RUnlock() + + total := 0 + for _, processes := range a.deviceProcesses { + total += len(processes) + } + return total +} + +// GetAllDevices retrieves all available devices from the accelerator library +func (a *AcceleratorInterface) GetAllDevices() ([]*api.DeviceInfo, error) { + // First, get the device count + var cDeviceCount C.size_t + //nolint:staticcheck + result := C.GetDeviceCountWrapper(&cDeviceCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get device count: %d", result) + } + + if cDeviceCount == 0 { + return []*api.DeviceInfo{}, nil + } + + // Allocate stack buffer (max 256 devices to avoid stack overflow) + const maxStackDevices = 256 + var stackDevices [maxStackDevices]C.ExtendedDeviceInfo + maxDevices := int(cDeviceCount) + if maxDevices > maxStackDevices { + maxDevices = maxStackDevices + } + + var cCount C.size_t + //nolint:staticcheck + result = C.GetAllDevicesWrapper(&stackDevices[0], C.size_t(maxDevices), &cCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get all devices: %d", result) + } + + if cCount == 0 { + return []*api.DeviceInfo{}, nil + } + + devices := make([]*api.DeviceInfo, int(cCount)) + + for i := 0; i < int(cCount); i++ { + cInfo := &stackDevices[i] + devices[i] = &api.DeviceInfo{ + UUID: C.GoString(&cInfo.basic.uuid[0]), + Vendor: C.GoString(&cInfo.basic.vendor[0]), + Model: C.GoString(&cInfo.basic.model[0]), + Index: int32(cInfo.basic.index), + NUMANode: int32(cInfo.basic.numaNode), + TotalMemoryBytes: uint64(cInfo.basic.totalMemoryBytes), + MaxTflops: float64(cInfo.basic.maxTflops), + Capabilities: api.DeviceCapabilities{ + SupportsPartitioning: bool(cInfo.capabilities.supportsPartitioning), + SupportsSoftIsolation: bool(cInfo.capabilities.supportsSoftIsolation), + SupportsHardIsolation: bool(cInfo.capabilities.supportsHardIsolation), + SupportsSnapshot: bool(cInfo.capabilities.supportsSnapshot), + SupportsMetrics: bool(cInfo.capabilities.supportsMetrics), + MaxPartitions: uint32(cInfo.capabilities.maxPartitions), + MaxWorkersPerDevice: uint32(cInfo.capabilities.maxWorkersPerDevice), + }, + Properties: make(map[string]string, 0), + } + } + + return devices, nil +} + +// AssignPartition assigns a partition to a device +func (a *AcceleratorInterface) AssignPartition(templateID, deviceUUID string) (string, uint64, error) { + cTemplateID := C.CString(templateID) + defer C.free(unsafe.Pointer(cTemplateID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + var assignment C.PartitionAssignment + C.strncpy(&assignment.templateId[0], cTemplateID, C.size_t(len(templateID))) + C.strncpy(&assignment.deviceUUID[0], cDeviceUUID, C.size_t(len(deviceUUID))) + + //nolint:staticcheck + result := C.AssignPartitionWrapper(&assignment) + if !result { + return "", 0, fmt.Errorf("failed to assign partition") + } + + partitionUUID := C.GoString(&assignment.partitionUUID[0]) + overhead := uint64(assignment.partitionOverheadBytes) + + return partitionUUID, overhead, nil +} + +// RemovePartition removes a partition from a device +func (a *AcceleratorInterface) RemovePartition(templateID, deviceUUID string) error { + cTemplateID := C.CString(templateID) + defer C.free(unsafe.Pointer(cTemplateID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + //nolint:staticcheck + result := C.RemovePartitionWrapper(cTemplateID, cDeviceUUID) + if !result { + return fmt.Errorf("failed to remove partition") + } + + return nil +} + +// SetMemHardLimit sets hard memory limit for a worker +func (a *AcceleratorInterface) SetMemHardLimit(workerID, deviceUUID string, memoryLimitBytes uint64) error { + cWorkerID := C.CString(workerID) + defer C.free(unsafe.Pointer(cWorkerID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + //nolint:staticcheck + result := C.SetMemHardLimitWrapper(cWorkerID, cDeviceUUID, C.uint64_t(memoryLimitBytes)) + if result != C.RESULT_SUCCESS { + return fmt.Errorf("failed to set memory hard limit: %d", result) + } + + return nil +} + +// SetComputeUnitHardLimit sets hard compute unit limit for a worker +func (a *AcceleratorInterface) SetComputeUnitHardLimit(workerID, deviceUUID string, computeUnitLimit uint32) error { + cWorkerID := C.CString(workerID) + defer C.free(unsafe.Pointer(cWorkerID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + //nolint:staticcheck + result := C.SetComputeUnitHardLimitWrapper(cWorkerID, cDeviceUUID, C.uint32_t(computeUnitLimit)) + if result != C.RESULT_SUCCESS { + return fmt.Errorf("failed to set compute unit hard limit: %d", result) + } + + return nil +} + +// GetProcessComputeUtilization retrieves compute utilization for all tracked processes +func (a *AcceleratorInterface) GetProcessComputeUtilization() ([]api.ComputeUtilization, error) { + // Get total process count from the map + totalCount := a.GetTotalProcessCount() + if totalCount == 0 { + return []api.ComputeUtilization{}, nil + } + + // Allocate stack buffer (max 1024 to avoid stack overflow) + const maxStackUtilizations = 1024 + var stackUtilizations [maxStackUtilizations]C.ComputeUtilization + maxCount := totalCount + if maxCount > maxStackUtilizations { + maxCount = maxStackUtilizations + } + + var cCount C.size_t + //nolint:staticcheck + result := C.GetProcessComputeUtilizationWrapper(&stackUtilizations[0], C.size_t(maxCount), &cCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get process compute utilization: %d", result) + } + + if cCount == 0 { + return []api.ComputeUtilization{}, nil + } + + utilizations := make([]api.ComputeUtilization, int(cCount)) + for i := 0; i < int(cCount); i++ { + cu := &stackUtilizations[i] + utilizations[i] = api.ComputeUtilization{ + ProcessID: C.GoString(&cu.processId[0]), + DeviceUUID: C.GoString(&cu.deviceUUID[0]), + UtilizationPercent: float64(cu.utilizationPercent), + // Note: ActiveSMs, TotalSMs, and TFLOPsUsed will be added to ComputeUtilization if needed + } + } + + return utilizations, nil +} + +// GetProcessMemoryUtilization retrieves memory utilization for all tracked processes +func (a *AcceleratorInterface) GetProcessMemoryUtilization() ([]api.MemoryUtilization, error) { + // Get total process count from the map + totalCount := a.GetTotalProcessCount() + if totalCount == 0 { + return []api.MemoryUtilization{}, nil + } + + // Allocate stack buffer (max 1024 to avoid stack overflow) + const maxStackUtilizations = 1024 + var stackUtilizations [maxStackUtilizations]C.MemoryUtilization + maxCount := totalCount + if maxCount > maxStackUtilizations { + maxCount = maxStackUtilizations + } + + var cCount C.size_t + //nolint:staticcheck + result := C.GetProcessMemoryUtilizationWrapper(&stackUtilizations[0], C.size_t(maxCount), &cCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get process memory utilization: %d", result) + } + + if cCount == 0 { + return []api.MemoryUtilization{}, nil + } + + utilizations := make([]api.MemoryUtilization, int(cCount)) + for i := 0; i < int(cCount); i++ { + mu := &stackUtilizations[i] + utilizations[i] = api.MemoryUtilization{ + ProcessID: C.GoString(&mu.processId[0]), + DeviceUUID: C.GoString(&mu.deviceUUID[0]), + UsedBytes: uint64(mu.usedBytes), + ReservedBytes: uint64(mu.reservedBytes), + // Note: UtilizationPercent will be calculated separately if needed + } + } + + return utilizations, nil +} diff --git a/internal/hypervisor/device/controller.go b/internal/hypervisor/device/controller.go new file mode 100644 index 00000000..2f7025e4 --- /dev/null +++ b/internal/hypervisor/device/controller.go @@ -0,0 +1,335 @@ +package device + +import ( + "context" + "fmt" + "sync" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "k8s.io/klog/v2" +) + +// Controller manages GPU device discovery, allocation, and lifecycle +type Controller struct { + ctx context.Context + mu sync.RWMutex + devices map[string]*api.DeviceInfo // key: device UUID + allocations map[string]*api.WorkerInfo // key: worker UID + deviceToAlloc map[string][]string // device UUID -> []worker UID + accelerator *AcceleratorInterface + discoveryInterval time.Duration +} + +var _ framework.DeviceController = &Controller{} + +// NewController creates a new device manager +func NewController(ctx context.Context, acceleratorLibPath string, discoveryInterval time.Duration) (framework.DeviceController, error) { + accel, err := NewAcceleratorInterface(acceleratorLibPath) + if err != nil { + return nil, fmt.Errorf("failed to create accelerator interface: %w", err) + } + return &Controller{ + ctx: ctx, + devices: make(map[string]*api.DeviceInfo), + allocations: make(map[string]*api.WorkerInfo), + deviceToAlloc: make(map[string][]string), + accelerator: accel, + discoveryInterval: discoveryInterval, + }, nil +} + +// DiscoverDevices discovers all available GPU devices +func (m *Controller) StartDiscoverDevices() error { + // Initial device discovery + if err := m.discoverDevices(); err != nil { + return fmt.Errorf("initial device discovery failed: %w", err) + } + + go m.periodicDiscovery() + return nil +} + +// discoverDevices discovers all available GPU devices +func (m *Controller) discoverDevices() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Get all devices at once + devices, err := m.accelerator.GetAllDevices() + if err != nil { + return fmt.Errorf("failed to get all devices: %w", err) + } + + // Update device map + for _, device := range devices { + m.devices[device.UUID] = device + } + + return nil +} + +// periodicDiscovery periodically discovers devices +func (m *Controller) periodicDiscovery() { + ticker := time.NewTicker(m.discoveryInterval) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + if err := m.discoverDevices(); err != nil { + // Log error but continue + continue + } + } + } +} + +// GetDevices returns all discovered devices +func (m *Controller) GetDevices() []*api.DeviceInfo { + m.mu.RLock() + defer m.mu.RUnlock() + + devices := make([]*api.DeviceInfo, 0, len(m.devices)) + for _, device := range m.devices { + devices = append(devices, device) + } + return devices +} + +// getDevice returns a device by UUID (internal method) +func (m *Controller) getDevice(uuid string) (*api.DeviceInfo, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + device, exists := m.devices[uuid] + return device, exists +} + +// Deallocate de-allocates devices for a pod +func (m *Controller) Deallocate(workerUID string) error { + m.mu.Lock() + defer m.mu.Unlock() + + allocation, exists := m.allocations[workerUID] + if !exists { + return fmt.Errorf("allocation not found for pod %s", workerUID) + } + + // Handle partitioned mode cleanup + if allocation.IsolationMode == tfv1.IsolationModePartitioned && allocation.TemplateID != "" { + if err := m.accelerator.RemovePartition(allocation.TemplateID, allocation.AllocatedDevices[0]); err != nil { + // Log error but continue + klog.Errorf("failed to remove partition: %v", err) + } + } + + // Remove from allocations + delete(m.allocations, workerUID) + + // Remove from device mapping + for _, deviceUUID := range allocation.AllocatedDevices { + if workerUIDs, exists := m.deviceToAlloc[deviceUUID]; exists { + for i, uid := range workerUIDs { + if uid == workerUID { + m.deviceToAlloc[deviceUUID] = append(workerUIDs[:i], workerUIDs[i+1:]...) + break + } + } + } + } + + return nil +} + +// GetAllocation returns allocation for a pod +func (m *Controller) GetAllocation(workerUID string) (*api.WorkerInfo, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + allocation, exists := m.allocations[workerUID] + return allocation, exists +} + +// Start implements framework.DeviceController +func (m *Controller) Start() error { + // Start device discovery + return m.StartDiscoverDevices() +} + +// DiscoverDevices implements framework.DeviceController +func (m *Controller) DiscoverDevices() error { + return m.discoverDevices() +} + +// ListDevices implements framework.DeviceController +func (m *Controller) ListDevices() ([]*api.DeviceInfo, error) { + return m.GetDevices(), nil +} + +// DevicesUpdates implements framework.DeviceController +func (m *Controller) DevicesUpdates() (<-chan []*api.DeviceInfo, error) { + ch := make(chan []*api.DeviceInfo, 1) + // Send initial device list + go func() { + devices := m.GetDevices() + select { + case ch <- devices: + default: + } + // TODO: Implement proper device updates channel with periodic updates + // Channel will be closed when controller is stopped + }() + return ch, nil +} + +// GetDevice implements framework.DeviceController +func (m *Controller) GetDevice(deviceUUID string) (*api.DeviceInfo, error) { + device, exists := m.getDevice(deviceUUID) + if !exists { + return nil, fmt.Errorf("device not found: %s", deviceUUID) + } + return device, nil +} + +// GetDeviceAllocations implements framework.DeviceController +func (m *Controller) GetDeviceAllocations(deviceUUID string) ([]*api.WorkerAllocation, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + var workerUIDs []string + if deviceUUID == "" { + // Return all allocations + workerUIDs = make([]string, 0, len(m.allocations)) + for workerUID := range m.allocations { + workerUIDs = append(workerUIDs, workerUID) + } + } else { + // Return allocations for specific device + workerUIDs = m.deviceToAlloc[deviceUUID] + } + + allocations := make([]*api.WorkerAllocation, 0, len(workerUIDs)) + for _, workerUID := range workerUIDs { + if workerInfo, exists := m.allocations[workerUID]; exists { + // Create WorkerAllocation with WorkerInfo and DeviceInfos + deviceInfos := make([]*api.DeviceInfo, 0, len(workerInfo.AllocatedDevices)) + for _, devUUID := range workerInfo.AllocatedDevices { + if device, devExists := m.devices[devUUID]; devExists { + deviceInfos = append(deviceInfos, device) + } + } + + allocation := &api.WorkerAllocation{ + WorkerInfo: workerInfo, + DeviceInfos: deviceInfos, + } + allocations = append(allocations, allocation) + } + } + return allocations, nil +} + +// GetDeviceAllocationUpdates implements framework.DeviceController +func (m *Controller) GetDeviceAllocationUpdates(deviceUUID string, allocationID string) (<-chan []*api.WorkerAllocation, error) { + ch := make(chan []*api.WorkerAllocation, 1) + // Send initial allocation list + go func() { + allocations, err := m.GetDeviceAllocations(deviceUUID) + if err == nil { + select { + case ch <- allocations: + default: + } + } + // TODO: Implement proper allocation updates channel with periodic updates + // Channel will be closed when controller is stopped + }() + return ch, nil +} + +// GetGPUMetrics implements framework.DeviceController +func (m *Controller) GetGPUMetrics() (map[string]*api.GPUUsageMetrics, error) { + m.mu.RLock() + devices := make([]*api.DeviceInfo, 0, len(m.devices)) + for _, device := range m.devices { + devices = append(devices, device) + } + m.mu.RUnlock() + + // Get device metrics from accelerator interface + // Note: This requires GetDeviceMetrics from accelerator.h which needs to be implemented + // For now, we'll use process-level metrics to aggregate + result := make(map[string]*api.GPUUsageMetrics) + + // Get memory utilization from processes + memUtils, err := m.accelerator.GetProcessMemoryUtilization() + if err != nil { + // If we can't get metrics, return empty metrics for each device + for _, device := range devices { + result[device.UUID] = &api.GPUUsageMetrics{ + DeviceUUID: device.UUID, + } + } + return result, nil + } + + // Aggregate memory usage per device + deviceMemoryUsed := make(map[string]uint64) + for _, memUtil := range memUtils { + deviceMemoryUsed[memUtil.DeviceUUID] += memUtil.UsedBytes + } + + // Get compute utilization + computeUtils, err := m.accelerator.GetProcessComputeUtilization() + if err != nil { + // Continue with memory metrics only + computeUtils = []api.ComputeUtilization{} + } + + // Aggregate compute usage per device + deviceComputePercent := make(map[string]float64) + deviceComputeTflops := make(map[string]float64) + for _, computeUtil := range computeUtils { + deviceComputePercent[computeUtil.DeviceUUID] += computeUtil.UtilizationPercent + // Note: TFLOPs calculation will be implemented separately based on device capabilities + } + + // Build metrics for each device + for _, device := range devices { + memoryUsed := deviceMemoryUsed[device.UUID] + memoryPercent := 0.0 + if device.TotalMemoryBytes > 0 { + memoryPercent = float64(memoryUsed) / float64(device.TotalMemoryBytes) * 100.0 + } + + result[device.UUID] = &api.GPUUsageMetrics{ + DeviceUUID: device.UUID, + MemoryBytes: memoryUsed, + MemoryPercentage: memoryPercent, + ComputePercentage: deviceComputePercent[device.UUID], + ComputeTflops: deviceComputeTflops[device.UUID], + } + } + + return result, nil +} + +// GetProcessComputeUtilization exposes accelerator interface method +func (m *Controller) GetProcessComputeUtilization() ([]api.ComputeUtilization, error) { + return m.accelerator.GetProcessComputeUtilization() +} + +// GetProcessMemoryUtilization exposes accelerator interface method +func (m *Controller) GetProcessMemoryUtilization() ([]api.MemoryUtilization, error) { + return m.accelerator.GetProcessMemoryUtilization() +} + +// Close closes the device controller and unloads the accelerator library +func (m *Controller) Close() error { + return m.accelerator.Close() +} diff --git a/internal/hypervisor/device/provider_log.go b/internal/hypervisor/device/provider_log.go new file mode 100644 index 00000000..aa425e78 --- /dev/null +++ b/internal/hypervisor/device/provider_log.go @@ -0,0 +1,56 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package device + +/* +#cgo CFLAGS: -I../../../provider +#include +*/ +import "C" +import ( + "k8s.io/klog/v2" +) + +// GoLog is exported to C code via //export directive +// This function is called by C code (wrapper.c) to log messages using klog +// +//export GoLog +func GoLog(level *C.char, message *C.char) { + if level == nil || message == nil { + return + } + + levelStr := C.GoString(level) + messageStr := C.GoString(message) + + // Map C log levels to klog levels + switch levelStr { + case "DEBUG", "debug": + klog.V(4).Info(messageStr) + case "INFO", "info": + klog.Info(messageStr) + case "WARN", "warn", "WARNING", "warning": + klog.Warning(messageStr) + case "ERROR", "error": + klog.Error(messageStr) + case "FATAL", "fatal": + klog.Fatal(messageStr) + default: + // Default to Info level for unknown levels + klog.Info(messageStr) + } +} diff --git a/internal/hypervisor/device/wrapper.c b/internal/hypervisor/device/wrapper.c new file mode 100644 index 00000000..dbf9822f --- /dev/null +++ b/internal/hypervisor/device/wrapper.c @@ -0,0 +1,205 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../../provider/accelerator.h" +#include +#include +#include +#include +#include +#include +#include + +// Forward declaration of Go Log function +extern void GoLog(const char* level, const char* message); + +// Function pointer types for dynamic loading +typedef Result (*GetDeviceCountFunc)(size_t*); +typedef Result (*GetAllDevicesFunc)(ExtendedDeviceInfo*, size_t, size_t*); +typedef Result (*GetPartitionTemplatesFunc)(int32_t, PartitionTemplate*, size_t, size_t*); +typedef bool (*AssignPartitionFunc)(PartitionAssignment*); +typedef bool (*RemovePartitionFunc)(const char*, const char*); +typedef Result (*SetMemHardLimitFunc)(const char*, const char*, uint64_t); +typedef Result (*SetComputeUnitHardLimitFunc)(const char*, const char*, uint32_t); +typedef Result (*GetProcessComputeUtilizationFunc)(ComputeUtilization*, size_t, size_t*); +typedef Result (*GetProcessMemoryUtilizationFunc)(MemoryUtilization*, size_t, size_t*); +typedef Result (*LogFunc)(const char*, const char*); + +// Global handle for the loaded library +static void* libHandle = NULL; + +// Function pointers +static GetDeviceCountFunc getDeviceCountFunc = NULL; +static GetAllDevicesFunc getAllDevicesFunc = NULL; +static GetPartitionTemplatesFunc getPartitionTemplatesFunc = NULL; +static AssignPartitionFunc assignPartitionFunc = NULL; +static RemovePartitionFunc removePartitionFunc = NULL; +static SetMemHardLimitFunc setMemHardLimitFunc = NULL; +static SetComputeUnitHardLimitFunc setComputeUnitHardLimitFunc = NULL; +static GetProcessComputeUtilizationFunc getProcessComputeUtilizationFunc = NULL; +static GetProcessMemoryUtilizationFunc getProcessMemoryUtilizationFunc = NULL; +static LogFunc logFunc = NULL; + +// Load library dynamically +int loadAcceleratorLibrary(const char* libPath) { + if (libHandle != NULL) { + dlclose(libHandle); + } + + libHandle = dlopen(libPath, RTLD_LAZY | RTLD_LOCAL); + if (libHandle == NULL) { + return -1; // Failed to load + } + + // Load function symbols + getDeviceCountFunc = (GetDeviceCountFunc)dlsym(libHandle, "GetDeviceCount"); + getAllDevicesFunc = (GetAllDevicesFunc)dlsym(libHandle, "GetAllDevices"); + getPartitionTemplatesFunc = (GetPartitionTemplatesFunc)dlsym(libHandle, "GetPartitionTemplates"); + assignPartitionFunc = (AssignPartitionFunc)dlsym(libHandle, "AssignPartition"); + removePartitionFunc = (RemovePartitionFunc)dlsym(libHandle, "RemovePartition"); + setMemHardLimitFunc = (SetMemHardLimitFunc)dlsym(libHandle, "SetMemHardLimit"); + setComputeUnitHardLimitFunc = (SetComputeUnitHardLimitFunc)dlsym(libHandle, "SetComputeUnitHardLimit"); + getProcessComputeUtilizationFunc = (GetProcessComputeUtilizationFunc)dlsym(libHandle, "GetProcessComputeUtilization"); + getProcessMemoryUtilizationFunc = (GetProcessMemoryUtilizationFunc)dlsym(libHandle, "GetProcessMemoryUtilization"); + logFunc = (LogFunc)dlsym(libHandle, "Log"); + + // Check if all required functions are loaded (Log is optional) + if (!getDeviceCountFunc || !getAllDevicesFunc || !getPartitionTemplatesFunc || + !assignPartitionFunc || !removePartitionFunc || !setMemHardLimitFunc || + !setComputeUnitHardLimitFunc || !getProcessComputeUtilizationFunc || + !getProcessMemoryUtilizationFunc) { + dlclose(libHandle); + libHandle = NULL; + return -2; // Missing symbols + } + + // If the library has a Log function, we can't directly replace it, + // but we provide our own Log function that the library can use. + // The library's internal Log calls will use its own implementation, + // but if the library is designed to call Log via function pointer or + // if it doesn't have its own Log, it will use our implementation. + + return 0; // Success +} + +// Unload library +void unloadAcceleratorLibrary(void) { + if (libHandle != NULL) { + dlclose(libHandle); + libHandle = NULL; + getDeviceCountFunc = NULL; + getAllDevicesFunc = NULL; + getPartitionTemplatesFunc = NULL; + assignPartitionFunc = NULL; + removePartitionFunc = NULL; + setMemHardLimitFunc = NULL; + setComputeUnitHardLimitFunc = NULL; + getProcessComputeUtilizationFunc = NULL; + getProcessMemoryUtilizationFunc = NULL; + logFunc = NULL; + } +} + +// Wrapper functions that call the dynamically loaded functions +Result GetDeviceCountWrapper(size_t* deviceCount) { + if (getDeviceCountFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getDeviceCountFunc(deviceCount); +} + +Result GetAllDevicesWrapper(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount) { + if (getAllDevicesFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getAllDevicesFunc(devices, maxCount, deviceCount); +} + +Result GetPartitionTemplatesWrapper(int32_t deviceIndex, PartitionTemplate* templates, size_t maxCount, size_t* templateCount) { + if (getPartitionTemplatesFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getPartitionTemplatesFunc(deviceIndex, templates, maxCount, templateCount); +} + +bool AssignPartitionWrapper(PartitionAssignment* assignment) { + if (assignPartitionFunc == NULL) { + return false; + } + return assignPartitionFunc(assignment); +} + +bool RemovePartitionWrapper(const char* templateId, const char* deviceUUID) { + if (removePartitionFunc == NULL) { + return false; + } + return removePartitionFunc(templateId, deviceUUID); +} + +Result SetMemHardLimitWrapper(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes) { + if (setMemHardLimitFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return setMemHardLimitFunc(workerId, deviceUUID, memoryLimitBytes); +} + +Result SetComputeUnitHardLimitWrapper(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit) { + if (setComputeUnitHardLimitFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return setComputeUnitHardLimitFunc(workerId, deviceUUID, computeUnitLimit); +} + +Result GetProcessComputeUtilizationWrapper(ComputeUtilization* utilizations, size_t maxCount, size_t* utilizationCount) { + if (getProcessComputeUtilizationFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getProcessComputeUtilizationFunc(utilizations, maxCount, utilizationCount); +} + +Result GetProcessMemoryUtilizationWrapper(MemoryUtilization* utilizations, size_t maxCount, size_t* utilizationCount) { + if (getProcessMemoryUtilizationFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getProcessMemoryUtilizationFunc(utilizations, maxCount, utilizationCount); +} + +// Get error message from dlopen +const char* getDlError(void) { + return dlerror(); +} + +// Log wrapper that calls Go's Log function +// This function provides a Log implementation that the dynamically loaded library can use +// When the library calls Log(), it will call this function which forwards to Go's klog +Result LogWrapper(const char* level, const char* message) { + if (level == NULL || message == NULL) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Call Go's Log function + GoLog(level, message); + + return RESULT_SUCCESS; +} + +// Provide a Log function that can be called by the dynamically loaded library +// This is the Log function that accelerator.h defines - we provide an implementation +// that forwards to Go's klog via GoLog +Result Log(const char* level, const char* message) { + return LogWrapper(level, message); +} + diff --git a/internal/hypervisor/framework/framework.go b/internal/hypervisor/framework/framework.go new file mode 100644 index 00000000..790bce65 --- /dev/null +++ b/internal/hypervisor/framework/framework.go @@ -0,0 +1,97 @@ +package framework + +import ( + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" +) + +type DeviceController interface { + Start() error + + DiscoverDevices() error + + // ListDevices returns all discovered devices + ListDevices() ([]*api.DeviceInfo, error) + + // GetDevice returns device information by UUID + GetDevice(deviceUUID string) (*api.DeviceInfo, error) + + // GetDeviceAllocations returns device allocations + // If deviceUUID is empty, returns all allocations + GetDeviceAllocations(deviceUUID string) ([]*api.WorkerAllocation, error) + + // DevicesUpdates returns a channel that receives device list updates + // The channel should be closed when Stop() is called + DevicesUpdates() (<-chan []*api.DeviceInfo, error) + + // GetDeviceAllocationUpdates returns a channel that receives allocation updates + // The channel should be closed when Stop() is called + GetDeviceAllocationUpdates(deviceUUID string, allocationID string) (<-chan []*api.WorkerAllocation, error) + + // GetGPUMetrics returns current GPU metrics for all devices + GetGPUMetrics() (map[string]*api.GPUUsageMetrics, error) +} + +type DeviceInterface interface { + SplitDevice(deviceUUID string) error + + GetDeviceMetrics() (*api.MemoryUtilization, error) +} + +type WorkerController interface { + Start() error + + Stop() error + + // AllocateWorker allocates devices for a worker + AllocateWorker(request *api.WorkerInfo) (*api.WorkerAllocation, error) + + // GetWorkerAllocation returns allocation information for a worker + GetWorkerAllocation(workerUID string) (*api.WorkerAllocation, error) + + // GetWorkerMetricsUpdates returns a channel that receives worker metrics updates + // The channel should be closed when Stop() is called + GetWorkerMetricsUpdates() (<-chan *api.WorkerAllocation, error) + + // GetWorkerMetrics returns current worker metrics for all workers + // Returns map keyed by device UUID, then by worker UID, then by process ID + GetWorkerMetrics() (map[string]map[string]map[string]*api.WorkerMetrics, error) + + // ListWorkers returns list of all worker infos + ListWorkers() ([]*api.WorkerInfo, error) +} + +type QuotaController interface { + // SetQuota sets quota for a worker + SetQuota(workerUID string) error + + StartSoftQuotaLimiter() error + + StopSoftQuotaLimiter() error + + // GetWorkerQuotaStatus gets quota status for a worker + GetWorkerQuotaStatus(workerUID string) error +} + +// The backend interface for the hypervisor to interact with the underlying infrastructure +type Backend interface { + Start() error + + Stop() error + + // ListAndWatchWorkers gets GPU workers from the workload orchestration platform + // Returns a channel that receives worker info lists and a stop channel + // The channel should be closed when Stop() is called + ListAndWatchWorkers() (<-chan []*api.WorkerInfo, <-chan struct{}, error) + + // GetWorkerToProcessMap links workers to actual running process list on OS + GetWorkerToProcessMap() (map[string][]string, error) + + // StartWorker spawns worker process + StartWorker(workerUID string) error + + // StopWorker stops worker process + StopWorker(workerUID string) error + + // ReconcileDevices reports devices to backend orchestration and O&M platform + ReconcileDevices(devices []string) error +} diff --git a/internal/hypervisor/hypervisor_suite_test.go b/internal/hypervisor/hypervisor_suite_test.go new file mode 100644 index 00000000..0006d2c0 --- /dev/null +++ b/internal/hypervisor/hypervisor_suite_test.go @@ -0,0 +1,505 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package hypervisor + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/single_node" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/device" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/metrics" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/server" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker" +) + +func TestHypervisor(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Hypervisor Suite") +} + +var _ = Describe("Hypervisor Integration Tests", func() { + var ( + ctx context.Context + cancel context.CancelFunc + deviceController framework.DeviceController + backend framework.Backend + workerController framework.WorkerController + metricsRecorder *metrics.HypervisorMetricsRecorder + httpServer *server.Server + stubLibPath string + tempMetricsFile string + ) + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + + // Find stub library path + // Try relative path first (from provider/build) + stubLibPath = filepath.Join("..", "..", "provider", "build", "libaccelerator_stub.so") + if _, err := os.Stat(stubLibPath); os.IsNotExist(err) { + // Try absolute path from workspace root + workspaceRoot := os.Getenv("WORKSPACE_ROOT") + if workspaceRoot == "" { + // Try to find it relative to current directory + cwd, _ := os.Getwd() + stubLibPath = filepath.Join(cwd, "..", "..", "provider", "build", "libaccelerator_stub.so") + } else { + stubLibPath = filepath.Join(workspaceRoot, "provider", "build", "libaccelerator_stub.so") + } + } + + // Create temp file for metrics + tempFile, err := os.CreateTemp("", "hypervisor-metrics-*.log") + Expect(err).NotTo(HaveOccurred()) + tempMetricsFile = tempFile.Name() + _ = tempFile.Close() + }) + + AfterEach(func() { + if cancel != nil { + cancel() + } + if httpServer != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer shutdownCancel() + _ = httpServer.Stop(shutdownCtx) + } + if workerController != nil { + _ = workerController.Stop() + } + if backend != nil { + _ = backend.Stop() + } + if deviceController != nil { + if closer, ok := deviceController.(interface{ Close() error }); ok { + _ = closer.Close() + } + } + _ = os.Remove(tempMetricsFile) + }) + + Context("With stub device library", func() { + BeforeEach(func() { + // Check if stub library exists, skip if not + if _, err := os.Stat(stubLibPath); os.IsNotExist(err) { + Skip("Stub library not found. Run 'make stub' in provider directory first.") + } + + var err error + deviceController, err = device.NewController(ctx, stubLibPath, 1*time.Hour) + Expect(err).NotTo(HaveOccurred()) + Expect(deviceController).NotTo(BeNil()) + + backend = single_node.NewSingleNodeBackend(ctx, deviceController) + Expect(backend).NotTo(BeNil()) + + workerController = worker.NewWorkerController(deviceController, tfv1.IsolationModeShared, backend) + Expect(workerController).NotTo(BeNil()) + + metricsRecorder = metrics.NewHypervisorMetricsRecorder(ctx, tempMetricsFile, deviceController, workerController) + Expect(metricsRecorder).NotTo(BeNil()) + + httpServer = server.NewServer(ctx, deviceController, workerController, metricsRecorder, backend, 0) + Expect(httpServer).NotTo(BeNil()) + }) + + Describe("C Stub Library Integration", func() { + It("should load stub accelerator library", func() { + // Verify library can be loaded + accel, err := device.NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + Expect(accel).NotTo(BeNil()) + + // Test device discovery through C library + devices, err := accel.GetAllDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + // Verify stub device properties + device := devices[0] + Expect(device.UUID).To(ContainSubstring("stub-device")) + Expect(device.Vendor).To(Equal("STUB")) + Expect(device.TotalMemoryBytes).To(Equal(uint64(16 * 1024 * 1024 * 1024))) // 16GB + + _ = accel.Close() + }) + + It("should get process utilization from stub library", func() { + accel, err := device.NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + defer func() { + _ = accel.Close() + }() + + // Get compute utilization (may be empty for stub) + computeUtils, err := accel.GetProcessComputeUtilization() + Expect(err).NotTo(HaveOccurred()) + Expect(computeUtils).NotTo(BeNil()) + + // Get memory utilization (may be empty for stub) + memUtils, err := accel.GetProcessMemoryUtilization() + Expect(err).NotTo(HaveOccurred()) + Expect(memUtils).NotTo(BeNil()) + }) + }) + + Describe("Device Controller", func() { + It("should start and discover devices", func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + + // Wait a bit for discovery + time.Sleep(100 * time.Millisecond) + + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty(), "Should discover at least one stub device") + + // Verify device properties + device := devices[0] + Expect(device.UUID).NotTo(BeEmpty()) + Expect(device.Vendor).To(Equal("STUB")) + Expect(device.TotalMemoryBytes).To(BeNumerically(">", 0)) + }) + + It("should allocate devices", func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + + time.Sleep(100 * time.Millisecond) + + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + deviceUUID := devices[0].UUID + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{deviceUUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + + resp, err := workerController.AllocateWorker(req) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + // TODO verify the mounts/envs + + // Verify allocation exists + allocations, err := deviceController.GetDeviceAllocations(deviceUUID) + Expect(err).NotTo(HaveOccurred()) + Expect(allocations).To(HaveLen(1)) + }) + + It("should get GPU metrics", func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + + time.Sleep(100 * time.Millisecond) + + metrics, err := deviceController.GetGPUMetrics() + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).NotTo(BeNil()) + + // Should have metrics for all discovered devices + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(len(devices))) + }) + }) + + Describe("Single Node Backend", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = backend.Start() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should start and stop", func() { + Expect(backend).NotTo(BeNil()) + }) + + It("should list workers from allocations", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorker(req) + Expect(err).NotTo(HaveOccurred()) + + // Wait for backend to discover + time.Sleep(2 * time.Second) + + workerCh, _, err := backend.ListAndWatchWorkers() + Expect(err).NotTo(HaveOccurred()) + // Note: stopCh is receive-only, backend will close it when stopped + + // Read initial worker list from channel + select { + case workers := <-workerCh: + Expect(workers).To(ContainElement("test-worker-1")) + case <-time.After(5 * time.Second): + Fail("timeout waiting for workers") + } + }) + + It("should track worker to process mapping", func() { + // Start a worker + err := backend.StartWorker("test-worker-1") + Expect(err).NotTo(HaveOccurred()) + + processMap, err := backend.GetWorkerToProcessMap() + Expect(err).NotTo(HaveOccurred()) + Expect(processMap).NotTo(BeNil()) + }) + }) + + Describe("Worker Controller", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should start and stop", func() { + Expect(workerController).NotTo(BeNil()) + }) + + It("should list workers", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorker(req) + Expect(err).NotTo(HaveOccurred()) + + workers, err := workerController.ListWorkers() + Expect(err).NotTo(HaveOccurred()) + Expect(workers).To(ContainElement("test-worker-1")) + }) + + It("should get worker allocation", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorker(req) + Expect(err).NotTo(HaveOccurred()) + + allocation, err := workerController.GetWorkerAllocation("test-worker-1") + Expect(err).NotTo(HaveOccurred()) + Expect(allocation).NotTo(BeNil()) + Expect(allocation.WorkerInfo.WorkerUID).To(Equal("test-worker-1")) + }) + + It("should get worker metrics", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorker(req) + Expect(err).NotTo(HaveOccurred()) + + metrics, err := workerController.GetWorkerMetrics() + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).NotTo(BeNil()) + }) + }) + + Describe("Metrics Recorder", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + + metricsRecorder.Start() + }) + + It("should record metrics", func() { + // Wait for metrics to be recorded + time.Sleep(2 * time.Second) + + // Check if metrics file was created and has content + info, err := os.Stat(tempMetricsFile) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Size()).To(BeNumerically(">=", 0)) + }) + }) + + Describe("HTTP Server", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + + metricsRecorder.Start() + }) + + It("should start HTTP server", func() { + // Start server in background + go func() { + err := httpServer.Start() + Expect(err).To(Or(BeNil(), MatchError("http: Server closed"))) + }() + + // Wait for server to start + time.Sleep(500 * time.Millisecond) + + // Server should be running (we can't easily test HTTP endpoints without knowing the port) + // But we can verify the server object is created + Expect(httpServer).NotTo(BeNil()) + }) + }) + + Describe("Full Integration", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = backend.Start() + Expect(err).NotTo(HaveOccurred()) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + + metricsRecorder.Start() + + // Start HTTP server in background + go func() { + _ = httpServer.Start() + }() + time.Sleep(500 * time.Millisecond) + }) + + It("should handle complete workflow: discover -> allocate -> track -> metrics", func() { + // 1. Discover devices + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + deviceUUID := devices[0].UUID + + // 2. Allocate device + req := &api.WorkerInfo{ + WorkerUID: "integration-worker-1", + AllocatedDevices: []string{deviceUUID}, + IsolationMode: tfv1.IsolationModeSoft, + MemoryLimitBytes: 1024 * 1024 * 1024, // 1GB + } + resp, err := workerController.AllocateWorker(req) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).To(Not(BeNil())) + + // 3. Verify allocation + allocations, err := deviceController.GetDeviceAllocations(deviceUUID) + Expect(err).NotTo(HaveOccurred()) + Expect(allocations).To(HaveLen(1)) + + // 4. Backend should discover worker + time.Sleep(2 * time.Second) + workerCh, _, err := backend.ListAndWatchWorkers() + Expect(err).NotTo(HaveOccurred()) + // Note: stopCh is receive-only, backend will close it when stopped + + // Read initial worker list from channel + select { + case workers := <-workerCh: + Expect(workers).To(ContainElement("integration-worker-1")) + case <-time.After(5 * time.Second): + Fail("timeout waiting for workers") + } + + // 5. Worker controller should list worker + workerList, err := workerController.ListWorkers() + Expect(err).NotTo(HaveOccurred()) + Expect(workerList).To(ContainElement("integration-worker-1")) + + // 6. Get worker allocation + allocation, err := workerController.GetWorkerAllocation("integration-worker-1") + Expect(err).NotTo(HaveOccurred()) + Expect(allocation).NotTo(BeNil()) + Expect(allocation.WorkerInfo.WorkerUID).To(Equal(deviceUUID)) + + // 7. Get metrics + gpuMetrics, err := deviceController.GetGPUMetrics() + Expect(err).NotTo(HaveOccurred()) + Expect(gpuMetrics).NotTo(BeNil()) + Expect(gpuMetrics[deviceUUID]).NotTo(BeNil()) + + workerMetrics, err := workerController.GetWorkerMetrics() + Expect(err).NotTo(HaveOccurred()) + Expect(workerMetrics).NotTo(BeNil()) + + // 8. Deallocate (if method exists) + if deallocator, ok := deviceController.(interface{ Deallocate(string) error }); ok { + err = deallocator.Deallocate("integration-worker-1") + Expect(err).NotTo(HaveOccurred()) + } + + // 9. Verify deallocation + allocations, err = deviceController.GetDeviceAllocations(deviceUUID) + Expect(err).NotTo(HaveOccurred()) + Expect(allocations).To(BeEmpty()) + }) + }) + }) +}) diff --git a/internal/hypervisor/metrics/metrics.go b/internal/hypervisor/metrics/metrics.go new file mode 100644 index 00000000..d674ee67 --- /dev/null +++ b/internal/hypervisor/metrics/metrics.go @@ -0,0 +1,252 @@ +package metrics + +import ( + "context" + "encoding/json" + "io" + "os" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/metrics" + "gopkg.in/natefinch/lumberjack.v2" +) + +type HypervisorMetricsRecorder struct { + ctx context.Context + outputPath string + nodeName string + gpuPool string + deviceController framework.DeviceController + workerController framework.WorkerController + gpuCapacityMap map[string]float64 // GPU UUID -> MaxTflops + extraLabelsMap map[string]string // podLabelKey -> tagName mapping from env config +} + +const ( + defaultNodeName = "unknown" + defaultGPUPool = "unknown" +) + +func NewHypervisorMetricsRecorder( + ctx context.Context, outputPath string, + deviceController framework.DeviceController, + workerController framework.WorkerController, +) *HypervisorMetricsRecorder { + nodeName := os.Getenv(constants.HypervisorGPUNodeNameEnv) + if nodeName == "" { + nodeName = defaultNodeName + } + gpuPool := os.Getenv(constants.HypervisorPoolNameEnv) + if gpuPool == "" { + gpuPool = defaultGPUPool + } + + // Parse extra labels config once at initialization + extraLabelsMap := make(map[string]string) + extraLabelsConfig := os.Getenv(constants.HypervisorMetricsExtraLabelsEnv) + if extraLabelsConfig != "" { + if err := json.Unmarshal([]byte(extraLabelsConfig), &extraLabelsMap); err != nil { + // Log error but continue without extra labels + extraLabelsMap = make(map[string]string) + } + } + + return &HypervisorMetricsRecorder{ + ctx: ctx, + outputPath: outputPath, + nodeName: nodeName, + gpuPool: gpuPool, + deviceController: deviceController, + workerController: workerController, + gpuCapacityMap: make(map[string]float64), + extraLabelsMap: extraLabelsMap, + } +} + +func (h *HypervisorMetricsRecorder) Start() { + writer := &lumberjack.Logger{ + Filename: h.outputPath, + MaxSize: 100, + MaxBackups: 10, + MaxAge: 14, + } + + // Initialize GPU capacity map from devices + h.initGPUCapacityMap() + + // Record device and worker metrics + deviceMetricsTicker := time.NewTicker(10 * time.Second) + go func() { + for { + select { + case <-h.ctx.Done(): + return + case <-deviceMetricsTicker.C: + h.RecordDeviceMetrics(writer) + h.RecordWorkerMetrics(writer) + } + } + }() +} + +func (h *HypervisorMetricsRecorder) initGPUCapacityMap() { + devices, err := h.deviceController.ListDevices() + if err != nil { + return + } + for _, device := range devices { + h.gpuCapacityMap[device.UUID] = device.MaxTflops + } +} + +func (h *HypervisorMetricsRecorder) RecordDeviceMetrics(writer io.Writer) { + gpuMetrics, err := h.deviceController.GetGPUMetrics() + if err != nil { + return + } + + // Output GPU metrics directly + now := time.Now() + enc := metrics.NewEncoder(os.Getenv(constants.HypervisorMetricsFormatEnv)) + + for gpuUUID, metrics := range gpuMetrics { + enc.StartLine("tf_gpu_usage") + enc.AddTag("uuid", gpuUUID) + enc.AddTag("node", h.nodeName) + enc.AddTag("pool", h.gpuPool) + + enc.AddField("rx", metrics.Rx) + enc.AddField("tx", metrics.Tx) + // Add vendor-specific metrics from ExtraMetrics map + if metrics.ExtraMetrics != nil { + for key, value := range metrics.ExtraMetrics { + enc.AddField(key, value) + } + } + enc.AddField("temperature", metrics.Temperature) + enc.AddField("graphics_clock_mhz", metrics.GraphicsClockMHz) + enc.AddField("sm_clock_mhz", metrics.SMClockMHz) + enc.AddField("memory_clock_mhz", metrics.MemoryClockMHz) + enc.AddField("video_clock_mhz", metrics.VideoClockMHz) + enc.AddField("memory_bytes", int64(metrics.MemoryBytes)) + enc.AddField("memory_percentage", metrics.MemoryPercentage) + enc.AddField("compute_percentage", metrics.ComputePercentage) + enc.AddField("compute_tflops", metrics.ComputeTflops) + enc.AddField("power_usage", float64(metrics.PowerUsage)) + + enc.EndLine(now) + } + + if err := enc.Err(); err == nil { + _, _ = writer.Write(enc.Bytes()) + } +} + +func (h *HypervisorMetricsRecorder) RecordWorkerMetrics(writer io.Writer) { + workerMetrics, err := h.workerController.GetWorkerMetrics() + if err != nil { + return + } + + workerInfos, err := h.workerController.ListWorkers() + if err != nil { + return + } + + // Get worker allocations for metadata + workerAllocations := make(map[string]*api.WorkerAllocation) + for _, worker := range workerInfos { + allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if err == nil && allocation != nil { + workerAllocations[worker.WorkerUID] = allocation + } + } + + // Output worker metrics directly + now := time.Now() + enc := metrics.NewEncoder(os.Getenv(constants.HypervisorMetricsFormatEnv)) + + for deviceUUID, workerMap := range workerMetrics { + for workerUID, processMap := range workerMap { + allocation, ok := workerAllocations[workerUID] + if !ok { + continue + } + + var memoryBytes uint64 + var computePercentage float64 + var computeTflops float64 + var memoryPercentage float64 + + // Sum up metrics from all processes for this worker + for _, metrics := range processMap { + memoryBytes += metrics.MemoryBytes + computePercentage += metrics.ComputePercentage + computeTflops += metrics.ComputeTflops + + // Calculate memory percentage + vramLimit := float64(0) + if allocation.WorkerInfo != nil { + vramLimit = float64(allocation.WorkerInfo.MemoryLimitBytes) + } + if vramLimit > 0 { + memoryPercentage += float64(metrics.MemoryBytes) / vramLimit * 100.0 + } + } + + enc.StartLine("tf_worker_usage") + enc.AddTag("uuid", deviceUUID) + enc.AddTag("node", h.nodeName) + enc.AddTag("pool", h.gpuPool) + if allocation.WorkerInfo != nil { + enc.AddTag("pod_name", allocation.WorkerInfo.PodName) + enc.AddTag("namespace", allocation.WorkerInfo.Namespace) + } + + workloadName := "unknown" + // Try to get workload name from worker ID or pod name + if allocation.WorkerInfo != nil && allocation.WorkerInfo.WorkerUID != "" { + workloadName = allocation.WorkerInfo.WorkerUID + } + enc.AddTag("workload", workloadName) + enc.AddTag("worker", workerUID) + + // Add extra labels if configured + h.addExtraLabels(enc, allocation) + + enc.AddField("memory_bytes", int64(memoryBytes)) + enc.AddField("compute_percentage", computePercentage) + enc.AddField("compute_tflops", computeTflops) + enc.AddField("memory_percentage", memoryPercentage) + + enc.EndLine(now) + } + } + + if err := enc.Err(); err == nil { + _, _ = writer.Write(enc.Bytes()) + } +} + +// addExtraLabels adds dynamic tags based on HypervisorMetricsExtraLabelsEnv configuration +// The config is a JSON map where keys are tag names and values are pod label keys to extract +// Labels are read directly from allocation.Labels which is populated by the backend +func (h *HypervisorMetricsRecorder) addExtraLabels(enc metrics.Encoder, allocation *api.WorkerAllocation) { + if len(h.extraLabelsMap) == 0 { + return + } + + if allocation.WorkerInfo == nil || len(allocation.WorkerInfo.Annotations) == 0 { + return + } + + // Add tags based on the mapping + for podLabelKey, tagName := range h.extraLabelsMap { + if labelValue, exists := allocation.WorkerInfo.Annotations[podLabelKey]; exists && labelValue != "" { + enc.AddTag(tagName, labelValue) + } + } +} diff --git a/internal/hypervisor/server/handlers/device.go b/internal/hypervisor/server/handlers/device.go new file mode 100644 index 00000000..bc8c8627 --- /dev/null +++ b/internal/hypervisor/server/handlers/device.go @@ -0,0 +1,67 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" +) + +// DeviceHandler handles device-related endpoints +type DeviceHandler struct { + deviceController framework.DeviceController +} + +// NewDeviceHandler creates a new device handler +func NewDeviceHandler(deviceController framework.DeviceController) *DeviceHandler { + return &DeviceHandler{ + deviceController: deviceController, + } +} + +// HandleGetDevices handles GET /api/v1/devices +func (h *DeviceHandler) HandleGetDevices(c *gin.Context) { + devices, err := h.deviceController.ListDevices() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + c.JSON(http.StatusOK, api.DataResponse[[]*api.DeviceInfo]{Data: devices}) +} + +// HandleGetDevice handles GET /api/v1/devices/:uuid +func (h *DeviceHandler) HandleGetDevice(c *gin.Context) { + uuid := c.Param("uuid") + device, err := h.deviceController.GetDevice(uuid) + if err != nil { + c.JSON(http.StatusNotFound, api.ErrorResponse{Error: err.Error()}) + return + } + c.JSON(http.StatusOK, api.DataResponse[*api.DeviceInfo]{Data: device}) +} + +// HandleDiscoverDevices handles POST /api/v1/devices/discover +func (h *DeviceHandler) HandleDiscoverDevices(c *gin.Context) { + if err := h.deviceController.DiscoverDevices(); err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + c.JSON(http.StatusOK, api.StatusResponse{Status: "Device discovery triggered"}) +} diff --git a/internal/hypervisor/server/handlers/health.go b/internal/hypervisor/server/handlers/health.go new file mode 100644 index 00000000..2ccd1167 --- /dev/null +++ b/internal/hypervisor/server/handlers/health.go @@ -0,0 +1,47 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" +) + +// HealthHandler handles health check endpoints +type HealthHandler struct{} + +// NewHealthHandler creates a new health handler +func NewHealthHandler() *HealthHandler { + return &HealthHandler{} +} + +// HandleHealthz handles GET /healthz +func (h *HealthHandler) HandleHealthz(c *gin.Context) { + c.JSON(http.StatusOK, api.StatusResponse{Status: "ok"}) +} + +// HandleReadyz handles GET /readyz +func (h *HealthHandler) HandleReadyz(c *gin.Context, deviceController framework.DeviceController, workerController framework.WorkerController) { + if deviceController == nil || workerController == nil { + c.JSON(http.StatusServiceUnavailable, api.StatusResponse{Status: "not ready"}) + return + } + c.JSON(http.StatusOK, api.StatusResponse{Status: "ready"}) +} diff --git a/internal/hypervisor/server/handlers/legacy.go b/internal/hypervisor/server/handlers/legacy.go new file mode 100644 index 00000000..23eeed30 --- /dev/null +++ b/internal/hypervisor/server/handlers/legacy.go @@ -0,0 +1,204 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" + "k8s.io/apimachinery/pkg/api/resource" +) + +// LegacyHandler handles legacy endpoints +type LegacyHandler struct { + workerController framework.WorkerController + backend framework.Backend +} + +// NewLegacyHandler creates a new legacy handler +func NewLegacyHandler(workerController framework.WorkerController, backend framework.Backend) *LegacyHandler { + return &LegacyHandler{ + workerController: workerController, + backend: backend, + } +} + +// HandleGetLimiter handles GET /api/v1/limiter +func (h *LegacyHandler) HandleGetLimiter(c *gin.Context) { + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + limiterInfos := make([]api.LimiterInfo, 0, len(workers)) + for _, worker := range workers { + allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if err != nil || allocation == nil { + continue + } + + var requests, limits *tfv1.Resource + if allocation.WorkerInfo != nil && allocation.WorkerInfo.MemoryLimitBytes > 0 { + vramQty := resource.NewQuantity(int64(allocation.WorkerInfo.MemoryLimitBytes), resource.BinarySI) + limits = &tfv1.Resource{ + Vram: *vramQty, + } + } + if allocation.WorkerInfo != nil && allocation.WorkerInfo.ComputeLimitUnits > 0 { + computeLimit := float64(allocation.WorkerInfo.ComputeLimitUnits) + computeQty := resource.NewQuantity(int64(computeLimit), resource.DecimalSI) + if limits == nil { + limits = &tfv1.Resource{} + } + limits.ComputePercent = *computeQty + } + + limiterInfos = append(limiterInfos, api.LimiterInfo{ + WorkerUID: worker.WorkerUID, + Requests: requests, + Limits: limits, + }) + } + + c.JSON(http.StatusOK, api.ListLimitersResponse{Limiters: limiterInfos}) +} + +// HandleTrap handles POST /api/v1/trap +func (h *LegacyHandler) HandleTrap(c *gin.Context) { + // Trap endpoint: start snapshot low QoS workers to release VRAM + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + snapshotCount := 0 + for _, worker := range workers { + allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if err != nil || allocation == nil { + continue + } + + // TODO: Check QoS level and snapshot low QoS workers + // For now, snapshot all workers (this should be filtered by QoS) + snapshotCount++ + } + + c.JSON(http.StatusOK, api.TrapResponse{ + Message: "trap initiated", + SnapshotCount: snapshotCount, + }) +} + +// HandleGetPods handles GET /api/v1/pod +func (h *LegacyHandler) HandleGetPods(c *gin.Context) { + // Only available when k8s backend is enabled + if h.backend == nil { + c.JSON(http.StatusServiceUnavailable, api.ErrorResponse{Error: "kubernetes backend not enabled"}) + return + } + + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + pods := make([]api.PodInfo, 0) + for _, worker := range workers { + allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if err != nil || allocation == nil { + continue + } + + var tflopsLimit *float64 + var vramLimit *uint64 + var qosLevel *string + + if allocation.WorkerInfo != nil && allocation.WorkerInfo.MemoryLimitBytes > 0 { + vramLimit = &allocation.WorkerInfo.MemoryLimitBytes + } + + // Try to get QoS from allocation or default to medium + qos := "medium" + qosLevel = &qos + + pods = append(pods, api.PodInfo{ + PodName: getAllocationPodName(allocation), + Namespace: getAllocationNamespace(allocation), + GPUIDs: getDeviceUUIDs(allocation), + TflopsLimit: tflopsLimit, + VramLimit: vramLimit, + QoSLevel: qosLevel, + }) + } + + c.JSON(http.StatusOK, api.ListPodsResponse{Pods: pods}) +} + +// Helper functions for WorkerAllocation field access +func getAllocationPodName(allocation *api.WorkerAllocation) string { + if allocation.WorkerInfo != nil { + return allocation.WorkerInfo.PodName + } + return "" +} + +func getAllocationNamespace(allocation *api.WorkerAllocation) string { + if allocation.WorkerInfo != nil { + return allocation.WorkerInfo.Namespace + } + return "" +} + +func getDeviceUUIDs(allocation *api.WorkerAllocation) []string { + uuids := make([]string, 0, len(allocation.DeviceInfos)) + for _, device := range allocation.DeviceInfos { + uuids = append(uuids, device.UUID) + } + return uuids +} + +// HandleGetProcesses handles GET /api/v1/process +func (h *LegacyHandler) HandleGetProcesses(c *gin.Context) { + // Get worker to process mapping + processMap, err := h.backend.GetWorkerToProcessMap() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + processInfos := make([]api.ProcessInfo, 0, len(processMap)) + for workerUID, pids := range processMap { + mapping := make(map[string]string) + for _, pid := range pids { + // In a real implementation, this would map container PID to host PID + // For now, use the same PID + mapping[pid] = pid + } + processInfos = append(processInfos, api.ProcessInfo{ + WorkerUID: workerUID, + ProcessMapping: mapping, + }) + } + + c.JSON(http.StatusOK, api.ListProcessesResponse{Processes: processInfos}) +} diff --git a/internal/hypervisor/server/handlers/worker.go b/internal/hypervisor/server/handlers/worker.go new file mode 100644 index 00000000..e3f1ca82 --- /dev/null +++ b/internal/hypervisor/server/handlers/worker.go @@ -0,0 +1,134 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" +) + +// WorkerHandler handles worker-related endpoints +type WorkerHandler struct { + workerController framework.WorkerController +} + +// NewWorkerHandler creates a new worker handler +func NewWorkerHandler(workerController framework.WorkerController) *WorkerHandler { + return &WorkerHandler{ + workerController: workerController, + } +} + +// HandleGetWorkers handles GET /api/v1/workers +func (h *WorkerHandler) HandleGetWorkers(c *gin.Context) { + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + // Get worker details + workerDetails := make([]*api.WorkerAllocation, 0, len(workers)) + for _, worker := range workers { + allocation, err := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if err != nil { + continue + } + workerDetails = append(workerDetails, allocation) + } + + c.JSON(http.StatusOK, api.DataResponse[[]*api.WorkerAllocation]{Data: workerDetails}) +} + +// HandleGetWorker handles GET /api/v1/workers/:id +func (h *WorkerHandler) HandleGetWorker(c *gin.Context) { + workerID := c.Param("id") + allocation, err := h.workerController.GetWorkerAllocation(workerID) + if err != nil { + c.JSON(http.StatusNotFound, api.ErrorResponse{Error: err.Error()}) + return + } + if allocation == nil { + c.JSON(http.StatusNotFound, api.ErrorResponse{Error: "worker not found"}) + return + } + + // Get worker metrics + metrics, err := h.workerController.GetWorkerMetrics() + if err != nil { + c.JSON(http.StatusOK, api.DataResponse[map[string]interface{}]{ + Data: map[string]interface{}{ + "worker_uid": workerID, + "allocation": allocation, + }, + }) + return + } + + // Filter metrics for this worker + workerMetrics := make(map[string]map[string]map[string]*api.WorkerMetrics) + // Get metrics for all devices in the allocation + for _, device := range allocation.DeviceInfos { + if allMetrics, exists := metrics[device.UUID]; exists { + if wm, exists := allMetrics[workerID]; exists { + if workerMetrics[device.UUID] == nil { + workerMetrics[device.UUID] = make(map[string]map[string]*api.WorkerMetrics) + } + workerMetrics[device.UUID][workerID] = wm + } + } + } + + type WorkerDetail struct { + WorkerUID string `json:"worker_uid"` + Allocation *api.WorkerAllocation `json:"allocation"` + Metrics map[string]map[string]map[string]*api.WorkerMetrics `json:"metrics,omitempty"` + } + + c.JSON(http.StatusOK, api.DataResponse[WorkerDetail]{ + Data: WorkerDetail{ + WorkerUID: workerID, + Allocation: allocation, + Metrics: workerMetrics, + }, + }) +} + +// HandleSnapshotWorker handles POST /api/v1/workers/:id/snapshot +func (h *WorkerHandler) HandleSnapshotWorker(c *gin.Context) { + workerID := c.Param("id") + // TODO: Implement actual snapshot logic using accelerator interface + // For now, return success + c.JSON(http.StatusOK, api.MessageAndDataResponse[string]{ + Message: "worker snapshot initiated", + Data: workerID, + }) +} + +// HandleResumeWorker handles POST /api/v1/workers/:id/resume +func (h *WorkerHandler) HandleResumeWorker(c *gin.Context) { + workerID := c.Param("id") + // TODO: Implement actual resume logic using accelerator interface + // For now, return success + c.JSON(http.StatusOK, api.MessageAndDataResponse[string]{ + Message: "worker resume initiated", + Data: workerID, + }) +} diff --git a/internal/hypervisor/server/server.go b/internal/hypervisor/server/server.go new file mode 100644 index 00000000..0578825e --- /dev/null +++ b/internal/hypervisor/server/server.go @@ -0,0 +1,131 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "context" + "fmt" + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/server/handlers" + "github.com/gin-gonic/gin" + "k8s.io/klog/v2" +) + +// MetricsRecorder interface for metrics +type MetricsRecorder interface { + Start() +} + +// Server represents the hypervisor HTTP server +type Server struct { + deviceController framework.DeviceController + workerController framework.WorkerController + metricsRecorder MetricsRecorder + backend framework.Backend + ctx context.Context + router *gin.Engine + httpServer *http.Server + + // Handlers + healthHandler *handlers.HealthHandler + deviceHandler *handlers.DeviceHandler + workerHandler *handlers.WorkerHandler + legacyHandler *handlers.LegacyHandler +} + +// NewServer creates a new hypervisor HTTP server +func NewServer( + ctx context.Context, + deviceController framework.DeviceController, + workerController framework.WorkerController, + metricsRecorder MetricsRecorder, + backend framework.Backend, + port int, +) *Server { + gin.SetMode(gin.ReleaseMode) + router := gin.New() + router.Use(gin.Logger(), gin.Recovery()) + + // Initialize handlers + healthHandler := handlers.NewHealthHandler() + deviceHandler := handlers.NewDeviceHandler(deviceController) + workerHandler := handlers.NewWorkerHandler(workerController) + legacyHandler := handlers.NewLegacyHandler(workerController, backend) + + s := &Server{ + deviceController: deviceController, + workerController: workerController, + metricsRecorder: metricsRecorder, + backend: backend, + ctx: ctx, + router: router, + httpServer: &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: router, + }, + healthHandler: healthHandler, + deviceHandler: deviceHandler, + workerHandler: workerHandler, + legacyHandler: legacyHandler, + } + + s.setupRoutes() + return s +} + +func (s *Server) setupRoutes() { + // Health check routes + s.router.GET("/healthz", s.healthHandler.HandleHealthz) + s.router.GET("/readyz", func(c *gin.Context) { + s.healthHandler.HandleReadyz(c, s.deviceController, s.workerController) + }) + + // RESTful API routes + // TODO: add authentication and authorization for worker APIs + apiV1 := s.router.Group("/api/v1") + { + // Device routes + apiV1.GET("/devices", s.deviceHandler.HandleGetDevices) + apiV1.GET("/devices/:uuid", s.deviceHandler.HandleGetDevice) + apiV1.POST("/devices/discover", s.deviceHandler.HandleDiscoverDevices) + + // Worker routes + apiV1.GET("/workers", s.workerHandler.HandleGetWorkers) + apiV1.GET("/workers/:id", s.workerHandler.HandleGetWorker) + apiV1.POST("/workers/:id/snapshot", s.workerHandler.HandleSnapshotWorker) + apiV1.POST("/workers/:id/resume", s.workerHandler.HandleResumeWorker) + + // Legacy routes + apiV1.GET("/limiter", s.legacyHandler.HandleGetLimiter) + apiV1.POST("/trap", s.legacyHandler.HandleTrap) + apiV1.GET("/pod", s.legacyHandler.HandleGetPods) + apiV1.GET("/process", s.legacyHandler.HandleGetProcesses) + } +} + +// Start starts the HTTP server +func (s *Server) Start() error { + klog.Infof("Starting hypervisor HTTP server on %s", s.httpServer.Addr) + return s.httpServer.ListenAndServe() +} + +// Stop stops the HTTP server +func (s *Server) Stop(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +} diff --git a/internal/hypervisor/tui/chart.go b/internal/hypervisor/tui/chart.go new file mode 100644 index 00000000..ed5f1fb4 --- /dev/null +++ b/internal/hypervisor/tui/chart.go @@ -0,0 +1,219 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" +) + +const ( + maxHistorySize = 60 // Keep 60 data points for ~2 minutes at 2s intervals +) + +// TimeSeriesChart represents a time-series chart for metrics +type TimeSeriesChart struct { + data []float64 + width int + height int + maxValue float64 + minValue float64 + label string +} + +// NewTimeSeriesChart creates a new time-series chart +func NewTimeSeriesChart(width, height int, label string) *TimeSeriesChart { + return &TimeSeriesChart{ + data: make([]float64, 0, maxHistorySize), + width: width, + height: height, + maxValue: 100.0, // Default max for percentages + minValue: 0.0, + label: label, + } +} + +// AddDataPoint adds a new data point to the chart +func (c *TimeSeriesChart) AddDataPoint(value float64) { + c.data = append(c.data, value) + if len(c.data) > maxHistorySize { + c.data = c.data[1:] // Remove oldest point + } + + // Auto-scale max value + if value > c.maxValue { + c.maxValue = value * 1.1 // Add 10% padding + } + if value < c.minValue { + c.minValue = value + } +} + +// SetMaxValue sets the maximum value for the chart scale +func (c *TimeSeriesChart) SetMaxValue(max float64) { + c.maxValue = max +} + +// SetDimensions sets the width and height of the chart +func (c *TimeSeriesChart) SetDimensions(width, height int) { + c.width = width + c.height = height +} + +// Render renders the time-series chart as a string +// +//nolint:gocyclo // Complex rendering logic with multiple conditional branches +func (c *TimeSeriesChart) Render() string { + if len(c.data) == 0 { + return fmt.Sprintf("%s: No data\n", c.label) + } + + var result strings.Builder + result.WriteString(fmt.Sprintf("%s (max: %.1f)\n", c.label, c.maxValue)) + + if c.height < 2 { + // Single line mode - just show current value + lastValue := c.data[len(c.data)-1] + result.WriteString(renderBarChart(lastValue, c.width)) + return result.String() + } + + // Multi-line chart + chartHeight := c.height - 1 // Reserve one line for label + if chartHeight < 1 { + chartHeight = 1 + } + + // Create a grid for the chart + grid := make([][]rune, chartHeight) + for i := range grid { + grid[i] = make([]rune, c.width) + for j := range grid[i] { + grid[i][j] = ' ' + } + } + + // Handle edge case: maxValue == minValue + valueRange := c.maxValue - c.minValue + if valueRange == 0 { + valueRange = 1.0 // Avoid division by zero + } + + // Draw the data + dataLen := len(c.data) + if dataLen > c.width { + // Downsample if we have more data points than width + step := float64(dataLen) / float64(c.width) + for x := 0; x < c.width; x++ { + idx := int(float64(x) * step) + if idx >= dataLen { + idx = dataLen - 1 + } + value := c.data[idx] + y := int((c.maxValue - value) / valueRange * float64(chartHeight-1)) + if y < 0 { + y = 0 + } + if y >= chartHeight { + y = chartHeight - 1 + } + grid[y][x] = '█' + + // Draw line connecting to previous point + if x > 0 { + prevIdx := int(float64(x-1) * step) + if prevIdx >= dataLen { + prevIdx = dataLen - 1 + } + prevValue := c.data[prevIdx] + prevY := int((c.maxValue - prevValue) / valueRange * float64(chartHeight-1)) + if prevY < 0 { + prevY = 0 + } + if prevY >= chartHeight { + prevY = chartHeight - 1 + } + + // Draw connecting line + startY, endY := prevY, y + if startY > endY { + startY, endY = endY, startY + } + for lineY := startY; lineY <= endY; lineY++ { + if lineY < chartHeight { + if grid[lineY][x] == ' ' { + grid[lineY][x] = '│' + } + } + } + } + } + } else { + // Draw all data points + for x, value := range c.data { + if x >= c.width { + break + } + y := int((c.maxValue - value) / valueRange * float64(chartHeight-1)) + if y < 0 { + y = 0 + } + if y >= chartHeight { + y = chartHeight - 1 + } + grid[y][x] = '█' + + // Draw connecting line + if x > 0 { + prevValue := c.data[x-1] + prevY := int((c.maxValue - prevValue) / valueRange * float64(chartHeight-1)) + if prevY < 0 { + prevY = 0 + } + if prevY >= chartHeight { + prevY = chartHeight - 1 + } + + startY, endY := prevY, y + if startY > endY { + startY, endY = endY, startY + } + for lineY := startY; lineY <= endY; lineY++ { + if lineY < chartHeight { + if grid[lineY][x] == ' ' { + grid[lineY][x] = '│' + } + } + } + } + } + } + + // Render the grid + for _, row := range grid { + result.WriteString(ChartBarStyle.Render(string(row))) + result.WriteString("\n") + } + + // Add current value + if len(c.data) > 0 { + lastValue := c.data[len(c.data)-1] + result.WriteString(fmt.Sprintf("Current: %.1f", lastValue)) + } + + return result.String() +} diff --git a/internal/hypervisor/tui/client.go b/internal/hypervisor/tui/client.go new file mode 100644 index 00000000..db1160d2 --- /dev/null +++ b/internal/hypervisor/tui/client.go @@ -0,0 +1,181 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" +) + +// Client is an HTTP client for fetching data from the hypervisor server +type Client struct { + baseURL string + httpClient *http.Client +} + +// NewClient creates a new HTTP client for the hypervisor +func NewClient(host string, port int) *Client { + return &Client{ + baseURL: fmt.Sprintf("http://%s:%d/api/v1", host, port), + httpClient: &http.Client{ + Timeout: 5 * time.Second, + }, + } +} + +// doRequest performs an HTTP request and decodes the JSON response +// +//nolint:unparam // method parameter is kept for API consistency, even though it's always "GET" +func (c *Client) doRequest(ctx context.Context, method, path string, result interface{}) error { + url := fmt.Sprintf("%s/%s", c.baseURL, path) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("execute request: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + if err := json.NewDecoder(resp.Body).Decode(result); err != nil { + return fmt.Errorf("decode response: %w", err) + } + + return nil +} + +// ListDevices fetches all devices from the hypervisor +func (c *Client) ListDevices(ctx context.Context) ([]*api.DeviceInfo, error) { + var result api.DataResponse[[]*api.DeviceInfo] + if err := c.doRequest(ctx, "GET", "devices", &result); err != nil { + return nil, fmt.Errorf("list devices: %w", err) + } + return result.Data, nil +} + +// GetDevice fetches a specific device by UUID +func (c *Client) GetDevice(ctx context.Context, uuid string) (*api.DeviceInfo, error) { + var result api.DataResponse[*api.DeviceInfo] + if err := c.doRequest(ctx, "GET", fmt.Sprintf("devices/%s", uuid), &result); err != nil { + return nil, fmt.Errorf("get device %s: %w", uuid, err) + } + return result.Data, nil +} + +// GetDeviceAllocations fetches allocations for a specific device +func (c *Client) GetDeviceAllocations(ctx context.Context, uuid string) ([]*api.WorkerAllocation, error) { + workers, err := c.ListWorkers(ctx) + if err != nil { + return nil, fmt.Errorf("list workers: %w", err) + } + + allocations := make([]*api.WorkerAllocation, 0) + for _, worker := range workers { + // Check if any device in the allocation matches the UUID + for _, device := range worker.DeviceInfos { + if device.UUID == uuid { + allocations = append(allocations, worker) + break + } + } + } + + return allocations, nil +} + +// GetGPUMetrics fetches GPU metrics for all devices +// Note: This is a placeholder until a dedicated metrics endpoint is available +func (c *Client) GetGPUMetrics(ctx context.Context) (map[string]*api.GPUUsageMetrics, error) { + // TODO: Implement when metrics endpoint is available + // For now, return empty metrics to avoid errors + return make(map[string]*api.GPUUsageMetrics), nil +} + +// ListWorkers fetches all workers from the hypervisor +func (c *Client) ListWorkers(ctx context.Context) ([]*api.WorkerAllocation, error) { + var result api.DataResponse[[]*api.WorkerAllocation] + if err := c.doRequest(ctx, "GET", "workers", &result); err != nil { + return nil, fmt.Errorf("list workers: %w", err) + } + return result.Data, nil +} + +// GetWorker fetches a specific worker by ID +func (c *Client) GetWorker(ctx context.Context, workerID string) (*api.WorkerAllocation, map[string]map[string]map[string]*api.WorkerMetrics, error) { + type WorkerDetail struct { + WorkerUID string `json:"worker_uid"` + Allocation *api.WorkerAllocation `json:"allocation"` + Metrics map[string]map[string]map[string]*api.WorkerMetrics `json:"metrics,omitempty"` + } + + var result api.DataResponse[WorkerDetail] + if err := c.doRequest(ctx, "GET", fmt.Sprintf("workers/%s", workerID), &result); err != nil { + return nil, nil, fmt.Errorf("get worker %s: %w", workerID, err) + } + return result.Data.Allocation, result.Data.Metrics, nil +} + +// GetWorkerMetrics fetches worker metrics for all workers +// This is optimized to batch requests when possible +func (c *Client) GetWorkerMetrics(ctx context.Context) (map[string]map[string]map[string]*api.WorkerMetrics, error) { + workers, err := c.ListWorkers(ctx) + if err != nil { + return nil, err + } + + metrics := make(map[string]map[string]map[string]*api.WorkerMetrics) + for _, worker := range workers { + // Get WorkerUID from WorkerInfo + if worker.WorkerInfo == nil { + continue + } + workerUID := worker.WorkerInfo.WorkerUID + _, workerMetrics, err := c.GetWorker(ctx, workerUID) + if err != nil { + // Continue on individual worker errors to get as much data as possible + continue + } + + // Merge metrics by device UUID + for deviceUUID, deviceMetrics := range workerMetrics { + if metrics[deviceUUID] == nil { + metrics[deviceUUID] = make(map[string]map[string]*api.WorkerMetrics) + } + // Copy worker metrics for this device + for wUID, wMetrics := range deviceMetrics { + metrics[deviceUUID][wUID] = wMetrics + } + } + } + + return metrics, nil +} diff --git a/internal/hypervisor/tui/device_view.go b/internal/hypervisor/tui/device_view.go new file mode 100644 index 00000000..c7b1ca90 --- /dev/null +++ b/internal/hypervisor/tui/device_view.go @@ -0,0 +1,146 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "context" + "fmt" + "strings" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/viewport" +) + +// deviceItem represents a device in the list +type deviceItem struct { + uuid string + model string + index int32 +} + +func (d deviceItem) FilterValue() string { + return fmt.Sprintf("%s %s %d", d.uuid, d.model, d.index) +} + +func (d deviceItem) Title() string { + return fmt.Sprintf("[%d] %s", d.index, d.model) +} + +func (d deviceItem) Description() string { + return d.uuid +} + +func newDeviceDelegate() list.DefaultDelegate { + d := list.NewDefaultDelegate() + d.Styles.SelectedTitle = SelectedStyle + d.Styles.SelectedDesc = SelectedStyle + d.Styles.NormalTitle = NormalStyle + d.Styles.NormalDesc = NormalStyle + return d +} + +// updateDeviceList updates the device list with current devices +func updateDeviceList(deviceList *list.Model, devices []*api.DeviceInfo) { + deviceItems := make([]list.Item, len(devices)) + for i, device := range devices { + deviceItems[i] = deviceItem{ + uuid: device.UUID, + model: device.Model, + index: device.Index, + } + } + deviceList.SetItems(deviceItems) +} + +// updateDeviceDetail updates the device detail viewport +func updateDeviceDetail( + ctx context.Context, + client *Client, + deviceDetail *viewport.Model, + selectedDeviceUUID string, + devices []*api.DeviceInfo, + metrics map[string]*api.GPUUsageMetrics, + deviceMetricsHistory map[string]*DeviceMetricsHistory, +) { + var device *api.DeviceInfo + for _, d := range devices { + if d.UUID == selectedDeviceUUID { + device = d + break + } + } + if device == nil { + deviceDetail.SetContent("Device not found") + return + } + + deviceMetrics, hasMetrics := metrics[device.UUID] + + var content strings.Builder + content.WriteString(TitleStyle.Render("Device Details\n\n")) + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("UUID"), MetricValueStyle.Render(device.UUID))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Vendor"), MetricValueStyle.Render(device.Vendor))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Model"), MetricValueStyle.Render(device.Model))) + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("Index"), device.Index)) + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("NUMA Node"), device.NUMANode)) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Total Memory"), formatBytes(device.TotalMemoryBytes))) + content.WriteString(fmt.Sprintf("%s: %.2f TFLOPS\n\n", MetricLabelStyle.Render("Max TFLOPS"), device.MaxTflops)) + + if hasMetrics && deviceMetrics != nil { + content.WriteString(TitleStyle.Render("Current Metrics\n\n")) + content.WriteString(fmt.Sprintf("%s: %.1f%%\n", MetricLabelStyle.Render("Memory Usage"), deviceMetrics.MemoryPercentage)) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Memory Used"), formatBytes(deviceMetrics.MemoryBytes))) + content.WriteString(fmt.Sprintf("%s: %.1f%%\n", MetricLabelStyle.Render("Compute Usage"), deviceMetrics.ComputePercentage)) + content.WriteString(fmt.Sprintf("%s: %.2f TFLOPS\n", MetricLabelStyle.Render("Compute TFLOPS"), deviceMetrics.ComputeTflops)) + content.WriteString(fmt.Sprintf("%s: %.1f°C\n", MetricLabelStyle.Render("Temperature"), deviceMetrics.Temperature)) + content.WriteString(fmt.Sprintf("%s: %d W\n", MetricLabelStyle.Render("Power Usage"), deviceMetrics.PowerUsage)) + content.WriteString(fmt.Sprintf("%s: %.1f MHz\n", MetricLabelStyle.Render("Graphics Clock"), deviceMetrics.GraphicsClockMHz)) + content.WriteString(fmt.Sprintf("%s: %.1f MHz\n\n", MetricLabelStyle.Render("SM Clock"), deviceMetrics.SMClockMHz)) + + // Time-series charts + if history, exists := deviceMetricsHistory[selectedDeviceUUID]; exists && history != nil { + content.WriteString("\n") + content.WriteString(history.MemoryChart.Render()) + content.WriteString("\n") + content.WriteString(history.ComputeChart.Render()) + content.WriteString("\n") + content.WriteString(history.TempChart.Render()) + content.WriteString("\n") + content.WriteString(history.PowerChart.Render()) + content.WriteString("\n") + } + } + + // Get allocations for this device + allocations, err := client.GetDeviceAllocations(ctx, device.UUID) + if err == nil && len(allocations) > 0 { + content.WriteString(TitleStyle.Render("Allocations\n\n")) + for _, alloc := range allocations { + content.WriteString(fmt.Sprintf(" Worker: %s\n", alloc.WorkerInfo.WorkerUID)) + content.WriteString(fmt.Sprintf(" Pod: %s/%s\n", alloc.WorkerInfo.Namespace, alloc.WorkerInfo.PodName)) + content.WriteString(fmt.Sprintf(" Mode: %s\n", alloc.WorkerInfo.IsolationMode)) + if alloc.WorkerInfo.MemoryLimitBytes > 0 { + content.WriteString(fmt.Sprintf(" Memory Limit: %s\n", formatBytes(alloc.WorkerInfo.MemoryLimitBytes))) + } + content.WriteString("\n") + } + } + + deviceDetail.SetContent(content.String()) +} diff --git a/internal/hypervisor/tui/metrics_view.go b/internal/hypervisor/tui/metrics_view.go new file mode 100644 index 00000000..df925d62 --- /dev/null +++ b/internal/hypervisor/tui/metrics_view.go @@ -0,0 +1,76 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/viewport" +) + +// updateMetricsView updates the metrics viewport +func updateMetricsView( + metricsView *viewport.Model, + devices []*api.DeviceInfo, + workers []WorkerInfo, + metrics map[string]*api.GPUUsageMetrics, + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics, + lastUpdate time.Time, +) { + var content strings.Builder + content.WriteString(TitleStyle.Render("System Metrics\n\n")) + content.WriteString(fmt.Sprintf("Last Update: %s\n\n", lastUpdate.Format(time.RFC3339))) + + // Device metrics overview + content.WriteString(TitleStyle.Render("Device Metrics Overview\n\n")) + for _, device := range devices { + metrics, hasMetrics := metrics[device.UUID] + content.WriteString(fmt.Sprintf("%s [%s]\n", device.Model, device.UUID[:8])) + if hasMetrics && metrics != nil { + content.WriteString(fmt.Sprintf(" Memory: %.1f%% %s\n", metrics.MemoryPercentage, renderBarChart(metrics.MemoryPercentage, 20))) + content.WriteString(fmt.Sprintf(" Compute: %.1f%% %s\n", metrics.ComputePercentage, renderBarChart(metrics.ComputePercentage, 20))) + content.WriteString(fmt.Sprintf(" Temperature: %.1f°C Power: %dW\n", metrics.Temperature, metrics.PowerUsage)) + } else { + content.WriteString(" No metrics available\n") + } + content.WriteString("\n") + } + + // Worker metrics overview + content.WriteString(TitleStyle.Render("Worker Metrics Overview\n\n")) + for _, worker := range workers { + content.WriteString(fmt.Sprintf("%s/%s\n", worker.Namespace, worker.PodName)) + if workerMetrics, exists := workerMetrics[worker.DeviceUUID]; exists { + if wm, exists := workerMetrics[worker.UID]; exists { + var totalMemory uint64 + var totalCompute float64 + for _, metrics := range wm { + totalMemory += metrics.MemoryBytes + totalCompute += metrics.ComputePercentage + } + content.WriteString(fmt.Sprintf(" Memory: %s\n", formatBytes(totalMemory))) + content.WriteString(fmt.Sprintf(" Compute: %.1f%% %s\n", totalCompute, renderBarChart(totalCompute, 20))) + } + } + content.WriteString("\n") + } + + metricsView.SetContent(content.String()) +} diff --git a/internal/hypervisor/tui/model.go b/internal/hypervisor/tui/model.go new file mode 100644 index 00000000..a08db355 --- /dev/null +++ b/internal/hypervisor/tui/model.go @@ -0,0 +1,558 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "context" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +const ( + viewDevices = iota + viewWorkers + viewMetrics + viewDeviceDetail + viewWorkerDetail +) + +// Model represents the TUI model +type Model struct { + ctx context.Context + client *Client + + currentView int + devices []*api.DeviceInfo + workers []WorkerInfo + metrics map[string]*api.GPUUsageMetrics + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics + + // Metrics history for time-series charts + deviceMetricsHistory map[string]*DeviceMetricsHistory + workerMetricsHistory map[string]*WorkerMetricsHistory + + deviceList list.Model + workerList list.Model + deviceDetail viewport.Model + workerDetail viewport.Model + metricsView viewport.Model + + shmDialog *ShmDialogModel + + selectedDeviceUUID string + selectedWorkerUID string + + width int + height int + + lastUpdate time.Time +} + +// DeviceMetricsHistory tracks historical metrics for a device +type DeviceMetricsHistory struct { + MemoryChart *TimeSeriesChart + ComputeChart *TimeSeriesChart + TempChart *TimeSeriesChart + PowerChart *TimeSeriesChart +} + +// WorkerMetricsHistory tracks historical metrics for a worker +type WorkerMetricsHistory struct { + MemoryChart *TimeSeriesChart + ComputeChart *TimeSeriesChart +} + +type tickMsg time.Time +type updateDataMsg struct { + devices []*api.DeviceInfo + workers []WorkerInfo + metrics map[string]*api.GPUUsageMetrics + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics +} + +// NewModel creates a new TUI model +func NewModel(ctx context.Context, client *Client) *Model { + m := &Model{ + ctx: ctx, + client: client, + currentView: viewDevices, + metrics: make(map[string]*api.GPUUsageMetrics), + workerMetrics: make(map[string]map[string]map[string]*api.WorkerMetrics), + deviceMetricsHistory: make(map[string]*DeviceMetricsHistory), + workerMetricsHistory: make(map[string]*WorkerMetricsHistory), + } + + // Initialize device list + deviceItems := []list.Item{} + m.deviceList = list.New(deviceItems, newDeviceDelegate(), 0, 0) + m.deviceList.Title = "GPU Devices" + m.deviceList.SetShowStatusBar(false) + m.deviceList.SetFilteringEnabled(true) + m.deviceList.Styles.Title = TitleStyle + m.deviceList.Styles.FilterPrompt = SubtitleStyle + m.deviceList.Styles.FilterCursor = SelectedStyle + + // Initialize worker list + workerItems := []list.Item{} + m.workerList = list.New(workerItems, newWorkerDelegate(), 0, 0) + m.workerList.Title = "Workers" + m.workerList.SetShowStatusBar(false) + m.workerList.SetFilteringEnabled(true) + m.workerList.Styles.Title = TitleStyle + m.workerList.Styles.FilterPrompt = SubtitleStyle + m.workerList.Styles.FilterCursor = SelectedStyle + + // Initialize detail viewports + m.deviceDetail = viewport.New(0, 0) + m.workerDetail = viewport.New(0, 0) + m.metricsView = viewport.New(0, 0) + + // Initialize SHM dialog + m.shmDialog = NewShmDialogModel() + + return m +} + +func (m *Model) Init() tea.Cmd { + return tea.Batch( + m.updateData(), + tick(), + ) +} + +func (m *Model) updateData() tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithTimeout(m.ctx, 5*time.Second) + defer cancel() + + // Get devices + devices, err := m.client.ListDevices(ctx) + if err != nil { + devices = []*api.DeviceInfo{} + } + + // Get workers + workerDetails, err := m.client.ListWorkers(ctx) + if err != nil { + workerDetails = []*api.WorkerAllocation{} + } + + workers := make([]WorkerInfo, 0, len(workerDetails)) + for _, worker := range workerDetails { + if worker == nil { + continue + } + // Extract device UUID from the first device in allocation + deviceUUID := "" + if len(worker.DeviceInfos) > 0 { + deviceUUID = worker.DeviceInfos[0].UUID + } + workers = append(workers, WorkerInfo{ + UID: worker.WorkerInfo.WorkerUID, + PodName: worker.WorkerInfo.PodName, + Namespace: worker.WorkerInfo.Namespace, + DeviceUUID: deviceUUID, + Allocation: worker, + }) + } + + // Get GPU metrics - for now, we'll need to add a metrics endpoint + // For now, return empty metrics + metrics := make(map[string]*api.GPUUsageMetrics) + + // Get worker metrics + workerMetrics, err := m.client.GetWorkerMetrics(ctx) + if err != nil { + workerMetrics = make(map[string]map[string]map[string]*api.WorkerMetrics) + } + + return updateDataMsg{ + devices: devices, + workers: workers, + metrics: metrics, + workerMetrics: workerMetrics, + } + } +} + +func tick() tea.Cmd { + return tea.Tick(2*time.Second, func(t time.Time) tea.Msg { + return tickMsg(t) + }) +} + +//nolint:gocyclo // Complex state machine with many message types and view transitions +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.resizeViews() + if m.shmDialog != nil { + m.shmDialog.width = msg.Width + m.shmDialog.height = msg.Height + } + return m, nil + + case tea.KeyMsg: + switch msg.String() { + case "q", "ctrl+c": + return m, tea.Quit + case "1": + m.currentView = viewDevices + return m, nil + case "2": + m.currentView = viewWorkers + return m, nil + case "3": + m.currentView = viewMetrics + return m, nil + case "esc": + // Close SHM dialog if visible + if m.shmDialog != nil && m.shmDialog.IsVisible() { + m.shmDialog.Hide() + return m, nil + } + if m.currentView == viewDeviceDetail || m.currentView == viewWorkerDetail { + if m.currentView == viewDeviceDetail { + m.currentView = viewDevices + } else { + m.currentView = viewWorkers + } + return m, nil + } + case "enter": + if m.currentView == viewDevices { + if selectedItem := m.deviceList.SelectedItem(); selectedItem != nil { + item := selectedItem.(deviceItem) + m.selectedDeviceUUID = item.uuid + m.currentView = viewDeviceDetail + // Initialize history if needed + if m.deviceMetricsHistory[m.selectedDeviceUUID] == nil { + m.initDeviceHistory(m.selectedDeviceUUID) + } + updateDeviceDetail(m.ctx, m.client, &m.deviceDetail, m.selectedDeviceUUID, m.devices, m.metrics, m.deviceMetricsHistory) + return m, nil + } + } else if m.currentView == viewWorkers { + if selectedItem := m.workerList.SelectedItem(); selectedItem != nil { + item := selectedItem.(workerItem) + m.selectedWorkerUID = item.uid + m.currentView = viewWorkerDetail + // Initialize history if needed + if m.workerMetricsHistory[m.selectedWorkerUID] == nil { + m.initWorkerHistory(m.selectedWorkerUID) + } + updateWorkerDetail(&m.workerDetail, m.selectedWorkerUID, m.workers, m.workerMetrics, m.workerMetricsHistory) + return m, nil + } + } else if m.currentView == viewWorkerDetail { + // Check if SHM dialog is visible, if so, close it + if m.shmDialog != nil && m.shmDialog.IsVisible() { + m.shmDialog.Hide() + return m, nil + } + // Otherwise, show SHM dialog if isolation mode is soft + var worker *WorkerInfo + for _, w := range m.workers { + if w.UID == m.selectedWorkerUID { + worker = &w + break + } + } + if worker != nil && worker.Allocation != nil && worker.Allocation.WorkerInfo != nil { + m.shmDialog.Show(worker) + return m, nil + } + } + } + + case tickMsg: + return m, tea.Batch(m.updateData(), tick()) + + case updateDataMsg: + m.devices = msg.devices + m.workers = msg.workers + m.metrics = msg.metrics + m.workerMetrics = msg.workerMetrics + m.lastUpdate = time.Now() + + // Update metrics history for charts + m.updateMetricsHistory() + + updateDeviceList(&m.deviceList, m.devices) + updateWorkerList(&m.workerList, m.workers) + switch m.currentView { + case viewDeviceDetail: + updateDeviceDetail(m.ctx, m.client, &m.deviceDetail, m.selectedDeviceUUID, m.devices, m.metrics, m.deviceMetricsHistory) + case viewWorkerDetail: + updateWorkerDetail(&m.workerDetail, m.selectedWorkerUID, m.workers, m.workerMetrics, m.workerMetricsHistory) + case viewMetrics: + updateMetricsView(&m.metricsView, m.devices, m.workers, m.metrics, m.workerMetrics, m.lastUpdate) + } + return m, nil + } + + // Update sub-views + // If SHM dialog is visible, it should handle input first + if m.shmDialog != nil && m.shmDialog.IsVisible() { + var cmd tea.Cmd + _, cmd = m.shmDialog.Update(msg) + cmds = append(cmds, cmd) + return m, tea.Batch(cmds...) + } + + switch m.currentView { + case viewDevices: + var cmd tea.Cmd + m.deviceList, cmd = m.deviceList.Update(msg) + cmds = append(cmds, cmd) + case viewWorkers: + var cmd tea.Cmd + m.workerList, cmd = m.workerList.Update(msg) + cmds = append(cmds, cmd) + case viewDeviceDetail: + var cmd tea.Cmd + m.deviceDetail, cmd = m.deviceDetail.Update(msg) + cmds = append(cmds, cmd) + case viewWorkerDetail: + var cmd tea.Cmd + m.workerDetail, cmd = m.workerDetail.Update(msg) + cmds = append(cmds, cmd) + case viewMetrics: + var cmd tea.Cmd + m.metricsView, cmd = m.metricsView.Update(msg) + cmds = append(cmds, cmd) + } + + return m, tea.Batch(cmds...) +} + +func (m *Model) resizeViews() { + headerHeight := 3 + footerHeight := 2 + availableHeight := m.height - headerHeight - footerHeight + + switch m.currentView { + case viewDevices: + m.deviceList.SetWidth(m.width) + m.deviceList.SetHeight(availableHeight) + case viewWorkers: + m.workerList.SetWidth(m.width) + m.workerList.SetHeight(availableHeight) + case viewDeviceDetail, viewWorkerDetail, viewMetrics: + width := m.width + height := availableHeight + m.deviceDetail.Width = width + m.deviceDetail.Height = height + m.workerDetail.Width = width + m.workerDetail.Height = height + m.metricsView.Width = width + m.metricsView.Height = height + + // Update chart dimensions when resizing + chartWidth := width - 20 + if chartWidth < 40 { + chartWidth = 40 + } + chartHeight := 8 + + if m.currentView == viewDeviceDetail && m.selectedDeviceUUID != "" { + if history := m.deviceMetricsHistory[m.selectedDeviceUUID]; history != nil { + history.MemoryChart.SetDimensions(chartWidth, chartHeight) + history.ComputeChart.SetDimensions(chartWidth, chartHeight) + history.TempChart.SetDimensions(chartWidth, chartHeight) + history.PowerChart.SetDimensions(chartWidth, chartHeight) + } + } else if m.currentView == viewWorkerDetail && m.selectedWorkerUID != "" { + if history := m.workerMetricsHistory[m.selectedWorkerUID]; history != nil { + history.MemoryChart.SetDimensions(chartWidth, chartHeight) + history.ComputeChart.SetDimensions(chartWidth, chartHeight) + } + } + } +} + +func (m *Model) View() string { + if m.width == 0 || m.height == 0 { + return "Initializing..." + } + + var view string + switch m.currentView { + case viewDevices: + view = m.deviceList.View() + case viewWorkers: + view = m.workerList.View() + case viewDeviceDetail: + view = m.deviceDetail.View() + case viewWorkerDetail: + view = m.workerDetail.View() + case viewMetrics: + view = m.metricsView.View() + } + + header := m.renderHeader() + footer := m.renderFooter() + + mainView := lipgloss.JoinVertical(lipgloss.Left, header, view, footer) + + // Render SHM dialog on top if visible + if m.shmDialog != nil && m.shmDialog.IsVisible() { + dialogView := m.shmDialog.View() + // The dialog already handles centering, so we just return it + // It will overlay on top of the main view + return dialogView + } + + return mainView +} + +// initDeviceHistory initializes metrics history for a device +func (m *Model) initDeviceHistory(deviceUUID string) { + chartWidth := m.width - 20 + if chartWidth < 40 { + chartWidth = 40 + } + chartHeight := 8 + + m.deviceMetricsHistory[deviceUUID] = &DeviceMetricsHistory{ + MemoryChart: NewTimeSeriesChart(chartWidth, chartHeight, "Memory Usage"), + ComputeChart: NewTimeSeriesChart(chartWidth, chartHeight, "Compute Usage"), + TempChart: NewTimeSeriesChart(chartWidth, chartHeight, "Temperature"), + PowerChart: NewTimeSeriesChart(chartWidth, chartHeight, "Power Usage"), + } + + // Set max values + m.deviceMetricsHistory[deviceUUID].MemoryChart.SetMaxValue(100.0) + m.deviceMetricsHistory[deviceUUID].ComputeChart.SetMaxValue(100.0) + m.deviceMetricsHistory[deviceUUID].TempChart.SetMaxValue(100.0) // Will auto-scale + m.deviceMetricsHistory[deviceUUID].PowerChart.SetMaxValue(500.0) // Will auto-scale +} + +// initWorkerHistory initializes metrics history for a worker +func (m *Model) initWorkerHistory(workerUID string) { + chartWidth := m.width - 20 + if chartWidth < 40 { + chartWidth = 40 + } + chartHeight := 8 + + m.workerMetricsHistory[workerUID] = &WorkerMetricsHistory{ + MemoryChart: NewTimeSeriesChart(chartWidth, chartHeight, "Memory Usage"), + ComputeChart: NewTimeSeriesChart(chartWidth, chartHeight, "Compute Usage"), + } + + // Set max values + m.workerMetricsHistory[workerUID].MemoryChart.SetMaxValue(100.0) + m.workerMetricsHistory[workerUID].ComputeChart.SetMaxValue(100.0) +} + +// updateMetricsHistory updates the metrics history with current values +func (m *Model) updateMetricsHistory() { + // Update device metrics history + for deviceUUID, metrics := range m.metrics { + if metrics == nil { + continue + } + + history := m.deviceMetricsHistory[deviceUUID] + if history == nil { + // Only initialize if we're viewing this device + if m.currentView == viewDeviceDetail && m.selectedDeviceUUID == deviceUUID { + m.initDeviceHistory(deviceUUID) + history = m.deviceMetricsHistory[deviceUUID] + } else { + continue + } + } + + history.MemoryChart.AddDataPoint(metrics.MemoryPercentage) + history.ComputeChart.AddDataPoint(metrics.ComputePercentage) + history.TempChart.AddDataPoint(metrics.Temperature) + history.PowerChart.AddDataPoint(float64(metrics.PowerUsage)) + } + + // Update worker metrics history + for _, deviceWorkers := range m.workerMetrics { + for workerUID, workerMetrics := range deviceWorkers { + history := m.workerMetricsHistory[workerUID] + if history == nil { + // Only initialize if we're viewing this worker + if m.currentView == viewWorkerDetail && m.selectedWorkerUID == workerUID { + m.initWorkerHistory(workerUID) + history = m.workerMetricsHistory[workerUID] + } else { + continue + } + } + + // Aggregate metrics for this worker + var totalMemory uint64 + var totalCompute float64 + for _, metrics := range workerMetrics { + totalMemory += metrics.MemoryBytes + totalCompute += metrics.ComputePercentage + } + + // Calculate percentage if we have allocation info + var memPercent float64 + for _, worker := range m.workers { + if worker.UID == workerUID && worker.Allocation != nil && worker.Allocation.WorkerInfo != nil && worker.Allocation.WorkerInfo.MemoryLimitBytes > 0 { + memPercent = float64(totalMemory) / float64(worker.Allocation.WorkerInfo.MemoryLimitBytes) * 100.0 + break + } + } + + history.MemoryChart.AddDataPoint(memPercent) + history.ComputeChart.AddDataPoint(totalCompute) + } + } +} + +func (m *Model) renderHeader() string { + title := TitleStyle.Render("Tensor Fusion Hypervisor") + tabs := []string{} + tabs = append(tabs, m.renderTab("Devices [1]", m.currentView == viewDevices)) + tabs = append(tabs, m.renderTab("Workers [2]", m.currentView == viewWorkers)) + tabs = append(tabs, m.renderTab("Metrics [3]", m.currentView == viewMetrics)) + tabLine := lipgloss.JoinHorizontal(lipgloss.Left, tabs...) + return lipgloss.JoinVertical(lipgloss.Left, title, tabLine) +} + +func (m *Model) renderTab(text string, active bool) string { + if active { + return SelectedStyle.Render(text) + } + return NormalStyle.Render(text) +} + +func (m *Model) renderFooter() string { + help := "Press 'q' to quit | 'Enter' to view details" + if m.currentView == viewWorkerDetail { + help += " (Enter again for SHM details if soft isolation)" + } + help += " | 'Esc' to go back | '1/2/3' to switch views" + return SubtitleStyle.Render(help) +} diff --git a/internal/hypervisor/tui/shm_dialog.go b/internal/hypervisor/tui/shm_dialog.go new file mode 100644 index 00000000..0dd3983b --- /dev/null +++ b/internal/hypervisor/tui/shm_dialog.go @@ -0,0 +1,300 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "path/filepath" + "strings" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + workerstate "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker/state" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +const ( + shmBasePath = constants.TFDataPath + constants.SharedMemMountSubPath +) + +// ShmDialogModel represents the shared memory detail dialog +type ShmDialogModel struct { + viewport viewport.Model + content string + width int + height int + isVisible bool + workerInfo *WorkerInfo +} + +// NewShmDialogModel creates a new SHM dialog model +func NewShmDialogModel() *ShmDialogModel { + return &ShmDialogModel{ + viewport: viewport.New(0, 0), + isVisible: false, + } +} + +// Init initializes the dialog +func (m *ShmDialogModel) Init() tea.Cmd { + return nil +} + +// Update updates the dialog +func (m *ShmDialogModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + if !m.isVisible { + return m, nil + } + + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "esc", "q": + m.isVisible = false + return m, nil + } + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.resize() + return m, nil + } + + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + return m, cmd +} + +// View renders the dialog +func (m *ShmDialogModel) View() string { + if !m.isVisible { + return "" + } + + // Calculate dialog dimensions (80% of screen, centered) + dialogWidth := int(float64(m.width) * 0.8) + dialogHeight := int(float64(m.height) * 0.8) + + if dialogWidth < 40 { + dialogWidth = 40 + } + if dialogHeight < 10 { + dialogHeight = 10 + } + + // Create dialog box + box := BorderStyle. + Width(dialogWidth). + Height(dialogHeight). + Render(m.viewport.View()) + + // Center the dialog + return lipgloss.Place( + m.width, + m.height, + lipgloss.Center, + lipgloss.Center, + box, + ) +} + +// Show displays the dialog with SHM details for the given worker +func (m *ShmDialogModel) Show(workerInfo *WorkerInfo) { + m.workerInfo = workerInfo + m.isVisible = true + m.resize() + m.updateContent() +} + +// Hide hides the dialog +func (m *ShmDialogModel) Hide() { + m.isVisible = false +} + +// IsVisible returns whether the dialog is visible +func (m *ShmDialogModel) IsVisible() bool { + return m.isVisible +} + +// resize resizes the dialog viewport +func (m *ShmDialogModel) resize() { + if !m.isVisible { + return + } + + dialogWidth := int(float64(m.width) * 0.8) + dialogHeight := int(float64(m.height) * 0.8) + + if dialogWidth < 40 { + dialogWidth = 40 + } + if dialogHeight < 10 { + dialogHeight = 10 + } + + // Account for border + m.viewport.Width = dialogWidth - 2 + m.viewport.Height = dialogHeight - 2 +} + +// updateContent updates the dialog content with SHM details +func (m *ShmDialogModel) updateContent() { + if m.workerInfo == nil { + m.content = "No worker information available" + m.viewport.SetContent(m.content) + return + } + + var content strings.Builder + + // Title + content.WriteString(TitleStyle.Render("Shared Memory Details\n\n")) + + // Construct pod identifier and path + podIdentifier := workerstate.NewPodIdentifier(m.workerInfo.Namespace, m.workerInfo.PodName) + podPath := podIdentifier.ToPath(shmBasePath) + shmPath := filepath.Join(podPath, workerstate.ShmPathSuffix) + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Pod"), MetricValueStyle.Render(podIdentifier.String()))) + content.WriteString(fmt.Sprintf("%s: %s\n\n", MetricLabelStyle.Render("SHM Path"), MetricValueStyle.Render(shmPath))) + + // Try to open the shared memory handle + handle, err := workerstate.OpenSharedMemoryHandle(podPath) + if err != nil { + content.WriteString(fmt.Sprintf("%s: %s\n\n", MetricLabelStyle.Render("Error"), MetricValueStyle.Render(err.Error()))) + m.content = content.String() + m.viewport.SetContent(m.content) + return + } + defer func() { + _ = handle.Close() + }() + + // Get the state + state := handle.GetState() + if state == nil { + content.WriteString(fmt.Sprintf("%s: %s\n\n", MetricLabelStyle.Render("Error"), MetricValueStyle.Render("Shared memory state is null"))) + m.content = content.String() + m.viewport.SetContent(m.content) + return + } + + // Basic information + deviceCount := state.DeviceCount() + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("Device Count"), deviceCount)) + + lastHeartbeat := state.GetLastHeartbeat() + heartbeatTime := time.Unix(int64(lastHeartbeat), 0) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Last Heartbeat"), heartbeatTime.Format(time.RFC3339))) + + // Health check (2 seconds timeout) + isHealthy := state.IsHealthy(2 * time.Second) + healthStatus := "Healthy" + if !isHealthy { + healthStatus = "Unhealthy" + } + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Health Status"), MetricValueStyle.Render(healthStatus))) + + // Version information + version := state.Version() + content.WriteString(fmt.Sprintf("%s: v%d\n\n", MetricLabelStyle.Render("State Version"), version)) + + // Device details based on version + if version == 1 && state.V1 != nil { + // V1 format + for i := 0; i < deviceCount; i++ { + if !state.V1.HasDevice(i) { + continue + } + + device := &state.V1.Devices[i] + if !device.IsActive() { + continue + } + + uuid := device.GetUUID() + availableCores := device.DeviceInfo.AvailableCudaCores + totalCores := device.DeviceInfo.TotalCudaCores + memLimit := device.DeviceInfo.MemLimit + podMemoryUsed := device.DeviceInfo.PodMemoryUsed + upLimit := device.DeviceInfo.UpLimit + + content.WriteString(fmt.Sprintf("Device %d:\n", i)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("UUID"), MetricValueStyle.Render(uuid))) + content.WriteString(fmt.Sprintf(" %s: %d / %d\n", MetricLabelStyle.Render("Cores"), availableCores, totalCores)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Limit"), formatBytes(memLimit))) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Used"), formatBytes(podMemoryUsed))) + content.WriteString(fmt.Sprintf(" %s: %d%%\n\n", MetricLabelStyle.Render("Up Limit"), upLimit)) + } + } else if version == 2 && state.V2 != nil { + // V2 format with ERL + for i := 0; i < deviceCount; i++ { + if !state.V2.HasDevice(i) { + continue + } + + device := &state.V2.Devices[i] + if !device.IsActive() { + continue + } + + uuid := device.GetUUID() + totalCores := device.DeviceInfo.TotalCudaCores + memLimit := device.DeviceInfo.MemLimit + podMemoryUsed := device.DeviceInfo.PodMemoryUsed + upLimit := device.DeviceInfo.UpLimit + + // ERL information + erlCurrentTokens := device.DeviceInfo.GetERLCurrentTokens() + erlTokenCapacity := device.DeviceInfo.GetERLTokenCapacity() + erlTokenRefillRate := device.DeviceInfo.GetERLTokenRefillRate() + erlLastTokenUpdate := device.DeviceInfo.GetERLLastTokenUpdate() + + content.WriteString(fmt.Sprintf("Device %d:\n", i)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("UUID"), MetricValueStyle.Render(uuid))) + content.WriteString(fmt.Sprintf(" %s: %d\n", MetricLabelStyle.Render("Total Cores"), totalCores)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Limit"), formatBytes(memLimit))) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Used"), formatBytes(podMemoryUsed))) + content.WriteString(fmt.Sprintf(" %s: %d%%\n", MetricLabelStyle.Render("Up Limit"), upLimit)) + content.WriteString(fmt.Sprintf(" %s: %.1f / %.1f (rate: %.1f/s, updated: %.0fµs)\n\n", + MetricLabelStyle.Render("ERL Tokens"), + erlCurrentTokens, + erlTokenCapacity, + erlTokenRefillRate, + erlLastTokenUpdate)) + } + } else { + content.WriteString(fmt.Sprintf("Unknown shared memory version: %d\n\n", version)) + } + + // Additional state information + pids := state.GetAllPIDs() + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("Active PIDs Count"), len(pids))) + if len(pids) > 0 { + pidStrs := make([]string, len(pids)) + for i, pid := range pids { + pidStrs[i] = fmt.Sprintf("%d", pid) + } + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Active PIDs"), strings.Join(pidStrs, ", "))) + } + + m.content = content.String() + m.viewport.SetContent(m.content) + m.viewport.GotoTop() +} diff --git a/internal/hypervisor/tui/styles.go b/internal/hypervisor/tui/styles.go new file mode 100644 index 00000000..6fb4c01d --- /dev/null +++ b/internal/hypervisor/tui/styles.go @@ -0,0 +1,33 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "github.com/charmbracelet/lipgloss" +) + +var ( + TitleStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("63")) + SubtitleStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("241")) + BorderStyle = lipgloss.NewStyle().Border(lipgloss.RoundedBorder()).BorderForeground(lipgloss.Color("62")) + SelectedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("212")).Bold(true) + NormalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("250")) + MetricLabelStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("243")).Width(20) + MetricValueStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("39")).Bold(true) + ChartBarStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("46")) + ChartEmptyStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("238")) +) diff --git a/internal/hypervisor/tui/utils.go b/internal/hypervisor/tui/utils.go new file mode 100644 index 00000000..dc8722e0 --- /dev/null +++ b/internal/hypervisor/tui/utils.go @@ -0,0 +1,57 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" +) + +// formatBytes formats bytes into human-readable format +func formatBytes(bytes uint64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +// renderBarChart renders a bar chart for a percentage value +// This is a simple wrapper that calls the chart implementation +func renderBarChart(percentage float64, width int) string { + if percentage > 100 { + percentage = 100 + } + if percentage < 0 { + percentage = 0 + } + + filled := int(percentage / 100.0 * float64(width)) + empty := width - filled + + var bar strings.Builder + bar.WriteString(ChartBarStyle.Render(strings.Repeat("█", filled))) + bar.WriteString(ChartEmptyStyle.Render(strings.Repeat("░", empty))) + bar.WriteString(fmt.Sprintf(" %.1f%%", percentage)) + + return bar.String() +} diff --git a/internal/hypervisor/tui/worker_view.go b/internal/hypervisor/tui/worker_view.go new file mode 100644 index 00000000..3ac363d0 --- /dev/null +++ b/internal/hypervisor/tui/worker_view.go @@ -0,0 +1,148 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/viewport" +) + +// WorkerInfo represents worker information +type WorkerInfo struct { + UID string + PodName string + Namespace string + DeviceUUID string + Allocation *api.WorkerAllocation +} + +// workerItem represents a worker in the list +type workerItem struct { + uid string + podName string + namespace string +} + +func (w workerItem) FilterValue() string { + return fmt.Sprintf("%s %s %s", w.uid, w.podName, w.namespace) +} + +func (w workerItem) Title() string { + return fmt.Sprintf("%s/%s", w.namespace, w.podName) +} + +func (w workerItem) Description() string { + return w.uid +} + +func newWorkerDelegate() list.DefaultDelegate { + d := list.NewDefaultDelegate() + d.Styles.SelectedTitle = SelectedStyle + d.Styles.SelectedDesc = SelectedStyle + d.Styles.NormalTitle = NormalStyle + d.Styles.NormalDesc = NormalStyle + return d +} + +// updateWorkerList updates the worker list with current workers +func updateWorkerList(workerList *list.Model, workers []WorkerInfo) { + workerItems := make([]list.Item, len(workers)) + for i, worker := range workers { + workerItems[i] = workerItem{ + uid: worker.UID, + podName: worker.PodName, + namespace: worker.Namespace, + } + } + workerList.SetItems(workerItems) +} + +// updateWorkerDetail updates the worker detail viewport +func updateWorkerDetail( + workerDetail *viewport.Model, + selectedWorkerUID string, + workers []WorkerInfo, + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics, + workerMetricsHistory map[string]*WorkerMetricsHistory, +) { + var worker *WorkerInfo + for _, w := range workers { + if w.UID == selectedWorkerUID { + worker = &w + break + } + } + if worker == nil { + workerDetail.SetContent("Worker not found") + return + } + + var content strings.Builder + content.WriteString(TitleStyle.Render("Worker Details\n\n")) + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Worker UID"), MetricValueStyle.Render(worker.UID))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Pod Name"), MetricValueStyle.Render(worker.PodName))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Namespace"), MetricValueStyle.Render(worker.Namespace))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Device UUID"), MetricValueStyle.Render(worker.DeviceUUID))) + + if worker.Allocation != nil && worker.Allocation.WorkerInfo != nil { + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Isolation Mode"), MetricValueStyle.Render(string(worker.Allocation.WorkerInfo.IsolationMode)))) + if worker.Allocation.WorkerInfo.MemoryLimitBytes > 0 { + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Memory Limit"), formatBytes(worker.Allocation.WorkerInfo.MemoryLimitBytes))) + } + if worker.Allocation.WorkerInfo.ComputeLimitUnits > 0 { + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("Compute Limit Units"), worker.Allocation.WorkerInfo.ComputeLimitUnits)) + } + // Note: AllocatedAt timestamp will be added to WorkerInfo if needed for business logic + content.WriteString("\n") + } + + // Get worker metrics + if deviceWorkerMetrics, exists := workerMetrics[worker.DeviceUUID]; exists { + if wm, exists := deviceWorkerMetrics[worker.UID]; exists { + content.WriteString(TitleStyle.Render("Current Metrics\n\n")) + var totalMemory uint64 + var totalCompute float64 + var totalTflops float64 + + for _, metrics := range wm { + totalMemory += metrics.MemoryBytes + totalCompute += metrics.ComputePercentage + totalTflops += metrics.ComputeTflops + } + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Memory Used"), formatBytes(totalMemory))) + content.WriteString(fmt.Sprintf("%s: %.1f%%\n", MetricLabelStyle.Render("Compute Usage"), totalCompute)) + content.WriteString(fmt.Sprintf("%s: %.2f TFLOPS\n\n", MetricLabelStyle.Render("Compute TFLOPS"), totalTflops)) + + // Time-series charts + if history, exists := workerMetricsHistory[selectedWorkerUID]; exists && history != nil { + content.WriteString("\n") + content.WriteString(history.MemoryChart.Render()) + content.WriteString("\n") + content.WriteString(history.ComputeChart.Render()) + content.WriteString("\n") + } + } + } + + workerDetail.SetContent(content.String()) +} diff --git a/internal/hypervisor/worker/computing/erl.go b/internal/hypervisor/worker/computing/erl.go new file mode 100644 index 00000000..e882c738 --- /dev/null +++ b/internal/hypervisor/worker/computing/erl.go @@ -0,0 +1,352 @@ +package computing + +import ( + "errors" + "fmt" + "math" +) + +var ( + ErrInvalidConfig = errors.New("invalid configuration") +) + +// DeviceBackend defines the interface for device token/quota operations +type DeviceBackend interface { + ReadTokenState(device int) (*TokenState, error) + WriteTokenState(device int, state *TokenState) error + ReadQuota(device int) (*DeviceQuota, error) + WriteRefillRate(device int, refillRate float64) error + WriteCapacity(device int, capacity float64) error + FetchSubTokens(device int, cost float64) (float64, error) + FetchAddTokens(device int, amount float64) (float64, error) +} + +// TokenState represents the current token bucket state +type TokenState struct { + Tokens float64 + LastUpdate float64 +} + +// DeviceQuota represents device quota configuration +type DeviceQuota struct { + Capacity float64 + RefillRate float64 +} + +// DeviceControllerConfig holds configuration for the PID-based device controller +type DeviceControllerConfig struct { + // Target GPU utilization (0.0 to 1.0, e.g., 0.5 = 50%) + TargetUtilization float64 + + // Minimum refill rate (tokens/second) - prevents rate from dropping to zero + RateMin float64 + + // Maximum refill rate (tokens/second) + RateMax float64 + + // PID proportional gain - how aggressively to respond to error + Kp float64 + + // PID integral gain - how quickly to eliminate steady-state error + Ki float64 + + // PID derivative gain - how much to dampen oscillations + Kd float64 + + // Low-pass filter coefficient for smoothing utilization (0.0 to 1.0) + // Higher values = less filtering (more responsive, more noise) + FilterAlpha float64 + + // Burst window in seconds - capacity = refill_rate × burst_window + BurstWindow float64 + + // Minimum capacity (tokens) + CapacityMin float64 + + // Maximum capacity (tokens) - prevents unbounded growth + CapacityMax float64 + + // Minimum time between updates (seconds) + MinDeltaTime float64 + + // Integral decay factor (0.0 to 1.0) for exponential decay of integral term + // Higher values (closer to 1.0) = slower decay, retains more history + // Lower values = faster decay, responds more quickly to changes + // Default 0.95 means ~20 update cycles for integral to decay to ~35.8% of original value + IntegralDecayFactor float64 +} + +// DefaultDeviceControllerConfig returns a default configuration +func DefaultDeviceControllerConfig() DeviceControllerConfig { + return DeviceControllerConfig{ + TargetUtilization: 0.5, + RateMin: 10.0, + RateMax: 100_000.0, + Kp: 0.5, + Ki: 0.1, + Kd: 0.05, + FilterAlpha: 0.3, + BurstWindow: 2.0, + CapacityMin: 100.0, + CapacityMax: 200_000.0, + MinDeltaTime: 0.05, + IntegralDecayFactor: 0.95, + } +} + +// DeviceControllerState is a snapshot of controller state after an update +type DeviceControllerState struct { + TargetUtilization float64 + SmoothedUtilization float64 + CurrentRate float64 + CurrentCapacity float64 + TokenDrainRate float64 +} + +// DeviceController is a PID-based controller that dynamically adjusts token refill rates +type DeviceController struct { + backend DeviceBackend + device int + cfg DeviceControllerConfig + + // PID state + integral float64 + lastError float64 + + // Filtering state + smoothedUtil *float64 + + // Rate tracking + currentRate float64 + + // Drain rate estimation + lastTokenLevel float64 + lastTimestamp *float64 +} + +// NewDeviceController creates a new device controller +func NewDeviceController(backend DeviceBackend, device int, cfg DeviceControllerConfig) (*DeviceController, error) { + // Validate configuration + if cfg.TargetUtilization < 0.0 || cfg.TargetUtilization > 1.0 { + return nil, fmt.Errorf("%w: target_utilization must be in [0, 1]", ErrInvalidConfig) + } + if cfg.RateMin <= 0.0 || cfg.RateMax <= cfg.RateMin { + return nil, fmt.Errorf("%w: rate_max must be greater than rate_min > 0", ErrInvalidConfig) + } + if cfg.FilterAlpha < 0.0 || cfg.FilterAlpha > 1.0 { + return nil, fmt.Errorf("%w: filter_alpha must be in [0, 1]", ErrInvalidConfig) + } + if cfg.IntegralDecayFactor < 0.0 || cfg.IntegralDecayFactor > 1.0 { + return nil, fmt.Errorf("%w: integral_decay_factor must be in [0, 1]", ErrInvalidConfig) + } + + // Initialize with a conservative starting rate + startRate := math.Min(100.0, cfg.RateMax) + startRate = math.Max(startRate, cfg.RateMin) + initialCapacity := math.Max(cfg.CapacityMin, math.Min(cfg.CapacityMax, startRate*cfg.BurstWindow)) + + // Initialize backend + if err := backend.WriteCapacity(device, initialCapacity); err != nil { + return nil, err + } + if err := backend.WriteRefillRate(device, startRate); err != nil { + return nil, err + } + + tokenState, err := backend.ReadTokenState(device) + if err != nil { + return nil, err + } + tokenState.Tokens = initialCapacity + if err := backend.WriteTokenState(device, tokenState); err != nil { + return nil, err + } + + return &DeviceController{ + backend: backend, + device: device, + cfg: cfg, + integral: 0.0, + lastError: 0.0, + smoothedUtil: nil, + currentRate: startRate, + lastTokenLevel: initialCapacity, + lastTimestamp: nil, + }, nil +} + +// State returns the current controller state +func (dc *DeviceController) State() DeviceControllerState { + capacity := math.Max(dc.cfg.CapacityMin, math.Min(dc.cfg.CapacityMax, dc.currentRate*dc.cfg.BurstWindow)) + smoothedUtil := 0.0 + if dc.smoothedUtil != nil { + smoothedUtil = *dc.smoothedUtil + } + return DeviceControllerState{ + TargetUtilization: dc.cfg.TargetUtilization, + SmoothedUtilization: smoothedUtil, + CurrentRate: dc.currentRate, + CurrentCapacity: capacity, + TokenDrainRate: 0.0, // Will be updated during next cycle + } +} + +// Update updates controller with new utilization measurement and explicit delta time +func (dc *DeviceController) Update(utilization float64, deltaTime float64) (*DeviceControllerState, error) { + if deltaTime < dc.cfg.MinDeltaTime { + state := dc.State() + return &state, nil + } + return dc.updateInternal(utilization, deltaTime) +} + +// UpdateWithTimestamp updates controller with timestamp (calculates delta automatically) +func (dc *DeviceController) UpdateWithTimestamp(utilization float64, timestampMicros uint64) (*DeviceControllerState, error) { + seconds := float64(timestampMicros) / 1_000_000.0 + var delta float64 + if dc.lastTimestamp != nil { + rawDelta := seconds - *dc.lastTimestamp + if rawDelta < dc.cfg.MinDeltaTime { + state := dc.State() + return &state, nil + } + delta = rawDelta + } else { + delta = dc.cfg.MinDeltaTime + } + dc.lastTimestamp = &seconds + return dc.updateInternal(utilization, delta) +} + +// updateInternal performs the core update logic +func (dc *DeviceController) updateInternal(measuredUtil float64, deltaTime float64) (*DeviceControllerState, error) { + // Clamp measured utilization + measured := math.Max(0.0, math.Min(1.0, measuredUtil)) + + // Step 1: Low-pass filter to smooth NVML noise + smoothed := dc.smoothUtilization(measured) + + // Step 2: Estimate token drain rate + drainRate, err := dc.estimateDrainRate(deltaTime) + if err != nil { + return nil, err + } + + // Step 3: Calculate base rate from drain rate and target + baseRate := dc.calculateBaseRate(smoothed, drainRate) + + // Step 4: Compute PID correction + error := dc.cfg.TargetUtilization - smoothed + correction := dc.computePIDCorrection(error, deltaTime) + + // Step 5: Apply correction to base rate + newRate := math.Max(dc.cfg.RateMin, math.Min(dc.cfg.RateMax, baseRate*(1.0+correction))) + dc.currentRate = newRate + + // Step 6: Calculate capacity (bounded) + newCapacity := math.Max(dc.cfg.CapacityMin, math.Min(dc.cfg.CapacityMax, newRate*dc.cfg.BurstWindow)) + + // Step 7: Refill tokens + refillAmount := newRate * deltaTime + if _, err := dc.backend.FetchAddTokens(dc.device, refillAmount); err != nil { + return nil, err + } + + // Step 8: Update backend (capacity must be updated before clamping) + if err := dc.backend.WriteRefillRate(dc.device, newRate); err != nil { + return nil, err + } + if err := dc.backend.WriteCapacity(dc.device, newCapacity); err != nil { + return nil, err + } + + // Step 9: Clamp tokens to capacity (after capacity update, tokens may exceed new capacity) + // Optimization: only read and write if clamping is needed + state, err := dc.backend.ReadTokenState(dc.device) + if err != nil { + return nil, err + } + if state.Tokens > newCapacity { + state.Tokens = newCapacity + if err := dc.backend.WriteTokenState(dc.device, state); err != nil { + return nil, err + } + } + + return &DeviceControllerState{ + TargetUtilization: dc.cfg.TargetUtilization, + SmoothedUtilization: smoothed, + CurrentRate: newRate, + CurrentCapacity: newCapacity, + TokenDrainRate: drainRate, + }, nil +} + +// smoothUtilization applies exponential moving average to smooth utilization measurements +func (dc *DeviceController) smoothUtilization(measured float64) float64 { + alpha := dc.cfg.FilterAlpha + var smoothed float64 + if dc.smoothedUtil != nil { + smoothed = alpha*measured + (1.0-alpha)**dc.smoothedUtil + } else { + smoothed = measured + } + dc.smoothedUtil = &smoothed + return smoothed +} + +// estimateDrainRate estimates token drain rate from bucket level changes +func (dc *DeviceController) estimateDrainRate(deltaTime float64) (float64, error) { + currentState, err := dc.backend.ReadTokenState(dc.device) + if err != nil { + return 0, err + } + currentTokens := currentState.Tokens + + // Expected tokens = last level + refill during delta_time + expectedTokens := dc.lastTokenLevel + dc.currentRate*deltaTime + + // Actual drain = expected - actual + drainRate := math.Max(0.0, (expectedTokens-currentTokens)/deltaTime) + + dc.lastTokenLevel = currentTokens + return drainRate, nil +} + +// calculateBaseRate calculates base refill rate from current utilization and drain rate +// The idea: if we're at `actual_util` with `drain_rate`, then to reach +// `target_util` we need: `base_rate = drain_rate × (target / actual)` +func (dc *DeviceController) calculateBaseRate(smoothedUtil float64, drainRate float64) float64 { + if smoothedUtil > 0.01 { + // Theoretical base rate to reach target + theoretical := drainRate * (dc.cfg.TargetUtilization / smoothedUtil) + return math.Max(dc.cfg.RateMin, math.Min(dc.cfg.RateMax, theoretical)) + } + // Very low utilization - maintain current rate or use minimum + return math.Max(dc.currentRate, dc.cfg.RateMin) +} + +// computePIDCorrection computes PID correction term +// Returns a correction factor in the range [-0.5, 0.5] to apply to base_rate +func (dc *DeviceController) computePIDCorrection(error float64, deltaTime float64) float64 { + // Proportional term + p := dc.cfg.Kp * error + + // Integral term with exponential decay and anti-windup + // Apply decay factor to forget old errors gradually + dc.integral *= dc.cfg.IntegralDecayFactor + // Add new error contribution + dc.integral += error * deltaTime + // Clamp to prevent windup + dc.integral = math.Max(-1.0, math.Min(1.0, dc.integral)) + i := dc.cfg.Ki * dc.integral + + // Derivative term + derivative := (error - dc.lastError) / deltaTime + d := dc.cfg.Kd * derivative + + dc.lastError = error + + // Total correction, clamped to avoid over-reaction + return math.Max(-0.5, math.Min(0.5, p+i+d)) +} diff --git a/internal/hypervisor/worker/computing/erl_test.go b/internal/hypervisor/worker/computing/erl_test.go new file mode 100644 index 00000000..bb7e5978 --- /dev/null +++ b/internal/hypervisor/worker/computing/erl_test.go @@ -0,0 +1,335 @@ +package computing + +import ( + "math" + "sync" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestERL(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "ERL Controller Suite") +} + +var _ = Describe("DeviceController", func() { + var ( + backend *MockBackend + device int + cfg DeviceControllerConfig + ) + + BeforeEach(func() { + device = 0 + cfg = DefaultDeviceControllerConfig() + cfg.RateMax = 50000.0 + cfg.CapacityMax = 100_000.0 + }) + + Describe("Initialization", func() { + It("should initialize correctly with valid config", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.TargetUtilization = 0.7 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + Expect(ctrl).NotTo(BeNil()) + Expect(ctrl.cfg.TargetUtilization).To(Equal(0.7)) + Expect(ctrl.currentRate).To(BeNumerically(">=", ctrl.cfg.RateMin)) + Expect(ctrl.currentRate).To(BeNumerically("<=", ctrl.cfg.RateMax)) + }) + + It("should reject invalid target_utilization", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.TargetUtilization = 1.5 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("target_utilization must be in [0, 1]"))) + }) + + It("should reject invalid rate_min/rate_max", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.RateMin = 100.0 + cfg.RateMax = 50.0 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("rate_max must be greater than rate_min"))) + }) + + It("should reject invalid filter_alpha", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.FilterAlpha = 1.5 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("filter_alpha must be in [0, 1]"))) + }) + + It("should reject invalid integral_decay_factor", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.IntegralDecayFactor = 1.5 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("integral_decay_factor must be in [0, 1]"))) + }) + }) + + Describe("Rate Adjustment", func() { + It("should increase rate when utilization is below target", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.7 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + rateBefore := ctrl.currentRate + + // Utilization 20% when target is 70% -> should increase rate + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + + rateAfter := ctrl.currentRate + Expect(rateAfter).To(BeNumerically(">", rateBefore), "Rate should increase when utilization is below target") + }) + + It("should decrease rate when utilization is above target", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // First establish a higher rate + _, err = ctrl.Update(0.3, 0.1) + Expect(err).NotTo(HaveOccurred()) + _, err = ctrl.Update(0.3, 0.1) + Expect(err).NotTo(HaveOccurred()) + + rateBefore := ctrl.currentRate + + // Now push utilization above target + _, err = ctrl.Update(0.95, 0.1) + Expect(err).NotTo(HaveOccurred()) + + rateAfter := ctrl.currentRate + Expect(rateAfter).To(BeNumerically("<", rateBefore), "Rate should decrease when utilization is above target") + }) + + It("should respect rate limits", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + cfg.RateMin = 50.0 + cfg.RateMax = 500.0 + cfg.CapacityMax = 1000.0 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Try to push rate very low + for i := 0; i < 10; i++ { + _, err = ctrl.Update(0.99, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + Expect(ctrl.currentRate).To(BeNumerically(">=", 50.0), "Rate should not go below rate_min") + + // Try to push rate very high + for i := 0; i < 10; i++ { + _, err = ctrl.Update(0.01, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + Expect(ctrl.currentRate).To(BeNumerically("<=", 500.0), "Rate should not exceed rate_max") + }) + }) + + Describe("Utilization Smoothing", func() { + It("should smooth utilization measurements", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + cfg.FilterAlpha = 0.3 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Feed alternating utilization values + _, err = ctrl.Update(0.8, 0.1) + Expect(err).NotTo(HaveOccurred()) + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + + state := ctrl.State() + // Smoothed value should be between the extremes + Expect(state.SmoothedUtilization).To(BeNumerically(">", 0.2)) + Expect(state.SmoothedUtilization).To(BeNumerically("<", 0.8)) + }) + }) + + Describe("Edge Cases", func() { + It("should handle zero utilization", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Feed zero utilization repeatedly + for i := 0; i < 5; i++ { + _, err = ctrl.Update(0.0, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + + // Rate should still be above minimum + Expect(ctrl.currentRate).To(BeNumerically(">=", ctrl.cfg.RateMin), "Rate should never drop below rate_min") + }) + + It("should handle very small delta_time", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + rateBefore := ctrl.currentRate + + // Update with delta_time smaller than min_delta_time + _, err = ctrl.Update(0.3, 0.001) + Expect(err).NotTo(HaveOccurred()) + + // Rate should not change + Expect(ctrl.currentRate).To(Equal(rateBefore)) + }) + }) + + Describe("Capacity Scaling", func() { + It("should scale capacity with rate", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + state1 := ctrl.State() + + // Continue to increase rate + for i := 0; i < 5; i++ { + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + + state2 := ctrl.State() + if state2.CurrentRate > state1.CurrentRate { + Expect(state2.CurrentCapacity).To(BeNumerically(">=", state1.CurrentCapacity), "Capacity should scale with rate") + } + }) + }) + + Describe("Timestamp-based Updates", func() { + It("should handle timestamp-based updates", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Update with timestamps (in microseconds) + t1 := uint64(1_000_000) // 1 second + t2 := uint64(1_200_000) // 1.2 seconds (0.2s delta) + + _, err = ctrl.UpdateWithTimestamp(0.3, t1) + Expect(err).NotTo(HaveOccurred()) + + _, err = ctrl.UpdateWithTimestamp(0.4, t2) + Expect(err).NotTo(HaveOccurred()) + }) + }) +}) + +// MockBackend is a mock implementation of DeviceBackend for testing +type MockBackend struct { + mu sync.RWMutex + quotaCapacity float64 + quotaRefillRate float64 + tokens float64 + lastUpdate float64 +} + +func NewMockBackend(capacity, refillRate, tokens float64) *MockBackend { + return &MockBackend{ + quotaCapacity: capacity, + quotaRefillRate: refillRate, + tokens: tokens, + lastUpdate: 0, + } +} + +func (m *MockBackend) ReadTokenState(device int) (*TokenState, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return &TokenState{ + Tokens: m.tokens, + LastUpdate: m.lastUpdate, + }, nil +} + +func (m *MockBackend) WriteTokenState(device int, state *TokenState) error { + m.mu.Lock() + defer m.mu.Unlock() + m.tokens = state.Tokens + m.lastUpdate = state.LastUpdate + return nil +} + +func (m *MockBackend) ReadQuota(device int) (*DeviceQuota, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return &DeviceQuota{ + Capacity: m.quotaCapacity, + RefillRate: m.quotaRefillRate, + }, nil +} + +func (m *MockBackend) WriteRefillRate(device int, refillRate float64) error { + m.mu.Lock() + defer m.mu.Unlock() + m.quotaRefillRate = refillRate + return nil +} + +func (m *MockBackend) WriteCapacity(device int, capacity float64) error { + m.mu.Lock() + defer m.mu.Unlock() + m.quotaCapacity = capacity + return nil +} + +func (m *MockBackend) FetchSubTokens(device int, cost float64) (float64, error) { + m.mu.Lock() + defer m.mu.Unlock() + + current := m.tokens + if current < cost { + return current, nil + } + + capacity := m.quotaCapacity + newTokens := math.Max(0.0, math.Min(capacity, current-cost)) + m.tokens = newTokens + return current, nil +} + +func (m *MockBackend) FetchAddTokens(device int, amount float64) (float64, error) { + m.mu.Lock() + defer m.mu.Unlock() + + current := m.tokens + capacity := m.quotaCapacity + newTokens := math.Max(0.0, math.Min(capacity, current+amount)) + m.tokens = newTokens + return current, nil +} diff --git a/internal/hypervisor/worker/computing/qos.go b/internal/hypervisor/worker/computing/qos.go new file mode 100644 index 00000000..0bfc86b9 --- /dev/null +++ b/internal/hypervisor/worker/computing/qos.go @@ -0,0 +1,3 @@ +package computing + +// diff --git a/internal/hypervisor/worker/computing/quota_controller.go b/internal/hypervisor/worker/computing/quota_controller.go new file mode 100644 index 00000000..91bb9330 --- /dev/null +++ b/internal/hypervisor/worker/computing/quota_controller.go @@ -0,0 +1,72 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package computing + +import ( + "sync" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "k8s.io/klog/v2" +) + +type Controller struct { + deviceController framework.DeviceController + mu sync.RWMutex + running bool + stopCh chan struct{} +} + +func NewQuotaController(deviceController framework.DeviceController) framework.QuotaController { + return &Controller{ + deviceController: deviceController, + stopCh: make(chan struct{}), + } +} + +func (c *Controller) SetQuota(workerUID string) error { + // TODO: Implement quota setting + return nil +} + +func (c *Controller) StartSoftQuotaLimiter() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.running { + return nil + } + c.running = true + // TODO: Start soft quota limiter thread + klog.Info("Soft quota limiter started") + return nil +} + +func (c *Controller) StopSoftQuotaLimiter() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.running { + return nil + } + close(c.stopCh) + c.running = false + klog.Info("Soft quota limiter stopped") + return nil +} + +func (c *Controller) GetWorkerQuotaStatus(workerUID string) error { + // TODO: Implement quota status retrieval + return nil +} diff --git a/internal/hypervisor/worker/controller.go b/internal/hypervisor/worker/controller.go new file mode 100644 index 00000000..09f01ec4 --- /dev/null +++ b/internal/hypervisor/worker/controller.go @@ -0,0 +1,284 @@ +package worker + +import ( + "fmt" + "sync" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker/computing" + "k8s.io/klog/v2" +) + +type WorkerController struct { + mode api.IsolationMode + backend framework.Backend + + deviceController framework.DeviceController + quotaController framework.QuotaController + + mu sync.RWMutex + workers map[string]*api.WorkerInfo + workerWatchStop chan struct{} + workerWatchStopOnce sync.Once +} + +func NewWorkerController( + deviceController framework.DeviceController, mode api.IsolationMode, backend framework.Backend) framework.WorkerController { + quotaController := computing.NewQuotaController(deviceController) + return &WorkerController{ + deviceController: deviceController, + mode: mode, + backend: backend, + quotaController: quotaController, + workers: make(map[string]*api.WorkerInfo, 16), + workerWatchStop: make(chan struct{}), + } +} + +func (w *WorkerController) Start() error { + err := w.backend.Start() + if err != nil { + return err + } + klog.Info("Worker backend started") + + // Start watching workers from backend + workerCh, stopCh, err := w.backend.ListAndWatchWorkers() + if err != nil { + return err + } + + // Start worker watcher goroutine + go func() { + for { + select { + case <-w.workerWatchStop: + return + case <-stopCh: + return + case workers, ok := <-workerCh: + if !ok { + return + } + // Update worker cache + w.mu.Lock() + for _, worker := range workers { + w.workers[worker.WorkerUID] = worker + } + w.mu.Unlock() + klog.V(4).Infof("Updated worker list: %d workers", len(workers)) + } + } + }() + + // Start soft quota limiter + if err := w.quotaController.StartSoftQuotaLimiter(); err != nil { + klog.Fatalf("Failed to start soft quota limiter: %v", err) + } + klog.Info("Soft quota limiter started") + + return nil +} + +func (w *WorkerController) Stop() error { + w.workerWatchStopOnce.Do(func() { + close(w.workerWatchStop) + }) + _ = w.backend.Stop() + _ = w.quotaController.StopSoftQuotaLimiter() + return nil +} + +// AllocateWorker implements framework.WorkerController +func (w *WorkerController) AllocateWorker(request *api.WorkerInfo) (*api.WorkerAllocation, error) { + // Validate devices exist + devices, err := w.deviceController.ListDevices() + if err != nil { + return nil, fmt.Errorf("failed to list devices: %w", err) + } + + deviceMap := make(map[string]*api.DeviceInfo) + for _, device := range devices { + deviceMap[device.UUID] = device + } + + for _, deviceUUID := range request.AllocatedDevices { + if _, exists := deviceMap[deviceUUID]; !exists { + return nil, fmt.Errorf("device not found: %s", deviceUUID) + } + } + + // Store allocation (this logic would ideally be in device controller's state management) + // For now, we'll create the allocation and let device controller track it + + // Create WorkerAllocation with WorkerInfo and DeviceInfos + deviceInfos := make([]*api.DeviceInfo, 0, len(request.AllocatedDevices)) + for _, deviceUUID := range request.AllocatedDevices { + if device, exists := deviceMap[deviceUUID]; exists { + deviceInfos = append(deviceInfos, device) + } + } + + allocation := &api.WorkerAllocation{ + WorkerInfo: request, + DeviceInfos: deviceInfos, + } + + return allocation, nil +} + +func (w *WorkerController) GetWorkerAllocation(workerUID string) (*api.WorkerAllocation, error) { + allocations, err := w.deviceController.GetDeviceAllocations("") + if err != nil { + return nil, err + } + // Find allocation for this worker + for _, allocation := range allocations { + if allocation.WorkerInfo.PodUID == workerUID || allocation.WorkerInfo.WorkerUID == workerUID { + return allocation, nil + } + } + return nil, nil +} + +func (w *WorkerController) GetWorkerMetricsUpdates() (<-chan *api.WorkerAllocation, error) { + ch := make(chan *api.WorkerAllocation, 1) + // TODO: Implement proper worker metrics updates channel with periodic updates + // Channel will be closed when controller is stopped + return ch, nil +} + +func (w *WorkerController) GetWorkerMetrics() (map[string]map[string]map[string]*api.WorkerMetrics, error) { + // Get all allocations to know which workers exist + allocations, err := w.deviceController.GetDeviceAllocations("") + if err != nil { + return nil, err + } + + // Get process compute and memory utilization from device controller + // Try to cast to concrete type to access accelerator methods + type acceleratorExposer interface { + GetProcessComputeUtilization() ([]api.ComputeUtilization, error) + GetProcessMemoryUtilization() ([]api.MemoryUtilization, error) + } + + var computeUtils []api.ComputeUtilization + var memUtils []api.MemoryUtilization + + if exposer, ok := w.deviceController.(acceleratorExposer); ok { + var err error + computeUtils, err = exposer.GetProcessComputeUtilization() + if err != nil { + computeUtils = []api.ComputeUtilization{} + } + memUtils, err = exposer.GetProcessMemoryUtilization() + if err != nil { + memUtils = []api.MemoryUtilization{} + } + } else { + // Fallback to empty metrics if interface not available + computeUtils = []api.ComputeUtilization{} + memUtils = []api.MemoryUtilization{} + } + + // Build worker to process mapping + workerToProcesses, err := w.backend.GetWorkerToProcessMap() + if err != nil { + workerToProcesses = make(map[string][]string) + } + + // Build process to metrics mapping + processMetrics := make(map[string]map[string]*api.WorkerMetrics) // processID -> deviceUUID -> metrics + + // Aggregate compute metrics by process + for _, computeUtil := range computeUtils { + if processMetrics[computeUtil.ProcessID] == nil { + processMetrics[computeUtil.ProcessID] = make(map[string]*api.WorkerMetrics) + } + if processMetrics[computeUtil.ProcessID][computeUtil.DeviceUUID] == nil { + processMetrics[computeUtil.ProcessID][computeUtil.DeviceUUID] = &api.WorkerMetrics{ + DeviceUUID: computeUtil.DeviceUUID, + ProcessID: computeUtil.ProcessID, + ComputePercentage: computeUtil.UtilizationPercent, + ComputeTflops: 0, // ComputeTflops calculation will be implemented separately + } + } else { + processMetrics[computeUtil.ProcessID][computeUtil.DeviceUUID].ComputePercentage += computeUtil.UtilizationPercent + // ComputeTflops calculation will be implemented separately + } + } + + // Aggregate memory metrics by process + for _, memUtil := range memUtils { + if processMetrics[memUtil.ProcessID] == nil { + processMetrics[memUtil.ProcessID] = make(map[string]*api.WorkerMetrics) + } + if processMetrics[memUtil.ProcessID][memUtil.DeviceUUID] == nil { + processMetrics[memUtil.ProcessID][memUtil.DeviceUUID] = &api.WorkerMetrics{ + DeviceUUID: memUtil.DeviceUUID, + ProcessID: memUtil.ProcessID, + MemoryBytes: memUtil.UsedBytes, + } + } else { + processMetrics[memUtil.ProcessID][memUtil.DeviceUUID].MemoryBytes += memUtil.UsedBytes + } + } + + // Build result: deviceUUID -> workerUID -> processID -> metrics + result := make(map[string]map[string]map[string]*api.WorkerMetrics) + + // Map processes to workers + for workerUID, processIDs := range workerToProcesses { + for _, processID := range processIDs { + if deviceMetrics, exists := processMetrics[processID]; exists { + for deviceUUID, metrics := range deviceMetrics { + if result[deviceUUID] == nil { + result[deviceUUID] = make(map[string]map[string]*api.WorkerMetrics) + } + if result[deviceUUID][workerUID] == nil { + result[deviceUUID][workerUID] = make(map[string]*api.WorkerMetrics) + } + result[deviceUUID][workerUID][processID] = metrics + metrics.WorkerUID = workerUID + } + } + } + } + + // Also include allocations that might not have process mappings yet + for _, allocation := range allocations { + workerUID := allocation.WorkerInfo.WorkerUID + if workerUID == "" { + workerUID = allocation.WorkerInfo.PodUID + } + if workerUID == "" { + continue + } + + // Process all devices in the allocation + for _, deviceInfo := range allocation.DeviceInfos { + if result[deviceInfo.UUID] == nil { + result[deviceInfo.UUID] = make(map[string]map[string]*api.WorkerMetrics) + } + if result[deviceInfo.UUID][workerUID] == nil { + result[deviceInfo.UUID][workerUID] = make(map[string]*api.WorkerMetrics) + } + } + } + + return result, nil +} + +func (w *WorkerController) ListWorkers() ([]*api.WorkerInfo, error) { + w.mu.RLock() + defer w.mu.RUnlock() + workerSnapshot := make([]*api.WorkerInfo, 0, len(w.workers)) + for _, worker := range w.workers { + if worker.Deleted { + continue + } + workerSnapshot = append(workerSnapshot, worker) + } + return workerSnapshot, nil +} diff --git a/internal/hypervisor/worker/state/ctx_migration.go b/internal/hypervisor/worker/state/ctx_migration.go new file mode 100644 index 00000000..4df0094f --- /dev/null +++ b/internal/hypervisor/worker/state/ctx_migration.go @@ -0,0 +1 @@ +package worker diff --git a/internal/hypervisor/worker/state/soft_limiter_shm.go b/internal/hypervisor/worker/state/soft_limiter_shm.go new file mode 100644 index 00000000..c548006b --- /dev/null +++ b/internal/hypervisor/worker/state/soft_limiter_shm.go @@ -0,0 +1,937 @@ +package worker + +import ( + "fmt" + "math" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + "unsafe" +) + +// Constants +const ( + MaxProcesses = 2048 + MaxDevices = 16 + MaxUUIDLen = 64 + ShmPathSuffix = "shm" +) + +// RefCountError represents errors in reference count operations +type RefCountError struct { + Type string +} + +func (e *RefCountError) Error() string { + return fmt.Sprintf("ref count error: %s", e.Type) +} + +var ( + ErrRefCountUnderflow = &RefCountError{Type: "underflow"} +) + +// PodIdentifier contains namespace and name +type PodIdentifier struct { + Namespace string + Name string +} + +// NewPodIdentifier creates a new PodIdentifier +func NewPodIdentifier(namespace, name string) *PodIdentifier { + return &PodIdentifier{ + Namespace: namespace, + Name: name, + } +} + +// ToPath returns the path for this pod identifier +func (p *PodIdentifier) ToPath(basePath string) string { + return filepath.Join(basePath, p.Namespace, p.Name) +} + +// FromShmFilePath parses a PodIdentifier from a full shared memory path +// Path format: {base_path}/{namespace}/{name}/shm +func FromShmFilePath(path string) (*PodIdentifier, error) { + path = filepath.Clean(path) + components := strings.Split(path, string(filepath.Separator)) + + // Filter out empty components (from leading/trailing separators) + var filtered []string + for _, comp := range components { + if comp != "" { + filtered = append(filtered, comp) + } + } + components = filtered + + // Need at least: namespace, name, and "shm" (3 components minimum) + if len(components) < 3 { + return nil, fmt.Errorf("invalid path format: %s (need at least namespace/name/shm)", path) + } + + // Extract the last 3 components: {namespace}/{name}/shm + compLen := len(components) + + // Verify the last component is "shm" + if components[compLen-1] != ShmPathSuffix { + return nil, fmt.Errorf("invalid path format: %s (last component must be 'shm')", path) + } + + namespace := components[compLen-3] + name := components[compLen-2] + + // Validate namespace and name are not empty + if namespace == "" || name == "" { + return nil, fmt.Errorf("invalid path format: %s (namespace and name must be non-empty)", path) + } + + return NewPodIdentifier(namespace, name), nil +} + +// String returns the string representation +func (p *PodIdentifier) String() string { + return fmt.Sprintf("%s/%s", p.Namespace, p.Name) +} + +// CleanupEmptyParentDirectories removes empty parent directories after removing a file +func CleanupEmptyParentDirectories(filePath string, stopAtPath *string) error { + parentDir := filepath.Dir(filePath) + + // Skip if we've reached the stop path + if stopAtPath != nil && parentDir == *stopAtPath { + return nil + } + + // Try to remove the immediate parent directory if it's empty + entries, err := os.ReadDir(parentDir) + if err != nil { + return err + } + + if len(entries) == 0 { + if err := os.Remove(parentDir); err != nil { + return err + } + + // Recursively try to remove parent directories if they're also empty + return CleanupEmptyParentDirectories(parentDir, stopAtPath) + } + + return nil +} + +// SharedDeviceInfoV1 is the legacy device state (without ERL) +type SharedDeviceInfoV1 struct { + AvailableCudaCores int32 + UpLimit uint32 + MemLimit uint64 + TotalCudaCores uint32 + PodMemoryUsed uint64 +} + +// SharedDeviceInfoV2 is the V2 device state with ERL support +type SharedDeviceInfoV2 struct { + UpLimit uint32 + MemLimit uint64 + TotalCudaCores uint32 + PodMemoryUsed uint64 + + // ERL (Elastic Rate Limiting) - PID-controlled token bucket + ERLTokenRefillRate uint64 // f64 stored as bits + ERLTokenCapacity uint64 // f64 stored as bits + ERLCurrentTokens uint64 // f64 stored as bits + ERLLastTokenUpdate uint64 // f64 stored as bits +} + +// SharedDeviceInfo is a type alias for backward compatibility +type SharedDeviceInfo = SharedDeviceInfoV2 + +// NewSharedDeviceInfoV1 creates a new V1 device info +func NewSharedDeviceInfoV1(totalCudaCores, upLimit uint32, memLimit uint64) *SharedDeviceInfoV1 { + return &SharedDeviceInfoV1{ + AvailableCudaCores: 0, + UpLimit: upLimit, + MemLimit: memLimit, + TotalCudaCores: totalCudaCores, + PodMemoryUsed: 0, + } +} + +// NewSharedDeviceInfoV2 creates a new V2 device info +func NewSharedDeviceInfoV2(totalCudaCores, upLimit uint32, memLimit uint64) *SharedDeviceInfoV2 { + return &SharedDeviceInfoV2{ + UpLimit: upLimit, + MemLimit: memLimit, + TotalCudaCores: totalCudaCores, + PodMemoryUsed: 0, + ERLTokenRefillRate: math.Float64bits(10.0), // Default 10 tokens/sec + ERLTokenCapacity: math.Float64bits(100.0), + ERLCurrentTokens: math.Float64bits(100.0), + ERLLastTokenUpdate: math.Float64bits(0.0), + } +} + +// DeviceEntryV1 is the legacy device entry +type DeviceEntryV1 struct { + UUID [MaxUUIDLen]byte + DeviceInfo SharedDeviceInfoV1 + IsActiveField uint32 + //nolint:unused // Padding field for memory alignment in shared memory structures + _padding [4]byte +} + +// DeviceEntryV2 is the V2 device entry with ERL +type DeviceEntryV2 struct { + UUID [MaxUUIDLen]byte + DeviceInfo SharedDeviceInfoV2 + IsActiveField uint32 +} + +// DeviceEntry is a type alias for backward compatibility +type DeviceEntry = DeviceEntryV2 + +// NewDeviceEntryV1 creates a new V1 device entry +func NewDeviceEntryV1() *DeviceEntryV1 { + return &DeviceEntryV1{ + DeviceInfo: *NewSharedDeviceInfoV1(0, 0, 0), + } +} + +// NewDeviceEntryV2 creates a new V2 device entry +func NewDeviceEntryV2() *DeviceEntryV2 { + return &DeviceEntryV2{ + DeviceInfo: *NewSharedDeviceInfoV2(0, 0, 0), + } +} + +// SetUUID sets the device UUID +func (d *DeviceEntryV1) SetUUID(uuid string) { + copyLen := len(uuid) + if copyLen > MaxUUIDLen-1 { + copyLen = MaxUUIDLen - 1 + } + + // Clear the UUID array + for i := range d.UUID { + d.UUID[i] = 0 + } + + // Copy the new UUID + copy(d.UUID[:], uuid[:copyLen]) +} + +// GetUUID gets the device UUID as a string +func (d *DeviceEntryV1) GetUUID() string { + nullPos := MaxUUIDLen - 1 + for i, b := range d.UUID { + if b == 0 { + nullPos = i + break + } + } + return string(d.UUID[:nullPos]) +} + +// IsActive checks if this entry is active +func (d *DeviceEntryV1) IsActive() bool { + return atomic.LoadUint32(&d.IsActiveField) != 0 +} + +// SetActive sets the active status +func (d *DeviceEntryV1) SetActive(active bool) { + var val uint32 + if active { + val = 1 + } + atomic.StoreUint32(&d.IsActiveField, val) +} + +// SetUUID sets the device UUID +func (d *DeviceEntryV2) SetUUID(uuid string) { + copyLen := len(uuid) + if copyLen > MaxUUIDLen-1 { + copyLen = MaxUUIDLen - 1 + } + + // Clear the UUID array + for i := range d.UUID { + d.UUID[i] = 0 + } + + // Copy the new UUID + copy(d.UUID[:], uuid[:copyLen]) +} + +// GetUUID gets the device UUID as a string +func (d *DeviceEntryV2) GetUUID() string { + nullPos := MaxUUIDLen - 1 + for i, b := range d.UUID { + if b == 0 { + nullPos = i + break + } + } + return string(d.UUID[:nullPos]) +} + +// IsActive checks if this entry is active +func (d *DeviceEntryV2) IsActive() bool { + return atomic.LoadUint32(&d.IsActiveField) != 0 +} + +// SetActive sets the active status +func (d *DeviceEntryV2) SetActive(active bool) { + var val uint32 + if active { + val = 1 + } + atomic.StoreUint32(&d.IsActiveField, val) +} + +// DeviceConfig contains device configuration information +type DeviceConfig struct { + DeviceIdx uint32 + DeviceUUID string + UpLimit uint32 + MemLimit uint64 + SMCount uint32 + MaxThreadPerSM uint32 + TotalCudaCores uint32 +} + +// SharedDeviceStateV1 is the V1 shared device state +type SharedDeviceStateV1 struct { + Devices [MaxDevices]DeviceEntryV1 + DeviceCountField uint32 + LastHeartbeat uint64 + PIDs *ShmMutex[*PIDSet] +} + +// SharedDeviceStateV2 is the V2 shared device state with ERL +type SharedDeviceStateV2 struct { + Devices [MaxDevices]DeviceEntryV2 + DeviceCountField uint32 + LastHeartbeat uint64 + PIDs *ShmMutex[*PIDSet] +} + +// SharedDeviceState is a versioned enum for compatibility +type SharedDeviceState struct { + V1 *SharedDeviceStateV1 + V2 *SharedDeviceStateV2 +} + +// Version returns the version number +func (s *SharedDeviceState) Version() uint32 { + if s.V1 != nil { + return 1 + } + return 2 +} + +// HasERL checks if this state uses ERL features +func (s *SharedDeviceState) HasERL() bool { + return s.V2 != nil +} + +// NewSharedDeviceStateV1 creates a new V1 state +func NewSharedDeviceStateV1(configs []DeviceConfig) (*SharedDeviceStateV1, error) { + now := uint64(time.Now().Unix()) + + state := &SharedDeviceStateV1{ + DeviceCountField: uint32(len(configs)), + LastHeartbeat: now, + PIDs: NewShmMutex(NewPIDSet()), + } + + for _, config := range configs { + deviceIdx := int(config.DeviceIdx) + if deviceIdx >= MaxDevices { + return nil, fmt.Errorf("device index %d exceeds maximum devices %d", deviceIdx, MaxDevices) + } + + entry := &state.Devices[deviceIdx] + entry.SetUUID(config.DeviceUUID) + entry.DeviceInfo.TotalCudaCores = config.TotalCudaCores + entry.DeviceInfo.AvailableCudaCores = int32(config.TotalCudaCores) + entry.DeviceInfo.UpLimit = config.UpLimit + entry.DeviceInfo.MemLimit = config.MemLimit + entry.SetActive(true) + } + + return state, nil +} + +// NewSharedDeviceStateV2 creates a new V2 state +func NewSharedDeviceStateV2(configs []DeviceConfig) (*SharedDeviceStateV2, error) { + now := uint64(time.Now().Unix()) + + state := &SharedDeviceStateV2{ + DeviceCountField: uint32(len(configs)), + LastHeartbeat: now, + PIDs: NewShmMutex(NewPIDSet()), + } + + for _, config := range configs { + deviceIdx := int(config.DeviceIdx) + if deviceIdx >= MaxDevices { + return nil, fmt.Errorf("device index %d exceeds maximum devices %d", deviceIdx, MaxDevices) + } + + entry := &state.Devices[deviceIdx] + entry.SetUUID(config.DeviceUUID) + entry.DeviceInfo.TotalCudaCores = config.TotalCudaCores + entry.DeviceInfo.UpLimit = config.UpLimit + entry.DeviceInfo.MemLimit = config.MemLimit + + // Initialize ERL fields with defaults + entry.DeviceInfo.ERLTokenCapacity = math.Float64bits(100.0) + entry.DeviceInfo.ERLTokenRefillRate = math.Float64bits(10.0) + entry.DeviceInfo.ERLCurrentTokens = math.Float64bits(100.0) + entry.DeviceInfo.ERLLastTokenUpdate = math.Float64bits(float64(now)) + + entry.SetActive(true) + } + + return state, nil +} + +// NewSharedDeviceState creates a new SharedDeviceState (defaults to V2) +func NewSharedDeviceState(configs []DeviceConfig) (*SharedDeviceState, error) { + v2, err := NewSharedDeviceStateV2(configs) + if err != nil { + return nil, err + } + return &SharedDeviceState{V2: v2}, nil +} + +// HasDevice checks if a device exists at the given index +func (s *SharedDeviceStateV1) HasDevice(index int) bool { + return index < MaxDevices && s.Devices[index].IsActive() +} + +// DeviceCount returns the number of devices +func (s *SharedDeviceStateV1) DeviceCount() int { + return int(atomic.LoadUint32(&s.DeviceCountField)) +} + +// UpdateHeartbeat updates the heartbeat timestamp +func (s *SharedDeviceStateV1) UpdateHeartbeat(timestamp uint64) { + atomic.StoreUint64(&s.LastHeartbeat, timestamp) +} + +// GetLastHeartbeat returns the last heartbeat timestamp +func (s *SharedDeviceStateV1) GetLastHeartbeat() uint64 { + return atomic.LoadUint64(&s.LastHeartbeat) +} + +// IsHealthy checks if the shared memory is healthy based on heartbeat +func (s *SharedDeviceStateV1) IsHealthy(timeout time.Duration) bool { + now := uint64(time.Now().Unix()) + lastHeartbeat := s.GetLastHeartbeat() + + if lastHeartbeat == 0 { + return false + } + + if lastHeartbeat > now { + return false + } + + return now-lastHeartbeat <= uint64(timeout.Seconds()) +} + +// AddPID adds a PID to the set +func (s *SharedDeviceStateV1) AddPID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.InsertIfAbsent(pid) +} + +// RemovePID removes a PID from the set +func (s *SharedDeviceStateV1) RemovePID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.RemoveValue(pid) +} + +// GetAllPIDs returns all PIDs currently stored +func (s *SharedDeviceStateV1) GetAllPIDs() []int { + s.PIDs.Lock() + defer s.PIDs.Unlock() + return s.PIDs.Value.Values() +} + +// CleanupOrphanedLocks cleans up any orphaned locks +func (s *SharedDeviceStateV1) CleanupOrphanedLocks() { + s.PIDs.CleanupOrphanedLock() +} + +// HasDevice checks if a device exists at the given index +func (s *SharedDeviceStateV2) HasDevice(index int) bool { + return index < MaxDevices && s.Devices[index].IsActive() +} + +// DeviceCount returns the number of devices +func (s *SharedDeviceStateV2) DeviceCount() int { + return int(atomic.LoadUint32(&s.DeviceCountField)) +} + +// UpdateHeartbeat updates the heartbeat timestamp +func (s *SharedDeviceStateV2) UpdateHeartbeat(timestamp uint64) { + atomic.StoreUint64(&s.LastHeartbeat, timestamp) +} + +// GetLastHeartbeat returns the last heartbeat timestamp +func (s *SharedDeviceStateV2) GetLastHeartbeat() uint64 { + return atomic.LoadUint64(&s.LastHeartbeat) +} + +// IsHealthy checks if the shared memory is healthy based on heartbeat +func (s *SharedDeviceStateV2) IsHealthy(timeout time.Duration) bool { + now := uint64(time.Now().Unix()) + lastHeartbeat := s.GetLastHeartbeat() + + if lastHeartbeat == 0 { + return false + } + + if lastHeartbeat > now { + return false + } + + return now-lastHeartbeat <= uint64(timeout.Seconds()) +} + +// AddPID adds a PID to the set +func (s *SharedDeviceStateV2) AddPID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.InsertIfAbsent(pid) +} + +// RemovePID removes a PID from the set +func (s *SharedDeviceStateV2) RemovePID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.RemoveValue(pid) +} + +// GetAllPIDs returns all PIDs currently stored +func (s *SharedDeviceStateV2) GetAllPIDs() []int { + s.PIDs.Lock() + defer s.PIDs.Unlock() + return s.PIDs.Value.Values() +} + +// CleanupOrphanedLocks cleans up any orphaned locks +func (s *SharedDeviceStateV2) CleanupOrphanedLocks() { + s.PIDs.CleanupOrphanedLock() +} + +// Helper methods for SharedDeviceState that delegate to the appropriate version + +// HasDevice checks if a device exists +func (s *SharedDeviceState) HasDevice(index int) bool { + if s.V1 != nil { + return s.V1.HasDevice(index) + } + return s.V2.HasDevice(index) +} + +// DeviceCount returns the number of devices +func (s *SharedDeviceState) DeviceCount() int { + if s.V1 != nil { + return s.V1.DeviceCount() + } + return s.V2.DeviceCount() +} + +// UpdateHeartbeat updates the heartbeat +func (s *SharedDeviceState) UpdateHeartbeat(timestamp uint64) { + if s.V1 != nil { + s.V1.UpdateHeartbeat(timestamp) + } else { + s.V2.UpdateHeartbeat(timestamp) + } +} + +// GetLastHeartbeat returns the last heartbeat +func (s *SharedDeviceState) GetLastHeartbeat() uint64 { + if s.V1 != nil { + return s.V1.GetLastHeartbeat() + } + return s.V2.GetLastHeartbeat() +} + +// IsHealthy checks if healthy +func (s *SharedDeviceState) IsHealthy(timeout time.Duration) bool { + if s.V1 != nil { + return s.V1.IsHealthy(timeout) + } + return s.V2.IsHealthy(timeout) +} + +// AddPID adds a PID +func (s *SharedDeviceState) AddPID(pid int) { + if s.V1 != nil { + s.V1.AddPID(pid) + } else { + s.V2.AddPID(pid) + } +} + +// RemovePID removes a PID +func (s *SharedDeviceState) RemovePID(pid int) { + if s.V1 != nil { + s.V1.RemovePID(pid) + } else { + s.V2.RemovePID(pid) + } +} + +// GetAllPIDs returns all PIDs +func (s *SharedDeviceState) GetAllPIDs() []int { + if s.V1 != nil { + return s.V1.GetAllPIDs() + } + return s.V2.GetAllPIDs() +} + +// CleanupOrphanedLocks cleans up orphaned locks +func (s *SharedDeviceState) CleanupOrphanedLocks() { + if s.V1 != nil { + s.V1.CleanupOrphanedLocks() + } else { + s.V2.CleanupOrphanedLocks() + } +} + +// SetPodMemoryUsed sets pod memory used for a device +func (s *SharedDeviceState) SetPodMemoryUsed(index int, memory uint64) bool { + if s.V1 != nil { + if index >= MaxDevices || !s.V1.Devices[index].IsActive() { + return false + } + atomic.StoreUint64(&s.V1.Devices[index].DeviceInfo.PodMemoryUsed, memory) + return true + } + if index >= MaxDevices || !s.V2.Devices[index].IsActive() { + return false + } + atomic.StoreUint64(&s.V2.Devices[index].DeviceInfo.PodMemoryUsed, memory) + return true +} + +// ERL token bucket operations for SharedDeviceInfoV2 + +// GetERLTokenCapacity returns the token capacity +func (d *SharedDeviceInfoV2) GetERLTokenCapacity() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLTokenCapacity)) +} + +// SetERLTokenCapacity sets the token capacity +func (d *SharedDeviceInfoV2) SetERLTokenCapacity(capacity float64) { + atomic.StoreUint64(&d.ERLTokenCapacity, math.Float64bits(capacity)) +} + +// GetERLTokenRefillRate returns the refill rate +func (d *SharedDeviceInfoV2) GetERLTokenRefillRate() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLTokenRefillRate)) +} + +// SetERLTokenRefillRate sets the refill rate +func (d *SharedDeviceInfoV2) SetERLTokenRefillRate(rate float64) { + atomic.StoreUint64(&d.ERLTokenRefillRate, math.Float64bits(rate)) +} + +// GetERLCurrentTokens returns the current tokens +func (d *SharedDeviceInfoV2) GetERLCurrentTokens() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLCurrentTokens)) +} + +// SetERLCurrentTokens sets the current tokens +func (d *SharedDeviceInfoV2) SetERLCurrentTokens(tokens float64) { + atomic.StoreUint64(&d.ERLCurrentTokens, math.Float64bits(tokens)) +} + +// GetERLLastTokenUpdate returns the last token update timestamp +func (d *SharedDeviceInfoV2) GetERLLastTokenUpdate() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLLastTokenUpdate)) +} + +// SetERLLastTokenUpdate sets the last token update timestamp +func (d *SharedDeviceInfoV2) SetERLLastTokenUpdate(timestamp float64) { + atomic.StoreUint64(&d.ERLLastTokenUpdate, math.Float64bits(timestamp)) +} + +// LoadERLTokenState loads the token state atomically +func (d *SharedDeviceInfoV2) LoadERLTokenState() (float64, float64) { + return d.GetERLCurrentTokens(), d.GetERLLastTokenUpdate() +} + +// StoreERLTokenState stores the token state atomically +func (d *SharedDeviceInfoV2) StoreERLTokenState(tokens, timestamp float64) { + d.SetERLCurrentTokens(tokens) + d.SetERLLastTokenUpdate(timestamp) +} + +// LoadERLQuota loads the quota configuration +func (d *SharedDeviceInfoV2) LoadERLQuota() (float64, float64) { + return d.GetERLTokenCapacity(), d.GetERLTokenRefillRate() +} + +// FetchSubERLTokens atomically subtracts tokens and returns the value before subtraction +func (d *SharedDeviceInfoV2) FetchSubERLTokens(cost float64) float64 { + for { + currentBits := atomic.LoadUint64(&d.ERLCurrentTokens) + current := math.Float64frombits(currentBits) + + if current < cost { + return current + } + + newValue := math.Max(0.0, current-cost) + newBits := math.Float64bits(newValue) + + if atomic.CompareAndSwapUint64(&d.ERLCurrentTokens, currentBits, newBits) { + return current + } + } +} + +// FetchAddERLTokens atomically adds tokens (capped at capacity) and returns the value before addition +func (d *SharedDeviceInfoV2) FetchAddERLTokens(amount float64) float64 { + capacity := d.GetERLTokenCapacity() + + for { + currentBits := atomic.LoadUint64(&d.ERLCurrentTokens) + current := math.Float64frombits(currentBits) + + newValue := math.Max(0.0, math.Min(capacity, current+amount)) + newBits := math.Float64bits(newValue) + + if atomic.CompareAndSwapUint64(&d.ERLCurrentTokens, currentBits, newBits) { + return current + } + } +} + +// PIDSet is a set of process IDs with a fixed capacity +type PIDSet struct { + values []int + mu sync.Mutex //nolint:unused // Used via ShmMutex wrapper +} + +// NewPIDSet creates a new PID set +func NewPIDSet() *PIDSet { + return &PIDSet{ + values: make([]int, 0, MaxProcesses), + } +} + +// InsertIfAbsent inserts a value if it's not already present +func (s *PIDSet) InsertIfAbsent(pid int) bool { + for _, v := range s.values { + if v == pid { + return false + } + } + if len(s.values) >= MaxProcesses { + return false + } + s.values = append(s.values, pid) + return true +} + +// RemoveValue removes a value from the set +func (s *PIDSet) RemoveValue(pid int) bool { + for i, v := range s.values { + if v == pid { + s.values = append(s.values[:i], s.values[i+1:]...) + return true + } + } + return false +} + +// Values returns all values in the set +func (s *PIDSet) Values() []int { + result := make([]int, len(s.values)) + copy(result, s.values) + return result +} + +// ShmMutex is a shared memory mutex wrapper +type ShmMutex[T any] struct { + mu sync.Mutex + Value T +} + +// NewShmMutex creates a new shared memory mutex +func NewShmMutex[T any](value T) *ShmMutex[T] { + return &ShmMutex[T]{ + Value: value, + } +} + +// Lock locks the mutex +func (m *ShmMutex[T]) Lock() { + m.mu.Lock() +} + +// Unlock unlocks the mutex +func (m *ShmMutex[T]) Unlock() { + m.mu.Unlock() +} + +// CleanupOrphanedLock cleans up orphaned locks (placeholder for now) +func (m *ShmMutex[T]) CleanupOrphanedLock() { + // In a real implementation, this would check for dead processes + // and release their locks. For now, it's a no-op. +} + +// SharedMemoryHandle manages a shared memory mapping +type SharedMemoryHandle struct { + path string + data []byte + state *SharedDeviceState + file *os.File + fileSize int64 +} + +// CreateSharedMemoryHandle creates a new shared memory handle +func CreateSharedMemoryHandle(podPath string, configs []DeviceConfig) (*SharedMemoryHandle, error) { + shmPath := filepath.Join(podPath, ShmPathSuffix) + + // Create directory if it doesn't exist + if err := os.MkdirAll(podPath, 0755); err != nil { + return nil, fmt.Errorf("failed to create directory: %w", err) + } + + // Calculate size needed for SharedDeviceStateV2 + stateSize := int(unsafe.Sizeof(SharedDeviceStateV2{})) + + // Create or open the file + file, err := os.OpenFile(shmPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return nil, fmt.Errorf("failed to create file: %w", err) + } + + // Truncate to the required size + if err := file.Truncate(int64(stateSize)); err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to truncate file: %w", err) + } + + // Memory map the file + data, err := syscall.Mmap(int(file.Fd()), 0, stateSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to mmap: %w", err) + } + + // Initialize the state + state, err := NewSharedDeviceStateV2(configs) + if err != nil { + _ = syscall.Munmap(data) + _ = file.Close() + return nil, err + } + + // Copy the state to the mapped memory + stateBytes := (*[1 << 30]byte)(unsafe.Pointer(state))[:stateSize:stateSize] + copy(data, stateBytes) + + // Get a pointer to the mapped state + mappedState := (*SharedDeviceStateV2)(unsafe.Pointer(&data[0])) + + // Initialize the PIDs mutex in the mapped memory + // Note: This is a simplified version - in a real implementation, + // you'd need to properly initialize the mutex for shared memory + mappedState.PIDs = NewShmMutex(NewPIDSet()) + + return &SharedMemoryHandle{ + path: shmPath, + data: data, + state: &SharedDeviceState{V2: mappedState}, + file: file, + fileSize: int64(stateSize), + }, nil +} + +// OpenSharedMemoryHandle opens an existing shared memory handle +func OpenSharedMemoryHandle(podPath string) (*SharedMemoryHandle, error) { + shmPath := filepath.Join(podPath, ShmPathSuffix) + + // Open the file + file, err := os.OpenFile(shmPath, os.O_RDWR, 0666) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + + // Get file size + stat, err := file.Stat() + if err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to stat file: %w", err) + } + + fileSize := stat.Size() + + // Memory map the file + data, err := syscall.Mmap(int(file.Fd()), 0, int(fileSize), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to mmap: %w", err) + } + + // Get a pointer to the mapped state (assume V2 for now) + mappedState := (*SharedDeviceStateV2)(unsafe.Pointer(&data[0])) + + return &SharedMemoryHandle{ + path: shmPath, + data: data, + state: &SharedDeviceState{V2: mappedState}, + file: file, + fileSize: fileSize, + }, nil +} + +// GetState returns the shared device state +func (h *SharedMemoryHandle) GetState() *SharedDeviceState { + return h.state +} + +// Close closes the shared memory handle +func (h *SharedMemoryHandle) Close() error { + if h.data != nil { + _ = syscall.Munmap(h.data) + h.data = nil + } + if h.file != nil { + _ = h.file.Close() + h.file = nil + } + return nil +} + +// Cleanup removes the shared memory file and cleans up empty directories +func (h *SharedMemoryHandle) Cleanup(stopAtPath *string) error { + if err := h.Close(); err != nil { + return err + } + + if err := os.Remove(h.path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove file: %w", err) + } + + if stopAtPath != nil { + return CleanupEmptyParentDirectories(h.path, stopAtPath) + } + return CleanupEmptyParentDirectories(h.path, nil) +} diff --git a/internal/hypervisor/worker/state/soft_limiter_shm_test.go b/internal/hypervisor/worker/state/soft_limiter_shm_test.go new file mode 100644 index 00000000..41d67ff7 --- /dev/null +++ b/internal/hypervisor/worker/state/soft_limiter_shm_test.go @@ -0,0 +1,648 @@ +package worker + +import ( + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testShmBasePath = "/tmp/test_shm" + testDeviceIdx = uint32(0) + testTotalCores = uint32(1024) + testUpLimit = uint32(80) + testMemLimit = uint64(1024 * 1024 * 1024) // 1GB +) + +func createTestConfigs() []DeviceConfig { + return []DeviceConfig{ + { + DeviceIdx: testDeviceIdx, + DeviceUUID: "test-device-uuid", + UpLimit: testUpLimit, + MemLimit: testMemLimit, + TotalCudaCores: testTotalCores, + SMCount: 10, + MaxThreadPerSM: 1024, + }, + } +} + +func TestDeviceEntryBasicOperations(t *testing.T) { + entry := NewDeviceEntryV2() + + // Test UUID operations + entry.SetUUID("test-uuid-123") + assert.Equal(t, "test-uuid-123", entry.GetUUID()) + + // Test active status + assert.False(t, entry.IsActive()) + entry.SetActive(true) + assert.True(t, entry.IsActive()) + entry.SetActive(false) + assert.False(t, entry.IsActive()) + + // Test very long UUID handling + longUUID := strings.Repeat("a", MaxUUIDLen+10) + entry.SetUUID(longUUID) + storedUUID := entry.GetUUID() + assert.Less(t, len(storedUUID), MaxUUIDLen) + assert.Contains(t, storedUUID, "a") +} + +func TestSharedDeviceStateCreationAndBasicOps(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceState(configs) + require.NoError(t, err) + + // Test initial state (V2 by default) + assert.Equal(t, uint32(2), state.Version()) + assert.Equal(t, 1, state.DeviceCount()) + + // Test that heartbeat is initialized to current time (should be non-zero and recent) + heartbeat := state.GetLastHeartbeat() + assert.Greater(t, heartbeat, uint64(0)) + now := uint64(time.Now().Unix()) + assert.Less(t, now-heartbeat, uint64(2)) // Should be within 2 seconds + + // Should be healthy since heartbeat was just set + assert.True(t, state.IsHealthy(30*time.Second)) + + // Test device exists by index + deviceIdx := int(configs[0].DeviceIdx) + assert.True(t, state.HasDevice(deviceIdx)) +} + +func TestSharedDeviceStateHeartbeatFunctionality(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + // Test initial healthy state (heartbeat is initialized to current time) + assert.True(t, state.IsHealthy(30*time.Second)) + + // Test setting heartbeat to a specific time + now := uint64(time.Now().Unix()) + state.UpdateHeartbeat(now) + assert.Equal(t, now, state.GetLastHeartbeat()) + assert.True(t, state.IsHealthy(30*time.Second)) + + // Test old heartbeat (should be unhealthy) + state.UpdateHeartbeat(now - 60) + assert.False(t, state.IsHealthy(30*time.Second)) +} + +func TestSharedDeviceInfoAtomicOperations(t *testing.T) { + // Test V1 device info (has available_cores) + deviceInfoV1 := NewSharedDeviceInfoV1(testTotalCores, testUpLimit, testMemLimit) + + // Test available cores operations (V1 only) + deviceInfoV1.AvailableCudaCores = 512 + assert.Equal(t, int32(512), deviceInfoV1.AvailableCudaCores) + + deviceInfoV1.AvailableCudaCores = 600 + assert.Equal(t, int32(600), deviceInfoV1.AvailableCudaCores) + + // Test negative values + deviceInfoV1.AvailableCudaCores = -50 + assert.Equal(t, int32(-50), deviceInfoV1.AvailableCudaCores) + + // Test other fields + deviceInfoV1.UpLimit = 90 + assert.Equal(t, uint32(90), deviceInfoV1.UpLimit) + + deviceInfoV1.MemLimit = 2 * 1024 * 1024 * 1024 + assert.Equal(t, uint64(2*1024*1024*1024), deviceInfoV1.MemLimit) + + // Test V2 device info (has ERL fields) + deviceInfoV2 := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + // Test ERL fields - refill rate is now the control parameter + deviceInfoV2.SetERLTokenRefillRate(15.0) + assert.Equal(t, 15.0, deviceInfoV2.GetERLTokenRefillRate()) + + deviceInfoV2.SetERLTokenCapacity(100.0) + assert.Equal(t, 100.0, deviceInfoV2.GetERLTokenCapacity()) + + deviceInfoV2.PodMemoryUsed = 512 * 1024 * 1024 + assert.Equal(t, uint64(512*1024*1024), deviceInfoV2.PodMemoryUsed) +} + +func TestERLTokenBucketPreservesTokensWhenInsufficient(t *testing.T) { + deviceInfo := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + + deviceInfo.SetERLCurrentTokens(1.5) + before := deviceInfo.FetchSubERLTokens(2.0) + assert.Equal(t, 1.5, before) + assert.Equal(t, 1.5, deviceInfo.GetERLCurrentTokens()) + + deviceInfo.SetERLCurrentTokens(5.0) + beforeSuccess := deviceInfo.FetchSubERLTokens(2.0) + assert.Equal(t, 5.0, beforeSuccess) + assert.Equal(t, 3.0, deviceInfo.GetERLCurrentTokens()) +} + +func TestSharedMemoryHandleCreateAndOpen(t *testing.T) { + configs := createTestConfigs() + identifier := NewPodIdentifier("handle_create_open", "test") + + podPath := identifier.ToPath(testShmBasePath) + defer func() { + _ = os.RemoveAll(podPath) + }() + + // Create shared memory + handle1, err := CreateSharedMemoryHandle(podPath, configs) + require.NoError(t, err) + defer func() { + _ = handle1.Close() + }() + + state1 := handle1.GetState() + assert.Equal(t, uint32(2), state1.Version()) + assert.Equal(t, 1, state1.DeviceCount()) + + // Verify shared memory file exists after creation + assert.True(t, fileExists(filepath.Join(podPath, ShmPathSuffix))) + + // Open existing shared memory + handle2, err := OpenSharedMemoryHandle(podPath) + require.NoError(t, err) + defer func() { + _ = handle2.Close() + }() + + state2 := handle2.GetState() + assert.Equal(t, uint32(2), state2.Version()) + assert.Equal(t, 1, state2.DeviceCount()) + + // Verify they access the same memory + deviceIdx := int(configs[0].DeviceIdx) + state1.SetPodMemoryUsed(deviceIdx, 42) + memory := state2.GetPodMemoryUsed(deviceIdx) + assert.Equal(t, uint64(42), memory) +} + +func TestSharedMemoryHandleErrorHandling(t *testing.T) { + _, err := OpenSharedMemoryHandle("non_existent_memory") + assert.Error(t, err) +} + +func TestConcurrentDeviceAccess(t *testing.T) { + configs := createTestConfigs() + identifier := NewPodIdentifier("concurrent_access", "test") + podPath := identifier.ToPath(testShmBasePath) + defer func() { + _ = os.RemoveAll(podPath) + }() + + handle, err := CreateSharedMemoryHandle(podPath, configs) + require.NoError(t, err) + defer func() { + _ = handle.Close() + }() + + deviceIdx := int(configs[0].DeviceIdx) + var wg sync.WaitGroup + numGoroutines := 5 + iterations := 20 + + // Spawn multiple goroutines doing concurrent access + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + state := handle.GetState() + + for j := 0; j < iterations; j++ { + value := uint64(id*iterations + j) + state.SetPodMemoryUsed(deviceIdx, value) + + time.Sleep(time.Millisecond) + + readValue := state.GetPodMemoryUsed(deviceIdx) + // Value should be valid (set by some goroutine) + assert.GreaterOrEqual(t, readValue, uint64(0)) + assert.Less(t, readValue, uint64(100)) + } + }(i) + } + + wg.Wait() +} + +func TestDeviceIterationMethods(t *testing.T) { + // Create multiple device configurations + configs := []DeviceConfig{ + { + DeviceIdx: 0, + DeviceUUID: "device-0", + UpLimit: 80, + MemLimit: 1024 * 1024 * 1024, + TotalCudaCores: 1024, + SMCount: 10, + MaxThreadPerSM: 1024, + }, + { + DeviceIdx: 2, + DeviceUUID: "device-2", + UpLimit: 70, + MemLimit: 2 * 1024 * 1024 * 1024, + TotalCudaCores: 2048, + SMCount: 20, + MaxThreadPerSM: 1024, + }, + } + + state, err := NewSharedDeviceState(configs) + require.NoError(t, err) + + // Test iterating over active devices + activeCount := 0 + for i := 0; i < MaxDevices; i++ { + if state.HasDevice(i) { + activeCount++ + } + } + assert.Equal(t, 2, activeCount) + + // Check that indices match the device_idx from configs + assert.True(t, state.HasDevice(0)) + assert.True(t, state.HasDevice(2)) + + // Test deactivating a device and checking + if state.V2 != nil { + state.V2.Devices[2].SetActive(false) + assert.False(t, state.HasDevice(2)) + assert.True(t, state.HasDevice(0)) + } +} + +func TestPIDSetDeduplicatesOnAdd(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + // Add the same pid multiple times + state.AddPID(1234) + state.AddPID(1234) + state.AddPID(1234) + + pids := state.GetAllPIDs() + assert.Equal(t, 1, len(pids), "should contain only one PID after duplicate adds") + if len(pids) > 0 { + assert.Equal(t, 1234, pids[0]) + } +} + +func TestPIDRemoveByValueWorks(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + state.AddPID(111) + state.AddPID(222) + state.AddPID(333) + + state.RemovePID(222) + + pids := state.GetAllPIDs() + assert.Equal(t, 2, len(pids), "should remove the specified PID") + assert.Contains(t, pids, 111) + assert.Contains(t, pids, 333) + assert.NotContains(t, pids, 222) +} + +func TestPIDSetCapacityAndDuplicateBehavior(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + // Fill to capacity with unique PIDs + for pid := 0; pid < MaxProcesses; pid++ { + state.AddPID(pid) + } + + pids := state.GetAllPIDs() + assert.Equal(t, MaxProcesses, len(pids), "should reach max capacity with unique PIDs") + + // Adding an existing PID should not change the count + state.AddPID(0) + pidsAfterDup := state.GetAllPIDs() + assert.Equal(t, MaxProcesses, len(pidsAfterDup), "should remain at capacity when inserting duplicate") +} + +func TestCleanupEmptyParentDirectories(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "test_cleanup_*") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(tempDir) + }() + + // Create nested directory structure: base/namespace/podname/ + namespaceDir := filepath.Join(tempDir, "test-namespace") + podDir := filepath.Join(namespaceDir, "test-pod") + err = os.MkdirAll(podDir, 0755) + require.NoError(t, err) + + // Create a file in the pod directory + testFile := filepath.Join(podDir, ShmPathSuffix) + err = os.WriteFile(testFile, []byte("test data"), 0644) + require.NoError(t, err) + + // Verify structure exists + assert.True(t, fileExists(testFile)) + assert.True(t, fileExists(podDir)) + assert.True(t, fileExists(namespaceDir)) + + // Remove the file + err = os.Remove(testFile) + require.NoError(t, err) + + // Test cleanup without stop_at_path (should remove all empty dirs) + err = CleanupEmptyParentDirectories(testFile, nil) + assert.NoError(t, err) + + // Pod directory should be removed + assert.False(t, fileExists(podDir)) + // Namespace directory should be removed + assert.False(t, fileExists(namespaceDir)) +} + +func TestCleanupEmptyParentDirectoriesWithStopAtPath(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "test_cleanup_*") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(tempDir) + }() + + // Create nested directory structure: base/namespace/podname/ + namespaceDir := filepath.Join(tempDir, "test-namespace") + podDir := filepath.Join(namespaceDir, "test-pod") + err = os.MkdirAll(podDir, 0755) + require.NoError(t, err) + + // Create a file in the pod directory + testFile := filepath.Join(podDir, ShmPathSuffix) + err = os.WriteFile(testFile, []byte("test data"), 0644) + require.NoError(t, err) + + // Remove the file + err = os.Remove(testFile) + require.NoError(t, err) + + // Test cleanup with stop_at_path set to base_path + stopAtPath := tempDir + err = CleanupEmptyParentDirectories(testFile, &stopAtPath) + assert.NoError(t, err) + + // Pod directory should be removed + assert.False(t, fileExists(podDir)) + // Namespace directory should be removed + assert.False(t, fileExists(namespaceDir)) + // Base directory should remain (it's the stop_at_path) + assert.True(t, fileExists(tempDir)) +} + +func TestCleanupEmptyParentDirectoriesStopsAtNonEmptyDir(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "test_cleanup_*") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(tempDir) + }() + + // Create nested directory structure: base/namespace/podname/ + namespaceDir := filepath.Join(tempDir, "test-namespace") + podDir := filepath.Join(namespaceDir, "test-pod") + err = os.MkdirAll(podDir, 0755) + require.NoError(t, err) + + // Create two files in the pod directory + testFile1 := filepath.Join(podDir, ShmPathSuffix) + testFile2 := filepath.Join(podDir, "other_file") + err = os.WriteFile(testFile1, []byte("test data"), 0644) + require.NoError(t, err) + err = os.WriteFile(testFile2, []byte("other data"), 0644) + require.NoError(t, err) + + // Remove only one file + err = os.Remove(testFile1) + require.NoError(t, err) + + // Test cleanup - should not remove pod directory since it's not empty + stopAtPath := tempDir + err = CleanupEmptyParentDirectories(testFile1, &stopAtPath) + assert.NoError(t, err) + + // Pod directory should still exist (not empty) + assert.True(t, fileExists(podDir)) + assert.True(t, fileExists(namespaceDir)) + assert.True(t, fileExists(testFile2)) +} + +func TestPodIdentifierFromShmFilePath(t *testing.T) { + tests := []struct { + name string + path string + expectError bool + expectedNS string + expectedName string + }{ + { + name: "valid path", + path: "/base/namespace/podname/shm", + expectError: false, + expectedNS: "namespace", + expectedName: "podname", + }, + { + name: "invalid path - too short", + path: "/base/shm", + expectError: true, + }, + { + name: "invalid path - only two components", + path: "/namespace/shm", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pid, err := FromShmFilePath(tt.path) + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, pid) + } else { + assert.NoError(t, err) + assert.NotNil(t, pid) + assert.Equal(t, tt.expectedNS, pid.Namespace) + assert.Equal(t, tt.expectedName, pid.Name) + } + }) + } +} + +func TestPodIdentifierToPath(t *testing.T) { + pid := NewPodIdentifier("test-namespace", "test-pod") + path := pid.ToPath("/base") + expected := filepath.Join("/base", "test-namespace", "test-pod") + assert.Equal(t, expected, path) +} + +func TestSharedDeviceStateSetPodMemoryUsed(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceState(configs) + require.NoError(t, err) + + deviceIdx := int(configs[0].DeviceIdx) + + // Test setting memory + success := state.SetPodMemoryUsed(deviceIdx, 1024*1024*1024) + assert.True(t, success) + + // Test setting memory for non-existent device + success = state.SetPodMemoryUsed(999, 1024) + assert.False(t, success) +} + +func TestERLTokenOperations(t *testing.T) { + deviceInfo := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + + // Test initial values + assert.Equal(t, 10.0, deviceInfo.GetERLTokenRefillRate()) + assert.Equal(t, 100.0, deviceInfo.GetERLTokenCapacity()) + assert.Equal(t, 100.0, deviceInfo.GetERLCurrentTokens()) + + // Test setting values + deviceInfo.SetERLTokenRefillRate(50.0) + deviceInfo.SetERLTokenCapacity(200.0) + deviceInfo.SetERLCurrentTokens(150.0) + + assert.Equal(t, 50.0, deviceInfo.GetERLTokenRefillRate()) + assert.Equal(t, 200.0, deviceInfo.GetERLTokenCapacity()) + assert.Equal(t, 150.0, deviceInfo.GetERLCurrentTokens()) + + // Test LoadERLTokenState + tokens, timestamp := deviceInfo.LoadERLTokenState() + assert.Equal(t, 150.0, tokens) + assert.Equal(t, 0.0, timestamp) // Initial timestamp is 0.0 + + // Test StoreERLTokenState + deviceInfo.StoreERLTokenState(175.0, 12345.0) + tokens, timestamp = deviceInfo.LoadERLTokenState() + assert.Equal(t, 175.0, tokens) + assert.Equal(t, 12345.0, timestamp) + + // Test LoadERLQuota + capacity, rate := deviceInfo.LoadERLQuota() + assert.Equal(t, 200.0, capacity) + assert.Equal(t, 50.0, rate) +} + +func TestFetchAddERLTokens(t *testing.T) { + deviceInfo := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + deviceInfo.SetERLTokenCapacity(100.0) + deviceInfo.SetERLCurrentTokens(50.0) + + // Add tokens + before := deviceInfo.FetchAddERLTokens(30.0) + assert.Equal(t, 50.0, before) + assert.Equal(t, 80.0, deviceInfo.GetERLCurrentTokens()) + + // Add tokens that would exceed capacity + before = deviceInfo.FetchAddERLTokens(50.0) + assert.Equal(t, 80.0, before) + assert.Equal(t, 100.0, deviceInfo.GetERLCurrentTokens()) // Capped at capacity +} + +func TestSharedDeviceStateV1Operations(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceStateV1(configs) + require.NoError(t, err) + + assert.Equal(t, 1, state.DeviceCount()) + assert.True(t, state.HasDevice(0)) + assert.False(t, state.HasDevice(1)) + + // Test heartbeat + now := uint64(time.Now().Unix()) + state.UpdateHeartbeat(now) + assert.Equal(t, now, state.GetLastHeartbeat()) + assert.True(t, state.IsHealthy(30*time.Second)) +} + +func TestSharedDeviceStateV2Operations(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceStateV2(configs) + require.NoError(t, err) + + assert.Equal(t, 1, state.DeviceCount()) + assert.True(t, state.HasDevice(0)) + assert.False(t, state.HasDevice(1)) + + // Test heartbeat + now := uint64(time.Now().Unix()) + state.UpdateHeartbeat(now) + assert.Equal(t, now, state.GetLastHeartbeat()) + assert.True(t, state.IsHealthy(30*time.Second)) +} + +func TestDeviceEntryV1Operations(t *testing.T) { + entry := NewDeviceEntryV1() + + entry.SetUUID("v1-uuid-test") + assert.Equal(t, "v1-uuid-test", entry.GetUUID()) + + assert.False(t, entry.IsActive()) + entry.SetActive(true) + assert.True(t, entry.IsActive()) +} + +func TestSharedMemoryHandleCleanup(t *testing.T) { + configs := createTestConfigs() + identifier := NewPodIdentifier("cleanup_test", "test") + podPath := identifier.ToPath(testShmBasePath) + defer func() { + _ = os.RemoveAll(testShmBasePath) + }() + + handle, err := CreateSharedMemoryHandle(podPath, configs) + require.NoError(t, err) + + shmPath := filepath.Join(podPath, ShmPathSuffix) + assert.True(t, fileExists(shmPath)) + + // Cleanup + stopAtPath := testShmBasePath + err = handle.Cleanup(&stopAtPath) + assert.NoError(t, err) + + // File should be removed + assert.False(t, fileExists(shmPath)) +} + +// Helper function to check if file exists +func fileExists(path string) bool { + _, err := os.Stat(path) + return !os.IsNotExist(err) +} + +// Helper function to get pod memory used (needed for tests) +func (s *SharedDeviceState) GetPodMemoryUsed(index int) uint64 { + if s.V1 != nil { + if index >= MaxDevices || !s.V1.Devices[index].IsActive() { + return 0 + } + return atomic.LoadUint64(&s.V1.Devices[index].DeviceInfo.PodMemoryUsed) + } + if index >= MaxDevices || !s.V2.Devices[index].IsActive() { + return 0 + } + return atomic.LoadUint64(&s.V2.Devices[index].DeviceInfo.PodMemoryUsed) +} diff --git a/internal/hypervisor/worker/vram/vram_trap.go b/internal/hypervisor/worker/vram/vram_trap.go new file mode 100644 index 00000000..15728f5d --- /dev/null +++ b/internal/hypervisor/worker/vram/vram_trap.go @@ -0,0 +1,3 @@ +package worker + +// diff --git a/internal/indexallocator/indexallocator.go b/internal/indexallocator/indexallocator.go index d839589e..dd712e2b 100644 --- a/internal/indexallocator/indexallocator.go +++ b/internal/indexallocator/indexallocator.go @@ -17,11 +17,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" ) -const ( - IndexRangeStart = 1 - IndexRangeEnd = 512 -) - // IndexAllocator manages allocation of 1-512 temporary indices for Pod-to-DevicePlugin communication // Uses a simple atomic counter that increments from 1 to 512, then wraps around to 1 // No bitmap tracking needed - index reuse is acceptable after 512 cycles @@ -91,7 +86,7 @@ func (s *IndexAllocator) AssignIndex(podName string) (int, error) { } // Atomic increment and wrap around next := atomic.AddInt64(&s.currentIndex, 1) - index := int((next-1)%IndexRangeEnd) + IndexRangeStart + index := int((next-1)%constants.IndexRangeEnd) + constants.IndexRangeStart log.FromContext(s.ctx).Info("assigned index successfully", "podName", podName, "index", index) return index, nil } diff --git a/internal/metrics/encoder.go b/internal/metrics/encoder.go index a78fa50c..892e36bc 100644 --- a/internal/metrics/encoder.go +++ b/internal/metrics/encoder.go @@ -37,6 +37,9 @@ type MultiProtocolEncoder struct { } func NewEncoder(encoderType string) Encoder { + if encoderType == "" { + encoderType = config.MetricsFormatInflux + } encoderEnum, exists := stringToEncoderType[encoderType] if !exists { // Default to influx for unknown types diff --git a/internal/scheduler/expander/handler.go b/internal/scheduler/expander/handler.go index 26da438a..77a7cffc 100644 --- a/internal/scheduler/expander/handler.go +++ b/internal/scheduler/expander/handler.go @@ -155,15 +155,15 @@ func (e *NodeExpander) ProcessExpansion(ctx context.Context, pod *corev1.Pod) er gpuNodesPassedOtherFilters, err := e.simulateSchedulingWithoutGPU(ctx, pod) if err != nil { e.eventRecorder.Eventf(pod, corev1.EventTypeNormal, "NodeExpansionCheck", - "can not schedule on any nodes even without GPU constraints, manual check required. error: %w", err) - e.logger.Info("Pod schedulable but no GPU nodes available, manual check required", + "can not schedule on any nodes even without GPU constraints, karpenter should take over expansion. error: %w", err) + e.logger.Info("Pod schedulable but no GPU nodes available, karpenter should take over expansion", "namespace", pod.Namespace, "pod", pod.Name, "error", err) return nil } if len(gpuNodesPassedOtherFilters) == 0 { e.eventRecorder.Eventf(pod, corev1.EventTypeNormal, "NodeExpansionCheck", - "can not schedule on any nodes, manual check required, 0 fit nodes") - e.logger.Info("Pod schedulable but no GPU nodes available, manual check required", + "can not schedule on any nodes even without GPU constraints, karpenter should take over expansion, 0 fit nodes") + e.logger.Info("Pod schedulable but no GPU nodes available, karpenter should take over expansion", "namespace", pod.Namespace, "pod", pod.Name) return nil } @@ -417,7 +417,7 @@ func (e *NodeExpander) checkGPUFitWithInflightNodes(pod *corev1.Pod, potentialGp // Get allocation request e.mu.RLock() defer e.mu.RUnlock() - allocRequest, _, err := e.allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(e.ctx, pod) if err != nil { return nil, false, true, false } @@ -468,7 +468,7 @@ func (e *NodeExpander) checkGPUFitWithInflightNodes(pod *corev1.Pod, potentialGp } func (e *NodeExpander) checkGPUFitForNewNode(pod *corev1.Pod, gpus []*tfv1.GPU) bool { - allocRequest, _, err := e.allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(e.ctx, pod) if err != nil { return false } diff --git a/internal/scheduler/gpuresources/gpuresources.go b/internal/scheduler/gpuresources/gpuresources.go index c3759fad..17e96203 100644 --- a/internal/scheduler/gpuresources/gpuresources.go +++ b/internal/scheduler/gpuresources/gpuresources.go @@ -128,7 +128,7 @@ func (s *GPUFit) PreFilter(ctx context.Context, state fwk.CycleState, pod *v1.Po // Handle tensor-fusion mode scheduling s.logger.Info("checking GPU node resources for pod", "pod", pod.Name) - allocRequest, reason, err := s.allocator.ComposeAllocationRequest(pod) + allocRequest, reason, err := utils.ComposeAllocationRequest(s.ctx, pod) if err != nil { return nil, fwk.NewStatus(fwk.Error, reason) } @@ -162,6 +162,29 @@ func (s *GPUFit) PreFilter(ctx context.Context, state fwk.CycleState, pod *v1.Po } } + // For partitioned mode, match partition template if not already specified + if allocRequest.Isolation == tfv1.IsolationModePartitioned && allocRequest.PartitionTemplateID == "" { + matchedGPU, partitionMatch, err := s.allocator.GetMatchedPartition(allocRequest, filteredGPUs) + if err != nil { + metrics.SetSchedulerMetrics(allocRequest.PoolName, false) + s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "PartitionTemplateMatchFailed", + "match partition template", "Failed to match partition template: "+err.Error()) + s.logger.Error(err, "failed to match partition template", "pod", pod.Name) + return nil, fwk.NewStatus(fwk.Unschedulable, fmt.Sprintf("no suitable partition template: %v", err)) + } + + // Set partition template ID in alloc request + allocRequest.PartitionTemplateID = partitionMatch.TemplateID + s.logger.Info("Matched partition template in PreFilter", + "pod", pod.Name, + "gpu", matchedGPU.Name, + "template", allocRequest.PartitionTemplateID, + "score", partitionMatch.Score) + + // Update state with the updated alloc request + state.Write(CycleStateAllocateRequest, allocRequest) + } + validNodesValidGPUs := lo.GroupBy(filteredGPUs, func(gpu *tfv1.GPU) string { return gpu.Status.NodeSelector[constants.KubernetesHostNameLabel] }) @@ -424,9 +447,10 @@ func (s *GPUFit) Reserve(ctx context.Context, state fwk.CycleState, pod *v1.Pod, } // reserve GPU resources inside memory and asynchronously update GPU custom resource + allocReq := allocRequest.(*tfv1.AllocRequest) _, err = s.allocator.Bind( schedulingResult.FinalGPUs, - allocRequest.(*tfv1.AllocRequest), + allocReq, ) if err != nil { return fwk.NewStatus(fwk.Error, err.Error()) @@ -477,14 +501,40 @@ func (s *GPUFit) PostBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod gpuIDs := strings.Join(gpuSchedulingResult.(*GPUSchedulingStateData).FinalGPUs, ",") s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs) - // Patch GPU device IDs annotation - patch := []byte(`[{ - "op": "add", - "path": "/metadata/annotations/` + utils.EscapeJSONPointer(constants.GPUDeviceIDsAnnotation) + `", - "value": "` + gpuIDs + `"}]`) - err = s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch)) + // Build patch operations + patchOps := []map[string]interface{}{ + { + "op": "add", + "path": "/metadata/annotations/" + utils.EscapeJSONPointer(constants.GPUDeviceIDsAnnotation), + "value": gpuIDs, + }, + } + + // Add partition template ID annotation if in partitioned mode + allocRequestRaw, err := state.Read(CycleStateAllocateRequest) + if err == nil { + allocRequest := allocRequestRaw.(*tfv1.AllocRequest) + if allocRequest.Isolation == tfv1.IsolationModePartitioned && allocRequest.PartitionTemplateID != "" { + patchOps = append(patchOps, map[string]interface{}{ + "op": "add", + "path": "/metadata/annotations/" + utils.EscapeJSONPointer(constants.PartitionTemplateIDAnnotation), + "value": allocRequest.PartitionTemplateID, + }) + s.logger.Info("Adding partition template ID annotation", "pod", pod.Name, "templateID", allocRequest.PartitionTemplateID) + } + } + + // Convert patch operations to JSON + patchBytes, err := json.Marshal(patchOps) + if err != nil { + s.logger.Error(err, "failed to marshal patch operations", "pod", pod.Name) + return + } + + // Patch pod annotations + err = s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patchBytes)) if err != nil { - s.logger.Error(err, "failed to patch gpu device ids", "pod", pod.Name) + s.logger.Error(err, "failed to patch pod annotations", "pod", pod.Name) s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "GPUDeviceAllocatedFailed", "Attach GPU device ID info failed", "Can not add GPU device IDs: "+gpuIDs) } else { @@ -573,7 +623,7 @@ func (s *GPUFit) queueingHint(logger klog.Logger, pod *v1.Pod, oldObj, newObj in } // Compose allocation request for the pod passed in by scheduler framework - allocRequest, _, err := s.allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(s.ctx, pod) if err != nil { logger.V(5).Info("Failed to compose allocation request for pod, skip", "pod", klog.KObj(pod), "error", err) diff --git a/internal/utils/compose.go b/internal/utils/compose.go index 5ca775a2..98f4a322 100644 --- a/internal/utils/compose.go +++ b/internal/utils/compose.go @@ -135,6 +135,10 @@ func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo Tens // add inject container annotation for client Pod, in case user doesn't specify it pod.Annotations[constants.InjectContainerAnnotation] = strings.Join(tfInfo.ContainerNames, ",") pod.Annotations[constants.IsolationModeAnnotation] = string(tfInfo.Profile.Isolation) + // add partition template ID if in partitioned mode + if tfInfo.Profile.Isolation == tfv1.IsolationModePartitioned && tfInfo.Profile.PartitionTemplateID != "" { + pod.Annotations[constants.PartitionTemplateIDAnnotation] = tfInfo.Profile.PartitionTemplateID + } } func AppendTFWorkerLabelsAndAnnotationsAfterTemplate( @@ -196,6 +200,10 @@ func AppendTFWorkerLabelsAndAnnotationsAfterTemplate( }), ",") } annotations[constants.IsolationModeAnnotation] = string(workload.Spec.Isolation) + // add partition template ID if in partitioned mode + if workload.Spec.Isolation == tfv1.IsolationModePartitioned && workload.Spec.PartitionTemplateID != "" { + annotations[constants.PartitionTemplateIDAnnotation] = workload.Spec.PartitionTemplateID + } return labels, annotations } @@ -449,7 +457,7 @@ func configureFeatures4InjectLib(isLocalGPU bool, disabledFeatures string) []v1. return envList } -func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, pool *tfv1.GPUPool) { +func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, pool *tfv1.GPUPool, compatibleWithNvidiaContainerToolkit bool) { // Hypervisor needs to read /proc to map pod with processID spec.HostPID = true spec.TerminationGracePeriodSeconds = constants.GracefulPeriodSeconds @@ -534,7 +542,7 @@ func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, poo }, }) - composeHypervisorInitContainer(spec, pool) + composeHypervisorInitContainer(spec, pool, compatibleWithNvidiaContainerToolkit) composeHypervisorContainer(spec, pool, enableVector) if enableVector { @@ -542,7 +550,7 @@ func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, poo } } -func composeHypervisorInitContainer(spec *v1.PodSpec, pool *tfv1.GPUPool) { +func composeHypervisorInitContainer(spec *v1.PodSpec, pool *tfv1.GPUPool, compatibleWithNvidiaContainerToolkit bool) { spec.InitContainers = append(spec.InitContainers, v1.Container{ Name: "init-shm", Image: pool.Spec.ComponentConfig.Hypervisor.Image, @@ -559,6 +567,49 @@ func composeHypervisorInitContainer(spec *v1.PodSpec, pool *tfv1.GPUPool) { }, }, }) + + // Add initContainer to wait for NVIDIA Container Toolkit toolkit-ready validation + if compatibleWithNvidiaContainerToolkit { + initContainerImage := pool.Spec.ComponentConfig.Hypervisor.Image + if initContainerImage == "" { + // Use the same image as the main container if not specified + if len(spec.Containers) > 0 { + initContainerImage = spec.Containers[0].Image + } + } + + initContainer := v1.Container{ + Name: "toolkit-validation", + Image: initContainerImage, + Command: []string{"sh", "-c"}, + Args: []string{ + "until [ -f /run/nvidia/validations/toolkit-ready ]; do echo waiting for nvidia container stack to be setup; sleep 5; done", + }, + SecurityContext: &v1.SecurityContext{ + Privileged: ptr.To(true), + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "run-nvidia-validations", + MountPath: "/run/nvidia/validations", + MountPropagation: ptr.To(v1.MountPropagationHostToContainer), + }, + }, + } + + spec.InitContainers = append(spec.InitContainers, initContainer) + + // Add volume for NVIDIA validations + spec.Volumes = append(spec.Volumes, v1.Volume{ + Name: "run-nvidia-validations", + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: "/run/nvidia/validations", + Type: ptr.To(v1.HostPathDirectoryOrCreate), + }, + }, + }) + } } func composeHypervisorContainer(spec *v1.PodSpec, pool *tfv1.GPUPool, enableVector bool) { diff --git a/internal/utils/config.go b/internal/utils/config.go index 23256dc2..ed8bd192 100644 --- a/internal/utils/config.go +++ b/internal/utils/config.go @@ -127,6 +127,67 @@ func GetEnvOrDefault(key, defaultValue string) string { return defaultValue } +// PodWorkerInfo contains extracted worker information from pod annotations +type PodWorkerInfo struct { + DeviceUUIDs []string + IsolationMode string + MemoryLimitBytes uint64 + ComputeLimitUnits uint32 + TemplateID string +} + +// ExtractPodWorkerInfo extracts worker information from pod annotations +// This is a common utility function used by both GpuAllocator and PodCacheManager +func ExtractPodWorkerInfo(pod *corev1.Pod) PodWorkerInfo { + info := PodWorkerInfo{} + + // Extract GPU device IDs + if gpuIDsStr, exists := pod.Annotations[constants.GPUDeviceIDsAnnotation]; exists { + ids := strings.Split(gpuIDsStr, ",") + info.DeviceUUIDs = make([]string, 0, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if id != "" { + info.DeviceUUIDs = append(info.DeviceUUIDs, id) + } + } + } + + // Extract isolation mode + if isolationMode, exists := pod.Annotations[constants.IsolationModeAnnotation]; exists { + info.IsolationMode = isolationMode + } else { + info.IsolationMode = string(tfv1.IsolationModeSoft) // default + } + + // Extract memory limit (VRAM) + if vramLimit, exists := pod.Annotations[constants.VRAMLimitAnnotation]; exists { + if qty, err := resource.ParseQuantity(vramLimit); err == nil { + info.MemoryLimitBytes = uint64(qty.Value()) + } + } + + // Extract compute limit (compute percent) + if computeLimit, exists := pod.Annotations[constants.ComputeLimitAnnotation]; exists { + if qty, err := resource.ParseQuantity(computeLimit); err == nil { + // Convert to percentage units (e.g., "50" -> 50, "100" -> 100) + percent := qty.AsApproximateFloat64() + info.ComputeLimitUnits = uint32(percent) + } + } + + // Extract template ID (for partitioned mode) + // First check PartitionTemplateIDAnnotation (set by scheduler) + if templateID, exists := pod.Annotations[constants.PartitionTemplateIDAnnotation]; exists { + info.TemplateID = templateID + } else if templateID, exists := pod.Annotations[constants.WorkloadProfileAnnotation]; exists { + // Fallback to WorkloadProfileAnnotation + info.TemplateID = templateID + } + + return info +} + func GetGPUResource(pod *corev1.Pod, isRequest bool) (tfv1.Resource, error) { tflopsKey := constants.TFLOPSRequestAnnotation vramKey := constants.VRAMRequestAnnotation @@ -222,3 +283,16 @@ func GetLeaderIP(client client.Client) string { } return leaderInfo.Data[constants.LeaderInfoConfigMapLeaderIPKey] } + +// only for local development, won't set KUBECONFIG env var in none local environments +func NormalizeKubeConfigEnv() { + cfgPath := os.Getenv("KUBECONFIG") + if cfgPath != "" && strings.HasPrefix(cfgPath, "~") { + home, err := os.UserHomeDir() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + _ = os.Setenv("KUBECONFIG", strings.Replace(cfgPath, "~", home, 1)) + } +} diff --git a/internal/utils/reconcile.go b/internal/utils/reconcile.go index ce2138a6..f4376be0 100644 --- a/internal/utils/reconcile.go +++ b/internal/utils/reconcile.go @@ -245,9 +245,12 @@ func IsDesignatedNodePod(pod *corev1.Pod) bool { func GetInitialGPUNodeSelector() []string { selector := os.Getenv("INITIAL_GPU_NODE_LABEL_SELECTOR") if selector == "" { - selector = constants.InitialGPUNodeSelector + return nil } selectors := strings.Split(selector, "=") + if len(selectors) != 2 { + return nil + } return selectors } @@ -265,3 +268,21 @@ func containsGPUResources(res corev1.ResourceList) bool { } return false } + +// AppendEnvVarsIfNotExists appends environment variables to the slice only if they don't already exist (by name). +// It returns the updated slice with new env vars appended. +func AppendEnvVarsIfNotExists(envVars []corev1.EnvVar, newEnvVars ...corev1.EnvVar) []corev1.EnvVar { + existingNames := make(map[string]bool) + for _, env := range envVars { + existingNames[env.Name] = true + } + + for _, newEnv := range newEnvVars { + if !existingNames[newEnv.Name] { + envVars = append(envVars, newEnv) + existingNames[newEnv.Name] = true + } + } + + return envVars +} diff --git a/internal/utils/resource.go b/internal/utils/resource.go index b78f579e..e9b5a328 100644 --- a/internal/utils/resource.go +++ b/internal/utils/resource.go @@ -1,6 +1,7 @@ package utils import ( + context "context" "fmt" "math" "slices" @@ -10,10 +11,14 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/samber/lo" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/log" ) +const MaxGPUCounterPerAllocation = 128 + func GPUResourcesFromAnnotations(annotations map[string]string) (*tfv1.Resources, error) { result := tfv1.Resources{} resInfo := []struct { @@ -73,3 +78,78 @@ func ParseIndicesAnnotation(gpuIndicesStr string) ([]int32, bool) { }) return gpuIndices, false } + +func ComposeAllocationRequest(ctx context.Context, pod *corev1.Pod) (*tfv1.AllocRequest, string, error) { + // allow Pods with no requests/limits to use TensorFusion, Pod webhook will ensure at least one request/limit is set + gpuRequestResource, err := GetGPUResource(pod, true) + if err != nil { + log.FromContext(ctx).Error(err, "Invalid gpu request annotation", "pod", pod.Name, "namespace", pod.Namespace) + } + gpuLimitResource, err := GetGPUResource(pod, false) + if err != nil { + log.FromContext(ctx).Error(err, "Invalid gpu limit annotation", "pod", pod.Name, "namespace", pod.Namespace) + } + + count := 1 + if gpuCountStr, exists := pod.Annotations[constants.GpuCountAnnotation]; exists { + count, err = strconv.Atoi(gpuCountStr) + if err != nil { + return &tfv1.AllocRequest{}, "invalid gpu count annotation", err + } + } + if count > MaxGPUCounterPerAllocation { + return &tfv1.AllocRequest{}, "gpu count annotation is too large", nil + } + + qosLevel := tfv1.QoSLevel(pod.Annotations[constants.QoSLevelAnnotation]) + if qosLevel == "" { + qosLevel = tfv1.QoSMedium + } + + gpuVendor := pod.Annotations[constants.GpuVendorAnnotation] + + gpuIndices, hasError := ParseIndicesAnnotation(pod.Annotations[constants.GpuIndicesAnnotation]) + if hasError { + return &tfv1.AllocRequest{}, "invalid gpu-indices annotation", + fmt.Errorf("can not parse gpu indices annotation") + } + + // Read isolation mode + isolationMode := tfv1.IsolationModeType(pod.Annotations[constants.IsolationModeAnnotation]) + if isolationMode == "" { + isolationMode = tfv1.IsolationModeSoft + } + + allocRequest := tfv1.AllocRequest{ + PoolName: pod.Annotations[constants.GpuPoolKey], + Request: gpuRequestResource, + Limit: gpuLimitResource, + + Count: uint(count), + GPUModel: pod.Annotations[constants.GPUModelAnnotation], + GPUIndices: gpuIndices, + GPUVendor: gpuVendor, + Isolation: isolationMode, + WorkloadNameNamespace: tfv1.NameNamespace{ + Name: pod.Labels[constants.WorkloadKey], + Namespace: pod.Namespace, + }, + PodMeta: pod.ObjectMeta, + QoS: qosLevel, + } + + // Read partition template ID annotation if in partitioned mode + if allocRequest.Isolation == tfv1.IsolationModePartitioned { + if partitionTemplateID, ok := pod.Annotations[constants.PartitionTemplateIDAnnotation]; ok && partitionTemplateID != "" { + allocRequest.PartitionTemplateID = partitionTemplateID + } + } + + // for already allocated workers, set the GPU device IDs for further scaling and retrieval + if gpuIdStr, exists := pod.Annotations[constants.GPUDeviceIDsAnnotation]; exists { + gpuIds := strings.SplitSeq(gpuIdStr, ",") + allocRequest.GPUNames = slices.Collect(gpuIds) + } + + return &allocRequest, "", nil +} diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index 0841f423..2bf46671 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -302,8 +302,17 @@ func (m *TensorFusionPodMutator) patchTFClient( // Index must be assigned in webhook stage since scheduler cannot modify Pod // This is a special index resource (1-512), not a real device resource // Index is assigned in ascending order (1, 2, 3, ...) via distributed lock (leader election) - index := m.assignDeviceAllocationIndex(ctx, pod) - log.FromContext(ctx).Info("assigned device allocation index successfully", "index", index, "pod", pod.Name) + index := 0 + if pod.Annotations[constants.PodIndexAnnotation] == "" { + index = m.assignDeviceAllocationIndex(ctx, pod) + log.FromContext(ctx).Info("assigned device allocation index successfully", "index", index, "pod", pod.Name) + } else { + var err error + index, err = strconv.Atoi(pod.Annotations[constants.PodIndexAnnotation]) + if err != nil { + return nil, fmt.Errorf("invalid pod index annotation: %w", err) + } + } for _, containerIndex := range containerIndices { container := &pod.Spec.Containers[containerIndex] diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index 0066b442..c4adf622 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -106,6 +106,13 @@ func ParseTensorFusionInfo( workloadProfile.Spec.Isolation = tfv1.IsolationModeSoft } + // Read partition template ID annotation if in partitioned mode + if workloadProfile.Spec.Isolation == tfv1.IsolationModePartitioned { + if partitionTemplateID, ok := pod.Annotations[constants.PartitionTemplateIDAnnotation]; ok && partitionTemplateID != "" { + workloadProfile.Spec.PartitionTemplateID = partitionTemplateID + } + } + workerPodTemplate, ok := pod.Annotations[constants.WorkerPodTemplateAnnotation] if ok && workerPodTemplate != "" { if workloadProfile.Spec.IsLocalGPU { diff --git a/provider/Makefile b/provider/Makefile new file mode 100644 index 00000000..c1ad8680 --- /dev/null +++ b/provider/Makefile @@ -0,0 +1,89 @@ +# Makefile for building accelerator libraries +# Supports both stub and vendor-specific implementations (NVIDIA, Ascend, etc.) + +CC ?= gcc +CFLAGS ?= -Wall -Wextra -std=c11 -fPIC -O2 +LDFLAGS ?= -shared + +# Directories +PROVIDER_DIR := $(shell pwd) +STUB_DIR := $(PROVIDER_DIR)/stub +ASCEND_DIR := $(PROVIDER_DIR)/ascend +BUILD_DIR := $(PROVIDER_DIR)/build +TEST_DIR := $(PROVIDER_DIR)/test + +# Output libraries +STUB_LIB := $(BUILD_DIR)/libaccelerator_stub.so +ASCEND_LIB := $(BUILD_DIR)/libaccelerator_ascend.so + +# Source files +STUB_SRC := $(STUB_DIR)/accelerator.c +ASCEND_SRC := $(ASCEND_DIR)/accelerator.c + +# Object files +STUB_OBJ := $(BUILD_DIR)/accelerator_stub.o +ASCEND_OBJ := $(BUILD_DIR)/accelerator_ascend.o + +# Test executables +TEST_BIN := $(BUILD_DIR)/test_accelerator + +.PHONY: all clean stub ascend test install + +all: stub + +# Build stub implementation +stub: $(STUB_LIB) + +$(STUB_LIB): $(STUB_OBJ) | $(BUILD_DIR) + $(CC) $(LDFLAGS) -o $@ $< + +$(STUB_OBJ): $(STUB_SRC) | $(BUILD_DIR) + $(CC) $(CFLAGS) -I$(PROVIDER_DIR) -c -o $@ $< + +# Build Ascend implementation (requires Ascend CANN SDK) +ascend: $(ASCEND_LIB) + +$(ASCEND_LIB): $(ASCEND_OBJ) | $(BUILD_DIR) + $(CC) $(LDFLAGS) -o $@ $< $(ASCEND_LDFLAGS) + +$(ASCEND_OBJ): $(ASCEND_SRC) | $(BUILD_DIR) + $(CC) $(CFLAGS) -I$(PROVIDER_DIR) $(ASCEND_CFLAGS) -c -o $@ $< + +# Build test executable +test: $(TEST_BIN) + +$(TEST_BIN): $(TEST_DIR)/test_accelerator.c $(STUB_LIB) | $(BUILD_DIR) + $(CC) $(CFLAGS) -I$(PROVIDER_DIR) -o $@ $(TEST_DIR)/test_accelerator.c -L$(BUILD_DIR) -laccelerator_stub -Wl,-rpath,$(BUILD_DIR) + +# Run tests +test-run: test + LD_LIBRARY_PATH=$(BUILD_DIR):$$LD_LIBRARY_PATH $(TEST_BIN) + +# Create build directory +$(BUILD_DIR): + mkdir -p $(BUILD_DIR) + +# Clean build artifacts +clean: + rm -rf $(BUILD_DIR) + +# Install libraries to system path (optional) +install: $(STUB_LIB) + install -d /usr/local/lib/tensor-fusion + install -m 755 $(STUB_LIB) /usr/local/lib/tensor-fusion/ + install -d /usr/local/include/tensor-fusion + install -m 644 $(PROVIDER_DIR)/accelerator.h /usr/local/include/tensor-fusion/ + install -m 644 $(PROVIDER_DIR)/limiter.h /usr/local/include/tensor-fusion/ + +# Help target +help: + @echo "Available targets:" + @echo " all - Build stub implementation (default)" + @echo " stub - Build stub accelerator library" + @echo " ascend - Build Ascend accelerator library (requires CANN SDK)" + @echo " test - Build test executable" + @echo " test-run - Build and run tests" + @echo " clean - Remove build artifacts" + @echo " install - Install libraries to system path" + @echo " help - Show this help message" + diff --git a/provider/README.md b/provider/README.md new file mode 100644 index 00000000..d6a7ffb5 --- /dev/null +++ b/provider/README.md @@ -0,0 +1,129 @@ +# Accelerator Provider Interface + +This directory contains the abstract ABI (Application Binary Interface) for vGPU vendor accelerator libraries. + +## Overview + +The accelerator interface abstracts vGPU vendor-specific implementations into a unified API, supporting four isolation modes: + +- **Shared Mode**: Oversubscription, high elasticity, no resource control (equivalent to NVIDIA timeslicing) +- **Soft Mode**: Oversubscription, high elasticity, time-sharing resource control via hooks and limiter +- **Hard Mode**: No oversubscription, medium elasticity, space-sharing via one-time resource limits +- **Partitioned Mode**: No oversubscription, low elasticity, hardware/driver-level partitioning (e.g., MIG) + +## Structure + +``` +provider/ +├── accelerator.h # Main interface definition +├── limiter.h # Limiter.so API (not vendor-implemented) +├── Makefile # Build scripts +├── stub/ +│ └── accelerator.c # Stub implementation for testing +├── ascend/ +│ └── accelerator.c # Huawei Ascend implementation +└── test/ + └── test_accelerator.c # Test suite +``` + +## Building + +### Build Stub Implementation + +```bash +cd provider +make stub +``` + +### Build Ascend Implementation + +```bash +cd provider +make ascend +``` + +### Run Tests + +```bash +cd provider +make test-run +``` + +## Interface Categories + +### 1. DeviceInfo APIs + +- `getDeviceInfo()`: Get device information (capabilities, basic info, NUMA, etc.) +- `getPartitionTemplates()`: Get hardware partition templates (e.g., MIG) +- `getDeviceTopology()`: Get device topology (NVLink, IB NIC, etc.) + +### 2. Virtualization APIs + +#### Partitioned Isolation +- `assignPartition()`: Assign hardware partition (returns partitionOverhead) +- `removePartition()`: Remove partition + +#### Hard Isolation +- `setMemHardLimit()`: Set hard memory limit (one-time) +- `setComputeUnitHardLimit()`: Set hard compute limit (one-time) + +#### Snapshot/Migration +- `snapshot()`: Snapshot device state for processes +- `resume()`: Resume device state for processes + +### 3. Metrics APIs + +- `getProcessComputeUtilization()`: Get compute utilization per process +- `getProcessMemoryUtilization()`: Get memory utilization per process +- `getDeviceMetrics()`: Get basic device metrics (power, PCIe, SM active, TC usage) +- `getExtendedDeviceMetrics()`: Get extended metrics (NVLink bandwidth, etc.) + +## Vendor Implementations + +### Stub Implementation + +The stub implementation (`stub/accelerator.c`) provides a reference implementation for testing and development. + +### Ascend Implementation + +The Ascend implementation (`ascend/accelerator.c`) provides support for Huawei Ascend accelerators: + +- Supports Soft and Hard isolation modes +- Does not support hardware partitioning (MIG-like features) +- Uses HCCS (Huawei Cache Coherent System) for device interconnects +- Typical device: Ascend 910 with 32GB memory, 2 AI cores, 320 TFLOPS (FP16) + +## Usage in Hypervisor + +The hypervisor uses the accelerator library via CGO bindings: + +```go +import "github.com/NexusGPU/tensor-fusion/internal/hypervisor/device" + +mgr, err := device.NewManager("path/to/libaccelerator.so", 30*time.Second) +``` + +See `internal/hypervisor/device/` for the Go bindings and device manager implementation. + +## Testing + +All tests pass successfully: + +```bash +$ make test-run +======================================== +Accelerator Library Test Suite +======================================== +Total tests: 47 +Passed: 47 +Failed: 0 +All tests passed! ✓ +``` + +## Notes + +- All struct parameters are carefully designed with key attributes +- Memory management: Use provided cleanup functions to free allocated memory +- Thread safety: Vendor implementations should be thread-safe +- Error handling: All APIs return Result enum for error handling + diff --git a/provider/accelerator.h b/provider/accelerator.h new file mode 100644 index 00000000..386d6de3 --- /dev/null +++ b/provider/accelerator.h @@ -0,0 +1,405 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACCELERATOR_H +#define ACCELERATOR_H + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================ +// Common Types +// ============================================================================ + +typedef enum { + RESULT_SUCCESS = 0, + RESULT_ERROR_INVALID_PARAM = 1, + RESULT_ERROR_NOT_FOUND = 2, + RESULT_ERROR_NOT_SUPPORTED = 3, + RESULT_ERROR_RESOURCE_EXHAUSTED = 4, + RESULT_ERROR_OPERATION_FAILED = 5, + RESULT_ERROR_INTERNAL = 6 +} Result; + +typedef enum { + ISOLATION_MODE_SHARED = 0, // Timeslicing, no resource control + ISOLATION_MODE_SOFT = 1, // Hook-based, token-based limiting + ISOLATION_MODE_HARD = 2, // One-time resource limits + ISOLATION_MODE_PARTITIONED = 3 // Hardware/driver-level partitioning (MIG) +} IsolationMode; + +// ============================================================================ +// DeviceInfo Types +// ============================================================================ + +// Device capabilities +typedef struct { + bool supportsPartitioning; // e.g., MIG support + bool supportsSoftIsolation; // Hook-based isolation support + bool supportsHardIsolation; // One-time limit support + bool supportsSnapshot; // Process snapshot/resume support + bool supportsMetrics; // Metrics collection support + uint32_t maxPartitions; // Maximum number of partitions + uint32_t maxWorkersPerDevice; // Maximum workers per device +} DeviceCapabilities; + +// Basic device information +typedef struct { + char uuid[64]; // Device UUID + char vendor[32]; // Vendor name (e.g., "NVIDIA", "AMD") + char model[128]; // Model name (e.g., "A100", "H100") + char driverVersion[64]; // Driver version + char firmwareVersion[64]; // Firmware version + int32_t index; // Device index + int32_t numaNode; // NUMA node ID (-1 if not assigned) + uint64_t totalMemoryBytes; // Total memory in bytes + uint64_t totalComputeUnits; // Total compute units (e.g., SMs for NVIDIA) + double maxTflops; // Maximum TFLOPS + uint32_t pcieGen; // PCIe generation + uint32_t pcieWidth; // PCIe width (lanes) +} DeviceBasicInfo; + +// Device properties +typedef struct { + uint32_t clockGraphics; // Graphics clock (MHz) + uint32_t clockSM; // SM clock (MHz) - for NVIDIA + uint32_t clockMem; // Memory clock (MHz) + uint32_t clockAI; // AI core clock (MHz) - for Ascend + uint32_t powerLimit; // Power limit (W) + uint32_t temperatureThreshold; // Temperature threshold (C) + bool eccEnabled; // ECC enabled + bool persistenceModeEnabled; // Persistence mode + char computeCapability[16]; // Compute capability (e.g., "8.0", "9.0" for NVIDIA, "Ascend310" for Ascend) + char chipType[32]; // Chip type (e.g., "NVIDIA", "Ascend", "AMD") +} DeviceProperties; + +// Related device information (for topology) +typedef struct { + char deviceUUID[64]; // Related device UUID + char connectionType[32]; // Connection type (e.g., "NVLink", "PCIe", "IB") + uint32_t bandwidthMBps; // Bandwidth in MB/s + uint32_t latencyNs; // Latency in nanoseconds +} RelatedDevice; + +// Extended device information +typedef struct { + DeviceBasicInfo basic; + DeviceProperties props; + RelatedDevice* relatedDevices; // Array of related devices + size_t relatedDeviceCount; // Number of related devices + DeviceCapabilities capabilities; +} ExtendedDeviceInfo; + +// Partition template for hardware partitioning (e.g., MIG) +typedef struct { + char templateId[64]; // Template identifier + char name[128]; // Human-readable name + uint64_t memoryBytes; // Memory allocated to partition + uint64_t computeUnits; // Compute units allocated + double tflops; // TFLOPS for this partition + uint32_t sliceCount; // Number of slices (for MIG) + bool isDefault; // Is this a default template + char description[256]; // Description +} PartitionTemplate; + +// Device topology information +typedef struct { + char deviceUUID[64]; // Device UUID + int32_t numaNode; // NUMA node + RelatedDevice* connections; // Array of connections + size_t connectionCount; // Number of connections +} DeviceTopology; + +// Extended topology (includes NVLink, IB NIC, etc.) +typedef struct { + DeviceTopology* devices; // Array of device topologies + size_t deviceCount; // Number of devices + uint32_t nvlinkBandwidthMBps; // NVLink total bandwidth + uint32_t ibNicCount; // InfiniBand NIC count + char topologyType[32]; // Topology type (e.g., "NVLink", "PCIe") +} ExtendedDeviceTopology; + +// ============================================================================ +// Virtualization Types +// ============================================================================ + +// Partition assignment request +typedef struct { + char templateId[64]; // Template ID to use + char deviceUUID[64]; // Target device UUID + char partitionUUID[64]; // Output: assigned partition UUID + uint64_t partitionOverheadBytes; // Memory overhead for partition (output) +} PartitionAssignment; + +// Worker information for isolation +typedef struct { + char workerId[64]; // Worker identifier + char deviceUUID[64]; // Device UUID + pid_t processId; // Process ID + uint64_t memoryLimitBytes; // Memory limit (for hard isolation) + uint32_t computeUnitLimit; // Compute unit limit (for hard isolation) + IsolationMode isolationMode; // Isolation mode +} WorkerInfo; + +// Process array for snapshot/resume +typedef struct { + pid_t* processIds; // Array of process IDs + size_t processCount; // Number of processes + char deviceUUID[64]; // Device UUID +} ProcessArray; + +// ============================================================================ +// Metrics Types +// ============================================================================ + +// Compute utilization +typedef struct { + char processId[32]; // Process ID as string + char deviceUUID[64]; // Device UUID + double utilizationPercent; // Utilization percentage (0-100) + uint64_t activeSMs; // Active SMs/Compute Units + uint64_t totalSMs; // Total SMs/Compute Units + double tflopsUsed; // TFLOPS currently used +} ComputeUtilization; + +// Memory utilization +typedef struct { + char processId[32]; // Process ID as string + char deviceUUID[64]; // Device UUID + uint64_t usedBytes; // Memory used in bytes + uint64_t reservedBytes; // Memory reserved in bytes + double utilizationPercent; // Utilization percentage (0-100) +} MemoryUtilization; + +// Basic device metrics +typedef struct { + char deviceUUID[64]; // Device UUID + double powerUsageWatts; // Current power usage (W) + double temperatureCelsius; // Temperature (C) + uint64_t pcieRxBytes; // PCIe RX bytes + uint64_t pcieTxBytes; // PCIe TX bytes + uint32_t smActivePercent; // SM active percentage + uint32_t tensorCoreUsagePercent; // Tensor Core usage percentage + uint64_t memoryUsedBytes; // Memory used + uint64_t memoryTotalBytes; // Memory total +} DeviceMetrics; + +// Extended device metrics (NVLink, etc.) +typedef struct { + char deviceUUID[64]; // Device UUID + uint32_t* nvlinkBandwidthMBps; // NVLink bandwidth per link (MB/s) + size_t nvlinkCount; // Number of NVLink connections + uint64_t* ibNicBandwidthMBps; // IB NIC bandwidth per NIC (MB/s) + size_t ibNicCount; // Number of IB NICs + uint32_t* pcieBandwidthMBps; // PCIe bandwidth per link (MB/s) + size_t pcieLinkCount; // Number of PCIe links +} ExtendedDeviceMetrics; + +// ============================================================================ +// DeviceInfo APIs +// ============================================================================ + +/** + * Get the number of available devices. + * + * @param deviceCount Output parameter for number of devices + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetDeviceCount(size_t* deviceCount); + +/** + * Get all available devices information. + * + * @param devices Output buffer for device information (allocated by caller) + * @param maxCount Maximum number of devices that can fit in the buffer + * @param deviceCount Output parameter for number of devices actually returned + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetAllDevices(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount); + +/** + * Get device topology including NVLink, IB NIC, and other interconnects. + * + * @param deviceIndexArray Array of device indices to query + * @param deviceCount Number of devices in array + * @param topology Output parameter for extended topology (allocated by caller) + * @param maxConnectionsPerDevice Maximum number of connections per device in topology buffer + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetDeviceTopology(int32_t* deviceIndexArray, size_t deviceCount, ExtendedDeviceTopology* topology, size_t maxConnectionsPerDevice); + +// ============================================================================ +// Virtualization APIs - Partitioned Isolation +// ============================================================================ + +/** + * Assign a partition to a device using a template (e.g., create MIG instance). + * + * @param assignment Partition assignment request (templateId, deviceUUID) + * Output: partitionUUID and partitionOverheadBytes + * @return true on success, false otherwise + */ +bool AssignPartition(PartitionAssignment* assignment); + +/** + * Remove a partition from a device. + * + * @param templateId Template ID used to create the partition + * @param deviceUUID Device UUID + * @return true on success, false otherwise + */ +bool RemovePartition(const char* templateId, const char* deviceUUID); + +// ============================================================================ +// Virtualization APIs - Hard Isolation +// ============================================================================ + +/** + * Set hard memory limit for a worker (one-time, called at worker start by limiter.so). + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param memoryLimitBytes Memory limit in bytes + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result SetMemHardLimit(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes); + +/** + * Set hard compute unit limit for a worker (one-time, called at worker start). + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param computeUnitLimit Compute unit limit (e.g., percentage 0-100) + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result SetComputeUnitHardLimit(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit); + +// ============================================================================ +// Virtualization APIs - Device Snapshot/Migration +// ============================================================================ + +/** + * Snapshot device state for processes (lock processes, checkpoint state). + * Called from hypervisor for migration. + * + * @param processes Array of processes to snapshot + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result Snapshot(ProcessArray* processes); + +/** + * Resume device state for processes (unlock processes, restore state). + * Called from hypervisor after migration. + * + * @param processes Array of processes to resume + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result Resume(ProcessArray* processes); + +// ============================================================================ +// Metrics APIs +// ============================================================================ + +/** + * Get compute utilization for all processes on all devices. + * + * @param utilizations Output buffer for compute utilizations (allocated by caller) + * @param maxCount Maximum number of utilizations that can fit in the buffer + * @param utilizationCount Output parameter for number of utilizations actually returned + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetProcessComputeUtilization( + ComputeUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +); + +/** + * Get memory utilization for all processes on all devices. + * + * @param utilizations Output buffer for memory utilizations (allocated by caller) + * @param maxCount Maximum number of utilizations that can fit in the buffer + * @param utilizationCount Output parameter for number of utilizations actually returned + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetProcessMemoryUtilization( + MemoryUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +); + +/** + * Get basic device metrics (power, PCIe, SM active, TC usage, etc.). + * + * @param deviceUUIDArray Array of device UUIDs + * @param deviceCount Number of devices + * @param metrics Output buffer for device metrics (allocated by caller, size >= deviceCount) + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + DeviceMetrics* metrics +); + +/** + * Get extended device metrics (NVLink bandwidth, etc.). + * + * @param deviceUUIDArray Array of device UUIDs + * @param deviceCount Number of devices + * @param metrics Output buffer for extended device metrics (allocated by caller, size >= deviceCount) + * @param maxNvlinkPerDevice Maximum number of NVLink connections per device + * @param maxIbNicPerDevice Maximum number of IB NICs per device + * @param maxPciePerDevice Maximum number of PCIe links per device + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetExtendedDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + ExtendedDeviceMetrics* metrics, + size_t maxNvlinkPerDevice, + size_t maxIbNicPerDevice, + size_t maxPciePerDevice +); + +// ============================================================================ +// Utility APIs +// ============================================================================ + +/** + * Log a message (for debugging and diagnostics). + * + * @param level Log level (e.g., "DEBUG", "INFO", "WARN", "ERROR") + * @param message Log message + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result Log(const char* level, const char* message); + +#ifdef __cplusplus +} +#endif + +// Include limiter.h after defining Result enum +#include "limiter.h" + +#endif // ACCELERATOR_H + diff --git a/provider/ascend/accelerator.c b/provider/ascend/accelerator.c new file mode 100644 index 00000000..19409576 --- /dev/null +++ b/provider/ascend/accelerator.c @@ -0,0 +1,387 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../accelerator.h" +#include +#include +#include +#include +#include +#include + +// Ascend CANN API headers (when available) +// #include "acl/acl.h" +// For now, we'll use stub implementations that match Ascend behavior + +// ============================================================================ +// Ascend Implementation - DeviceInfo APIs +// ============================================================================ + +Result GetDeviceCount(size_t* deviceCount) { + if (!deviceCount) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Use actual Ascend CANN API when available + // uint32_t deviceCount; + // aclError ret = aclrtGetDeviceCount(&deviceCount); + + // Stub: return 2 devices + *deviceCount = 2; + return RESULT_SUCCESS; +} + +// Helper function to initialize a single device info +static void initDeviceInfo(ExtendedDeviceInfo* info, int32_t deviceIndex) { + // Initialize basic info for Ascend device + snprintf(info->basic.uuid, sizeof(info->basic.uuid), "ascend-device-%d", deviceIndex); + snprintf(info->basic.vendor, sizeof(info->basic.vendor), "Huawei"); + snprintf(info->basic.model, sizeof(info->basic.model), "Ascend-910"); + snprintf(info->basic.driverVersion, sizeof(info->basic.driverVersion), "CANN-7.0"); + snprintf(info->basic.firmwareVersion, sizeof(info->basic.firmwareVersion), "1.0.0"); + info->basic.index = deviceIndex; + info->basic.numaNode = deviceIndex % 2; // Stub: alternate NUMA nodes + info->basic.totalMemoryBytes = 32ULL * 1024 * 1024 * 1024; // 32GB (Ascend 910) + info->basic.totalComputeUnits = 2; // Ascend uses AI cores, typically 2 per chip + info->basic.maxTflops = 320.0; // Ascend 910: 320 TFLOPS (FP16) + info->basic.pcieGen = 4; + info->basic.pcieWidth = 16; + + // Initialize properties for Ascend + info->props.clockGraphics = 0; // Not applicable for Ascend + info->props.clockSM = 0; // Not applicable for Ascend + info->props.clockMem = 1200; // MHz + info->props.clockAI = 1000; // AI core clock (MHz) - Ascend specific + info->props.powerLimit = 310; // W (Ascend 910) + info->props.temperatureThreshold = 85; // C + info->props.eccEnabled = true; + info->props.persistenceModeEnabled = false; + snprintf(info->props.computeCapability, sizeof(info->props.computeCapability), "Ascend910"); + snprintf(info->props.chipType, sizeof(info->props.chipType), "Ascend"); + + // Initialize capabilities + // Ascend typically doesn't support hardware partitioning like MIG + info->capabilities.supportsPartitioning = false; + info->capabilities.supportsSoftIsolation = true; + info->capabilities.supportsHardIsolation = true; + info->capabilities.supportsSnapshot = true; + info->capabilities.supportsMetrics = true; + info->capabilities.maxPartitions = 0; // No hardware partitioning + info->capabilities.maxWorkersPerDevice = 32; // Higher than NVIDIA due to different architecture + + // Initialize related devices (stub: no related devices) + info->relatedDevices = NULL; + info->relatedDeviceCount = 0; +} + +Result GetAllDevices(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount) { + if (!devices || !deviceCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Use actual Ascend CANN API when available + // uint32_t deviceCount; + // aclError ret = aclrtGetDeviceCount(&deviceCount); + + // Stub: return 2 devices (but not more than maxCount) + size_t actualCount = 2; + if (actualCount > maxCount) { + actualCount = maxCount; + } + *deviceCount = actualCount; + + // Initialize each device + for (size_t i = 0; i < actualCount; i++) { + initDeviceInfo(&devices[i], (int32_t)i); + } + + return RESULT_SUCCESS; +} + +Result GetPartitionTemplates(int32_t deviceIndex __attribute__((unused)), PartitionTemplate* templates, size_t maxCount, size_t* templateCount) { + if (!templates || !templateCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Ascend doesn't support hardware partitioning like MIG + *templateCount = 0; + return RESULT_SUCCESS; +} + +Result GetDeviceTopology(int32_t* deviceIndexArray, size_t deviceCount, ExtendedDeviceTopology* topology, size_t maxConnectionsPerDevice) { + if (!deviceIndexArray || deviceCount == 0 || !topology || maxConnectionsPerDevice == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Note: topology->devices must be pre-allocated by caller with size >= deviceCount + // topology->devices[i].connections must be pre-allocated by caller with size >= maxConnectionsPerDevice + if (!topology->devices) { + return RESULT_ERROR_INVALID_PARAM; + } + topology->deviceCount = deviceCount; + + // Initialize each device topology + for (size_t i = 0; i < deviceCount; i++) { + DeviceTopology* dt = &topology->devices[i]; + snprintf(dt->deviceUUID, sizeof(dt->deviceUUID), "ascend-device-%d", deviceIndexArray[i]); + dt->numaNode = deviceIndexArray[i] % 2; + + // Ascend devices typically connect via PCIe or HCCS (Huawei Cache Coherent System) + size_t connectionCount = (deviceCount > 1) ? (deviceCount - 1) : 0; + if (connectionCount > maxConnectionsPerDevice) { + connectionCount = maxConnectionsPerDevice; + } + + if (connectionCount > 0 && dt->connections) { + dt->connectionCount = connectionCount; + + size_t connIdx = 0; + for (size_t j = 0; j < deviceCount && connIdx < connectionCount; j++) { + if (j != i) { + RelatedDevice* rd = &dt->connections[connIdx]; + snprintf(rd->deviceUUID, sizeof(rd->deviceUUID), "ascend-device-%d", deviceIndexArray[j]); + snprintf(rd->connectionType, sizeof(rd->connectionType), "HCCS"); // Huawei Cache Coherent System + rd->bandwidthMBps = 200000; // 200 GB/s (stub) + rd->latencyNs = 150; // 150ns (stub) + connIdx++; + } + } + } else { + dt->connections = NULL; + dt->connectionCount = 0; + } + } + + // Set extended topology info + topology->nvlinkBandwidthMBps = 0; // Not applicable for Ascend + topology->ibNicCount = 0; // Stub: no IB NICs + snprintf(topology->topologyType, sizeof(topology->topologyType), "HCCS"); + + return RESULT_SUCCESS; +} + +// ============================================================================ +// Ascend Implementation - Virtualization APIs - Partitioned Isolation +// ============================================================================ + +bool AssignPartition(PartitionAssignment* assignment) { + if (!assignment || assignment->templateId[0] == '\0' || assignment->deviceUUID[0] == '\0') { + return false; + } + + // Ascend doesn't support hardware partitioning + return false; +} + +bool RemovePartition(const char* templateId, const char* deviceUUID) { + if (!templateId || !deviceUUID) { + return false; + } + + // Ascend doesn't support hardware partitioning + return false; +} + +// ============================================================================ +// Ascend Implementation - Virtualization APIs - Hard Isolation +// ============================================================================ + +Result SetMemHardLimit(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes) { + if (!workerId || !deviceUUID || memoryLimitBytes == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Use Ascend CANN API to set memory limit + // aclrtSetDevice(deviceIndex); + // aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); + + // Stub: always succeed + return RESULT_SUCCESS; +} + +Result SetComputeUnitHardLimit(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit) { + if (!workerId || !deviceUUID || computeUnitLimit == 0 || computeUnitLimit > 100) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Use Ascend CANN API to set compute unit limit + // This might involve setting AI core allocation + + // Stub: always succeed + return RESULT_SUCCESS; +} + +// ============================================================================ +// Ascend Implementation - Virtualization APIs - Device Snapshot/Migration +// ============================================================================ + +Result Snapshot(ProcessArray* processes) { + if (!processes || !processes->processIds || processes->processCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: verify processes exist (basic check) + for (size_t i = 0; i < processes->processCount; i++) { + if (kill(processes->processIds[i], 0) != 0) { + // Process doesn't exist or no permission + return RESULT_ERROR_NOT_FOUND; + } + } + + // TODO: Use Ascend CANN API to snapshot device context + // This would involve saving device memory state, context, etc. + + // Stub: always succeed (no actual snapshot implementation) + return RESULT_SUCCESS; +} + +Result Resume(ProcessArray* processes) { + if (!processes || !processes->processIds || processes->processCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Use Ascend CANN API to resume device context + // This would involve restoring device memory state, context, etc. + + // Stub: always succeed (no actual resume implementation) + return RESULT_SUCCESS; +} + +// ============================================================================ +// Ascend Implementation - Metrics APIs +// ============================================================================ + +Result GetProcessComputeUtilization( + ComputeUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +) { + if (!utilizations || !utilizationCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Get actual device and process list from limiter + // TODO: Use Ascend CANN API or ascend-toolkit to get actual metrics + // aclprofGetDeviceUtilizationRate() + // For now, stub implementation returns empty + *utilizationCount = 0; + return RESULT_SUCCESS; +} + +Result GetProcessMemoryUtilization( + MemoryUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +) { + if (!utilizations || !utilizationCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Get actual device and process list from limiter + // TODO: Use Ascend CANN API to get actual memory usage + // aclrtGetMemInfo() + // For now, stub implementation returns empty + *utilizationCount = 0; + return RESULT_SUCCESS; +} + +Result GetDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + DeviceMetrics* metrics +) { + if (!deviceUUIDArray || deviceCount == 0 || !metrics) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Use Ascend CANN API or ascend-toolkit to get actual metrics + // aclrtGetDeviceUtilizationRate() + // ascend-toolkit: npu-smi info + + // Fill stub data + for (size_t i = 0; i < deviceCount; i++) { + DeviceMetrics* dm = &metrics[i]; + snprintf(dm->deviceUUID, sizeof(dm->deviceUUID), "%s", deviceUUIDArray[i]); + dm->powerUsageWatts = 250.0 + (i * 20.0); // Stub: 250-270W + dm->temperatureCelsius = 50.0 + (i * 5.0); // Stub: 50-55C + dm->pcieRxBytes = 2ULL * 1024 * 1024 * 1024 * (i + 1); // Stub: 2-4GB + dm->pcieTxBytes = 1ULL * 1024 * 1024 * 1024 * (i + 1); // Stub: 1-2GB + dm->smActivePercent = 60 + (i * 10); // Stub: 60-80% (AI core active) + dm->tensorCoreUsagePercent = 0; // Not applicable for Ascend + dm->memoryUsedBytes = 16ULL * 1024 * 1024 * 1024; // Stub: 16GB + dm->memoryTotalBytes = 32ULL * 1024 * 1024 * 1024; // Stub: 32GB + } + + return RESULT_SUCCESS; +} + +Result GetExtendedDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + ExtendedDeviceMetrics* metrics, + size_t maxNvlinkPerDevice, + size_t maxIbNicPerDevice, + size_t maxPciePerDevice +) { + if (!deviceUUIDArray || deviceCount == 0 || !metrics || + maxNvlinkPerDevice == 0 || maxIbNicPerDevice == 0 || maxPciePerDevice == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Fill stub data + // Note: metrics[i].nvlinkBandwidthMBps, ibNicBandwidthMBps, pcieBandwidthMBps + // must be pre-allocated by caller with appropriate sizes + for (size_t i = 0; i < deviceCount; i++) { + ExtendedDeviceMetrics* edm = &metrics[i]; + snprintf(edm->deviceUUID, sizeof(edm->deviceUUID), "%s", deviceUUIDArray[i]); + + // Ascend doesn't have NVLink, but may have HCCS connections + edm->nvlinkCount = 0; + edm->nvlinkBandwidthMBps = NULL; + + // Stub: 2 HCCS connections per device (but not IB) + edm->ibNicCount = 0; // Not IB, but HCCS + edm->ibNicBandwidthMBps = NULL; + + // Stub: 1 PCIe link (but not more than max) + edm->pcieLinkCount = 1; + if (edm->pcieLinkCount > maxPciePerDevice) { + edm->pcieLinkCount = maxPciePerDevice; + } + if (edm->pcieBandwidthMBps && edm->pcieLinkCount > 0) { + edm->pcieBandwidthMBps[0] = 32000; // Stub: 32 GB/s (PCIe 4.0 x16) + } + } + + return RESULT_SUCCESS; +} + +// ============================================================================ +// Ascend Implementation - Utility APIs +// ============================================================================ + +Result Log(const char* level, const char* message) { + if (!level || !message) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: print to stderr + fprintf(stderr, "[%s] %s\n", level, message); + fflush(stderr); + + return RESULT_SUCCESS; +} + diff --git a/provider/limiter.h b/provider/limiter.h new file mode 100644 index 00000000..681a0ec2 --- /dev/null +++ b/provider/limiter.h @@ -0,0 +1,140 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIMITER_H +#define LIMITER_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================ +// Limiter Types +// ============================================================================ + +// Memory operation record +typedef struct { + char deviceUUID[64]; // Device UUID + int64_t bytesDiff; // Bytes difference (positive = allocation, negative = deallocation) + bool shouldBlock; // Output: whether this operation should be blocked + uint64_t availableBytes; // Output: available bytes after this operation +} MemoryOpRecord; + +// Compute operation record +typedef struct { + char deviceUUID[64]; // Device UUID + uint64_t computeTokens; // Compute tokens consumed (e.g., SM-cycles) + bool shouldBlock; // Output: whether this operation should be blocked + uint64_t availableTokens; // Output: available tokens after this operation +} ComputeOpRecord; + +// Worker freeze state +typedef struct { + char workerId[64]; // Worker identifier + bool isFrozen; // Current freeze state + uint64_t freezeTimeMs; // Time frozen in milliseconds +} WorkerFreezeState; + +// ============================================================================ +// Limiter APIs (Implemented by limiter.so, NOT by vendor accelerator.so) +// ============================================================================ + +/** + * Check and record memory operations for soft isolation. + * This API is called from hooks in CUDA runtime (via dlsym replacement). + * + * @param processId Process identifier + * @param deviceUUID Device UUID + * @param bytesDiff Bytes difference (positive = allocation, negative = deallocation) + * @param record Output parameter for operation record + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result CheckAndRecordMemoryOps(const char* processId, const char* deviceUUID, int64_t bytesDiff, MemoryOpRecord* record); + +/** + * Check and record compute operations for soft isolation. + * This API is called from hooks in CUDA runtime (via dlsym replacement). + * + * @param processId Process identifier + * @param deviceUUID Device UUID + * @param computeTokens Compute tokens consumed (e.g., SM-cycles) + * @param record Output parameter for operation record + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result CheckAndRecordComputeOps(const char* processId, const char* deviceUUID, uint64_t computeTokens, ComputeOpRecord* record); + +/** + * Freeze a worker process (pause execution when resource limit reached). + * This API is called automatically when resources are exhausted. + * + * @param workerId Worker identifier + * @param state Output parameter for freeze state + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result FreezeWorker(const char* workerId, WorkerFreezeState* state); + +/** + * Resume a worker process (resume execution when resources become available). + * This API is called automatically when resources become available. + * + * @param workerId Worker identifier + * @param state Output parameter for freeze state + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result ResumeWorker(const char* workerId, WorkerFreezeState* state); + +/** + * Auto-freeze hook: called when resource limit is reached. + * This triggers automatic freezing of the worker. + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param resourceType Resource type ("memory" or "compute") + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result AutoFreeze(const char* workerId, const char* deviceUUID, const char* resourceType); + +/** + * Auto-resume hook: called when resources become available. + * This triggers automatic resuming of the worker. + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param resourceType Resource type ("memory" or "compute") + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result AutoResume(const char* workerId, const char* deviceUUID, const char* resourceType); + +/** + * Add a worker process to the limiter tracking. + * This API is called when a process starts using a device. + * + * @param deviceUUID Device UUID + * @param processId Process identifier (as string) + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result AddWorkerProcess(const char* deviceUUID, const char* processId); + +#ifdef __cplusplus +} +#endif + +#endif // LIMITER_H + diff --git a/provider/stub/accelerator.c b/provider/stub/accelerator.c new file mode 100644 index 00000000..7fed0e2f --- /dev/null +++ b/provider/stub/accelerator.c @@ -0,0 +1,575 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Feature test macros for POSIX functions (required on Linux) +#define _POSIX_C_SOURCE 200809L +#define _DEFAULT_SOURCE + +#include "../accelerator.h" +#include +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================ +// Global Variables for Limiter Thread +// ============================================================================ + +static const char* g_processId = "stub-process-0"; +static _Atomic uint64_t g_lastComputeCallTimeMs = 0; // Last call time in milliseconds +static pthread_t g_limiterThread; +static volatile int g_threadRunning = 0; + +// ============================================================================ +// Limiter Thread Function +// ============================================================================ + +static void* limiterThreadFunc(void* arg __attribute__((unused))) { + // Get first device UUID for testing + ExtendedDeviceInfo devices[256]; // Stack-allocated buffer + size_t deviceCount = 0; + char deviceUUID[64] = {0}; + + if (GetAllDevices(devices, 256, &deviceCount) != RESULT_SUCCESS || deviceCount == 0) { + return NULL; + } + snprintf(deviceUUID, sizeof(deviceUUID), "%s", devices[0].basic.uuid); + + // Add worker process to limiter tracking + AddWorkerProcess(deviceUUID, g_processId); + + // Call CheckAndRecordMemoryOps once + MemoryOpRecord memRecord; + CheckAndRecordMemoryOps(g_processId, deviceUUID, 0, &memRecord); + + // Call CheckAndRecordComputeOps every second + while (g_threadRunning) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + uint64_t currentTimeMs = (uint64_t)ts.tv_sec * 1000 + (uint64_t)ts.tv_nsec / 1000000; + + ComputeOpRecord computeRecord; + CheckAndRecordComputeOps(g_processId, deviceUUID, 1000, &computeRecord); + + // Update global variable + g_lastComputeCallTimeMs = currentTimeMs; + + // Sleep for 1 second + sleep(1); + } + + return NULL; +} + +// ============================================================================ +// Constructor - Initialize Limiter Thread +// ============================================================================ + +__attribute__((constructor)) +static void initLimiterThread(void) { + g_threadRunning = 1; + if (pthread_create(&g_limiterThread, NULL, limiterThreadFunc, NULL) != 0) { + fprintf(stderr, "Failed to create limiter thread\n"); + return; + } + pthread_detach(g_limiterThread); +} + +// ============================================================================ +// Destructor - Cleanup Limiter Thread +// ============================================================================ + +__attribute__((destructor)) +static void cleanupLimiterThread(void) { + g_threadRunning = 0; + // Thread will exit on next iteration +} + +// ============================================================================ +// Stub Implementation - Limiter APIs +// ============================================================================ + +Result AddWorkerProcess(const char* deviceUUID, const char* processId) { + (void)deviceUUID; // Unused in stub + (void)processId; // Unused in stub + return RESULT_SUCCESS; +} + +Result CheckAndRecordMemoryOps(const char* processId, const char* deviceUUID, int64_t bytesDiff, MemoryOpRecord* record) { + (void)processId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)bytesDiff; // Unused in stub + + if (!record) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always allow, set available bytes to a large value + record->shouldBlock = false; + record->availableBytes = 16ULL * 1024 * 1024 * 1024; // 16GB + return RESULT_SUCCESS; +} + +Result CheckAndRecordComputeOps(const char* processId, const char* deviceUUID, uint64_t computeTokens, ComputeOpRecord* record) { + (void)processId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)computeTokens; // Unused in stub + + if (!record) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always allow, set available tokens to a large value + record->shouldBlock = false; + record->availableTokens = 1000000; // Large token pool + return RESULT_SUCCESS; +} + +Result FreezeWorker(const char* workerId, WorkerFreezeState* state) { + (void)workerId; // Unused in stub + if (!state) { + return RESULT_ERROR_INVALID_PARAM; + } + state->isFrozen = false; + state->freezeTimeMs = 0; + return RESULT_SUCCESS; +} + +Result ResumeWorker(const char* workerId, WorkerFreezeState* state) { + (void)workerId; // Unused in stub + if (!state) { + return RESULT_ERROR_INVALID_PARAM; + } + state->isFrozen = false; + state->freezeTimeMs = 0; + return RESULT_SUCCESS; +} + +Result AutoFreeze(const char* workerId, const char* deviceUUID, const char* resourceType) { + (void)workerId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)resourceType; // Unused in stub + return RESULT_SUCCESS; +} + +Result AutoResume(const char* workerId, const char* deviceUUID, const char* resourceType) { + (void)workerId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)resourceType; // Unused in stub + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - DeviceInfo APIs +// ============================================================================ + +Result GetDeviceCount(size_t* deviceCount) { + if (!deviceCount) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: return 4 devices + *deviceCount = 4; + return RESULT_SUCCESS; +} + +// Helper function to initialize a single device info +static void initDeviceInfo(ExtendedDeviceInfo* info, int32_t deviceIndex) { + // Initialize basic info + snprintf(info->basic.uuid, sizeof(info->basic.uuid), "stub-device-%d", deviceIndex); + snprintf(info->basic.vendor, sizeof(info->basic.vendor), "STUB"); + snprintf(info->basic.model, sizeof(info->basic.model), "Stub-GPU-Model"); + snprintf(info->basic.driverVersion, sizeof(info->basic.driverVersion), "1.0.0-stub"); + snprintf(info->basic.firmwareVersion, sizeof(info->basic.firmwareVersion), "1.0.0-stub"); + info->basic.index = deviceIndex; + info->basic.numaNode = deviceIndex % 2; // Stub: alternate NUMA nodes + info->basic.totalMemoryBytes = 16ULL * 1024 * 1024 * 1024; // 16GB + info->basic.totalComputeUnits = 108; // Stub: 108 SMs + info->basic.maxTflops = 312.0; // Stub: 312 TFLOPS + info->basic.pcieGen = 4; + info->basic.pcieWidth = 16; + + // Initialize properties + info->props.clockGraphics = 1410; // MHz + info->props.clockSM = 1410; // MHz + info->props.clockMem = 1215; // MHz + info->props.powerLimit = 400; // W + info->props.temperatureThreshold = 83; // C + info->props.eccEnabled = true; + info->props.persistenceModeEnabled = false; + snprintf(info->props.computeCapability, sizeof(info->props.computeCapability), "8.0"); + info->props.clockAI = 0; // Not applicable for stub + snprintf(info->props.chipType, sizeof(info->props.chipType), "STUB"); + + // Initialize capabilities + info->capabilities.supportsPartitioning = true; + info->capabilities.supportsSoftIsolation = true; + info->capabilities.supportsHardIsolation = true; + info->capabilities.supportsSnapshot = true; + info->capabilities.supportsMetrics = true; + info->capabilities.maxPartitions = 7; + info->capabilities.maxWorkersPerDevice = 16; + + // Initialize related devices (stub: no related devices) + info->relatedDevices = NULL; + info->relatedDeviceCount = 0; +} + +Result GetAllDevices(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount) { + if (!devices || !deviceCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: return 4 devices (but not more than maxCount) + size_t actualCount = 4; + if (actualCount > maxCount) { + actualCount = maxCount; + } + *deviceCount = actualCount; + + // Initialize each device + for (size_t i = 0; i < actualCount; i++) { + initDeviceInfo(&devices[i], (int32_t)i); + } + + return RESULT_SUCCESS; +} + +Result GetPartitionTemplates(int32_t deviceIndex __attribute__((unused)), PartitionTemplate* templates, size_t maxCount, size_t* templateCount) { + if (!templates || !templateCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: return 3 example templates (but not more than maxCount) + size_t actualCount = 3; + if (actualCount > maxCount) { + actualCount = maxCount; + } + *templateCount = actualCount; + + // Template 1: 1/7 slice + if (actualCount > 0) { + PartitionTemplate* t1 = &templates[0]; + snprintf(t1->templateId, sizeof(t1->templateId), "mig-1g.7gb"); + snprintf(t1->name, sizeof(t1->name), "1/7 GPU Slice"); + t1->memoryBytes = 7ULL * 1024 * 1024 * 1024; // 7GB + t1->computeUnits = 14; // 1/7 of 108 SMs + t1->tflops = 312.0 / 7.0; // ~44.6 TFLOPS + t1->sliceCount = 1; + t1->isDefault = false; + snprintf(t1->description, sizeof(t1->description), "1/7 GPU slice with 7GB memory"); + } + + // Template 2: 2/7 slice + if (actualCount > 1) { + PartitionTemplate* t2 = &templates[1]; + snprintf(t2->templateId, sizeof(t2->templateId), "mig-2g.14gb"); + snprintf(t2->name, sizeof(t2->name), "2/7 GPU Slice"); + t2->memoryBytes = 14ULL * 1024 * 1024 * 1024; // 14GB + t2->computeUnits = 28; // 2/7 of 108 SMs + t2->tflops = 312.0 * 2.0 / 7.0; // ~89.1 TFLOPS + t2->sliceCount = 2; + t2->isDefault = true; + snprintf(t2->description, sizeof(t2->description), "2/7 GPU slice with 14GB memory"); + } + + // Template 3: 3/7 slice + if (actualCount > 2) { + PartitionTemplate* t3 = &templates[2]; + snprintf(t3->templateId, sizeof(t3->templateId), "mig-3g.21gb"); + snprintf(t3->name, sizeof(t3->name), "3/7 GPU Slice"); + t3->memoryBytes = 21ULL * 1024 * 1024 * 1024; // 21GB (stub, exceeds total) + t3->computeUnits = 42; // 3/7 of 108 SMs + t3->tflops = 312.0 * 3.0 / 7.0; // ~133.7 TFLOPS + t3->sliceCount = 3; + t3->isDefault = false; + snprintf(t3->description, sizeof(t3->description), "3/7 GPU slice with 21GB memory"); + } + + return RESULT_SUCCESS; +} + +Result GetDeviceTopology(int32_t* deviceIndexArray, size_t deviceCount, ExtendedDeviceTopology* topology, size_t maxConnectionsPerDevice) { + if (!deviceIndexArray || deviceCount == 0 || !topology || maxConnectionsPerDevice == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Note: topology->devices must be pre-allocated by caller with size >= deviceCount + // topology->devices[i].connections must be pre-allocated by caller with size >= maxConnectionsPerDevice + if (!topology->devices) { + return RESULT_ERROR_INVALID_PARAM; + } + topology->deviceCount = deviceCount; + + // Initialize each device topology + for (size_t i = 0; i < deviceCount; i++) { + DeviceTopology* dt = &topology->devices[i]; + snprintf(dt->deviceUUID, sizeof(dt->deviceUUID), "stub-device-%d", deviceIndexArray[i]); + dt->numaNode = deviceIndexArray[i] % 2; + + // Stub: create connections to other devices + size_t connectionCount = (deviceCount > 1) ? (deviceCount - 1) : 0; + if (connectionCount > maxConnectionsPerDevice) { + connectionCount = maxConnectionsPerDevice; + } + + if (connectionCount > 0 && dt->connections) { + dt->connectionCount = connectionCount; + + size_t connIdx = 0; + for (size_t j = 0; j < deviceCount && connIdx < connectionCount; j++) { + if (j != i) { + RelatedDevice* rd = &dt->connections[connIdx]; + snprintf(rd->deviceUUID, sizeof(rd->deviceUUID), "stub-device-%d", deviceIndexArray[j]); + snprintf(rd->connectionType, sizeof(rd->connectionType), "NVLink"); + rd->bandwidthMBps = 600000; // 600 GB/s (stub) + rd->latencyNs = 100; // 100ns (stub) + connIdx++; + } + } + } else { + dt->connections = NULL; + dt->connectionCount = 0; + } + } + + // Set extended topology info + topology->nvlinkBandwidthMBps = 600000 * deviceCount; // Total bandwidth + topology->ibNicCount = 0; // Stub: no IB NICs + snprintf(topology->topologyType, sizeof(topology->topologyType), "NVLink"); + + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - Virtualization APIs - Partitioned Isolation +// ============================================================================ + +bool AssignPartition(PartitionAssignment* assignment) { + if (!assignment || assignment->templateId[0] == '\0' || assignment->deviceUUID[0] == '\0') { + return false; + } + + // Stub: generate a partition UUID + // Limit string lengths to ensure output fits in 64-byte buffer: + // "partition-" (9) + templateId (26) + "-" (1) + deviceUUID (26) + null (1) = 63 bytes + snprintf(assignment->partitionUUID, sizeof(assignment->partitionUUID), + "partition-%.26s-%.26s", assignment->templateId, assignment->deviceUUID); + + // Stub: set partition overhead (e.g., 100MB) + assignment->partitionOverheadBytes = 100ULL * 1024 * 1024; + + return true; +} + +bool RemovePartition(const char* templateId, const char* deviceUUID) { + if (!templateId || !deviceUUID) { + return false; + } + + // Stub: always succeed + return true; +} + +// ============================================================================ +// Stub Implementation - Virtualization APIs - Hard Isolation +// ============================================================================ + +Result SetMemHardLimit(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes) { + if (!workerId || !deviceUUID || memoryLimitBytes == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always succeed + return RESULT_SUCCESS; +} + +Result SetComputeUnitHardLimit(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit) { + if (!workerId || !deviceUUID || computeUnitLimit == 0 || computeUnitLimit > 100) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always succeed + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - Virtualization APIs - Device Snapshot/Migration +// ============================================================================ + +Result Snapshot(ProcessArray* processes) { + if (!processes || !processes->processIds || processes->processCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: verify processes exist (basic check) + for (size_t i = 0; i < processes->processCount; i++) { + if (kill(processes->processIds[i], 0) != 0) { + // Process doesn't exist or no permission + return RESULT_ERROR_NOT_FOUND; + } + } + + // Stub: always succeed (no actual snapshot implementation) + return RESULT_SUCCESS; +} + +Result Resume(ProcessArray* processes) { + if (!processes || !processes->processIds || processes->processCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always succeed (no actual resume implementation) + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - Metrics APIs +// ============================================================================ + +Result GetProcessComputeUtilization( + ComputeUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +) { + if (!utilizations || !utilizationCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Get actual device and process list from limiter + // For now, stub implementation returns empty + // The actual implementation should query limiter for all tracked processes + *utilizationCount = 0; + return RESULT_SUCCESS; +} + +Result GetProcessMemoryUtilization( + MemoryUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +) { + if (!utilizations || !utilizationCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Get actual device and process list from limiter + // For now, stub implementation returns empty + // The actual implementation should query limiter for all tracked processes + *utilizationCount = 0; + return RESULT_SUCCESS; +} + +Result GetDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + DeviceMetrics* metrics +) { + if (!deviceUUIDArray || deviceCount == 0 || !metrics) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Fill stub data + for (size_t i = 0; i < deviceCount; i++) { + DeviceMetrics* dm = &metrics[i]; + snprintf(dm->deviceUUID, sizeof(dm->deviceUUID), "%s", deviceUUIDArray[i]); + dm->powerUsageWatts = 200.0 + (i * 10.0); // Stub: 200-300W + dm->temperatureCelsius = 45.0 + (i * 5.0); // Stub: 45-50C + dm->pcieRxBytes = 1024ULL * 1024 * 1024 * (i + 1); // Stub: 1-4GB + dm->pcieTxBytes = 512ULL * 1024 * 1024 * (i + 1); // Stub: 0.5-2GB + dm->smActivePercent = 50 + (i * 10); // Stub: 50-90% + dm->tensorCoreUsagePercent = 30 + (i * 5); // Stub: 30-50% + dm->memoryUsedBytes = 8ULL * 1024 * 1024 * 1024; // Stub: 8GB + dm->memoryTotalBytes = 16ULL * 1024 * 1024 * 1024; // Stub: 16GB + } + + return RESULT_SUCCESS; +} + +Result GetExtendedDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + ExtendedDeviceMetrics* metrics, + size_t maxNvlinkPerDevice, + size_t maxIbNicPerDevice, + size_t maxPciePerDevice +) { + if (!deviceUUIDArray || deviceCount == 0 || !metrics || + maxNvlinkPerDevice == 0 || maxIbNicPerDevice == 0 || maxPciePerDevice == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Fill stub data + // Note: metrics[i].nvlinkBandwidthMBps, ibNicBandwidthMBps, pcieBandwidthMBps + // must be pre-allocated by caller with appropriate sizes + for (size_t i = 0; i < deviceCount; i++) { + ExtendedDeviceMetrics* edm = &metrics[i]; + snprintf(edm->deviceUUID, sizeof(edm->deviceUUID), "%s", deviceUUIDArray[i]); + + // Stub: 6 NVLink connections per device (but not more than max) + edm->nvlinkCount = 6; + if (edm->nvlinkCount > maxNvlinkPerDevice) { + edm->nvlinkCount = maxNvlinkPerDevice; + } + if (edm->nvlinkBandwidthMBps) { + for (size_t j = 0; j < edm->nvlinkCount; j++) { + edm->nvlinkBandwidthMBps[j] = 500000 + (j * 10000); // Stub: 500-550 GB/s + } + } + + // Stub: 2 IB NICs per device (but not more than max) + edm->ibNicCount = 2; + if (edm->ibNicCount > maxIbNicPerDevice) { + edm->ibNicCount = maxIbNicPerDevice; + } + if (edm->ibNicBandwidthMBps) { + for (size_t j = 0; j < edm->ibNicCount; j++) { + edm->ibNicBandwidthMBps[j] = 200000; // Stub: 200 GB/s per NIC + } + } + + // Stub: 1 PCIe link (but not more than max) + edm->pcieLinkCount = 1; + if (edm->pcieLinkCount > maxPciePerDevice) { + edm->pcieLinkCount = maxPciePerDevice; + } + if (edm->pcieBandwidthMBps && edm->pcieLinkCount > 0) { + edm->pcieBandwidthMBps[0] = 32000; // Stub: 32 GB/s (PCIe 4.0 x16) + } + } + + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - Utility APIs +// ============================================================================ + +Result Log(const char* level, const char* message) { + if (!level || !message) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: print to stderr + fprintf(stderr, "[%s] %s\n", level, message); + fflush(stderr); + + return RESULT_SUCCESS; +} + diff --git a/provider/test/test_accelerator.c b/provider/test/test_accelerator.c new file mode 100644 index 00000000..6b04e3bc --- /dev/null +++ b/provider/test/test_accelerator.c @@ -0,0 +1,293 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "../accelerator.h" + +// Test result tracking +static int tests_run = 0; +static int tests_passed = 0; +static int tests_failed = 0; + +#define TEST_ASSERT(condition, message) \ + do { \ + tests_run++; \ + if (condition) { \ + tests_passed++; \ + printf(" ✓ %s\n", message); \ + } else { \ + tests_failed++; \ + printf(" ✗ %s\n", message); \ + } \ + } while (0) + +// Test getDeviceInfo +void test_getDeviceInfo() { + printf("\n=== Testing getDeviceInfo ===\n"); + + ExtendedDeviceInfo info; + Result result = getDeviceInfo(0, &info); + + TEST_ASSERT(result == RESULT_SUCCESS, "getDeviceInfo returns success"); + TEST_ASSERT(strlen(info.basic.uuid) > 0, "Device UUID is not empty"); + TEST_ASSERT(strlen(info.basic.vendor) > 0, "Vendor is not empty"); + TEST_ASSERT(strlen(info.basic.model) > 0, "Model is not empty"); + TEST_ASSERT(info.basic.totalMemoryBytes > 0, "Total memory > 0"); + TEST_ASSERT(info.basic.totalComputeUnits > 0, "Total compute units > 0"); + TEST_ASSERT(info.basic.maxTflops > 0, "Max TFLOPS > 0"); + TEST_ASSERT(info.capabilities.maxPartitions > 0, "Max partitions > 0"); + + // Test invalid device index + result = getDeviceInfo(-1, &info); + TEST_ASSERT(result != RESULT_SUCCESS, "Invalid device index returns error"); + + // Cleanup + freeExtendedDeviceInfo(&info); +} + +// Test getPartitionTemplates +void test_getPartitionTemplates() { + printf("\n=== Testing getPartitionTemplates ===\n"); + + PartitionTemplate* templates = NULL; + size_t templateCount = 0; + Result result = getPartitionTemplates(0, &templates, &templateCount); + + TEST_ASSERT(result == RESULT_SUCCESS, "getPartitionTemplates returns success"); + TEST_ASSERT(templates != NULL, "Templates array is not NULL"); + TEST_ASSERT(templateCount > 0, "Template count > 0"); + + if (templates && templateCount > 0) { + TEST_ASSERT(strlen(templates[0].templateId) > 0, "First template has ID"); + TEST_ASSERT(strlen(templates[0].name) > 0, "First template has name"); + TEST_ASSERT(templates[0].memoryBytes > 0, "First template has memory"); + TEST_ASSERT(templates[0].computeUnits > 0, "First template has compute units"); + } + + // Cleanup + freePartitionTemplates(templates, templateCount); +} + +// Test getDeviceTopology +void test_getDeviceTopology() { + printf("\n=== Testing getDeviceTopology ===\n"); + + int32_t deviceIndices[] = {0, 1}; + size_t deviceCount = 2; + ExtendedDeviceTopology topology; + + Result result = getDeviceTopology(deviceIndices, deviceCount, &topology); + + TEST_ASSERT(result == RESULT_SUCCESS, "getDeviceTopology returns success"); + TEST_ASSERT(topology.devices != NULL, "Devices array is not NULL"); + TEST_ASSERT(topology.deviceCount == deviceCount, "Device count matches"); + + if (topology.devices && topology.deviceCount > 0) { + TEST_ASSERT(strlen(topology.devices[0].deviceUUID) > 0, "First device has UUID"); + } + + // Cleanup + freeExtendedDeviceTopology(&topology); +} + +// Test assignPartition +void test_assignPartition() { + printf("\n=== Testing assignPartition ===\n"); + + PartitionAssignment assignment; + snprintf(assignment.templateId, sizeof(assignment.templateId), "mig-1g.7gb"); + snprintf(assignment.deviceUUID, sizeof(assignment.deviceUUID), "stub-device-0"); + + bool result = assignPartition(&assignment); + + TEST_ASSERT(result == true, "assignPartition returns true"); + TEST_ASSERT(strlen(assignment.partitionUUID) > 0, "Partition UUID is assigned"); + TEST_ASSERT(assignment.partitionOverheadBytes > 0, "Partition overhead > 0"); + + // Test invalid input + PartitionAssignment invalid; + invalid.templateId[0] = '\0'; + invalid.deviceUUID[0] = '\0'; + result = assignPartition(&invalid); + TEST_ASSERT(result == false, "Invalid assignment returns false"); +} + +// Test removePartition +void test_removePartition() { + printf("\n=== Testing removePartition ===\n"); + + bool result = removePartition("mig-1g.7gb", "stub-device-0"); + TEST_ASSERT(result == true, "removePartition returns true"); + + result = removePartition(NULL, "stub-device-0"); + TEST_ASSERT(result == false, "NULL templateId returns false"); +} + +// Test setMemHardLimit +void test_setMemHardLimit() { + printf("\n=== Testing setMemHardLimit ===\n"); + + Result result = setMemHardLimit("worker-1", "stub-device-0", 4ULL * 1024 * 1024 * 1024); + TEST_ASSERT(result == RESULT_SUCCESS, "setMemHardLimit returns success"); + + result = setMemHardLimit(NULL, "stub-device-0", 4ULL * 1024 * 1024 * 1024); + TEST_ASSERT(result == RESULT_ERROR_INVALID_PARAM, "NULL workerId returns error"); +} + +// Test setComputeUnitHardLimit +void test_setComputeUnitHardLimit() { + printf("\n=== Testing setComputeUnitHardLimit ===\n"); + + Result result = setComputeUnitHardLimit("worker-1", "stub-device-0", 50); + TEST_ASSERT(result == RESULT_SUCCESS, "setComputeUnitHardLimit returns success"); + + result = setComputeUnitHardLimit("worker-1", "stub-device-0", 150); + TEST_ASSERT(result == RESULT_ERROR_INVALID_PARAM, "Invalid limit > 100 returns error"); +} + +// Test getProcessComputeUtilization +void test_getProcessComputeUtilization() { + printf("\n=== Testing getProcessComputeUtilization ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + const char* processIds[] = {"12345"}; + ComputeUtilization* utilizations = NULL; + size_t utilizationCount = 0; + + Result result = getProcessComputeUtilization( + deviceUUIDs, 1, + processIds, 1, + &utilizations, &utilizationCount + ); + + TEST_ASSERT(result == RESULT_SUCCESS, "getProcessComputeUtilization returns success"); + TEST_ASSERT(utilizations != NULL, "Utilizations array is not NULL"); + TEST_ASSERT(utilizationCount > 0, "Utilization count > 0"); + + if (utilizations && utilizationCount > 0) { + TEST_ASSERT(utilizations[0].utilizationPercent >= 0 && + utilizations[0].utilizationPercent <= 100, + "Utilization percent in valid range"); + } + + freeComputeUtilizations(utilizations, utilizationCount); +} + +// Test getProcessMemoryUtilization +void test_getProcessMemoryUtilization() { + printf("\n=== Testing getProcessMemoryUtilization ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + const char* processIds[] = {"12345"}; + MemoryUtilization* utilizations = NULL; + size_t utilizationCount = 0; + + Result result = getProcessMemoryUtilization( + deviceUUIDs, 1, + processIds, 1, + &utilizations, &utilizationCount + ); + + TEST_ASSERT(result == RESULT_SUCCESS, "getProcessMemoryUtilization returns success"); + TEST_ASSERT(utilizations != NULL, "Utilizations array is not NULL"); + TEST_ASSERT(utilizationCount > 0, "Utilization count > 0"); + + if (utilizations && utilizationCount > 0) { + TEST_ASSERT(utilizations[0].usedBytes > 0, "Used bytes > 0"); + } + + freeMemoryUtilizations(utilizations, utilizationCount); +} + +// Test getDeviceMetrics +void test_getDeviceMetrics() { + printf("\n=== Testing getDeviceMetrics ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + DeviceMetrics* metrics = NULL; + + Result result = getDeviceMetrics(deviceUUIDs, 1, &metrics); + + TEST_ASSERT(result == RESULT_SUCCESS, "getDeviceMetrics returns success"); + TEST_ASSERT(metrics != NULL, "Metrics array is not NULL"); + + if (metrics) { + TEST_ASSERT(strlen(metrics[0].deviceUUID) > 0, "Device UUID is not empty"); + TEST_ASSERT(metrics[0].powerUsageWatts >= 0, "Power usage >= 0"); + TEST_ASSERT(metrics[0].temperatureCelsius >= 0, "Temperature >= 0"); + } + + freeDeviceMetrics(metrics, 1); +} + +// Test getExtendedDeviceMetrics +void test_getExtendedDeviceMetrics() { + printf("\n=== Testing getExtendedDeviceMetrics ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + ExtendedDeviceMetrics* metrics = NULL; + + Result result = getExtendedDeviceMetrics(deviceUUIDs, 1, &metrics); + + TEST_ASSERT(result == RESULT_SUCCESS, "getExtendedDeviceMetrics returns success"); + TEST_ASSERT(metrics != NULL, "Metrics array is not NULL"); + + if (metrics) { + TEST_ASSERT(strlen(metrics[0].deviceUUID) > 0, "Device UUID is not empty"); + TEST_ASSERT(metrics[0].nvlinkCount > 0, "NVLink count > 0"); + } + + freeExtendedDeviceMetrics(metrics, 1); +} + +// Main test runner +int main() { + printf("========================================\n"); + printf("Accelerator Library Test Suite\n"); + printf("========================================\n"); + + test_getDeviceInfo(); + test_getPartitionTemplates(); + test_getDeviceTopology(); + test_assignPartition(); + test_removePartition(); + test_setMemHardLimit(); + test_setComputeUnitHardLimit(); + test_getProcessComputeUtilization(); + test_getProcessMemoryUtilization(); + test_getDeviceMetrics(); + test_getExtendedDeviceMetrics(); + + printf("\n========================================\n"); + printf("Test Summary\n"); + printf("========================================\n"); + printf("Total tests: %d\n", tests_run); + printf("Passed: %d\n", tests_passed); + printf("Failed: %d\n", tests_failed); + printf("========================================\n"); + + if (tests_failed == 0) { + printf("All tests passed! ✓\n"); + return 0; + } else { + printf("Some tests failed! ✗\n"); + return 1; + } +} +