1
+ import keras
1
2
import keras .ops as ops
2
- import numpy as np
3
- import tensorflow as tf
4
3
5
4
from keras_hub .src .api_export import keras_hub_export
6
5
from keras_hub .src .layers .preprocessing .audio_converter import AudioConverter
7
6
from keras_hub .src .models .whisper .whisper_backbone import WhisperBackbone
8
7
8
+ try :
9
+ import tensorflow as tf
10
+ except ImportError :
11
+ tf = None
12
+
9
13
10
14
@keras_hub_export ("keras_hub.layers.WhisperAudioConverter" )
11
15
class WhisperAudioConverter (AudioConverter ):
@@ -97,6 +101,7 @@ def _get_mel_filters(self):
97
101
Adapted from Hugging Face
98
102
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
99
103
"""
104
+
100
105
dtype = self .compute_dtype # Use the class's dtype
101
106
# Initialize the weights
102
107
weights = ops .zeros (
@@ -123,6 +128,7 @@ def _get_mel_filters(self):
123
128
log_t , min_log_hz * ops .exp (logstep * (mels - min_log_mel )), freqs
124
129
)
125
130
mel_f = freqs
131
+
126
132
fdiff = ops .diff (mel_f )
127
133
ramps = (
128
134
ops .expand_dims (mel_f , axis = 1 ) - fftfreqs
@@ -157,9 +163,8 @@ def _extract_audio_features(self, audio):
157
163
[0 , 0 ],
158
164
[self .num_fft_bins // 2 , self .num_fft_bins // 2 ],
159
165
],
160
- mode = "REFLECT " ,
166
+ mode = "reflect " ,
161
167
)
162
- # Compute the mel spectrogram.
163
168
stft = ops .stft (
164
169
audio ,
165
170
sequence_length = self .num_fft_bins ,
@@ -168,16 +173,12 @@ def _extract_audio_features(self, audio):
168
173
center = False ,
169
174
)
170
175
stft = ops .sum (stft , axis = 0 )
171
- # magnitudes = ops.square(ops.absolute(stft)
172
176
magnitudes = ops .square (ops .absolute (stft [:, :- 1 , :]))
173
- # magnitudes = ops.square(ops.sqrt(ops.square(stft_real) + ops.square(stft_imag)))
174
- # mel_filters_casted = ops.cast(self.mel_filters, dtype=magnitudes.dtype)
175
-
177
+
176
178
mel_spec = ops .matmul (
177
179
magnitudes ,
178
180
self .mel_filters ,
179
181
)
180
- # mel_spec = ops.matmul(magnitudes,mel_filters_casted,)
181
182
182
183
def tf_log10 (x ):
183
184
"""Computes log base 10 of input tensor using TensorFlow."""
@@ -217,29 +218,175 @@ def tf_log10(x):
217
218
)
218
219
return log_spec
219
220
220
- def call (self , audio ):
221
- if not isinstance (audio , (tf .Tensor , tf .RaggedTensor )):
222
- audio = tf .convert_to_tensor (audio )
221
+ # def call(self, audio):
222
+ # if not ops.is_tensor(audio):
223
+ # audio = ops.convert_to_tensor(audio)
224
+
225
+ # rank_1_input = ops.ndim(audio) == 1
226
+ # if rank_1_input:
227
+ # audio = ops.expand_dims(audio, axis=0)
228
+
229
+ # input_shape = ops.shape(audio)
230
+ # audio_len = input_shape[-1]
231
+ # padding_size = self.num_samples - audio_len
232
+
233
+ # if padding_size > 0:
234
+ # audio = ops.pad(audio, ((0, 0), (0, padding_size)))
235
+
236
+ # log_spec = self._extract_audio_features(audio)
237
+
238
+ # if rank_1_input:
239
+ # log_spec = ops.squeeze(log_spec, axis=0)
240
+
241
+ # return log_spec
242
+
243
+ def call (
244
+ self ,
245
+ inputs ,
246
+ padding = None ,
247
+ max_length = None ,
248
+ pad_to_multiple_of = None ,
249
+ ):
250
+ input_shape = keras .ops .shape (inputs )
251
+ input_rank = (
252
+ len (input_shape )
253
+ if isinstance (input_shape , (list , tuple ))
254
+ else input_shape .rank
255
+ )
256
+ rank_1_input = input_rank == 1
223
257
224
- rank_1_input = audio .shape .rank == 1
225
258
if rank_1_input :
226
- audio = ops .expand_dims (audio , 0 )
259
+ inputs = ops .expand_dims (inputs , 0 )
227
260
228
- # Convert the tensor to a Ragged Tensor.
229
- if isinstance (audio , tf .Tensor ):
230
- audio = tf .RaggedTensor .from_tensor (audio )
261
+ # Convert to dense tensor with proper padding/truncation
262
+ processed_inputs = self .variable_length_inputs (
263
+ inputs , padding , max_length , pad_to_multiple_of
264
+ )
231
265
232
- # Pad audio.
233
- audio_shape = audio .shape .as_list ()
234
- audio_shape [- 1 ] = self .num_samples
235
- audio = audio .to_tensor (shape = audio_shape )
266
+ # Extract features
267
+ log_spec = self ._extract_audio_features (processed_inputs )
236
268
237
- # Find the log mel spectrogram.
238
- log_spec = self ._extract_audio_features (audio )
239
269
if rank_1_input :
240
270
log_spec = ops .squeeze (log_spec , 0 )
271
+
241
272
return log_spec
242
273
274
+ # handling variable length inputs
275
+ def variable_length_inputs (
276
+ self , inputs , padding = None , max_length = None , pad_to_multiple_of = None
277
+ ):
278
+ """Handles variable length inputs with padding or truncation."""
279
+
280
+ # Determine the appropriate target length
281
+ if padding == "max_length" and max_length is not None :
282
+ target_length = max_length
283
+ else :
284
+ # Use default max_audio_length
285
+ target_length = self .num_samples
286
+
287
+ if pad_to_multiple_of :
288
+ target_length = (
289
+ (target_length + pad_to_multiple_of - 1 ) // pad_to_multiple_of
290
+ ) * pad_to_multiple_of
291
+
292
+ # Get current shape and length
293
+ audio_shape = keras .ops .shape (inputs )
294
+ audio_length = audio_shape [1 ]
295
+
296
+ if padding == "max_length" and max_length is not None :
297
+ is_padding_required = keras .ops .less (audio_length , target_length )
298
+ is_trunc_required = keras .ops .greater (audio_length , target_length )
299
+
300
+ def pad_fn ():
301
+ padding_amount = target_length - audio_length
302
+ paddings = [[0 , 0 ], [0 , padding_amount ]]
303
+ return keras .ops .pad (
304
+ inputs ,
305
+ paddings ,
306
+ mode = "constant" ,
307
+ constant_values = self .padding_value ,
308
+ )
309
+
310
+ def trunc_fn ():
311
+ return keras .ops .slice (
312
+ inputs ,
313
+ [0 , 0 ],
314
+ [- 1 , target_length ],
315
+ )
316
+
317
+ # Check if we're in symbolic execution
318
+ is_tf_symbolic = (
319
+ tf is not None
320
+ and hasattr (inputs , "graph" )
321
+ and hasattr (inputs .graph , "as_graph_def" )
322
+ )
323
+ use_tf_graph_ops = tf is not None and is_tf_symbolic
324
+
325
+ if use_tf_graph_ops and keras .config .backend () != "torch" :
326
+ processed_inputs = tf .cond (
327
+ is_padding_required ,
328
+ pad_fn ,
329
+ lambda : tf .cond (is_trunc_required , trunc_fn ,lambda : inputs ),
330
+ )
331
+ else :
332
+ is_padding_bool = keras .ops .convert_to_numpy (is_padding_required )
333
+ is_trunc_bool = keras .ops .convert_to_numpy (
334
+ is_trunc_required
335
+ )
336
+
337
+ if is_padding_bool :
338
+ padding_amount = target_length - audio_length
339
+ paddings = [[0 , 0 ], [0 , padding_amount ]]
340
+ processed_inputs = keras .ops .pad (
341
+ inputs ,
342
+ paddings ,
343
+ mode = "constant" ,
344
+ constant_values = self .padding_value ,
345
+ )
346
+ elif is_trunc_bool :
347
+ processed_inputs = inputs [:, :target_length ]
348
+ else :
349
+ processed_inputs = inputs
350
+ else :
351
+ # No explicit padding - just pad/truncate to default max length
352
+ is_padding_required = keras .ops .less (audio_length , target_length )
353
+ is_trunc_required = keras .ops .greater (audio_length , target_length )
354
+
355
+ # Use eager execution approach for simplicity
356
+ is_padding_bool = keras .ops .convert_to_numpy (is_padding_required )
357
+ is_trunc_bool = keras .ops .convert_to_numpy (is_trunc_required )
358
+
359
+ if is_padding_bool :
360
+ padding_amount = target_length - audio_length
361
+ paddings = [[0 , 0 ], [0 , padding_amount ]]
362
+ processed_inputs = keras .ops .pad (
363
+ inputs ,
364
+ paddings ,
365
+ mode = "constant" ,
366
+ constant_values = self .padding_value ,
367
+ )
368
+ elif is_trunc_bool :
369
+ processed_inputs = inputs [:, :target_length ]
370
+ else :
371
+ processed_inputs = inputs
372
+
373
+ return processed_inputs
374
+
375
+ def compute_output_shape (self , input_shape ):
376
+ """Compute output shape for variable-length inputs."""
377
+
378
+ if len (input_shape ) == 1 :
379
+ # For single audio sample - returns 2D shape (frames, mels)
380
+ num_frames = (self .num_samples + self .stride - 1 ) // self .stride
381
+ return (num_frames , self .num_mels )
382
+ elif len (input_shape ) == 2 :
383
+ # For batch of audio samples -returns 3D shape (batch, frames, mels)
384
+ batch_size = input_shape [0 ]
385
+ num_frames = (self .num_samples + self .stride - 1 ) // self .stride
386
+ return (batch_size , num_frames , self .num_mels )
387
+ else :
388
+ raise ValueError ("Input shape must be rank 1 or 2." )
389
+
243
390
def get_config (self ):
244
391
config = super ().get_config ()
245
392
config .update (
@@ -252,3 +399,4 @@ def get_config(self):
252
399
}
253
400
)
254
401
return config
402
+
0 commit comments