Skip to content

Inconsistant input_feature length and attention_mask length in WhisperFeatureExtractor #39214

Open
@BakerBunker

Description

@BakerBunker

System Info

transformers main branch

Who can help?

@eustlb @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoProcessor
import numpy as np

audios=[np.random.randn(16000*5)]
processor=AutoProcessor.from_pretrained("openai/whisper-large-v3")
print(processor(
    [audios[0][: 160 * 5 - 1]], return_attention_mask=True, sampling_rate=16000, padding=False
)["attention_mask"].shape)
print(processor(
    [audios[0][: 160 * 5]], return_attention_mask=True, sampling_rate=16000, padding=False
)["attention_mask"].shape)
print(processor(
    [audios[0][: 160 * 5 + 1]], return_attention_mask=True, sampling_rate=16000, padding=False
)["attention_mask"].shape)

Expected behavior

The input_feature and attention_mask length should be audio_length // hop_length, but the code here:

if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]

makes the attention_mask length equal to (audio_length+hop_length-1) // hop_length, it should be changed to:

diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py
index 68c52c6eb3..b9f3b4cb35 100644
--- a/src/transformers/models/whisper/feature_extraction_whisper.py
+++ b/src/transformers/models/whisper/feature_extraction_whisper.py
@@ -326,7 +326,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
 
         if return_attention_mask:
             # rescale from sample (48000) to feature (3000)
-            padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
+            padded_inputs["attention_mask"] = padded_inputs["attention_mask"][
+                :, self.hop_length - 1 :: self.hop_length
+             ]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions