Upload _SpydazWebAI_Mistral_Transformer_.py
Browse files
_SpydazWebAI_Mistral_Transformer_.py
CHANGED
|
@@ -661,6 +661,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|
| 661 |
if module.padding_idx is not None:
|
| 662 |
module.weight.data[module.padding_idx].zero_()
|
| 663 |
|
|
|
|
| 664 |
@add_start_docstrings(
|
| 665 |
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
| 666 |
MISTRAL_START_DOCSTRING,
|
|
@@ -673,7 +674,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 673 |
config: MistralConfig
|
| 674 |
"""
|
| 675 |
|
| 676 |
-
def __init__(self, config):
|
| 677 |
super().__init__(config)
|
| 678 |
self.padding_idx = config.pad_token_id
|
| 679 |
self.vocab_size = config.vocab_size
|
|
@@ -694,8 +695,6 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 694 |
|
| 695 |
def set_input_embeddings(self, value):
|
| 696 |
self.embed_tokens = value
|
| 697 |
-
|
| 698 |
-
|
| 699 |
|
| 700 |
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
| 701 |
def forward(
|
|
@@ -703,12 +702,13 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 703 |
input_ids: torch.LongTensor = None,
|
| 704 |
attention_mask: Optional[torch.Tensor] = None,
|
| 705 |
position_ids: Optional[torch.LongTensor] = None,
|
| 706 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 707 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 708 |
use_cache: Optional[bool] = None,
|
| 709 |
output_attentions: Optional[bool] = None,
|
| 710 |
output_hidden_states: Optional[bool] = None,
|
| 711 |
return_dict: Optional[bool] = None,
|
|
|
|
| 712 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 713 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 714 |
output_hidden_states = (
|
|
@@ -719,73 +719,42 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 719 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 720 |
|
| 721 |
# retrieve input_ids and inputs_embeds
|
| 722 |
-
if input_ids is
|
| 723 |
-
raise ValueError(
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
elif inputs_embeds is not None:
|
| 727 |
-
batch_size, seq_length, _ = inputs_embeds.shape
|
| 728 |
-
else:
|
| 729 |
-
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 730 |
-
|
| 731 |
-
if self.gradient_checkpointing and self.training:
|
| 732 |
-
if use_cache:
|
| 733 |
-
logger.warning_once(
|
| 734 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 735 |
-
)
|
| 736 |
-
use_cache = False
|
| 737 |
-
|
| 738 |
-
past_key_values_length = 0
|
| 739 |
-
|
| 740 |
-
if use_cache:
|
| 741 |
-
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 742 |
-
if use_legacy_cache:
|
| 743 |
-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 744 |
-
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 745 |
|
| 746 |
-
if
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 750 |
)
|
| 751 |
-
|
| 752 |
-
else:
|
| 753 |
-
position_ids = position_ids.view(-1, seq_length).long()
|
| 754 |
|
| 755 |
if inputs_embeds is None:
|
| 756 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 757 |
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
)
|
| 766 |
-
|
| 767 |
-
if self._attn_implementation == "flash_attention_2":
|
| 768 |
-
# 2d mask is passed through the layers
|
| 769 |
-
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 770 |
-
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
| 771 |
-
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 772 |
-
# the manual implementation that requires a 4D causal mask in all cases.
|
| 773 |
-
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 774 |
-
attention_mask,
|
| 775 |
-
(batch_size, seq_length),
|
| 776 |
-
inputs_embeds,
|
| 777 |
-
past_key_values_length,
|
| 778 |
)
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
inputs_embeds,
|
| 785 |
-
past_key_values_length,
|
| 786 |
-
sliding_window=self.config.sliding_window,
|
| 787 |
)
|
| 788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
hidden_states = inputs_embeds
|
| 790 |
|
| 791 |
# decoder layers
|
|
@@ -801,20 +770,22 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 801 |
layer_outputs = self._gradient_checkpointing_func(
|
| 802 |
decoder_layer.__call__,
|
| 803 |
hidden_states,
|
| 804 |
-
|
| 805 |
position_ids,
|
| 806 |
past_key_values,
|
| 807 |
output_attentions,
|
| 808 |
use_cache,
|
|
|
|
| 809 |
)
|
| 810 |
else:
|
| 811 |
layer_outputs = decoder_layer(
|
| 812 |
hidden_states,
|
| 813 |
-
attention_mask=
|
| 814 |
position_ids=position_ids,
|
| 815 |
past_key_value=past_key_values,
|
| 816 |
output_attentions=output_attentions,
|
| 817 |
use_cache=use_cache,
|
|
|
|
| 818 |
)
|
| 819 |
|
| 820 |
hidden_states = layer_outputs[0]
|
|
@@ -831,9 +802,9 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 831 |
if output_hidden_states:
|
| 832 |
all_hidden_states += (hidden_states,)
|
| 833 |
|
| 834 |
-
next_cache = None
|
| 835 |
-
if
|
| 836 |
-
next_cache =
|
| 837 |
|
| 838 |
if not return_dict:
|
| 839 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
@@ -853,7 +824,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 853 |
use_cache: bool,
|
| 854 |
output_attentions: bool,
|
| 855 |
):
|
| 856 |
-
|
| 857 |
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
| 858 |
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
| 859 |
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
|
@@ -952,6 +923,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
| 952 |
|
| 953 |
return causal_mask
|
| 954 |
|
|
|
|
| 955 |
############################## LM Heads #################################
|
| 956 |
|
| 957 |
|
|
|
|
| 661 |
if module.padding_idx is not None:
|
| 662 |
module.weight.data[module.padding_idx].zero_()
|
| 663 |
|
| 664 |
+
|
| 665 |
@add_start_docstrings(
|
| 666 |
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
| 667 |
MISTRAL_START_DOCSTRING,
|
|
|
|
| 674 |
config: MistralConfig
|
| 675 |
"""
|
| 676 |
|
| 677 |
+
def __init__(self, config: MistralConfig):
|
| 678 |
super().__init__(config)
|
| 679 |
self.padding_idx = config.pad_token_id
|
| 680 |
self.vocab_size = config.vocab_size
|
|
|
|
| 695 |
|
| 696 |
def set_input_embeddings(self, value):
|
| 697 |
self.embed_tokens = value
|
|
|
|
|
|
|
| 698 |
|
| 699 |
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
| 700 |
def forward(
|
|
|
|
| 702 |
input_ids: torch.LongTensor = None,
|
| 703 |
attention_mask: Optional[torch.Tensor] = None,
|
| 704 |
position_ids: Optional[torch.LongTensor] = None,
|
| 705 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 706 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 707 |
use_cache: Optional[bool] = None,
|
| 708 |
output_attentions: Optional[bool] = None,
|
| 709 |
output_hidden_states: Optional[bool] = None,
|
| 710 |
return_dict: Optional[bool] = None,
|
| 711 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 712 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 713 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 714 |
output_hidden_states = (
|
|
|
|
| 719 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 720 |
|
| 721 |
# retrieve input_ids and inputs_embeds
|
| 722 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 723 |
+
raise ValueError(
|
| 724 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
| 725 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
|
| 727 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 728 |
+
logger.warning_once(
|
| 729 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
|
| 730 |
)
|
| 731 |
+
use_cache = False
|
|
|
|
|
|
|
| 732 |
|
| 733 |
if inputs_embeds is None:
|
| 734 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 735 |
|
| 736 |
+
return_legacy_cache = False
|
| 737 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 738 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 739 |
+
return_legacy_cache = True
|
| 740 |
+
logger.warning_once(
|
| 741 |
+
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
|
| 742 |
+
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
)
|
| 744 |
+
|
| 745 |
+
if cache_position is None:
|
| 746 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 747 |
+
cache_position = torch.arange(
|
| 748 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
|
|
|
|
|
|
|
|
| 749 |
)
|
| 750 |
|
| 751 |
+
if position_ids is None:
|
| 752 |
+
position_ids = cache_position.unsqueeze(0)
|
| 753 |
+
|
| 754 |
+
causal_mask = self._update_causal_mask(
|
| 755 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
hidden_states = inputs_embeds
|
| 759 |
|
| 760 |
# decoder layers
|
|
|
|
| 770 |
layer_outputs = self._gradient_checkpointing_func(
|
| 771 |
decoder_layer.__call__,
|
| 772 |
hidden_states,
|
| 773 |
+
causal_mask,
|
| 774 |
position_ids,
|
| 775 |
past_key_values,
|
| 776 |
output_attentions,
|
| 777 |
use_cache,
|
| 778 |
+
cache_position,
|
| 779 |
)
|
| 780 |
else:
|
| 781 |
layer_outputs = decoder_layer(
|
| 782 |
hidden_states,
|
| 783 |
+
attention_mask=causal_mask,
|
| 784 |
position_ids=position_ids,
|
| 785 |
past_key_value=past_key_values,
|
| 786 |
output_attentions=output_attentions,
|
| 787 |
use_cache=use_cache,
|
| 788 |
+
cache_position=cache_position,
|
| 789 |
)
|
| 790 |
|
| 791 |
hidden_states = layer_outputs[0]
|
|
|
|
| 802 |
if output_hidden_states:
|
| 803 |
all_hidden_states += (hidden_states,)
|
| 804 |
|
| 805 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 806 |
+
if return_legacy_cache:
|
| 807 |
+
next_cache = next_cache.to_legacy_cache()
|
| 808 |
|
| 809 |
if not return_dict:
|
| 810 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
|
| 824 |
use_cache: bool,
|
| 825 |
output_attentions: bool,
|
| 826 |
):
|
| 827 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
| 828 |
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
| 829 |
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
| 830 |
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
|
|
|
| 923 |
|
| 924 |
return causal_mask
|
| 925 |
|
| 926 |
+
|
| 927 |
############################## LM Heads #################################
|
| 928 |
|
| 929 |
|