@@ -228,50 +228,16 @@ def __init__(
228228 if torch .cuda .is_available ():
229229 self .register_buffer ("mask_cuda" , mask .cuda (), persistent = False )
230230
231-
232- def forward (self , tokens : Tensor , audio_features : Tensor ) -> Tensor :
231+ def forward (self , tokens : Tensor , audio_features : Tensor , kv_cache : Optional [dict ] = None ):
233232 """
234233 Args:
235234 tokens: (n_batch, n_token)
236235 audio_features: (n_batch, n_audio_ctx, n_audio_state)
236+ kv_cache: Optional cache for key/value tensors
237237
238238 Returns:
239239 logits: (n_batch, n_token, n_vocab)
240240 """
241- n_batch , n_token = tokens .shape
242- n_audio_ctx , n_audio_state = audio_features .shape [1 :]
243-
244- x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
245-
246- # Optimisation: Move audio_features to GPU once here.
247- if torch .cuda .is_available ():
248- audio_features = audio_features .cuda ()
249-
250-
251- for block in self .blocks :
252- x = block (x , audio_features )
253-
254- x = self .ln (x )
255- logits = x @ self .token_embedding .weight .T
256-
257- # Optimisation: Apply the precomputed CUDA mask if available.
258- if torch .cuda .is_available ():
259- mask = self .mask_cuda [:n_token , :n_token ]
260- else :
261- mask = self .mask [:n_token , :n_token ]
262-
263- logits = logits + mask
264-
265- return logits
266-
267-
268- def forward (self , x : Tensor , xa : Tensor , kv_cache : Optional [dict ] = None ):
269- """
270- Args:
271- tokens: (n_batch, n_token) or x tensor
272- audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
273- kv_cache: Optional cache for key/value tensors
274- """
275241 if kv_cache is not None :
276242 # Handle the kv_cache case
277243 offset = next (iter (kv_cache .values ())).shape [1 ] if kv_cache else 0
@@ -313,7 +279,6 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
313279
314280 return logits
315281
316-
317282# The Whisper class has been moved outside of TextDecoder and is now a top-level class
318283class Whisper (nn .Module ):
319284 def __init__ (self , dims : ModelDimensions ):
0 commit comments