Skip to content

shreyashkar-ml/native_sparse_attention

Repository files navigation

Native Sparse Attention (NSA)

A stand-alone .py file PyTorch implementation of Native Sparse Attention - an efficient attention mechanism that combines local windowed attention, compressed token attention, and selective block attention to achieve significant memory reduction (O(n²/stride) vs O(n²)) while maintaining performance.

Features

  • Hybrid Attention: Combines three attention branches (local, compressed, selective) for optimal efficiency-performance trade-off
  • Reduced Memory: O(n²/stride) memory complexity - a multiplicative reduction from O(n²) standard attention
  • Configurable: Extensive configuration options for different use cases
  • Drop-in Replacement: Easy integration with existing transformer architectures
  • Flash Attention Support: Optional Flash Attention acceleration when available
  • Multiple Compression: Support for GroupedMLP, Conv1D, and AvgPool compression methods

Installation

git clone https://github.com/shreyashkar-ml/native_sparse_attention
uv pip install -r requirements.txt

Quick Start

import torch
from native_sparse_attention import NSAConfig, NSABlock

# Configure NSA
config = NSAConfig(
    dim=512,                    # Model dimension
    heads=8,                    # Number of attention heads
    seq_len=2048,               # Maximum sequence length
    local_window=128,           # Local attention window
    block_size=32,              # Compression block size
    topk_blocks=4,              # Top-K selective blocks
    compression="grouped_mlp"   # Compression method
)

# Create an NSA block
nsa_block = NSABlock(config)

# Use it like any PyTorch module
x = torch.randn(2, 1024, 512)  # (batch, seq_len, dim)
output = nsa_block(x)          # Same shape as input

Configuration Options

The NSAConfig class provides extensive customization options:

Basic Parameters

NSAConfig(
    dim=512,                    # Model dimension
    heads=8,                    # Number of attention heads  
    dim_head=None,              # Head dimension (auto: dim // heads)
    seq_len=2048,               # Maximum sequence length
)

Sparse Attention Parameters

NSAConfig(
    local_window=128,           # Local sliding window size (W)
    block_size=32,              # Block size for compression (B)
    stride=32,                  # Stride for block compression (S)
    topk_blocks=4,              # Number of top-K blocks to select (K)
)

Compression Methods

Choose from three compression strategies:

# Grouped MLP (default - best performance)
config = NSAConfig(compression="grouped_mlp")

# 1D Convolution (memory efficient)
config = NSAConfig(compression="conv1d") 

# Average Pooling (fastest)
config = NSAConfig(compression="avgpool")

Gating Options

Control how the three attention branches are combined:

# Static gating (fixed weights)
config = NSAConfig(
    gate_mode="static",
    gate_init=(2.0, -2.0, -2.0)  # (local, compressed, selective)
)

# Query-conditioned gating (adaptive weights)
config = NSAConfig(gate_mode="q_cond")

Performance Options

config = NSAConfig(
    dropout=0.1,                # Dropout probability
    use_flash=True,             # Use Flash Attention when available
)

Architecture Overview

NSA combines three attention mechanisms:

  1. Local Attention: Causal sliding window attention for recent context
  2. Compressed Attention: Attention over compressed representations of past tokens
  3. Selective Attention: Dynamic selection of important blocks based on query-key similarities

The outputs are combined using learnable gating weights that can be static or query-conditioned.

Integration as Drop-in Replacement

Replace Standard Transformer Blocks

import torch.nn as nn
from native_sparse_attention import NSABlock, NSAConfig

class YourTransformer(nn.Module):
    def __init__(self, vocab_size, n_layers=6, n_embd=384, n_head=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_embd)
        
        # Configure NSA
        nsa_config = NSAConfig(
            dim=n_embd,
            heads=n_head,
            seq_len=1024,           # Your max sequence length
            local_window=64,        # Adjust based on your needs
            block_size=32,
            topk_blocks=4,
            compression="grouped_mlp"
        )
        
        # Replace standard transformer blocks with NSA blocks
        self.blocks = nn.Sequential(*[
            NSABlock(nsa_config) for _ in range(n_layers)
        ])
        
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.blocks(x)          # NSA blocks handle attention + FFN
        x = self.ln_f(x)
        return self.lm_head(x)

Replace Only Attention Layers

from native_sparse_attention import NativeSparseAttention, NSAConfig

class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        
        # Replace MultiHeadAttention with NativeSparseAttention
        nsa_config = NSAConfig(dim=n_embd, heads=n_head, seq_len=1024)
        self.attn = NativeSparseAttention(nsa_config)
        
        self.ln2 = nn.LayerNorm(n_embd)
        self.ffn = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
        )
    
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

Training Example

The repository includes nsa_gpt_training.py - a complete example of training a language model with NSA on the Tiny Shakespeare dataset.

Key Features of the Training Script

  1. Complete Language Model: Implements a GPT-style model using NSA blocks
  2. Quick Test Mode: Fast verification that NSA works as drop-in replacement
  3. Full Training Mode: Complete training loop with evaluation
  4. Text Generation: Sample generation to verify model quality

Configuration Used in Training

from native_sparse_attention import NSAConfig

nsa_config = NSAConfig(
    dim=384,                    # Model dimension (384/6 = 64 per head)
    heads=6,                    # Number of attention heads
    seq_len=256,                # Context length
    local_window=64,            # Local attention window
    block_size=32,              # Compression block size
    stride=32,                  # Compression stride
    topk_blocks=4,              # Top-K blocks for selective attention
    compression="grouped_mlp",  # Use GroupedMLP compression
    dropout=0.2,                # Dropout rate
    use_flash=True,             # Enable Flash Attention
    gate_mode="static",         # Static gating
    gate_init=(2.0, -2.0, -2.0) # Favor local attention initially
)

Running the Training Example

# Download Tiny Shakespeare dataset (or use your own text file)
wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O native-sparse-attention/input.txt

# Quick test (few iterations to verify correctness)
python native-sparse-attention/nsa_gpt_training.py  # QUICK_TEST=True by default

# Full training (set QUICK_TEST=False in the script)
# Edit the script to set QUICK_TEST = False, then:
python native-sparse-attention/nsa_gpt_training.py

Model Architecture in Training Example

class NSALanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embd = nn.Embedding(vocab_size, n_embd)
        self.position_embd = nn.Embedding(block_size, n_embd)
        
        # Stack of NSA blocks (drop-in replacement for transformer blocks)
        self.blocks = nn.Sequential(*[
            NSABlock(nsa_config) for _ in range(n_layer)
        ])
        
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

Performance Characteristics

Memory Complexity

  • Standard Attention: O(n²) memory
  • NSA: O(n²/stride) memory - multiplicative reduction by compression stride
    • Local branch: O(n × window_size) ≈ O(n)
    • Compressed branch: O(n × n/stride) = O(n²/stride) [dominant]
    • Selective branch: O(n × topk_blocks × block_size) ≈ O(n)

Recommended Settings by Sequence Length

# Short sequences (< 1K tokens)
short_config = NSAConfig(local_window=64, block_size=16, topk_blocks=2)

# Medium sequences (1K - 4K tokens)  
medium_config = NSAConfig(local_window=128, block_size=32, topk_blocks=4)

# Long sequences (4K+ tokens)
long_config = NSAConfig(local_window=256, block_size=64, topk_blocks=8)

Advanced Usage

Custom Compression

Implement your own compression method:

from native_sparse_attention import TokenCompressor, NSAConfig

class CustomCompression(nn.Module):
    def __init__(self, dim: int, block: int):
        super().__init__()
        # Your custom compression logic
        
    def forward(self, x: torch.Tensor):
        # Compress tokens from (b, n, d) to (b, n//block_size, d)
        return compressed_x

# Then modify TokenCompressor to use your method

Visualization and Analysis

# Access attention patterns (when use_flash=False)
nsa = NativeSparseAttention(config)
nsa.cfg.use_flash = False  # Disable flash for pattern access

# Forward pass returns attention weights for analysis
with torch.no_grad():
    output = nsa(x)
    # Access internal attention patterns for visualization

Requirements

  • Python 3.8+
  • PyTorch 2.0+
  • einops
  • CUDA (optional, for Flash Attention acceleration)

Testing

The implementation includes built-in tests:

python native_sparse_attention.py

Tests verify:

  • Basic functionality (smoke test)
  • Causal masking correctness
  • Flash Attention parity (when available)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published