Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,14 @@ def __call__(

if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length]

# The STFT computation produces L//hop_length + 1 frames, but we skip the last frame (see `_torch_extract_fbank_features`).
# This means we need to trim the rescaled attention mask to match the actual number of frames (L//hop_length) when the input length
# is not perfectly divisible by the hop length.
if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0:
rescaled_attention_mask = rescaled_attention_mask[:, :-1]
padded_inputs["attention_mask"] = rescaled_attention_mask

if return_token_timestamps is not None:
logger.warning_once(
Expand Down
33 changes: 33 additions & 0 deletions tests/models/whisper/test_feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,39 @@ def test_dither(self):
self.assertTrue(np.abs(diff).mean() <= 1e-4)
self.assertTrue(np.abs(diff).max() <= 5e-3)

def test_feature_shape(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
hop_length = feature_extractor.hop_length
test_inputs = np.random.randn(16000)

self.assertTrue(
feature_extractor(
[test_inputs[: hop_length * 5 + 1]],
return_attention_mask=True,
padding=False,
return_tensors="np",
).attention_mask.shape[-1]
== 5
)
self.assertTrue(
feature_extractor(
[test_inputs[: hop_length * 5]],
return_attention_mask=True,
padding=False,
return_tensors="np",
).attention_mask.shape[-1]
== 5
)
self.assertTrue(
feature_extractor(
[test_inputs[: hop_length * 5 - 1]],
return_attention_mask=True,
padding=False,
return_tensors="np",
).attention_mask.shape[-1]
== 4
)

@require_torch
def test_double_precision_pad(self):
import torch
Expand Down
2 changes: 1 addition & 1 deletion tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,7 +2158,7 @@ def test_tiny_token_timestamp_generation_longform(self):
torch.tensor([44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400, 50.5400]),
torch.tensor([50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400, 52.9600]),
torch.tensor([52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1600, 58.5200, 58.6400, 58.8200, 59.4200, 59.4200]),
torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.3800, 62.4400])
torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.4200, 62.4200])
]
# fmt: on

Expand Down