1
+ import keras .ops as ops
1
2
import numpy as np
3
+ import tensorflow as tf
2
4
3
5
from keras_hub .src .api_export import keras_hub_export
4
6
from keras_hub .src .layers .preprocessing .audio_converter import AudioConverter
5
7
from keras_hub .src .models .whisper .whisper_backbone import WhisperBackbone
6
8
7
- try :
8
- import tensorflow as tf
9
- except ImportError :
10
- tf = None
11
-
12
9
13
10
@keras_hub_export ("keras_hub.layers.WhisperAudioConverter" )
14
11
class WhisperAudioConverter (AudioConverter ):
@@ -36,7 +33,7 @@ class WhisperAudioConverter(AudioConverter):
36
33
37
34
Examples:
38
35
```python
39
- audio_tensor = tf .ones((8000,), dtype="float32")
36
+ audio_tensor = ops .ones((8000,), dtype="float32")
40
37
41
38
# Compute the log-mel spectrogram.
42
39
audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
@@ -45,8 +42,8 @@ class WhisperAudioConverter(AudioConverter):
45
42
audio_converter(audio_tensor)
46
43
47
44
# Compute the log-mel spectrogram for a batch of audio tensors.
48
- audio_tensor_1 = tf .ones((8000,), dtype="float32")
49
- audio_tensor_2 = tf .ones((10000,), dtype="float32")
45
+ audio_tensor_1 = ops .ones((8000,), dtype="float32")
46
+ audio_tensor_2 = ops .ones((10000,), dtype="float32")
50
47
audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
51
48
audio_converter(audio_tensor)
52
49
```
@@ -84,33 +81,33 @@ def audio_shape(self):
84
81
"""Returns the preprocessed size of a single audio sample."""
85
82
return (self .max_audio_length , self .num_mels )
86
83
84
+ def _get_rfftfreq_keras (self ):
85
+ n = self .num_fft_bins
86
+ d = 1.0 / self .sampling_rate
87
+
88
+ if n % 2 == 0 :
89
+ freqs = ops .arange (0 , n // 2 + 1 , dtype = "float32" ) / (d * n )
90
+ else :
91
+ freqs = ops .arange (0 , (n - 1 ) // 2 + 1 , dtype = "float32" ) / (d * n )
92
+
93
+ return freqs
94
+
87
95
def _get_mel_filters (self ):
88
96
"""
89
97
Adapted from Hugging Face
90
98
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
91
99
"""
92
-
93
- # TODO: Convert to TensorFlow ops (if possible).
94
-
95
- dtype = np .float32
100
+ dtype = self .compute_dtype # Use the class's dtype
96
101
# Initialize the weights
97
- weights = np .zeros (
102
+ weights = ops .zeros (
98
103
(self .num_mels , int (1 + self .num_fft_bins // 2 )), dtype = dtype
99
104
)
100
-
101
105
# Center freqs of each FFT bin
102
- fftfreqs = np .fft .rfftfreq (
103
- n = self .num_fft_bins , d = 1.0 / self .sampling_rate
104
- )
105
-
106
+ fftfreqs = self ._get_rfftfreq_keras ()
106
107
# 'Center freqs' of mel bands - uniformly spaced between limits
107
108
min_mel = 0.0
108
109
max_mel = 45.245640471924965
109
-
110
- mels = np .linspace (min_mel , max_mel , self .num_mels + 2 )
111
-
112
- mels = np .asanyarray (mels )
113
-
110
+ mels = ops .linspace (min_mel , max_mel , self .num_mels + 2 )
114
111
# Fill in the linear scale
115
112
f_min = 0.0
116
113
f_sp = 200.0 / 3
@@ -119,93 +116,105 @@ def _get_mel_filters(self):
119
116
# And now the nonlinear scale
120
117
min_log_hz = 1000.0 # beginning of log region (Hz)
121
118
min_log_mel = (min_log_hz - f_min ) / f_sp # same (Mels)
122
- logstep = np .log (6.4 ) / 27.0 # step size for log region
123
-
119
+ logstep = ops .log (6.4 ) / 27.0 # step size for log region
124
120
# If we have vector data, vectorize
125
121
log_t = mels >= min_log_mel
126
- freqs [ log_t ] = min_log_hz * np . exp (
127
- logstep * (mels [ log_t ] - min_log_mel )
122
+ freqs = ops . where (
123
+ log_t , min_log_hz * ops . exp ( logstep * (mels - min_log_mel )), freqs
128
124
)
129
-
130
125
mel_f = freqs
126
+ fdiff = ops .diff (mel_f )
127
+ ramps = (
128
+ ops .expand_dims (mel_f , axis = 1 ) - fftfreqs
129
+ ) # keras subtract outer
131
130
132
- fdiff = np .diff (mel_f )
133
- ramps = np .subtract .outer (mel_f , fftfreqs )
134
-
131
+ weights_list = []
135
132
for i in range (self .num_mels ):
136
133
# lower and upper slopes for all bins
137
134
lower = - ramps [i ] / fdiff [i ]
138
135
upper = ramps [i + 2 ] / fdiff [i + 1 ]
139
136
140
137
# .. then intersect them with each other and zero
141
- weights [i ] = np .maximum (0 , np .minimum (lower , upper ))
138
+ weights_i = ops .maximum (0 , ops .minimum (lower , upper ))
139
+ weights_list .append (weights_i )
140
+
141
+ weights = ops .stack (weights_list )
142
142
143
143
# Slaney-style mel is scaled to be approx constant energy per channel
144
144
enorm = 2.0 / (mel_f [2 : self .num_mels + 2 ] - mel_f [: self .num_mels ])
145
- weights *= enorm [:, np . newaxis ]
145
+ weights *= ops . expand_dims ( enorm , axis = 1 )
146
146
147
- weights = np .transpose (weights )
148
- return tf . constant ( weights , dtype = self . compute_dtype )
147
+ weights = ops .transpose (weights )
148
+ return weights
149
149
150
150
def _extract_audio_features (self , audio ):
151
- audio = tf .cast (audio , self .compute_dtype )
151
+ audio = ops .cast (audio , self .compute_dtype )
152
152
# Use "reflection" padding - `tf.signal.stft` uses symmetric padding
153
153
# internally.
154
- audio = tf .pad (
154
+ audio = ops .pad (
155
155
audio ,
156
- paddings = [[0 , 0 ], [self .num_fft_bins // 2 , self .num_fft_bins // 2 ]],
156
+ pad_width = [
157
+ [0 , 0 ],
158
+ [self .num_fft_bins // 2 , self .num_fft_bins // 2 ],
159
+ ],
157
160
mode = "REFLECT" ,
158
161
)
159
-
160
162
# Compute the mel spectrogram.
161
- stft = tf . signal .stft (
163
+ stft = ops .stft (
162
164
audio ,
163
- frame_length = self .num_fft_bins ,
164
- frame_step = self .stride ,
165
+ sequence_length = self .num_fft_bins ,
166
+ sequence_stride = self .stride ,
165
167
fft_length = self .num_fft_bins ,
168
+ center = False ,
166
169
)
167
- magnitudes = tf .square (tf .abs (stft [:, :- 1 , :]))
170
+ stft = ops .sum (stft , axis = 0 )
171
+ # magnitudes = ops.square(ops.absolute(stft)
172
+ magnitudes = ops .square (ops .absolute (stft [:, :- 1 , :]))
173
+ # magnitudes = ops.square(ops.sqrt(ops.square(stft_real) + ops.square(stft_imag)))
174
+ # mel_filters_casted = ops.cast(self.mel_filters, dtype=magnitudes.dtype)
168
175
169
- mel_spec = tf .matmul (
176
+ mel_spec = ops .matmul (
170
177
magnitudes ,
171
178
self .mel_filters ,
172
179
)
180
+ # mel_spec = ops.matmul(magnitudes,mel_filters_casted,)
173
181
174
182
def tf_log10 (x ):
175
183
"""Computes log base 10 of input tensor using TensorFlow."""
176
- numerator = tf .math .log (x )
177
- denominator = tf .math .log (tf .constant (10 , dtype = numerator .dtype ))
184
+ numerator = ops .log (x )
185
+ denominator = ops .log (
186
+ ops .cast (ops .array (10 ), dtype = numerator .dtype )
187
+ )
178
188
return numerator / denominator
179
189
180
190
# Clamp the values to a minimum value of 1e-10. This is done to avoid
181
191
# taking the log of 0, i.e., for numerical stability.
182
- mel_spec = tf .maximum (mel_spec , 1e-10 )
192
+ mel_spec = ops .maximum (mel_spec , 1e-10 )
183
193
184
194
# Calculate the log mel spectrogram.
185
195
log_spec = tf_log10 (mel_spec )
186
196
# Dynamic range compression.
187
- log_spec_shape = tf .shape (log_spec )
188
- max_value_minus_eight = tf . math .subtract (
189
- tf . math . reduce_max (log_spec , axis = [1 , 2 ]),
190
- tf .cast (8 , dtype = log_spec .dtype ),
197
+ log_spec_shape = ops .shape (log_spec )
198
+ max_value_minus_eight = ops .subtract (
199
+ ops . max (log_spec , axis = [1 , 2 ]),
200
+ ops .cast (8 , dtype = log_spec .dtype ),
191
201
)
192
- max_value_minus_eight = tf .expand_dims (max_value_minus_eight , axis = 1 )
193
- max_value_minus_eight = tf .repeat (
202
+ max_value_minus_eight = ops .expand_dims (max_value_minus_eight , axis = 1 )
203
+ max_value_minus_eight = ops .repeat (
194
204
max_value_minus_eight ,
195
205
repeats = log_spec_shape [1 ] * log_spec_shape [2 ],
196
206
axis = 1 ,
197
207
)
198
- max_value_minus_eight = tf .reshape (
199
- max_value_minus_eight , shape = log_spec_shape
208
+ max_value_minus_eight = ops .reshape (
209
+ max_value_minus_eight , newshape = log_spec_shape
200
210
)
201
- log_spec = tf .maximum (log_spec , max_value_minus_eight )
211
+ log_spec = ops .maximum (log_spec , max_value_minus_eight )
202
212
# Normalization.
203
- type_cast_four = tf .cast (4 , dtype = log_spec .dtype )
204
- log_spec = tf . math .divide (
205
- tf . math .add (log_spec , type_cast_four ),
213
+ type_cast_four = ops .cast (4 , dtype = log_spec .dtype )
214
+ log_spec = ops .divide (
215
+ ops .add (log_spec , type_cast_four ),
206
216
type_cast_four ,
207
217
)
208
-
209
218
return log_spec
210
219
211
220
def call (self , audio ):
@@ -214,7 +223,7 @@ def call(self, audio):
214
223
215
224
rank_1_input = audio .shape .rank == 1
216
225
if rank_1_input :
217
- audio = tf .expand_dims (audio , 0 )
226
+ audio = ops .expand_dims (audio , 0 )
218
227
219
228
# Convert the tensor to a Ragged Tensor.
220
229
if isinstance (audio , tf .Tensor ):
@@ -228,7 +237,7 @@ def call(self, audio):
228
237
# Find the log mel spectrogram.
229
238
log_spec = self ._extract_audio_features (audio )
230
239
if rank_1_input :
231
- log_spec = tf .squeeze (log_spec , 0 )
240
+ log_spec = ops .squeeze (log_spec , 0 )
232
241
return log_spec
233
242
234
243
def get_config (self ):
0 commit comments