diff --git a/.gitignore b/.gitignore
index aac9309..d5cf866 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,4 +10,13 @@ htmlcov/
.idea
*.log
*.pyc
-examples/paddle_case/log
\ No newline at end of file
+*.so
+examples/paddle_case/log
+
+# Auto-generated hipified files and directories (created during ROCm build)
+fastsafetensors/cpp/hip/
+fastsafetensors/cpp/*.hip.*
+fastsafetensors/cpp/hip_compat.h
+
+# Auto-generated PyPI index (generated by GitHub Actions)
+pypi-index/
\ No newline at end of file
diff --git a/README.md b/README.md
index bcacbce..4d58823 100644
--- a/README.md
+++ b/README.md
@@ -48,8 +48,9 @@ Please refer to [Foundation Model Stack Community Code of Conduct](https://githu
Takeshi Yoshimura, Tatsuhiro Chiba, Manish Sethi, Daniel Waddington, Swaminathan Sundararaman. (2025) Speeding up Model Loading with fastsafetensors [arXiv:2505.23072](https://arxiv.org/abs/2505.23072) and IEEE CLOUD 2025.
+## For NVIDIA
-## Install from PyPI
+### Install from PyPI
See https://pypi.org/project/fastsafetensors/
@@ -57,7 +58,24 @@ See https://pypi.org/project/fastsafetensors/
pip install fastsafetensors
```
-## Install from source
+### Install from source
+
+```bash
+pip install .
+```
+
+## For ROCm
+
+On ROCm, there are not GDS equivalent support. So fastsafetensors support only supports `nogds=True` mode.
+The performance gain example can be found at [amd-perf.md](./docs/amd-perf.md)
+
+### Install from Github Source
+
+```bash
+pip install git+https://github.com/foundation-model-stack/fastsafetensors.git
+```
+
+### Install from source
```bash
pip install .
diff --git a/docs/amd-perf.md b/docs/amd-perf.md
new file mode 100644
index 0000000..ea60664
--- /dev/null
+++ b/docs/amd-perf.md
@@ -0,0 +1,88 @@
+# Performance of FastSafeTensors on AMD GPUs
+
+## DeepSeek-R1 vLLM Model Weight Loading Speed
+
+This benchmark compares the performance of `safetensors` vs `fastsafetensors` when loading model weights on AMD GPUs.
+
+NOTES: `fastsafetensors` does not support GDS feature on ROCm as there are no GDS alternative on ROCm.
+
+### Benchmark Methodology
+
+**Platform:** AMD ROCm 7.0.1
+**GPUs:** 8x AMD Instinct MI300X
+**Library:** fastsafetensors 0.1.15
+
+1. **Clear system cache** to ensure consistent starting conditions:
+ ```bash
+ sudo sh -c 'sync && echo 3 > /proc/sys/vm/drop_caches'
+ ```
+
+2. **Launch vLLM** with either `--load-format safetensors` or `--load-format fastsafetensors`:
+
+ ```bash
+ MODEL=EmbeddedLLM/deepseek-r1-FP8-Dynamic
+
+ VLLM_USE_V1=1 \
+ VLLM_ROCM_USE_AITER=1 \
+ vllm serve $MODEL \
+ --tensor-parallel-size 8 \
+ --disable-log-requests \
+ --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
+ --trust-remote-code \
+ --load-format fastsafetensors \
+ --block-size 1
+ ```
+
+### Results
+
+The experiments are carried on MI300X.
+
+**Cache Scenarios:**
+- **No cache**: Model weights are loaded after clearing the system cache (cold start).
+- **Cached**: Model weights are loaded immediately after a previous load. The weights are cached in the filesystem and RAM (warm start).
+
+
+
+
+
+
+## GPT-2 perf tests based on the script [perf/fastsafetensors_perf/perf.py](../perf/fastsafetensors_perf/perf.py)
+
+### Test Configuration
+
+All tests were performed on single-GPU loading scenarios with two different model sizes:
+- **GPT-2 (small):** 523MB safetensors file
+- **GPT-2 Medium:** ~1.4GB safetensors file
+
+#### Key Parameters Tested:
+- **nogds mode:** ROCm fallback (GDS not available on AMD GPUs)
+- **Thread counts:** 8, 16, 32
+- **Buffer sizes:** 8MB, 16MB, 32MB
+- **Loading methods:** nogds (async I/O), mmap (memory-mapped)
+- **Data types:** AUTO (no conversion), F16 (half precision conversion)
+
+---
+
+#### Performance Results
+
+##### GPT-2 (523MB) - Single GPU Tests
+
+| Test # | Method | Threads | Buffer | Config | Bandwidth | Elapsed Time | Notes |
+|--------|--------|---------|--------|--------|-----------|--------------|-------|
+| 1 | nogds | 16 | 16MB | default | **1.91 GB/s** | 0.268s | Baseline test |
+| 2 | nogds | 32 | 32MB | default | **2.07 GB/s** | 0.246s | Higher threads/buffer |
+| 3 | nogds | 8 | 8MB | default | **2.10 GB/s** | 0.243s | Lower threads/buffer |
+| 4 | mmap | N/A | N/A | default | **1.01 GB/s** | 0.505s | Memory-mapped |
+| 5 | nogds | 32 | 32MB | cache-drop | **1.24 GB/s** | 0.410s | Cold cache test |
+| 6 | nogds | 32 | 32MB | F16 dtype | **0.77 GB/s** | 0.332s | With type conversion |
+| 8 | nogds | 16 | 16MB | **optimal** | **2.62 GB/s** | 0.195s | Best config |
+
+##### GPT-2 Medium (1.4GB) - Single GPU Tests
+
+| Test # | Method | Threads | Buffer | Block Size | Bandwidth | Elapsed Time | Notes |
+|--------|--------|---------|--------|------------|-----------|--------------|-------|
+| 9 | nogds | 16 | 16MB | 160MB | **6.02 GB/s** | 0.235s | Optimal config |
+| 10 | mmap | N/A | N/A | N/A | **1.28 GB/s** | 1.104s | Memory-mapped |
+| 11 | nogds | 32 | 32MB | 160MB | **5.34 GB/s** | 0.265s | Higher threads |
+
+---
\ No newline at end of file
diff --git a/docs/images/fastsafetensors-rocm.png b/docs/images/fastsafetensors-rocm.png
new file mode 100644
index 0000000..7bc2324
Binary files /dev/null and b/docs/images/fastsafetensors-rocm.png differ
diff --git a/fastsafetensors/common.py b/fastsafetensors/common.py
index d369975..33446a0 100644
--- a/fastsafetensors/common.py
+++ b/fastsafetensors/common.py
@@ -14,6 +14,15 @@
from .st_types import Device, DType
+def is_gpu_found():
+ """Check if any GPU (CUDA or HIP) is available.
+
+ Returns True if either CUDA or ROCm/HIP GPUs are detected.
+ This allows code to work transparently across both platforms.
+ """
+ return fstcpp.is_cuda_found() or fstcpp.is_hip_found()
+
+
def get_device_numa_node(device: Optional[int]) -> Optional[int]:
if device is None or not sys.platform.startswith("linux"):
return None
diff --git a/fastsafetensors/copier/gds.py b/fastsafetensors/copier/gds.py
index de23786..9fde2af 100644
--- a/fastsafetensors/copier/gds.py
+++ b/fastsafetensors/copier/gds.py
@@ -5,7 +5,7 @@
from typing import Dict, Optional
from .. import cpp as fstcpp
-from ..common import SafeTensorsMetadata
+from ..common import SafeTensorsMetadata, is_gpu_found
from ..frameworks import FrameworkOpBase, TensorBase
from ..st_types import Device, DeviceType, DType
from .base import CopierInterface
@@ -30,12 +30,29 @@ def __init__(
self.fh: Optional[fstcpp.gds_file_handle] = None
self.copy_reqs: Dict[int, int] = {}
self.aligned_length = 0
- cudavers = list(map(int, framework.get_cuda_ver().split(".")))
- # CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors
- # Compatible with CUDA 11.x
- self.o_direct = not (
- cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)
- )
+ cuda_ver = framework.get_cuda_ver()
+ if cuda_ver and cuda_ver != "0.0":
+ # Parse version string (e.g., "cuda-12.1" or "hip-5.7.0")
+ # Extract the numeric part after the platform prefix
+ ver_parts = cuda_ver.split("-", 1)
+ if len(ver_parts) == 2:
+ cudavers = list(map(int, ver_parts[1].split(".")))
+ # CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors
+ # Compatible with CUDA 11.x
+ # Only applies to CUDA platform (not ROCm/HIP)
+ if ver_parts[0] == "cuda":
+ self.o_direct = not (
+ cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)
+ )
+ else:
+ # ROCm/HIP platform, use O_DIRECT
+ self.o_direct = True
+ else:
+ # Fallback if format is unexpected
+ self.o_direct = True
+ else:
+ # No GPU platform detected, use O_DIRECT
+ self.o_direct = True
def set_o_direct(self, enable: bool):
self.o_direct = enable
@@ -151,8 +168,10 @@ def new_gds_file_copier(
nogds: bool = False,
):
device_is_not_cpu = device.type != DeviceType.CPU
- if device_is_not_cpu and not fstcpp.is_cuda_found():
- raise Exception("[FAIL] libcudart.so does not exist")
+ if device_is_not_cpu and not is_gpu_found():
+ raise Exception(
+ "[FAIL] GPU runtime library (libcudart.so or libamdhip64.so) does not exist"
+ )
if not fstcpp.is_cufile_found() and not nogds:
warnings.warn(
"libcufile.so does not exist but nogds is False. use nogds=True",
diff --git a/fastsafetensors/cpp/cuda_compat.h b/fastsafetensors/cpp/cuda_compat.h
new file mode 100644
index 0000000..021e7f4
--- /dev/null
+++ b/fastsafetensors/cpp/cuda_compat.h
@@ -0,0 +1,37 @@
+// SPDX-License-Identifier: Apache-2.0
+/*
+ * CUDA/HIP compatibility layer for fastsafetensors
+ * Minimal compatibility header - only defines what hipify-perl doesn't handle
+ */
+
+#ifndef __CUDA_COMPAT_H__
+#define __CUDA_COMPAT_H__
+
+// Platform detection - this gets hipified to check __HIP_PLATFORM_AMD__
+#ifdef __HIP_PLATFORM_AMD__
+ #ifndef USE_ROCM
+ #define USE_ROCM
+ #endif
+ // Note: We do NOT include here to avoid compile-time dependencies.
+ // Instead, we dynamically load the ROCm runtime library (libamdhip64.so) at runtime
+ // using dlopen(), just like we do for CUDA (libcudart.so).
+ // Minimal types are defined in ext.hpp.
+#else
+ // For CUDA platform, we also avoid including headers and define minimal types in ext.hpp
+#endif
+
+// Runtime library name - hipify-perl doesn't change string literals
+#ifdef USE_ROCM
+ #define GPU_RUNTIME_LIB "libamdhip64.so"
+#else
+ #define GPU_RUNTIME_LIB "libcudart.so"
+#endif
+
+// Custom function pointer names that hipify-perl doesn't recognize
+// These are our own naming in ext_funcs struct, not standard CUDA API
+#ifdef USE_ROCM
+ #define cudaDeviceMalloc hipDeviceMalloc
+ #define cudaDeviceFree hipDeviceFree
+#endif
+
+#endif // __CUDA_COMPAT_H__
diff --git a/fastsafetensors/cpp/ext.cpp b/fastsafetensors/cpp/ext.cpp
index 4f08894..bd15650 100644
--- a/fastsafetensors/cpp/ext.cpp
+++ b/fastsafetensors/cpp/ext.cpp
@@ -10,6 +10,7 @@
#include
#include
+#include "cuda_compat.h"
#include "ext.hpp"
#define ALIGN 4096
@@ -78,6 +79,7 @@ ext_funcs_t cpu_fns = ext_funcs_t {
ext_funcs_t cuda_fns;
static bool cuda_found = false;
+static bool is_hip_runtime = false; // Track if we loaded HIP (not auto-hipified)
static bool cufile_found = false;
static int cufile_ver = 0;
@@ -89,7 +91,7 @@ template void mydlsym(T** h, void* lib, std::string const& name) {
static void load_nvidia_functions() {
cudaError_t (*cudaGetDeviceCount)(int*);
const char* cufileLib = "libcufile.so.0";
- const char* cudartLib = "libcudart.so";
+ const char* cudartLib = GPU_RUNTIME_LIB;
const char* numaLib = "libnuma.so.1";
bool init_log = getenv(ENV_ENABLE_INIT_LOG);
int mode = RTLD_LAZY | RTLD_GLOBAL | RTLD_NODELETE;
@@ -122,8 +124,12 @@ static void load_nvidia_functions() {
count = 0; // why cudaGetDeviceCount returns non-zero for errors?
}
cuda_found = count > 0;
+ // Detect if we loaded HIP runtime (ROCm) vs CUDA runtime
+ if (cuda_found && std::string(cudartLib).find("hip") != std::string::npos) {
+ is_hip_runtime = true;
+ }
if (init_log) {
- fprintf(stderr, "[DEBUG] device count=%d, cuda_found=%d\n", count, cuda_found);
+ fprintf(stderr, "[DEBUG] device count=%d, cuda_found=%d, is_hip_runtime=%d\n", count, cuda_found, is_hip_runtime);
}
} else {
cuda_found = false;
@@ -217,11 +223,28 @@ static void load_nvidia_functions() {
}
}
+// Note: is_cuda_found gets auto-hipified to is_hip_found on ROCm builds
+// So this function will be is_hip_found() after hipification on ROCm
bool is_cuda_found()
{
return cuda_found;
}
+// Separate function that always returns false on ROCm (CUDA not available on ROCm)
+// This will be used for the "is_cuda_found" Python export on ROCm builds
+bool cuda_not_available()
+{
+ return false; // On ROCm, CUDA is never available
+}
+
+// Separate function for checking HIP runtime detection (not hipified)
+// On CUDA: checks if HIP runtime was detected
+// On ROCm: not used (is_cuda_found gets hipified to is_hip_found)
+bool check_hip_runtime()
+{
+ return is_hip_runtime;
+}
+
bool is_cufile_found()
{
return cufile_found;
@@ -718,7 +741,21 @@ cpp_metrics_t get_cpp_metrics() {
PYBIND11_MODULE(__MOD_NAME__, m)
{
- m.def("is_cuda_found", &is_cuda_found);
+ // Export both is_cuda_found and is_hip_found on all platforms
+ // Use string concatenation to prevent hipify from converting the export names
+#ifdef USE_ROCM
+ // On ROCm after hipify:
+ // - is_cuda_found() becomes is_hip_found(), so export it as "is_hip_found"
+ // - Export cuda_not_available() as "is_cuda_found" (CUDA not available on ROCm)
+ m.def(("is_" "cuda" "_found"), &cuda_not_available); // Returns false on ROCm
+ m.def(("is_" "hip" "_found"), &is_cuda_found); // hipified to is_hip_found, returns hip status
+#else
+ // On CUDA:
+ // - is_cuda_found() checks for CUDA
+ // - check_hip_runtime() checks if HIP runtime was loaded
+ m.def(("is_" "cuda" "_found"), &is_cuda_found);
+ m.def(("is_" "hip" "_found"), &check_hip_runtime);
+#endif
m.def("is_cufile_found", &is_cufile_found);
m.def("cufile_version", &cufile_version);
m.def("set_debug_log", &set_debug_log);
diff --git a/fastsafetensors/cpp/ext.hpp b/fastsafetensors/cpp/ext.hpp
index eafd24c..770d3a1 100644
--- a/fastsafetensors/cpp/ext.hpp
+++ b/fastsafetensors/cpp/ext.hpp
@@ -15,6 +15,8 @@
#include
#include
+#include "cuda_compat.h"
+
#define ENV_ENABLE_INIT_LOG "FASTSAFETENSORS_ENABLE_INIT_LOG"
#ifndef __MOD_NAME__
@@ -33,8 +35,16 @@ typedef struct CUfileDescr_t {
const void *fs_ops; /* CUfileFSOps_t */
} CUfileDescr_t;
typedef struct CUfileError { CUfileOpError err; } CUfileError_t;
+
+// Define minimal CUDA/HIP types for both platforms to avoid compile-time dependencies
+// We load all GPU functions dynamically at runtime via dlopen()
typedef enum cudaError { cudaSuccess = 0, cudaErrorMemoryAllocation = 2 } cudaError_t;
+// Platform-specific enum values - CUDA and HIP have different values for HostToDevice
+#ifdef USE_ROCM
+enum cudaMemcpyKind { cudaMemcpyHostToDevice=1, cudaMemcpyDefault = 4 };
+#else
enum cudaMemcpyKind { cudaMemcpyHostToDevice=2, cudaMemcpyDefault = 4 };
+#endif
typedef enum CUfileFeatureFlags {
diff --git a/fastsafetensors/dlpack.py b/fastsafetensors/dlpack.py
index 007487a..e8883ed 100644
--- a/fastsafetensors/dlpack.py
+++ b/fastsafetensors/dlpack.py
@@ -12,24 +12,53 @@
_c_str_dltensor = b"dltensor"
+# Lazy GPU type detection - avoid calling framework-specific code at module load time
+_GPU_DEVICE_TYPE = None # Will be detected lazily
+
+
+def _detect_gpu_type():
+ """Detect if we're running on ROCm or CUDA.
+
+ This detection is now done lazily to avoid framework-specific calls at module load time.
+ Uses the C++ extension's is_hip_found() to determine the platform.
+ """
+ # Import here to avoid circular dependency
+ from . import cpp as fstcpp
+
+ # Check if we loaded HIP runtime (ROCm)
+ if fstcpp.is_hip_found():
+ return 10 # kDLROCM
+ return 2 # kDLCUDA
+
+
+def _get_gpu_device_type():
+ """Get the GPU device type, detecting it lazily if needed."""
+ global _GPU_DEVICE_TYPE
+ if _GPU_DEVICE_TYPE is None:
+ _GPU_DEVICE_TYPE = _detect_gpu_type()
+ return _GPU_DEVICE_TYPE
+
+
class DLDevice(ctypes.Structure):
def __init__(self, dev: Device):
- self.device_type = self.DeviceToDL[dev.type]
+ # Use lazy detection to get the GPU device type
+ gpu_type = _get_gpu_device_type()
+ device_to_dl = {
+ DeviceType.CPU: self.kDLCPU,
+ DeviceType.CUDA: gpu_type,
+ DeviceType.GPU: gpu_type,
+ }
+ self.device_type = device_to_dl[dev.type]
self.device_id = dev.index if dev.index is not None else 0
kDLCPU = 1
kDLCUDA = 2
+ kDLROCM = 10
_fields_ = [
("device_type", ctypes.c_int),
("device_id", ctypes.c_int),
]
- DeviceToDL = {
- DeviceType.CPU: kDLCPU,
- DeviceType.CUDA: kDLCUDA,
- DeviceType.GPU: kDLCUDA,
- }
-
class c_DLDataType(ctypes.Structure):
def __init__(self, dtype: DType):
diff --git a/fastsafetensors/frameworks/_paddle.py b/fastsafetensors/frameworks/_paddle.py
index 13f5eec..8ced6eb 100644
--- a/fastsafetensors/frameworks/_paddle.py
+++ b/fastsafetensors/frameworks/_paddle.py
@@ -214,11 +214,18 @@ def copy_tensor(self, dst: PaddleTensor, src: PaddleTensor) -> None:
dst.device = src.device
def get_cuda_ver(self) -> str:
- return (
- str(paddle.version.cuda())
- if paddle.device.is_compiled_with_cuda()
- else "0.0"
- )
+ """Get GPU runtime version with platform indicator.
+
+ Returns a string like 'hip-5.7.0' for ROCm or 'cuda-12.1' for CUDA,
+ or '0.0' if no GPU is available. This allows code to distinguish
+ between different GPU platforms without using paddle directly.
+ """
+ if paddle.device.is_compiled_with_cuda():
+ # Check if this is ROCm/HIP build
+ if paddle.device.is_compiled_with_rocm():
+ return f"hip-{paddle.version.cuda()}"
+ return f"cuda-{paddle.version.cuda()}"
+ return "0.0"
def get_device_ptr_align(self) -> int:
CUDA_PTR_ALIGN: int = 16
diff --git a/fastsafetensors/frameworks/_torch.py b/fastsafetensors/frameworks/_torch.py
index 1487153..aeb8084 100644
--- a/fastsafetensors/frameworks/_torch.py
+++ b/fastsafetensors/frameworks/_torch.py
@@ -186,8 +186,17 @@ def copy_tensor(self, dst: TorchTensor, src: TorchTensor):
dst.real_tensor.copy_(src.real_tensor)
def get_cuda_ver(self) -> str:
+ """Get GPU runtime version with platform indicator.
+
+ Returns a string like 'hip-5.7.0' for ROCm or 'cuda-12.1' for CUDA,
+ or '0.0' if no GPU is available. This allows code to distinguish
+ between different GPU platforms without using torch directly.
+ """
if torch.cuda.is_available():
- return str(torch.version.cuda)
+ # Check if this is ROCm/HIP build
+ if hasattr(torch.version, "hip") and torch.version.hip is not None:
+ return f"hip-{torch.version.hip}"
+ return f"cuda-{torch.version.cuda}"
return "0.0"
def get_device_ptr_align(self) -> int:
diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py
index 602c0f7..cfbc52a 100644
--- a/fastsafetensors/loader.py
+++ b/fastsafetensors/loader.py
@@ -5,7 +5,7 @@
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
from . import cpp as fstcpp
-from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node
+from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node, is_gpu_found
from .copier.gds import new_gds_file_copier
from .file_buffer import FilesBufferOnDevice
from .frameworks import FrameworkOpBase, TensorBase, get_framework_op
diff --git a/setup.py b/setup.py
index b052909..b513c58 100644
--- a/setup.py
+++ b/setup.py
@@ -2,11 +2,163 @@
# SPDX-License-Identifier: Apache-2.0
import os
+import re
+import shutil
+import subprocess
+from pathlib import Path
from setuptools import Extension, setup
+from setuptools.command.build_ext import build_ext
-def MyExtension(name, sources, mod_name, *args, **kwargs):
+def detect_platform():
+ """
+ Detect if we're on NVIDIA CUDA or AMD ROCm platform.
+
+ Returns:
+ tuple: (platform_type, rocm_version, rocm_path)
+ platform_type: 'cuda' or 'rocm'
+ rocm_version: ROCm version string (e.g., '7.0.1') or None
+ rocm_path: Path to ROCm installation or None
+ """
+ # Check for ROCm installation
+ rocm_path = os.environ.get("ROCM_PATH")
+ if not rocm_path:
+ # Try common ROCm installation paths
+ for path in ["/opt/rocm", "/opt/rocm-*"]:
+ if "*" in path:
+ import glob
+
+ matches = sorted(glob.glob(path), reverse=True)
+ if matches:
+ rocm_path = matches[0]
+ break
+ elif os.path.exists(path):
+ rocm_path = path
+ break
+
+ # Check if ROCm is available
+ if rocm_path and os.path.exists(rocm_path):
+ # Detect ROCm version
+ rocm_version = None
+ version_file = os.path.join(rocm_path, ".info", "version")
+ if os.path.exists(version_file):
+ with open(version_file, "r") as f:
+ rocm_version = f.read().strip()
+ else:
+ # Try to extract version from path
+ match = re.search(r"rocm[-/](\d+\.\d+(?:\.\d+)?)", rocm_path)
+ if match:
+ rocm_version = match.group(1)
+
+ print(f"Detected ROCm platform at {rocm_path}")
+ if rocm_version:
+ print(f"ROCm version: {rocm_version}")
+ return ("rocm", rocm_version, rocm_path)
+
+ # Check for CUDA
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
+ if not cuda_home:
+ # Try to find nvcc
+ nvcc_path = shutil.which("nvcc")
+ if nvcc_path:
+ cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
+
+ if cuda_home and os.path.exists(cuda_home):
+ print(f"Detected CUDA platform at {cuda_home}")
+ return ("cuda", None, None)
+
+ # Default to CUDA if nothing detected
+ print("No GPU platform detected, defaulting to CUDA")
+ return ("cuda", None, None)
+
+
+def hipify_source_files(rocm_path):
+ """
+ Automatically hipify CUDA source files to HIP using hipify-perl from ROCm.
+ The cuda_compat.h header handles what hipify doesn't convert.
+
+ Args:
+ rocm_path: Path to ROCm installation
+
+ Returns:
+ list: Paths to hipified source files
+ """
+ cpp_dir = Path("fastsafetensors/cpp").resolve()
+ hip_dir = cpp_dir / "hip"
+
+ # Create hip/ subdirectory if it doesn't exist
+ hip_dir.mkdir(exist_ok=True)
+
+ # Find hipify-perl in ROCm installation
+ hipify_perl = os.path.join(rocm_path, "bin", "hipify-perl")
+ if not os.path.exists(hipify_perl):
+ raise RuntimeError(
+ f"hipify-perl not found at {hipify_perl}. "
+ f"Please ensure ROCm is properly installed at {rocm_path}"
+ )
+
+ # Files to hipify
+ source_files = [
+ ("ext.cpp", "ext.cpp"),
+ ("ext.hpp", "ext.hpp"),
+ ]
+
+ print(f"Hipifying files using hipify-perl from {hipify_perl}:")
+ hipified_files = []
+
+ for src_name, dst_name in source_files:
+ src_path = cpp_dir / src_name
+ dst_path = hip_dir / dst_name
+
+ print(f" - {src_path} -> {dst_path}")
+
+ try:
+ # Run hipify-perl: hipify-perl input.cpp -o output.cpp
+ result = subprocess.run(
+ [hipify_perl, str(src_path), "-o", str(dst_path)],
+ check=True,
+ capture_output=True,
+ text=True,
+ )
+ print(f" Successfully hipified: {src_name}")
+ hipified_files.append(str(dst_path))
+
+ # Print any warnings from hipify-perl
+ if result.stderr:
+ print(f" hipify-perl output: {result.stderr.strip()}")
+
+ # Post-process: Replace cuda_compat.h with hip_compat.h
+ # hipify-perl doesn't convert custom header names
+ with open(dst_path, "r") as f:
+ content = f.read()
+ content = content.replace(
+ '#include "cuda_compat.h"', '#include "hip_compat.h"'
+ )
+ with open(dst_path, "w") as f:
+ f.write(content)
+ print(f" Post-processed: cuda_compat.h -> hip_compat.h")
+
+ except subprocess.CalledProcessError as e:
+ raise RuntimeError(
+ f"Failed to hipify {src_path}:\n"
+ f"stdout: {e.stdout}\n"
+ f"stderr: {e.stderr}"
+ ) from e
+
+ # Copy cuda_compat.h to hip directory as hip_compat.h
+ # (hipify converts the include statement from cuda_compat.h to hip_compat.h)
+ cuda_compat = cpp_dir / "cuda_compat.h"
+ hip_compat = hip_dir / "hip_compat.h"
+ shutil.copy2(cuda_compat, hip_compat)
+ print(f"Copied {cuda_compat} -> {hip_compat}")
+
+ return hipified_files
+
+
+def MyExtension(
+ name, sources, mod_name, platform_type, rocm_path=None, *args, **kwargs
+):
import pybind11
pybind11_path = os.path.dirname(pybind11.__file__)
@@ -21,9 +173,83 @@ def MyExtension(name, sources, mod_name, *args, **kwargs):
# https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes
kwargs["extra_compile_args"] = ["-fvisibility=hidden", "-std=c++17"]
+ # Platform-specific configuration
+ if platform_type == "rocm" and rocm_path:
+ # ROCm/HIP configuration
+ kwargs["define_macros"].append(("__HIP_PLATFORM_AMD__", "1"))
+ kwargs["libraries"].append("amdhip64")
+ kwargs["library_dirs"] = [f"{rocm_path}/lib"]
+ kwargs["include_dirs"].append(f"{rocm_path}/include")
+ kwargs["extra_compile_args"].append("-D__HIP_PLATFORM_AMD__")
+ kwargs["extra_link_args"] = [f"-L{rocm_path}/lib", "-lamdhip64"]
+
return Extension(name, sources, *args, **kwargs)
+class CustomBuildExt(build_ext):
+ """Custom build_ext to handle automatic hipification for ROCm platforms"""
+
+ def run(self):
+ # Detect platform
+ platform_type, rocm_version, rocm_path = detect_platform()
+
+ # Store platform info
+ self.platform_type = platform_type
+ self.rocm_version = rocm_version
+ self.rocm_path = rocm_path
+
+ # Configure build based on platform
+ if platform_type == "rocm" and rocm_path:
+ print("=" * 60)
+ print("Building for AMD ROCm platform")
+ if rocm_version:
+ print(f"ROCm version: {rocm_version}")
+ print("=" * 60)
+
+ # Hipify sources
+ hipify_source_files(rocm_path)
+
+ # Update extension sources to use hipified files
+ for ext in self.extensions:
+ new_sources = []
+ for src in ext.sources:
+ if "fastsafetensors/cpp/ext.cpp" in src:
+ # torch.utils.hipify creates files in hip/ subdirectory
+ new_sources.append(
+ src.replace(
+ "fastsafetensors/cpp/ext.cpp",
+ "fastsafetensors/cpp/hip/ext.cpp",
+ )
+ )
+ else:
+ new_sources.append(src)
+ ext.sources = new_sources
+
+ # Update include dirs to include hip/ subdirectory
+ ext.include_dirs.append("fastsafetensors/cpp/hip")
+
+ # Update extension with ROCm-specific settings
+ ext.define_macros.append(("__HIP_PLATFORM_AMD__", "1"))
+ ext.define_macros.append(("USE_ROCM", "1"))
+ ext.libraries.append("amdhip64")
+ ext.library_dirs = [f"{rocm_path}/lib"]
+ ext.include_dirs.append(f"{rocm_path}/include")
+ ext.extra_compile_args.append("-D__HIP_PLATFORM_AMD__")
+ ext.extra_compile_args.append("-DUSE_ROCM")
+ ext.extra_link_args = [f"-L{rocm_path}/lib", "-lamdhip64"]
+ else:
+ print("=" * 60)
+ print("Building for NVIDIA CUDA platform")
+ print("=" * 60)
+
+ # Continue with normal build
+ build_ext.run(self)
+
+
+# Detect platform for package_data
+platform_type, _, rocm_path_detected = detect_platform()
+package_data_patterns = ["*.hpp", "*.h", "cpp.pyi"]
+
setup(
packages=[
"fastsafetensors",
@@ -32,13 +258,18 @@ def MyExtension(name, sources, mod_name, *args, **kwargs):
"fastsafetensors.frameworks",
],
include_package_data=True,
- package_data={"fastsafetensors.cpp": ["*.hpp", "cpp.pyi"]},
+ package_data={"fastsafetensors.cpp": package_data_patterns},
ext_modules=[
MyExtension(
name=f"fastsafetensors.cpp",
sources=["fastsafetensors/cpp/ext.cpp"],
include_dirs=["fastsafetensors/cpp"],
mod_name="cpp",
+ platform_type=platform_type,
+ rocm_path=rocm_path_detected,
)
],
+ cmdclass={
+ "build_ext": CustomBuildExt,
+ },
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 96d2f95..af4d27f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,15 +1,20 @@
import os
+import sys
from typing import List
import pytest
from fastsafetensors import SingleGroup
from fastsafetensors import cpp as fstcpp
+from fastsafetensors.common import is_gpu_found
from fastsafetensors.cpp import load_nvidia_functions
from fastsafetensors.frameworks import FrameworkOpBase, get_framework_op
from fastsafetensors.st_types import Device
+# Add tests directory to path to import platform_utils
TESTS_DIR = os.path.dirname(__file__)
+from platform_utils import get_platform_info
+
REPO_ROOT = os.path.dirname(os.path.dirname(TESTS_DIR))
DATA_DIR = os.path.join(REPO_ROOT, ".testdata")
TF_DIR = os.path.join(DATA_DIR, "transformers_cache")
@@ -20,6 +25,15 @@
load_nvidia_functions()
FRAMEWORK = get_framework_op(os.getenv("TEST_FASTSAFETENSORS_FRAMEWORK", "please set"))
+# Print platform information at test startup
+platform_info = get_platform_info()
+print("\n" + "=" * 60)
+print("Platform Detection:")
+print("=" * 60)
+for key, value in platform_info.items():
+ print(f" {key}: {value}")
+print("=" * 60 + "\n")
+
@pytest.fixture(scope="session", autouse=True)
def framework() -> FrameworkOpBase:
@@ -68,7 +82,7 @@ def pg():
@pytest.fixture(scope="session", autouse=True)
def dev_init() -> None:
- if fstcpp.is_cuda_found():
+ if is_gpu_found():
dev_str = "cuda:0" if FRAMEWORK.get_name() == "pytorch" else "gpu:0"
else:
dev_str = "cpu"
diff --git a/tests/platform_utils.py b/tests/platform_utils.py
new file mode 100644
index 0000000..d9da5f6
--- /dev/null
+++ b/tests/platform_utils.py
@@ -0,0 +1,90 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2024 IBM Inc. All rights reserved
+
+"""Utilities for platform detection and conditional test execution."""
+
+import pytest
+
+
+def is_rocm_platform():
+ """Detect if running on ROCm/AMD platform.
+
+ Uses the C++ extension's is_hip_found() to avoid framework-specific calls.
+ """
+ try:
+ from fastsafetensors import cpp as fstcpp
+
+ return fstcpp.is_hip_found()
+ except:
+ return False
+
+
+def is_cuda_platform():
+ """Detect if running on CUDA/NVIDIA platform."""
+ return not is_rocm_platform()
+
+
+# List of tests that are expected to fail on ROCm (based on TEST_RESULTS.md)
+ROCM_EXPECTED_FAILURES = {
+ "test_GdsFileCopier", # GDS not available on AMD
+}
+
+
+def skip_if_rocm_expected_failure(test_name):
+ """Skip test if it's an expected failure on ROCm."""
+ if is_rocm_platform() and test_name in ROCM_EXPECTED_FAILURES:
+ pytest.skip(
+ f"Test '{test_name}' is expected to fail on ROCm (GDS not supported)"
+ )
+
+
+def get_platform_info():
+ """Get platform information for debugging.
+
+ Uses framework's get_cuda_ver() to avoid direct torch calls where possible.
+ """
+ info = {
+ "is_rocm": is_rocm_platform(),
+ "is_cuda": is_cuda_platform(),
+ }
+
+ try:
+ from fastsafetensors import cpp as fstcpp
+ from fastsafetensors.common import is_gpu_found
+
+ if is_gpu_found():
+ # Get version info from framework
+ try:
+ from fastsafetensors.frameworks import get_framework_op
+
+ framework = get_framework_op("pytorch")
+ gpu_ver = framework.get_cuda_ver()
+ info["gpu_version"] = gpu_ver
+
+ # Parse the version to get specific info
+ if gpu_ver.startswith("hip-"):
+ info["hip_version"] = gpu_ver[4:] # Remove 'hip-' prefix
+ info["rocm_version"] = gpu_ver[4:]
+ elif gpu_ver.startswith("cuda-"):
+ info["cuda_version"] = gpu_ver[5:] # Remove 'cuda-' prefix
+ except:
+ pass
+
+ # Get device count and name (still needs torch for this)
+ try:
+ import torch
+
+ if torch.cuda.is_available():
+ info["torch_version"] = torch.__version__
+ info["device_count"] = torch.cuda.device_count()
+ info["device_name"] = (
+ torch.cuda.get_device_name(0)
+ if torch.cuda.device_count() > 0
+ else None
+ )
+ except:
+ pass
+ except:
+ pass
+
+ return info
diff --git a/tests/test_fastsafetensors.py b/tests/test_fastsafetensors.py
index dcbdccd..1f4ff28 100644
--- a/tests/test_fastsafetensors.py
+++ b/tests/test_fastsafetensors.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
+import sys
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
@@ -10,13 +11,17 @@
from fastsafetensors import SafeTensorsFileLoader, SafeTensorsMetadata, SingleGroup
from fastsafetensors import cpp as fstcpp
from fastsafetensors import fastsafe_open
-from fastsafetensors.common import get_device_numa_node
+from fastsafetensors.common import get_device_numa_node, is_gpu_found
from fastsafetensors.copier.gds import GdsFileCopier
from fastsafetensors.copier.nogds import NoGdsFileCopier
from fastsafetensors.dlpack import from_cuda_buffer
from fastsafetensors.frameworks import FrameworkOpBase
from fastsafetensors.st_types import Device, DeviceType, DType
+# Add tests directory to path to import platform_utils
+sys.path.insert(0, os.path.dirname(__file__))
+from platform_utils import skip_if_rocm_expected_failure
+
def load_safetensors_file(
filename: str,
@@ -58,7 +63,7 @@ def save_safetensors_file(
def get_and_check_device(framework: FrameworkOpBase):
- dev_is_gpu = fstcpp.is_cuda_found()
+ dev_is_gpu = is_gpu_found()
device = "cpu"
if dev_is_gpu:
if framework.get_name() == "pytorch":
@@ -105,18 +110,27 @@ def test_framework(fstcpp_log, framework) -> None:
framework.is_equal(t, [float(0.0)])
with pytest.raises(Exception):
framework.get_process_group(int(0))
+ # Test that get_cuda_ver() returns a string with platform prefix
+ cuda_ver = framework.get_cuda_ver()
+ assert isinstance(cuda_ver, str)
+ # Should be "hip-X.Y.Z", "cuda-X.Y", or "0.0"
+ assert (
+ cuda_ver.startswith("hip-") or cuda_ver.startswith("cuda-") or cuda_ver == "0.0"
+ )
+
+ # Verify it matches what torch reports
if framework.get_name() == "pytorch":
import torch
- cuda_ver = str(torch.version.cuda) if torch.cuda.is_available() else "0.0"
- elif framework.get_name() == "paddle":
- import paddle
-
- if paddle.device.is_compiled_with_cuda():
- cuda_ver = str(paddle.version.cuda())
+ if torch.cuda.is_available():
+ if hasattr(torch.version, "hip") and torch.version.hip:
+ assert cuda_ver.startswith("hip-")
+ assert str(torch.version.hip) in cuda_ver
+ else:
+ assert cuda_ver.startswith("cuda-")
+ assert str(torch.version.cuda) in cuda_ver
else:
- cuda_ver = "0.0"
- assert framework.get_cuda_ver() == cuda_ver
+ assert cuda_ver == "0.0"
def test_get_framework_fail(fstcpp_log) -> None:
@@ -228,10 +242,10 @@ def test_close_gds(fstcpp_log) -> None:
def test_get_device_pci_bus(fstcpp_log) -> None:
bus = fstcpp.get_device_pci_bus(0)
- if not fstcpp.is_cuda_found():
+ if not is_gpu_found():
assert bus == ""
else:
- print(f"bus for cuda:0: {bus}")
+ print(f"bus for gpu:0: {bus}")
assert len(bus) > 0
@@ -326,6 +340,7 @@ def test_NoGdsFileCopier(fstcpp_log, input_files, framework) -> None:
def test_GdsFileCopier(fstcpp_log, input_files, framework) -> None:
print("test_GdsFileCopier")
+ skip_if_rocm_expected_failure("test_GdsFileCopier")
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
device, dev_is_gpu = get_and_check_device(framework)
reader = fstcpp.gds_file_reader(4, dev_is_gpu)
diff --git a/tests/test_multi.py b/tests/test_multi.py
index b80b2f9..12b95cf 100644
--- a/tests/test_multi.py
+++ b/tests/test_multi.py
@@ -5,6 +5,7 @@
from fastsafetensors import SafeTensorsFileLoader
from fastsafetensors import cpp as fstcpp
+from fastsafetensors.common import is_gpu_found
def test_shuffle(fstcpp_log, input_files, pg, framework):
@@ -14,13 +15,13 @@ def test_shuffle(fstcpp_log, input_files, pg, framework):
rank = pg.rank()
world_size = pg.size()
- device = "cuda:0" if fstcpp.is_cuda_found() else "cpu"
+ device = "cuda:0" if is_gpu_found() else "cpu"
elif framework.get_name() == "paddle":
from safetensors.paddle import load_file
rank = pg.process_group.rank()
world_size = pg.process_group.size()
- device = "gpu:0" if fstcpp.is_cuda_found() else "cpu"
+ device = "gpu:0" if is_gpu_found() else "cpu"
else:
raise Exception(f"Unknown framework: {framework.get_name()}")
loader = SafeTensorsFileLoader(