diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 6a9a3d299..ebcf61f90 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -293,9 +293,18 @@ def create_sharding_infos_by_sharding_device_group( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), - total_num_buckets=config.total_num_buckets, - use_virtual_table=config.use_virtual_table, - virtual_table_eviction_policy=config.virtual_table_eviction_policy, + total_num_buckets=( + getattr(config, "total_num_buckets", None) + # TODO: Need to check if attribute exists for BC + ), + use_virtual_table=( + getattr(config, "use_virtual_table", None) + # TODO: Need to check if attribute exists for BC + ), + virtual_table_eviction_policy=( + getattr(config, "virtual_table_eviction_policy", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param, @@ -692,9 +701,18 @@ def create_grouped_sharding_infos( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), - total_num_buckets=config.total_num_buckets, - use_virtual_table=config.use_virtual_table, - virtual_table_eviction_policy=config.virtual_table_eviction_policy, + total_num_buckets=( + getattr(config, "total_num_buckets", None) + # TODO: Need to check if attribute exists for BC + ), + use_virtual_table=( + getattr(config, "use_virtual_table", None) + # TODO: Need to check if attribute exists for BC + ), + virtual_table_eviction_policy=( + getattr(config, "virtual_table_eviction_policy", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param,