Skip to content

Commit ee12be5

Browse files
Convert TF and Numpy ops in whisper_audio_convert.py to Keras Ops
1 parent 16f6e03 commit ee12be5

File tree

2 files changed

+297
-105
lines changed

2 files changed

+297
-105
lines changed
Lines changed: 211 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import numpy as np
1+
import keras
2+
import keras.ops as ops
23

34
from keras_hub.src.api_export import keras_hub_export
45
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
@@ -33,23 +34,6 @@ class WhisperAudioConverter(AudioConverter):
3334
max_audio_length: int. The length of each audio chunk in
3435
seconds. The input audio tensor will be padded/trimmed to
3536
`max_audio_length * sampling_rate`. Defaults to `30`.
36-
37-
Examples:
38-
```python
39-
audio_tensor = tf.ones((8000,), dtype="float32")
40-
41-
# Compute the log-mel spectrogram.
42-
audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
43-
"whisper_base_en",
44-
)
45-
audio_converter(audio_tensor)
46-
47-
# Compute the log-mel spectrogram for a batch of audio tensors.
48-
audio_tensor_1 = tf.ones((8000,), dtype="float32")
49-
audio_tensor_2 = tf.ones((10000,), dtype="float32")
50-
audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
51-
audio_converter(audio_tensor)
52-
```
5337
"""
5438

5539
backbone_cls = WhisperBackbone
@@ -84,33 +68,35 @@ def audio_shape(self):
8468
"""Returns the preprocessed size of a single audio sample."""
8569
return (self.max_audio_length, self.num_mels)
8670

71+
def _get_rfftfreq_keras(self):
72+
n = self.num_fft_bins
73+
d = 1.0 / self.sampling_rate
74+
75+
if n % 2 == 0:
76+
freqs = ops.arange(0, n // 2 + 1, dtype="float32") / (d * n)
77+
else:
78+
freqs = ops.arange(0, (n - 1) // 2 + 1, dtype="float32") / (d * n)
79+
80+
return freqs
81+
8782
def _get_mel_filters(self):
8883
"""
8984
Adapted from Hugging Face
9085
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
9186
"""
9287

93-
# TODO: Convert to TensorFlow ops (if possible).
9488

95-
dtype = np.float32
89+
dtype = self.compute_dtype # Use the class's dtype
9690
# Initialize the weights
97-
weights = np.zeros(
91+
weights = ops.zeros(
9892
(self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype
9993
)
100-
10194
# Center freqs of each FFT bin
102-
fftfreqs = np.fft.rfftfreq(
103-
n=self.num_fft_bins, d=1.0 / self.sampling_rate
104-
)
105-
95+
fftfreqs = self._get_rfftfreq_keras()
10696
# 'Center freqs' of mel bands - uniformly spaced between limits
10797
min_mel = 0.0
10898
max_mel = 45.245640471924965
109-
110-
mels = np.linspace(min_mel, max_mel, self.num_mels + 2)
111-
112-
mels = np.asanyarray(mels)
113-
99+
mels = ops.linspace(min_mel, max_mel, self.num_mels + 2)
114100
# Fill in the linear scale
115101
f_min = 0.0
116102
f_sp = 200.0 / 3
@@ -119,118 +105,256 @@ def _get_mel_filters(self):
119105
# And now the nonlinear scale
120106
min_log_hz = 1000.0 # beginning of log region (Hz)
121107
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
122-
logstep = np.log(6.4) / 27.0 # step size for log region
123-
108+
logstep = ops.log(6.4) / 27.0 # step size for log region
124109
# If we have vector data, vectorize
125110
log_t = mels >= min_log_mel
126-
freqs[log_t] = min_log_hz * np.exp(
127-
logstep * (mels[log_t] - min_log_mel)
111+
freqs = ops.where(
112+
log_t, min_log_hz * ops.exp(logstep * (mels - min_log_mel)), freqs
128113
)
129-
130114
mel_f = freqs
131115

132-
fdiff = np.diff(mel_f)
133-
ramps = np.subtract.outer(mel_f, fftfreqs)
134116

117+
fdiff = ops.diff(mel_f)
118+
ramps = (
119+
ops.expand_dims(mel_f, axis=1) - fftfreqs
120+
) # keras subtract outer
121+
122+
weights_list = []
135123
for i in range(self.num_mels):
136124
# lower and upper slopes for all bins
137125
lower = -ramps[i] / fdiff[i]
138126
upper = ramps[i + 2] / fdiff[i + 1]
139127

140128
# .. then intersect them with each other and zero
141-
weights[i] = np.maximum(0, np.minimum(lower, upper))
129+
weights_i = ops.maximum(0, ops.minimum(lower, upper))
130+
weights_list.append(weights_i)
131+
132+
weights = ops.stack(weights_list)
142133

143134
# Slaney-style mel is scaled to be approx constant energy per channel
144135
enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels])
145-
weights *= enorm[:, np.newaxis]
136+
weights *= ops.expand_dims(enorm, axis=1)
146137

147-
weights = np.transpose(weights)
148-
return tf.constant(weights, dtype=self.compute_dtype)
138+
weights = ops.transpose(weights)
139+
return weights
149140

150141
def _extract_audio_features(self, audio):
151-
audio = tf.cast(audio, self.compute_dtype)
142+
audio = ops.cast(audio, self.compute_dtype)
152143
# Use "reflection" padding - `tf.signal.stft` uses symmetric padding
153144
# internally.
154-
audio = tf.pad(
145+
audio = ops.pad(
155146
audio,
156-
paddings=[[0, 0], [self.num_fft_bins // 2, self.num_fft_bins // 2]],
157-
mode="REFLECT",
147+
pad_width=[
148+
[0, 0],
149+
[self.num_fft_bins // 2, self.num_fft_bins // 2],
150+
],
151+
mode="reflect",
158152
)
159-
160-
# Compute the mel spectrogram.
161-
stft = tf.signal.stft(
153+
stft = ops.stft(
162154
audio,
163-
frame_length=self.num_fft_bins,
164-
frame_step=self.stride,
155+
sequence_length=self.num_fft_bins,
156+
sequence_stride=self.stride,
165157
fft_length=self.num_fft_bins,
158+
center=False,
166159
)
167-
magnitudes = tf.square(tf.abs(stft[:, :-1, :]))
168-
169-
mel_spec = tf.matmul(
160+
stft = ops.sum(stft, axis=0)
161+
magnitudes = ops.square(ops.absolute(stft[:, :-1, :]))
162+
163+
mel_spec = ops.matmul(
170164
magnitudes,
171165
self.mel_filters,
172166
)
173167

174168
def tf_log10(x):
175169
"""Computes log base 10 of input tensor using TensorFlow."""
176-
numerator = tf.math.log(x)
177-
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
170+
numerator = ops.log(x)
171+
denominator = ops.log(
172+
ops.cast(ops.array(10), dtype=numerator.dtype)
173+
)
178174
return numerator / denominator
179175

180176
# Clamp the values to a minimum value of 1e-10. This is done to avoid
181177
# taking the log of 0, i.e., for numerical stability.
182-
mel_spec = tf.maximum(mel_spec, 1e-10)
178+
mel_spec = ops.maximum(mel_spec, 1e-10)
183179

184180
# Calculate the log mel spectrogram.
185181
log_spec = tf_log10(mel_spec)
186182
# Dynamic range compression.
187-
log_spec_shape = tf.shape(log_spec)
188-
max_value_minus_eight = tf.math.subtract(
189-
tf.math.reduce_max(log_spec, axis=[1, 2]),
190-
tf.cast(8, dtype=log_spec.dtype),
183+
log_spec_shape = ops.shape(log_spec)
184+
max_value_minus_eight = ops.subtract(
185+
ops.max(log_spec, axis=[1, 2]),
186+
ops.cast(8, dtype=log_spec.dtype),
191187
)
192-
max_value_minus_eight = tf.expand_dims(max_value_minus_eight, axis=1)
193-
max_value_minus_eight = tf.repeat(
188+
max_value_minus_eight = ops.expand_dims(max_value_minus_eight, axis=1)
189+
max_value_minus_eight = ops.repeat(
194190
max_value_minus_eight,
195191
repeats=log_spec_shape[1] * log_spec_shape[2],
196192
axis=1,
197193
)
198-
max_value_minus_eight = tf.reshape(
199-
max_value_minus_eight, shape=log_spec_shape
194+
max_value_minus_eight = ops.reshape(
195+
max_value_minus_eight, newshape=log_spec_shape
200196
)
201-
log_spec = tf.maximum(log_spec, max_value_minus_eight)
197+
log_spec = ops.maximum(log_spec, max_value_minus_eight)
202198
# Normalization.
203-
type_cast_four = tf.cast(4, dtype=log_spec.dtype)
204-
log_spec = tf.math.divide(
205-
tf.math.add(log_spec, type_cast_four),
199+
type_cast_four = ops.cast(4, dtype=log_spec.dtype)
200+
log_spec = ops.divide(
201+
ops.add(log_spec, type_cast_four),
206202
type_cast_four,
207203
)
208-
209204
return log_spec
210205

211-
def call(self, audio):
212-
if not isinstance(audio, (tf.Tensor, tf.RaggedTensor)):
213-
audio = tf.convert_to_tensor(audio)
206+
def call(
207+
self,
208+
inputs,
209+
padding=None,
210+
max_length=None,
211+
pad_to_multiple_of=None,
212+
):
213+
input_shape = keras.ops.shape(inputs)
214+
input_rank = (
215+
len(input_shape)
216+
if isinstance(input_shape, (list, tuple))
217+
else input_shape.rank
218+
)
219+
rank_1_input = input_rank == 1
214220

215-
rank_1_input = audio.shape.rank == 1
216221
if rank_1_input:
217-
audio = tf.expand_dims(audio, 0)
222+
inputs = ops.expand_dims(inputs, 0)
223+
inputs = ops.expand_dims(inputs, 0)
224+
225+
# Convert to dense tensor with proper padding/truncation
226+
processed_inputs = self.variable_length_inputs(
227+
inputs, padding, max_length, pad_to_multiple_of
228+
)
218229

219-
# Convert the tensor to a Ragged Tensor.
220-
if isinstance(audio, tf.Tensor):
221-
audio = tf.RaggedTensor.from_tensor(audio)
230+
# Extract features
231+
log_spec = self._extract_audio_features(processed_inputs)
222232

223-
# Pad audio.
224-
audio_shape = audio.shape.as_list()
225-
audio_shape[-1] = self.num_samples
226-
audio = audio.to_tensor(shape=audio_shape)
233+
# Extract features
234+
log_spec = self._extract_audio_features(processed_inputs)
227235

228-
# Find the log mel spectrogram.
229-
log_spec = self._extract_audio_features(audio)
230236
if rank_1_input:
231-
log_spec = tf.squeeze(log_spec, 0)
237+
log_spec = ops.squeeze(log_spec, 0)
238+
239+
232240
return log_spec
233241

242+
# handling variable length inputs
243+
def variable_length_inputs(
244+
self, inputs, padding=None, max_length=None, pad_to_multiple_of=None
245+
):
246+
"""Handles variable length inputs with padding or truncation."""
247+
248+
# Determine the appropriate target length
249+
if padding == "max_length" and max_length is not None:
250+
target_length = max_length
251+
else:
252+
# Use default max_audio_length
253+
target_length = self.num_samples
254+
255+
if pad_to_multiple_of:
256+
target_length = (
257+
(target_length + pad_to_multiple_of - 1) // pad_to_multiple_of
258+
) * pad_to_multiple_of
259+
260+
# Get current shape and length
261+
audio_shape = keras.ops.shape(inputs)
262+
audio_length = audio_shape[1]
263+
264+
if padding == "max_length" and max_length is not None:
265+
is_padding_required = keras.ops.less(audio_length, target_length)
266+
is_trunc_required = keras.ops.greater(audio_length, target_length)
267+
268+
def pad_fn():
269+
padding_amount = target_length - audio_length
270+
paddings = [[0, 0], [0, padding_amount]]
271+
return keras.ops.pad(
272+
inputs,
273+
paddings,
274+
mode="constant",
275+
constant_values=self.padding_value,
276+
)
277+
278+
def trunc_fn():
279+
return keras.ops.slice(
280+
inputs,
281+
[0, 0],
282+
[-1, target_length],
283+
)
284+
285+
# Check if we're in symbolic execution
286+
is_tf_symbolic = (
287+
tf is not None
288+
and hasattr(inputs, "graph")
289+
and hasattr(inputs.graph, "as_graph_def")
290+
)
291+
use_tf_graph_ops = tf is not None and is_tf_symbolic
292+
293+
if use_tf_graph_ops and keras.config.backend() != "torch":
294+
processed_inputs = tf.cond(
295+
is_padding_required,
296+
pad_fn,
297+
lambda: tf.cond(is_trunc_required, trunc_fn,lambda: inputs),
298+
)
299+
else:
300+
is_padding_bool = keras.ops.convert_to_numpy(is_padding_required)
301+
is_trunc_bool = keras.ops.convert_to_numpy(
302+
is_trunc_required
303+
)
304+
305+
if is_padding_bool:
306+
padding_amount = target_length - audio_length
307+
paddings = [[0, 0], [0, padding_amount]]
308+
processed_inputs = keras.ops.pad(
309+
inputs,
310+
paddings,
311+
mode="constant",
312+
constant_values=self.padding_value,
313+
)
314+
elif is_trunc_bool:
315+
processed_inputs = inputs[:, :target_length]
316+
else:
317+
processed_inputs = inputs
318+
else:
319+
# No explicit padding - just pad/truncate to default max length
320+
is_padding_required = keras.ops.less(audio_length, target_length)
321+
is_trunc_required = keras.ops.greater(audio_length, target_length)
322+
323+
# Use eager execution approach for simplicity
324+
is_padding_bool = keras.ops.convert_to_numpy(is_padding_required)
325+
is_trunc_bool = keras.ops.convert_to_numpy(is_trunc_required)
326+
327+
if is_padding_bool:
328+
padding_amount = target_length - audio_length
329+
paddings = [[0, 0], [0, padding_amount]]
330+
processed_inputs = keras.ops.pad(
331+
inputs,
332+
paddings,
333+
mode="constant",
334+
constant_values=self.padding_value,
335+
)
336+
elif is_trunc_bool:
337+
processed_inputs = inputs[:, :target_length]
338+
else:
339+
processed_inputs = inputs
340+
341+
return processed_inputs
342+
343+
def compute_output_shape(self, input_shape):
344+
"""Compute output shape for variable-length inputs."""
345+
346+
if len(input_shape) == 1:
347+
# For single audio sample - returns 2D shape (frames, mels)
348+
num_frames = (self.num_samples + self.stride - 1) // self.stride
349+
return (num_frames, self.num_mels)
350+
elif len(input_shape) == 2:
351+
# For batch of audio samples -returns 3D shape (batch, frames, mels)
352+
batch_size = input_shape[0]
353+
num_frames = (self.num_samples + self.stride - 1) // self.stride
354+
return (batch_size, num_frames, self.num_mels)
355+
else:
356+
raise ValueError("Input shape must be rank 1 or 2.")
357+
234358
def get_config(self):
235359
config = super().get_config()
236360
config.update(
@@ -243,3 +367,5 @@ def get_config(self):
243367
}
244368
)
245369
return config
370+
371+

0 commit comments

Comments
 (0)