Skip to content

Commit fbddf02

Browse files
authored
[tests] properly skip tests instead of return (#11771)
model test updates
1 parent f20b83a commit fbddf02

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

tests/models/test_modeling_common.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from typing import Dict, List, Optional, Tuple, Union
3131

3232
import numpy as np
33+
import pytest
3334
import requests_mock
3435
import safetensors.torch
3536
import torch
@@ -938,8 +939,9 @@ def recursive_check(tuple_object, dict_object):
938939

939940
@require_torch_accelerator_with_training
940941
def test_enable_disable_gradient_checkpointing(self):
942+
# Skip test if model does not support gradient checkpointing
941943
if not self.model_class._supports_gradient_checkpointing:
942-
return # Skip test if model does not support gradient checkpointing
944+
pytest.skip("Gradient checkpointing is not supported.")
943945

944946
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
945947

@@ -957,8 +959,9 @@ def test_enable_disable_gradient_checkpointing(self):
957959

958960
@require_torch_accelerator_with_training
959961
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
962+
# Skip test if model does not support gradient checkpointing
960963
if not self.model_class._supports_gradient_checkpointing:
961-
return # Skip test if model does not support gradient checkpointing
964+
pytest.skip("Gradient checkpointing is not supported.")
962965

963966
# enable deterministic behavior for gradient checkpointing
964967
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1015,8 +1018,9 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
10151018
def test_gradient_checkpointing_is_applied(
10161019
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
10171020
):
1021+
# Skip test if model does not support gradient checkpointing
10181022
if not self.model_class._supports_gradient_checkpointing:
1019-
return # Skip test if model does not support gradient checkpointing
1023+
pytest.skip("Gradient checkpointing is not supported.")
10201024

10211025
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10221026

@@ -1073,7 +1077,7 @@ def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
10731077
model = self.model_class(**init_dict).to(torch_device)
10741078

10751079
if not issubclass(model.__class__, PeftAdapterMixin):
1076-
return
1080+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
10771081

10781082
torch.manual_seed(0)
10791083
output_no_lora = model(**inputs_dict, return_dict=False)[0]
@@ -1128,7 +1132,7 @@ def test_lora_wrong_adapter_name_raises_error(self):
11281132
model = self.model_class(**init_dict).to(torch_device)
11291133

11301134
if not issubclass(model.__class__, PeftAdapterMixin):
1131-
return
1135+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
11321136

11331137
denoiser_lora_config = LoraConfig(
11341138
r=4,
@@ -1159,7 +1163,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_d
11591163
model = self.model_class(**init_dict).to(torch_device)
11601164

11611165
if not issubclass(model.__class__, PeftAdapterMixin):
1162-
return
1166+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
11631167

11641168
denoiser_lora_config = LoraConfig(
11651169
r=rank,
@@ -1196,7 +1200,7 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
11961200
model = self.model_class(**init_dict).to(torch_device)
11971201

11981202
if not issubclass(model.__class__, PeftAdapterMixin):
1199-
return
1203+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
12001204

12011205
denoiser_lora_config = LoraConfig(
12021206
r=4,
@@ -1233,10 +1237,10 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
12331237

12341238
@require_torch_accelerator
12351239
def test_cpu_offload(self):
1240+
if self.model_class._no_split_modules is None:
1241+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
12361242
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
12371243
model = self.model_class(**config).eval()
1238-
if model._no_split_modules is None:
1239-
return
12401244

12411245
model = model.to(torch_device)
12421246

@@ -1263,10 +1267,10 @@ def test_cpu_offload(self):
12631267

12641268
@require_torch_accelerator
12651269
def test_disk_offload_without_safetensors(self):
1270+
if self.model_class._no_split_modules is None:
1271+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
12661272
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
12671273
model = self.model_class(**config).eval()
1268-
if model._no_split_modules is None:
1269-
return
12701274

12711275
model = model.to(torch_device)
12721276

@@ -1296,10 +1300,10 @@ def test_disk_offload_without_safetensors(self):
12961300

12971301
@require_torch_accelerator
12981302
def test_disk_offload_with_safetensors(self):
1303+
if self.model_class._no_split_modules is None:
1304+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
12991305
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
13001306
model = self.model_class(**config).eval()
1301-
if model._no_split_modules is None:
1302-
return
13031307

13041308
model = model.to(torch_device)
13051309

@@ -1324,10 +1328,10 @@ def test_disk_offload_with_safetensors(self):
13241328

13251329
@require_torch_multi_accelerator
13261330
def test_model_parallelism(self):
1331+
if self.model_class._no_split_modules is None:
1332+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
13271333
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
13281334
model = self.model_class(**config).eval()
1329-
if model._no_split_modules is None:
1330-
return
13311335

13321336
model = model.to(torch_device)
13331337

@@ -1426,10 +1430,10 @@ def test_sharded_checkpoints_with_variant(self):
14261430

14271431
@require_torch_accelerator
14281432
def test_sharded_checkpoints_device_map(self):
1433+
if self.model_class._no_split_modules is None:
1434+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
14291435
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
14301436
model = self.model_class(**config).eval()
1431-
if model._no_split_modules is None:
1432-
return
14331437
model = model.to(torch_device)
14341438

14351439
torch.manual_seed(0)
@@ -1497,7 +1501,7 @@ def test_variant_sharded_ckpt_right_format(self):
14971501
def test_layerwise_casting_training(self):
14981502
def test_fn(storage_dtype, compute_dtype):
14991503
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
1500-
return
1504+
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
15011505
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15021506

15031507
model = self.model_class(**init_dict)
@@ -1617,6 +1621,9 @@ def get_memory_usage(storage_dtype, compute_dtype):
16171621
@parameterized.expand([False, True])
16181622
@require_torch_accelerator
16191623
def test_group_offloading(self, record_stream):
1624+
if not self.model_class._supports_group_offloading:
1625+
pytest.skip("Model does not support group offloading.")
1626+
16201627
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
16211628
torch.manual_seed(0)
16221629

@@ -1633,8 +1640,6 @@ def run_forward(model):
16331640
return model(**inputs_dict)[0]
16341641

16351642
model = self.model_class(**init_dict)
1636-
if not getattr(model, "_supports_group_offloading", True):
1637-
return
16381643

16391644
model.to(torch_device)
16401645
output_without_group_offloading = run_forward(model)
@@ -1670,13 +1675,13 @@ def run_forward(model):
16701675
@require_torch_accelerator
16711676
@torch.no_grad()
16721677
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
1678+
if not self.model_class._supports_group_offloading:
1679+
pytest.skip("Model does not support group offloading.")
1680+
16731681
torch.manual_seed(0)
16741682
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
16751683
model = self.model_class(**init_dict)
16761684

1677-
if not getattr(model, "_supports_group_offloading", True):
1678-
return
1679-
16801685
model.to(torch_device)
16811686
model.eval()
16821687
_ = model(**inputs_dict)[0]
@@ -1698,13 +1703,13 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
16981703
@require_torch_accelerator
16991704
@torch.no_grad()
17001705
def test_group_offloading_with_disk(self, record_stream, offload_type):
1706+
if not self.model_class._supports_group_offloading:
1707+
pytest.skip("Model does not support group offloading.")
1708+
17011709
torch.manual_seed(0)
17021710
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17031711
model = self.model_class(**init_dict)
17041712

1705-
if not getattr(model, "_supports_group_offloading", True):
1706-
return
1707-
17081713
torch.manual_seed(0)
17091714
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17101715
model = self.model_class(**init_dict)

0 commit comments

Comments
 (0)