From 754a68936cf0011eeb01a217e7d5860f4e9aa740 Mon Sep 17 00:00:00 2001 From: Chelsea Zhou Date: Mon, 14 Jul 2025 15:20:28 -0700 Subject: [PATCH] Temporary Commit at 7/14/2025, 3:15:40 PM Differential Revision: D78303516 --- torchrec/distributed/embeddingbag.py | 30 ++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 6a9a3d299..ebcf61f90 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -293,9 +293,18 @@ def create_sharding_infos_by_sharding_device_group( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), - total_num_buckets=config.total_num_buckets, - use_virtual_table=config.use_virtual_table, - virtual_table_eviction_policy=config.virtual_table_eviction_policy, + total_num_buckets=( + getattr(config, "total_num_buckets", None) + # TODO: Need to check if attribute exists for BC + ), + use_virtual_table=( + getattr(config, "use_virtual_table", None) + # TODO: Need to check if attribute exists for BC + ), + virtual_table_eviction_policy=( + getattr(config, "virtual_table_eviction_policy", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param, @@ -692,9 +701,18 @@ def create_grouped_sharding_infos( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), - total_num_buckets=config.total_num_buckets, - use_virtual_table=config.use_virtual_table, - virtual_table_eviction_policy=config.virtual_table_eviction_policy, + total_num_buckets=( + getattr(config, "total_num_buckets", None) + # TODO: Need to check if attribute exists for BC + ), + use_virtual_table=( + getattr(config, "use_virtual_table", None) + # TODO: Need to check if attribute exists for BC + ), + virtual_table_eviction_policy=( + getattr(config, "virtual_table_eviction_policy", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param,