Skip to content

Commit 58cfbea

Browse files
Convert TF and Numpy ops in whisper_audio_convert.py to Keras Ops
1 parent 9d319ff commit 58cfbea

File tree

2 files changed

+76
-66
lines changed

2 files changed

+76
-66
lines changed

keras_hub/src/models/whisper/whisper_audio_converter.py

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1+
import keras.ops as ops
12
import numpy as np
3+
import tensorflow as tf
24

35
from keras_hub.src.api_export import keras_hub_export
46
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
57
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
68

7-
try:
8-
import tensorflow as tf
9-
except ImportError:
10-
tf = None
11-
129

1310
@keras_hub_export("keras_hub.layers.WhisperAudioConverter")
1411
class WhisperAudioConverter(AudioConverter):
@@ -36,7 +33,7 @@ class WhisperAudioConverter(AudioConverter):
3633
3734
Examples:
3835
```python
39-
audio_tensor = tf.ones((8000,), dtype="float32")
36+
audio_tensor = ops.ones((8000,), dtype="float32")
4037
4138
# Compute the log-mel spectrogram.
4239
audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
@@ -45,8 +42,8 @@ class WhisperAudioConverter(AudioConverter):
4542
audio_converter(audio_tensor)
4643
4744
# Compute the log-mel spectrogram for a batch of audio tensors.
48-
audio_tensor_1 = tf.ones((8000,), dtype="float32")
49-
audio_tensor_2 = tf.ones((10000,), dtype="float32")
45+
audio_tensor_1 = ops.ones((8000,), dtype="float32")
46+
audio_tensor_2 = ops.ones((10000,), dtype="float32")
5047
audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
5148
audio_converter(audio_tensor)
5249
```
@@ -84,33 +81,33 @@ def audio_shape(self):
8481
"""Returns the preprocessed size of a single audio sample."""
8582
return (self.max_audio_length, self.num_mels)
8683

84+
def _get_rfftfreq_keras(self):
85+
n = self.num_fft_bins
86+
d = 1.0 / self.sampling_rate
87+
88+
if n % 2 == 0:
89+
freqs = ops.arange(0, n // 2 + 1, dtype="float32") / (d * n)
90+
else:
91+
freqs = ops.arange(0, (n - 1) // 2 + 1, dtype="float32") / (d * n)
92+
93+
return freqs
94+
8795
def _get_mel_filters(self):
8896
"""
8997
Adapted from Hugging Face
9098
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
9199
"""
92-
93-
# TODO: Convert to TensorFlow ops (if possible).
94-
95-
dtype = np.float32
100+
dtype = self.compute_dtype # Use the class's dtype
96101
# Initialize the weights
97-
weights = np.zeros(
102+
weights = ops.zeros(
98103
(self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype
99104
)
100-
101105
# Center freqs of each FFT bin
102-
fftfreqs = np.fft.rfftfreq(
103-
n=self.num_fft_bins, d=1.0 / self.sampling_rate
104-
)
105-
106+
fftfreqs = self._get_rfftfreq_keras()
106107
# 'Center freqs' of mel bands - uniformly spaced between limits
107108
min_mel = 0.0
108109
max_mel = 45.245640471924965
109-
110-
mels = np.linspace(min_mel, max_mel, self.num_mels + 2)
111-
112-
mels = np.asanyarray(mels)
113-
110+
mels = ops.linspace(min_mel, max_mel, self.num_mels + 2)
114111
# Fill in the linear scale
115112
f_min = 0.0
116113
f_sp = 200.0 / 3
@@ -119,93 +116,105 @@ def _get_mel_filters(self):
119116
# And now the nonlinear scale
120117
min_log_hz = 1000.0 # beginning of log region (Hz)
121118
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
122-
logstep = np.log(6.4) / 27.0 # step size for log region
123-
119+
logstep = ops.log(6.4) / 27.0 # step size for log region
124120
# If we have vector data, vectorize
125121
log_t = mels >= min_log_mel
126-
freqs[log_t] = min_log_hz * np.exp(
127-
logstep * (mels[log_t] - min_log_mel)
122+
freqs = ops.where(
123+
log_t, min_log_hz * ops.exp(logstep * (mels - min_log_mel)), freqs
128124
)
129-
130125
mel_f = freqs
126+
fdiff = ops.diff(mel_f)
127+
ramps = (
128+
ops.expand_dims(mel_f, axis=1) - fftfreqs
129+
) # keras subtract outer
131130

132-
fdiff = np.diff(mel_f)
133-
ramps = np.subtract.outer(mel_f, fftfreqs)
134-
131+
weights_list = []
135132
for i in range(self.num_mels):
136133
# lower and upper slopes for all bins
137134
lower = -ramps[i] / fdiff[i]
138135
upper = ramps[i + 2] / fdiff[i + 1]
139136

140137
# .. then intersect them with each other and zero
141-
weights[i] = np.maximum(0, np.minimum(lower, upper))
138+
weights_i = ops.maximum(0, ops.minimum(lower, upper))
139+
weights_list.append(weights_i)
140+
141+
weights = ops.stack(weights_list)
142142

143143
# Slaney-style mel is scaled to be approx constant energy per channel
144144
enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels])
145-
weights *= enorm[:, np.newaxis]
145+
weights *= ops.expand_dims(enorm, axis=1)
146146

147-
weights = np.transpose(weights)
148-
return tf.constant(weights, dtype=self.compute_dtype)
147+
weights = ops.transpose(weights)
148+
return weights
149149

150150
def _extract_audio_features(self, audio):
151-
audio = tf.cast(audio, self.compute_dtype)
151+
audio = ops.cast(audio, self.compute_dtype)
152152
# Use "reflection" padding - `tf.signal.stft` uses symmetric padding
153153
# internally.
154-
audio = tf.pad(
154+
audio = ops.pad(
155155
audio,
156-
paddings=[[0, 0], [self.num_fft_bins // 2, self.num_fft_bins // 2]],
156+
pad_width=[
157+
[0, 0],
158+
[self.num_fft_bins // 2, self.num_fft_bins // 2],
159+
],
157160
mode="REFLECT",
158161
)
159-
160162
# Compute the mel spectrogram.
161-
stft = tf.signal.stft(
163+
stft = ops.stft(
162164
audio,
163-
frame_length=self.num_fft_bins,
164-
frame_step=self.stride,
165+
sequence_length=self.num_fft_bins,
166+
sequence_stride=self.stride,
165167
fft_length=self.num_fft_bins,
168+
center=False,
166169
)
167-
magnitudes = tf.square(tf.abs(stft[:, :-1, :]))
170+
stft = ops.sum(stft, axis=0)
171+
# magnitudes = ops.square(ops.absolute(stft)
172+
magnitudes = ops.square(ops.absolute(stft[:, :-1, :]))
173+
# magnitudes = ops.square(ops.sqrt(ops.square(stft_real) + ops.square(stft_imag)))
174+
# mel_filters_casted = ops.cast(self.mel_filters, dtype=magnitudes.dtype)
168175

169-
mel_spec = tf.matmul(
176+
mel_spec = ops.matmul(
170177
magnitudes,
171178
self.mel_filters,
172179
)
180+
# mel_spec = ops.matmul(magnitudes,mel_filters_casted,)
173181

174182
def tf_log10(x):
175183
"""Computes log base 10 of input tensor using TensorFlow."""
176-
numerator = tf.math.log(x)
177-
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
184+
numerator = ops.log(x)
185+
denominator = ops.log(
186+
ops.cast(ops.array(10), dtype=numerator.dtype)
187+
)
178188
return numerator / denominator
179189

180190
# Clamp the values to a minimum value of 1e-10. This is done to avoid
181191
# taking the log of 0, i.e., for numerical stability.
182-
mel_spec = tf.maximum(mel_spec, 1e-10)
192+
mel_spec = ops.maximum(mel_spec, 1e-10)
183193

184194
# Calculate the log mel spectrogram.
185195
log_spec = tf_log10(mel_spec)
186196
# Dynamic range compression.
187-
log_spec_shape = tf.shape(log_spec)
188-
max_value_minus_eight = tf.math.subtract(
189-
tf.math.reduce_max(log_spec, axis=[1, 2]),
190-
tf.cast(8, dtype=log_spec.dtype),
197+
log_spec_shape = ops.shape(log_spec)
198+
max_value_minus_eight = ops.subtract(
199+
ops.max(log_spec, axis=[1, 2]),
200+
ops.cast(8, dtype=log_spec.dtype),
191201
)
192-
max_value_minus_eight = tf.expand_dims(max_value_minus_eight, axis=1)
193-
max_value_minus_eight = tf.repeat(
202+
max_value_minus_eight = ops.expand_dims(max_value_minus_eight, axis=1)
203+
max_value_minus_eight = ops.repeat(
194204
max_value_minus_eight,
195205
repeats=log_spec_shape[1] * log_spec_shape[2],
196206
axis=1,
197207
)
198-
max_value_minus_eight = tf.reshape(
199-
max_value_minus_eight, shape=log_spec_shape
208+
max_value_minus_eight = ops.reshape(
209+
max_value_minus_eight, newshape=log_spec_shape
200210
)
201-
log_spec = tf.maximum(log_spec, max_value_minus_eight)
211+
log_spec = ops.maximum(log_spec, max_value_minus_eight)
202212
# Normalization.
203-
type_cast_four = tf.cast(4, dtype=log_spec.dtype)
204-
log_spec = tf.math.divide(
205-
tf.math.add(log_spec, type_cast_four),
213+
type_cast_four = ops.cast(4, dtype=log_spec.dtype)
214+
log_spec = ops.divide(
215+
ops.add(log_spec, type_cast_four),
206216
type_cast_four,
207217
)
208-
209218
return log_spec
210219

211220
def call(self, audio):
@@ -214,7 +223,7 @@ def call(self, audio):
214223

215224
rank_1_input = audio.shape.rank == 1
216225
if rank_1_input:
217-
audio = tf.expand_dims(audio, 0)
226+
audio = ops.expand_dims(audio, 0)
218227

219228
# Convert the tensor to a Ragged Tensor.
220229
if isinstance(audio, tf.Tensor):
@@ -228,7 +237,7 @@ def call(self, audio):
228237
# Find the log mel spectrogram.
229238
log_spec = self._extract_audio_features(audio)
230239
if rank_1_input:
231-
log_spec = tf.squeeze(log_spec, 0)
240+
log_spec = ops.squeeze(log_spec, 0)
232241
return log_spec
233242

234243
def get_config(self):

keras_hub/src/models/whisper/whisper_audio_converter_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tensorflow as tf
2+
import keras.ops as ops
23

34
from keras_hub.src.models.whisper.whisper_audio_converter import (
45
WhisperAudioConverter,
@@ -15,8 +16,8 @@ def setUp(self):
1516
"sampling_rate": 100,
1617
"max_audio_length": 5,
1718
}
18-
audio_tensor_1 = tf.ones((2,), dtype="float32")
19-
audio_tensor_2 = tf.ones((25,), dtype="float32")
19+
audio_tensor_1 = ops.ones((2,), dtype="float32")
20+
audio_tensor_2 = ops.ones((25,), dtype="float32")
2021
self.input_data = tf.ragged.stack(
2122
[audio_tensor_1, audio_tensor_2],
2223
axis=0,
@@ -30,11 +31,11 @@ def test_feature_extractor_basics(self):
3031
)
3132

3233
def test_correctness(self):
33-
audio_tensor = tf.ones((2,), dtype="float32")
34+
audio_tensor = ops.ones((2,), dtype="float32")
3435
outputs = WhisperAudioConverter(**self.init_kwargs)(audio_tensor)
3536

3637
# Verify shape.
3738
self.assertEqual(outputs.shape, (5, 80))
3839
# Verify output.
3940
expected = [1.1656, 1.0151, -0.8343, -0.8343, -0.8343]
40-
self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01)
41+
self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01)

0 commit comments

Comments
 (0)