From aaa13065e8cfb626b5f59b6f958032ef4d1cd131 Mon Sep 17 00:00:00 2001 From: Chelsea Zhou Date: Fri, 11 Jul 2025 16:39:24 -0700 Subject: [PATCH] Fix compatibility issue with EmbeddingBag for IEN Publish (#3181) Summary: Discovered when working on fix a ien e2e test https://fb.workplace.com/groups/gpuinference/permalink/3213711332110846/ This is not ideal fix, new field will introduce incompatibility again in the future. Following up in post for better solution Reviewed By: hannaxu, faran928 Differential Revision: D78125516 --- torchrec/distributed/embeddingbag.py | 30 ++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 45553b3e6..f2ed1de1e 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -293,9 +293,18 @@ def create_sharding_infos_by_sharding_device_group( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), - total_num_buckets=config.total_num_buckets, - use_virtual_table=config.use_virtual_table, - virtual_table_eviction_policy=config.virtual_table_eviction_policy, + total_num_buckets=( + getattr(config, "total_num_buckets", None) + # TODO: Need to check if attribute exists for BC + ), + use_virtual_table=( + getattr(config, "use_virtual_table", None) + # TODO: Need to check if attribute exists for BC + ), + virtual_table_eviction_policy=( + getattr(config, "virtual_table_eviction_policy", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param, @@ -692,9 +701,18 @@ def create_grouped_sharding_infos( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), - total_num_buckets=config.total_num_buckets, - use_virtual_table=config.use_virtual_table, - virtual_table_eviction_policy=config.virtual_table_eviction_policy, + total_num_buckets=( + getattr(config, "total_num_buckets", None) + # TODO: Need to check if attribute exists for BC + ), + use_virtual_table=( + getattr(config, "use_virtual_table", None) + # TODO: Need to check if attribute exists for BC + ), + virtual_table_eviction_policy=( + getattr(config, "virtual_table_eviction_policy", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param,