|
19 | 19 | from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
|
20 | 20 | PartiallyMaterializedTensor,
|
21 | 21 | )
|
22 |
| -from hypothesis import given, settings, strategies as st, Verbosity |
| 22 | +from hypothesis import assume, given, settings, strategies as st, Verbosity |
23 | 23 | from torch import distributed as dist
|
24 | 24 | from torch.distributed._shard.sharded_tensor import ShardedTensor
|
25 | 25 | from torch.distributed._tensor import DTensor
|
@@ -624,11 +624,10 @@ def test_load_state_dict(
|
624 | 624 | kernel_type: str,
|
625 | 625 | is_training: bool,
|
626 | 626 | ) -> None:
|
627 |
| - if ( |
628 |
| - self.device == torch.device("cpu") |
629 |
| - and kernel_type != EmbeddingComputeKernel.FUSED.value |
630 |
| - ): |
631 |
| - self.skipTest("CPU does not support uvm.") |
| 627 | + assume( |
| 628 | + self.device != torch.device("cpu") |
| 629 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 630 | + ) |
632 | 631 |
|
633 | 632 | sharders = [
|
634 | 633 | cast(
|
@@ -683,11 +682,10 @@ def test_optimizer_load_state_dict(
|
683 | 682 | sharding_type: str,
|
684 | 683 | kernel_type: str,
|
685 | 684 | ) -> None:
|
686 |
| - if ( |
687 |
| - self.device == torch.device("cpu") |
688 |
| - and kernel_type != EmbeddingComputeKernel.FUSED.value |
689 |
| - ): |
690 |
| - self.skipTest("CPU does not support uvm.") |
| 685 | + assume( |
| 686 | + self.device != torch.device("cpu") |
| 687 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 688 | + ) |
691 | 689 |
|
692 | 690 | sharders = [
|
693 | 691 | cast(
|
@@ -800,11 +798,10 @@ def test_load_state_dict_dp(
|
800 | 798 | def test_load_state_dict_prefix(
|
801 | 799 | self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool
|
802 | 800 | ) -> None:
|
803 |
| - if ( |
804 |
| - self.device == torch.device("cpu") |
805 |
| - and kernel_type != EmbeddingComputeKernel.FUSED.value |
806 |
| - ): |
807 |
| - self.skipTest("CPU does not support uvm.") |
| 801 | + assume( |
| 802 | + self.device != torch.device("cpu") |
| 803 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 804 | + ) |
808 | 805 |
|
809 | 806 | sharders = [
|
810 | 807 | cast(
|
@@ -855,11 +852,10 @@ def test_load_state_dict_prefix(
|
855 | 852 | def test_params_and_buffers(
|
856 | 853 | self, sharder_type: str, sharding_type: str, kernel_type: str
|
857 | 854 | ) -> None:
|
858 |
| - if ( |
859 |
| - self.device == torch.device("cpu") |
860 |
| - and kernel_type != EmbeddingComputeKernel.FUSED.value |
861 |
| - ): |
862 |
| - self.skipTest("CPU does not support uvm.") |
| 855 | + assume( |
| 856 | + self.device != torch.device("cpu") |
| 857 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 858 | + ) |
863 | 859 |
|
864 | 860 | sharders = [
|
865 | 861 | create_test_sharder(sharder_type, sharding_type, kernel_type),
|
@@ -897,11 +893,10 @@ def test_params_and_buffers(
|
897 | 893 | def test_load_state_dict_cw_multiple_shards(
|
898 | 894 | self, sharder_type: str, sharding_type: str, kernel_type: str, is_training: bool
|
899 | 895 | ) -> None:
|
900 |
| - if ( |
901 |
| - self.device == torch.device("cpu") |
902 |
| - and kernel_type != EmbeddingComputeKernel.FUSED.value |
903 |
| - ): |
904 |
| - self.skipTest("CPU does not support uvm.") |
| 896 | + assume( |
| 897 | + self.device != torch.device("cpu") |
| 898 | + or kernel_type == EmbeddingComputeKernel.FUSED.value |
| 899 | + ) |
905 | 900 |
|
906 | 901 | sharders = [
|
907 | 902 | cast(
|
|
0 commit comments