AtAndDev commited on
Commit
50e657e
·
verified ·
1 Parent(s): af0d841

Upload ultravox_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ultravox_model.py +92 -145
ultravox_model.py CHANGED
@@ -33,13 +33,10 @@ SHARED_PRETRAINED_KWARGS = [
33
  class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
34
  """
35
  The Ultravox model which consists of an audio encoder and a language model.
36
-
37
  Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
38
  projected to the language model's embedding space using a few linear layers.
39
  The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
40
-
41
  A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
42
-
43
  Parameters:
44
  config: Model configuration class with all the parameters of the model.
45
  """
@@ -59,11 +56,11 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
59
  self.keep_params: Set[str] = set()
60
  self.vocab_size = config.vocab_size
61
 
62
- if not config.llm_only_training:
63
- self.audio_tower = self._create_audio_tower(config)
64
- self.multi_modal_projector = self._create_multi_modal_projector(config)
65
- self.audio_tower_context_length = self.audio_tower.max_context_length
66
 
 
67
  self.language_model = self._create_language_model(config)
68
 
69
  if self.language_model._tied_weights_keys is not None:
@@ -72,16 +69,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
72
  ]
73
 
74
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
 
75
  # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
76
- # FSDP throws an error if some of the layer types are not found in the model, and they need to be filted out.
77
- # 1. Get the names the language model *wants* to keep intact
78
- candidate_names = set(
79
- getattr(self.language_model, "_no_split_modules", []) or []
80
- )
81
- # 2. Names that actually exist in the current model
82
- present_names = {m.__class__.__name__ for m in self.modules()}
83
- # 3. Keep only those that are both requested and present
84
- self._no_split_modules = list(candidate_names & present_names)
85
 
86
  self.loss_config = LossConfig()
87
  self.post_init()
@@ -159,17 +149,13 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
159
  self, labels: Optional[torch.Tensor]
160
  ) -> Tuple[torch.Tensor, torch.Tensor]:
161
  """Get boolean masks for positions where we want to compute KL divergence.
162
-
163
  For each label position, we want the position before it since that's where
164
  the model makes the prediction for that label.
165
-
166
  Additionally, we want to identify the position right before the EOT token
167
  (the last token with label != -100).
168
-
169
  Args:
170
  labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
171
  with -100 for masked positions and token ids for label positions
172
-
173
  Returns:
174
  Tuple containing:
175
  - pred_mask: Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
@@ -239,32 +225,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
239
  )
240
 
241
  # Compute the KL divergence loss for EOT token positions if any exist
242
- if self.loss_config.eot_loss_weight > 0:
243
- eot_loss = F.kl_div(
244
- F.log_softmax(
245
- lm_output.logits[eot_mask] / self.loss_config.kl_temperature,
246
- dim=-1,
247
- ),
248
- F.softmax(
249
- alt_lm_output.logits[alt_eot_mask]
250
- / self.loss_config.kl_temperature,
251
- dim=-1,
252
- ),
253
- reduction="batchmean",
254
- )
255
- kl_loss += self.loss_config.eot_loss_weight * eot_loss
256
 
257
- return kl_loss
258
 
259
  def _audio_iter(
260
  self, audio_batch_size: torch.Tensor
261
  ) -> Generator[Tuple[int, int], None, None]:
262
  """
263
  Iterate over the audio batch size and yield the batch index and audio index of each audio item.
264
-
265
  Args:
266
  audio_batch_size: A tensor of shape (B,) where B is the batch size.
267
-
268
  Returns:
269
  A generator that yields a tuple of (start index, length) for each audio item.
270
  """
@@ -277,8 +258,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
277
  def forward(
278
  self,
279
  input_ids: torch.Tensor,
280
- audio_values: Optional[torch.Tensor] = None,
281
- inputs_embeds: Optional[torch.Tensor] = None,
282
  labels: Optional[torch.Tensor] = None,
283
  attention_mask: Optional[torch.Tensor] = None,
284
  audio_token_start_idx: Optional[torch.Tensor] = None,
@@ -291,16 +272,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
291
  alt_attention_mask: Optional[torch.Tensor] = None,
292
  alt_labels: Optional[torch.Tensor] = None,
293
  **kwargs,
294
- ) -> transformers.modeling_outputs.CausalLMOutputWithPast:
295
  """
296
  Forward pass for the Ultravox model.
297
-
298
  `input_ids` are the tokenized text input. They are embedded by the language model as usual.
299
  `audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
300
  projected to the language model's embedding space using a few linear layers.
301
  The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
302
  of the audio embeddings in the merged embeddings.
303
-
304
  Args:
305
  input_ids: The tokenized text input.
306
  audio_values: The processed audio values.
@@ -316,14 +295,36 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
316
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
317
 
318
  if audio_values is not None and len(audio_values) > 0:
319
- inputs_embeds = self._prepare_audio_embeds(
320
- inputs_embeds=inputs_embeds,
321
- audio_values=audio_values,
322
- audio_token_start_idx=audio_token_start_idx,
323
- audio_lens=audio_lens,
324
- audio_token_len=audio_token_len,
325
- audio_batch_size=audio_batch_size,
326
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  lm_output = self.language_model.forward(
329
  inputs_embeds=inputs_embeds,
@@ -334,9 +335,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
334
  )
335
  if self.training:
336
  if self.loss_config.loss_function == LossFunction.CrossEntropy:
337
- pass
338
  elif self.loss_config.loss_function == LossFunction.KL_Divergence:
339
- lm_output.loss = self._compute_kl_loss(
340
  lm_output=lm_output,
341
  labels=labels,
342
  past_key_values=past_key_values,
@@ -349,82 +350,52 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
349
  raise ValueError(
350
  f"Unsupported loss function: {self.loss_config.loss_function}"
351
  )
352
- return lm_output
 
353
 
354
- def _prepare_audio_embeds(
355
- self,
356
- inputs_embeds: Optional[torch.Tensor] = None,
357
- audio_values: Optional[torch.Tensor] = None,
358
- audio_token_start_idx: Optional[torch.Tensor] = None,
359
- audio_lens: Optional[torch.Tensor] = None,
360
- audio_token_len: Optional[torch.Tensor] = None,
361
- audio_batch_size: Optional[torch.Tensor] = None,
362
- ) -> torch.Tensor:
363
- assert (
364
- inputs_embeds is not None
365
- and audio_values is not None
366
- and audio_token_start_idx is not None
367
- and audio_token_len is not None
368
- and audio_lens is not None
369
- and audio_batch_size is not None
370
- ), "inputs_embeds/audio_values/audio_token_start_idx/audio_token_len/audio_lens/audio_batch_size must be provided."
371
- assert (
372
- len(audio_token_start_idx)
373
- == len(audio_token_len)
374
- == len(audio_lens)
375
- == len(audio_values)
376
- ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
377
- assert len(audio_batch_size) == len(
378
- inputs_embeds
379
- ), "audio_batch_size and inputs_embeds must have the same batch size."
380
-
381
- # B x A/3200 x (D=max-audio-length-in-batch)
382
- audio_tower_output = self.audio_tower.forward(
383
- audio_values.to(self.audio_tower.dtype),
384
- audio_len=audio_lens,
385
- ).last_hidden_state
386
- audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
387
- audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
388
-
389
- # combine audio and text embeddings
390
- for i_b, i_a in self._audio_iter(audio_batch_size):
391
- start_idx = audio_token_start_idx[i_a]
392
- token_len = audio_token_len[i_a]
393
- item_embedding = audio_embeds[i_a][:token_len]
394
- inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
395
-
396
- return inputs_embeds
397
-
398
- def generate(
399
  self,
400
  input_ids: torch.Tensor,
401
- audio_values: Optional[torch.Tensor] = None,
402
- inputs_embeds: Optional[torch.Tensor] = None,
403
  audio_token_start_idx: Optional[torch.Tensor] = None,
404
- audio_lens: Optional[torch.Tensor] = None,
405
  audio_token_len: Optional[torch.Tensor] = None,
 
406
  audio_batch_size: Optional[torch.Tensor] = None,
 
 
 
 
407
  **kwargs,
408
- ) -> torch.Tensor:
409
- if inputs_embeds is None:
410
- inputs_embeds = self.get_input_embeddings().forward(input_ids)
411
-
412
- if audio_values is not None and len(audio_values) > 0:
413
- inputs_embeds = self._prepare_audio_embeds(
414
- inputs_embeds=inputs_embeds,
415
- audio_values=audio_values,
416
- audio_token_start_idx=audio_token_start_idx,
417
- audio_lens=audio_lens,
418
- audio_token_len=audio_token_len,
419
- audio_batch_size=audio_batch_size,
420
- )
421
-
422
- return self.language_model.generate(
423
  input_ids=input_ids,
 
 
424
  inputs_embeds=inputs_embeds,
 
425
  **kwargs,
426
  )
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  @classmethod
429
  def _create_multi_modal_projector(
430
  cls, config: UltravoxConfig
@@ -454,9 +425,6 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
454
  audio_tower.init_latency_mask(
455
  config.audio_latency_block_size, dtype=config.torch_dtype
456
  )
457
- audio_tower.init_latency_mask(
458
- config.audio_latency_block_size, dtype=config.torch_dtype
459
- )
460
  else:
461
  assert config.audio_latency_block_size in (
462
  None,
@@ -539,9 +507,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
539
  )
540
  )
541
 
542
- if hasattr(self, "audio_tower") and isinstance(
543
- self.audio_tower, peft.PeftModel
544
- ):
545
  self.audio_tower = self.audio_tower.merge_and_unload()
546
  # no need to download base audio model weights anymore, so we can remove the id
547
  self.config.audio_model_id = None
@@ -607,33 +573,18 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
607
  )
608
 
609
  lm_trainable_params, lm_all_params = count_params(self.language_model)
610
- if hasattr(self, "audio_tower") and self.audio_tower is not None:
611
- audio_trainable_params, audio_all_params = count_params(self.audio_tower)
612
- else:
613
- audio_trainable_params, audio_all_params = 0, 0
614
 
615
  projector_trainable_params = (
616
  trainable_params - lm_trainable_params - audio_trainable_params
617
  )
618
  projector_all_params = all_param - lm_all_params - audio_all_params
619
 
620
- # Calculate percentages only if the total parameters are non-zero
621
- audio_percent = (
622
- 0.0
623
- if audio_all_params == 0
624
- else 100 * audio_trainable_params / audio_all_params
625
- )
626
- projector_percent = (
627
- 0.0
628
- if projector_all_params == 0
629
- else 100 * projector_trainable_params / projector_all_params
630
- )
631
-
632
  logging.info(
633
  f"Trainable%: "
634
  f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
635
- f" || Audio Encoder: {audio_percent:.1f}%"
636
- f" || Projector: {projector_percent:.1f}%"
637
  )
638
 
639
 
@@ -770,7 +721,6 @@ class UltravoxProjector(nn.Module):
770
  Takes in audio features from the audio tower and projects them to the text model's embedding space.
771
  It reduces the number of frames by a factor of `stack_factor` and increases the number of channels by the same factor.
772
  If the number of audio frames are not a multiple of the stack factor, the last few frames will be padded with zeros.
773
-
774
  Input shape:
775
  audio_features: B, T*S, C
776
  Output shape:
@@ -784,7 +734,6 @@ class UltravoxProjector(nn.Module):
784
  C: number of channels out of the encoder (aka audio tower)
785
  H: hidden size of the projector (config.hidden_size)
786
  D: dimension of the text model (config.text_config.hidden_size)
787
-
788
  """
789
  # B, F, C -> B, T, C*S
790
  audio_features = self._pad_and_stack(audio_features)
@@ -805,13 +754,11 @@ class ModifiedWhisperEncoder(
805
  ):
806
  """
807
  Encoder portion of OpenAI's Whisper model.
808
-
809
  This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
810
  1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
811
  2. allow less than 30 second of audio padding to be passed in:
812
  - relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
813
  - embed_pos is now sliced to match the length of `inputs_embeds`
814
-
815
  Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
816
  """
817
 
@@ -913,7 +860,7 @@ class ModifiedWhisperEncoder(
913
  # This masking ensures consistent behavior between training and inference
914
  # by preventing the model from attending to padding tokens in both cases
915
  attention_mask = None
916
- if audio_len is not None:
917
  audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
918
  max_seq_len = hidden_states.shape[1]
919
  attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
 
33
  class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
34
  """
35
  The Ultravox model which consists of an audio encoder and a language model.
 
36
  Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
37
  projected to the language model's embedding space using a few linear layers.
38
  The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
 
39
  A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
 
40
  Parameters:
41
  config: Model configuration class with all the parameters of the model.
42
  """
 
56
  self.keep_params: Set[str] = set()
57
  self.vocab_size = config.vocab_size
58
 
59
+ self.audio_tower = self._create_audio_tower(config)
60
+ self.audio_tower_context_length: Optional[int] = None
61
+ self.audio_tower_context_length = self.audio_tower.max_context_length
 
62
 
63
+ self.multi_modal_projector = self._create_multi_modal_projector(config)
64
  self.language_model = self._create_language_model(config)
65
 
66
  if self.language_model._tied_weights_keys is not None:
 
69
  ]
70
 
71
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
72
+ # FSDP throws an error if some of the layer types are not found in the model.
73
  # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
74
+ self._no_split_modules = self.language_model._no_split_modules
 
 
 
 
 
 
 
 
75
 
76
  self.loss_config = LossConfig()
77
  self.post_init()
 
149
  self, labels: Optional[torch.Tensor]
150
  ) -> Tuple[torch.Tensor, torch.Tensor]:
151
  """Get boolean masks for positions where we want to compute KL divergence.
 
152
  For each label position, we want the position before it since that's where
153
  the model makes the prediction for that label.
 
154
  Additionally, we want to identify the position right before the EOT token
155
  (the last token with label != -100).
 
156
  Args:
157
  labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
158
  with -100 for masked positions and token ids for label positions
 
159
  Returns:
160
  Tuple containing:
161
  - pred_mask: Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
 
225
  )
226
 
227
  # Compute the KL divergence loss for EOT token positions if any exist
228
+ eot_loss = F.kl_div(
229
+ F.log_softmax(
230
+ lm_output.logits[eot_mask] / self.loss_config.kl_temperature,
231
+ dim=-1,
232
+ ),
233
+ F.softmax(
234
+ alt_lm_output.logits[alt_eot_mask] / self.loss_config.kl_temperature,
235
+ dim=-1,
236
+ ),
237
+ reduction="batchmean",
238
+ )
 
 
 
239
 
240
+ return {"loss": kl_loss + self.loss_config.eot_loss_weight * eot_loss}
241
 
242
  def _audio_iter(
243
  self, audio_batch_size: torch.Tensor
244
  ) -> Generator[Tuple[int, int], None, None]:
245
  """
246
  Iterate over the audio batch size and yield the batch index and audio index of each audio item.
 
247
  Args:
248
  audio_batch_size: A tensor of shape (B,) where B is the batch size.
 
249
  Returns:
250
  A generator that yields a tuple of (start index, length) for each audio item.
251
  """
 
258
  def forward(
259
  self,
260
  input_ids: torch.Tensor,
261
+ audio_values: Optional[torch.FloatTensor] = None,
262
+ inputs_embeds: Optional[torch.FloatTensor] = None,
263
  labels: Optional[torch.Tensor] = None,
264
  attention_mask: Optional[torch.Tensor] = None,
265
  audio_token_start_idx: Optional[torch.Tensor] = None,
 
272
  alt_attention_mask: Optional[torch.Tensor] = None,
273
  alt_labels: Optional[torch.Tensor] = None,
274
  **kwargs,
275
+ ) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
276
  """
277
  Forward pass for the Ultravox model.
 
278
  `input_ids` are the tokenized text input. They are embedded by the language model as usual.
279
  `audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
280
  projected to the language model's embedding space using a few linear layers.
281
  The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
282
  of the audio embeddings in the merged embeddings.
 
283
  Args:
284
  input_ids: The tokenized text input.
285
  audio_values: The processed audio values.
 
295
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
296
 
297
  if audio_values is not None and len(audio_values) > 0:
298
+ assert (
299
+ audio_token_start_idx is not None
300
+ and audio_token_len is not None
301
+ and audio_lens is not None
302
+ and audio_batch_size is not None
303
+ ), "audio_token_start_idx/audio_token_len/audio_lens must be provided if audio_values are provided."
304
+ assert (
305
+ len(audio_token_start_idx)
306
+ == len(audio_token_len)
307
+ == len(audio_lens)
308
+ == len(audio_values)
309
+ ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
310
+ assert len(audio_batch_size) == len(
311
+ inputs_embeds
312
+ ), "audio_batch_size and inputs_embeds must have the same batch size."
313
+
314
+ # B x A/3200 x (D=max-audio-length-in-batch)
315
+ audio_tower_output = self.audio_tower.forward(
316
+ audio_values.to(self.audio_tower.dtype),
317
+ audio_len=audio_lens,
318
+ ).last_hidden_state
319
+ audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
320
+ audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
321
+
322
+ # combine audio and text embeddings
323
+ for i_b, i_a in self._audio_iter(audio_batch_size):
324
+ start_idx = audio_token_start_idx[i_a]
325
+ token_len = audio_token_len[i_a]
326
+ item_embedding = audio_embeds[i_a][:token_len]
327
+ inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
328
 
329
  lm_output = self.language_model.forward(
330
  inputs_embeds=inputs_embeds,
 
335
  )
336
  if self.training:
337
  if self.loss_config.loss_function == LossFunction.CrossEntropy:
338
+ return lm_output
339
  elif self.loss_config.loss_function == LossFunction.KL_Divergence:
340
+ return self._compute_kl_loss(
341
  lm_output=lm_output,
342
  labels=labels,
343
  past_key_values=past_key_values,
 
350
  raise ValueError(
351
  f"Unsupported loss function: {self.loss_config.loss_function}"
352
  )
353
+ else:
354
+ return lm_output
355
 
356
+ def prepare_inputs_for_generation(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  self,
358
  input_ids: torch.Tensor,
359
+ audio_values: Optional[torch.FloatTensor] = None,
 
360
  audio_token_start_idx: Optional[torch.Tensor] = None,
 
361
  audio_token_len: Optional[torch.Tensor] = None,
362
+ audio_lens: Optional[torch.Tensor] = None,
363
  audio_batch_size: Optional[torch.Tensor] = None,
364
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
365
+ attention_mask: Optional[torch.Tensor] = None,
366
+ inputs_embeds: Optional[torch.Tensor] = None,
367
+ cache_position: Optional[torch.Tensor] = None,
368
  **kwargs,
369
+ ) -> Dict[str, Any]:
370
+ model_input = self.language_model.prepare_inputs_for_generation(
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  input_ids=input_ids,
372
+ past_key_values=past_key_values,
373
+ attention_mask=attention_mask,
374
  inputs_embeds=inputs_embeds,
375
+ cache_position=cache_position,
376
  **kwargs,
377
  )
378
 
379
+ # include audio information in model_input only when it is needed during prefilling
380
+ # audio_token_start_idx should always be relative to the current cache position
381
+ prefill_start_idx: int | torch.Tensor = (
382
+ 0 if cache_position is None else cache_position[0]
383
+ )
384
+ if (
385
+ audio_values is not None
386
+ and audio_token_start_idx is not None
387
+ and prefill_start_idx <= torch.max(audio_token_start_idx)
388
+ ):
389
+ model_input["audio_values"] = audio_values
390
+ model_input["audio_token_start_idx"] = (
391
+ audio_token_start_idx - prefill_start_idx
392
+ )
393
+ model_input["audio_token_len"] = audio_token_len
394
+ model_input["audio_batch_size"] = audio_batch_size
395
+ model_input["audio_lens"] = audio_lens
396
+
397
+ return model_input
398
+
399
  @classmethod
400
  def _create_multi_modal_projector(
401
  cls, config: UltravoxConfig
 
425
  audio_tower.init_latency_mask(
426
  config.audio_latency_block_size, dtype=config.torch_dtype
427
  )
 
 
 
428
  else:
429
  assert config.audio_latency_block_size in (
430
  None,
 
507
  )
508
  )
509
 
510
+ if isinstance(self.audio_tower, peft.PeftModel):
 
 
511
  self.audio_tower = self.audio_tower.merge_and_unload()
512
  # no need to download base audio model weights anymore, so we can remove the id
513
  self.config.audio_model_id = None
 
573
  )
574
 
575
  lm_trainable_params, lm_all_params = count_params(self.language_model)
576
+ audio_trainable_params, audio_all_params = count_params(self.audio_tower)
 
 
 
577
 
578
  projector_trainable_params = (
579
  trainable_params - lm_trainable_params - audio_trainable_params
580
  )
581
  projector_all_params = all_param - lm_all_params - audio_all_params
582
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  logging.info(
584
  f"Trainable%: "
585
  f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
586
+ f" || Audio Encoder: {100 * audio_trainable_params / audio_all_params:.1f}%"
587
+ f" || Projector: {100 * projector_trainable_params / projector_all_params:.1f}%"
588
  )
589
 
590
 
 
721
  Takes in audio features from the audio tower and projects them to the text model's embedding space.
722
  It reduces the number of frames by a factor of `stack_factor` and increases the number of channels by the same factor.
723
  If the number of audio frames are not a multiple of the stack factor, the last few frames will be padded with zeros.
 
724
  Input shape:
725
  audio_features: B, T*S, C
726
  Output shape:
 
734
  C: number of channels out of the encoder (aka audio tower)
735
  H: hidden size of the projector (config.hidden_size)
736
  D: dimension of the text model (config.text_config.hidden_size)
 
737
  """
738
  # B, F, C -> B, T, C*S
739
  audio_features = self._pad_and_stack(audio_features)
 
754
  ):
755
  """
756
  Encoder portion of OpenAI's Whisper model.
 
757
  This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
758
  1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
759
  2. allow less than 30 second of audio padding to be passed in:
760
  - relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
761
  - embed_pos is now sliced to match the length of `inputs_embeds`
 
762
  Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
763
  """
764
 
 
860
  # This masking ensures consistent behavior between training and inference
861
  # by preventing the model from attending to padding tokens in both cases
862
  attention_mask = None
863
+ if audio_len != None:
864
  audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
865
  max_seq_len = hidden_states.shape[1]
866
  attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[