From 67e0c15276f55043fef347a62b19938c7f274b0e Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 18 Mar 2025 11:11:50 -0700 Subject: [PATCH 1/7] triton_backend_v2 --- scripts/eval_from_generations.py | 13 +- scripts/generate_and_eval_single_sample.py | 27 +- .../generate_and_eval_single_sample_modal.py | 38 +- scripts/generate_samples.py | 2 + src/eval.py | 68 ++- src/prompt_constructor_triton.py | 489 ++++++++++++++++++ .../triton/model_new_ex_add_triton.py | 63 +++ src/prompts/model_new_ex_add_triton.py | 63 +++ 8 files changed, 731 insertions(+), 32 deletions(-) create mode 100644 src/prompt_constructor_triton.py create mode 100644 src/prompts/few_shot/triton/model_new_ex_add_triton.py create mode 100644 src/prompts/model_new_ex_add_triton.py diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 82913fce..e007abbe 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -6,7 +6,8 @@ import json from tqdm import tqdm -from src import eval, utils, compile +# Import only what we need +from src import compile import torch import os import multiprocessing as mp @@ -15,7 +16,7 @@ from datasets import load_dataset from src.dataset import construct_kernelbench_dataset -from src.eval import build_compile_cache, eval_kernel_against_ref, KernelExecResult, check_metadata_serializable_all_types +from src.eval import eval_kernel_against_ref, KernelExecResult, check_metadata_serializable_all_types from src.utils import set_gpu_arch, read_file """ @@ -82,6 +83,9 @@ def __init__(self): # number of GPUs to do batch evaluation self.num_gpu_devices = 1 + # Backend to use for kernel implementation (cuda or triton) + self.backend = "cuda" + def __repr__(self): return f"EvalConfig({self.to_dict()})" @@ -160,6 +164,7 @@ def evaluate_single_sample(work_args: WorkArgs, configs: EvalConfig, dataset, ru num_perf_trials=configs.num_perf_trials, build_dir=build_dir, device=device, + backend=configs.backend, ) return eval_result except Exception as e: @@ -205,7 +210,7 @@ def cuda_single_eval_wrapper(curr_work: WorkArgs, configs: dict, dataset, run_di pool.join() raise except mp.TimeoutError as e: - print(f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}") + print(f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}\nException: {e}") print(f"[Eval Result] Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}: {result}") return result @@ -406,7 +411,7 @@ def main(config: EvalConfig): print(f"Evaluating 1 sample each for level {config.level} problems: {problem_id_range}") run_dir = os.path.join(config.runs_dir, config.run_name) - eval_file_path = os.path.join(run_dir, f"eval_results.json") + eval_file_path = os.path.join(run_dir, "eval_results.json") # set GPU arch to configure what target to build for diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 3fdb14b5..15ba4c1a 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -9,6 +9,7 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template +from src.prompt_constructor_triton import prompt_generate_custom_triton_from_prompt_template from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets """ @@ -56,6 +57,8 @@ def __init__(self): self.log_generated_kernel = False self.log_eval_result = False + self.backend = "cuda" + def verbose_logging(self): self.log = True self.log_prompt = True @@ -127,27 +130,35 @@ def main(config: EvalConfig): - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + # Use appropriate prompt constructor based on backend + if config.backend == "cuda": + custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + elif config.backend == "triton": + custom_prompt = prompt_generate_custom_triton_from_prompt_template(ref_arch_src) + else: + raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda' or 'triton'.") + if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) - # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + custom_kernel = inference_server(custom_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) + # check LLM is able to generate custom kernel code + assert custom_kernel is not None, f"Custom {config.backend} kernel code generation failed" # this should be optional if config.log: with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation kernel_exec_result = eval_kernel_against_ref( - ref_arch_src, custom_cuda, verbose=config.verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 + ref_arch_src, custom_kernel, verbose=config.verbose, measure_performance=True, + num_correct_trials=5, num_perf_trials=100, backend=config.backend ) print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index e4a31233..d235b460 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -10,6 +10,7 @@ #from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template +from src.prompt_constructor_triton import prompt_generate_custom_triton_from_prompt_template from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets app = modal.App("eval_single_sample") @@ -63,6 +64,8 @@ def __init__(self): self.log_generated_kernel = False self.log_eval_result = False + self.backend = "cuda" + def verbose_logging(self): self.log = True self.log_prompt = True @@ -106,15 +109,17 @@ def __repr__(self): class EvalFunc: @modal.method() - def eval_single_sample_modal(self, ref_arch_src, custom_cuda, verbose, gpu_arch): + def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend): # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation from src.eval import eval_kernel_against_ref - from src.utils import set_gpu_arch - set_gpu_arch(gpu_arch) + # Use utility function to set the GPU architecture in the modal environment + from src.utils import set_gpu_arch as modal_set_gpu_arch + modal_set_gpu_arch(gpu_arch) return eval_kernel_against_ref( - ref_arch_src, custom_cuda, verbose=verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 + ref_arch_src, custom_kernel, verbose=verbose, measure_performance=True, + num_correct_trials=5, num_perf_trials=100, backend=backend ) @pydra.main(base=EvalConfig) @@ -174,24 +179,33 @@ def main(config: EvalConfig): - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + # Use appropriate prompt constructor based on backend + if config.backend == "cuda": + custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + elif config.backend == "triton": + custom_prompt = prompt_generate_custom_triton_from_prompt_template(ref_arch_src) + else: + raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda' or 'triton'.") + if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) - # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + custom_kernel = inference_server(custom_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) + # check LLM is able to generate custom kernel code + assert custom_kernel is not None, f"Custom {config.backend} kernel code generation failed" # this should be optional if config.log: with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) with app.run(): - kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote(ref_arch_src, custom_cuda, config.verbose, gpu_arch_mapping[config.gpu]) + kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote( + ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend + ) print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 0d552b8b..51f1e892 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -62,6 +62,8 @@ def __init__(self): self.log_prompt = False + self.backend = "cuda" + def greedy(self): # For greedy decoding, epsecially baseline eval self.greedy_sample = True diff --git a/src/eval.py b/src/eval.py index 4532154e..ba9bff30 100644 --- a/src/eval.py +++ b/src/eval.py @@ -112,6 +112,42 @@ def load_original_model_and_inputs( Model = context.get("Model") return (Model, get_init_inputs_fn, get_inputs_fn) +def load_custom_model_with_tempfile(code_string, build_directory= None ,entry_point="ModelNew"): + """ + Writes the provided Python code string to a temporary .py file, + dynamically imports the module so we can access the modified model class. + + Returns both a Model class and the temporary file. The temporary file must be + deleted manually be the caller. + + This is a hack that is needed for triton code as compile / exec do not play well + with the @triton.jit decorator. + """ + + if build_directory: + model_custom_src = ( + "import os\n" f"os.environ['TORCH_EXTENSIONS_DIR'] = '{build_directory}'\n" + ) + model_custom_src + + # Create a temporary named file with a .py extension + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_file: + # Write the code string into the file + tmp_file.write(code_string) + # Capture the path to the file + tempfile_path = tmp_file.name + temp_file = tmp_file + + # Create a module specification pointing to our temp file + spec = importlib.util.spec_from_file_location("temp_module", tempfile_path) + # Create a new module based on that spec + temp_module = importlib.util.module_from_spec(spec) + # Execute the code in the module's namespace + spec.loader.exec_module(temp_module) + + ModelNew = getattr(temp_module, entry_point) + + # Return the object (class, function, etc.) that was defined in the code + return ModelNew, temp_file def load_custom_model( model_custom_src: str, context: dict, build_directory: str = None @@ -151,7 +187,7 @@ def _cleanup_cuda_extensions(): shutil.rmtree(torch_extensions_path) -def graceful_eval_cleanup(curr_context: dict, device: torch.device): +def graceful_eval_cleanup(curr_context: dict, device: torch.device, tempfile: tempfile.NamedTemporaryFile = None): """ Clean up env, gpu cache, and compiled CUDA extensions after evaluation """ # delete ran-specific function definitions before next eval run @@ -166,6 +202,9 @@ def graceful_eval_cleanup(curr_context: dict, device: torch.device): torch.cuda.synchronize( device=device ) # Wait for all CUDA operations to complete + if tempfile: + tempfile.close() + os.remove(tempfile.name) # _cleanup_cuda_extensions() # SIMON NOTE: is this necessary? @@ -301,6 +340,7 @@ def eval_kernel_against_ref( measure_performance: bool = False, build_dir: os.PathLike = None, device: torch.device = torch.cuda.current_device() if torch.cuda.is_available() else None, # have to run on GPU + backend: str = "cuda", # can be 'cuda' or 'triton' ) -> KernelExecResult: """ Evaluate the custom kernel against the original model @@ -308,6 +348,7 @@ def eval_kernel_against_ref( num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass num_perf_trials: run the evalutation many times to take the average device: GPU (cuda) device to run the evalutation on + backend: str, either 'cuda' or 'triton', determines which backend implementation to use """ # TODO: check device is busy assert torch.cuda.is_available(), "CUDA is not available, cannot run Eval" @@ -320,7 +361,11 @@ def eval_kernel_against_ref( # set CUDA device torch.cuda.set_device(device) - + is_triton = backend == "triton" + if is_triton: + # need to set env var for triton code to guarentee no wrong device shennanignas + assert device.type == "cuda", "CUDA is not availible on device, cannot run Eval" + os.environ["CUDA_VISIBLE_DEVICES"] = str(device.index) context = {} if verbose: @@ -352,8 +397,12 @@ def eval_kernel_against_ref( # this is where compilation happens try: os.environ["TORCH_USE_CUDA_DSA"] = "1" # compile with device side assertion + tempfile = None # add hash for later to distinguish between multi-turn kernels - ModelNew = load_custom_model(custom_model_src, context, build_dir) + if is_triton: + ModelNew, tempfile = load_custom_model_with_tempfile(custom_model_src, build_dir) + else: + ModelNew = load_custom_model(custom_model_src, context, build_dir) torch.cuda.synchronize(device=device) # not sure if this is too much except Exception as e: print( @@ -367,11 +416,11 @@ def eval_kernel_against_ref( print( f"[Eval] Lock file error during compilation, Please retry. Error: {e}" ) - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) return None else: metadata["compilation_error"] = e - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) return KernelExecResult( compiled=False, metadata=metadata ) # skip further steps @@ -390,7 +439,7 @@ def eval_kernel_against_ref( f"Failed to load custom CUDA kernel; Compiled but not able to run, count as runtime error. \nError: {e}" ) # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) metadata["runtime_error"] = e return KernelExecResult( compiled=True, correctness=False, metadata=metadata @@ -411,6 +460,7 @@ def eval_kernel_against_ref( verbose=verbose, seed=seed_num, device=device, + truncate_errors=not is_triton, ) except Exception as e: # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... @@ -454,7 +504,7 @@ def eval_kernel_against_ref( print(f"[Eval] Error in Measuring Performance: {e}") kernel_exec_result.metadata["error_during_performance"] = e - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) return kernel_exec_result @@ -550,6 +600,8 @@ def run_and_check_correctness( verbose=False, seed=42, device=None, + truncate_errors: bool =True, + ) -> KernelExecResult: """ run the model and check correctness, @@ -629,7 +681,7 @@ def run_and_check_correctness( print(f"Error in launching kernel for ModelNew: {e}") metadata = register_and_format_exception( - "runtime_error", e, metadata, truncate=True + "runtime_error", e, metadata, truncate=truncate_errors ) return KernelExecResult( compiled=True, correctness=False, metadata=metadata diff --git a/src/prompt_constructor_triton.py b/src/prompt_constructor_triton.py new file mode 100644 index 00000000..807cceb8 --- /dev/null +++ b/src/prompt_constructor_triton.py @@ -0,0 +1,489 @@ +import os +from .utils import read_file + + +""" +Construct Prompt + +Design principles: +- To evaluate base model performance on KernelBench, we use the simplest prompt possible to guide model output to generated desired output format. +- However, we do not do extensive prompt engineering or few-shot example in the LLM to steer behaviour. +""" + +REPO_TOP_PATH = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + "..", + ) +) +KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") + + +def get_arch_definition_from_file(arch_path): + arch_src = read_file(arch_path) + return get_arch_definition(arch_src) + + +def get_arch_definition(arch_src): + """ + Construct torch definition from original torch nn.Module definition + """ + prompt = f"Here is a pytorch defintion of a neural network architecture in the file model.py: ```{arch_src}```\n" + return prompt + + +############################################ +# Triton Prompt +############################################ + +PROBLEM_STATEMENT = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups. \n + You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +PROBLEM_INSTRUCTION = """ +Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + +def prompt_generate_custom_triton( + arc_src: str, example_arch_src: str, example_new_arch_src: str +) -> str: + prompt = PROBLEM_STATEMENT + + assert "@triton.jit" in example_new_arch_src, "Example new arch must contain Triton kernel" + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom Triton kernels looks like this: \n + ``` + {example_new_arch_src} + ``` \n + """ + + prompt += f""" + You are given the following architecture: \n + ``` + {arc_src} + ``` + """ + prompt += PROBLEM_INSTRUCTION + return prompt + + +PROBLEM_STATEMENT_CLEANED = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" +PROBLEM_INSTRUCTION_CLEANED = """ +Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + +def prompt_generate_custom_triton_fewshot_and_template(ref_arch_src: str, shots: list) -> str: + """ + Generate a prompt with specified few-shot examples following a template + + shots: list of few-shot examples to include in the prompt + Avaliable few shot options to start with: + - ex_add: pointwise addition + - ex_fuse_gelu: fused gelu + - ex_mnist2: fused convolutions and relus (DEPRECATED) + - ex_tiled_matmul: tiled matrix multiplication + - ex_flash_attn: simple flash attention + """ + prompt = PROBLEM_STATEMENT_CLEANED + + # k = 1 + example_add = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_add.py") + ) + example_add_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_add.py") + ) + example_add_desc = "This given architecture is for a pointwise addition: " + + # k = 2 + example_fuse_gelu = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") + ) + example_fuse_gelu_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") + ) + example_fuse_gelu_desc = "This given architecture is for a fused gelu: " + + # k = 3 (DEPRECATED) + example_mnist2 = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") + ) + example_mnist2_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") + ) + exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " + + # k = 4 + example_tiled_matmul = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") + ) + example_tiled_matmul_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") + ) + example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " + + # k = 5 + example_flash_attn = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_flash_attn.py") + ) + example_flash_attn_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_flash_attn.py") + ) + example_flash_attn_desc = "This given architecture is for a model with simple io-aware implementation of attention, also known as flash attention: " + + examples = [] + for s in shots: + if s not in ["ex_add", "ex_fuse_gelu", "ex_mnist2", "ex_tiled_matmul", "ex_flash_attn"]: + raise ValueError(f"Invalid shot: {s}") + elif s == "ex_add": + examples.append((example_add, example_add_new, example_add_desc)) + elif s == "ex_fuse_gelu": + examples.append((example_fuse_gelu, example_fuse_gelu_new, example_fuse_gelu_desc)) + elif s == "ex_mnist2": # DEPRECATED + raise ValueError("ex_mnist2 is deprecated") + examples.append((example_mnist2, example_mnist2_new, exmaple_mnist2_desc)) + elif s == "ex_tiled_matmul": + examples.append((example_tiled_matmul, example_tiled_matmul_new, example_tiled_matmul_desc)) + elif s == "ex_flash_attn": + examples.append((example_flash_attn, example_flash_attn_new, example_flash_attn_desc)) + + + for i, tup in enumerate(examples): + base, kernel, desc = tup + + prompt += f""" +Example {i+1}:\n\n +Here is an example architecture:\n\n +``` +{base} +```\n +{PROBLEM_INSTRUCTION_CLEANED} \n +Here is an optimized verison with custom CUDA kernels: \n +``` +{kernel} +```\n\n +""" + +# should we put task here? + prompt += f""" +Task:\n\n +Here is an example architecture:\n\n +``` +{ref_arch_src} +```\n +""" + prompt += PROBLEM_INSTRUCTION_CLEANED + return prompt + +def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) -> str: + """ + Generate a prompt with a CoT example following a template + Avaliable CoT examples: + - ex_fuse_gelu: fused gelu + - ex_mnist2: fused convolutions and relus + - ex_tiled_matmul: tiled matrix multiplication + """ + + # I updated this to allow CoT. Also explicilty state think step by step. + PROBLEM_INSTRUCTION_COT = """ +Optimize the architecture named Model with custom CUDA operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Do not output testing code. +In the end, make sure the final code block contains code for output architecture ModelNew with cuda code.\n +Let's think step by step.\n +""" + + prompt = PROBLEM_STATEMENT_CLEANED + + assert cot_example in ["ex_fuse_gelu", "ex_mnist2", "ex_tiled_matmul"] + + # k = 2 + example_fuse_gelu = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") + ) + example_fuse_gelu_cot = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_fuse_gelu.py") + ) + example_fuse_gelu_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") + ) + example_fuse_gelu_desc = "This given architecture is for a fused gelu: " + + # k = 3 + example_mnist2 = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") + ) + example_mnist2_cot = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_mnist2.py") + ) + example_mnist2_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") + ) + exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " + + # k = 4 + example_tiled_matmul = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") + ) + example_tiled_matmul_cot = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_tiled_matmul.py") + ) + example_tiled_matmul_new = read_file( + os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") + ) + example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " + + match cot_example: + case "ex_fuse_gelu": + base = example_fuse_gelu + cot = example_fuse_gelu_cot + kernel = example_fuse_gelu_new + desc = example_fuse_gelu_desc + case "ex_mnist2": + base = example_mnist2 + cot = example_mnist2_cot + kernel = example_mnist2_new + desc = exmaple_mnist2_desc + case "ex_tiled_matmul": + base = example_tiled_matmul + cot = example_tiled_matmul_cot + kernel = example_tiled_matmul_new + desc = example_tiled_matmul_desc + case _: + raise ValueError(f"Invalid CoT example: {cot_example} not found in CoT examples") + + # construct example with + # NOTE: we only do one example with CoT for now + # 1. ref_src problem -> 2. Instruction -> 3. CoT -> 4. Solution + prompt += f""" +Here is an example architecture:\n\n +``` +{base} +```\n +{PROBLEM_INSTRUCTION_COT} \n +{cot} \n +``` +{kernel} +```\n\n +""" + +# show task to solve + prompt += f""" +Task:\n\n +Here is an example architecture:\n\n +``` +{ref_arch_src} +```\n +""" + prompt += PROBLEM_INSTRUCTION_COT + + return prompt + + +def prompt_generate_custom_triton_from_prompt_template(ref_arch_src: str) -> str: + """ + Using prompt example (an element-wise addition) for prompt templates + The most basic form of example just to show LLM the task and the expected output format + """ + arch = ref_arch_src + # These are strictly defined for now + + # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) + example_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_ex_add.py" + ) + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" + ) + + if not os.path.exists(example_arch_path): + raise FileNotFoundError( + f"Example architecture file not found: {example_arch_path}" + ) + if not os.path.exists(example_new_arch_path): + raise FileNotFoundError( + f"Example new architecture file not found: {example_new_arch_path}" + ) + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + + return prompt_generate_custom_triton(arch, example_arch, example_new_arch) + + +def prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src: str, gpu_name: str) -> str: + """ + Similar to prompt_generate_custom_triton_from_prompt_template, + but with hardware information for the given GPU + """ + + arch = ref_arch_src + # These are strictly defined for now + + # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) + example_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_ex_add.py" + ) + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" + ) + + gpu_spec_file_path = os.path.join(REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py") + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + gpu_spec_info = read_file(gpu_spec_file_path) + + return prompt_generate_prompt_with_hardware_info( + ref_arch_src=arch, + gpu_name=gpu_name, + example_arch_src=example_arch, + example_new_arch_src=example_new_arch, + gpu_spec_info_src=gpu_spec_info + ) + + + +def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, + gpu_name: str, + example_arch_src: str, + example_new_arch_src: str, + gpu_spec_info_src: str) -> str: + """ + Generate a prompt with hardware information for the given GPU + gpu_spec_info_src: str of the gpu spec src file + """ + + # Create a dictionary to store the local namespace + local_dict = {} + + # Execute the GPU spec file in the local namespace + exec(gpu_spec_info_src, {}, local_dict) + + # Get the required variables from the local namespace + GPU_SPEC_INFO = local_dict.get('GPU_SPEC_INFO') + GPU_DEFINITIONS = local_dict.get('GPU_DEFINITIONS') + GPU_BEST_PRACTICES = local_dict.get('GPU_BEST_PRACTICES') + + if not GPU_SPEC_INFO or not GPU_DEFINITIONS or not GPU_BEST_PRACTICES: + raise ValueError("GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src") + + assert gpu_name in GPU_SPEC_INFO, f"GPU name {gpu_name} not found in GPU_SPEC_INFO" + + prompt = PROBLEM_STATEMENT + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom CUDA operators in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom CUDA kernels looks like this: + ``` + {example_new_arch_src} + ``` \n + """ + + curr_gpu_spec_info = GPU_SPEC_INFO[gpu_name] + + gpu_architecture = curr_gpu_spec_info.get("GPU Architecture") + prompt += f""" + Here is some information about the underlying hardware that you should keep in mind. \n\n +The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture.\n\n""" + + for key, value in curr_gpu_spec_info.items(): + if key == "GPU Architecture": + continue + prompt += f"""- We have {value} of {key}.\n""" + + + prompt += f"""\n\n +Here are some concepts about the GPU architecture that could be helpful: \n\n""" + for key, value in GPU_DEFINITIONS.items(): + prompt += f"""- {key}: {value}\n""" + + prompt += f"""\n\n +Here are some best practices for writing CUDA kernels on GPU: \n\n""" + for best_practice in GPU_BEST_PRACTICES: + prompt += f"""- {best_practice}\n""" + + + prompt += f""" + You are given the following architecture: \n + ``` + {ref_arch_src} + ``` + """ + + + prompt += PROBLEM_INSTRUCTION + return prompt + + + return None + + + + + +def prompt_fix_compile(ref_arch_src, custom_cuda, metadata): + prompt = PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed to compile: + ``` + {custom_cuda} + ``` + Here's the metadata of the compilation error: + ``` + {metadata} + ``` + + Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + +def prompt_fix_correctness(ref_arch_src, custom_cuda, metadata): + prompt = PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed correctness: + ``` + {custom_cuda} + ``` + Here's the metadata of the correctness error: + ``` + {metadata} + ``` + Please consider how your custom CUDA kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + +def main(): + gpu_name = "L40S" + + + ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, f"level1/19_ReLU.py")) + assert len(ref_arch_src) > 0, "ref_arch_src is empty" + prompt = prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src, gpu_name) + print(prompt) + # Write prompt to temp file + temp_file_path = os.path.join(REPO_TOP_PATH, "scratch", "prompt_draft.txt") + os.makedirs(os.path.dirname(temp_file_path), exist_ok=True) + with open(temp_file_path, "w") as f: + f.write(prompt) + +if __name__ == "__main__": + main() diff --git a/src/prompts/few_shot/triton/model_new_ex_add_triton.py b/src/prompts/few_shot/triton/model_new_ex_add_triton.py new file mode 100644 index 00000000..43a3f712 --- /dev/null +++ b/src/prompts/few_shot/triton/model_new_ex_add_triton.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # Pointer to first input + y_ptr, # Pointer to second input + out_ptr, # Pointer to output + n_elements, # Total number of elements in input/output + BLOCK_SIZE: tl.constexpr, +): + # Each program handles a contiguous block of data of size BLOCK_SIZE + block_start = tl.program_id(0) * BLOCK_SIZE + # Create a range of offsets [0..BLOCK_SIZE-1] + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Mask to ensure we don't go out of bounds + mask = offsets < n_elements + # Load input values + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + # Perform the elementwise addition + out = x + y + # Store the result + tl.store(out_ptr + offsets, out, mask=mask) + + +def triton_add(x: torch.Tensor, y: torch.Tensor): + """ + This function wraps the Triton kernel call. It: + 1. Ensures the inputs are contiguous on GPU. + 2. Calculates the grid (blocks) needed. + 3. Launches the Triton kernel. + """ + assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." + x = x.contiguous() + y = y.contiguous() + + # Prepare output tensor + out = torch.empty_like(x) + + # Number of elements in the tensor + n_elements = x.numel() + BLOCK_SIZE = 128 # Tunable parameter for block size + + # Determine the number of blocks needed + grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) + + # Launch the Triton kernel + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return out + + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + # Instead of "return a + b", call our Triton-based addition + return triton_add(a, b) \ No newline at end of file diff --git a/src/prompts/model_new_ex_add_triton.py b/src/prompts/model_new_ex_add_triton.py new file mode 100644 index 00000000..43a3f712 --- /dev/null +++ b/src/prompts/model_new_ex_add_triton.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # Pointer to first input + y_ptr, # Pointer to second input + out_ptr, # Pointer to output + n_elements, # Total number of elements in input/output + BLOCK_SIZE: tl.constexpr, +): + # Each program handles a contiguous block of data of size BLOCK_SIZE + block_start = tl.program_id(0) * BLOCK_SIZE + # Create a range of offsets [0..BLOCK_SIZE-1] + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Mask to ensure we don't go out of bounds + mask = offsets < n_elements + # Load input values + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + # Perform the elementwise addition + out = x + y + # Store the result + tl.store(out_ptr + offsets, out, mask=mask) + + +def triton_add(x: torch.Tensor, y: torch.Tensor): + """ + This function wraps the Triton kernel call. It: + 1. Ensures the inputs are contiguous on GPU. + 2. Calculates the grid (blocks) needed. + 3. Launches the Triton kernel. + """ + assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." + x = x.contiguous() + y = y.contiguous() + + # Prepare output tensor + out = torch.empty_like(x) + + # Number of elements in the tensor + n_elements = x.numel() + BLOCK_SIZE = 128 # Tunable parameter for block size + + # Determine the number of blocks needed + grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) + + # Launch the Triton kernel + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return out + + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + # Instead of "return a + b", call our Triton-based addition + return triton_add(a, b) \ No newline at end of file From 8bfdd210d85054ed7afca2cb6d9f82c3f6ac1988 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 18 Mar 2025 11:30:12 -0700 Subject: [PATCH 2/7] fix eval bugs --- src/eval.py | 90 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 31 deletions(-) diff --git a/src/eval.py b/src/eval.py index ba9bff30..8902c488 100644 --- a/src/eval.py +++ b/src/eval.py @@ -2,17 +2,21 @@ Helpers for Evaluations """ +import importlib +import json +import os, subprocess +import random +import sys +import tempfile +from contextlib import redirect_stderr, redirect_stdout +from io import StringIO +from typing import Union + +import numpy as np import requests import torch import torch.nn as nn -import os, subprocess from pydantic import BaseModel -import numpy as np -import random -import json -from contextlib import redirect_stdout, redirect_stderr -from io import StringIO -import sys from . import utils @@ -112,7 +116,10 @@ def load_original_model_and_inputs( Model = context.get("Model") return (Model, get_init_inputs_fn, get_inputs_fn) -def load_custom_model_with_tempfile(code_string, build_directory= None ,entry_point="ModelNew"): + +def load_custom_model_with_tempfile( + code_string, build_directory=None, entry_point="ModelNew" +): """ Writes the provided Python code string to a temporary .py file, dynamically imports the module so we can access the modified model class. @@ -149,6 +156,7 @@ def load_custom_model_with_tempfile(code_string, build_directory= None ,entry_po # Return the object (class, function, etc.) that was defined in the code return ModelNew, temp_file + def load_custom_model( model_custom_src: str, context: dict, build_directory: str = None ) -> nn.Module: @@ -187,7 +195,11 @@ def _cleanup_cuda_extensions(): shutil.rmtree(torch_extensions_path) -def graceful_eval_cleanup(curr_context: dict, device: torch.device, tempfile: tempfile.NamedTemporaryFile = None): +def graceful_eval_cleanup( + curr_context: dict, + device: torch.device, + tempfile: tempfile.NamedTemporaryFile = None, +): """ Clean up env, gpu cache, and compiled CUDA extensions after evaluation """ # delete ran-specific function definitions before next eval run @@ -208,6 +220,7 @@ def graceful_eval_cleanup(curr_context: dict, device: torch.device, tempfile: te # _cleanup_cuda_extensions() # SIMON NOTE: is this necessary? + def build_compile_cache_legacy( custom_model_src: str, verbose: bool = False, @@ -241,11 +254,12 @@ def build_compile_cache_legacy( if verbose: print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}") except Exception as e: - print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}") + print( + f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}" + ) return False, stdout_buffer.getvalue(), str(e) - - return True, stdout_buffer.getvalue(), None + return True, stdout_buffer.getvalue(), None def build_compile_cache( @@ -281,16 +295,16 @@ def build_compile_cache( if verbose: print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}") except Exception as e: - print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}") + print( + f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}" + ) return False, stdout_buffer.getvalue(), str(e) return True, stdout_buffer.getvalue(), None def build_compile_cache_with_capturing( - custom_model_src: str, - verbose: bool = False, - build_dir: os.PathLike = None + custom_model_src: str, verbose: bool = False, build_dir: os.PathLike = None ) -> tuple[int, str, str]: """ Write a temporary python file to compile the custom model on CPU @@ -312,22 +326,21 @@ def build_compile_cache_with_capturing( f.write(custom_model_src) # Execute the temporary Python file and capture output - process = subprocess.Popen(['python', tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process = subprocess.Popen( + ["python", tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) stdout, stderr = process.communicate() returncode = process.returncode # Clean up temporary file os.remove(tmp) - if verbose: print("[CPU Precompile] return code: ", returncode) - print("[CPU Precompile] stdout: \n", stdout.decode('utf-8')) - print("[CPU Precompile] stderr: \n", stderr.decode('utf-8')) - - return returncode, stdout.decode('utf-8'), stderr.decode('utf-8') - + print("[CPU Precompile] stdout: \n", stdout.decode("utf-8")) + print("[CPU Precompile] stderr: \n", stderr.decode("utf-8")) + return returncode, stdout.decode("utf-8"), stderr.decode("utf-8") def eval_kernel_against_ref( @@ -339,8 +352,10 @@ def eval_kernel_against_ref( verbose: bool = False, measure_performance: bool = False, build_dir: os.PathLike = None, - device: torch.device = torch.cuda.current_device() if torch.cuda.is_available() else None, # have to run on GPU - backend: str = "cuda", # can be 'cuda' or 'triton' + device: Union[torch.device, int] = ( + torch.cuda.current_device() if torch.cuda.is_available() else None + ), # have to run on GPU + backend: str = "cuda", # can be 'cuda' or 'triton' ) -> KernelExecResult: """ Evaluate the custom kernel against the original model @@ -363,9 +378,19 @@ def eval_kernel_against_ref( torch.cuda.set_device(device) is_triton = backend == "triton" if is_triton: - # need to set env var for triton code to guarentee no wrong device shennanignas - assert device.type == "cuda", "CUDA is not availible on device, cannot run Eval" - os.environ["CUDA_VISIBLE_DEVICES"] = str(device.index) + # need to set env var for triton code to guarentee no wrong device shennanignas + if isinstance(device, int): + device_num = device + elif isinstance(device, torch.device): + assert ( + device.type == "cuda" + ), "CUDA is not availible on device, cannot run Eval" + device_num = device.index + else: + raise ValueError( + f"device must be an int or torch.device, got {type(device)}" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_num) context = {} if verbose: @@ -400,7 +425,9 @@ def eval_kernel_against_ref( tempfile = None # add hash for later to distinguish between multi-turn kernels if is_triton: - ModelNew, tempfile = load_custom_model_with_tempfile(custom_model_src, build_dir) + ModelNew, tempfile = load_custom_model_with_tempfile( + custom_model_src, build_dir + ) else: ModelNew = load_custom_model(custom_model_src, context, build_dir) torch.cuda.synchronize(device=device) # not sure if this is too much @@ -600,8 +627,7 @@ def run_and_check_correctness( verbose=False, seed=42, device=None, - truncate_errors: bool =True, - + truncate_errors: bool = True, ) -> KernelExecResult: """ run the model and check correctness, @@ -730,11 +756,13 @@ def check_metadata_serializable(metadata: dict): return metadata + def check_metadata_serializable_all_types(metadata: dict): """ Ensure metadata is JSON serializable, if not, convert non-serializable values to strings recursively """ + def convert_to_serializable(obj): if isinstance(obj, dict): return {k: convert_to_serializable(v) for k, v in obj.items()} From 32ff679f2512c7d61e3a887c7d5252cbcfa0fdab Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 18 Mar 2025 15:09:47 -0700 Subject: [PATCH 3/7] fix issues --- scripts/eval_from_generations.py | 248 ++++++++++++------ scripts/generate_and_eval_single_sample.py | 119 ++++++--- scripts/generate_samples.py | 181 ++++++++----- src/eval.py | 26 +- src/prompt_constructor_triton.py | 290 ++++----------------- 5 files changed, 428 insertions(+), 436 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index e007abbe..4f9e560d 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -1,23 +1,29 @@ -from dataclasses import dataclass +import json +import multiprocessing as mp +import os import shutil import time -import pydra -from pydra import REQUIRED, Config +import traceback +from dataclasses import dataclass -import json -from tqdm import tqdm -# Import only what we need -from src import compile +import pydra import torch -import os -import multiprocessing as mp - from datasets import load_dataset +from pydra import Config, REQUIRED + +# Import only what we need +from src import compile from src.dataset import construct_kernelbench_dataset -from src.eval import eval_kernel_against_ref, KernelExecResult, check_metadata_serializable_all_types -from src.utils import set_gpu_arch, read_file +from src.eval import ( + check_metadata_serializable_all_types, + eval_kernel_against_ref, + get_error_name, + KernelExecResult, +) +from src.utils import read_file, set_gpu_arch +from tqdm import tqdm """ Batch Evaluation from Existing Generations @@ -40,9 +46,9 @@ class EvalConfig(Config): def __init__(self): - self.run_name = REQUIRED # name of the run to evaluate + self.run_name = REQUIRED # name of the run to evaluate - self.dataset_src = REQUIRED # either huggingface or local + self.dataset_src = REQUIRED # either huggingface or local # name of dataset name on Hugging Face self.dataset_name = "ScalingIntelligence/KernelBench" @@ -51,7 +57,7 @@ def __init__(self): self.level = REQUIRED # subset of problems to evaluate - self.subset = (None, None) # (start_id, end_id), these are the logical index + self.subset = (None, None) # (start_id, end_id), these are the logical index # Evaluation Mode: local (requires GPU), see modal (cloud GPU) in the modal file self.eval_mode = "local" @@ -69,23 +75,24 @@ def __init__(self): # Eval settings self.num_correct_trials = 5 self.num_perf_trials = 100 - self.timeout = 180 # in seconds + self.timeout = 180 # in seconds self.measure_performance = True - + # Eval Flow setting # To speedup evaluation, you can start building the kernel on CPU on disk as cache self.build_cache = False - self.num_cpu_workers = 20 # number of parallel process to to parallelize the build on CPUs - + self.num_cpu_workers = ( + 96 # number of parallel process to to parallelize the build on CPUs + ) + # Directory to build kernels for evaluation self.kernel_eval_build_dir = os.path.join(REPO_TOP_DIR, "cache") # number of GPUs to do batch evaluation - self.num_gpu_devices = 1 - + self.num_gpu_devices = 8 + # Backend to use for kernel implementation (cuda or triton) self.backend = "cuda" - def __repr__(self): return f"EvalConfig({self.to_dict()})" @@ -98,43 +105,58 @@ class WorkArgs: device: torch.device -def fetch_ref_arch_from_problem_id(dataset, problem_id: int, dataset_src: str) -> str | None: +def fetch_ref_arch_from_problem_id( + dataset, problem_id: int, dataset_src: str +) -> str | None: """ Fetch reference architecture from problem directory Either from Hugging Face or Local Dataset """ if dataset_src == "huggingface": - curr_problem_row = dataset.filter(lambda x: x["problem_id"] == problem_id, num_proc=1, desc=None) + curr_problem_row = dataset.filter( + lambda x: x["problem_id"] == problem_id, num_proc=1, desc=None + ) ref_arch_src = curr_problem_row["code"][0] problem_name = curr_problem_row["name"][0] - + elif dataset_src == "local": - problem_idx_in_dataset = problem_id - 1 # due to dataset list being 0-indexed locally + problem_idx_in_dataset = ( + problem_id - 1 + ) # due to dataset list being 0-indexed locally ref_arch_path = dataset[problem_idx_in_dataset] problem_name = os.path.basename(ref_arch_path) ref_arch_src = read_file(ref_arch_path) # verify - # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") + # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") problem_number = int(problem_name.split("_")[0]) - assert problem_number == problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({problem_id})" - + assert ( + problem_number == problem_id + ), f"Problem number in filename ({problem_number}) does not match config problem_id ({problem_id})" + return ref_arch_src -def fetch_kernel_from_disk(run_dir: str, level: int, problem_id: int, sample_id: int) -> str | None: +def fetch_kernel_from_disk( + run_dir: str, level: int, problem_id: int, sample_id: int +) -> str | None: """ Fetch kernel file from disk (stored in runs/{run_name}) """ - kernel_path = os.path.join(run_dir, f"level_{level}_problem_{problem_id}_sample_{sample_id}_kernel.py") - + kernel_path = os.path.join( + run_dir, f"level_{level}_problem_{problem_id}_sample_{sample_id}_kernel.py" + ) + if os.path.exists(kernel_path): return read_file(kernel_path) else: return None -def evaluate_single_sample(work_args: WorkArgs, configs: EvalConfig, dataset, run_dir: str) -> KernelExecResult | None: + +def evaluate_single_sample( + work_args: WorkArgs, configs: EvalConfig, dataset, run_dir: str +) -> KernelExecResult | None: """ Evaluate a single sample on a single GPU """ @@ -144,22 +166,28 @@ def evaluate_single_sample(work_args: WorkArgs, configs: EvalConfig, dataset, ru work_args.device, ) # fetch reference architecture from problem directory - ref_arch_src = fetch_ref_arch_from_problem_id(dataset, problem_id, configs.dataset_src) + ref_arch_src = fetch_ref_arch_from_problem_id( + dataset, problem_id, configs.dataset_src + ) # fetch kernel from disk # Add database support in the future kernel_src = fetch_kernel_from_disk(run_dir, configs.level, problem_id, sample_id) - assert kernel_src is not None, f"Kernel not found for problem {problem_id} sample {sample_id}" + assert ( + kernel_src is not None + ), f"Kernel not found for problem {problem_id} sample {sample_id}" - build_dir = os.path.join(configs.kernel_eval_build_dir, configs.run_name, f"{problem_id}", f"{sample_id}") + build_dir = os.path.join( + configs.kernel_eval_build_dir, configs.run_name, f"{problem_id}", f"{sample_id}" + ) - try: + try: eval_result = eval_kernel_against_ref( original_model_src=ref_arch_src, custom_model_src=kernel_src, measure_performance=configs.measure_performance, - verbose=configs.verbose, + verbose=configs.verbose, num_correct_trials=configs.num_correct_trials, num_perf_trials=configs.num_perf_trials, build_dir=build_dir, @@ -175,6 +203,7 @@ def evaluate_single_sample(work_args: WorkArgs, configs: EvalConfig, dataset, ru # NOTE: count this as compilation failure as it is not runnable code metadata = { "cuda_error": f"CUDA Error: {str(e)}", + "cuda_error_name": get_error_name(e), "hardware": torch.cuda.get_device_name(device=device), "device": str(device), } # log this for debugging as this usually signifies illegal memory access @@ -183,14 +212,18 @@ def evaluate_single_sample(work_args: WorkArgs, configs: EvalConfig, dataset, ru ) return eval_result else: - metadata = {"other_error": f"error: {str(e)}", - "hardware": torch.cuda.get_device_name(device=device), - "device": str(device) - } # for debugging - eval_result = KernelExecResult(compiled=False, correctness=False, - metadata=metadata) + metadata = { + "other_error": f"error: {str(e)}", + "other_error_name": get_error_name(e), + "hardware": torch.cuda.get_device_name(device=device), + "device": str(device), + } # for debugging + eval_result = KernelExecResult( + compiled=False, correctness=False, metadata=metadata + ) return eval_result - + + def cuda_single_eval_wrapper(curr_work: WorkArgs, configs: dict, dataset, run_dir: str): """ Wrapper to handle timeout and keyboard interrupt @@ -203,16 +236,18 @@ def cuda_single_eval_wrapper(curr_work: WorkArgs, configs: dict, dataset, run_di args=(curr_work, configs, dataset, run_dir), ).get(timeout=configs.timeout) except KeyboardInterrupt: - print( - "\n [Terminate] Caught KeyboardInterrupt, terminating workers..." - ) + print("\n [Terminate] Caught KeyboardInterrupt, terminating workers...") pool.terminate() pool.join() raise except mp.TimeoutError as e: - print(f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}\nException: {e}") + print( + f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}\nException: {e}" + ) - print(f"[Eval Result] Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}: {result}") + print( + f"[Eval Result] Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}: {result}" + ) return result @@ -221,15 +256,20 @@ def remove_cache_dir(cache_dir: str, run_name: str, problem_id, sample_id): Remove the cached folder for sample compilation so it can start a clean build next time useful for time out, failed build, etc. """ - problem_cache_dir = os.path.join(cache_dir, run_name, f"{problem_id}", f"{sample_id}") + problem_cache_dir = os.path.join( + cache_dir, run_name, f"{problem_id}", f"{sample_id}" + ) print(f"cache_dir to remove: {problem_cache_dir}") if os.path.exists(cache_dir): try: shutil.rmtree(cache_dir, ignore_errors=True) - print(f"\n[INFO] Removed cached folder for Problem ID: {problem_id}, Sample ID: {sample_id}") + print( + f"\n[INFO] Removed cached folder for Problem ID: {problem_id}, Sample ID: {sample_id}" + ) except Exception as e: print(f"\n[WARNING] Failed to remove cache directory {cache_dir}: {str(e)}") + def batch_eval( total_work: list[tuple[int, int]], config: EvalConfig, @@ -253,7 +293,9 @@ def batch_eval( print( f"[Curr Batch] {len(curr_work_batch)} tasks over {config.num_gpu_devices} GPUs; [Total Work left] {len(total_work)}" ) - assert len(curr_work_batch) <= batch_size, f"Current batch size {len(curr_work_batch)} is greater than the number of GPUs {batch_size}" + assert ( + len(curr_work_batch) <= batch_size + ), f"Current batch size {len(curr_work_batch)} is greater than the number of GPUs {batch_size}" with mp.Pool(batch_size) as pool: @@ -278,7 +320,7 @@ def batch_eval( async_results.append( pool.apply_async(evaluate_single_sample, work_arg) ) - + # Collect results with a batch timeout results = [] batch_timeout = config.timeout @@ -290,20 +332,31 @@ def batch_eval( remaining_time = max(0, batch_timeout - elapsed_time) result = async_result.get(timeout=remaining_time) results.append((problem_id, sample_id, result)) - + except mp.TimeoutError: print( f"[WARNING] Evaluation TIMED OUT for Problem ID: {problem_id}, Sample ID: {sample_id}" ) results.append((problem_id, sample_id, None)) - - remove_cache_dir(config.kernel_eval_build_dir, config.run_name, problem_id, sample_id) + + remove_cache_dir( + config.kernel_eval_build_dir, + config.run_name, + problem_id, + sample_id, + ) except Exception as e: print( f"[ERROR] Evaluation FAILED for Problem ID: {problem_id}, Sample ID: {sample_id}: {str(e)}" ) + traceback.print_exc() results.append((problem_id, sample_id, None)) - remove_cache_dir(config.kernel_eval_build_dir, config.run_name, problem_id, sample_id) + remove_cache_dir( + config.kernel_eval_build_dir, + config.run_name, + problem_id, + sample_id, + ) end_time = time.time() @@ -318,8 +371,12 @@ def batch_eval( # add all the batch results here to avoid file race condition # add to eval result if valid result if result is not None: - print(f"Adding Eval Result to file for problem {problem_id} sample {sample_id}") - add_to_eval_results_file(problem_id, sample_id, result, eval_file_path) + print( + f"Adding Eval Result to file for problem {problem_id} sample {sample_id}" + ) + add_to_eval_results_file( + problem_id, sample_id, result, eval_file_path + ) print("-" * 128) print( @@ -328,51 +385,62 @@ def batch_eval( pbar.update(len(curr_work_batch)) -def check_if_eval_exists_local(problem_id: int, sample_id: int, eval_file_path: str) -> bool: + +def check_if_eval_exists_local( + problem_id: int, sample_id: int, eval_file_path: str +) -> bool: """ Check if evaluation result already exists in eval results file """ if os.path.exists(eval_file_path): - with open(eval_file_path, 'r') as f: + with open(eval_file_path, "r") as f: eval_results = json.load(f) return str(problem_id) in eval_results return False -def add_to_eval_results_file(problem_id: int, sample_id: int, eval_result: KernelExecResult, eval_file_path: str): + +def add_to_eval_results_file( + problem_id: int, sample_id: int, eval_result: KernelExecResult, eval_file_path: str +): """ Add evaluation result to eval results file TODO: migrate database support """ # Load existing results if file exists if os.path.exists(eval_file_path): - with open(eval_file_path, 'r') as f: + with open(eval_file_path, "r") as f: eval_results = json.load(f) else: eval_results = {} - + # Add new result eval_results[str(problem_id)] = { # assume 1 sample for now, will think about how to do this better for more samples - 'sample_id': sample_id, - 'compiled': eval_result.compiled, - 'correctness': eval_result.correctness, - 'metadata': check_metadata_serializable_all_types(eval_result.metadata), - 'runtime': eval_result.runtime, - 'runtime_stats': eval_result.runtime_stats, + "sample_id": sample_id, + "compiled": eval_result.compiled, + "correctness": eval_result.correctness, + "metadata": check_metadata_serializable_all_types(eval_result.metadata), + "runtime": eval_result.runtime, + "runtime_stats": eval_result.runtime_stats, } - + # Write updated results back to file if not os.path.exists(eval_file_path): os.makedirs(os.path.dirname(eval_file_path), exist_ok=True) - + with open(eval_file_path, "w") as f: - json.dump(eval_results, f) + json.dump(eval_results, f, indent=4) -def single_eval_example(config: EvalConfig, curr_level_dataset: list[str], run_dir: str, eval_file_path ): + +def single_eval_example( + config: EvalConfig, curr_level_dataset: list[str], run_dir: str, eval_file_path +): device = torch.device("cuda:0") example_work = WorkArgs(problem_id=1, sample_id=0, device=device) # example_eval_result = evaluate_single_sample(example_work, config, curr_level_dataset, run_dir) - example_eval_result = cuda_single_eval_wrapper(example_work, config, curr_level_dataset, run_dir) + example_eval_result = cuda_single_eval_wrapper( + example_work, config, curr_level_dataset, run_dir + ) print(example_eval_result) if not check_if_eval_exists_local(1, 0, eval_file_path): add_to_eval_results_file(1, 0, example_eval_result, eval_file_path) @@ -385,7 +453,7 @@ def main(config: EvalConfig): Store Eval Results in specified eval results file """ print(f"Starting Batch Eval with config: {config}") - + # Check if CUDA is available if not torch.cuda.is_available(): raise RuntimeError("CUDA device not available. Evaluation requires GPU.") @@ -399,35 +467,44 @@ def main(config: EvalConfig): curr_level_dataset = dataset[f"level_{config.level}"] elif config.dataset_src == "local": curr_level_dataset = construct_kernelbench_dataset(config.level) - + num_problems_in_level = len(curr_level_dataset) if config.subset == (None, None): problem_id_range = range(1, num_problems_in_level) else: - assert config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level, f"Subset range {config.subset} out of range for Level {config.level}" + assert ( + config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level + ), f"Subset range {config.subset} out of range for Level {config.level}" problem_id_range = range(config.subset[0], config.subset[1]) - print(f"Evaluating 1 sample each for level {config.level} problems: {problem_id_range}") + print( + f"Evaluating 1 sample each for level {config.level} problems: {problem_id_range}" + ) run_dir = os.path.join(config.runs_dir, config.run_name) eval_file_path = os.path.join(run_dir, "eval_results.json") - # set GPU arch to configure what target to build for set_gpu_arch(config.gpu_arch) - assert config.num_gpu_devices <= torch.cuda.device_count(), f"Number of GPUs requested ({config.num_gpu_devices}) is greater than the number of available GPUs ({torch.cuda.device_count()})" + assert ( + config.num_gpu_devices <= torch.cuda.device_count() + ), f"Number of GPUs requested ({config.num_gpu_devices}) is greater than the number of available GPUs ({torch.cuda.device_count()})" # To Debug # single_eval_example(config, curr_level_dataset, run_dir, eval_file_path) total_work = [] - for problem_id in range(problem_id_range.start, problem_id_range.stop + 1): # end index is inclusive - sample_id = 0 # only evaluate 1 sample for now + for problem_id in range( + problem_id_range.start, problem_id_range.stop + 1 + ): # end index is inclusive + sample_id = 0 # only evaluate 1 sample for now if not check_if_eval_exists_local(problem_id, sample_id, eval_file_path): total_work.append((problem_id, sample_id)) - print(f"Start evaluation on {len(total_work)} unevaluated samples in range: {problem_id_range}") + print( + f"Start evaluation on {len(total_work)} unevaluated samples in range: {problem_id_range}" + ) # Build Cache on CPU as that is faster if config.build_cache: compile.batch_compile(total_work, config.to_dict()) @@ -438,4 +515,3 @@ def main(config: EvalConfig): if __name__ == "__main__": main() - \ No newline at end of file diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 15ba4c1a..223cdc39 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -1,16 +1,25 @@ -import pydra -from pydra import REQUIRED, Config +import json import os, sys + +import pydra import torch -import json from datasets import load_dataset +from pydra import Config, REQUIRED from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.prompt_constructor_triton import prompt_generate_custom_triton_from_prompt_template -from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets +from src.prompt_constructor_triton import ( + prompt_generate_custom_triton_from_prompt_template, +) +from src.utils import ( + create_inference_server_from_presets, + extract_first_code, + query_server, + read_file, + set_gpu_arch, +) """ Generate and evaluate a single sample @@ -21,15 +30,15 @@ torch.set_printoptions(precision=4, threshold=10) + class EvalConfig(Config): def __init__(self): - - self.dataset_src = REQUIRED # either huggingface or local + + self.dataset_src = REQUIRED # either huggingface or local # name of dataset name on Hugging Face self.dataset_name = "ScalingIntelligence/KernelBench" - # Problem Specification self.level = REQUIRED # NOTE: this is the logical index (problem id the problem_name)\ @@ -89,24 +98,31 @@ def main(config: EvalConfig): if config.log: os.makedirs(config.logdir, exist_ok=True) - + # Problem Checks num_problems = len(curr_level_dataset) print(f"Number of problems in Level {config.level}: {num_problems}") - print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}") - - assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}" + print( + f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}" + ) + assert ( + config.problem_id <= num_problems + ), f"Problem ID {config.problem_id} out of range for Level {config.level}" # 1. Fetch Problem if config.dataset_src == "huggingface": - curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id) + curr_problem_row = curr_level_dataset.filter( + lambda x: x["problem_id"] == config.problem_id + ) ref_arch_src = curr_problem_row["code"][0] problem_name = curr_problem_row["name"][0] elif config.dataset_src == "local": - problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally + problem_idx_in_dataset = ( + config.problem_id - 1 + ) # due to dataset list being 0-indexed locally ref_arch_path = curr_level_dataset[problem_idx_in_dataset] problem_name = os.path.basename(ref_arch_path) @@ -115,20 +131,21 @@ def main(config: EvalConfig): # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") problem_number = int(problem_name.split("_")[0]) - assert problem_number == config.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" - - + assert ( + problem_number == config.problem_id + ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + # 2. Generate Sample # Create inference function with config parameters # We provide some presets in utils but you can also pass in your own, see query_server for more details - inference_server = create_inference_server_from_presets(server_type=config.server_type, - model_name=config.model_name, - temperature=config.temperature, - max_tokens=config.max_tokens, - verbose=config.verbose, - time_generation=True) - - + inference_server = create_inference_server_from_presets( + server_type=config.server_type, + model_name=config.model_name, + temperature=config.temperature, + max_tokens=config.max_tokens, + verbose=config.verbose, + time_generation=True, + ) # Use appropriate prompt constructor based on backend if config.backend == "cuda": @@ -136,39 +153,67 @@ def main(config: EvalConfig): elif config.backend == "triton": custom_prompt = prompt_generate_custom_triton_from_prompt_template(ref_arch_src) else: - raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda' or 'triton'.") - + raise ValueError( + f"Unsupported backend: {config.backend}. Must be 'cuda' or 'triton'." + ) + if config.log_prompt: - with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: + with open( + os.path.join( + config.logdir, + f"prompt_level_{config.level}_problem_{config.problem_id}.txt", + ), + "w", + ) as f: f.write(custom_prompt) # Query server with constructed prompt custom_kernel = inference_server(custom_prompt) custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) # check LLM is able to generate custom kernel code - assert custom_kernel is not None, f"Custom {config.backend} kernel code generation failed" - + assert ( + custom_kernel is not None + ), f"Custom {config.backend} kernel code generation failed" + # this should be optional if config.log: - with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: + with open( + os.path.join( + config.logdir, + f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py", + ), + "w", + ) as f: f.write(custom_kernel) # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation kernel_exec_result = eval_kernel_against_ref( - ref_arch_src, custom_kernel, verbose=config.verbose, measure_performance=True, - num_correct_trials=5, num_perf_trials=100, backend=config.backend + ref_arch_src, + custom_kernel, + verbose=config.verbose, + measure_performance=True, + num_correct_trials=5, + num_perf_trials=100, + backend=config.backend, + ) + + print( + f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}" ) - - print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") if config.log: - with open(os.path.join(config.logdir, f"eval_result_level_{config.level}_problem_{config.problem_id}.txt"), "a") as f: + with open( + os.path.join( + config.logdir, + f"eval_result_level_{config.level}_problem_{config.problem_id}.txt", + ), + "a", + ) as f: f.write(f"Problem Name: {problem_name}\n") f.write(str(kernel_exec_result)) if __name__ == "__main__": main() - diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 51f1e892..9dd48bf5 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -1,17 +1,26 @@ -import pydra -from pydra import REQUIRED, Config -import os, sys -import torch import json +import os, sys from dataclasses import dataclass +import pydra +import torch from datasets import load_dataset +from pydra import Config, REQUIRED from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.utils import extract_first_code, set_gpu_arch, read_file, create_inference_server_from_presets, maybe_multithread +from src.prompt_constructor_triton import ( + prompt_generate_custom_triton_from_prompt_template, +) +from src.utils import ( + create_inference_server_from_presets, + extract_first_code, + maybe_multithread, + read_file, + set_gpu_arch, +) """ Batch Generate Samples for Particular Level @@ -23,21 +32,25 @@ torch.set_printoptions(precision=4, threshold=10) + class GenerationConfig(Config): def __init__(self): - - self.dataset_src = REQUIRED # either huggingface or local + + self.dataset_src = REQUIRED # either huggingface or local # name of dataset name on Hugging Face self.dataset_name = "ScalingIntelligence/KernelBench" # Problem Specification self.level = REQUIRED - + # subset of problems to generate, otherwise generate on all problems in the level - self.subset = (None, None) # (problem_id, problem_name), these are the logical index + self.subset = ( + None, + None, + ) # (problem_id, problem_name), these are the logical index - self.run_name = REQUIRED # name of the run + self.run_name = REQUIRED # name of the run # num of thread pool to call inference server in parallel self.num_workers = 1 @@ -48,13 +61,13 @@ def __init__(self): self.model_name = "deepseek-coder" self.max_tokens = 4096 self.temperature = 0.0 - + # Logging # Top Directory to Store Runs self.runs_dir = os.path.join(REPO_TOP_DIR, "runs") - + self.verbose = False - self.store_type = "local" # TODO: add Database Integration + self.store_type = "local" # TODO: add Database Integration # Future support # Migrate Monkeys code base to KernelBench @@ -70,23 +83,34 @@ def greedy(self): def __repr__(self): return f"EvalConfig({self.to_dict()})" - + @dataclass class WorkArgs: - problem_id: int # logically indexed + problem_id: int # logically indexed sample_id: int -def generate_sample_single(work: WorkArgs, config: GenerationConfig, dataset, inference_server: callable, run_dir: str) -> bool: + +def generate_sample_single( + work: WorkArgs, + config: GenerationConfig, + dataset, + inference_server: callable, + run_dir: str, +) -> bool: # 1. Fetch Problem if config.dataset_src == "huggingface": - curr_problem_row = dataset.filter(lambda x: x["problem_id"] == work.problem_id, desc=None) + curr_problem_row = dataset.filter( + lambda x: x["problem_id"] == work.problem_id, desc=None + ) ref_arch_src = curr_problem_row["code"][0] problem_name = curr_problem_row["name"][0] elif config.dataset_src == "local": - problem_idx_in_dataset = work.problem_id - 1 # due to dataset list being 0-indexed locally + problem_idx_in_dataset = ( + work.problem_id - 1 + ) # due to dataset list being 0-indexed locally ref_arch_path = dataset[problem_idx_in_dataset] problem_name = os.path.basename(ref_arch_path) @@ -94,14 +118,24 @@ def generate_sample_single(work: WorkArgs, config: GenerationConfig, dataset, in # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") problem_number = int(problem_name.split("_")[0]) - assert problem_number == work.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" - - - - # Construct Prompt - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + assert ( + problem_number == work.problem_id + ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + + # Construct Prompt + if config.backend == "cuda": + custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template( + ref_arch_src + ) + elif config.backend == "triton": + custom_cuda_prompt = prompt_generate_custom_triton_from_prompt_template( + ref_arch_src + ) if config.log_prompt: - prompt_path = os.path.join(run_dir, f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_prompt.txt") + prompt_path = os.path.join( + run_dir, + f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_prompt.txt", + ) with open(prompt_path, "w") as f: f.write(custom_cuda_prompt) @@ -112,17 +146,28 @@ def generate_sample_single(work: WorkArgs, config: GenerationConfig, dataset, in assert custom_cuda is not None, "Custom CUDA code generation failed" if config.verbose: - print(f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}") + print( + f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}" + ) # Store to local file - kernel_path = os.path.join(run_dir, f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_kernel.py") + kernel_path = os.path.join( + run_dir, + f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_kernel.py", + ) with open(kernel_path, "w") as f: f.write(custom_cuda) - + return True - -def generate_sample_launcher(work: WorkArgs, config: GenerationConfig, dataset, inference_server: callable, run_dir: str): + +def generate_sample_launcher( + work: WorkArgs, + config: GenerationConfig, + dataset, + inference_server: callable, + run_dir: str, +): try: return generate_sample_single(work, config, dataset, inference_server, run_dir) except Exception as e: @@ -130,13 +175,17 @@ def generate_sample_launcher(work: WorkArgs, config: GenerationConfig, dataset, return None -def check_kernel_exists(run_dir: str, level: int, problem_id: int, sample_id: int) -> bool: +def check_kernel_exists( + run_dir: str, level: int, problem_id: int, sample_id: int +) -> bool: """ Check if a kernel for a given problem and sample ID already exists in the run directory """ - kernel_path = os.path.join(run_dir, f"level_{level}_problem_{problem_id}_sample_{sample_id}_kernel.py") + kernel_path = os.path.join( + run_dir, f"level_{level}_problem_{problem_id}_sample_{sample_id}_kernel.py" + ) return os.path.exists(kernel_path) - + @pydra.main(base=GenerationConfig) def main(config: GenerationConfig): @@ -153,63 +202,69 @@ def main(config: GenerationConfig): elif config.dataset_src == "local": curr_level_dataset = construct_kernelbench_dataset(config.level) - num_problems_in_level = len(curr_level_dataset) if config.subset == (None, None): problem_id_range = range(1, num_problems_in_level) else: - assert config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level, f"Subset range {config.subset} out of range for Level {config.level}" + assert ( + config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level + ), f"Subset range {config.subset} out of range for Level {config.level}" problem_id_range = range(config.subset[0], config.subset[1]) - print(f"Generating on 1 sample each for level {config.level} problems: {problem_id_range}") + print( + f"Generating on 1 sample each for level {config.level} problems: {problem_id_range}" + ) # set up run directory run_dir = os.path.join(config.runs_dir, config.run_name) os.makedirs(run_dir, exist_ok=True) pydra.save_yaml(config.to_dict(), os.path.join(run_dir, "generation_config.yaml")) - - assert config.store_type == "local", "supporting local file-system based storage for now" # database integreation coming soon, need to migrate from CUDA Monkeys code + assert ( + config.store_type == "local" + ), "supporting local file-system based storage for now" # database integreation coming soon, need to migrate from CUDA Monkeys code problems_to_run = [] - for problem_id in range(problem_id_range.start, problem_id_range.stop + 1): # end index is inclusive + for problem_id in range( + problem_id_range.start, problem_id_range.stop + 1 + ): # end index is inclusive # assume sample id is 0 for now if not check_kernel_exists(run_dir, config.level, problem_id, sample_id=0): problems_to_run.append( - WorkArgs( - problem_id=int(problem_id), - sample_id=0 # fix to 0 for now - ) - ) - + WorkArgs(problem_id=int(problem_id), sample_id=0) # fix to 0 for now + ) # Create inference function with config parameters # We provide some presets in utils but you can also pass in your own, see query_server for more details - inference_server = create_inference_server_from_presets(server_type=config.server_type, - model_name=config.model_name, - temperature=config.temperature, - max_tokens=config.max_tokens, - verbose=config.verbose) + inference_server = create_inference_server_from_presets( + server_type=config.server_type, + model_name=config.model_name, + temperature=config.temperature, + max_tokens=config.max_tokens, + verbose=config.verbose, + ) # Launch workers - generation_results = maybe_multithread(generate_sample_launcher, - problems_to_run, - config.num_workers, - time_interval=config.api_query_interval, - # extra args - config=config, - dataset=curr_level_dataset, - inference_server=inference_server, - run_dir=run_dir - ) - + generation_results = maybe_multithread( + generate_sample_launcher, + problems_to_run, + config.num_workers, + time_interval=config.api_query_interval, + # extra args + config=config, + dataset=curr_level_dataset, + inference_server=inference_server, + run_dir=run_dir, + ) + num_generated_samples = len(generation_results) total_problems = len(problems_to_run) num_failed_problems = total_problems - num_generated_samples - print(f"Generated {num_generated_samples} samples for total {total_problems} problems, Please retry for the {num_failed_problems} failed problems.") + print( + f"Generated {num_generated_samples} samples for total {total_problems} problems, Please retry for the {num_failed_problems} failed problems." + ) if __name__ == "__main__": main() - diff --git a/src/eval.py b/src/eval.py index 8902c488..345495a5 100644 --- a/src/eval.py +++ b/src/eval.py @@ -29,6 +29,11 @@ KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") +def get_error_name(e: Exception) -> str: + + return f"{e.__class__.__module__}.{e.__class__.__name__}" + + def fetch_kernel_from_database( run_name: str, problem_id: int, sample_id: int, server_url: str ): @@ -118,7 +123,7 @@ def load_original_model_and_inputs( def load_custom_model_with_tempfile( - code_string, build_directory=None, entry_point="ModelNew" + model_custom_src, build_directory=None, entry_point="ModelNew" ): """ Writes the provided Python code string to a temporary .py file, @@ -139,7 +144,7 @@ def load_custom_model_with_tempfile( # Create a temporary named file with a .py extension with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_file: # Write the code string into the file - tmp_file.write(code_string) + tmp_file.write(model_custom_src) # Capture the path to the file tempfile_path = tmp_file.name temp_file = tmp_file @@ -377,6 +382,10 @@ def eval_kernel_against_ref( # set CUDA device torch.cuda.set_device(device) is_triton = backend == "triton" + metadata = {} # for storing result metadata + metadata["hardware"] = torch.cuda.get_device_name(device=device) + metadata["device"] = str(device) # for debugging + if is_triton: # need to set env var for triton code to guarentee no wrong device shennanignas if isinstance(device, int): @@ -415,10 +424,6 @@ def eval_kernel_against_ref( if verbose: print("[Eval] Loading and Compiling New Model with Custom CUDA Kernel") - metadata = {} # for storing result metadata - metadata["hardware"] = torch.cuda.get_device_name(device=device) - metadata["device"] = str(device) # for debugging - # this is where compilation happens try: os.environ["TORCH_USE_CUDA_DSA"] = "1" # compile with device side assertion @@ -446,6 +451,7 @@ def eval_kernel_against_ref( graceful_eval_cleanup(context, device, tempfile) return None else: + metadata["compilation_error_name"] = get_error_name(e) metadata["compilation_error"] = e graceful_eval_cleanup(context, device, tempfile) return KernelExecResult( @@ -468,6 +474,7 @@ def eval_kernel_against_ref( # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... graceful_eval_cleanup(context, device, tempfile) metadata["runtime_error"] = e + metadata["runtime_error_name"] = get_error_name(e) return KernelExecResult( compiled=True, correctness=False, metadata=metadata ) # skip further steps @@ -487,11 +494,11 @@ def eval_kernel_against_ref( verbose=verbose, seed=seed_num, device=device, - truncate_errors=not is_triton, ) except Exception as e: # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... metadata["runtime_error"] = e + metadata["runtime_error_name"] = get_error_name(e) kernel_exec_result = KernelExecResult( compiled=True, correctness=False, metadata=metadata ) @@ -627,7 +634,6 @@ def run_and_check_correctness( verbose=False, seed=42, device=None, - truncate_errors: bool = True, ) -> KernelExecResult: """ run the model and check correctness, @@ -678,6 +684,7 @@ def run_and_check_correctness( f"Output shape mismatch: Expected {output.shape}, got {output_new.shape}", metadata, ) + metadata["correctness_issue_name"] = "correctness_issue" if verbose: print( f"[FAIL] trial {trial}: Output shape mismatch: Expected {output.shape}, got {output_new.shape}" @@ -707,8 +714,9 @@ def run_and_check_correctness( print(f"Error in launching kernel for ModelNew: {e}") metadata = register_and_format_exception( - "runtime_error", e, metadata, truncate=truncate_errors + "runtime_error", e, metadata, truncate=True ) + metadata["runtime_error_name"] = get_error_name(e) return KernelExecResult( compiled=True, correctness=False, metadata=metadata ) diff --git a/src/prompt_constructor_triton.py b/src/prompt_constructor_triton.py index 807cceb8..e3ea945b 100644 --- a/src/prompt_constructor_triton.py +++ b/src/prompt_constructor_triton.py @@ -1,4 +1,5 @@ import os + from .utils import read_file @@ -44,12 +45,15 @@ def get_arch_definition(arch_src): Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n """ + def prompt_generate_custom_triton( arc_src: str, example_arch_src: str, example_new_arch_src: str ) -> str: prompt = PROBLEM_STATEMENT - assert "@triton.jit" in example_new_arch_src, "Example new arch must contain Triton kernel" + assert ( + "@triton.jit" in example_new_arch_src + ), "Example new arch must contain Triton kernel" if example_arch_src != "" and example_new_arch_src != "": prompt += f""" @@ -79,210 +83,15 @@ def prompt_generate_custom_triton( Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n """ -def prompt_generate_custom_triton_fewshot_and_template(ref_arch_src: str, shots: list) -> str: - """ - Generate a prompt with specified few-shot examples following a template - - shots: list of few-shot examples to include in the prompt - Avaliable few shot options to start with: - - ex_add: pointwise addition - - ex_fuse_gelu: fused gelu - - ex_mnist2: fused convolutions and relus (DEPRECATED) - - ex_tiled_matmul: tiled matrix multiplication - - ex_flash_attn: simple flash attention - """ - prompt = PROBLEM_STATEMENT_CLEANED - - # k = 1 - example_add = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_add.py") - ) - example_add_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_add.py") - ) - example_add_desc = "This given architecture is for a pointwise addition: " - - # k = 2 - example_fuse_gelu = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") - ) - example_fuse_gelu_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") - ) - example_fuse_gelu_desc = "This given architecture is for a fused gelu: " - - # k = 3 (DEPRECATED) - example_mnist2 = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") - ) - example_mnist2_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") - ) - exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " - - # k = 4 - example_tiled_matmul = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") - ) - example_tiled_matmul_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") - ) - example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " - - # k = 5 - example_flash_attn = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_flash_attn.py") - ) - example_flash_attn_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_flash_attn.py") - ) - example_flash_attn_desc = "This given architecture is for a model with simple io-aware implementation of attention, also known as flash attention: " - - examples = [] - for s in shots: - if s not in ["ex_add", "ex_fuse_gelu", "ex_mnist2", "ex_tiled_matmul", "ex_flash_attn"]: - raise ValueError(f"Invalid shot: {s}") - elif s == "ex_add": - examples.append((example_add, example_add_new, example_add_desc)) - elif s == "ex_fuse_gelu": - examples.append((example_fuse_gelu, example_fuse_gelu_new, example_fuse_gelu_desc)) - elif s == "ex_mnist2": # DEPRECATED - raise ValueError("ex_mnist2 is deprecated") - examples.append((example_mnist2, example_mnist2_new, exmaple_mnist2_desc)) - elif s == "ex_tiled_matmul": - examples.append((example_tiled_matmul, example_tiled_matmul_new, example_tiled_matmul_desc)) - elif s == "ex_flash_attn": - examples.append((example_flash_attn, example_flash_attn_new, example_flash_attn_desc)) - - - for i, tup in enumerate(examples): - base, kernel, desc = tup - prompt += f""" -Example {i+1}:\n\n -Here is an example architecture:\n\n -``` -{base} -```\n -{PROBLEM_INSTRUCTION_CLEANED} \n -Here is an optimized verison with custom CUDA kernels: \n -``` -{kernel} -```\n\n -""" +def prompt_generate_custom_triton_fewshot_and_template( + ref_arch_src: str, shots: list +) -> str: + raise NotImplementedError("This function has not been implemented yet") -# should we put task here? - prompt += f""" -Task:\n\n -Here is an example architecture:\n\n -``` -{ref_arch_src} -```\n -""" - prompt += PROBLEM_INSTRUCTION_CLEANED - return prompt def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) -> str: - """ - Generate a prompt with a CoT example following a template - Avaliable CoT examples: - - ex_fuse_gelu: fused gelu - - ex_mnist2: fused convolutions and relus - - ex_tiled_matmul: tiled matrix multiplication - """ - - # I updated this to allow CoT. Also explicilty state think step by step. - PROBLEM_INSTRUCTION_COT = """ -Optimize the architecture named Model with custom CUDA operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Do not output testing code. -In the end, make sure the final code block contains code for output architecture ModelNew with cuda code.\n -Let's think step by step.\n -""" - - prompt = PROBLEM_STATEMENT_CLEANED - - assert cot_example in ["ex_fuse_gelu", "ex_mnist2", "ex_tiled_matmul"] - - # k = 2 - example_fuse_gelu = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") - ) - example_fuse_gelu_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_fuse_gelu.py") - ) - example_fuse_gelu_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") - ) - example_fuse_gelu_desc = "This given architecture is for a fused gelu: " - - # k = 3 - example_mnist2 = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") - ) - example_mnist2_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_mnist2.py") - ) - example_mnist2_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") - ) - exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " - - # k = 4 - example_tiled_matmul = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") - ) - example_tiled_matmul_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_tiled_matmul.py") - ) - example_tiled_matmul_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") - ) - example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " - - match cot_example: - case "ex_fuse_gelu": - base = example_fuse_gelu - cot = example_fuse_gelu_cot - kernel = example_fuse_gelu_new - desc = example_fuse_gelu_desc - case "ex_mnist2": - base = example_mnist2 - cot = example_mnist2_cot - kernel = example_mnist2_new - desc = exmaple_mnist2_desc - case "ex_tiled_matmul": - base = example_tiled_matmul - cot = example_tiled_matmul_cot - kernel = example_tiled_matmul_new - desc = example_tiled_matmul_desc - case _: - raise ValueError(f"Invalid CoT example: {cot_example} not found in CoT examples") - - # construct example with - # NOTE: we only do one example with CoT for now - # 1. ref_src problem -> 2. Instruction -> 3. CoT -> 4. Solution - prompt += f""" -Here is an example architecture:\n\n -``` -{base} -```\n -{PROBLEM_INSTRUCTION_COT} \n -{cot} \n -``` -{kernel} -```\n\n -""" - -# show task to solve - prompt += f""" -Task:\n\n -Here is an example architecture:\n\n -``` -{ref_arch_src} -```\n -""" - prompt += PROBLEM_INSTRUCTION_COT - - return prompt + raise NotImplementedError("This function has not been implemented yet") def prompt_generate_custom_triton_from_prompt_template(ref_arch_src: str) -> str: @@ -294,9 +103,7 @@ def prompt_generate_custom_triton_from_prompt_template(ref_arch_src: str) -> str # These are strictly defined for now # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) - example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_add.py" - ) + example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") example_new_arch_path = os.path.join( REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" ) @@ -316,9 +123,11 @@ def prompt_generate_custom_triton_from_prompt_template(ref_arch_src: str) -> str return prompt_generate_custom_triton(arch, example_arch, example_new_arch) -def prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src: str, gpu_name: str) -> str: +def prompt_generate_prompt_with_hardware_info_from_template( + ref_arch_src: str, gpu_name: str +) -> str: """ - Similar to prompt_generate_custom_triton_from_prompt_template, + Similar to prompt_generate_custom_triton_from_prompt_template, but with hardware information for the given GPU """ @@ -326,34 +135,35 @@ def prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src: str, g # These are strictly defined for now # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) - example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_add.py" - ) + example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") example_new_arch_path = os.path.join( REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" ) - gpu_spec_file_path = os.path.join(REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py") + gpu_spec_file_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py" + ) example_arch = read_file(example_arch_path) example_new_arch = read_file(example_new_arch_path) gpu_spec_info = read_file(gpu_spec_file_path) return prompt_generate_prompt_with_hardware_info( - ref_arch_src=arch, - gpu_name=gpu_name, - example_arch_src=example_arch, - example_new_arch_src=example_new_arch, - gpu_spec_info_src=gpu_spec_info - ) - + ref_arch_src=arch, + gpu_name=gpu_name, + example_arch_src=example_arch, + example_new_arch_src=example_new_arch, + gpu_spec_info_src=gpu_spec_info, + ) -def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, - gpu_name: str, - example_arch_src: str, - example_new_arch_src: str, - gpu_spec_info_src: str) -> str: +def prompt_generate_prompt_with_hardware_info( + ref_arch_src: str, + gpu_name: str, + example_arch_src: str, + example_new_arch_src: str, + gpu_spec_info_src: str, +) -> str: """ Generate a prompt with hardware information for the given GPU gpu_spec_info_src: str of the gpu spec src file @@ -361,17 +171,19 @@ def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, # Create a dictionary to store the local namespace local_dict = {} - + # Execute the GPU spec file in the local namespace exec(gpu_spec_info_src, {}, local_dict) - + # Get the required variables from the local namespace - GPU_SPEC_INFO = local_dict.get('GPU_SPEC_INFO') - GPU_DEFINITIONS = local_dict.get('GPU_DEFINITIONS') - GPU_BEST_PRACTICES = local_dict.get('GPU_BEST_PRACTICES') - + GPU_SPEC_INFO = local_dict.get("GPU_SPEC_INFO") + GPU_DEFINITIONS = local_dict.get("GPU_DEFINITIONS") + GPU_BEST_PRACTICES = local_dict.get("GPU_BEST_PRACTICES") + if not GPU_SPEC_INFO or not GPU_DEFINITIONS or not GPU_BEST_PRACTICES: - raise ValueError("GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src") + raise ValueError( + "GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src" + ) assert gpu_name in GPU_SPEC_INFO, f"GPU name {gpu_name} not found in GPU_SPEC_INFO" @@ -388,20 +200,19 @@ def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, {example_new_arch_src} ``` \n """ - + curr_gpu_spec_info = GPU_SPEC_INFO[gpu_name] gpu_architecture = curr_gpu_spec_info.get("GPU Architecture") prompt += f""" Here is some information about the underlying hardware that you should keep in mind. \n\n The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture.\n\n""" - + for key, value in curr_gpu_spec_info.items(): if key == "GPU Architecture": continue prompt += f"""- We have {value} of {key}.\n""" - - + prompt += f"""\n\n Here are some concepts about the GPU architecture that could be helpful: \n\n""" for key, value in GPU_DEFINITIONS.items(): @@ -412,25 +223,19 @@ def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, for best_practice in GPU_BEST_PRACTICES: prompt += f"""- {best_practice}\n""" - prompt += f""" You are given the following architecture: \n ``` {ref_arch_src} ``` """ - prompt += PROBLEM_INSTRUCTION return prompt - return None - - - def prompt_fix_compile(ref_arch_src, custom_cuda, metadata): prompt = PROBLEM_STATEMENT prompt += f""" @@ -471,13 +276,15 @@ def prompt_fix_correctness(ref_arch_src, custom_cuda, metadata): """ return prompt + def main(): gpu_name = "L40S" - ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, f"level1/19_ReLU.py")) assert len(ref_arch_src) > 0, "ref_arch_src is empty" - prompt = prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src, gpu_name) + prompt = prompt_generate_prompt_with_hardware_info_from_template( + ref_arch_src, gpu_name + ) print(prompt) # Write prompt to temp file temp_file_path = os.path.join(REPO_TOP_PATH, "scratch", "prompt_draft.txt") @@ -485,5 +292,6 @@ def main(): with open(temp_file_path, "w") as f: f.write(prompt) + if __name__ == "__main__": main() From eb4e8aaceb3d6466f5fce8bdb2d5623f9a4280d0 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 18 Mar 2025 15:14:05 -0700 Subject: [PATCH 4/7] revert eval --- scripts/eval_from_generations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 4f9e560d..4898ed7a 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -82,14 +82,14 @@ def __init__(self): # To speedup evaluation, you can start building the kernel on CPU on disk as cache self.build_cache = False self.num_cpu_workers = ( - 96 # number of parallel process to to parallelize the build on CPUs + 20 # number of parallel process to to parallelize the build on CPUs ) # Directory to build kernels for evaluation self.kernel_eval_build_dir = os.path.join(REPO_TOP_DIR, "cache") # number of GPUs to do batch evaluation - self.num_gpu_devices = 8 + self.num_gpu_devices = 1 # Backend to use for kernel implementation (cuda or triton) self.backend = "cuda" From e9bf734c08a4fb36b7f067ad6e7e369d63feecec Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 18 Mar 2025 15:16:44 -0700 Subject: [PATCH 5/7] remove traceback --- scripts/eval_from_generations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 4898ed7a..25bf3e80 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -3,7 +3,6 @@ import os import shutil import time -import traceback from dataclasses import dataclass import pydra @@ -349,7 +348,6 @@ def batch_eval( print( f"[ERROR] Evaluation FAILED for Problem ID: {problem_id}, Sample ID: {sample_id}: {str(e)}" ) - traceback.print_exc() results.append((problem_id, sample_id, None)) remove_cache_dir( config.kernel_eval_build_dir, From 26d4cc0d07d48cdafa862fa3c5eb6a725ed09d39 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 25 Mar 2025 09:36:10 -0700 Subject: [PATCH 6/7] remove cot --- .../triton/model_new_ex_add_triton.py | 63 ------------------- 1 file changed, 63 deletions(-) delete mode 100644 src/prompts/few_shot/triton/model_new_ex_add_triton.py diff --git a/src/prompts/few_shot/triton/model_new_ex_add_triton.py b/src/prompts/few_shot/triton/model_new_ex_add_triton.py deleted file mode 100644 index 43a3f712..00000000 --- a/src/prompts/few_shot/triton/model_new_ex_add_triton.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import triton -import triton.language as tl - - -@triton.jit -def add_kernel( - x_ptr, # Pointer to first input - y_ptr, # Pointer to second input - out_ptr, # Pointer to output - n_elements, # Total number of elements in input/output - BLOCK_SIZE: tl.constexpr, -): - # Each program handles a contiguous block of data of size BLOCK_SIZE - block_start = tl.program_id(0) * BLOCK_SIZE - # Create a range of offsets [0..BLOCK_SIZE-1] - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Mask to ensure we don't go out of bounds - mask = offsets < n_elements - # Load input values - x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - y = tl.load(y_ptr + offsets, mask=mask, other=0.0) - # Perform the elementwise addition - out = x + y - # Store the result - tl.store(out_ptr + offsets, out, mask=mask) - - -def triton_add(x: torch.Tensor, y: torch.Tensor): - """ - This function wraps the Triton kernel call. It: - 1. Ensures the inputs are contiguous on GPU. - 2. Calculates the grid (blocks) needed. - 3. Launches the Triton kernel. - """ - assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." - x = x.contiguous() - y = y.contiguous() - - # Prepare output tensor - out = torch.empty_like(x) - - # Number of elements in the tensor - n_elements = x.numel() - BLOCK_SIZE = 128 # Tunable parameter for block size - - # Determine the number of blocks needed - grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) - - # Launch the Triton kernel - add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) - return out - - -class ModelNew(nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, a, b): - # Instead of "return a + b", call our Triton-based addition - return triton_add(a, b) \ No newline at end of file From e5ba9b3d1d12f5cca88978c3429eefaeb52899df Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 25 Mar 2025 11:06:51 -0700 Subject: [PATCH 7/7] improve eval --- src/eval.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/eval.py b/src/eval.py index 345495a5..370786d4 100644 --- a/src/eval.py +++ b/src/eval.py @@ -122,9 +122,7 @@ def load_original_model_and_inputs( return (Model, get_init_inputs_fn, get_inputs_fn) -def load_custom_model_with_tempfile( - model_custom_src, build_directory=None, entry_point="ModelNew" -): +def load_custom_model_with_tempfile(model_custom_src, entry_point="ModelNew"): """ Writes the provided Python code string to a temporary .py file, dynamically imports the module so we can access the modified model class. @@ -136,11 +134,6 @@ def load_custom_model_with_tempfile( with the @triton.jit decorator. """ - if build_directory: - model_custom_src = ( - "import os\n" f"os.environ['TORCH_EXTENSIONS_DIR'] = '{build_directory}'\n" - ) + model_custom_src - # Create a temporary named file with a .py extension with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_file: # Write the code string into the file @@ -431,7 +424,7 @@ def eval_kernel_against_ref( # add hash for later to distinguish between multi-turn kernels if is_triton: ModelNew, tempfile = load_custom_model_with_tempfile( - custom_model_src, build_dir + custom_model_src, entry_point="ModelNew" ) else: ModelNew = load_custom_model(custom_model_src, context, build_dir)