Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,13 @@
KEEP_2 = "activation_keep_2"
CONV_OUT = "activation_conv_out_channels"

# For setting self/cross attention independently in splash kernel
SELF_ATTN_HEAD = "activation_self_attn_heads"
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"


WAN_MODEL = "Wan2.1"
16 changes: 14 additions & 2 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,21 @@ flash_block_sizes: {}
# "block_kv" : 2048,
# "block_q_dkv" : 3024,
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 2048,
# "block_kv_dkv_compute" : 1024,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048
# }
# Use on v5p
flash_block_sizes: {
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 3072,
"block_q_dkv" : 1024,
"block_kv_dkv" : 3072,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 3072
}
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -132,8 +143,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],

['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
Expand Down
16 changes: 8 additions & 8 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,14 @@ def get_flash_block_sizes(config):
flash_block_sizes = None
if len(config.flash_block_sizes.keys()) > 0:
flash_block_sizes = splash_attention_kernel.BlockSizes(
block_q=config.flash_block_sizes["block_q"],
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
block_kv=config.flash_block_sizes["block_kv"],
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
block_q_dq=config.flash_block_sizes["block_q_dq"],
block_kv_dq=config.flash_block_sizes["block_kv_dq"],
block_q=int(config.flash_block_sizes["block_q"]),
block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]),
block_kv=int(config.flash_block_sizes["block_kv"]),
block_q_dkv=int(config.flash_block_sizes["block_q_dkv"]),
block_kv_dkv=int(config.flash_block_sizes["block_kv_dkv"]),
block_kv_dkv_compute=int(config.flash_block_sizes["block_kv_dkv_compute"]),
block_q_dq=int(config.flash_block_sizes["block_q_dq"]),
block_kv_dq=int(config.flash_block_sizes["block_kv_dq"]),
)
return flash_block_sizes

Expand Down
31 changes: 29 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@
EMBED = common_types.EMBED
Quant = quantizations.AqtQuantization

SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD
SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH
SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH
CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD
CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH
CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH


def _maybe_aqt_einsum(quant: Quant):
return jnp.einsum if quant is None else quant.einsum()
Expand Down Expand Up @@ -184,7 +191,8 @@ def _tpu_flash_attention(
kv_max_block_size = key.shape[1]
else:
kv_max_block_size = q_max_block_size
if flash_block_sizes:
# ensure that for cross attention we override the block sizes.
if flash_block_sizes and key.shape[1] == query.shape[1]:
block_sizes = flash_block_sizes
else:
block_sizes = splash_attention_kernel.BlockSizes(
Expand Down Expand Up @@ -439,7 +447,16 @@ def _apply_attention(
)
elif attention_kernel == "flash":
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
query,
key * scale,
value,
heads,
mesh,
axis_names_q,
axis_names_kv,
flash_block_sizes,
dtype,
attention_kernel,
)
elif attention_kernel == "ring":
return _tpu_flash_attention(
Expand Down Expand Up @@ -701,6 +718,7 @@ def __init__(
precision: jax.lax.Precision = None,
qkv_bias: bool = False,
quant: Quant = None,
is_self_attention: bool = True,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
Expand All @@ -717,6 +735,13 @@ def __init__(
self.value_axis_names = value_axis_names
self.out_axis_names = out_axis_names

if is_self_attention:
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV)
else:
axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV)
axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV)

self.attention_op = NNXAttentionOp(
mesh=mesh,
attention_kernel=attention_kernel,
Expand All @@ -726,6 +751,8 @@ def __init__(
use_memory_efficient_attention=use_memory_efficient_attention,
split_head_dim=split_head_dim,
float32_qk_product=False,
axis_names_q=axis_names_q,
axis_names_kv=axis_names_kv,
flash_min_seq_length=flash_min_seq_length,
flash_block_sizes=flash_block_sizes,
dtype=dtype,
Expand Down
7 changes: 6 additions & 1 deletion src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=True,
)

# 1. Cross-attention
Expand All @@ -300,6 +301,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=False,
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand Down Expand Up @@ -351,7 +353,10 @@ def __call__(
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
hidden_states = hidden_states + attn_output

Expand Down
Loading