@@ -994,10 +994,10 @@ def summary_failures_short(tr):
994
994
config .option .tbstyle = orig_tbstyle
995
995
996
996
997
- # Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
997
+ # Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
998
998
def is_flaky (max_attempts : int = 5 , wait_before_retry : Optional [float ] = None , description : Optional [str ] = None ):
999
999
"""
1000
- To decorate flaky tests. They will be retried on failures.
1000
+ To decorate flaky tests (methods or entire classes) . They will be retried on failures.
1001
1001
1002
1002
Args:
1003
1003
max_attempts (`int`, *optional*, defaults to 5):
@@ -1009,22 +1009,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
1009
1009
etc.)
1010
1010
"""
1011
1011
1012
- def decorator (test_func_ref ):
1013
- @functools .wraps (test_func_ref )
1012
+ def decorator (obj ):
1013
+ # If decorating a class, wrap each test method on it
1014
+ if inspect .isclass (obj ):
1015
+ for attr_name , attr_value in list (obj .__dict__ .items ()):
1016
+ if callable (attr_value ) and attr_name .startswith ("test" ):
1017
+ # recursively decorate the method
1018
+ setattr (obj , attr_name , decorator (attr_value ))
1019
+ return obj
1020
+
1021
+ # Otherwise we're decorating a single test function / method
1022
+ @functools .wraps (obj )
1014
1023
def wrapper (* args , ** kwargs ):
1015
1024
retry_count = 1
1016
-
1017
1025
while retry_count < max_attempts :
1018
1026
try :
1019
- return test_func_ref (* args , ** kwargs )
1020
-
1027
+ return obj (* args , ** kwargs )
1021
1028
except Exception as err :
1022
- print (f"Test failed with { err } at try { retry_count } /{ max_attempts } ." , file = sys .stderr )
1029
+ msg = (
1030
+ f"[FLAKY] { description or obj .__name__ !r} "
1031
+ f"failed on attempt { retry_count } /{ max_attempts } : { err } "
1032
+ )
1033
+ print (msg , file = sys .stderr )
1023
1034
if wait_before_retry is not None :
1024
1035
time .sleep (wait_before_retry )
1025
1036
retry_count += 1
1026
1037
1027
- return test_func_ref (* args , ** kwargs )
1038
+ return obj (* args , ** kwargs )
1028
1039
1029
1040
return wrapper
1030
1041
0 commit comments