Skip to content

Commit aaa1306

Browse files
chelqxzfacebook-github-bot
authored andcommitted
Fix compatibility issue with EmbeddingBag for IEN Publish (#3181)
Summary: Discovered when working on fix a ien e2e test https://fb.workplace.com/groups/gpuinference/permalink/3213711332110846/ This is not ideal fix, new field will introduce incompatibility again in the future. Following up in post for better solution Reviewed By: hannaxu, faran928 Differential Revision: D78125516
1 parent d95f247 commit aaa1306

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,18 @@ def create_sharding_infos_by_sharding_device_group(
293293
getattr(config, "num_embeddings_post_pruning", None)
294294
# TODO: Need to check if attribute exists for BC
295295
),
296-
total_num_buckets=config.total_num_buckets,
297-
use_virtual_table=config.use_virtual_table,
298-
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
296+
total_num_buckets=(
297+
getattr(config, "total_num_buckets", None)
298+
# TODO: Need to check if attribute exists for BC
299+
),
300+
use_virtual_table=(
301+
getattr(config, "use_virtual_table", None)
302+
# TODO: Need to check if attribute exists for BC
303+
),
304+
virtual_table_eviction_policy=(
305+
getattr(config, "virtual_table_eviction_policy", None)
306+
# TODO: Need to check if attribute exists for BC
307+
),
299308
),
300309
param_sharding=parameter_sharding,
301310
param=param,
@@ -692,9 +701,18 @@ def create_grouped_sharding_infos(
692701
getattr(config, "num_embeddings_post_pruning", None)
693702
# TODO: Need to check if attribute exists for BC
694703
),
695-
total_num_buckets=config.total_num_buckets,
696-
use_virtual_table=config.use_virtual_table,
697-
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
704+
total_num_buckets=(
705+
getattr(config, "total_num_buckets", None)
706+
# TODO: Need to check if attribute exists for BC
707+
),
708+
use_virtual_table=(
709+
getattr(config, "use_virtual_table", None)
710+
# TODO: Need to check if attribute exists for BC
711+
),
712+
virtual_table_eviction_policy=(
713+
getattr(config, "virtual_table_eviction_policy", None)
714+
# TODO: Need to check if attribute exists for BC
715+
),
698716
),
699717
param_sharding=parameter_sharding,
700718
param=param,

0 commit comments

Comments
 (0)