seanmanifest commited on
Commit
abd96ca
·
verified ·
1 Parent(s): 4036b68

Add files using upload-large-folder tool

Browse files
Files changed (1) hide show
  1. modeling_brumby.py +159 -43
modeling_brumby.py CHANGED
@@ -21,17 +21,11 @@ import torch
21
  from torch import nn
22
 
23
  from transformers.activations import ACT2FN
24
- from transformers.cache_utils import Cache, DynamicCache
25
  from transformers.generation import GenerationMixin
26
  from transformers.integrations import use_kernel_forward_from_hub
27
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
28
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
- from transformers.modeling_layers import (
30
- GenericForQuestionAnswering,
31
- GenericForSequenceClassification,
32
- GenericForTokenClassification,
33
- GradientCheckpointingLayer,
34
- )
35
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
36
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -377,7 +371,7 @@ class BrumbyAttention(nn.Module):
377
  return attn_output, attn_weights
378
 
379
 
380
- class BrumbyDecoderLayer(GradientCheckpointingLayer):
381
  def __init__(self, config: BrumbyConfig, layer_idx: int):
382
  super().__init__()
383
  self.hidden_size = config.hidden_size
@@ -529,24 +523,9 @@ class BrumbyModel(BrumbyPreTrainedModel):
529
  if position_ids is None:
530
  position_ids = cache_position.unsqueeze(0)
531
 
532
- # It may already have been prepared by e.g. `generate`
533
- if not isinstance(causal_mask_mapping := attention_mask, dict):
534
- # Prepare mask arguments
535
- mask_kwargs = {
536
- "config": self.config,
537
- "input_embeds": inputs_embeds,
538
- "attention_mask": attention_mask,
539
- "cache_position": cache_position,
540
- "past_key_values": past_key_values,
541
- "position_ids": position_ids,
542
- }
543
- # Create the masks
544
- causal_mask_mapping = {
545
- "full_attention": create_causal_mask(**mask_kwargs),
546
- }
547
- # The sliding window alternating layers are not always activated depending on the config
548
- if self.has_sliding_layers:
549
- causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
550
 
551
  hidden_states = inputs_embeds
552
 
@@ -556,7 +535,7 @@ class BrumbyModel(BrumbyPreTrainedModel):
556
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
557
  hidden_states = decoder_layer(
558
  hidden_states,
559
- attention_mask=causal_mask_mapping[decoder_layer.attention_type],
560
  position_ids=position_ids,
561
  past_key_values=past_key_values,
562
  use_cache=use_cache,
@@ -571,6 +550,158 @@ class BrumbyModel(BrumbyPreTrainedModel):
571
  past_key_values=past_key_values if use_cache else None,
572
  )
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
  @auto_docstring
576
  class BrumbyForCausalLM(BrumbyPreTrainedModel, GenerationMixin):
@@ -719,23 +850,8 @@ class BrumbyForCausalLM(BrumbyPreTrainedModel, GenerationMixin):
719
  )
720
 
721
 
722
- class BrumbyForSequenceClassification(GenericForSequenceClassification, BrumbyPreTrainedModel):
723
- pass
724
-
725
-
726
- class BrumbyForTokenClassification(GenericForTokenClassification, BrumbyPreTrainedModel):
727
- pass
728
-
729
-
730
- class BrumbyForQuestionAnswering(GenericForQuestionAnswering, BrumbyPreTrainedModel):
731
- base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
732
-
733
-
734
  __all__ = [
735
  "BrumbyForCausalLM",
736
- "BrumbyForQuestionAnswering",
737
  "BrumbyPreTrainedModel",
738
  "BrumbyModel",
739
- "BrumbyForSequenceClassification",
740
- "BrumbyForTokenClassification",
741
  ]
 
21
  from torch import nn
22
 
23
  from transformers.activations import ACT2FN
24
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache, SlidingWindowCache
25
  from transformers.generation import GenerationMixin
26
  from transformers.integrations import use_kernel_forward_from_hub
27
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
28
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
 
 
 
 
 
 
29
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
30
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
31
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
 
371
  return attn_output, attn_weights
372
 
373
 
374
+ class BrumbyDecoderLayer(nn.Module):
375
  def __init__(self, config: BrumbyConfig, layer_idx: int):
376
  super().__init__()
377
  self.hidden_size = config.hidden_size
 
523
  if position_ids is None:
524
  position_ids = cache_position.unsqueeze(0)
525
 
526
+ causal_mask = self._update_causal_mask(
527
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions=False
528
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
  hidden_states = inputs_embeds
531
 
 
535
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
536
  hidden_states = decoder_layer(
537
  hidden_states,
538
+ attention_mask=causal_mask,
539
  position_ids=position_ids,
540
  past_key_values=past_key_values,
541
  use_cache=use_cache,
 
550
  past_key_values=past_key_values if use_cache else None,
551
  )
552
 
553
+ def _update_causal_mask(
554
+ self,
555
+ attention_mask: torch.Tensor,
556
+ input_tensor: torch.Tensor,
557
+ cache_position: torch.Tensor,
558
+ past_key_values: Cache,
559
+ output_attentions: bool = False,
560
+ ):
561
+ if self.config._attn_implementation == "flash_attention_2":
562
+ if attention_mask is not None and past_key_values is not None:
563
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
564
+ if is_padding_right:
565
+ raise ValueError(
566
+ "You are attempting to perform batched generation with padding_side='right'"
567
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
568
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
569
+ )
570
+ if attention_mask is not None and 0.0 in attention_mask:
571
+ return attention_mask
572
+ return None
573
+
574
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
575
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
576
+ # to infer the attention mask.
577
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
578
+ using_static_cache = isinstance(past_key_values, StaticCache)
579
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
580
+
581
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
582
+ if (
583
+ self.config._attn_implementation == "sdpa"
584
+ and not (using_static_cache or using_sliding_window_cache)
585
+ and not output_attentions
586
+ ):
587
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
588
+ attention_mask,
589
+ inputs_embeds=input_tensor,
590
+ past_key_values_length=past_seen_tokens,
591
+ sliding_window=self.config.sliding_window,
592
+ is_training=self.training,
593
+ ):
594
+ return None
595
+
596
+ dtype, device = input_tensor.dtype, input_tensor.device
597
+ min_dtype = torch.finfo(dtype).min
598
+ sequence_length = input_tensor.shape[1]
599
+ # SlidingWindowCache or StaticCache
600
+ if using_sliding_window_cache or using_static_cache:
601
+ target_length = past_key_values.get_max_cache_shape()
602
+ # DynamicCache or no cache
603
+ else:
604
+ target_length = (
605
+ attention_mask.shape[-1]
606
+ if isinstance(attention_mask, torch.Tensor)
607
+ else past_seen_tokens + sequence_length + 1
608
+ )
609
+
610
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
611
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
612
+ attention_mask,
613
+ sequence_length=sequence_length,
614
+ target_length=target_length,
615
+ dtype=dtype,
616
+ device=device,
617
+ cache_position=cache_position,
618
+ batch_size=input_tensor.shape[0],
619
+ config=self.config,
620
+ past_key_values=past_key_values,
621
+ )
622
+
623
+ if (
624
+ self.config._attn_implementation == "sdpa"
625
+ and attention_mask is not None
626
+ and attention_mask.device.type in ["cuda", "xpu"]
627
+ and not output_attentions
628
+ ):
629
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
630
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
631
+ # Details: https://github.com/pytorch/pytorch/issues/110213
632
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
633
+
634
+ return causal_mask
635
+
636
+ @staticmethod
637
+ def _prepare_4d_causal_attention_mask_with_cache_position(
638
+ attention_mask: torch.Tensor,
639
+ sequence_length: int,
640
+ target_length: int,
641
+ dtype: torch.dtype,
642
+ device: torch.device,
643
+ cache_position: torch.Tensor,
644
+ batch_size: int,
645
+ config: BrumbyConfig,
646
+ past_key_values: Cache,
647
+ ):
648
+ """
649
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
650
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
651
+
652
+ Args:
653
+ attention_mask (`torch.Tensor`):
654
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
655
+ sequence_length (`int`):
656
+ The sequence length being processed.
657
+ target_length (`int`):
658
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
659
+ dtype (`torch.dtype`):
660
+ The dtype to use for the 4D attention mask.
661
+ device (`torch.device`):
662
+ The device to place the 4D attention mask on.
663
+ cache_position (`torch.Tensor`):
664
+ Indices depicting the position of the input sequence tokens in the sequence.
665
+ batch_size (`torch.Tensor`):
666
+ Batch size.
667
+ config (`Qwen3Config`):
668
+ The model's configuration class
669
+ past_key_values (`Cache`):
670
+ The cache class that is being used currently to generate
671
+ """
672
+ if attention_mask is not None and attention_mask.dim() == 4:
673
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
674
+ causal_mask = attention_mask
675
+ else:
676
+ min_dtype = torch.finfo(dtype).min
677
+ causal_mask = torch.full(
678
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
679
+ )
680
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
681
+ if config.sliding_window is not None:
682
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
683
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
684
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
685
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
686
+ cache_position.reshape(-1, 1) - config.sliding_window
687
+ )
688
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
689
+ causal_mask *= diagonal_attend_mask
690
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
691
+ if attention_mask is not None:
692
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
693
+ if attention_mask.shape[-1] > target_length:
694
+ attention_mask = attention_mask[:, :target_length]
695
+ mask_length = attention_mask.shape[-1]
696
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
697
+ causal_mask.device
698
+ )
699
+ padding_mask = padding_mask == 0
700
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
701
+ padding_mask, min_dtype
702
+ )
703
+ return causal_mask
704
+
705
 
706
  @auto_docstring
707
  class BrumbyForCausalLM(BrumbyPreTrainedModel, GenerationMixin):
 
850
  )
851
 
852
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  __all__ = [
854
  "BrumbyForCausalLM",
 
855
  "BrumbyPreTrainedModel",
856
  "BrumbyModel",
 
 
857
  ]