From bbb5e376386c8b5803603d165bef5585d85aac1d Mon Sep 17 00:00:00 2001 From: Isuru Janith Ranawaka Date: Wed, 16 Jul 2025 07:12:35 -0700 Subject: [PATCH] Enable Changing the # of shards for CW resharding (#3188) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3188 Currently Dynamic Sharding assumes the # of shards per embedding table stays the same: - https://www.internalfb.com/code/fbsource/[6d270632037a1e8bca7f63500dd07fd0b213e572]/fbcode/torchrec/distributed/sharding/dynamic_sharding.py?lines=140 E.g. - `table_0` originally sharded on ranks: [0, 1] - Reshard API currently supports moving `table_0` shards to ranks [1, 2]. - Where `the shard` on rank 0 will move to rank 1, and the shard on rank 1 will move to rank 2 We want to support changing the # of shards: - e.g. table_0 originally on ranks: [0, 1] --> reshard to [0] - Or reshard to [0, 1, 2, 3] Here's the unit test you can modify to check if your usecase passes: - https://www.internalfb.com/code/fbsource/[4d0d74b9f3c441e7aa35ce7102200fa0ca8c95cf]/fbcode/torchrec/distributed/tests/test_dynamic_sharding.py?lines=453-459 - Basically change the new sharding plan to be a different # of ranks than the original sharding plan. Note: the new total number of ranks for each embedding table should be a factor of the dimension 0 of that embedding table - e.g. emb_table size: [4, 8], this can only be sharded on 1, 2, or 4 ranks. not 3 ranks Differential Revision: D78291717 --- .../distributed/sharding/dynamic_sharding.py | 390 +++++++++++++----- .../distributed/test_utils/test_sharding.py | 34 +- .../tests/test_dynamic_sharding.py | 28 +- 3 files changed, 346 insertions(+), 106 deletions(-) diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index 7f5e2c9eb..caaa9752b 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -20,25 +20,26 @@ ShardedModule, ShardedTensor, ShardingEnv, + ShardingType, ) OrderedShardNamesWithSizes = List[Tuple[str, List[int]]] """ -A type alias to represent an ordered shard name and the corresponding shard_size +A type alias to represent an ordered shard name and the corresponding shard_size in dim 0 & 1 that were sent to the current rank. -This is a flattened and pruned nested list, which orders the shards names and +This is a flattened and pruned nested list, which orders the shards names and sizes in the following priority: 1. Rank order 2. Table order 3. Shard order - in below examples represent the 2d tensor correlated to a -certain table `x`, allocated to rank `z`. The `y` here denotes the order of shards -in the module attributes such as state_dict, sharding_plan, etc.. + in below examples represent the 2d tensor correlated to a +certain table `x`, allocated to rank `z`. The `y` here denotes the order of shards +in the module attributes such as state_dict, sharding_plan, etc.. `z` != `y` numerically, but the order of shards is based on the order of ranks allocated -Example 1 NOTE: the ordering by rank: +Example 1 NOTE: the ordering by rank: Rank 0 sends table_0, shard_0 to Rank 1. Rank 2 sends table_1, shard_0 to Rank 1. Rank 2 sends table_1, shard_1 to Rank 0 @@ -60,8 +61,8 @@ Rank 0 sends table_2 to Rank 1 output_tensor = [ - , - + , + ] Example 3: NOTE: ordered by shard if table and rank are the same @@ -69,12 +70,239 @@ Rank 0 sends table_1, shard_1 to Rank 1 Rank 1: output_tensor = [ - , - + , + ] """ +def _generate_shard_allocation_metadata( + shard_name: str, + source_params: ParameterSharding, + destination_params: ParameterSharding, +) -> Dict[int, List[Tuple[int, List[int]]]]: + """ + Generates a mapping of shards to ranks for redistribution of data. + + This function creates a mapping from source ranks to destination ranks + based on the sharding specifications provided in the source and destination + parameters. It calculates the shard dimensions and allocates them to the + appropriate ranks in a greedy manner. + + Args: + shard_name (str): The name of the shard being processed. + source_params (ParameterSharding): The sharding parameters for the source. + destination_params (ParameterSharding): The sharding parameters for the destination. + + Returns: + Dict[int, List[Tuple[int, List[int]]]]: A dictionary mapping source ranks to a list of tuples, + where each tuple contains a destination rank and the corresponding shard offsets. + """ + shard_to_rank_mapping: Dict[int, List[Tuple[int, List[int]]]] = {} + src_rank_index = 0 + dst_rank_index = 0 + curr_source_offset = 0 + curr_dst_offset = 0 + + assert source_params.ranks is not None + assert destination_params.ranks is not None + + assert source_params.sharding_spec is not None + assert destination_params.sharding_spec is not None + + # Initialize dictionary keys for all source ranks + # Pyre-ignore + for rank in source_params.ranks: + shard_to_rank_mapping[rank] = [] + + # Pyre-ignore + while src_rank_index < len(source_params.ranks) and dst_rank_index < len( + destination_params.ranks # Pyre-ignore + ): + # Pyre-ignore + src_shard_size = source_params.sharding_spec.shards[src_rank_index].shard_sizes + dst_shard_size = destination_params.sharding_spec.shards[ + dst_rank_index + ].shard_sizes + + shard_dim = min( + src_shard_size[1] - curr_source_offset, dst_shard_size[1] - curr_dst_offset + ) + + next_source_offset = curr_source_offset + shard_dim + next_dst_offset = curr_dst_offset + shard_dim + + # Greedy way of allocating shards to ranks + # Pyre-ignore + shard_to_rank_mapping[source_params.ranks[src_rank_index]].append( + ( + destination_params.ranks[dst_rank_index], + [curr_source_offset, next_source_offset], + ) + ) + curr_source_offset = next_source_offset + curr_dst_offset = next_dst_offset + + if next_source_offset >= src_shard_size[1]: + src_rank_index += 1 + curr_source_offset = 0 + + if next_dst_offset >= dst_shard_size[1]: + dst_rank_index += 1 + curr_dst_offset = 0 + return shard_to_rank_mapping + + +def _process_shard_redistribution_metadata( + rank: int, + shard_name: str, + max_dim_0: int, + max_dim_1: int, + shard_to_rank_mapping: Dict[int, List[Tuple[int, List[int]]]], + sharded_tensor: ShardedTensor, + input_splits_per_rank: List[List[int]], + output_splits_per_rank: List[List[int]], + shard_names_to_lengths_by_src_rank: List[List[Tuple[str, List[int]]]], + local_table_to_input_tensor_by_dst_rank: List[List[torch.Tensor]], + local_table_to_opt_by_dst_rank: List[List[torch.Tensor]], + optimizer_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None, + extend_shard_name: Callable[[str], str] = lambda x: x, +) -> Tuple[int, int]: + """ + calculates shard redistribution metadata across ranks and processes optimizer state if present. + + This function handles the redistribution of tensor shards from source ranks to destination ranks + based on the provided shard-to-rank mapping. It also processes optimizer state if available, + ensuring that the data is correctly padded and split for communication between ranks. + + Args: + rank (int): The current rank of the process. + shard_name (str): The name of the shard being processed. + max_dim_0 (int): The maximum dimension size of dim 0 for padding. + max_dim_1 (int): The maximum dimension size of dim 1 for padding. + shard_to_rank_mapping (Dict[int, List[Tuple[int, List[int]]]]): Mapping of source ranks to destination ranks and split offsets. + sharded_tensor (ShardedTensor): The sharded tensor to be redistributed. + input_splits_per_rank (List[List[int]]): Input split sizes for each rank. + output_splits_per_rank (List[List[int]]): Output split sizes for each rank. + shard_names_to_lengths_by_src_rank (List[List[Tuple[str, List[int]]]]): List of shard names and sizes by source rank. + local_table_to_input_tensor_by_dst_rank (List[List[torch.Tensor]]): Local input tensors by destination rank. + local_table_to_opt_by_dst_rank (List[List[torch.Tensor]]): Local optimizer tensors by destination rank. + optimizer_state (Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]]): Optimizer state if available. + extend_shard_name (Callable[[str], str]): Function to extend shard names. + + Returns: + Tuple[int, int]: Counts of output tensors and optimizer tensors processed. + """ + + output_tensor_count = 0 + output_optimizer_count = 0 + has_optimizer = optimizer_state is not None + + # Process each shard mapping from source to destination + for src_rank, dsts in shard_to_rank_mapping.items(): + + for dst_rank, split_offsets in dsts: + + # Get shard metadata + shard_metadata = sharded_tensor.metadata().shards_metadata[0] + shard_size = shard_metadata.shard_sizes + + assert split_offsets[0] >= 0 + assert split_offsets[1] <= shard_size[1] + # Update the shard size with new size + shard_size = [shard_size[0], split_offsets[1] - split_offsets[0]] + # Update split sizes for communication + input_splits_per_rank[src_rank][dst_rank] += max_dim_0 + output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + if has_optimizer: + input_splits_per_rank[src_rank][dst_rank] += max_dim_0 + output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + + # Process data being sent from current rank + if src_rank == rank: + # Handle optimizer state if present + if has_optimizer and optimizer_state is not None: + + local_optimizer_shards = optimizer_state["state"][ + extend_shard_name(shard_name) + ][tmp_momentum_extender(shard_name)].local_shards() + assert ( + len(local_optimizer_shards) == 1 + ), "Expected exactly one local optimizer shard" + + local_optimizer_tensor = local_optimizer_shards[0].tensor + if len(local_optimizer_tensor.size()) == 1: # 1D Optimizer Tensor + # Convert to 2D Tensor, transpose, for AllToAll + local_optimizer_tensor = local_optimizer_tensor.view( + local_optimizer_tensor.size(0), 1 + ) + padded_optimizer_tensor = pad_tensor_to_max_dims( + local_optimizer_tensor, max_dim_0, max_dim_1 + ) + local_table_to_opt_by_dst_rank[dst_rank].append( + padded_optimizer_tensor + ) + + # Handle main tensor data + local_shards = sharded_tensor.local_shards() + assert len(local_shards) == 1, "Expected exactly one local shard" + + # cut the tensor based on split points + dst_t = local_shards[0].tensor[:, split_offsets[0] : split_offsets[1]] + + padded_tensor = pad_tensor_to_max_dims(dst_t, max_dim_0, max_dim_1) + local_table_to_input_tensor_by_dst_rank[dst_rank].append(padded_tensor) + + # Process data being received at current rank + if dst_rank == rank: + shard_names_to_lengths_by_src_rank[src_rank].append( + (shard_name, shard_size) + ) + output_tensor_count += max_dim_0 + if has_optimizer: + output_optimizer_count += max_dim_0 + + return output_tensor_count, output_optimizer_count + + +def _create_local_shard_tensors( + ordered_shard_names_and_lengths: OrderedShardNamesWithSizes, + output_tensor: torch.Tensor, + max_dim_0: int, +) -> Dict[str, torch.Tensor]: + """ + Creates local shard tensors from the output tensor based on the ordered shard names and lengths. + + This function slices the output tensor into smaller tensors (shards) according to the specified + dimensions in `ordered_shard_names_and_lengths`. It pads each shard to the maximum dimensions + and concatenates them if multiple shards exist for the same shard name. + + Args: + ordered_shard_names_and_lengths (OrderedShardNamesWithSizes): A list of tuples containing shard names + and their corresponding sizes. + output_tensor (torch.Tensor): The tensor containing all shards received by the current rank. + max_dim_0 (int): The maximum dimension size of dim 0 for slicing the output tensor. + + Returns: + Dict[str, torch.Tensor]: A dictionary mapping shard names to their corresponding local output tensors. + """ + slice_index = 0 + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} + for shard_name, shard_size in ordered_shard_names_and_lengths: + end_slice_index = slice_index + max_dim_0 + cur_t = output_tensor[slice_index:end_slice_index] + cur_t = pad_tensor_to_max_dims(cur_t, shard_size[0], shard_size[1]) + if shard_name not in shard_name_to_local_output_tensor.keys(): + shard_name_to_local_output_tensor[shard_name] = cur_t + else: + # CW sharding may have multiple shards per rank in many to one case, so we need to concatenate them + shard_name_to_local_output_tensor[shard_name] = torch.cat( + (shard_name_to_local_output_tensor[shard_name], cur_t), dim=1 + ) + slice_index = end_slice_index + return shard_name_to_local_output_tensor + + def shards_all_to_all( module: ShardedModule[Any, Any, Any, Any], # pyre-ignore state_dict: Dict[str, ShardedTensor], @@ -134,71 +362,49 @@ def shards_all_to_all( shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)] local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)] local_table_to_opt_by_dst_rank = [[] for _ in range(world_size)] + for shard_name, param in changed_sharding_params.items(): sharded_t = state_dict[extend_shard_name(shard_name)] assert param.ranks is not None - dst_ranks = param.ranks - # pyre-ignore - src_ranks = module.module_sharding_plan[shard_name].ranks - - # TODO: Implement changing rank sizes for beyond TW sharding - assert len(dst_ranks) == len(src_ranks) - - # index needed to distinguish between multiple shards - # within the same shardedTensor for each table - for i in range(len(src_ranks)): - # 1 to 1 mapping from src to dst - dst_rank = dst_ranks[i] - src_rank = src_ranks[i] - - shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes - input_splits_per_rank[src_rank][dst_rank] += max_dim_0 - output_splits_per_rank[dst_rank][src_rank] += max_dim_0 - if has_optimizer: - input_splits_per_rank[src_rank][dst_rank] += max_dim_0 - output_splits_per_rank[dst_rank][src_rank] += max_dim_0 - - # If sending from current rank - if src_rank == rank: - if has_optimizer: - # pyre-ignore - local_optimizer = optimizer_state["state"][ - extend_shard_name(shard_name) - ][tmp_momentum_extender(shard_name)].local_shards() - assert len(local_optimizer) == 1 - local_optimizer_tensor = local_optimizer[0].tensor - if len(local_optimizer_tensor.size()) == 1: # 1D Optimizer Tensor - # Convert to 2D Tensor, transpose, for AllToAll - local_optimizer_tensor = local_optimizer_tensor.view( - local_optimizer_tensor.size(0), 1 - ) - padded_local_optimizer = pad_tensor_to_max_dims( - local_optimizer_tensor, max_dim_0, max_dim_1 - ) - local_table_to_opt_by_dst_rank[dst_rank].append( - padded_local_optimizer - ) - local_shards = sharded_t.local_shards() - assert len(local_shards) == 1 - cur_t = pad_tensor_to_max_dims( - local_shards[0].tensor, max_dim_0, max_dim_1 - ) - local_table_to_input_tensor_by_dst_rank[dst_rank].append(cur_t) - - # If recieving from current rank - if dst_rank == rank: - shard_names_to_lengths_by_src_rank[src_rank].append( - (shard_name, shard_size) - ) - output_tensor_tensor_count += max_dim_0 - if has_optimizer: - output_optimizer_tensor_count += max_dim_0 + # pyre-ignore + src_params = module.module_sharding_plan[shard_name] + + assert ( + param.sharding_type == ShardingType.COLUMN_WISE.value + or param.sharding_type == ShardingType.TABLE_WISE.value + ) + + shard_mapping = _generate_shard_allocation_metadata( + shard_name=shard_name, + source_params=src_params, + destination_params=param, + ) + + tensor_count, optimizer_count = _process_shard_redistribution_metadata( + rank=rank, + shard_name=shard_name, + max_dim_0=max_dim_0, + max_dim_1=max_dim_1, + shard_to_rank_mapping=shard_mapping, + sharded_tensor=sharded_t, + input_splits_per_rank=input_splits_per_rank, + output_splits_per_rank=output_splits_per_rank, + shard_names_to_lengths_by_src_rank=shard_names_to_lengths_by_src_rank, + local_table_to_input_tensor_by_dst_rank=local_table_to_input_tensor_by_dst_rank, + local_table_to_opt_by_dst_rank=local_table_to_opt_by_dst_rank, + optimizer_state=optimizer_state, + extend_shard_name=extend_shard_name, + ) + + output_tensor_tensor_count += tensor_count + if has_optimizer: + output_optimizer_tensor_count += optimizer_count local_input_splits = input_splits_per_rank[rank] local_output_splits = output_splits_per_rank[rank] - local_input_tensor = torch.empty([0], device=device) + local_input_tensor = torch.empty([0, max_dim_1], device=device) for sub_l in local_table_to_input_tensor_by_dst_rank: for shard_info in sub_l: local_input_tensor = torch.cat( @@ -219,10 +425,11 @@ def shards_all_to_all( dim=0, ) + receive_count = output_tensor_tensor_count + output_optimizer_tensor_count max_embedding_size = max_dim_1 local_output_tensor = torch.empty( [ - output_tensor_tensor_count + output_optimizer_tensor_count, + receive_count, max_embedding_size, ], device=device, @@ -230,6 +437,7 @@ def shards_all_to_all( assert sum(local_output_splits) == len(local_output_tensor) assert sum(local_input_splits) == len(local_input_tensor) + dist.all_to_all_single( output=local_output_tensor, input=local_input_tensor, @@ -283,16 +491,12 @@ def update_state_dict_post_resharding( Returns: Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards. """ - slice_index = 0 - shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} - - for shard_name, shard_size in ordered_shard_names_and_lengths: - end_slice_index = slice_index + max_dim_0 - cur_t = output_tensor[slice_index:end_slice_index] - cur_t = pad_tensor_to_max_dims(cur_t, shard_size[0], shard_size[1]) - shard_name_to_local_output_tensor[shard_name] = cur_t - slice_index = end_slice_index + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = ( + _create_local_shard_tensors( + ordered_shard_names_and_lengths, output_tensor, max_dim_0 + ) + ) for shard_name, param in new_sharding_params.items(): extended_name = extend_shard_name(shard_name) @@ -302,18 +506,22 @@ def update_state_dict_post_resharding( r = param.ranks[i] sharded_t = state_dict[extended_name] # Update placements - sharded_t.metadata().shards_metadata[i].placement = ( - torch.distributed._remote_device(f"rank:{r}/cuda:{r}") - ) + + if len(sharded_t.metadata().shards_metadata) > i: + # pyre-ignore + sharded_t.metadata().shards_metadata[i] = param.sharding_spec.shards[i] + else: + sharded_t.metadata().shards_metadata.append( + param.sharding_spec.shards[i] + ) + # Update local shards if r == curr_rank: assert len(output_tensor) > 0 # slice output tensor for correct size. sharded_t._local_shards = [ Shard( tensor=shard_name_to_local_output_tensor[shard_name], - metadata=state_dict[extended_name] - .metadata() - .shards_metadata[i], + metadata=param.sharding_spec.shards[i], ) ] break @@ -334,14 +542,12 @@ def update_optimizer_state_post_resharding( old_opt_state_state = old_opt_state["state"] # Remove padding and store tensors by shard name - slice_index = 0 - shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} - for shard_name, shard_size in ordered_shard_names_and_lengths: - end_slice_index = slice_index + max_dim_0 - cur_t = output_tensor[slice_index:end_slice_index] - cur_t = pad_tensor_to_max_dims(cur_t, shard_size[0], shard_size[1]) - shard_name_to_local_output_tensor[shard_name] = cur_t - slice_index = end_slice_index + + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = ( + _create_local_shard_tensors( + ordered_shard_names_and_lengths, output_tensor, max_dim_0 + ) + ) for extended_shard_name, item in new_opt_state_state.items(): if extended_shard_name in old_opt_state_state: @@ -353,7 +559,7 @@ def update_optimizer_state_post_resharding( momentum_name = tmp_momentum_extender(shard_name) sharded_t = item[momentum_name] assert len(sharded_t._local_shards) == 1 - # TODO: support multiple shards in CW sharding + # local_tensor is updated in-pace for CW sharding local_tensor = shard_name_to_local_output_tensor[shard_name] if len(sharded_t._local_shards[0].tensor.size()) == 1: # Need to transpose 1D optimizer tensor, due to previous conversion diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index b78159f37..9d55360f4 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -364,6 +364,7 @@ def dynamic_sharding_test( random.randint(0, batch_size) if allow_zero_batch_size else batch_size ) num_steps = 2 + embedding_dim = 16 # Generate model & inputs. (global_model, inputs) = gen_model_and_input( model_class=model_class, @@ -381,7 +382,7 @@ def dynamic_sharding_test( # weighted_tables=weighted_tables, embedding_groups=embedding_groups, world_size=world_size, - num_float_features=16, + num_float_features=embedding_dim, variable_batch_size=variable_batch_size, batch_size=batch_size, feature_processor_modules=feature_processor_modules, @@ -408,7 +409,7 @@ def dynamic_sharding_test( embedding_groups=embedding_groups, dense_device=ctx.device, sparse_device=torch.device("meta"), - num_float_features=16, + num_float_features=embedding_dim, feature_processor_modules=feature_processor_modules, ) @@ -419,7 +420,7 @@ def dynamic_sharding_test( embedding_groups=embedding_groups, dense_device=ctx.device, sparse_device=torch.device("meta"), - num_float_features=16, + num_float_features=embedding_dim, feature_processor_modules=feature_processor_modules, ) @@ -492,11 +493,25 @@ def dynamic_sharding_test( assert ctx.pg is not None num_tables = len(tables) + ranks_per_tables = [1 for _ in range(num_tables)] + + # CW sharding + valid_candidates = [ + i for i in range(1, world_size + 1) if embedding_dim % i == 0 + ] + ranks_per_tables_for_CW = [ + random.choice(valid_candidates) for _ in range(num_tables) + ] + new_ranks = generate_rank_placements( world_size, num_tables, ranks_per_tables, random_seed ) + new_ranks_cw = generate_rank_placements( + world_size, num_tables, ranks_per_tables_for_CW, random_seed + ) + new_per_param_sharding = {} assert len(sharders) == 1 @@ -514,10 +529,15 @@ def dynamic_sharding_test( sharding_type_constructor = get_sharding_constructor_from_type( sharding_type ) - # TODO: CW sharding constructor takes in different args - new_per_param_sharding[table_name] = sharding_type_constructor( - rank=new_ranks[i][0], compute_kernel=kernel_type - ) + + if sharding_type == ShardingType.TABLE_WISE: + new_per_param_sharding[table_name] = sharding_type_constructor( + rank=new_ranks[i][0], compute_kernel=kernel_type + ) + elif sharding_type == ShardingType.COLUMN_WISE: + new_per_param_sharding[table_name] = sharding_type_constructor( + ranks=new_ranks_cw[i] + ) new_module_sharding_plan = construct_module_sharding_plan( local_m2.sparse.ebc, diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index d1bb8526b..eddf15faa 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -457,6 +457,7 @@ def test_dynamic_sharding_ebc_tw( num_tables=st.sampled_from([2, 3, 4]), data_type=st.sampled_from([DataType.FP32, DataType.FP16]), world_size=st.sampled_from([3, 4]), + embedding_dim=st.sampled_from([16]), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) def test_dynamic_sharding_ebc_cw( @@ -464,16 +465,21 @@ def test_dynamic_sharding_ebc_cw( num_tables: int, data_type: DataType, world_size: int, + embedding_dim: int, ) -> None: # Tests EBC dynamic sharding implementation for CW # Force the ranks per table to be consistent - ranks_per_tables = [ - random.randint(1, world_size - 1) for _ in range(num_tables) - ] + valid_ranks = [i for i in range(1, world_size) if embedding_dim % i == 0] + ranks_per_tables = [random.choice(valid_ranks) for _ in range(num_tables)] + new_ranks_per_tables = [random.choice(valid_ranks) for _ in range(num_tables)] + + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) - new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + new_ranks = generate_rank_placements( + world_size, num_tables, new_ranks_per_tables + ) # Cannot include old/new rank generation with hypothesis library due to depedency on world_size while new_ranks == old_ranks: @@ -497,6 +503,7 @@ def test_dynamic_sharding_ebc_cw( num_tables, world_size, data_type, + embedding_dim, ) @@ -513,6 +520,12 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), kernel_type=st.sampled_from( [ EmbeddingComputeKernel.DENSE.value, @@ -554,9 +567,10 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): random_seed=st.integers(0, 1000), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) - def test_sharding_tw( + def test_sharding( self, sharder_type: str, + sharding_type: str, kernel_type: str, qcomms_config: Optional[QCommsConfig], apply_optimizer_in_backward_config: Optional[ @@ -575,12 +589,12 @@ def test_sharding_tw( ): self.skipTest("CPU does not support uvm.") - sharding_type = ShardingType.TABLE_WISE.value assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size ) + sharding_type_e = ShardingType(sharding_type) self._test_dynamic_sharding( # pyre-ignore[6] sharders=[ @@ -597,7 +611,7 @@ def test_sharding_tw( apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, data_type=data_type, - sharding_type=ShardingType.TABLE_WISE, + sharding_type=sharding_type_e, random_seed=random_seed, )