Skip to content

Commit a325409

Browse files
authored
Expectations re-order and corrected FA3 skip (#39195)
* Fix Expectations and a FA3 skip * Fixed docstring * Added context for Default expectation
1 parent b0a8e0b commit a325409

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

src/transformers/testing_utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3338,17 +3338,25 @@ def unpacked(self) -> list[tuple[DeviceProperties, Any]]:
33383338
return [(unpack_device_properties(k), v) for k, v in self.data.items()]
33393339

33403340
@staticmethod
3341-
def is_default(properties: DeviceProperties) -> bool:
3342-
return all(p is None for p in properties)
3341+
def is_default(expectation_key: PackedDeviceProperties) -> bool:
3342+
"""
3343+
This function returns True if the expectation_key is the Default expectation (None, None).
3344+
When an Expectation dict contains a Default value, it is generally because the test existed before Expectations.
3345+
When we modify a test to use Expectations for a specific hardware, we don't want to affect the tests on other
3346+
hardwares. Thus we set the previous value as the Default expectation with key (None, None) and add a value for
3347+
the specific hardware with key (hardware_type, (major, minor)).
3348+
"""
3349+
return all(p is None for p in expectation_key)
33433350

33443351
@staticmethod
33453352
def score(properties: DeviceProperties, other: DeviceProperties) -> float:
33463353
"""
33473354
Returns score indicating how similar two instances of the `Properties` tuple are.
33483355
Rules are as follows:
3349-
* Matching `type` adds one point, semi-matching `type` adds half a point (e.g. cuda and rocm).
3356+
* Matching `type` adds one point, semi-matching `type` adds 0.1 point (e.g. cuda and rocm).
33503357
* If types match, matching `major` adds another point, and then matching `minor` adds another.
3351-
* Default expectation (if present) is worth 0.1 point to distinguish it from a straight-up zero.
3358+
* The Default expectation (None, None) is worth 0.5 point, which is better than semi-matching. More on this
3359+
in the `is_default` function.
33523360
"""
33533361
device_type, major, minor = properties
33543362
other_device_type, other_major, other_minor = other
@@ -3361,13 +3369,13 @@ def score(properties: DeviceProperties, other: DeviceProperties) -> float:
33613369
score += 1
33623370
if minor is not None and minor == other_minor:
33633371
score += 1
3364-
# Semi-matching device type
3372+
# Semi-matching device type, which carries less importance than the default expectation
33653373
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
3366-
score = 0.5
3374+
score = 0.1
33673375

33683376
# Default expectation
33693377
if Expectations.is_default(other):
3370-
score = 0.1
3378+
score = 0.5
33713379

33723380
return score
33733381

tests/test_modeling_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4306,7 +4306,7 @@ def flash_attn_from_config(self, attn_implementation: str):
43064306
def test_flash_attn_2_from_config(self):
43074307
self.flash_attn_from_config(attn_implementation="flash_attention_2")
43084308

4309-
@require_flash_attn
4309+
@require_flash_attn_3
43104310
@require_torch_gpu
43114311
@mark.flash_attn_3_test
43124312
@slow

0 commit comments

Comments
 (0)