diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index a73cf2a68..eb9877869 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -532,11 +532,12 @@ def build_train_valid_test_data_loaders(neox_args): else: pipe_load = True - # Data loader only on rank 0 of each model parallel group. + # Data loader only on rank 0 of each model and context parallel group. if ( pipe_load and (neox_args.dataset_impl == "online") and (mpu.get_model_parallel_rank() == 0) + and (mpu.get_context_parallel_rank() == 0) ): # Can skip most of the work... train_iters = neox_args.train_iters @@ -721,11 +722,17 @@ def build_train_valid_test_data_loaders(neox_args): # broadcast globally instead of just the model parallel group. torch.distributed.broadcast(flags, src=0) else: + # The same data should be used for the model parallel and context parallel groups torch.distributed.broadcast( flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group(), ) + torch.distributed.broadcast( + flags, + mpu.get_context_parallel_src_rank(), + group=mpu.get_context_parallel_group(), + ) neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() diff --git a/megatron/initialize.py b/megatron/initialize.py index 29afe7f9a..7ba11f161 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -158,16 +158,20 @@ def _initialize_distributed(neox_args): # Setup 3D topology. pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1 mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1 + cp = neox_args.context_parallel_size if neox_args.context_parallel_size >= 1 else 1 + assert ( + neox_args.world_size % (pp * mp * cp) == 0 + ), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}, cp={cp}" assert ( neox_args.world_size % (pp * mp) == 0 ), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}" + # The data parallel ranks will be used for context parallel + # to piggy back the gradient all reduce dp = neox_args.world_size // (pp * mp) + assert dp % cp == 0 + from deepspeed.runtime.pipe.topology import ProcessTopology - from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology - - # this does pipe on the most outside, then data, then model. - # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order. - topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp) + topo = ProcessTopology(axes=["pipe", "data", "model"], dims=[pp, dp, mp]) # Offset base seeds for the interior pipeline stages. # TODO: adjust last stage too once IO is improved. @@ -186,6 +190,8 @@ def _initialize_distributed(neox_args): else: mpu.initialize_model_parallel( neox_args.model_parallel_size, + neox_args.pipe_parallel_size, + neox_args.context_parallel_size, topology=topo, fp32_allreduce=neox_args.fp32_allreduce, ) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 3fd251147..08cb607ad 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -37,7 +37,7 @@ def __init__( normalized_shape, eps=1e-5, no_persist_layer_norm=True, - sequence_parallel=False, + context_parallel=False, apply_layernorm_1p=False, mem_efficient_ln=True, ): @@ -92,11 +92,11 @@ def __init__( self.bias = Parameter(torch.Tensor(*normalized_shape)) self.reset_parameters() self.no_persist_layer_norm = no_persist_layer_norm - self.sequence_parallel = sequence_parallel + self.context_parallel = context_parallel # set sequence parallelism flag on weight and bias parameters - setattr(self.weight, "sequence_parallel", self.sequence_parallel) - setattr(self.bias, "sequence_parallel", self.sequence_parallel) + setattr(self.weight, "context_parallel", self.context_parallel) + setattr(self.bias, "context_parallel", self.context_parallel) def reset_parameters(self): diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 7899048db..b7252ee25 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -74,7 +74,30 @@ def cross_entropy(output, labels, _fp16=False): else: losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels) loss_mask = loss_mask.view(-1) - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + loss_mask_sum = loss_mask.sum() + if mpu.get_context_parallel_world_size() > 1: + dt = loss_mask_sum.dtype + if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce(): + loss_mask_sum = loss_mask_sum.float() + torch.distributed.all_reduce( + loss_mask_sum, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_context_parallel_group(), + ) + if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce(): + loss_mask_sum = loss_mask_sum.bfloat16() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum + if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce(): + loss = loss.float() + torch.distributed.all_reduce( + loss, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_context_parallel_group(), + ) + if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce(): + loss = loss.bfloat16() + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum return loss diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 072aad8b4..74abf7d32 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -14,6 +14,7 @@ import torch import math +import megatron.mpu as mpu class SinusoidalPositionalEmbedding(torch.nn.Module): @@ -37,7 +38,13 @@ def forward(self, x, seq_dim=1): class RotaryEmbedding(torch.nn.Module): def __init__( - self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False + self, + dim, + max_seq_len, + base=10000, + precision=torch.half, + save_inv_freqs=False, + zigzag=True, ): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) @@ -49,6 +56,7 @@ def __init__( self.max_seq_len = max_seq_len self.base = base self.dim = dim + self.zigzag = zigzag # seq parallel zigzag # precompute cos_cached, sin_cached in fp32 cos_cached, sin_cached, inv_freq = self._prepare_cache( @@ -64,6 +72,19 @@ def _prepare_cache(self, seq_len, precision, base): inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) t = torch.arange(seq_len).type_as(inv_freq) + if mpu.get_context_parallel_world_size() > 1: + if not self.zigzag: + t_chunks = torch.chunk(t, mpu.get_context_parallel_world_size()) + t = t_chunks[mpu.get_context_parallel_rank()].contiguous() + else: + t_chunks = torch.chunk(t, 2 * mpu.get_context_parallel_world_size()) + t = torch.cat( + ( + t_chunks[mpu.get_context_parallel_rank()], + t_chunks[-(mpu.get_context_parallel_rank() + 1)], + ), + dim=0, + ).contiguous() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e60fbbe41..b64dc328b 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -459,6 +459,7 @@ def __init__( self.rope_fusion = neox_args.rope_fusion self.attention_type = neox_args.attention_config[layer_number] self.use_flash_attention = self.attention_type == "flash" + self.use_ring_attention = self.attention_type == "ring" self.use_triton = ( self.use_flash_attention and self.pos_emb == "alibi" @@ -467,7 +468,7 @@ def __init__( >= packaging.version.Version("2.4.0.post1") ) ) - self.sparse = self.attention_type not in ("global", "flash") + self.sparse = self.attention_type not in ("global", "flash", "ring") if self.gqa: assert not self.sparse @@ -496,6 +497,12 @@ def __init__( self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton self.flash_qkv_fn = flash_attn_func self.flash_varlen_qkv_fn = flash_attn_varlen_func + elif self.use_ring_attention: + from ring_flash_attn.zigzag_ring_flash_attn import ( + zigzag_ring_flash_attn_func, + ) + + self.ring_attn_fn = zigzag_ring_flash_attn_func else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, @@ -743,6 +750,96 @@ def flash_attention(self, query_layer, key_layer, value_layer): return matmul_result + def ring_attention(self, query_layer, key_layer, value_layer): + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn] + key_layer = key_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + value_layer = value_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + + # [sq, b, np, hn] -> [b, sq, np, hn] + query_layer = query_layer.transpose(0, 1).reshape( + output_size[0], output_size[2], output_size[1], -1 + ) + + # only pass in window_size or alibi_slopes kwarg + # if we use Sliding Window Attention / AliBi. + # Flash attn defaults to (-1,-1), or + # does not have this kwarg prior to v2.3.0 + extra_kwargs = ( + {"window_size": (self.sliding_window_width, -1)} + if self.sliding_window_width is not None + else {} + ) + if self.pos_emb == "alibi": + extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to( + query_layer.device + ).to(torch.float32) + + if not self.training: + batch_size = output_size[0] + max_seqlen_q = output_size[2] + max_seqlen_k = output_size[3] + + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device, + ) + + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * max_seqlen_k, + step=max_seqlen_k, + dtype=torch.int32, + device=key_layer.device, + ) + + q_shape = query_layer.shape + k_shape = key_layer.shape + v_shape = value_layer.shape + is_causal = max_seqlen_q == max_seqlen_k + output = self.ring_attn_fn( + query_layer, + key_layer, + value_layer, + 0.0, + softmax_scale=None, + causal=is_causal, + group=mpu.get_context_parallel_group(), + **extra_kwargs, + ) + output = output.reshape(q_shape) + else: + output = self.ring_attn_fn( + query_layer, + key_layer, + value_layer, + self.dropout_p if self.training else 0.0, + softmax_scale=None, + causal=True, + group=mpu.get_context_parallel_group(), + **extra_kwargs, + ) + + matmul_result = output + # [b, sq, np, hn] -> [b, np, sq, hn] + matmul_result = matmul_result.transpose(1, 2) + + return matmul_result + def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): # TODO: sparse attn dropout? # TODO: pad to block size @@ -818,7 +915,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): value_layer = value_layer.view(*new_kv_shape) # if not using Flash attention, we repeat K/V heads to match Q head counts - if not self.use_flash_attention: + if not (self.use_flash_attention or self.use_ring_attention): key_layer = torch.repeat_interleave( key_layer, repeats=int( @@ -929,6 +1026,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None): if self.use_flash_attention: context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif self.use_ring_attention: + context_layer = self.ring_attention(query_layer, key_layer, value_layer) elif not self.sparse: context_layer = self.attention( query_layer, key_layer, value_layer, layer_past, attention_mask diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 5515c41f5..40e45b868 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -373,14 +373,14 @@ def reduce_weight_grads_from_model_parallel_region(input_): # Bf16 convert dt = input_.dtype - if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce(): input_ = input_.float() # All-reduce. dist.all_reduce(input_, group=mpu.get_model_parallel_group()) # Bf16 convert - if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce(): input_ = input_.bfloat16() return input_ diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 780fb33e8..8ae2a26f9 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -57,3 +57,10 @@ from .utils import divide from .utils import split_tensor_along_last_dim +from .data import zigzag_data +from .initialize import ( + get_context_parallel_group, + get_context_parallel_rank, + get_context_parallel_world_size, + get_context_parallel_src_rank, +) diff --git a/megatron/mpu/data.py b/megatron/mpu/data.py index 87e2a9615..0b708171b 100644 --- a/megatron/mpu/data.py +++ b/megatron/mpu/data.py @@ -17,6 +17,10 @@ from .initialize import get_model_parallel_group from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_src_rank +from .initialize import get_context_parallel_src_rank +from .initialize import get_context_parallel_group +from .initialize import get_context_parallel_rank +from .initialize import get_context_parallel_world_size _MAX_DATA_DIM = 4 @@ -38,7 +42,7 @@ def _build_key_size_numel_dictionaries(keys, data): sizes = [0 for _ in range(max_dim) for _ in keys] # Pack the sizes on rank zero. - if get_model_parallel_rank() == 0: + if (get_model_parallel_rank() == 0) and (get_context_parallel_rank() == 0): offset = 0 for key in keys: assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" @@ -52,6 +56,9 @@ def _build_key_size_numel_dictionaries(keys, data): torch.distributed.broadcast( sizes_cuda, get_model_parallel_src_rank(), group=get_model_parallel_group() ) + torch.distributed.broadcast( + sizes_cuda, get_context_parallel_src_rank(), group=get_context_parallel_group() + ) # Move back to cpu and unpack. sizes_cpu = sizes_cuda.cpu() @@ -76,7 +83,7 @@ def _build_key_size_numel_dictionaries(keys, data): return key_size, key_numel, total_numel -def broadcast_data(keys, data, datatype): +def broadcast_data(keys, data, datatype, zigzag=False): """Broadcast data from rank zero of each model parallel group to the members of the same model parallel group. @@ -91,7 +98,7 @@ def broadcast_data(keys, data, datatype): key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) # Pack on rank zero. - if get_model_parallel_rank() == 0: + if (get_model_parallel_rank() == 0) and (get_context_parallel_rank() == 0): # Check that all keys have the same data type. _check_data_types(keys, data, datatype) # Flatten the data associated with the keys @@ -107,6 +114,11 @@ def broadcast_data(keys, data, datatype): torch.distributed.broadcast( flatten_data, get_model_parallel_src_rank(), group=get_model_parallel_group() ) + torch.distributed.broadcast( + flatten_data, + get_context_parallel_src_rank(), + group=get_context_parallel_group(), + ) # Unpack output = {} @@ -117,4 +129,23 @@ def broadcast_data(keys, data, datatype): output[key] = flatten_data.narrow(0, offset, numel).view(size) offset += numel - return output + return output if not zigzag else {key: zigzag_data(output[key]) for key in keys} + + +def zigzag_data(data, seq_dim=1): + """Zigzag the data along the seq dimension. + Arguments: + data: data dictionary of string keys and cpu tensor values. + seq_dim: the sequence dimension to zigzag. + """ + worldsize = get_context_parallel_world_size() + # first check if we can just skip it... + if worldsize == 1: + return data + # otherwise prepare for zigzagging + seq_chunks = torch.chunk(data, 2 * worldsize, dim=seq_dim) + data = [ + torch.cat((seq_chunks[i], seq_chunks[-(i + 1)]), dim=seq_dim) + for i in range(worldsize) + ] + return data[get_context_parallel_rank()].contiguous() diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index 19d231524..c3adc3f3d 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -28,6 +28,8 @@ _DATA_PARALLEL_GROUP = None # Pipeline parallel group that the current rank belongs to. _PIPE_PARALLEL_GROUP = None +# Sequence parallel group that the current rank belongs to. +_CONTEXT_PARALLEL_GROUP = None # A group used to sync during the IO process. Usually this is data_parallel_group(), # but with pipeline parallelism it must also involve the last stage (which is not in the @@ -38,7 +40,7 @@ _MPU_WORLD_SIZE = None _MPU_RANK = None -# Used to query 3D topology +# Used to query 4D topology _MPU_TOPOLOGY = None # Get fp32_allreduce flag @@ -50,12 +52,27 @@ def is_unitialized(): return _DATA_PARALLEL_GROUP is None -def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False): +def initialize_model_parallel( + model_parallel_size, + pipe_parallel_size, + context_parallel_size, + topology=None, + fp32_allreduce=False, +): """ Initialize model data parallel groups. Arguments: - model_parallel_size: number of GPUs used to parallelize model. + model_parallel_size: number of GPUs used for model parallelism. + pipe_parallel_size: number of GPUs used for pipeline parallelism. + context_parallel_size: number of GPUs used for context parallelism. + topology: topology if it exists. + fp32_allreduce: whether or not to do all reduce in fp32. + + Adjacent ranks are ordered by model parallel, then context parallel, + then data parallel. Context parallelism duplicates weights among GPUs in + a context parallel group, so we piggy back on the data parallel group + for the gradient all-reduce. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will @@ -74,9 +91,11 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() - if world_size < model_parallel_size: - raise ValueError("world size cannot be smaller than model parallel size") - ensure_divisibility(world_size, model_parallel_size) + if world_size < model_parallel_size * context_parallel_size: + raise ValueError( + "world size cannot be smaller than (model parallel size) * (sequence parallel size)" + ) + ensure_divisibility(world_size, model_parallel_size * context_parallel_size) rank = torch.distributed.get_rank() global _MPU_TOPOLOGY @@ -87,15 +106,18 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce global _DATA_PARALLEL_GROUP assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" if topology: - for dp_group in topology.get_axis_comm_lists("data"): + dp_groups = topology.get_axis_comm_lists("data") + for dp_group in dp_groups: group = torch.distributed.new_group(ranks=dp_group) if rank == 0: print(f"MPU DP:", dp_group) if rank in dp_group: _DATA_PARALLEL_GROUP = group else: + dp_groups = [] for i in range(model_parallel_size): ranks = range(i, world_size, model_parallel_size) + dp_groups.append(list(ranks)) group = torch.distributed.new_group(ranks) if i == (rank % model_parallel_size): _DATA_PARALLEL_GROUP = group @@ -110,22 +132,7 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce if rank in pp_group: _PIPE_PARALLEL_GROUP = group - # Build IO group - global _IO_PARALLEL_GROUP - if topology and topology.get_dim("pipe") > 1: - io_stages = [0, topology.get_dim("pipe") - 1] - io_group = [] - for stage in io_stages: - io_group.extend(topology.filter_match(pipe=stage, model=0)) - if rank == 0: - print(f"MPU IO:", io_group) - group = torch.distributed.new_group(ranks=io_group) - if rank in io_group: - _IO_PARALLEL_GROUP = group - else: - _IO_PARALLEL_GROUP = get_data_parallel_group() - - # Build the model parallel groups. + # Build the model parallel groups global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" if topology: @@ -138,8 +145,6 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce print(f"MPU MP:", [group_rank]) if rank == group_rank: _MODEL_PARALLEL_GROUP = group - return - for mp_group in topology.get_axis_comm_lists("model"): group = torch.distributed.new_group(ranks=mp_group) if rank == 0: @@ -154,6 +159,50 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce if i == (rank // model_parallel_size): _MODEL_PARALLEL_GROUP = group + # Build the sequence parallel groups. + global _CONTEXT_PARALLEL_GROUP + assert ( + _CONTEXT_PARALLEL_GROUP is None + ), "context parallel group is already initialized" + for dp_group in dp_groups: + for start in range(0, len(dp_group), context_parallel_size): + ranks = [dp_group[i] for i in range(start, start + context_parallel_size)] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _CONTEXT_PARALLEL_GROUP = group + + # Build IO group + global _IO_PARALLEL_GROUP + if topology and topology.get_dim("pipe") > 1: + if context_parallel_size > 1: + raise ValueError("Context parallel not tested with pipeline parallelism") + io_stages = [0, topology.get_dim("pipe") - 1] + io_group = [] + for stage in io_stages: + io_group.extend(topology.filter_match(pipe=stage, model=0)) + if rank == 0: + print(f"MPU IO:", io_group) + group = torch.distributed.new_group(ranks=io_group) + if rank in io_group: + _IO_PARALLEL_GROUP = group + else: + if context_parallel_size > 1: + if pipe_parallel_size > 1: + raise ValueError( + "Context parallel not tested with pipeline parallelism" + ) + for dp_group in dp_groups: + for start in range(0, len(dp_group) // context_parallel_size): + ranks = [ + dp_group[i] + for i in range(start, len(dp_group), context_parallel_size) + ] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _IO_PARALLEL_GROUP = group + else: + _IO_PARALLEL_GROUP = _DATA_PARALLEL_GROUP + global _FP32_ALLREDUCE assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized" _FP32_ALLREDUCE = fp32_allreduce @@ -184,6 +233,14 @@ def get_io_parallel_group(): return _IO_PARALLEL_GROUP +def get_context_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert ( + _CONTEXT_PARALLEL_GROUP is not None + ), "sequence parallel group is not initialized" + return _CONTEXT_PARALLEL_GROUP + + def set_model_parallel_world_size(world_size): """Set the model parallel size""" global _MPU_WORLD_SIZE @@ -220,6 +277,30 @@ def get_model_parallel_src_rank(): return (global_rank // local_world_size) * local_world_size +def get_context_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return torch.distributed.get_world_size(group=get_context_parallel_group()) + + +def get_context_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return torch.distributed.get_rank(group=get_context_parallel_group()) + + +def get_context_parallel_src_rank(): + """Calculate the global rank corresponding to a local rank zero + in the sequence parallel group.""" + global_rank = torch.distributed.get_rank() + # Model parallel and sequence parallel are scheduled together as a group + # Model parallel is scheduled in adjacent ranks + local_world_size = ( + get_model_parallel_world_size() * get_context_parallel_world_size() + ) + return ( + global_rank // local_world_size + ) * local_world_size + get_model_parallel_rank() + + def get_data_parallel_src_rank(): """Calculate the global rank corresponding to a local rank zero in the data parallel group.""" @@ -238,7 +319,9 @@ def get_data_parallel_src_rank(): def get_data_parallel_world_size(): """Return world size for the data parallel group.""" - return torch.distributed.get_world_size(group=get_data_parallel_group()) + return torch.distributed.get_world_size( + group=get_data_parallel_group() + ) // torch.distributed.get_world_size(group=get_context_parallel_group()) def get_data_parallel_rank(): @@ -316,6 +399,8 @@ def destroy_model_parallel(): _MPU_TOPOLOGY = None global _FP32_ALLREDUCE _FP32_ALLREDUCE = None + global _CONTEXT_PARALLEL_GROUP + _CONTEXT_PARALLEL_GROUP = None def get_fp32_allreduce(): diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index f3daacd4d..c9b72333d 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -814,8 +814,11 @@ def configure_distributed_args(self): if self.rank == 0: print( self.__class__.__name__ - + ".configure_distributed_args() using world size: {} and model-parallel size: {} ".format( - self.world_size, self.model_parallel_size + + ".configure_distributed_args() using world size: {}, pipe-parallel size: {}, context-parallel size: {}, and model-parallel size: {} ".format( + self.world_size, + self.pipe_parallel_size, + self.context_parallel_size, + self.model_parallel_size, ), flush=True, ) @@ -918,13 +921,16 @@ def calculate_derived(self): pp_size = pp_size if pp_size >= 1 else 1 mp_size = self.model_parallel_size mp_size = mp_size if mp_size >= 1 else 1 + cp_size = self.context_parallel_size + cp_size = cp_size if cp_size >= 1 else 1 self.update_value("model_parallel_size", mp_size) + self.update_value("context_parallel_size", cp_size) - # pp_size and mp_size are only used here to compute dp world size and nowhere else. - dp_world_size = (global_num_gpus / pp_size) / mp_size + # pp_size, mp_size, and cp_size are only used here to compute dp world size and nowhere else. + dp_world_size = (global_num_gpus / pp_size) / (mp_size * cp_size) if not (dp_world_size % 1 == 0): error_message = ( - f"{ERROR}" + "ERROR" + self.__class__.__name__ + ".calculate_derived() " + f"(global_num_gpus / pp_size) / mp_size [({global_num_gpus} / {pp_size}) / {mp_size}] must be a whole number" @@ -1081,6 +1087,11 @@ def calculate_derived(self): # if we set pipe_parallel_size to 0, GPT2ModelPipe.to_sequential() is called, and we run training with # the sequential model without the PipelineModule wrapper to avoid the overhead it incurs self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1) + # update 'is sequence parallel' flag + self.update_value( + "is_context_parallel", + self.context_parallel_size > 1 and self.moe_num_experts == 1, + ) if self.moe_num_experts > 1: assert not ( self.is_pipe_parallel or self.pipe_parallel_size > 1 @@ -1098,6 +1109,13 @@ def calculate_derived(self): "attention_config", expand_attention_types(self.attention_config, self.num_layers), ) + self.update_value( + "requires_attention_mask", + not all([item in ["ring", "flash"] for item in self.attention_config]), + ) + assert all([item == "ring" for item in self.attention_config]) or ( + not self.is_context_parallel + ), "Context parallel requires ring attention!" assert ( len(self.attention_config) == self.num_layers ), "Length of attention config list must equal num_layers" @@ -1143,7 +1161,9 @@ def calculate_derived(self): not self.sparsity_config ), "Sparse attention not compatible with GQA or MQA" assert all( - (attn_type == "flash") or (attn_type == "global") + (attn_type == "flash") + or (attn_type == "global") + or (attn_type == "ring") for attn_type in self.attention_config ), "GQA / MQA currently only compatible with Flash or standard global/sliding window Attention" assert ( diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 9c8d3635f..2c1cdf281 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -38,6 +38,7 @@ "flash", "rwkv", "mamba", + "ring", ] @@ -67,6 +68,11 @@ class NeoXArgsParallelism(NeoXArgsTemplate): Size of the model parallelism. """ + context_parallel_size: int = 1 + """ + Size of the context parallelism. + """ + pipe_partition_method: str = "type:transformer|mlp" """ method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number @@ -89,7 +95,12 @@ class NeoXArgsParallelism(NeoXArgsTemplate): """ flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. - **Set by user, in contrast to neox_args.is_pipe_parallel.** + """ + + is_context_parallel: bool = False + """ + flag to determine whether context parallelism is on - shouldn't be set by user, is automatically determined + according to context parallel size. """ expert_interval: int = 2 @@ -239,7 +250,7 @@ class NeoXArgsModel(NeoXArgsTemplate): The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"] + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv", "ring"] So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] @@ -251,6 +262,12 @@ class NeoXArgsModel(NeoXArgsTemplate): [[[`global`], n_layers]] """ + requires_attention_mask: bool = True + """ + If true, the model requires an attention mask to be passed in. + Automatically configured based on attention type. + """ + sparsity_config: dict = None """ diff --git a/megatron/training.py b/megatron/training.py index 3def74860..de25449a6 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -379,11 +379,19 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype, label_mask_zero=False eod_token=neox_args.tokenizer.eod, eod_mask_loss=neox_args.eod_mask_loss, sliding_window_width=neox_args.sliding_window_width, + requires_mask=neox_args.requires_attention_mask, ) - # combine loss masks from get_ltor_masks_and_position_ids with loss masks from data loss_mask = label_mask.to(loss_mask.dtype) * loss_mask - return tokens, labels, loss_mask, attention_mask, position_ids + return ( + mpu.zigzag_data(tokens), + mpu.zigzag_data(labels), + mpu.zigzag_data(loss_mask), + mpu.zigzag_data(attention_mask, -2) + if neox_args.requires_attention_mask + else None, + mpu.zigzag_data(position_ids), + ) def get_batch(neox_args, data_iterator): @@ -526,6 +534,7 @@ def get_batch_sequential(forward_input, neox_args): data=forward_input[0], eod_token=neox_args.tokenizer.eod, eod_mask_loss=neox_args.eod_mask_loss, + requires_mask=neox_args.requires_attention_mask, ) return (forward_input[0], forward_input[1], attention_mask) @@ -1374,6 +1383,9 @@ def backward_step(neox_args, timers, optimizer, model, loss): raise ValueError("Must be using deepspeed to run neox") +train_step_counter = 0 + + def train_step( neox_args, timers, @@ -1550,7 +1562,6 @@ def train( # to monitor if we've skipped many iterations in a row and trigger an early exit overflow_monitor = OverflowMonitor(optimizer) - if neox_args.profile: schedule = torch.profiler.schedule( wait=neox_args.profile_step_start, diff --git a/megatron/utils.py b/megatron/utils.py index fc2f80dad..7e39e8402 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -81,22 +81,29 @@ def get_attn_mask(seq_length, device, sliding_window_width): def get_ltor_masks_and_position_ids( - data, - eod_token, - eod_mask_loss=False, - sliding_window_width=None, + data, eod_token, eod_mask_loss=False, sliding_window_width=None, requires_mask=True ): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() - # Attention mask (lower triangular). - attention_mask = get_attn_mask( - seq_length=seq_length, - device=data.device, - sliding_window_width=sliding_window_width, - ) + if requires_mask: + # Attention mask (lower triangular). + attention_mask = get_attn_mask( + seq_length=seq_length, + device=data.device, + sliding_window_width=sliding_window_width, + ) + else: + # Need this to actually do long context, 128k**2 is v big. + # Give it a dummy value + # Surely there is a better way to do this... + attention_mask = get_attn_mask( + seq_length=64, + device=data.device, + sliding_window_width=sliding_window_width, + ) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) diff --git a/requirements/requirements-ringattention.txt b/requirements/requirements-ringattention.txt new file mode 100644 index 000000000..27636efcc --- /dev/null +++ b/requirements/requirements-ringattention.txt @@ -0,0 +1 @@ +git+https://github.com/zhuzilin/ring-flash-attention.git