diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py new file mode 100644 index 000000000..2f68cfa68 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example script for applying sparse attention to HuggingFace models.""" + +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +# Enable HuggingFace checkpointing support +mto.enable_huggingface_checkpointing() + +# You can define custom configurations or use the default +SPARSE_ATTN_CFG_CHOICES = { + "skip_softmax": SKIP_SOFTMAX_DEFAULT, +} + + +def get_narrativeqa_samples(num_samples=3): + """Load samples from NarrativeQA dataset for testing. + + Args: + num_samples: Number of samples to generate + + Raises: + RuntimeError: If dataset loading fails + ValueError: If no valid samples could be loaded + """ + # Load NarrativeQA dataset with retry logic + try: + dataset = load_dataset("narrativeqa", split="test", streaming=True) + except Exception as e: + raise RuntimeError(f"Failed to load NarrativeQA dataset: {e}") + + samples = [] + for i, item in enumerate(dataset): + if i >= num_samples: + break + + # Combine document context and question + context = item.get("document", {}).get("text", "") + question = item.get("question", {}).get("text", "") + + if context and question: + # Use the full context as-is + prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" + samples.append(prompt) + + if not samples: + raise ValueError("Could not load NarrativeQA samples") + + print(f"Loaded {len(samples)} NarrativeQA samples") + return samples + + +def truncate_text(text: str, tokenizer, max_length: int): + """Truncate text from the middle to preserve beginning and end. + + Args: + text: Input text to truncate + tokenizer: Tokenizer to use for encoding + max_length: Maximum number of tokens + + Returns: + Truncated text that fits within max_length tokens + """ + # First tokenize to see if truncation is needed + tokens = tokenizer.encode(text, add_special_tokens=True) + + if len(tokens) <= max_length: + return text + + # Need to truncate - preserve beginning and end + # Calculate actual special tokens used + dummy_tokens = tokenizer.encode("", add_special_tokens=True) + special_token_count = len(dummy_tokens) + available_tokens = max_length - special_token_count + + # Split tokens roughly in half for beginning and end + begin_tokens = available_tokens // 2 + end_tokens = available_tokens - begin_tokens + + # Decode beginning and end parts + begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True) + end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True) + + # Combine with ellipsis marker + return begin_text + " [...] " + end_text + + +def verify_outputs(model, tokenizer, args): + """Compare outputs between baseline and sparse attention models.""" + # Update seq_len to match calibration max_seqlen if calibration was used + base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) + if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: + calib_max_seqlen = base_config["calibration"]["max_seqlen"] + if args.seq_len != calib_max_seqlen: + print( + f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " + f"to match calibration config" + ) + args.seq_len = calib_max_seqlen + + # Load and prepare a single test prompt + print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + prompts = get_narrativeqa_samples(num_samples=1) + prompt = prompts[0] + + # Prepare inputs + truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) + display_prompt = ( + truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt + ) + + inputs = tokenizer( + truncated_prompt, + return_tensors="pt", + max_length=args.seq_len, + truncation=True, + padding=False, + ) + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + print("\n" + "=" * 60) + print("BASELINE vs SPARSE ATTENTION COMPARISON") + print("=" * 60) + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {inputs['input_ids'].shape[1]}") + + # Helper function to generate text + def generate_text(model, inputs, args, tokenizer): + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature if args.do_sample else 1.0, + pad_token_id=tokenizer.pad_token_id, + ) + input_length = inputs["input_ids"].shape[1] + generated_ids = outputs[0][input_length:] + return tokenizer.decode(generated_ids, skip_special_tokens=True) + + # Find all sparse attention modules + sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + # Generate baseline by temporarily disabling sparse attention + print("\n" + "-" * 60) + print("Generating baseline (sparse attention disabled)...") + for module in sparse_modules: + module.disable() + baseline_text = generate_text(model, inputs, args, tokenizer) + + # Generate with sparse attention enabled + print("\nGenerating with sparse attention (calibrated thresholds)...") + for module in sparse_modules: + module.enable() + sparse_text = generate_text(model, inputs, args, tokenizer) + + # Display comparison + print("\n" + "-" * 60) + print("RESULTS:") + baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text + sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text + + print(f"\nBaseline: {baseline_display}") + print(f"With Sparse: {sparse_display}") + + if baseline_text == sparse_text: + print("\nOutputs are identical") + else: + print("\nOutputs differ") + + +def sparsify_model(model, args): + """Apply sparse attention to the model with optional calibration.""" + print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") + base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + + # Create modified config with selected backend + modified_sparse_cfg = {} + for pattern, cfg in base_config["sparse_cfg"].items(): + modified_cfg = cfg.copy() + modified_cfg["backend"] = args.backend + modified_sparse_cfg[pattern] = modified_cfg + + # Create new config with modified settings + sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg) + + # Sparsify the model + model = mtsa.sparsify(model, config=sparse_config) + + print("Sparse attention applied successfully!") + + return model + + +def main(args): + """Main function to run the selected mode.""" + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + launch_memory_monitor() + + print(f"Loading model: {args.pyt_ckpt_path}") + + # Load model and tokenizer + # Note: attn_implementation="eager" is required for calibration to work properly + # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) + model = AutoModelForCausalLM.from_pretrained( + args.pyt_ckpt_path, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + ) + tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) + + # Set pad token if not set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Move model to GPU if available + if torch.cuda.is_available(): + model = model.cuda() + print("Model moved to CUDA") + + # Apply sparse attention to the model (with calibration if configured) + model = sparsify_model(model, args) + + # Verify outputs if requested (compares baseline vs calibrated sparse model) + if args.verify_output: + verify_outputs(model, tokenizer, args) + + # Export if requested + if args.export_dir: + print(f"\nExporting model to: {args.export_dir}") + export_dir = Path(args.export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + with torch.inference_mode(): + export_hf_checkpoint(model, export_dir=export_dir) + + tokenizer.save_pretrained(export_dir) + print(f"Model exported successfully to: {export_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + # Model arguments + parser.add_argument( + "--pyt_ckpt_path", + type=str, + required=True, + help="Specify where the PyTorch checkpoint path is", + ) + parser.add_argument( + "--sparse_attn", + type=str, + default="skip_softmax", + choices=list(SPARSE_ATTN_CFG_CHOICES.keys()), + help="Sparse attention configuration to apply.", + ) + parser.add_argument( + "--backend", + type=str, + default="pytorch", + choices=["pytorch", "triton"], + help="Backend to use for sparse attention computation (default: pytorch)", + ) + + # Sequence length arguments + parser.add_argument( + "--seq_len", + type=int, + default=2048, + help="Maximum sequence length for input prompts (will be truncated if longer)", + ) + parser.add_argument( + "--num_samples", + type=int, + default=3, + help="Number of samples to use from NarrativeQA dataset", + ) + + # Generation arguments + parser.add_argument( + "--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate" + ) + parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") + + # Operation arguments + parser.add_argument( + "--verify_output", + action="store_true", + help="Verify that sparse attention outputs match baseline", + ) + parser.add_argument( + "--export_dir", + type=str, + default=None, + help="Directory to export the model with sparse attention applied", + ) + + args = parser.parse_args() + main(args) diff --git a/examples/llm_sparsity/.gitignore b/examples/llm_sparsity/weight_sparsity/.gitignore similarity index 100% rename from examples/llm_sparsity/.gitignore rename to examples/llm_sparsity/weight_sparsity/.gitignore diff --git a/examples/llm_sparsity/README.md b/examples/llm_sparsity/weight_sparsity/README.md similarity index 100% rename from examples/llm_sparsity/README.md rename to examples/llm_sparsity/weight_sparsity/README.md diff --git a/examples/llm_sparsity/data_prep.py b/examples/llm_sparsity/weight_sparsity/data_prep.py similarity index 100% rename from examples/llm_sparsity/data_prep.py rename to examples/llm_sparsity/weight_sparsity/data_prep.py diff --git a/examples/llm_sparsity/eval.py b/examples/llm_sparsity/weight_sparsity/eval.py similarity index 100% rename from examples/llm_sparsity/eval.py rename to examples/llm_sparsity/weight_sparsity/eval.py diff --git a/examples/llm_sparsity/export_trtllm_ckpt.py b/examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py similarity index 100% rename from examples/llm_sparsity/export_trtllm_ckpt.py rename to examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py diff --git a/examples/llm_sparsity/finetune.py b/examples/llm_sparsity/weight_sparsity/finetune.py similarity index 95% rename from examples/llm_sparsity/finetune.py rename to examples/llm_sparsity/weight_sparsity/finetune.py index 3cfc1073f..869068dbd 100644 --- a/examples/llm_sparsity/finetune.py +++ b/examples/llm_sparsity/weight_sparsity/finetune.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li diff --git a/examples/llm_sparsity/hf_pts.py b/examples/llm_sparsity/weight_sparsity/hf_pts.py similarity index 100% rename from examples/llm_sparsity/hf_pts.py rename to examples/llm_sparsity/weight_sparsity/hf_pts.py diff --git a/examples/llm_sparsity/launch_finetune.sh b/examples/llm_sparsity/weight_sparsity/launch_finetune.sh similarity index 100% rename from examples/llm_sparsity/launch_finetune.sh rename to examples/llm_sparsity/weight_sparsity/launch_finetune.sh diff --git a/examples/llm_sparsity/requirements.txt b/examples/llm_sparsity/weight_sparsity/requirements.txt similarity index 100% rename from examples/llm_sparsity/requirements.txt rename to examples/llm_sparsity/weight_sparsity/requirements.txt diff --git a/examples/llm_sparsity/utils.py b/examples/llm_sparsity/weight_sparsity/utils.py similarity index 100% rename from examples/llm_sparsity/utils.py rename to examples/llm_sparsity/weight_sparsity/utils.py diff --git a/modelopt/torch/sparsity/attention_sparsity/__init__.py b/modelopt/torch/sparsity/attention_sparsity/__init__.py new file mode 100644 index 000000000..150f93a3a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention optimization for transformer models.""" + +# Initialize mode +from . import mode + +# Add methods to namespace +from .config import * +from .conversion import * +from .model_sparsify import * diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py new file mode 100644 index 000000000..e72dacc94 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for sparse attention optimization.""" + +from collections.abc import Callable +from typing import Any + +from pydantic import Field, field_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +# Type definitions for sparse configuration +SparseAttributeConfig = dict[str, Any] # Configuration for a specific pattern + +SparseAttentionCfgType = dict[ + str | Callable, # Pattern or callable for matching modules + SparseAttributeConfig, # Configuration dict with threshold, enable, etc. +] + + +class SparseAttentionAttributeConfig(ModeloptBaseConfig): + """Sparse attention attribute configuration for pattern-based module config.""" + + method: str = ModeloptField( + default="flash_skip_softmax", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_skip_softmax').", + ) + + enable: bool = ModeloptField( + default=True, + title="Enable sparse attention.", + description="If True, enables sparse attention. If False, bypasses sparsity.", + ) + + threshold: float | dict[str, float] = ModeloptField( + default=1e-3, + title="Sparsity threshold.", + description=( + "Threshold for determining which attention values to skip. " + "Can be a float or dict with phase-specific values." + ), + ) + + br: int = ModeloptField( + default=128, + title="Block row size.", + description="Block row size for block-wise sparsity in Flash Attention.", + ) + + bc: int = ModeloptField( + default=128, + title="Block column size.", + description="Block column size for block-wise sparsity in Flash Attention.", + ) + + backend: str = ModeloptField( + default="pytorch", + title="Backend implementation.", + description=( + "Backend to use for sparse attention computation. " + "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " + "Requires model to be loaded with attn_implementation='eager'." + ), + ) + + is_causal: bool = ModeloptField( + default=True, + title="Causal attention flag.", + description=( + "Whether the model uses causal (autoregressive) attention. " + "If True, sparsity statistics are calculated over the lower triangle only. " + "Defaults to True for decoder-only models like GPT, LLaMA, etc." + ), + ) + + calibration: dict | None = ModeloptField( + default=None, + title="Calibration configuration", + description=( + "Calibration settings for this pattern. " + "If provided, enables automatic threshold calibration. " + "Only one pattern should have calibration enabled." + ), + ) + + @field_validator("method") + @classmethod + def validate_method(cls, v): + """Validate method is a string.""" + if not isinstance(v, str): + raise ValueError("method must be a string") + return v + + @field_validator("backend") + @classmethod + def validate_backend(cls, v): + """Validate backend is pytorch.""" + if v != "pytorch": + raise ValueError( + f"Invalid backend: {v}. Only 'pytorch' backend is supported. " + f"Model must be loaded with attn_implementation='eager'." + ) + return v + + @field_validator("br", "bc") + @classmethod + def validate_block_size(cls, v): + """Validate block sizes are positive integers.""" + if v <= 0: + raise ValueError(f"Block size must be positive, got {v}") + return v + + @field_validator("threshold") + @classmethod + def validate_threshold(cls, v): + """Validate threshold is in valid range (0, 1) or dict with valid phases.""" + if isinstance(v, dict): + # Validate phase keys + valid_phases = {"prefill", "decode", "default"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range (0, 1) + for phase, threshold in v.items(): + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + raise ValueError( + f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + ) + elif isinstance(v, (int, float)): + if v <= 0 or v >= 1: + raise ValueError(f"Threshold must be in range (0, 1), got {v}") + else: + raise ValueError(f"Threshold must be a number in range (0, 1) or dict, got {type(v)}") + return v + + +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +class SparseAttentionConfig(ModeloptBaseConfig): + """Base configuration for sparse attention optimization. + + This base configuration provides the common structure for all sparse + attention methods and supports pattern-based layer configuration. + """ + + # Pattern-based sparse configuration (similar to quant_cfg in quantization) + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": {"method": "flash_skip_softmax", "enable": True}, + "default": {"enable": False}, + }, + title="Sparse attention configuration", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " + "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + validate_default=True, + ) + + # Export configuration + export_format: str | None = Field( + None, description="Export format for sparse attention (e.g., 'onnx', 'tensorrt')" + ) + + +class FlashSkipSoftmaxConfig(SparseAttentionConfig): + """Configuration for Flash Attention-aware softmax skip sparse attention.""" + + # Override sparse_cfg with flash_skip_softmax specific defaults + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, + title="Flash softmax skip sparse configuration", + description="Pattern-based configuration with flash_skip_softmax specific defaults. " + "Includes FA block sizes (br, bc) and correction factor settings.", + validate_default=True, + ) + + +__all__ = [ + "SKIP_SOFTMAX_DEFAULT", + "FlashSkipSoftmaxConfig", + "SparseAttentionAttributeConfig", + "SparseAttentionCfgType", + "SparseAttentionConfig", + "SparseAttributeConfig", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py new file mode 100644 index 000000000..25347c37f --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion and restoration utilities for sparse attention.""" + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.utils import get_unwrapped_name + +from .config import SparseAttentionConfig +from .plugins.huggingface import register_sparse_attention_on_the_fly +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry + + +def is_attn_sparsified(model: nn.Module) -> bool: + """Check if a model has sparse attention applied. + + Similar to quantization's is_quantized for API consistency. + + Args: + model: Model to check + + Returns: + True if model contains any SparseAttentionModule instances + """ + return any(isinstance(module, SparseAttentionModule) for module in model.modules()) + + +def convert_to_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig +) -> ConvertReturnType: + """Convert model to use sparse attention. + + Args: + model: Model to convert + config: Sparse attention configuration + + Returns: + Tuple of (converted_model, metadata) + """ + # Initialize the true module if necessary + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + # Register sparse attention modules dynamically + register_sparse_attention_on_the_fly(model) + + # Replace attention modules with sparse versions + replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version) + + # Apply configuration to sparse attention modules + sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} + set_sparse_attention_by_cfg(model, sparse_cfg) + + # Create metadata + metadata = {} + update_sparse_attention_metadata(model, config, metadata) + + return model, metadata + + +def replace_sparse_attention_modules(model: nn.Module, version=None): + """Replace regular attention modules with sparse attention modules. + + Recursively replace all attention modules in the model with their sparse attention counterparts. + + Args: + model: Model to process + version: State version for tracking (optional) + """ + # Recursively replace modules + _replace_sparse_attention_modules(model, version=version) + + # Count and report replaced modules + replaced_count = sum(isinstance(m, SparseAttentionModule) for _, m in model.named_modules()) + if replaced_count > 0: + print(f"Inserted {replaced_count} sparse attention modules") + + +def _replace_sparse_attention_modules(model: nn.Module, version=None): + """Helper function for replace_sparse_attention_modules.""" + for name, child in model.named_children(): + if type(child) in SparseAttentionRegistry: + # REPLACE on the parent (model), not on child + sparse_module = SparseAttentionRegistry.convert(child) + setattr(model, name, sparse_module) + + # Now recurse into whichever module is now at `model.name` + _replace_sparse_attention_modules(getattr(model, name), version=version) + + +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): + """Apply sparse attention configuration to model. + + Similar to quantization's set_quantizer_by_cfg. + + Args: + model: Model with sparse attention modules + sparse_cfg: Sparse configuration dictionary mapping patterns to attributes + """ + sparse_cfg = sparse_cfg.copy() + + # Apply default first if exists + if "default" in sparse_cfg: + set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) + sparse_cfg.pop("default") + + # Apply pattern-specific configs + for pattern, cfg in sparse_cfg.items(): + set_sparse_attention_attribute(model, pattern, cfg) + + +def set_sparse_attention_attribute( + model: nn.Module, + wildcard_or_filter: str | Callable, + attribute_cfg: dict[str, Any], +): + """Set sparse attention attributes for modules matching pattern. + + Similar to quantization's set_quantizer_attribute. + + Args: + model: Model to configure + wildcard_or_filter: Pattern to match module names + attribute_cfg: Attributes to apply (must include 'method') + """ + # Filter out model-level configs that shouldn't be passed to modules + module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} + + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + # Check pattern match + matched = False + if isinstance(wildcard_or_filter, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter) + elif callable(wildcard_or_filter): + matched = wildcard_or_filter(name) + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") + + if matched: + # Apply config using the same method as TensorQuantizer + module.set_from_attribute_config(module_cfg) + + +def restore_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig, metadata: MetadataDict +) -> nn.Module: + """Restore sparse attention model from saved state. + + Args: + model: Model to restore + config: Sparse attention configuration + metadata: Saved metadata + + Returns: + Restored model + """ + # Convert to sparse attention model + model, _ = convert_to_sparse_attention_model(model, config) + + # Restore sparse attention state from metadata + if "sparse_attention_state" in metadata: + restore_sparse_attention_state(model, metadata["sparse_attention_state"]) + + return model + + +def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): + """Restore sparse attention state from state dict. + + Args: + model: Model with sparse attention modules + state_dict: Saved state dictionary + """ + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + if module_name in state_dict: + module_state = state_dict[module_name] + + # Restore method and config + if "method" in module_state: + module._method = module_state["method"] + if "method_config" in module_state: + # Restore config attributes + for key, val in module_state["method_config"].items(): + setattr(module, f"_{key}", val) + + # Re-setup with restored config + module._setup() + + +def update_sparse_attention_metadata( + model: nn.Module, config: SparseAttentionConfig, metadata: MetadataDict +) -> None: + """Update metadata with sparse attention state. + + Args: + model: Model with sparse attention + config: Configuration used + metadata: Metadata dict to update + """ + sparse_state = {} + + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + + # Save the method configuration that was used + # _method_config already contains the validated config dict + module_state = { + "method": module._sparse_method_instance.name, + "method_config": module._method_config.copy(), + } + + sparse_state[module_name] = module_state + + metadata["sparse_attention_state"] = sparse_state + metadata["sparse_attention_config"] = ( + config.model_dump() if hasattr(config, "model_dump") else vars(config) + ) + + +def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Disable sparse attention for matching modules. + + Similar to mtq.disable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*lm_head*", "*layer_0*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Disable sparse attention for lm_head + >>> sparse_attn.disable_sparse_attention(model, "*lm_head*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.disable() + + +def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Enable sparse attention for matching modules. + + Similar to mtq.enable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*attention*", "*attn*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Re-enable sparse attention for all attention modules + >>> sparse_attn.enable_sparse_attention(model, "*attention*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.enable() + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Similar to mtq.print_quant_summary for API consistency. + + Args: + model: Model with sparse attention applied + + Prints: + - Total sparse attention modules + - Enabled vs disabled count + - Method distribution + - Configuration summary by module + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> sparse_attn.print_sparse_attention_summary(model) + """ + sparse_modules = [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + sparse_modules.append((name, module)) + + if not sparse_modules: + print("No sparse attention modules found in model") + return + + enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) + disabled_count = len(sparse_modules) - enabled_count + + # Count methods + method_counts = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + method_counts[method] = method_counts.get(method, 0) + 1 + + print(f"Total sparse attention modules: {len(sparse_modules)}") + print(f"Enabled: {enabled_count}") + print(f"Disabled: {disabled_count}") + + if method_counts: + print("\nMethods:") + for method, count in sorted(method_counts.items()): + print(f"{method}: {count}") + + for name, module in sparse_modules: + method = getattr(module, "_method", "unknown") + threshold = getattr(module, "_threshold", "N/A") + + # Format threshold nicely + if isinstance(threshold, dict): + threshold_str = str(threshold) + elif isinstance(threshold, float): + threshold_str = f"{threshold:.2e}" + else: + threshold_str = str(threshold) + + print(f"{name}: Method: {method}, Threshold: {threshold_str}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py new file mode 100644 index 000000000..8a109fda7 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention methods package.""" + +from .registry import SparseAttentionMethod, get_sparse_method, register_sparse_method + +__all__ = [ + "SparseAttentionMethod", + "get_sparse_method", + "register_sparse_method", +] + +# Import method implementations to trigger registration +from . import flash_skip_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py new file mode 100644 index 000000000..8801bafb0 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -0,0 +1,299 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flash Attention-aware softmax skip method for sparse attention. + +This module implements block-wise sparsity that aligns with Flash Attention's +processing pattern for optimal performance. +""" + +import math + +import numpy as np +import torch + +from . import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("flash_skip_softmax") +class FlashSkipSoftmax(SparseAttentionMethod): + """Flash Attention-aware softmax skip sparse attention method. + + Implements row-level block-wise sparsity aligned with Flash Attention's + processing pattern for optimal performance and accuracy. + """ + + def __init__(self, method_config: dict | None = None): + """Initialize Flash softmax skip method. + + Args: + method_config: Configuration dict with threshold, br, bc, is_causal, etc. + All required fields should have defaults from SparseAttentionAttributeConfig. + """ + config = method_config or {} + + # Extract configuration (defaults handled by Pydantic) + self.threshold_config = config["threshold"] + self.br = config["br"] + self.bc = config["bc"] + self.backend = config["backend"] + self.is_causal = config["is_causal"] + + # Optional parameters not in Pydantic config + self.enable_correction_factor = config.get("enable_correction_factor", True) + self.phase = config.get("phase", None) + + # Initialize threshold + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + "default", self.threshold_config.get("prefill", 1e-4) + ) + else: + self.threshold = self.threshold_config + + def _update_threshold(self, phase: str): + """Update threshold based on phase.""" + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + phase, self.threshold_config.get("default", self.threshold) + ) + + def _infer_phase(self, attention_scores: torch.Tensor) -> str: + """Infer phase from attention scores shape.""" + return "decode" if attention_scores.shape[2] == 1 else "prefill" + + def _reshape_to_blocks( + self, tensor: torch.Tensor, br: int, bc: int + ) -> tuple[torch.Tensor, ...]: + """Reshape tensor into blocks for Flash Attention processing. + + Args: + tensor: Input tensor of shape [batch, heads, seq_q, seq_k] + br: Block row size + bc: Block column size + + Returns: + Tuple of (blocked_tensor, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k) + """ + batch_size, num_heads, seq_q, seq_k = tensor.shape + + # Calculate padding needed + padded_seq_q = math.ceil(seq_q / br) * br + padded_seq_k = math.ceil(seq_k / bc) * bc + + # Pad tensor if necessary + if padded_seq_q != seq_q or padded_seq_k != seq_k: + pad_q = padded_seq_q - seq_q + pad_k = padded_seq_k - seq_k + # Use dtype min instead of -inf for numerical stability + pad_value = torch.finfo(tensor.dtype).min + tensor = torch.nn.functional.pad(tensor, (0, pad_k, 0, pad_q), value=pad_value) + + # Reshape to blocks + num_block_rows = padded_seq_q // br + num_block_cols = padded_seq_k // bc + + # Keep natural order for row-level processing: [batch, heads, block_rows, br, block_cols, bc] + blocked = tensor.view(batch_size, num_heads, num_block_rows, br, num_block_cols, bc) + + return blocked, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k + + def calc_correction_factor_and_p( + self, attn_weights: torch.Tensor, phase: str + ) -> tuple[torch.Tensor, dict]: + """Calculate sparse mask and statistics for Flash Attention. + + Implements block-wise sparsity compatible with Flash Attention's online softmax: + 1. Reshape attention scores into 128x128 blocks + 2. Track block-wise maximum values (simulating Flash Attention's row processing) + 3. Compute cumulative maximum across blocks (for online normalization) + 4. Apply threshold: mask blocks where p = score - cummax < log(threshold) + 5. Calculate correction factor and sparsity statistics + + Args: + attn_weights: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + phase: "prefill" (seq_q > 1) or "decode" (seq_q = 1) + + Returns: + element_mask: Boolean mask [batch, heads, seq_q, seq_k] + stats: Dict with sparsity, correction_factor, total_blocks, etc. + """ + batch_size, num_heads, seq_q, seq_k = attn_weights.shape + + # Calculate threshold + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + if threshold_scale_factor: + # Use calibrated dynamic threshold: λ = scale_factor / length + log_threshold = np.log(threshold_scale_factor / seq_k) + else: + # Use static threshold from config + log_threshold = np.log(self.threshold) + + if phase == "prefill": + blocked_attn, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k = ( + self._reshape_to_blocks(attn_weights, self.br, self.bc) + ) + + # Step 1: Compute maximum value in each block + # For each 128x128 block, find max across the 128 columns + # blocked_attn: [batch, heads, block_rows, br=128, block_cols, bc=128] + # block_max: [batch, heads, block_rows, br=128, block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across blocks (left to right) + # This simulates Flash Attention's online softmax normalization + # block_max_cummax: [batch, heads, block_rows, br=128, block_cols] + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor (how often max changes) + # Used by Flash Attention to adjust running sum when max increases + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize attention scores by cumulative max + # p represents log-space difference: log(score) - log(cummax) + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block-level mask + # Keep blocks where at least one element exceeds log(threshold) + p_larger_than_thresh = p > log_threshold + # Reduce over bc (128 cols), then br (128 rows) to get block-level decision + # Result: [batch, heads, block_rows, block_cols] + block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2) + + # Step 6: Expand block mask back to element level + # All 128x128 elements in a block share the same mask value + # [batch, heads, block_rows, block_cols] -> [batch, heads, block_rows, br=128, block_cols, bc=128] + element_mask = block_mask.unsqueeze(-2).unsqueeze(-1).expand_as(blocked_attn) + + # Step 7: Reshape to original attention shape and remove padding + element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 8: Calculate sparsity statistics + # Count kept blocks (averaged across batch and heads) + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + + # Total valid blocks (lower triangle only for causal attention) + # Note: Causal mask pre-applied by attention module, so block_mask naturally + # has zeros in upper triangle. We only count lower triangle for denominator. + total_blocks = ( + num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 + if self.is_causal + else num_block_rows * num_block_cols # Non-causal: N*N + ) + sparsity = 1 - (kept_blocks / total_blocks) + else: # decode + blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( + attn_weights, 1, self.bc + ) + + # Decode: Single query row attends to all past key blocks + # blocked_attn: [batch, heads, 1, 1, num_block_cols, bc=128] + + # Step 1: Find maximum in each key block + # block_max: [batch, heads, 1, 1, num_block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across key blocks (left to right) + # Simulates Flash Attention's online softmax normalization + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor + # Tracks how often the maximum increases (needed for Flash Attention rescaling) + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize scores by cumulative max + # p = log(score) - log(cummax) in log-space + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block mask + # Keep blocks where at least one element exceeds threshold + p_larger_than_thresh = p > log_threshold + block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False) + + # Step 6: Expand to element level and remove padding + element_mask = block_mask[..., None].expand_as(blocked_attn) + element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 7: Calculate statistics + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + total_blocks = num_block_cols + sparsity = 1 - (kept_blocks / total_blocks) + + # Create stats dictionary + stats = { + "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "sparsity": sparsity, + "phase": phase, + "total_blocks": total_blocks, + "sparse_blocks": int(sparsity * total_blocks), + "sample_length": seq_k, + } + + return element_mask, stats + + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply Flash Attention-aware block-wise sparsity. + + Args: + query: Query tensor (unused, for API compatibility) + key: Key tensor (unused, for API compatibility) + value: Value tensor (unused, for API compatibility) + attention_scores: Attention scores tensor with shape [batch, heads, seq_q, seq_k] + + Returns: + Tuple with potentially modified attention_scores + """ + # Attention scores must be provided for sparse attention + assert attention_scores is not None, "attention_scores must be provided for apply_sparsity" + + # Attention scores are always 4D: [batch, heads, seq_q, seq_k] + assert len(attention_scores.shape) == 4, ( + f"Expected 4D attention scores, got shape {attention_scores.shape}" + ) + + # Infer phase from tensor shape + phase = self._infer_phase(attention_scores) + + # Update threshold for the detected phase + self._update_threshold(phase) + + # Apply block-wise sparsity + sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) + + # Store stats for module to collect (doesn't persist across calls) + self._last_stats = stats + + # Apply mask to create sparse scores + mask_value = torch.finfo(attention_scores.dtype).min + sparse_scores = attention_scores.masked_fill(~sparse_mask, mask_value) + + return query, key, value, sparse_scores + + @property + def name(self) -> str: + """Method identifier.""" + return "flash_skip_softmax" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py new file mode 100644 index 000000000..df7b5853b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry and base class for sparse attention methods.""" + +import re +import warnings +from abc import ABC, abstractmethod + +import torch + + +class SparseAttentionMethod(ABC): + """Base class for sparse attention methods.""" + + @abstractmethod + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """Apply sparsity to attention computation. + + Args: + query: Query tensor + key: Key tensor + value: Value tensor + attention_scores: Pre-computed attention scores + + Returns: + Tuple of (query, key, value, attention_scores) with sparsity applied + """ + + @property + @abstractmethod + def name(self) -> str: + """Method name identifier.""" + + +# Method Registry with versioning support +_SPARSE_ATTENTION_METHODS: dict[str, dict[str, type[SparseAttentionMethod]]] = {} + + +def _version_key(version_str: str) -> list[int]: + """Extract numeric parts for proper version sorting. + + Args: + version_str: Version string (e.g., "v1", "v2", "v10") + + Returns: + List of integers extracted from version string for sorting + + Examples: + >>> _version_key("v1") + [1] + >>> _version_key("v10") + [10] + >>> _version_key("v2.3.1") + [2, 3, 1] + """ + parts = re.findall(r"\d+", version_str) + return [int(p) for p in parts] if parts else [0] + + +def register_sparse_method(name: str, version: str = "v1"): + """Decorator to register sparse attention methods with version support. + + Args: + name: Method name to register + version: Version string (default: "v1") + + Example:: + + @register_sparse_method("my_method", version="v3") + class MyMethodV3(SparseAttentionMethod): ... + """ + + def decorator(cls: type[SparseAttentionMethod]): + if name not in _SPARSE_ATTENTION_METHODS: + _SPARSE_ATTENTION_METHODS[name] = {} + + if version in _SPARSE_ATTENTION_METHODS[name]: + warnings.warn( + f"Overriding existing sparse attention method: {name}@{version}", + RuntimeWarning, + stacklevel=2, + ) + + _SPARSE_ATTENTION_METHODS[name][version] = cls + return cls + + return decorator + + +def get_sparse_method(name: str, version: str | None = None) -> type[SparseAttentionMethod]: + """Get sparse attention method by name and optional version. + + Args: + name: Method name to retrieve + version: Optional version string. If None, uses latest version. + + Returns: + Method class + + Raises: + ValueError: If method name or version is not registered + + Example: + >>> get_sparse_method("flash_skip_softmax") # Latest version + >>> get_sparse_method("flash_skip_softmax", "v1") # Specific version + """ + if name not in _SPARSE_ATTENTION_METHODS: + available = list(_SPARSE_ATTENTION_METHODS.keys()) + raise ValueError(f"Unknown sparse attention method: {name}. Available: {available}") + + method_versions = _SPARSE_ATTENTION_METHODS[name] + + if not version: + version = sorted(method_versions.keys(), key=_version_key)[-1] + + if version not in method_versions: + available_versions = list(method_versions.keys()) + raise ValueError( + f"Unknown version {version} for method {name}. Available: {available_versions}" + ) + + return method_versions[version] diff --git a/modelopt/torch/sparsity/attention_sparsity/mode.py b/modelopt/torch/sparsity/attention_sparsity/mode.py new file mode 100644 index 000000000..f389509a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/mode.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention mode descriptor for ModelOpt.""" + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) + +from .config import SparseAttentionConfig +from .conversion import ( + convert_to_sparse_attention_model, + restore_sparse_attention_model, + update_sparse_attention_metadata, +) + +# Create registry for sparse attention modes +SparseAttentionModeRegistry = _ModeRegistryCls("sparse_attention") + + +@SparseAttentionModeRegistry.register_mode +class SparseAttentionModeDescriptor(ModeDescriptor): + """Mode descriptor for sparse attention optimization. + + This mode enables various sparse attention methods to reduce + computational complexity and memory usage in transformer models. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "sparse_attention" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return SparseAttentionConfig + + @property + def next_prohibited_modes(self) -> set[str] | None: + """Modes that should not be applied after this mode.""" + # Can work with quantization but not with weight sparsity + return {"sparsity"} + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode of this mode.""" + return "export_sparse_attention" + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_sparse_attention_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_sparse_attention_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before saving.""" + return update_sparse_attention_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before new mode.""" + return update_sparse_attention_metadata diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py new file mode 100644 index 000000000..88434e746 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main API functions for sparse attention optimization.""" + +from typing import Any + +import torch + +from modelopt.torch.opt.conversion import apply_mode +from modelopt.torch.opt.searcher import ForwardLoop + +from .config import SparseAttentionConfig +from .mode import SparseAttentionModeRegistry + +__all__ = [ + "sparsify", +] + + +def sparsify( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Applies sparse attention optimization to the model in-place. + + This method performs replacement of attention modules with their sparse counterparts. + + Args: + model: A pytorch model + config: A dictionary or an instance of + :class:`SparseAttentionConfig ` + specifying the values for keys ``"sparse_cfg"`` and ``"method"``. + + The ``"sparse_cfg"`` key specifies the sparse attention configurations. + The ``"method"`` key specifies the sparse attention method (e.g., "flash_skip_softmax"). + + Sparse attention configurations is a dictionary mapping wildcards or filter functions + to its sparse attention attributes. The wildcards or filter functions are matched + against the module names. The sparse attention attributes include ``"threshold"``, + ``"enable"``, and method-specific parameters. + + An example ``config`` dictionary is given below: + + .. code-block::python + + config = { + "method": "flash_skip_softmax", + "sparse_cfg": { + "*attention*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + } + + The ``"backend"`` parameter must be set to ``"pytorch"``: + + - ``"pytorch"``: Softmax patching approach (only supported backend) + + This requires the model to be loaded with ``attn_implementation="eager"``. + + forward_loop: Reserved for future use. + + Here are a few examples for correct ``forward_loop`` definitions: + + Example 1: + + .. code-block:: + + def forward_loop(model) -> None: + # iterate over the data loader and forward data through the model + for batch in data_loader: + model(batch) + + Example 2: + + .. code-block:: + + def forward_loop(model) -> float: + # evaluate the model on the task + return evaluate(model, task, ....) + + .. note:: + + Calibration does not require forwarding the entire dataset through the model. + Please subsample the dataset or reduce the number of batches if needed. + + .. important:: + + The model must always be loaded with ``attn_implementation="eager"`` + for sparse attention to work correctly: + + .. code-block:: python + + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, + ) + + This is because sparse attention works by patching torch.nn.functional.softmax, + which is only called in the eager attention implementation. + + Returns: + A pytorch model which has sparse attention applied and optionally calibrated. + """ + model = apply_mode( + model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry + ) + + return model diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py new file mode 100644 index 000000000..ba8c8b821 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugins for sparse attention integration with various frameworks.""" + +from .huggingface import register_sparse_attention_on_the_fly + +__all__ = [ + "register_sparse_attention_on_the_fly", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py new file mode 100644 index 000000000..b0cd1dff6 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic sparse attention registration for HuggingFace models.""" + +import logging + +import torch.nn as nn +import transformers + +from modelopt.torch.opt.dynamic import DynamicModule + +from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry + +logger = logging.getLogger(__name__) + + +class _GenericSparseAttention(SparseAttentionModule): + """Generic sparse attention that works with any HF attention module. + + This class provides a universal sparse attention wrapper that can + work with various transformer attention implementations. + """ + + def _setup(self): + """Setup sparse attention for any attention type. + + The base SparseAttentionModule handles detection and initialization. + """ + super()._setup() + + def get_attn_type(self, attn_module) -> type: + """Get the original attention type. + + Args: + attn_module: Attention module (possibly wrapped) + + Returns: + Original class type + """ + # If this is a DynamicModule, get the original class + if isinstance(attn_module, DynamicModule): + return attn_module.get_original_cls_by_level(level=0) + return type(attn_module) + + +def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: + """Dynamically register sparse attention for any model. + + This function automatically detects attention modules in the model + and registers them for sparse attention optimization. + + Args: + model: Model to process + + Returns: + True if any modules were registered + """ + if not _is_supported_model(model): + return False + + registered_count = 0 + attention_types = set() + + for name, module in model.named_modules(): + # Skip if already a sparse attention module + if isinstance(module, SparseAttentionModule): + continue + + # Check if this is an attention module by name + module_type = type(module) + type_name = module_type.__name__ + + # Common attention module patterns + is_attention = "attention" in type_name.lower() or type_name.endswith( + ("Attention", "SelfAttention") + ) + + if is_attention and module_type not in SparseAttentionRegistry: + # Register attention type + if module_type not in attention_types: + SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) + attention_types.add(module_type) + registered_count += 1 + logger.info(f"Registered {type_name} for sparse attention optimization") + + if registered_count > 0: + logger.info( + f"Dynamically registered {registered_count} attention module types for sparsity" + ) + + return registered_count > 0 + + +def _is_supported_model(model: nn.Module) -> bool: + """Check if model is supported for sparse attention. + + Supports HuggingFace PreTrainedModel and any PyTorch model with attention modules. + + Args: + model: Model to check + + Returns: + True if model is supported + """ + # Check for HuggingFace PreTrainedModel + try: + if isinstance(model, transformers.PreTrainedModel): + return True + except ImportError: + pass + + # Support any PyTorch model with attention modules + return isinstance(model, nn.Module) diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py new file mode 100644 index 000000000..16b08bf19 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention module.""" + +import torch +import torch.nn.functional as F + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from modelopt.torch.quantization.utils import replace_function + +from .config import SparseAttentionAttributeConfig +from .methods import get_sparse_method + + +class SparseAttentionModule(DynamicModule): + """Generic sparse attention module wrapper for applying sparsity to attention layers. + + This module wraps existing attention implementations to add sparse attention + capabilities by patching torch.nn.functional.softmax. + + Forward Flow: + ------------- + 1. Check if sparse attention is enabled (pass-through if disabled) + 2. Create softmax patch context with sparse_softmax function + 3. Apply sparse attention by patching F.softmax: + - Patches torch.nn.functional.softmax with sparse_softmax + - sparse_softmax applies method's sparsity logic before softmax + 4. Forward through original attention with sparsity applied + + Requirements: + ------------- + - Model must be loaded with attn_implementation="eager" for proper softmax interception + - Only PyTorch backend is supported (patches F.softmax) + + Attributes: + ----------- + _enabled: bool + Whether sparse attention is enabled + _method: str + The sparse attention method to use (e.g., "flash_skip_softmax") + _method_config: dict + Configuration dictionary for the sparse method (threshold, br, bc, etc.) + _sparse_method_instance: SparseAttentionMethod + Instance of the configured sparse attention method + """ + + def set_from_attribute_config( + self, attribute_cfg: SparseAttentionAttributeConfig | dict | None = None + ): + """Set sparse attention attributes from configuration. + + Similar to TensorQuantizer.set_from_attribute_config. + + Args: + attribute_cfg: Sparse attention attribute configuration. + """ + # Ensure config is validated through Pydantic + if not isinstance(attribute_cfg, SparseAttentionAttributeConfig): + attribute_cfg = SparseAttentionAttributeConfig(**(attribute_cfg or {})) + + # Store raw config for method initialization + self._method_config = {} + + # Define which attributes are method-specific vs module-specific + # Module-specific attributes control the SparseAttentionModule behavior + _module_attributes = {"enable", "method"} + + # Custom setters for special module attributes + _custom_setters = { + "enable": ("_enabled", lambda val: bool(val)), + "method": ("_method", lambda val: str(val)), + } + + # Process each attribute from validated config + for attribute, val in attribute_cfg.model_dump().items(): + # Validate attribute if using config class + if hasattr(SparseAttentionAttributeConfig, "model_fields"): + assert attribute in SparseAttentionAttributeConfig.model_fields, ( + f"{attribute} is not a valid SparseAttentionModule attribute" + ) + + if attribute in _module_attributes: + # Module-level attribute: store with underscore prefix + attr_name, setter = _custom_setters.get(attribute, (f"_{attribute}", lambda v: v)) + setattr(self, attr_name, setter(val)) + else: + # Method-specific attribute: store in config dict + self._method_config[attribute] = val + + # Initialize sparse method instance + self._init_sparse_method() + + def _init_sparse_method(self): + """Initialize the sparse method instance.""" + method_class = get_sparse_method(self._method) + + # Initialize the sparse method instance + # _method_config is always initialized in set_from_attribute_config + self._sparse_method_instance = method_class(method_config=self._method_config) # type: ignore[call-arg] + + def enable(self): + """Enable sparse attention for this module.""" + self._enabled = True + + def disable(self): + """Disable sparse attention for this module.""" + self._enabled = False + + @property + def is_enabled(self) -> bool: + """Check if sparse attention is enabled.""" + return getattr(self, "_enabled", True) + + def get_stats(self) -> dict: + """Get sparsity statistics from the stats manager. + + Returns: + Dictionary with sparsity statistics including 'average_sparsity' if available. + Returns empty dict (statistics collection will be added in calibration PR). + """ + # TODO: Statistics collection will be added in calibration PR + return {} + + def _setup(self): + """Setup called by DynamicModule.""" + # Apply default configuration if not yet configured + if not hasattr(self, "_method"): + self.set_from_attribute_config(None) + + def forward(self, *args, **kwargs): + """Forward with selected sparse attention method. + + This method dispatches to the appropriate sparse attention implementation + based on the configured method and backend. + """ + # Pass through if sparse attention is disabled + if not self.is_enabled: + return super().forward(*args, **kwargs) + + # Get the appropriate context manager for this configuration + context = self._get_sparse_context() + + # Apply sparse attention through the context + with context: + result = super().forward(*args, **kwargs) + + return result + + def _get_sparse_context(self): + """Get the softmax patch context for applying sparse attention.""" + return self._create_softmax_patch_context() + + def _create_softmax_patch_context(self): + """Create context manager for patching softmax function.""" + return replace_function(torch.nn.functional, "softmax", self._create_sparse_softmax()) + + def _create_sparse_softmax(self): + """Create sparse softmax function for current method.""" + original_softmax = F.softmax + + def sparse_softmax(input, dim=-1, *args, **kwargs): + # Let the method handle the sparsification + _, _, _, sparse_input = self._sparse_method_instance.apply_sparsity( + None, None, None, input + ) + + # Use sparse input if modified, otherwise use original + if sparse_input is not None: + return original_softmax(sparse_input, dim, *args, **kwargs) + return original_softmax(input, dim, *args, **kwargs) + + return sparse_softmax + + +# Create registry for sparse attention modules +SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule) diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch_sparsity/sparse_attention_common.py new file mode 100644 index 000000000..7724908b0 --- /dev/null +++ b/tests/_test_utils/torch_sparsity/sparse_attention_common.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for sparse attention testing.""" + +import torch +import torch.nn as nn + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +# Test models for sparse attention +class SimpleAttentionModel(nn.Module): + """Simple attention model for testing.""" + + def __init__(self, hidden_size=256, num_heads=8): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.attention = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True + ) + self.fc = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + attn_output, _ = self.attention(x, x, x, need_weights=False) + return self.fc(attn_output) + + @classmethod + def get_input(cls, hidden_size=256, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, hidden_size) + + +class SimpleTransformerEncoderLayer(nn.Module): + """Simple TransformerEncoderLayer wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, dim_feedforward=256): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + + def forward(self, x): + return self.layer(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=20, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +class SimpleTransformerEncoder(nn.Module): + """Simple TransformerEncoder wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, num_layers=2): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True), + num_layers=num_layers, + ) + + def forward(self, x): + return self.encoder(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +# Test configurations +FLASH_SKIP_SOFTMAX_DEFAULT_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + +FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + + +def get_test_configs(): + """Get test configurations for parameterized tests. + + Note: Calibration config excluded (requires GPU and real tokenizers). + """ + return [FLASH_SKIP_SOFTMAX_DEFAULT_CFG, FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG] + + +def sparsify_model_and_forward(model, config, calib_data): + """Apply sparse attention and run forward passes. + + Args: + model: Model to sparsify + config: Sparse attention configuration + calib_data: List of calibration data tensors + + Returns: + Sparsified model + """ + + def forward_loop(model): + for batch in calib_data: + model(batch) + + # Apply sparse attention + model = sparse_attn.sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules were inserted + assert any(isinstance(m, SparseAttentionModule) for m in model.modules()), ( + "No sparse attention modules found" + ) + + # Test forward passes + model.eval() + with torch.no_grad(): + for batch in calib_data: + output = model(batch) + assert not torch.isnan(output).any(), "NaN in output" + assert output is not None, "Output is None" + + return model + + +def save_restore_test(model_cls, device, sparse_config): + """Test save and restore of sparse attention state. + + Args: + model_cls: Model class to test + device: Device to run on ('cpu' or 'cuda') + sparse_config: Sparse attention configuration + """ + # Create and sparsify reference model + model_sparse = model_cls().to(device) + calib_data = [model_sparse.get_input().to(device) for _ in range(2)] + + sparsify_model_and_forward(model_sparse, sparse_config, calib_data) + + # Save state + state_dict = mto.modelopt_state(model_sparse) + + # Restore to new model + model_restored = model_cls().to(device) + mto.restore_from_modelopt_state(model_restored, state_dict) + model_restored.load_state_dict(model_sparse.state_dict()) + + # Verify outputs match + test_input = calib_data[0] + model_sparse.eval() + model_restored.eval() + + with torch.no_grad(): + output_sparse = model_sparse(test_input) + output_restored = model_restored(test_input) + + assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + "Restored model output doesn't match original" + ) diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py new file mode 100644 index 000000000..b82303990 --- /dev/null +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test attention sparsity example script.""" + +import pytest +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +from _test_utils.torch.misc import minimum_gpu + + +def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", **kwargs): + """Run attention sparsity example script. + + Args: + model: Path to model + method: Sparse attention method (corresponds to --sparse_attn arg) + **kwargs: Additional arguments to pass to the script + """ + kwargs.update( + { + "pyt_ckpt_path": model, + "sparse_attn": method, + } + ) + kwargs.setdefault("seq_len", 128) + kwargs.setdefault("num_samples", 1) + kwargs.setdefault("max_new_tokens", 16) + + cmd_parts = extend_cmd_parts(["python", "hf_sa.py"], **kwargs) + run_example_command(cmd_parts, "llm_sparsity/attention_sparsity") + + +@minimum_gpu(1) +@pytest.mark.parametrize("method", ["skip_softmax"]) +def test_attention_sparsity(tiny_llama_path, tmp_path, method): + """Test sparse attention with TinyLlama.""" + run_attention_sparsity_command( + model=tiny_llama_path, + method=method, + ) diff --git a/tests/examples/llm_sparsity/test_llama_sparsify.py b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py similarity index 93% rename from tests/examples/llm_sparsity/test_llama_sparsify.py rename to tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py index 7f9ef929b..7094b2989 100644 --- a/tests/examples/llm_sparsity/test_llama_sparsify.py +++ b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py @@ -31,7 +31,7 @@ def run_llm_sparsity_command( kwargs.setdefault("model_max_length", 1024) cmd_parts = extend_cmd_parts(["python", "hf_pts.py"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") def run_llm_sparsity_ft_command( @@ -51,13 +51,15 @@ def run_llm_sparsity_ft_command( kwargs.setdefault("eval_bs", 1) cmd_parts = extend_cmd_parts(["bash", "launch_finetune.sh"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") @pytest.fixture(scope="session") def data_path(tmp_path_factory): data_path = tmp_path_factory.mktemp("data") - run_example_command(["python", "data_prep.py", "--save_path", data_path], "llm_sparsity") + run_example_command( + ["python", "data_prep.py", "--save_path", data_path], "llm_sparsity/weight_sparsity" + ) # Copy eval data to train path for faster test run_example_command( diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py new file mode 100644 index 000000000..bad077fdb --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for attention sparsity module.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoder, + SimpleTransformerEncoderLayer, + get_test_configs, + save_restore_test, + sparsify_model_and_forward, +) + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +class TestAttentionSparsityGPU: + """GPU tests for attention sparsity.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup for each test.""" + self.device = torch.device("cuda") + torch.cuda.empty_cache() + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + @pytest.mark.parametrize("config", get_test_configs()) + def test_gpu_forward(self, model_cls, config): + """Test sparse attention forward pass on GPU.""" + model = model_cls().to(self.device) + calib_data = [model.get_input().to(self.device) for _ in range(2)] + + sparsify_model_and_forward(model, config, calib_data) + + # Additional GPU-specific checks + for batch in calib_data: + with torch.no_grad(): + output = model(batch) + assert output.device.type == "cuda" + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + def test_save_restore(self, model_cls): + """Test save and restore on GPU.""" + save_restore_test(model_cls, "cuda", FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_different_dtypes(self, dtype): + """Test sparse attention with different dtypes.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).to(self.device).to(dtype) + calib_data = [model.get_input(d_model=256).to(self.device).to(dtype) for _ in range(2)] + + sparse_model = sparsify_model_and_forward(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG, calib_data) + + # Test forward + x = model.get_input(d_model=256).to(self.device).to(dtype) + with torch.no_grad(): + output = sparse_model(x) + + assert output.dtype == dtype + assert not torch.isnan(output).any() + if dtype != torch.bfloat16: # bfloat16 can have inf + assert not torch.isinf(output).any() + + def test_backward_pass(self): + """Test that gradients flow correctly through sparse attention.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Enable training mode + model.train() + + x = model.get_input(hidden_size=128, seq_len=32).to(self.device) + x.requires_grad = True + + # Forward + output = model(x) + loss = output.sum() + + # Backward + loss.backward() + + # Check gradients exist + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + # Check model gradients + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + @pytest.mark.parametrize("seq_len", [1, 1024, 2048]) + def test_various_sequence_lengths(self, seq_len): + """Test sparse attention with various sequence lengths.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(hidden_size=128, seq_len=seq_len, batch_size=1).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (1, seq_len, 128) + assert not torch.isnan(output).any() + + @pytest.mark.parametrize("batch_size", [1, 8, 16]) + def test_various_batch_sizes(self, batch_size): + """Test sparse attention with various batch sizes.""" + model = SimpleTransformerEncoderLayer(d_model=128, nhead=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(d_model=128, seq_len=64, batch_size=batch_size).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (batch_size, 64, 128) + assert not torch.isnan(output).any() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py new file mode 100644 index 000000000..586cb3b9d --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration testing with locally created minimal Llama model.""" + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create minimal Llama model locally.""" + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, # Minimal layers for fast testing + hidden_size=512, + intermediate_size=1024, + ) + + +@pytest.fixture(scope="module") +def tinyllama_model(tiny_llama_dir): + """Load locally created tiny Llama model.""" + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + device_map="cuda", + ) + return model + + +@pytest.fixture(scope="module") +def tinyllama_tokenizer(tiny_llama_dir): + """Load tokenizer for tiny Llama model.""" + tokenizer = AutoTokenizer.from_pretrained(tiny_llama_dir) + return tokenizer + + +class TestTinyLlama: + """TinyLlama sparse attention tests.""" + + def test_load_and_sparsify(self, tinyllama_model): + """Load TinyLlama and apply sparse attention.""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse attention modules were added + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + assert sparse_count > 0, "No sparse attention modules found" + + # Our tiny llama has 2 layers, so should have 2 attention modules + assert sparse_count == 2, f"Expected 2 sparse modules, got {sparse_count}" + + def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): + """Forward pass with seq_len=64 (prefill).""" + model = tinyllama_model + tokenizer = tinyllama_tokenizer + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create prefill input (seq_len > 1) + test_text = "Once upon a time in a land far away" + inputs = tokenizer(test_text, return_tensors="pt").to("cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(**inputs) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape[1] == inputs.input_ids.shape[1] # seq_len preserved + + def test_forward_decode(self, tinyllama_model): + """Forward pass with seq_len=1 (decode).""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-5, # More conservative for decode + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create decode input (seq_len = 1) + input_ids = torch.randint(0, 32000, (1, 1), device="cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape == (1, 1, 32000) # batch=1, seq=1, vocab_size + + def test_gqa_attention(self, tinyllama_model): + """Verify GQA support (num_kv_heads < num_heads).""" + model = tinyllama_model + + # Check if model uses GQA + config = model.config + has_gqa = hasattr(config, "num_key_value_heads") and ( + config.num_key_value_heads < config.num_attention_heads + ) + + if not has_gqa: + pytest.skip("Model does not use GQA") + + # Apply sparse attention + sparse_config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, sparse_config) + + # Test forward pass with GQA + input_ids = torch.randint(0, 32000, (1, 32), device="cuda") + + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py new file mode 100644 index 000000000..b487d8639 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FlashSkipSoftmax method internals.""" + +import pytest +import torch + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax + + +class TestFlashSkipSoftmaxMethod: + """Test FlashSkipSoftmax method internals.""" + + def test_phase_inference(self): + """Test phase detection from attention score shape.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Prefill: seq_q > 1 + prefill_scores = torch.randn(2, 4, 64, 64) + assert method._infer_phase(prefill_scores) == "prefill" + + # Decode: seq_q = 1 + decode_scores = torch.randn(2, 4, 1, 64) + assert method._infer_phase(decode_scores) == "decode" + + def test_threshold_update_dict_config(self): + """Test threshold updates with dict config.""" + method = FlashSkipSoftmax( + { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Initially uses prefill threshold + initial_threshold = method.threshold + + # Update to decode + method._update_threshold("decode") + assert method.threshold == 1e-5 + assert method.threshold != initial_threshold + + # Update back to prefill + method._update_threshold("prefill") + assert method.threshold == 1e-3 + + def test_threshold_update_static_config(self): + """Test threshold with static float config.""" + method = FlashSkipSoftmax( + { + "threshold": 5e-4, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + initial_threshold = method.threshold + assert initial_threshold == 5e-4 + + # Should not change for static config + method._update_threshold("decode") + assert method.threshold == 5e-4 + + def test_block_reshaping_divisible(self): + """Test block reshaping with divisible sequence lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths divisible by 128 + attn = torch.randn(2, 4, 256, 256) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify block dimensions + assert blocked.shape == (2, 4, 2, 128, 2, 128) # 256/128 = 2 blocks + assert num_br == 2 + assert num_bc == 2 + assert padded_q == 256 # No padding + assert padded_k == 256 # No padding + + def test_block_reshaping_with_padding(self): + """Test block reshaping with non-divisible lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths NOT divisible by 128 + attn = torch.randn(2, 4, 200, 300) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify padding applied + assert padded_q == 256 # ceil(200/128) * 128 = 2 * 128 + assert padded_k == 384 # ceil(300/128) * 128 = 3 * 128 + assert num_br == 2 + assert num_bc == 3 + assert blocked.shape == (2, 4, 2, 128, 3, 128) + + def test_correction_factor_calculation_prefill(self): + """Test correction factor for prefill phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Create simple attention pattern + attn = torch.randn(1, 1, 128, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify stats structure + assert "correction_factor" in stats + assert "sparsity" in stats + assert "phase" in stats + assert "total_blocks" in stats + assert stats["phase"] == "prefill" + assert 0 <= stats["correction_factor"] <= 1 + # Sparsity can be negative if threshold is too low (more blocks kept than expected) + assert -1 <= stats["sparsity"] <= 1 + + def test_correction_factor_calculation_decode(self): + """Test correction factor for decode phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-5, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Decode: single query + attn = torch.randn(1, 1, 1, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "decode") + + # Verify stats structure + assert stats["phase"] == "decode" + assert "correction_factor" in stats + assert 0 <= stats["sparsity"] <= 1 + assert mask.shape == (1, 1, 1, 256) + + def test_sparsity_statistics(self): + """Test sparsity statistics structure.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(1, 1, 128, 256) + _, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify statistics are present + assert stats["total_blocks"] > 0 + assert "sparse_blocks" in stats + assert "sample_length" in stats + assert stats["sample_length"] == 256 + + def test_block_mask_correctness(self): + """Test block mask shape and type.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(2, 4, 128, 256) + mask, _ = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify mask properties + assert mask.shape == attn.shape + assert mask.dtype == torch.bool + assert mask.device == attn.device + + def test_causal_vs_noncausal(self): + """Test total_blocks calculation for causal vs non-causal.""" + config_base = { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + } + + method_causal = FlashSkipSoftmax({**config_base, "is_causal": True}) + method_noncausal = FlashSkipSoftmax({**config_base, "is_causal": False}) + + attn = torch.randn(1, 1, 256, 256) # 2x2 blocks + + _, stats_causal = method_causal.calc_correction_factor_and_p(attn, "prefill") + _, stats_noncausal = method_noncausal.calc_correction_factor_and_p(attn, "prefill") + + # Causal: 2*(2+1)/2 = 3 blocks + # Non-causal: 2*2 = 4 blocks + assert stats_causal["total_blocks"] == 3 + assert stats_noncausal["total_blocks"] == 4 + + def test_apply_sparsity_assertions(self): + """Test apply_sparsity input validation.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Test: attention_scores required + with pytest.raises(AssertionError, match="attention_scores must be provided"): + method.apply_sparsity() + + # Test: 4D shape required + with pytest.raises(AssertionError, match="Expected 4D"): + method.apply_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D + + def test_name_property(self): + """Test method name property.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + assert method.name == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..1824825f9 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSkipSoftmaxConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert "sparse_cfg" in config.model_dump() + # Check default pattern has method + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSkipSoftmaxConfig: + """Test FlashSkipSoftmaxConfig.""" + + def test_default_values(self): + """Test default values for flash_skip_softmax config.""" + config = FlashSkipSoftmaxConfig() + assert "*attention*" in config.sparse_cfg + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py new file mode 100644 index 000000000..8df8fe476 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention conversion and replacement.""" + +import pytest + +pytest.importorskip("transformers") + +import torch.nn as nn +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoderLayer, +) + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + print_sparse_attention_summary, +) +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestSparseAttentionReplacement: + """Test module replacement logic.""" + + def test_basic_replacement(self): + """Test that attention modules are replaced with sparse versions.""" + model = SimpleAttentionModel() + + # Count original attention modules + original_attention_count = sum( + isinstance(m, nn.MultiheadAttention) for m in model.modules() + ) + assert original_attention_count > 0 + + # Apply sparse attention + sparse_model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Count sparse attention modules + sparse_attention_count = sum( + isinstance(m, SparseAttentionModule) for m in sparse_model.modules() + ) + + # Verify replacement occurred + assert sparse_attention_count > 0 + + def test_enable_disable_toggle(self): + """Test enabling and disabling sparse attention.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Check initially enabled + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + # Disable all sparse attention modules + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Re-enable all sparse attention modules + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_pattern_based_replacement(self): + """Test pattern-based selective replacement.""" + model = SimpleTransformerEncoderLayer() + + # Apply with pattern + config = { + "sparse_cfg": { + "*self_attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + }, + "default": {"enable": False}, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse modules exist + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + +class TestConversionEdgeCases: + """Test edge cases and error paths in conversion.""" + + def test_callable_filter(self): + """Test using callable filter instead of wildcard.""" + model = SimpleAttentionModel() + + # Use callable filter + def filter_func(name): + return "attn" in name + + config = { + "sparse_cfg": { + filter_func: { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + def test_no_matching_modules(self): + """Test pattern that matches nothing.""" + model = SimpleAttentionModel() + + config = { + "sparse_cfg": { + "*nonexistent*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + # Should not error, even with no matches + sparse_attn.sparsify(model, config) + + def test_disable_enable_functions(self): + """Test disable/enable utility functions.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + ) + + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Disable all + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Enable all + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Total sparse attention modules:" in captured.out + assert "Enabled:" in captured.out + + def test_restore_sparse_attention_model(self): + """Test save/restore via modelopt_state.""" + # Create and sparsify original model + model_orig = SimpleAttentionModel() + model_orig = sparse_attn.sparsify(model_orig, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Save state + state_dict = mto.modelopt_state(model_orig) + + # Restore to new model + model_restored = SimpleAttentionModel() + mto.restore_from_modelopt_state(model_restored, state_dict) + + # Verify restoration + has_sparse = any(isinstance(m, SparseAttentionModule) for m in model_restored.modules()) + assert has_sparse + + # Verify module is configured + for module in model_restored.modules(): + if isinstance(module, SparseAttentionModule): + assert hasattr(module, "_method") + assert module._method == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py new file mode 100644 index 000000000..e7e32e153 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention mode registry.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.opt.mode import _ModeRegistryCls +from modelopt.torch.sparsity.attention_sparsity.mode import SparseAttentionModeRegistry + + +def test_sparse_attention_mode_exists(): + """Test that sparse_attention mode is registered.""" + assert "sparse_attention" in SparseAttentionModeRegistry + + +def test_sparse_attention_mode_descriptor(): + """Test sparse attention mode descriptor properties.""" + mode_descriptor = _ModeRegistryCls.get_from_any("sparse_attention") + + assert mode_descriptor is not None + assert hasattr(mode_descriptor, "config_class") + assert hasattr(mode_descriptor, "convert") + + +def test_mode_registry_get(): + """Test getting mode from registry.""" + mode = SparseAttentionModeRegistry["sparse_attention"] + assert mode is not None