Skip to content

Commit 8c1555d

Browse files
wwwjnsoumith
authored andcommitted
remove manual n_heads change
1 parent 58370dc commit 8c1555d

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

distributed/tensor_parallelism/fsdp_tp_example.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@
121121
layer_tp_plan = {
122122
"attention_norm": SequenceParallel(),
123123
"attention": PrepareModuleInput(
124-
input_layouts=(Shard(1), None),
125-
desired_input_layouts=(Replicate(), None),
124+
input_layouts=(Shard(1), Replicate()),
125+
desired_input_layouts=(Replicate(), Replicate()),
126126
),
127-
"attention.wq": ColwiseParallel(),
128-
"attention.wk": ColwiseParallel(),
129-
"attention.wv": ColwiseParallel(),
127+
"attention.wq": ColwiseParallel(use_local_output=False),
128+
"attention.wk": ColwiseParallel(use_local_output=False),
129+
"attention.wv": ColwiseParallel(use_local_output=False),
130130
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
131131
"ffn_norm": SequenceParallel(),
132132
"feed_forward": PrepareModuleInput(
@@ -138,11 +138,6 @@
138138
"feed_forward.w3": ColwiseParallel(),
139139
}
140140

141-
# Adjust attention module to use the local number of heads
142-
attn_layer = transformer_block.attention
143-
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
144-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
145-
146141
# Custom parallelization plan for the model
147142
parallelize_module(
148143
module=transformer_block,

0 commit comments

Comments
 (0)