|
121 | 121 | layer_tp_plan = {
|
122 | 122 | "attention_norm": SequenceParallel(),
|
123 | 123 | "attention": PrepareModuleInput(
|
124 |
| - input_layouts=(Shard(1), None), |
125 |
| - desired_input_layouts=(Replicate(), None), |
| 124 | + input_layouts=(Shard(1), Replicate()), |
| 125 | + desired_input_layouts=(Replicate(), Replicate()), |
126 | 126 | ),
|
127 |
| - "attention.wq": ColwiseParallel(), |
128 |
| - "attention.wk": ColwiseParallel(), |
129 |
| - "attention.wv": ColwiseParallel(), |
| 127 | + "attention.wq": ColwiseParallel(use_local_output=False), |
| 128 | + "attention.wk": ColwiseParallel(use_local_output=False), |
| 129 | + "attention.wv": ColwiseParallel(use_local_output=False), |
130 | 130 | "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
|
131 | 131 | "ffn_norm": SequenceParallel(),
|
132 | 132 | "feed_forward": PrepareModuleInput(
|
|
138 | 138 | "feed_forward.w3": ColwiseParallel(),
|
139 | 139 | }
|
140 | 140 |
|
141 |
| - # Adjust attention module to use the local number of heads |
142 |
| - attn_layer = transformer_block.attention |
143 |
| - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() |
144 |
| - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() |
145 |
| - |
146 | 141 | # Custom parallelization plan for the model
|
147 | 142 | parallelize_module(
|
148 | 143 | module=transformer_block,
|
|
0 commit comments