Skip to content

sync layer norms #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ def _add_network_size_args(parser):
', needs to be divisible by TP size and `make-vocab-size-divisible-by`.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.')
group.add_argument('--layernorm-tp-auto-sync', action='store_true',
help='Force syncing layernorm params across TP ranks in forward. '
'This is a workaround for an unresolved bug leading to TP ranks '
'getting out of sync with each other.')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
Expand Down
3 changes: 2 additions & 1 deletion megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int)
'target_tokens': array([5])
}
]

Output:
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
Expand Down Expand Up @@ -139,6 +139,7 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
generator=torch.Generator().manual_seed(args.seed),
collate_fn=collate_fn,
pin_memory=True
)
Expand Down
17 changes: 12 additions & 5 deletions megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@

import numbers


from megatron import get_args
from megatron import mpu
from packaging import version
import torch
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
import importlib

from megatron import get_args
import torch
import torch.nn.functional as F

global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(self, normalized_shape, eps=1e-5):
self.reset_parameters()

args = get_args()
self.layernorm_tp_auto_sync = args.layernorm_tp_auto_sync

self.use_meg_ds_fused_layer_norm = (
args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm
Expand All @@ -97,6 +99,11 @@ def reset_parameters(self):


def forward(self, input):

if self.layernorm_tp_auto_sync:
torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())

if self.use_meg_ds_fused_layer_norm:
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
Expand Down
1 change: 1 addition & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def get_variation_config(self, variation, output_dir, n_samples=None):
--clip-grad 1.0
--weight-decay 1e-1
--embed-layernorm
--layernorm-tp-auto-sync
--fp16

--log-level debug
Expand Down