Skip to content

Commit e6a8063

Browse files
authored
Update expected values (after switching to A10) - part 8 - Final (#39220)
* fix * fix --------- Co-authored-by: ydshieh <[email protected]>
1 parent cd8a041 commit e6a8063

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

tests/models/moonshine/test_modeling_moonshine.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -494,14 +494,16 @@ def test_tiny_logits_batch(self):
494494
inputs.to(torch_device)
495495
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
496496
# fmt: off
497-
EXPECTED_LOGITS = torch.tensor([
498-
[-8.0109, 5.0241, 4.5979, -6.8125, -7.1675, -7.8783, -7.2152, -7.5188, -7.9077, -7.7394],
499-
[-4.4399, -1.4422, 6.6710, -6.8929, -7.3751, -7.0969, -6.5257, -7.0257, -7.2585, -7.0008],
500-
[-10.0086, 3.2859, 0.7345, -6.5557, -6.8514, -6.5308, -6.4172, -6.9484, -6.6214, -6.6229],
501-
[-10.8078, 4.0030, -0.0633, -5.0505, -5.3906, -5.4590, -5.2420, -5.4746, -5.2665, -5.3158]
502-
])
497+
EXPECTED_LOGITS = torch.tensor(
498+
[
499+
[-8.5966, 4.8608, 5.8849, -6.6183, -7.0378, -7.7121, -7.0640, -7.3839, -7.8330, -7.6116],
500+
[-4.3147, -2.4953, 8.4924, -6.4803, -7.0949, -6.7498, -6.1081, -6.6481, -6.9866, -6.5916],
501+
[-10.0088, 3.2862, 0.7342, -6.5559, -6.8514, -6.5309, -6.4173, -6.9485, -6.6215, -6.6230],
502+
[-11.1002, 3.9398, 0.6674, -5.0146, -5.3936, -5.4099, -5.2236, -5.4404, -5.2200, -5.2702],
503+
],
504+
)
503505
# fmt: on
504-
torch.testing.assert_close(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
506+
torch.testing.assert_close(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, rtol=2e-4, atol=2e-4)
505507

506508
@slow
507509
def test_base_logits_batch(self):
@@ -513,15 +515,16 @@ def test_base_logits_batch(self):
513515
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
514516

515517
# fmt: off
516-
EXPECTED_LOGITS = torch.tensor([
517-
[-7.7272, 1.4630, 5.2294, -7.7313, -7.6252, -7.6011, -7.6788, -7.6441, -7.8452, -7.7549],
518-
[-6.2173, -0.5891, 7.9493, -7.0694, -6.9997, -6.9982, -7.0953, -7.0831, -7.1686, -7.0137],
519-
[-7.3184, 3.1192, 3.8937, -5.7206, -5.8428, -5.7609, -5.9996, -5.8212, -5.8615, -5.8719],
520-
[-9.5475, 1.0146, 4.1179, -5.9971, -6.0614, -6.0329, -6.2103, -6.0318, -6.0789, -6.0873]
521-
])
522-
518+
EXPECTED_LOGITS = torch.tensor(
519+
[
520+
[-6.3602, 1.8383, 5.2615, -7.9576, -7.8442, -7.8238, -7.9014, -7.8645, -8.0550, -7.9963],
521+
[-6.1725, -0.6274, 8.1798, -6.8570, -6.8078, -6.7915, -6.9099, -6.8980, -6.9760, -6.8264],
522+
[-7.3186, 3.1192, 3.8938, -5.7208, -5.8429, -5.7610, -5.9997, -5.8213, -5.8616, -5.8720],
523+
[-7.3432, 1.0402, 3.9912, -5.4177, -5.4890, -5.4573, -5.6516, -5.4776, -5.5079, -5.5391],
524+
]
525+
)
523526
# fmt: on
524-
torch.testing.assert_close(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
527+
torch.testing.assert_close(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, rtol=2e-4, atol=2e-4)
525528

526529
@slow
527530
def test_tiny_generation_single(self):

tests/models/regnet/test_modeling_regnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from transformers import RegNetConfig
1919
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
20-
from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
20+
from transformers.testing_utils import Expectations, is_flaky, require_torch, require_vision, slow, torch_device
2121

2222
from ...test_configuration_common import ConfigTester
2323
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -146,6 +146,7 @@ def setUp(self):
146146
def test_config(self):
147147
self.config_tester.run_common_tests()
148148

149+
@is_flaky(description="Larger difference with A10. Still flaky after setting larger tolerance")
149150
def test_batching_equivalence(self, atol=3e-5, rtol=3e-5):
150151
super().test_batching_equivalence(atol=atol, rtol=rtol)
151152

0 commit comments

Comments
 (0)