From 9802f8a1347a74567ffa8b4ddd63642bc503b1ea Mon Sep 17 00:00:00 2001 From: James Dong Date: Fri, 18 Jul 2025 13:38:22 -0700 Subject: [PATCH] Fix skipped tests in MultiRankDMPDynamicShardingTest Summary: Fix skipped tests. See more context in D78355780. https://www.internalfb.com/intern/test/562950182314458?ref_report_id=0 Differential Revision: D78583353 --- torchrec/distributed/tests/test_dynamic_sharding.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index eddf15faa..f421ae1bb 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -583,11 +583,10 @@ def test_sharding( """ Tests resharding from DMP module interface, rather than EBC level. """ - if ( - self.device == torch.device("cpu") - and kernel_type != EmbeddingComputeKernel.FUSED.value - ): - self.skipTest("CPU does not support uvm.") + assume( + self.device != torch.device("cpu") + or kernel_type == EmbeddingComputeKernel.FUSED.value + ) assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value