30
30
from typing import Dict , List , Optional , Tuple , Union
31
31
32
32
import numpy as np
33
+ import pytest
33
34
import requests_mock
34
35
import safetensors .torch
35
36
import torch
@@ -938,8 +939,9 @@ def recursive_check(tuple_object, dict_object):
938
939
939
940
@require_torch_accelerator_with_training
940
941
def test_enable_disable_gradient_checkpointing (self ):
942
+ # Skip test if model does not support gradient checkpointing
941
943
if not self .model_class ._supports_gradient_checkpointing :
942
- return # Skip test if model does not support gradient checkpointing
944
+ pytest . skip ( "Gradient checkpointing is not supported." )
943
945
944
946
init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
945
947
@@ -957,8 +959,9 @@ def test_enable_disable_gradient_checkpointing(self):
957
959
958
960
@require_torch_accelerator_with_training
959
961
def test_effective_gradient_checkpointing (self , loss_tolerance = 1e-5 , param_grad_tol = 5e-5 , skip : set [str ] = {}):
962
+ # Skip test if model does not support gradient checkpointing
960
963
if not self .model_class ._supports_gradient_checkpointing :
961
- return # Skip test if model does not support gradient checkpointing
964
+ pytest . skip ( "Gradient checkpointing is not supported." )
962
965
963
966
# enable deterministic behavior for gradient checkpointing
964
967
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -1015,8 +1018,9 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
1015
1018
def test_gradient_checkpointing_is_applied (
1016
1019
self , expected_set = None , attention_head_dim = None , num_attention_heads = None , block_out_channels = None
1017
1020
):
1021
+ # Skip test if model does not support gradient checkpointing
1018
1022
if not self .model_class ._supports_gradient_checkpointing :
1019
- return # Skip test if model does not support gradient checkpointing
1023
+ pytest . skip ( "Gradient checkpointing is not supported." )
1020
1024
1021
1025
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1022
1026
@@ -1073,7 +1077,7 @@ def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
1073
1077
model = self .model_class (** init_dict ).to (torch_device )
1074
1078
1075
1079
if not issubclass (model .__class__ , PeftAdapterMixin ):
1076
- return
1080
+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
1077
1081
1078
1082
torch .manual_seed (0 )
1079
1083
output_no_lora = model (** inputs_dict , return_dict = False )[0 ]
@@ -1128,7 +1132,7 @@ def test_lora_wrong_adapter_name_raises_error(self):
1128
1132
model = self .model_class (** init_dict ).to (torch_device )
1129
1133
1130
1134
if not issubclass (model .__class__ , PeftAdapterMixin ):
1131
- return
1135
+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
1132
1136
1133
1137
denoiser_lora_config = LoraConfig (
1134
1138
r = 4 ,
@@ -1159,7 +1163,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_d
1159
1163
model = self .model_class (** init_dict ).to (torch_device )
1160
1164
1161
1165
if not issubclass (model .__class__ , PeftAdapterMixin ):
1162
- return
1166
+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
1163
1167
1164
1168
denoiser_lora_config = LoraConfig (
1165
1169
r = rank ,
@@ -1196,7 +1200,7 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
1196
1200
model = self .model_class (** init_dict ).to (torch_device )
1197
1201
1198
1202
if not issubclass (model .__class__ , PeftAdapterMixin ):
1199
- return
1203
+ pytest . skip ( f"PEFT is not supported for this model ( { model . __class__ . __name__ } )." )
1200
1204
1201
1205
denoiser_lora_config = LoraConfig (
1202
1206
r = 4 ,
@@ -1233,10 +1237,10 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
1233
1237
1234
1238
@require_torch_accelerator
1235
1239
def test_cpu_offload (self ):
1240
+ if self .model_class ._no_split_modules is None :
1241
+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
1236
1242
config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1237
1243
model = self .model_class (** config ).eval ()
1238
- if model ._no_split_modules is None :
1239
- return
1240
1244
1241
1245
model = model .to (torch_device )
1242
1246
@@ -1263,10 +1267,10 @@ def test_cpu_offload(self):
1263
1267
1264
1268
@require_torch_accelerator
1265
1269
def test_disk_offload_without_safetensors (self ):
1270
+ if self .model_class ._no_split_modules is None :
1271
+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
1266
1272
config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1267
1273
model = self .model_class (** config ).eval ()
1268
- if model ._no_split_modules is None :
1269
- return
1270
1274
1271
1275
model = model .to (torch_device )
1272
1276
@@ -1296,10 +1300,10 @@ def test_disk_offload_without_safetensors(self):
1296
1300
1297
1301
@require_torch_accelerator
1298
1302
def test_disk_offload_with_safetensors (self ):
1303
+ if self .model_class ._no_split_modules is None :
1304
+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
1299
1305
config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1300
1306
model = self .model_class (** config ).eval ()
1301
- if model ._no_split_modules is None :
1302
- return
1303
1307
1304
1308
model = model .to (torch_device )
1305
1309
@@ -1324,10 +1328,10 @@ def test_disk_offload_with_safetensors(self):
1324
1328
1325
1329
@require_torch_multi_accelerator
1326
1330
def test_model_parallelism (self ):
1331
+ if self .model_class ._no_split_modules is None :
1332
+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
1327
1333
config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1328
1334
model = self .model_class (** config ).eval ()
1329
- if model ._no_split_modules is None :
1330
- return
1331
1335
1332
1336
model = model .to (torch_device )
1333
1337
@@ -1426,10 +1430,10 @@ def test_sharded_checkpoints_with_variant(self):
1426
1430
1427
1431
@require_torch_accelerator
1428
1432
def test_sharded_checkpoints_device_map (self ):
1433
+ if self .model_class ._no_split_modules is None :
1434
+ pytest .skip ("Test not supported for this model as `_no_split_modules` is not set." )
1429
1435
config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1430
1436
model = self .model_class (** config ).eval ()
1431
- if model ._no_split_modules is None :
1432
- return
1433
1437
model = model .to (torch_device )
1434
1438
1435
1439
torch .manual_seed (0 )
@@ -1497,7 +1501,7 @@ def test_variant_sharded_ckpt_right_format(self):
1497
1501
def test_layerwise_casting_training (self ):
1498
1502
def test_fn (storage_dtype , compute_dtype ):
1499
1503
if torch .device (torch_device ).type == "cpu" and compute_dtype == torch .bfloat16 :
1500
- return
1504
+ pytest . skip ( "Skipping test because CPU doesn't go well with bfloat16." )
1501
1505
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1502
1506
1503
1507
model = self .model_class (** init_dict )
@@ -1617,6 +1621,9 @@ def get_memory_usage(storage_dtype, compute_dtype):
1617
1621
@parameterized .expand ([False , True ])
1618
1622
@require_torch_accelerator
1619
1623
def test_group_offloading (self , record_stream ):
1624
+ if not self .model_class ._supports_group_offloading :
1625
+ pytest .skip ("Model does not support group offloading." )
1626
+
1620
1627
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1621
1628
torch .manual_seed (0 )
1622
1629
@@ -1633,8 +1640,6 @@ def run_forward(model):
1633
1640
return model (** inputs_dict )[0 ]
1634
1641
1635
1642
model = self .model_class (** init_dict )
1636
- if not getattr (model , "_supports_group_offloading" , True ):
1637
- return
1638
1643
1639
1644
model .to (torch_device )
1640
1645
output_without_group_offloading = run_forward (model )
@@ -1670,13 +1675,13 @@ def run_forward(model):
1670
1675
@require_torch_accelerator
1671
1676
@torch .no_grad ()
1672
1677
def test_group_offloading_with_layerwise_casting (self , record_stream , offload_type ):
1678
+ if not self .model_class ._supports_group_offloading :
1679
+ pytest .skip ("Model does not support group offloading." )
1680
+
1673
1681
torch .manual_seed (0 )
1674
1682
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1675
1683
model = self .model_class (** init_dict )
1676
1684
1677
- if not getattr (model , "_supports_group_offloading" , True ):
1678
- return
1679
-
1680
1685
model .to (torch_device )
1681
1686
model .eval ()
1682
1687
_ = model (** inputs_dict )[0 ]
@@ -1698,13 +1703,13 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
1698
1703
@require_torch_accelerator
1699
1704
@torch .no_grad ()
1700
1705
def test_group_offloading_with_disk (self , record_stream , offload_type ):
1706
+ if not self .model_class ._supports_group_offloading :
1707
+ pytest .skip ("Model does not support group offloading." )
1708
+
1701
1709
torch .manual_seed (0 )
1702
1710
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1703
1711
model = self .model_class (** init_dict )
1704
1712
1705
- if not getattr (model , "_supports_group_offloading" , True ):
1706
- return
1707
-
1708
1713
torch .manual_seed (0 )
1709
1714
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1710
1715
model = self .model_class (** init_dict )
0 commit comments