Skip to content

Commit 301129d

Browse files
Update whisper_audio_converter.py
1 parent 51bfeb2 commit 301129d

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

keras_hub/src/models/whisper/whisper_audio_converter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def _extract_audio_features(self, audio):
174174
)
175175
stft = ops.sum(stft, axis=0)
176176
magnitudes = ops.square(ops.absolute(stft[:, :-1, :]))
177-
177+
# magnitudes = ops.square(ops.sqrt(ops.square(stft_real) + ops.square(stft_imag)))
178+
# mel_filters_casted = ops.cast(self.mel_filters, dtype=magnitudes.dtype)
179+
178180
mel_spec = ops.matmul(
179181
magnitudes,
180182
self.mel_filters,
@@ -258,10 +260,14 @@ def call(
258260
if rank_1_input:
259261
inputs = ops.expand_dims(inputs, 0)
260262

261-
# Convert to dense tensor with proper padding/truncation
262-
processed_inputs = self.variable_length_inputs(
263-
inputs, padding, max_length, pad_to_multiple_of
264-
)
263+
# Convert the tensor to a Ragged Tensor.
264+
if isinstance(audio, tf.Tensor):
265+
audio = tf.RaggedTensor.from_tensor(audio)
266+
267+
# Pad audio.
268+
audio_shape = audio.shape.as_list()
269+
audio_shape[-1] = self.num_samples
270+
audio = audio.to_tensor(shape=audio_shape)
265271

266272
# Extract features
267273
log_spec = self._extract_audio_features(processed_inputs)

0 commit comments

Comments
 (0)