Skip to content

Commit 0d6ef90

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
Fix skipped tests for test_model_parallel_gloo (#3196)
Summary: Pull Request resolved: #3196 `ModelParallelStateDictTestGloo: test_optimizer_load_state_dict` test is frequently getting skipped because some of the examples generated by the framework hits the skipTest() condition, which is: - Using CPU with UVM Kernel modes (FUSED_UVM, FUSED_UVM_CACHING) While iterating through each generated example, the test will consider the entire test "skipped" if any of them hit the skipTest condition. Instead, we should just skip the example so that hypothesis can generate the next example which is valid. Reviewed By: jd7-tr Differential Revision: D78355780 fbshipit-source-id: 8ad17b3953e3b1cb2bffcf4e25d6e0537410b66c
1 parent a65363e commit 0d6ef90

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
2020
PartiallyMaterializedTensor,
2121
)
22-
from hypothesis import given, settings, strategies as st, Verbosity
22+
from hypothesis import assume, given, settings, strategies as st, Verbosity
2323
from torch import distributed as dist
2424
from torch.distributed._shard.sharded_tensor import ShardedTensor
2525
from torch.distributed._tensor import DTensor
@@ -624,11 +624,10 @@ def test_load_state_dict(
624624
kernel_type: str,
625625
is_training: bool,
626626
) -> None:
627-
if (
628-
self.device == torch.device("cpu")
629-
and kernel_type != EmbeddingComputeKernel.FUSED.value
630-
):
631-
self.skipTest("CPU does not support uvm.")
627+
assume(
628+
self.device != torch.device("cpu")
629+
or kernel_type == EmbeddingComputeKernel.FUSED.value
630+
)
632631

633632
sharders = [
634633
cast(
@@ -683,11 +682,10 @@ def test_optimizer_load_state_dict(
683682
sharding_type: str,
684683
kernel_type: str,
685684
) -> None:
686-
if (
687-
self.device == torch.device("cpu")
688-
and kernel_type != EmbeddingComputeKernel.FUSED.value
689-
):
690-
self.skipTest("CPU does not support uvm.")
685+
assume(
686+
self.device != torch.device("cpu")
687+
or kernel_type == EmbeddingComputeKernel.FUSED.value
688+
)
691689

692690
sharders = [
693691
cast(
@@ -800,11 +798,10 @@ def test_load_state_dict_dp(
800798
def test_load_state_dict_prefix(
801799
self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool
802800
) -> None:
803-
if (
804-
self.device == torch.device("cpu")
805-
and kernel_type != EmbeddingComputeKernel.FUSED.value
806-
):
807-
self.skipTest("CPU does not support uvm.")
801+
assume(
802+
self.device != torch.device("cpu")
803+
or kernel_type == EmbeddingComputeKernel.FUSED.value
804+
)
808805

809806
sharders = [
810807
cast(
@@ -855,11 +852,10 @@ def test_load_state_dict_prefix(
855852
def test_params_and_buffers(
856853
self, sharder_type: str, sharding_type: str, kernel_type: str
857854
) -> None:
858-
if (
859-
self.device == torch.device("cpu")
860-
and kernel_type != EmbeddingComputeKernel.FUSED.value
861-
):
862-
self.skipTest("CPU does not support uvm.")
855+
assume(
856+
self.device != torch.device("cpu")
857+
or kernel_type == EmbeddingComputeKernel.FUSED.value
858+
)
863859

864860
sharders = [
865861
create_test_sharder(sharder_type, sharding_type, kernel_type),
@@ -897,11 +893,10 @@ def test_params_and_buffers(
897893
def test_load_state_dict_cw_multiple_shards(
898894
self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool
899895
) -> None:
900-
if (
901-
self.device == torch.device("cpu")
902-
and kernel_type != EmbeddingComputeKernel.FUSED.value
903-
):
904-
self.skipTest("CPU does not support uvm.")
896+
assume(
897+
self.device != torch.device("cpu")
898+
or kernel_type == EmbeddingComputeKernel.FUSED.value
899+
)
905900

906901
sharders = [
907902
cast(

0 commit comments

Comments
 (0)