A stand-alone .py file PyTorch implementation of Native Sparse Attention - an efficient attention mechanism that combines local windowed attention, compressed token attention, and selective block attention to achieve significant memory reduction (O(n²/stride) vs O(n²)) while maintaining performance.
- Hybrid Attention: Combines three attention branches (local, compressed, selective) for optimal efficiency-performance trade-off
- Reduced Memory: O(n²/stride) memory complexity - a multiplicative reduction from O(n²) standard attention
- Configurable: Extensive configuration options for different use cases
- Drop-in Replacement: Easy integration with existing transformer architectures
- Flash Attention Support: Optional Flash Attention acceleration when available
- Multiple Compression: Support for GroupedMLP, Conv1D, and AvgPool compression methods
git clone https://github.com/shreyashkar-ml/native_sparse_attention
uv pip install -r requirements.txtimport torch
from native_sparse_attention import NSAConfig, NSABlock
# Configure NSA
config = NSAConfig(
dim=512, # Model dimension
heads=8, # Number of attention heads
seq_len=2048, # Maximum sequence length
local_window=128, # Local attention window
block_size=32, # Compression block size
topk_blocks=4, # Top-K selective blocks
compression="grouped_mlp" # Compression method
)
# Create an NSA block
nsa_block = NSABlock(config)
# Use it like any PyTorch module
x = torch.randn(2, 1024, 512) # (batch, seq_len, dim)
output = nsa_block(x) # Same shape as inputThe NSAConfig class provides extensive customization options:
NSAConfig(
dim=512, # Model dimension
heads=8, # Number of attention heads
dim_head=None, # Head dimension (auto: dim // heads)
seq_len=2048, # Maximum sequence length
)NSAConfig(
local_window=128, # Local sliding window size (W)
block_size=32, # Block size for compression (B)
stride=32, # Stride for block compression (S)
topk_blocks=4, # Number of top-K blocks to select (K)
)Choose from three compression strategies:
# Grouped MLP (default - best performance)
config = NSAConfig(compression="grouped_mlp")
# 1D Convolution (memory efficient)
config = NSAConfig(compression="conv1d")
# Average Pooling (fastest)
config = NSAConfig(compression="avgpool")Control how the three attention branches are combined:
# Static gating (fixed weights)
config = NSAConfig(
gate_mode="static",
gate_init=(2.0, -2.0, -2.0) # (local, compressed, selective)
)
# Query-conditioned gating (adaptive weights)
config = NSAConfig(gate_mode="q_cond")config = NSAConfig(
dropout=0.1, # Dropout probability
use_flash=True, # Use Flash Attention when available
)NSA combines three attention mechanisms:
- Local Attention: Causal sliding window attention for recent context
- Compressed Attention: Attention over compressed representations of past tokens
- Selective Attention: Dynamic selection of important blocks based on query-key similarities
The outputs are combined using learnable gating weights that can be static or query-conditioned.
import torch.nn as nn
from native_sparse_attention import NSABlock, NSAConfig
class YourTransformer(nn.Module):
def __init__(self, vocab_size, n_layers=6, n_embd=384, n_head=6):
super().__init__()
self.embedding = nn.Embedding(vocab_size, n_embd)
# Configure NSA
nsa_config = NSAConfig(
dim=n_embd,
heads=n_head,
seq_len=1024, # Your max sequence length
local_window=64, # Adjust based on your needs
block_size=32,
topk_blocks=4,
compression="grouped_mlp"
)
# Replace standard transformer blocks with NSA blocks
self.blocks = nn.Sequential(*[
NSABlock(nsa_config) for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, x):
x = self.embedding(x)
x = self.blocks(x) # NSA blocks handle attention + FFN
x = self.ln_f(x)
return self.lm_head(x)from native_sparse_attention import NativeSparseAttention, NSAConfig
class TransformerBlock(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
# Replace MultiHeadAttention with NativeSparseAttention
nsa_config = NSAConfig(dim=n_embd, heads=n_head, seq_len=1024)
self.attn = NativeSparseAttention(nsa_config)
self.ln2 = nn.LayerNorm(n_embd)
self.ffn = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return xThe repository includes nsa_gpt_training.py - a complete example of training a language model with NSA on the Tiny Shakespeare dataset.
- Complete Language Model: Implements a GPT-style model using NSA blocks
- Quick Test Mode: Fast verification that NSA works as drop-in replacement
- Full Training Mode: Complete training loop with evaluation
- Text Generation: Sample generation to verify model quality
from native_sparse_attention import NSAConfig
nsa_config = NSAConfig(
dim=384, # Model dimension (384/6 = 64 per head)
heads=6, # Number of attention heads
seq_len=256, # Context length
local_window=64, # Local attention window
block_size=32, # Compression block size
stride=32, # Compression stride
topk_blocks=4, # Top-K blocks for selective attention
compression="grouped_mlp", # Use GroupedMLP compression
dropout=0.2, # Dropout rate
use_flash=True, # Enable Flash Attention
gate_mode="static", # Static gating
gate_init=(2.0, -2.0, -2.0) # Favor local attention initially
)# Download Tiny Shakespeare dataset (or use your own text file)
wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O native-sparse-attention/input.txt
# Quick test (few iterations to verify correctness)
python native-sparse-attention/nsa_gpt_training.py # QUICK_TEST=True by default
# Full training (set QUICK_TEST=False in the script)
# Edit the script to set QUICK_TEST = False, then:
python native-sparse-attention/nsa_gpt_training.pyclass NSALanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embd = nn.Embedding(vocab_size, n_embd)
self.position_embd = nn.Embedding(block_size, n_embd)
# Stack of NSA blocks (drop-in replacement for transformer blocks)
self.blocks = nn.Sequential(*[
NSABlock(nsa_config) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)- Standard Attention: O(n²) memory
- NSA: O(n²/stride) memory - multiplicative reduction by compression stride
- Local branch: O(n × window_size) ≈ O(n)
- Compressed branch: O(n × n/stride) = O(n²/stride) [dominant]
- Selective branch: O(n × topk_blocks × block_size) ≈ O(n)
# Short sequences (< 1K tokens)
short_config = NSAConfig(local_window=64, block_size=16, topk_blocks=2)
# Medium sequences (1K - 4K tokens)
medium_config = NSAConfig(local_window=128, block_size=32, topk_blocks=4)
# Long sequences (4K+ tokens)
long_config = NSAConfig(local_window=256, block_size=64, topk_blocks=8)Implement your own compression method:
from native_sparse_attention import TokenCompressor, NSAConfig
class CustomCompression(nn.Module):
def __init__(self, dim: int, block: int):
super().__init__()
# Your custom compression logic
def forward(self, x: torch.Tensor):
# Compress tokens from (b, n, d) to (b, n//block_size, d)
return compressed_x
# Then modify TokenCompressor to use your method# Access attention patterns (when use_flash=False)
nsa = NativeSparseAttention(config)
nsa.cfg.use_flash = False # Disable flash for pattern access
# Forward pass returns attention weights for analysis
with torch.no_grad():
output = nsa(x)
# Access internal attention patterns for visualization- Python 3.8+
- PyTorch 2.0+
- einops
- CUDA (optional, for Flash Attention acceleration)
The implementation includes built-in tests:
python native_sparse_attention.pyTests verify:
- Basic functionality (smoke test)
- Causal masking correctness
- Flash Attention parity (when available)