From 0cc8e320d1495c475600ee790e09c2a9c1749c26 Mon Sep 17 00:00:00 2001 From: TANG Ding Date: Thu, 12 Dec 2024 07:47:06 +0000 Subject: [PATCH 1/2] add ascend runtime and fix some problems --- README.md | 7 +- backend/ascend.py | 273 +++++++++++++++++++++++++++++++++++++++++++ backend/compiler.py | 11 +- backend/driver.py | 17 ++- backend/maca.py | 2 +- backend/mlu.py | 2 +- python/op/softmax.py | 2 +- 7 files changed, 297 insertions(+), 17 deletions(-) create mode 100644 backend/ascend.py diff --git a/README.md b/README.md index 77df5a1f..c8439404 100755 --- a/README.md +++ b/README.md @@ -5,9 +5,10 @@ triton for dsa ``` git clone https://github.com/llvm/llvm-project.git // triton下的llvm-hash.txt commit id -git reset --hard ed4e505c219fe6c7464ea5a056e90d8cd94c7332 +git checkout ed4e505c219fe6c7464ea5a056e90d8cd94c7332 +mkdir build && cd build -cmake -G Ninja ../llvm -DLLVM_ENABLE_PROJECTS="llvm;mlir" -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD="X86X86;NVPTX;AMDGPU" -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_INSTALL_UTILS=ON +cmake -G Ninja ../llvm -DLLVM_ENABLE_PROJECTS="llvm;mlir" -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_INSTALL_UTILS=ON ninja -j64 ``` @@ -18,7 +19,7 @@ ninja -j64 export LLVM_BUILD_DIR=... bash compile.sh export PYTHONPATH=$PWD/third_party/triton/python -export PATH=$PWD/third_party/triton/build/third_party/triton_shared/tools/triton-shared-opt/:$PATH +export PATH=$PWD/third_party/triton/build/third_party/dicp_triton/third_party/triton_shared/tools/triton-shared-opt:$PATH ``` diff --git a/backend/ascend.py b/backend/ascend.py new file mode 100644 index 00000000..4962ca5c --- /dev/null +++ b/backend/ascend.py @@ -0,0 +1,273 @@ +import hashlib +import os +import tempfile +import shutil +import subprocess +import sysconfig +import contextlib +import sys +import io +import functools +import importlib +import setuptools +from pathlib import Path +from triton.runtime.cache import get_cache_manager +from triton.runtime import JITFunction +from .utils import quiet +import torch +from torch_npu.contrib import transfer_to_npu + +def llir_to_ascendc(mod, metadata): + src = f""" +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * + * Function : z = x + y + * This sample is a very basic sample that implements vector add on Ascend plaform. + */ +#include "kernel_operator.h" + +using namespace AscendC; + +constexpr int32_t BUFFER_NUM = 2; // tensor num for each queue + +class KernelCustom {{ +public: + __aicore__ inline KernelCustom() {{}} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength) + {{ + this->blockLength = totalLength / GetBlockNum(); + this->tileNum = 8; + this->tileLength = this->blockLength / this->tileNum / BUFFER_NUM; + xGm.SetGlobalBuffer((__gm__ half*)x + this->blockLength * GetBlockIdx(), this->blockLength); + yGm.SetGlobalBuffer((__gm__ half*)y + this->blockLength * GetBlockIdx(), this->blockLength); + zGm.SetGlobalBuffer((__gm__ half*)z + this->blockLength * GetBlockIdx(), this->blockLength); + pipe.InitBuffer(inQueueX, BUFFER_NUM, this->tileLength * sizeof(half)); + pipe.InitBuffer(inQueueY, BUFFER_NUM, this->tileLength * sizeof(half)); + pipe.InitBuffer(outQueueZ, BUFFER_NUM, this->tileLength * sizeof(half)); + }} + __aicore__ inline void Process() + {{ + int32_t loopCount = this->tileNum * BUFFER_NUM; + for (int32_t i = 0; i < loopCount; i++) {{ + CopyIn(i); + Compute(i); + CopyOut(i); + }} + }} + +private: + __aicore__ inline void CopyIn(int32_t progress) + {{ + LocalTensor xLocal = inQueueX.AllocTensor(); + LocalTensor yLocal = inQueueY.AllocTensor(); + DataCopy(xLocal, xGm[progress * this->tileLength], this->tileLength); + DataCopy(yLocal, yGm[progress * this->tileLength], this->tileLength); + inQueueX.EnQue(xLocal); + inQueueY.EnQue(yLocal); + }} + __aicore__ inline void Compute(int32_t progress) + {{ + LocalTensor xLocal = inQueueX.DeQue(); + LocalTensor yLocal = inQueueY.DeQue(); + LocalTensor zLocal = outQueueZ.AllocTensor(); + Add(zLocal, xLocal, yLocal, this->tileLength); + outQueueZ.EnQue(zLocal); + inQueueX.FreeTensor(xLocal); + inQueueY.FreeTensor(yLocal); + }} + __aicore__ inline void CopyOut(int32_t progress) + {{ + LocalTensor zLocal = outQueueZ.DeQue(); + DataCopy(zGm[progress * this->tileLength], zLocal, this->tileLength); + outQueueZ.FreeTensor(zLocal); + }} + +private: + TPipe pipe; + TQue inQueueX, inQueueY; + TQue outQueueZ; + GlobalTensor xGm; + GlobalTensor yGm; + GlobalTensor zGm; + uint32_t blockLength; + uint32_t tileNum; + uint32_t tileLength; +}}; + +extern "C" __global__ __aicore__ void {metadata['name']}(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength) +{{ + KernelCustom op; + op.Init(x, y, z, totalLength); + op.Process(); +}} + """ + return src + + +def generate_pybind_cpp(name): + src = f""" +#include +#include "aclrtlaunch_{name}.h" +#include +#include "torch_npu/csrc/core/npu/NPUStream.h" + +PYBIND11_MODULE(__triton_launcher, m) {{ + m.def("launch", &aclrtlaunch_{name}, ""); +}} + """ + return src + + +def generate_cmakelist(soc_version, cann_path, name, mode='npu'): + src = f""" +cmake_minimum_required(VERSION 3.16.0) +project(Ascend_C) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +# user-defined configuration +set(SOC_VERSION "{soc_version}" CACHE STRING "system on chip type") +set(ASCEND_CANN_PACKAGE_PATH "{cann_path}" CACHE PATH "ASCEND CANN package installation directory") +set(RUN_MODE "{mode}" CACHE STRING "run mode: npu/sim/cpu") +set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type Release/Debug (default Debug)" FORCE) +set(CMAKE_INSTALL_PREFIX "${{CMAKE_CURRENT_LIST_DIR}}/out" CACHE STRING "path for install()" FORCE) + +if(EXISTS ${{ASCEND_CANN_PACKAGE_PATH}}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${{ASCEND_CANN_PACKAGE_PATH}}/tools/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${{ASCEND_CANN_PACKAGE_PATH}}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${{ASCEND_CANN_PACKAGE_PATH}}/compiler/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${{ASCEND_CANN_PACKAGE_PATH}}/ascendc_devkit/tikcpp/samples/cmake) + set(ASCENDC_CMAKE_DIR ${{ASCEND_CANN_PACKAGE_PATH}}/ascendc_devkit/tikcpp/samples/cmake) +else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.") +endif() + +include(${{ASCENDC_CMAKE_DIR}}/ascendc.cmake) + +ascendc_library(kernels STATIC + {name}.cpp +) + +add_library(pybind11_lib SHARED pybind11.cpp) +target_link_libraries(pybind11_lib PRIVATE + kernels + torch_npu +) +execute_process(COMMAND python3 -c "import os; import torch; print(os.path.dirname(torch.__file__))" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE TORCH_PATH +) +message("TORCH_PATH is ${{TORCH_PATH}}") +set(ENV{{ASCEND_HOME_PATH}} ${{ASCEND_CANN_PACKAGE_PATH}}) +execute_process(COMMAND python3 -c "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE TORCH_NPU_PATH +) +message("TORCH_NPU_PATH is ${{TORCH_NPU_PATH}}") +target_link_directories(pybind11_lib PRIVATE + ${{TORCH_PATH}}/lib + ${{TORCH_NPU_PATH}}/lib +) +target_include_directories(pybind11_lib PRIVATE + ${{TORCH_NPU_PATH}}/include + ${{TORCH_PATH}}/include + ${{TORCH_PATH}}/include/torch/csrc/api/include +) +execute_process(COMMAND python3 -m pybind11 --includes + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYBIND11_INC +) +string(REPLACE " " ";" PYBIND11_INC ${{PYBIND11_INC}}) +target_compile_options(pybind11_lib PRIVATE + ${{PYBIND11_INC}} + -D_GLIBCXX_USE_CXX11_ABI=0 +) + +execute_process(COMMAND python3-config --extension-suffix + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYBIND11_SUFFIX +) +set_target_properties(pybind11_lib PROPERTIES + OUTPUT_NAME {name}${{PYBIND11_SUFFIX}} + PREFIX "" SUFFIX "" +) + """ + return src + + +def generate_launcher_so(src, metadata): + current_dir = os.getcwd() + name = metadata['name'] + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, f"{name}.cpp") + with open(src_path, "w") as f: + f.write(src) + pybind_src = generate_pybind_cpp(name) + pybind_src_path = os.path.join(tmpdir, "pybind11.cpp") + with open(pybind_src_path, "w") as f: + f.write(pybind_src) + #soc_version = torch.cuda.get_device_properties(0)['name'] + soc_version = "Ascend910B2" + cann_path = os.getenv('ASCEND_TOOLKIT_HOME') + cmake_src = generate_cmakelist(soc_version, cann_path, name) + cmake_src_path = os.path.join(tmpdir, "CMakeLists.txt") + with open(cmake_src_path, "w") as f: + f.write(cmake_src) + build_dir = os.path.join(tmpdir, "build") + if os.path.exists(build_dir): + os.removedirs(build_dir) + os.makedirs(build_dir) + os.chdir(build_dir) + subprocess.check_call(["cmake", ".."]) + subprocess.check_call(["make"]) + + so_cache_manager = get_cache_manager(metadata['hash']) + files = os.listdir(build_dir) + for file in files: + if file.endswith(".so"): + so_file = os.path.join(build_dir, file) + break + with open(so_file, "rb") as f: + cache_path = so_cache_manager.put(f.read(), f"{name}.so", binary=True) + os.chdir(current_dir) + return cache_path + + +def load_binary(name, kernel, shared, device): + return None, None, 0, 0 + + +class AscendUtils(object): + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(AscendUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + self.load_binary = load_binary + self.get_device_properties = torch.cuda.get_device_properties + + +class AscendLauncher(object): + + def __init__(self, src, metadata): + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + self.signature = signature + self.constants = constants + self.hash_key = metadata['hash'] + self.name = metadata['name'] + + def __call__(self, grid_0, grid_1, grid_2, stream, kernel_function, kernel_packed_metadata, launch_metadata, launch_enter_hook,launch_exit_hook, *args, **kwargs): + so_cache_manager = get_cache_manager(self.hash_key) + cache_path = so_cache_manager.get_file(f"{self.name}.so") + if launch_enter_hook is not None: + launch_enter_hook(launch_metadata) + spec = importlib.util.spec_from_file_location("__triton_launcher", cache_path) + mod = importlib.util.module_from_spec(spec) + params = tuple([args[i] for i in self.signature.keys() if i not in self.constants]) + mod.launch(grid_0, stream, *params) + if launch_exit_hook is not None: + launch_exit_hook(launch_metadata) \ No newline at end of file diff --git a/backend/compiler.py b/backend/compiler.py index cc4abf4e..cfc6ebd2 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -115,13 +115,17 @@ def __init__(self, target:str) -> None: self.binary_ext = "cnbin" elif self.driver.target == 'maca': self.binary_ext = "mcfatbin" + elif self.driver.target == 'ascend': + self.binary_ext = "so" @staticmethod def supports_target(target: GPUTarget): - return target.backend in ['dicp', 'mlu', 'maca'] + return target.backend in ['dicp', 'mlu', 'maca', 'ascend'] @staticmethod def make_ttir(mod, metadata, opt): + key = hashlib.md5(str(mod).encode("utf-8")).hexdigest() + metadata['hash'] = key pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) @@ -154,6 +158,11 @@ def add_stages(self, stages, options): if mxcc_arch is None: raise RuntimeError('mxcc_arch is None (not specified)') stages["mcfatbin"] = lambda src, metadata: llir_to_mcfatbin(src, mxcc_arch, os.environ.get('MACA_PATH')) + elif self.driver.target == 'ascend': + from triton.backends.dicp_triton.ascend import llir_to_ascendc, generate_launcher_so + stages["ttlinalgdir"] = lambda src, metadata: _optimize_ttlinalgdir(_ttir_to_linalgdir(src)) + stages["ascendc"] = lambda src, metadata: llir_to_ascendc(src, metadata) + stages["so"] = lambda src, metadata: generate_launcher_so(src, metadata) else: raise RuntimeError("backend not supported") diff --git a/backend/driver.py b/backend/driver.py index 0ec4f3f1..585f4900 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -117,6 +117,11 @@ def __init__(self, target=None): self.target = "maca" self.utils = MacaUtils() self.launcher_cls = MacaLauncher + elif target == "ascend": + from triton.backends.dicp_triton.ascend import AscendLauncher, AscendUtils + self.target = "ascend" + self.utils = AscendUtils() + self.launcher_cls = AscendLauncher else: self.target = "dicp" @@ -132,11 +137,7 @@ def is_active(): return True def get_device_capability(self): - if self.target == "mlu": - return ("mlu", 0) - elif self.target == "maca": - return ("maca", 0) - return ("dicp", 0) + return (self.target, 0) def get_current_stream(self, device): if self.target == "mlu": @@ -167,11 +168,7 @@ def set_current_device(self, device): return def get_current_target(self): - if self.target == "mlu": - return GPUTarget("mlu", "x86", 32) - elif self.target == "maca": - return GPUTarget("maca", "x86", 32) - return GPUTarget("dicp", "x86", 32) + return GPUTarget(self.target, "x86", 32) def assemble_tensormap_to_arg(self, tensormaps_info, args): return args \ No newline at end of file diff --git a/backend/maca.py b/backend/maca.py index 6d10c9ab..b31c9576 100644 --- a/backend/maca.py +++ b/backend/maca.py @@ -693,7 +693,7 @@ def __init__(self, src, metadata): cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} - so_cache_key = src.hash() + so_cache_key = metadata['hash'] self.so_path = self.make_launcher_stub(name, so_cache_key, signature, constants) spec = importlib.util.spec_from_file_location("__triton_launcher", self.so_path) mod = importlib.util.module_from_spec(spec) diff --git a/backend/mlu.py b/backend/mlu.py index 08bbb36c..165dc59b 100644 --- a/backend/mlu.py +++ b/backend/mlu.py @@ -349,7 +349,7 @@ def __init__(self, src, metadata): cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} - so_cache_key = src.hash() + so_cache_key = metadata['hash'] self.so_path = self.make_launcher_stub(name, so_cache_key, signature, constants, ids) print("sopath: ", self.so_path) spec = importlib.util.spec_from_file_location("__triton_launcher", self.so_path) diff --git a/python/op/softmax.py b/python/op/softmax.py index 196c3e70..90cb7963 100644 --- a/python/op/softmax.py +++ b/python/op/softmax.py @@ -172,4 +172,4 @@ def softmax(x): ) ret = triton.compile(src) src_path = "softmax_optimize_kernel.mlir" -Path(src_path).write_bytes(ret.asm["ttlinalgdir"]) +Path(src_path).write_text(ret.asm["ttlinalgdir"]) From 6a6e25a8c2659d2c906fb2e05ab723c1f0eec041 Mon Sep 17 00:00:00 2001 From: TANG Ding Date: Thu, 12 Dec 2024 08:01:44 +0000 Subject: [PATCH 2/2] complete driver api --- backend/driver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/driver.py b/backend/driver.py index 585f4900..7468e746 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -144,7 +144,7 @@ def get_current_stream(self, device): if device is None: device = self.get_current_device() return torch.mlu.current_stream(device).mlu_stream - elif self.target == "maca": + elif self.target in ["maca", "ascend"]: if device is None: device = self.get_current_device() return torch.cuda.current_stream(device).cuda_stream @@ -154,7 +154,7 @@ def get_current_device(self): # dicp doesn't have a device to return. Return something. if self.target == "mlu": return torch.mlu.current_device() - elif self.target == "maca": + elif self.target in ["maca", "ascend"]: return torch.cuda.current_device() return "dicp" @@ -162,7 +162,7 @@ def set_current_device(self, device): # dicp doesn't have a device to set if self.target == "mlu": return torch.mlu.set_device(device) - elif self.target == "maca": + elif self.target in ["maca", "ascend"]: return torch.cuda.set_device(device) #assert device == "dicp" return