1
- import numpy as np
1
+ import keras
2
+ import keras .ops as ops
2
3
3
4
from keras_hub .src .api_export import keras_hub_export
4
5
from keras_hub .src .layers .preprocessing .audio_converter import AudioConverter
@@ -33,23 +34,6 @@ class WhisperAudioConverter(AudioConverter):
33
34
max_audio_length: int. The length of each audio chunk in
34
35
seconds. The input audio tensor will be padded/trimmed to
35
36
`max_audio_length * sampling_rate`. Defaults to `30`.
36
-
37
- Examples:
38
- ```python
39
- audio_tensor = tf.ones((8000,), dtype="float32")
40
-
41
- # Compute the log-mel spectrogram.
42
- audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
43
- "whisper_base_en",
44
- )
45
- audio_converter(audio_tensor)
46
-
47
- # Compute the log-mel spectrogram for a batch of audio tensors.
48
- audio_tensor_1 = tf.ones((8000,), dtype="float32")
49
- audio_tensor_2 = tf.ones((10000,), dtype="float32")
50
- audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
51
- audio_converter(audio_tensor)
52
- ```
53
37
"""
54
38
55
39
backbone_cls = WhisperBackbone
@@ -84,33 +68,35 @@ def audio_shape(self):
84
68
"""Returns the preprocessed size of a single audio sample."""
85
69
return (self .max_audio_length , self .num_mels )
86
70
71
+ def _get_rfftfreq_keras (self ):
72
+ n = self .num_fft_bins
73
+ d = 1.0 / self .sampling_rate
74
+
75
+ if n % 2 == 0 :
76
+ freqs = ops .arange (0 , n // 2 + 1 , dtype = "float32" ) / (d * n )
77
+ else :
78
+ freqs = ops .arange (0 , (n - 1 ) // 2 + 1 , dtype = "float32" ) / (d * n )
79
+
80
+ return freqs
81
+
87
82
def _get_mel_filters (self ):
88
83
"""
89
84
Adapted from Hugging Face
90
85
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
91
86
"""
92
87
93
- # TODO: Convert to TensorFlow ops (if possible).
94
88
95
- dtype = np . float32
89
+ dtype = self . compute_dtype # Use the class's dtype
96
90
# Initialize the weights
97
- weights = np .zeros (
91
+ weights = ops .zeros (
98
92
(self .num_mels , int (1 + self .num_fft_bins // 2 )), dtype = dtype
99
93
)
100
-
101
94
# Center freqs of each FFT bin
102
- fftfreqs = np .fft .rfftfreq (
103
- n = self .num_fft_bins , d = 1.0 / self .sampling_rate
104
- )
105
-
95
+ fftfreqs = self ._get_rfftfreq_keras ()
106
96
# 'Center freqs' of mel bands - uniformly spaced between limits
107
97
min_mel = 0.0
108
98
max_mel = 45.245640471924965
109
-
110
- mels = np .linspace (min_mel , max_mel , self .num_mels + 2 )
111
-
112
- mels = np .asanyarray (mels )
113
-
99
+ mels = ops .linspace (min_mel , max_mel , self .num_mels + 2 )
114
100
# Fill in the linear scale
115
101
f_min = 0.0
116
102
f_sp = 200.0 / 3
@@ -119,118 +105,256 @@ def _get_mel_filters(self):
119
105
# And now the nonlinear scale
120
106
min_log_hz = 1000.0 # beginning of log region (Hz)
121
107
min_log_mel = (min_log_hz - f_min ) / f_sp # same (Mels)
122
- logstep = np .log (6.4 ) / 27.0 # step size for log region
123
-
108
+ logstep = ops .log (6.4 ) / 27.0 # step size for log region
124
109
# If we have vector data, vectorize
125
110
log_t = mels >= min_log_mel
126
- freqs [ log_t ] = min_log_hz * np . exp (
127
- logstep * (mels [ log_t ] - min_log_mel )
111
+ freqs = ops . where (
112
+ log_t , min_log_hz * ops . exp ( logstep * (mels - min_log_mel )), freqs
128
113
)
129
-
130
114
mel_f = freqs
131
115
132
- fdiff = np .diff (mel_f )
133
- ramps = np .subtract .outer (mel_f , fftfreqs )
134
116
117
+ fdiff = ops .diff (mel_f )
118
+ ramps = (
119
+ ops .expand_dims (mel_f , axis = 1 ) - fftfreqs
120
+ ) # keras subtract outer
121
+
122
+ weights_list = []
135
123
for i in range (self .num_mels ):
136
124
# lower and upper slopes for all bins
137
125
lower = - ramps [i ] / fdiff [i ]
138
126
upper = ramps [i + 2 ] / fdiff [i + 1 ]
139
127
140
128
# .. then intersect them with each other and zero
141
- weights [i ] = np .maximum (0 , np .minimum (lower , upper ))
129
+ weights_i = ops .maximum (0 , ops .minimum (lower , upper ))
130
+ weights_list .append (weights_i )
131
+
132
+ weights = ops .stack (weights_list )
142
133
143
134
# Slaney-style mel is scaled to be approx constant energy per channel
144
135
enorm = 2.0 / (mel_f [2 : self .num_mels + 2 ] - mel_f [: self .num_mels ])
145
- weights *= enorm [:, np . newaxis ]
136
+ weights *= ops . expand_dims ( enorm , axis = 1 )
146
137
147
- weights = np .transpose (weights )
148
- return tf . constant ( weights , dtype = self . compute_dtype )
138
+ weights = ops .transpose (weights )
139
+ return weights
149
140
150
141
def _extract_audio_features (self , audio ):
151
- audio = tf .cast (audio , self .compute_dtype )
142
+ audio = ops .cast (audio , self .compute_dtype )
152
143
# Use "reflection" padding - `tf.signal.stft` uses symmetric padding
153
144
# internally.
154
- audio = tf .pad (
145
+ audio = ops .pad (
155
146
audio ,
156
- paddings = [[0 , 0 ], [self .num_fft_bins // 2 , self .num_fft_bins // 2 ]],
157
- mode = "REFLECT" ,
147
+ pad_width = [
148
+ [0 , 0 ],
149
+ [self .num_fft_bins // 2 , self .num_fft_bins // 2 ],
150
+ ],
151
+ mode = "reflect" ,
158
152
)
159
-
160
- # Compute the mel spectrogram.
161
- stft = tf .signal .stft (
153
+ stft = ops .stft (
162
154
audio ,
163
- frame_length = self .num_fft_bins ,
164
- frame_step = self .stride ,
155
+ sequence_length = self .num_fft_bins ,
156
+ sequence_stride = self .stride ,
165
157
fft_length = self .num_fft_bins ,
158
+ center = False ,
166
159
)
167
- magnitudes = tf .square (tf .abs (stft [:, :- 1 , :]))
168
-
169
- mel_spec = tf .matmul (
160
+ stft = ops .sum (stft , axis = 0 )
161
+ magnitudes = ops .square (ops .absolute (stft [:, :- 1 , :]))
162
+
163
+ mel_spec = ops .matmul (
170
164
magnitudes ,
171
165
self .mel_filters ,
172
166
)
173
167
174
168
def tf_log10 (x ):
175
169
"""Computes log base 10 of input tensor using TensorFlow."""
176
- numerator = tf .math .log (x )
177
- denominator = tf .math .log (tf .constant (10 , dtype = numerator .dtype ))
170
+ numerator = ops .log (x )
171
+ denominator = ops .log (
172
+ ops .cast (ops .array (10 ), dtype = numerator .dtype )
173
+ )
178
174
return numerator / denominator
179
175
180
176
# Clamp the values to a minimum value of 1e-10. This is done to avoid
181
177
# taking the log of 0, i.e., for numerical stability.
182
- mel_spec = tf .maximum (mel_spec , 1e-10 )
178
+ mel_spec = ops .maximum (mel_spec , 1e-10 )
183
179
184
180
# Calculate the log mel spectrogram.
185
181
log_spec = tf_log10 (mel_spec )
186
182
# Dynamic range compression.
187
- log_spec_shape = tf .shape (log_spec )
188
- max_value_minus_eight = tf . math .subtract (
189
- tf . math . reduce_max (log_spec , axis = [1 , 2 ]),
190
- tf .cast (8 , dtype = log_spec .dtype ),
183
+ log_spec_shape = ops .shape (log_spec )
184
+ max_value_minus_eight = ops .subtract (
185
+ ops . max (log_spec , axis = [1 , 2 ]),
186
+ ops .cast (8 , dtype = log_spec .dtype ),
191
187
)
192
- max_value_minus_eight = tf .expand_dims (max_value_minus_eight , axis = 1 )
193
- max_value_minus_eight = tf .repeat (
188
+ max_value_minus_eight = ops .expand_dims (max_value_minus_eight , axis = 1 )
189
+ max_value_minus_eight = ops .repeat (
194
190
max_value_minus_eight ,
195
191
repeats = log_spec_shape [1 ] * log_spec_shape [2 ],
196
192
axis = 1 ,
197
193
)
198
- max_value_minus_eight = tf .reshape (
199
- max_value_minus_eight , shape = log_spec_shape
194
+ max_value_minus_eight = ops .reshape (
195
+ max_value_minus_eight , newshape = log_spec_shape
200
196
)
201
- log_spec = tf .maximum (log_spec , max_value_minus_eight )
197
+ log_spec = ops .maximum (log_spec , max_value_minus_eight )
202
198
# Normalization.
203
- type_cast_four = tf .cast (4 , dtype = log_spec .dtype )
204
- log_spec = tf . math .divide (
205
- tf . math .add (log_spec , type_cast_four ),
199
+ type_cast_four = ops .cast (4 , dtype = log_spec .dtype )
200
+ log_spec = ops .divide (
201
+ ops .add (log_spec , type_cast_four ),
206
202
type_cast_four ,
207
203
)
208
-
209
204
return log_spec
210
205
211
- def call (self , audio ):
212
- if not isinstance (audio , (tf .Tensor , tf .RaggedTensor )):
213
- audio = tf .convert_to_tensor (audio )
206
+ def call (
207
+ self ,
208
+ inputs ,
209
+ padding = None ,
210
+ max_length = None ,
211
+ pad_to_multiple_of = None ,
212
+ ):
213
+ input_shape = keras .ops .shape (inputs )
214
+ input_rank = (
215
+ len (input_shape )
216
+ if isinstance (input_shape , (list , tuple ))
217
+ else input_shape .rank
218
+ )
219
+ rank_1_input = input_rank == 1
214
220
215
- rank_1_input = audio .shape .rank == 1
216
221
if rank_1_input :
217
- audio = tf .expand_dims (audio , 0 )
222
+ inputs = ops .expand_dims (inputs , 0 )
223
+ inputs = ops .expand_dims (inputs , 0 )
224
+
225
+ # Convert to dense tensor with proper padding/truncation
226
+ processed_inputs = self .variable_length_inputs (
227
+ inputs , padding , max_length , pad_to_multiple_of
228
+ )
218
229
219
- # Convert the tensor to a Ragged Tensor.
220
- if isinstance (audio , tf .Tensor ):
221
- audio = tf .RaggedTensor .from_tensor (audio )
230
+ # Extract features
231
+ log_spec = self ._extract_audio_features (processed_inputs )
222
232
223
- # Pad audio.
224
- audio_shape = audio .shape .as_list ()
225
- audio_shape [- 1 ] = self .num_samples
226
- audio = audio .to_tensor (shape = audio_shape )
233
+ # Extract features
234
+ log_spec = self ._extract_audio_features (processed_inputs )
227
235
228
- # Find the log mel spectrogram.
229
- log_spec = self ._extract_audio_features (audio )
230
236
if rank_1_input :
231
- log_spec = tf .squeeze (log_spec , 0 )
237
+ log_spec = ops .squeeze (log_spec , 0 )
238
+
239
+
232
240
return log_spec
233
241
242
+ # handling variable length inputs
243
+ def variable_length_inputs (
244
+ self , inputs , padding = None , max_length = None , pad_to_multiple_of = None
245
+ ):
246
+ """Handles variable length inputs with padding or truncation."""
247
+
248
+ # Determine the appropriate target length
249
+ if padding == "max_length" and max_length is not None :
250
+ target_length = max_length
251
+ else :
252
+ # Use default max_audio_length
253
+ target_length = self .num_samples
254
+
255
+ if pad_to_multiple_of :
256
+ target_length = (
257
+ (target_length + pad_to_multiple_of - 1 ) // pad_to_multiple_of
258
+ ) * pad_to_multiple_of
259
+
260
+ # Get current shape and length
261
+ audio_shape = keras .ops .shape (inputs )
262
+ audio_length = audio_shape [1 ]
263
+
264
+ if padding == "max_length" and max_length is not None :
265
+ is_padding_required = keras .ops .less (audio_length , target_length )
266
+ is_trunc_required = keras .ops .greater (audio_length , target_length )
267
+
268
+ def pad_fn ():
269
+ padding_amount = target_length - audio_length
270
+ paddings = [[0 , 0 ], [0 , padding_amount ]]
271
+ return keras .ops .pad (
272
+ inputs ,
273
+ paddings ,
274
+ mode = "constant" ,
275
+ constant_values = self .padding_value ,
276
+ )
277
+
278
+ def trunc_fn ():
279
+ return keras .ops .slice (
280
+ inputs ,
281
+ [0 , 0 ],
282
+ [- 1 , target_length ],
283
+ )
284
+
285
+ # Check if we're in symbolic execution
286
+ is_tf_symbolic = (
287
+ tf is not None
288
+ and hasattr (inputs , "graph" )
289
+ and hasattr (inputs .graph , "as_graph_def" )
290
+ )
291
+ use_tf_graph_ops = tf is not None and is_tf_symbolic
292
+
293
+ if use_tf_graph_ops and keras .config .backend () != "torch" :
294
+ processed_inputs = tf .cond (
295
+ is_padding_required ,
296
+ pad_fn ,
297
+ lambda : tf .cond (is_trunc_required , trunc_fn ,lambda : inputs ),
298
+ )
299
+ else :
300
+ is_padding_bool = keras .ops .convert_to_numpy (is_padding_required )
301
+ is_trunc_bool = keras .ops .convert_to_numpy (
302
+ is_trunc_required
303
+ )
304
+
305
+ if is_padding_bool :
306
+ padding_amount = target_length - audio_length
307
+ paddings = [[0 , 0 ], [0 , padding_amount ]]
308
+ processed_inputs = keras .ops .pad (
309
+ inputs ,
310
+ paddings ,
311
+ mode = "constant" ,
312
+ constant_values = self .padding_value ,
313
+ )
314
+ elif is_trunc_bool :
315
+ processed_inputs = inputs [:, :target_length ]
316
+ else :
317
+ processed_inputs = inputs
318
+ else :
319
+ # No explicit padding - just pad/truncate to default max length
320
+ is_padding_required = keras .ops .less (audio_length , target_length )
321
+ is_trunc_required = keras .ops .greater (audio_length , target_length )
322
+
323
+ # Use eager execution approach for simplicity
324
+ is_padding_bool = keras .ops .convert_to_numpy (is_padding_required )
325
+ is_trunc_bool = keras .ops .convert_to_numpy (is_trunc_required )
326
+
327
+ if is_padding_bool :
328
+ padding_amount = target_length - audio_length
329
+ paddings = [[0 , 0 ], [0 , padding_amount ]]
330
+ processed_inputs = keras .ops .pad (
331
+ inputs ,
332
+ paddings ,
333
+ mode = "constant" ,
334
+ constant_values = self .padding_value ,
335
+ )
336
+ elif is_trunc_bool :
337
+ processed_inputs = inputs [:, :target_length ]
338
+ else :
339
+ processed_inputs = inputs
340
+
341
+ return processed_inputs
342
+
343
+ def compute_output_shape (self , input_shape ):
344
+ """Compute output shape for variable-length inputs."""
345
+
346
+ if len (input_shape ) == 1 :
347
+ # For single audio sample - returns 2D shape (frames, mels)
348
+ num_frames = (self .num_samples + self .stride - 1 ) // self .stride
349
+ return (num_frames , self .num_mels )
350
+ elif len (input_shape ) == 2 :
351
+ # For batch of audio samples -returns 3D shape (batch, frames, mels)
352
+ batch_size = input_shape [0 ]
353
+ num_frames = (self .num_samples + self .stride - 1 ) // self .stride
354
+ return (batch_size , num_frames , self .num_mels )
355
+ else :
356
+ raise ValueError ("Input shape must be rank 1 or 2." )
357
+
234
358
def get_config (self ):
235
359
config = super ().get_config ()
236
360
config .update (
@@ -243,3 +367,5 @@ def get_config(self):
243
367
}
244
368
)
245
369
return config
370
+
371
+
0 commit comments