From 975b86a9a8e95f24b9e55e9f33df019e44e5a0a5 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 4 Sep 2025 04:08:54 +0000 Subject: [PATCH 01/10] Support PaddlePaddle with compatible API --- flashinfer/fused_moe/core.py | 1 + flashinfer/jit/cpp_ext.py | 52 +++++++++++++++++++++++++++++------- flashinfer/utils.py | 6 ++++- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 9a7ba595cf..6f97462ea9 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -42,6 +42,7 @@ get_shuffle_matrix_sf_a_row_indices, register_custom_op, register_fake_op, + use_paddle_compatible_api, ) from .utils import ( get_last_power_of_2_num_tokens_buckets, diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 26f7a2a073..d4d84a4bf6 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -20,6 +20,7 @@ ) from . import env as jit_env +from ..utils import use_paddle_compatible_api @functools.cache @@ -63,12 +64,25 @@ def generate_ninja_build_for_op( ) -> str: system_includes = [ sysconfig.get_path("include"), - "$torch_home/include", - "$torch_home/include/torch/csrc/api/include", "$cuda_home/include", jit_env.FLASHINFER_INCLUDE_DIR.resolve(), jit_env.FLASHINFER_CSRC_DIR.resolve(), ] + if use_paddle_compatible_api(): + system_includes.extend( + [ + "$torch_home/include", + "$torch_home/include/torch/csrc/api/include", + ] + ) + else: + system_includes.extend( + [ + "$torch_home/include", + "$torch_home/include/paddle/phi/api/include/compat", + "$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include", + ] + ) system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS] system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve()) @@ -113,15 +127,35 @@ def generate_ninja_build_for_op( ldflags = [ "-shared", - "-L$torch_home/lib", - "-L$cuda_home/lib64", - "-lc10", - "-lc10_cuda", - "-ltorch_cpu", - "-ltorch_cuda", - "-ltorch", "-lcudart", ] + if use_paddle_compatible_api(): + ldflags.extend( + [ + "-L$torch_home/lib", + "-L$cuda_home/lib64", + "-lc10", + "-lc10_cuda", + "-ltorch_cpu", + "-ltorch_cuda", + "-ltorch", + ] + ) + else: + ldflags.extend( + [ + "-shared", + "-L$torch_home/libs", + "-L$torch_home/base", + "-L$cuda_home/lib64", + "-lpaddle", + "-lphi", + "-lphi_core", + "-lphi_gpu", + "-lcommon", + "-lcudart", + ] + ) env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS") if env_extra_ldflags: diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 1c716d1e0e..15632a2a80 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -226,9 +226,13 @@ def _check_cached_qkv_data_type( raise ValueError( f"The dtype of k {k.dtype} does not match the kv_data_type {dtype_kv} specified in plan function." ) + +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] -if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"): + +if use_paddle_compatible_api() or IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"): def register_custom_op( name: str, From 9bb59d413c89b9c29ff67f441705773d7006ecfb Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 4 Sep 2025 04:09:06 +0000 Subject: [PATCH 02/10] update setup.py --- setup.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9b62a0ef33..bc537f0c37 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,9 @@ aot_ops_package_dir = root / "build" / "aot-ops-package-dir" enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir()) +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + def write_if_different(path: Path, content: str) -> None: if path.exists() and path.read_text() == content: @@ -55,7 +58,6 @@ def generate_build_meta(aot_build_meta: dict) -> None: cmdclass: Mapping[str, type[setuptools.Command]] = {} install_requires = [ "numpy", - "torch", "ninja", "requests", "cuda-python<=12.9", @@ -64,9 +66,16 @@ def generate_build_meta(aot_build_meta: dict) -> None: "packaging>=24.2", "nvidia-cudnn-frontend>=1.13.0", ] +if not use_paddle_compatible_api(): + install_requires.append("torch") + generate_build_meta({}) if enable_aot: + if use_paddle_compatible_api(): + import paddle + paddle.compat.install_torch_alias() + import torch import torch.utils.cpp_extension as torch_cpp_ext from packaging.version import Version From ee2044214fec96f5c10f53850ac6f2ed40908522 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 4 Sep 2025 08:44:03 +0000 Subject: [PATCH 03/10] rename `install_torch_alias` to `enable_torch_proxy` --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bc537f0c37..61ecdfc17b 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ def generate_build_meta(aot_build_meta: dict) -> None: if enable_aot: if use_paddle_compatible_api(): import paddle - paddle.compat.install_torch_alias() + paddle.compat.enable_torch_proxy() import torch import torch.utils.cpp_extension as torch_cpp_ext From 40c40f1e5a6994538b9008b18cebc81e06c0566a Mon Sep 17 00:00:00 2001 From: SigureMo Date: Fri, 5 Sep 2025 06:08:00 +0000 Subject: [PATCH 04/10] run pre-commit --- flashinfer/fused_moe/core.py | 1 - flashinfer/utils.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 6f97462ea9..9a7ba595cf 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -42,7 +42,6 @@ get_shuffle_matrix_sf_a_row_indices, register_custom_op, register_fake_op, - use_paddle_compatible_api, ) from .utils import ( get_last_power_of_2_num_tokens_buckets, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 15632a2a80..fbbe77963c 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -226,13 +226,17 @@ def _check_cached_qkv_data_type( raise ValueError( f"The dtype of k {k.dtype} does not match the kv_data_type {dtype_kv} specified in plan function." ) - + def use_paddle_compatible_api() -> bool: return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] -if use_paddle_compatible_api() or IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"): +if ( + use_paddle_compatible_api() + or IS_BUILDING_DOCS + or TorchVersion(torch_version) < TorchVersion("2.4") +): def register_custom_op( name: str, From 930cd627d942b582b390c78424e6ad268dae98dc Mon Sep 17 00:00:00 2001 From: SigureMo Date: Fri, 5 Sep 2025 07:44:17 +0000 Subject: [PATCH 05/10] resolve circular import --- flashinfer/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index fbbe77963c..caa0dc4c92 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -25,7 +25,7 @@ from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version -from .jit import gen_jit_spec, env as jit_env +import flashinfer IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1" @@ -473,14 +473,14 @@ def check_shape_dtype_device( def get_logging_module(): - return gen_jit_spec( + return flashinfer.jit.gen_jit_spec( "logging", [ - jit_env.FLASHINFER_CSRC_DIR / "logging.cc", + flashinfer.jit.env.FLASHINFER_CSRC_DIR / "logging.cc", ], extra_include_paths=[ - jit_env.SPDLOG_INCLUDE_DIR, - jit_env.FLASHINFER_INCLUDE_DIR, + flashinfer.jit.env.SPDLOG_INCLUDE_DIR, + flashinfer.jit.env.FLASHINFER_INCLUDE_DIR, ], ).build_and_load() From 5d7c443561d44e55d46e073a7c304dc4a053c088 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Fri, 5 Sep 2025 09:11:41 +0000 Subject: [PATCH 06/10] run pre-commit --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 61ecdfc17b..4d5736822b 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ aot_ops_package_dir = root / "build" / "aot-ops-package-dir" enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir()) + def use_paddle_compatible_api() -> bool: return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] @@ -74,6 +75,7 @@ def generate_build_meta(aot_build_meta: dict) -> None: if enable_aot: if use_paddle_compatible_api(): import paddle + paddle.compat.enable_torch_proxy() import torch From b7a8db4a0bd4e67233a830a97a9395bce5daee17 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Fri, 5 Sep 2025 09:22:16 +0000 Subject: [PATCH 07/10] resolve conflict error --- flashinfer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 0728be0529..ea30ac034b 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -500,7 +500,7 @@ def check_shape_dtype_device( def gen_logging_module(): - return flashinfer.jitgen_jit_spec( + return flashinfer.jit.gen_jit_spec( "logging", [ flashinfer.jit.env.FLASHINFER_CSRC_DIR / "logging.cc", From a7678a81ff34a890f72b2b124d5d2f782b2dbd93 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 17 Sep 2025 14:08:41 +0000 Subject: [PATCH 08/10] add some missing changes --- flashinfer/jit/cpp_ext.py | 38 ++++++++++++++++++++++---------------- flashinfer/sampling.py | 2 ++ 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 9b6314d04c..73277a76f4 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -12,7 +12,6 @@ import torch from torch.utils.cpp_extension import ( - _TORCH_PATH, CUDA_HOME, _get_num_workers, _get_pybind11_abi_build_flags, @@ -22,6 +21,11 @@ from ..utils import use_paddle_compatible_api from ..compilation_context import CompilationContext +if use_paddle_compatible_api(): + _TORCH_PATH = torch.__path__[0] +else: + from torch.utils.cpp_extension import _TORCH_PATH # type: ignore[no-redef] + @functools.cache def get_cuda_path() -> str: @@ -85,15 +89,15 @@ def generate_ninja_build_for_op( system_includes.extend( [ "$torch_home/include", - "$torch_home/include/torch/csrc/api/include", + "$torch_home/include/paddle/phi/api/include/compat", + "$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include", ] ) else: system_includes.extend( [ "$torch_home/include", - "$torch_home/include/paddle/phi/api/include/compat", - "$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include", + "$torch_home/include/torch/csrc/api/include", ] ) system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS] @@ -104,6 +108,8 @@ def generate_ninja_build_for_op( "-DTORCH_API_INCLUDE_EXTENSION_H", "-DPy_LIMITED_API=0x03090000", ] + if use_paddle_compatible_api(): + common_cflags.append("-DPADDLE_WITH_CUDA") common_cflags += _get_pybind11_abi_build_flags() common_cflags += _get_glibcxx_abi_build_flags() if extra_include_dirs is not None: @@ -161,18 +167,6 @@ def generate_ninja_build_for_op( "-lcudart", ] if use_paddle_compatible_api(): - ldflags.extend( - [ - "-L$torch_home/lib", - "-L$cuda_home/lib64", - "-lc10", - "-lc10_cuda", - "-ltorch_cpu", - "-ltorch_cuda", - "-ltorch", - ] - ) - else: ldflags.extend( [ "-shared", @@ -187,6 +181,18 @@ def generate_ninja_build_for_op( "-lcudart", ] ) + else: + ldflags.extend( + [ + "-L$torch_home/lib", + "-L$cuda_home/lib64", + "-lc10", + "-lc10_cuda", + "-ltorch_cpu", + "-ltorch_cuda", + "-ltorch", + ] + ) env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS") if env_extra_ldflags: diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 4cd7e5bd5a..f45ddabd3f 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -14,6 +14,8 @@ limitations under the License. """ +from __future__ import annotations # for torch.Generator + import functools from types import SimpleNamespace from typing import Optional, Union From a629c98ebf83b9e8ff50a56271d24204db2bca99 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 23 Sep 2025 07:02:36 +0000 Subject: [PATCH 09/10] add some temp patch --- flashinfer/compilation_context.py | 3 ++- flashinfer/fused_moe/core.py | 16 ++++++++++++---- flashinfer/jit/utils.py | 12 ++++++------ flashinfer/utils.py | 15 ++++++++------- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/flashinfer/compilation_context.py b/flashinfer/compilation_context.py index 5d24643f55..5e7078c5f6 100644 --- a/flashinfer/compilation_context.py +++ b/flashinfer/compilation_context.py @@ -42,7 +42,8 @@ def __init__(self): self.TARGET_CUDA_ARCHS.add((int(major), minor)) else: try: - for device in range(torch.cuda.device_count()): + # for device in range(torch.cuda.device_count()): + for device in range(1): major, minor = torch.cuda.get_device_capability(device) if major >= 9: minor = str(minor) + "a" diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index ad6169c515..85dff5723d 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -568,7 +568,8 @@ def cutlass_fused_moe( enable_pdl: Optional[bool] = None, ) -> List[torch.Tensor]: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -868,7 +869,8 @@ def cutlass_fused_moe( raise NotImplementedError("min latency mode not yet implemented for Blackwell.") if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) num_rows = input.shape[0] if min_latency_mode: @@ -877,10 +879,16 @@ def cutlass_fused_moe( output_shape = (num_rows, hidden_size) if output is None: - output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + # output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + output = torch.empty(output_shape, dtype=output_dtype, device=input.place) else: check_shape_dtype_device( - output, output_shape, output_dtype, input.device, "output" + # output, output_shape, output_dtype, input.device, "output" + output, + output_shape, + output_dtype, + input.place, + "output", ) major, minor = torch.cuda.get_device_capability() diff --git a/flashinfer/jit/utils.py b/flashinfer/jit/utils.py index 4e19212e14..5f621f83e6 100644 --- a/flashinfer/jit/utils.py +++ b/flashinfer/jit/utils.py @@ -38,9 +38,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.int8: "int8_t", torch.uint8: "uint8_t", torch.int32: "int32_t", - torch.uint32: "uint32_t", + # torch.uint32: "uint32_t", torch.int64: "int64_t", - torch.uint64: "uint64_t", + # torch.uint64: "uint64_t", } dtype_cutlass_map = { @@ -51,9 +51,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.int8: "cutlass::int8_t", torch.uint8: "cutlass::uint8_t", torch.int32: "cutlass::int32_t", - torch.uint32: "cutlass::uint32_t", + # torch.uint32: "cutlass::uint32_t", torch.int64: "cutlass::int64_t", - torch.uint64: "cutlass::uint64_t", + # torch.uint64: "cutlass::uint64_t", } filename_safe_dtype_map = { @@ -64,9 +64,9 @@ def write_if_different(path: pathlib.Path, content: str) -> None: torch.int8: "i8", torch.uint8: "u8", torch.int32: "i32", - torch.uint32: "u32", + # torch.uint32: "u32", torch.int64: "i64", - torch.uint64: "u64", + # torch.uint64: "u64", } pos_encoding_mode_literal = { diff --git a/flashinfer/utils.py b/flashinfer/utils.py index ea30ac034b..3f2ee016ba 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -22,8 +22,6 @@ import torch import torch.version -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version import flashinfer @@ -222,6 +220,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @functools.cache def get_compute_capability(device: torch.device) -> Tuple[int, int]: + return torch.device.cuda.get_device_capability(device.gpu_device_id()) if device.type != "cuda": raise ValueError("device must be a cuda device") return torch.cuda.get_device_capability(device.index) @@ -247,7 +246,8 @@ def use_paddle_compatible_api() -> bool: if ( use_paddle_compatible_api() or IS_BUILDING_DOCS - or TorchVersion(torch_version) < TorchVersion("2.4") + or torch.torch_version.TorchVersion(torch.torch_version.__version__) + < torch.torch_version.TorchVersion("2.4") ): def register_custom_op( @@ -485,7 +485,7 @@ def check_shape_dtype_device( expected_device: Optional[torch.device], name: str, ) -> None: - if expected_shape and x.shape != torch.Size(expected_shape): + if expected_shape and tuple(x.shape) != torch.Size(expected_shape): raise ValueError( f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}" ) @@ -493,7 +493,8 @@ def check_shape_dtype_device( raise ValueError( f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}" ) - if expected_device and x.device != expected_device: + # if expected_device and x.device != expected_device: + if expected_device and x.place != expected_device: raise ValueError( f"Invalid device of {name}: expected {expected_device}, got {x.device}" ) @@ -541,8 +542,8 @@ def set_log_level(lvl_str: str) -> None: def device_support_pdl(device: torch.device) -> bool: - if device.type != "cuda": - return False + # if device.type != "cuda": + # return False major, _ = get_compute_capability(device) return major >= 9 From 03177f7e600a160c54a6a6aa37d43a3836026b5f Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 24 Sep 2025 12:24:42 +0000 Subject: [PATCH 10/10] replace `input.device` with `input.place` in fp4_quantization --- flashinfer/fp4_quantization.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index fc6bd96610..2ec4ae4be0 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -174,7 +174,8 @@ def fp4_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) return module.fp4_quantize( input, global_scale, @@ -355,9 +356,11 @@ def fp4_quantize( assert input.shape[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) # get input device sm version - major, minor = get_compute_capability(input.device) + # major, minor = get_compute_capability(input.device) + major, minor = get_compute_capability(input.place) x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( input, global_scale,