|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Fast Tokenization classes for Bert.""" |
|
|
|
|
|
import math |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
from transformers import BatchEncoding |
|
|
from transformers.tokenization_utils_base import EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy |
|
|
from transformers.utils import PaddingStrategy, TensorType, logging |
|
|
from transformers import BertTokenizerFast |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class Bert2DTokenizerFast(BertTokenizerFast): |
|
|
r""" |
|
|
Construct a "fast" BERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. |
|
|
|
|
|
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should |
|
|
refer to this superclass for more information regarding those methods. |
|
|
|
|
|
Args: |
|
|
vocab_file (`str`): |
|
|
File containing the vocabulary. |
|
|
do_lower_case (`bool`, *optional*, defaults to `True`): |
|
|
Whether or not to lowercase the input when tokenizing. |
|
|
unk_token (`str`, *optional*, defaults to `\"[UNK]\"`): |
|
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this |
|
|
token instead. |
|
|
sep_token (`str`, *optional*, defaults to `\"[SEP]\"`): |
|
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for |
|
|
sequence classification or for a text and a question for question answering. It is also used as the last |
|
|
token of a sequence built with special tokens. |
|
|
pad_token (`str`, *optional*, defaults to `\"[PAD]\"`): |
|
|
The token used for padding, for example when batching sequences of different lengths. |
|
|
cls_token (`str`, *optional*, defaults to `\"[CLS]\"`): |
|
|
The classifier token which is used when doing sequence classification (classification of the whole sequence |
|
|
instead of per-token classification). It is the first token of the sequence when built with special tokens. |
|
|
mask_token (`str`, *optional*, defaults to `\"[MASK]\"`): |
|
|
The token used for masking values. This is the token used when training this model with masked language |
|
|
modeling. This is the token which the model will try to predict. |
|
|
clean_text (`bool`, *optional*, defaults to `True`): |
|
|
Whether or not to clean the text before tokenization by removing any control characters and replacing all |
|
|
whitespaces by the classic one. |
|
|
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): |
|
|
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this |
|
|
issue](https://github.com/huggingface/transformers/issues/328)). |
|
|
strip_accents (`bool`, *optional*): |
|
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the |
|
|
value for `lowercase` (as in the original BERT). |
|
|
wordpieces_prefix (`str`, *optional*, defaults to `\"##\"`): |
|
|
The prefix for subwords. |
|
|
max_intermediate_subword_positions_per_word (`int`, *optional*, defaults to `1`): |
|
|
The maximum number of intermediate subword positions per word. This is used to determine how many subword |
|
|
positions are allowed for each word in the tokenization process. |
|
|
subword_embedding_order (`str`, *optional*, defaults to `\"ending_first\"`): |
|
|
The order in which subword embeddings are processed. Can be `\"ending_first\"` or `\"starting_first\"`. |
|
|
intermediate_subword_distribution_strategy (`str`, *optional*, defaults to `\"uniform\"`): |
|
|
The strategy for distributing intermediate subword positions. Can be `\"uniform\"` or `\"random\"`. |
|
|
(Note: The original prompt mentioned "uniform" or "random", but the function code provided earlier |
|
|
implemented "uniform" or "leftover_as_last". This docstring reflects the prompt's options.) |
|
|
""" |
|
|
|
|
|
model_input_names: List[str] = ["input_ids", "token_type_ids", "word_ids", "subword_ids", "attention_mask"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_file=None, |
|
|
tokenizer_file=None, |
|
|
do_lower_case=True, |
|
|
unk_token="[UNK]", |
|
|
sep_token="[SEP]", |
|
|
pad_token="[PAD]", |
|
|
cls_token="[CLS]", |
|
|
mask_token="[MASK]", |
|
|
tokenize_chinese_chars=True, |
|
|
strip_accents=None, |
|
|
max_intermediate_subword_positions_per_word=1, |
|
|
subword_embedding_order="ending_first", |
|
|
intermediate_subword_distribution_strategy="uniform", |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__( |
|
|
vocab_file=vocab_file, |
|
|
tokenizer_file=tokenizer_file, |
|
|
do_lower_case=do_lower_case, |
|
|
unk_token=unk_token, |
|
|
sep_token=sep_token, |
|
|
pad_token=pad_token, |
|
|
cls_token=cls_token, |
|
|
mask_token=mask_token, |
|
|
tokenize_chinese_chars=tokenize_chinese_chars, |
|
|
strip_accents=strip_accents, |
|
|
**kwargs, |
|
|
) |
|
|
self.max_intermediate_subword_positions_per_word = max_intermediate_subword_positions_per_word |
|
|
self.subword_embedding_order = subword_embedding_order |
|
|
self.intermediate_subword_distribution_strategy = intermediate_subword_distribution_strategy |
|
|
|
|
|
|
|
|
self.init_kwargs["max_intermediate_subword_positions_per_word"] = max_intermediate_subword_positions_per_word |
|
|
self.init_kwargs["subword_embedding_order"] = subword_embedding_order |
|
|
self.init_kwargs["intermediate_subword_distribution_strategy"] = intermediate_subword_distribution_strategy |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None, |
|
|
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, |
|
|
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None, |
|
|
text_pair_target: Optional[ |
|
|
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] |
|
|
] = None, |
|
|
add_special_tokens: bool = True, |
|
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
|
truncation: Union[bool, str, TruncationStrategy, None] = None, |
|
|
max_length: Optional[int] = None, |
|
|
stride: int = 0, |
|
|
is_split_into_words: bool = False, |
|
|
pad_to_multiple_of: Optional[int] = None, |
|
|
padding_side: Optional[str] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
return_token_type_ids: Optional[bool] = None, |
|
|
return_attention_mask: Optional[bool] = None, |
|
|
return_overflowing_tokens: bool = False, |
|
|
return_special_tokens_mask: bool = False, |
|
|
return_offsets_mapping: bool = False, |
|
|
return_length: bool = False, |
|
|
verbose: bool = True, |
|
|
**kwargs, |
|
|
) -> BatchEncoding: |
|
|
|
|
|
result = super().__call__( |
|
|
text=text, |
|
|
text_pair=text_pair, |
|
|
text_target=text_target, |
|
|
text_pair_target=text_pair_target, |
|
|
add_special_tokens=add_special_tokens, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
max_length=max_length, |
|
|
stride=stride, |
|
|
is_split_into_words=is_split_into_words, |
|
|
pad_to_multiple_of=pad_to_multiple_of, |
|
|
return_tensors=None, |
|
|
return_token_type_ids=return_token_type_ids, |
|
|
return_attention_mask=return_attention_mask, |
|
|
return_overflowing_tokens=return_overflowing_tokens, |
|
|
return_special_tokens_mask=return_special_tokens_mask, |
|
|
return_offsets_mapping=return_offsets_mapping, |
|
|
return_length=return_length, |
|
|
verbose=verbose, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
is_batched_output = ( |
|
|
isinstance(result["input_ids"], list) and result["input_ids"] and isinstance(result["input_ids"][0], list) |
|
|
) |
|
|
|
|
|
actual_batch_size: int |
|
|
seq_length: int |
|
|
|
|
|
if is_batched_output: |
|
|
actual_batch_size = len(result["input_ids"]) |
|
|
seq_length = len(result["input_ids"][0]) if actual_batch_size > 0 else 0 |
|
|
else: |
|
|
actual_batch_size = 1 |
|
|
seq_length = len(result["input_ids"]) |
|
|
|
|
|
if "input_ids" in result and not isinstance(result["input_ids"][0], list): |
|
|
for key in result: |
|
|
if isinstance(result[key], list): |
|
|
result[key] = [result[key]] |
|
|
|
|
|
|
|
|
list_of_list_word_ids: List[List[int]] = [] |
|
|
list_of_list_subword_ids: List[List[int]] = [] |
|
|
|
|
|
for i in range(actual_batch_size): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_tokens = result.tokens(i) |
|
|
|
|
|
|
|
|
should_restart_word_ids_heuristic = current_tokens.count(self.sep_token) >= 2 |
|
|
|
|
|
list_of_list_word_ids.append( |
|
|
create_word_ids( |
|
|
current_tokens, |
|
|
restart_new_sentence=should_restart_word_ids_heuristic, |
|
|
seperator_token=self.sep_token, |
|
|
padding_token=self.pad_token, |
|
|
) |
|
|
) |
|
|
list_of_list_subword_ids.append( |
|
|
create_subword_ids( |
|
|
current_tokens, |
|
|
self.max_intermediate_subword_positions_per_word, |
|
|
self.subword_embedding_order, |
|
|
self.intermediate_subword_distribution_strategy, |
|
|
cls_token=self.cls_token, |
|
|
sep_token=self.sep_token, |
|
|
pad_token=self.pad_token, |
|
|
) |
|
|
) |
|
|
|
|
|
padding_value_for_ids = 0 |
|
|
|
|
|
effective_padding_side = padding_side if padding_side is not None else self.padding_side |
|
|
|
|
|
|
|
|
if padding_strategy_uses_max_length(padding, max_length): |
|
|
for i in range(actual_batch_size): |
|
|
for id_list in [list_of_list_word_ids[i], list_of_list_subword_ids[i]]: |
|
|
current_len = len(id_list) |
|
|
pad_len = seq_length - current_len |
|
|
if pad_len > 0: |
|
|
if effective_padding_side == "right": |
|
|
id_list.extend([padding_value_for_ids] * pad_len) |
|
|
else: |
|
|
for _ in range(pad_len): |
|
|
id_list.insert(0, padding_value_for_ids) |
|
|
elif pad_len < 0: |
|
|
if effective_padding_side == "right": |
|
|
del id_list[seq_length:] |
|
|
else: |
|
|
del id_list[:-seq_length] |
|
|
|
|
|
if not is_batched_output: |
|
|
result["word_ids"] = list_of_list_word_ids[0] |
|
|
result["subword_ids"] = list_of_list_subword_ids[0] |
|
|
|
|
|
for key in list(result.keys()): |
|
|
if isinstance(result[key], list) and len(result[key]) == 1 and key not in ["word_ids", "subword_ids"]: |
|
|
|
|
|
if isinstance(result[key][0], list): |
|
|
result[key] = result[key][0] |
|
|
else: |
|
|
result["word_ids"] = list_of_list_word_ids |
|
|
result["subword_ids"] = list_of_list_subword_ids |
|
|
|
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == "pt": |
|
|
import torch |
|
|
|
|
|
|
|
|
result = result.convert_to_tensors(tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
if not is_batched_output: |
|
|
for key, value in result.items(): |
|
|
if isinstance(value, torch.Tensor) and value.ndim == 1: |
|
|
|
|
|
result[key] = value.unsqueeze(0) |
|
|
else: |
|
|
|
|
|
result = result.convert_to_tensors(tensor_type=return_tensors) |
|
|
|
|
|
return result |
|
|
|
|
|
def encode_plus( |
|
|
self, |
|
|
text: Union[TextInput, PreTokenizedInput], |
|
|
text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None, |
|
|
add_special_tokens: bool = True, |
|
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
|
truncation: Union[bool, str, TruncationStrategy, None] = None, |
|
|
max_length: Optional[int] = None, |
|
|
stride: int = 0, |
|
|
is_split_into_words: bool = False, |
|
|
pad_to_multiple_of: Optional[int] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
return_token_type_ids: Optional[bool] = None, |
|
|
return_attention_mask: Optional[bool] = None, |
|
|
return_overflowing_tokens: bool = False, |
|
|
return_special_tokens_mask: bool = False, |
|
|
return_offsets_mapping: bool = False, |
|
|
return_length: bool = False, |
|
|
verbose: bool = True, |
|
|
**kwargs, |
|
|
) -> BatchEncoding: |
|
|
""" |
|
|
Tokenize and prepare for the model a sequence or a pair of sequences. |
|
|
This method includes the generation of word_ids and subword_ids specific to Bert2D. |
|
|
""" |
|
|
|
|
|
result = super().encode_plus( |
|
|
text=text, |
|
|
text_pair=text_pair, |
|
|
add_special_tokens=add_special_tokens, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
max_length=max_length, |
|
|
stride=stride, |
|
|
is_split_into_words=is_split_into_words, |
|
|
pad_to_multiple_of=pad_to_multiple_of, |
|
|
return_tensors=None, |
|
|
return_token_type_ids=return_token_type_ids, |
|
|
return_attention_mask=return_attention_mask, |
|
|
return_overflowing_tokens=return_overflowing_tokens, |
|
|
return_special_tokens_mask=return_special_tokens_mask, |
|
|
return_offsets_mapping=return_offsets_mapping, |
|
|
return_length=return_length, |
|
|
verbose=verbose, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
has_overflow = return_overflowing_tokens and "overflowing_tokens" in result |
|
|
|
|
|
|
|
|
is_batched_output = ( |
|
|
isinstance(result["input_ids"], list) and result["input_ids"] and isinstance(result["input_ids"][0], list) |
|
|
) |
|
|
|
|
|
|
|
|
if has_overflow or is_batched_output: |
|
|
|
|
|
batch_size = len(result["input_ids"]) if is_batched_output else 1 + len(result["overflowing_tokens"]) |
|
|
batch_word_ids = [] |
|
|
batch_subword_ids = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
|
|
|
tokens = result.tokens(i) |
|
|
|
|
|
|
|
|
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 |
|
|
|
|
|
word_ids = create_word_ids( |
|
|
tokens, |
|
|
restart_new_sentence=should_restart_word_ids_heuristic, |
|
|
seperator_token=self.sep_token, |
|
|
padding_token=self.pad_token, |
|
|
) |
|
|
|
|
|
subword_ids = create_subword_ids( |
|
|
tokens, |
|
|
self.max_intermediate_subword_positions_per_word, |
|
|
self.subword_embedding_order, |
|
|
self.intermediate_subword_distribution_strategy, |
|
|
cls_token=self.cls_token, |
|
|
sep_token=self.sep_token, |
|
|
pad_token=self.pad_token, |
|
|
) |
|
|
|
|
|
batch_word_ids.append(word_ids) |
|
|
batch_subword_ids.append(subword_ids) |
|
|
|
|
|
|
|
|
result["word_ids"] = batch_word_ids |
|
|
result["subword_ids"] = batch_subword_ids |
|
|
else: |
|
|
|
|
|
tokens = result.tokens() |
|
|
|
|
|
|
|
|
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 |
|
|
|
|
|
word_ids = create_word_ids( |
|
|
tokens, |
|
|
restart_new_sentence=should_restart_word_ids_heuristic, |
|
|
seperator_token=self.sep_token, |
|
|
padding_token=self.pad_token, |
|
|
) |
|
|
|
|
|
subword_ids = create_subword_ids( |
|
|
tokens, |
|
|
self.max_intermediate_subword_positions_per_word, |
|
|
self.subword_embedding_order, |
|
|
self.intermediate_subword_distribution_strategy, |
|
|
cls_token=self.cls_token, |
|
|
sep_token=self.sep_token, |
|
|
pad_token=self.pad_token, |
|
|
) |
|
|
|
|
|
|
|
|
result["word_ids"] = word_ids |
|
|
result["subword_ids"] = subword_ids |
|
|
|
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == "pt": |
|
|
import torch |
|
|
|
|
|
|
|
|
result = result.convert_to_tensors(tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
if not is_batched_output and not has_overflow: |
|
|
for key, value in result.items(): |
|
|
if isinstance(value, torch.Tensor) and value.ndim == 1: |
|
|
|
|
|
result[key] = value.unsqueeze(0) |
|
|
else: |
|
|
|
|
|
result = result.convert_to_tensors(tensor_type=return_tensors) |
|
|
|
|
|
return result |
|
|
|
|
|
def batch_encode_plus( |
|
|
self, |
|
|
batch_text_or_text_pairs: Union[ |
|
|
List[TextInput], |
|
|
List[PreTokenizedInput], |
|
|
List[Union[TextInput, PreTokenizedInput]], |
|
|
], |
|
|
add_special_tokens: bool = True, |
|
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
|
truncation: Union[bool, str, TruncationStrategy, None] = None, |
|
|
max_length: Optional[int] = None, |
|
|
stride: int = 0, |
|
|
is_split_into_words: bool = False, |
|
|
pad_to_multiple_of: Optional[int] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
return_token_type_ids: Optional[bool] = None, |
|
|
return_attention_mask: Optional[bool] = None, |
|
|
return_overflowing_tokens: bool = False, |
|
|
return_special_tokens_mask: bool = False, |
|
|
return_offsets_mapping: bool = False, |
|
|
return_length: bool = False, |
|
|
verbose: bool = True, |
|
|
**kwargs, |
|
|
) -> BatchEncoding: |
|
|
""" |
|
|
Tokenize and prepare a batch of sequences or a batch of sequence pairs for the model. |
|
|
This method includes the generation of word_ids and subword_ids specific to Bert2D. |
|
|
""" |
|
|
|
|
|
result = super().batch_encode_plus( |
|
|
batch_text_or_text_pairs, |
|
|
add_special_tokens=add_special_tokens, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
max_length=max_length, |
|
|
stride=stride, |
|
|
is_split_into_words=is_split_into_words, |
|
|
pad_to_multiple_of=pad_to_multiple_of, |
|
|
return_tensors=None, |
|
|
return_token_type_ids=return_token_type_ids, |
|
|
return_attention_mask=return_attention_mask, |
|
|
return_overflowing_tokens=return_overflowing_tokens, |
|
|
return_special_tokens_mask=return_special_tokens_mask, |
|
|
return_offsets_mapping=return_offsets_mapping, |
|
|
return_length=return_length, |
|
|
verbose=verbose, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = len(result["input_ids"]) |
|
|
batch_word_ids = [] |
|
|
batch_subword_ids = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
|
|
|
tokens = result.tokens(i) |
|
|
|
|
|
|
|
|
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 |
|
|
|
|
|
word_ids = create_word_ids( |
|
|
tokens, |
|
|
restart_new_sentence=should_restart_word_ids_heuristic, |
|
|
seperator_token=self.sep_token, |
|
|
padding_token=self.pad_token, |
|
|
) |
|
|
|
|
|
subword_ids = create_subword_ids( |
|
|
tokens, |
|
|
self.max_intermediate_subword_positions_per_word, |
|
|
self.subword_embedding_order, |
|
|
self.intermediate_subword_distribution_strategy, |
|
|
cls_token=self.cls_token, |
|
|
sep_token=self.sep_token, |
|
|
pad_token=self.pad_token, |
|
|
) |
|
|
|
|
|
batch_word_ids.append(word_ids) |
|
|
batch_subword_ids.append(subword_ids) |
|
|
|
|
|
|
|
|
result["word_ids"] = batch_word_ids |
|
|
result["subword_ids"] = batch_subword_ids |
|
|
|
|
|
|
|
|
|
|
|
if return_tensors is not None: |
|
|
result = result.convert_to_tensors(tensor_type=return_tensors) |
|
|
|
|
|
return result |
|
|
|
|
|
def prepare_for_model( |
|
|
self, |
|
|
ids: List[int], |
|
|
pair_ids: Optional[List[int]] = None, |
|
|
add_special_tokens: bool = True, |
|
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
|
truncation: Union[bool, str, TruncationStrategy, None] = None, |
|
|
max_length: Optional[int] = None, |
|
|
stride: int = 0, |
|
|
pad_to_multiple_of: Optional[int] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
return_token_type_ids: Optional[bool] = None, |
|
|
return_attention_mask: Optional[bool] = None, |
|
|
return_overflowing_tokens: bool = False, |
|
|
return_special_tokens_mask: bool = False, |
|
|
return_offsets_mapping: bool = False, |
|
|
return_length: bool = False, |
|
|
verbose: bool = True, |
|
|
prepend_batch_axis: bool = False, |
|
|
**kwargs, |
|
|
) -> BatchEncoding: |
|
|
""" |
|
|
Prepares a sequence of input id, or a pair of sequences of inputs ids, for the model. |
|
|
This method adds `word_ids` and `subword_ids` specific to Bert2D. |
|
|
""" |
|
|
|
|
|
prepared_inputs = super().prepare_for_model( |
|
|
ids=ids, |
|
|
pair_ids=pair_ids, |
|
|
add_special_tokens=add_special_tokens, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
max_length=max_length, |
|
|
stride=stride, |
|
|
pad_to_multiple_of=pad_to_multiple_of, |
|
|
return_tensors=None, |
|
|
return_token_type_ids=return_token_type_ids, |
|
|
return_attention_mask=return_attention_mask, |
|
|
return_overflowing_tokens=return_overflowing_tokens, |
|
|
return_special_tokens_mask=return_special_tokens_mask, |
|
|
return_offsets_mapping=return_offsets_mapping, |
|
|
return_length=return_length, |
|
|
verbose=verbose, |
|
|
prepend_batch_axis=prepend_batch_axis, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
tokens = self.convert_ids_to_tokens(prepared_inputs["input_ids"]) |
|
|
|
|
|
|
|
|
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 |
|
|
|
|
|
|
|
|
prepared_inputs["word_ids"] = create_word_ids( |
|
|
tokens, |
|
|
restart_new_sentence=should_restart_word_ids_heuristic, |
|
|
seperator_token=self.sep_token, |
|
|
padding_token=self.pad_token, |
|
|
) |
|
|
|
|
|
|
|
|
prepared_inputs["subword_ids"] = create_subword_ids( |
|
|
tokens, |
|
|
self.max_intermediate_subword_positions_per_word, |
|
|
self.subword_embedding_order, |
|
|
self.intermediate_subword_distribution_strategy, |
|
|
cls_token=self.cls_token, |
|
|
sep_token=self.sep_token, |
|
|
pad_token=self.pad_token, |
|
|
) |
|
|
|
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == "pt": |
|
|
import torch |
|
|
|
|
|
|
|
|
prepared_inputs = prepared_inputs.convert_to_tensors(tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
if not prepend_batch_axis: |
|
|
for key, value in prepared_inputs.items(): |
|
|
if isinstance(value, torch.Tensor) and value.ndim == 1: |
|
|
|
|
|
prepared_inputs[key] = value.unsqueeze(0) |
|
|
else: |
|
|
|
|
|
prepared_inputs = prepared_inputs.convert_to_tensors(tensor_type=return_tensors) |
|
|
|
|
|
return prepared_inputs |
|
|
|
|
|
def _pad( |
|
|
self, |
|
|
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], |
|
|
max_length: Optional[int] = None, |
|
|
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, |
|
|
pad_to_multiple_of: Optional[int] = None, |
|
|
padding_side: Optional[str] = None, |
|
|
return_attention_mask: Optional[bool] = None, |
|
|
) -> dict: |
|
|
effective_padding_side = padding_side if padding_side is not None else self.padding_side |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
word_ids_list = None |
|
|
if "word_ids" in encoded_inputs and isinstance(encoded_inputs["word_ids"], list): |
|
|
word_ids_list = encoded_inputs.pop("word_ids") |
|
|
|
|
|
subword_ids_list = None |
|
|
if "subword_ids" in encoded_inputs and isinstance(encoded_inputs["subword_ids"], list): |
|
|
subword_ids_list = encoded_inputs.pop("subword_ids") |
|
|
|
|
|
|
|
|
|
|
|
padded_standard_inputs = super()._pad( |
|
|
encoded_inputs, |
|
|
max_length=max_length, |
|
|
padding_strategy=padding_strategy, |
|
|
pad_to_multiple_of=pad_to_multiple_of, |
|
|
padding_side=effective_padding_side, |
|
|
return_attention_mask=return_attention_mask, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if padding_strategy != PaddingStrategy.DO_NOT_PAD and max_length is not None: |
|
|
main_input_name = self.model_input_names[0] |
|
|
if main_input_name not in padded_standard_inputs: |
|
|
|
|
|
|
|
|
if word_ids_list is not None: |
|
|
padded_standard_inputs["word_ids"] = word_ids_list |
|
|
if subword_ids_list is not None: |
|
|
padded_standard_inputs["subword_ids"] = subword_ids_list |
|
|
return padded_standard_inputs |
|
|
|
|
|
padded_len = len(padded_standard_inputs[main_input_name]) |
|
|
padding_val = 0 |
|
|
|
|
|
if word_ids_list is not None: |
|
|
current_len = len(word_ids_list) |
|
|
diff = padded_len - current_len |
|
|
if diff > 0: |
|
|
if effective_padding_side == "right": |
|
|
padded_standard_inputs["word_ids"] = word_ids_list + [padding_val] * diff |
|
|
else: |
|
|
padded_standard_inputs["word_ids"] = [padding_val] * diff + word_ids_list |
|
|
else: |
|
|
padded_standard_inputs["word_ids"] = word_ids_list[:padded_len] |
|
|
|
|
|
|
|
|
|
|
|
elif "word_ids" in encoded_inputs: |
|
|
padded_standard_inputs["word_ids"] = encoded_inputs["word_ids"] |
|
|
|
|
|
if subword_ids_list is not None: |
|
|
current_len = len(subword_ids_list) |
|
|
diff = padded_len - current_len |
|
|
if diff > 0: |
|
|
if effective_padding_side == "right": |
|
|
padded_standard_inputs["subword_ids"] = subword_ids_list + [padding_val] * diff |
|
|
else: |
|
|
padded_standard_inputs["subword_ids"] = [padding_val] * diff + subword_ids_list |
|
|
else: |
|
|
padded_standard_inputs["subword_ids"] = subword_ids_list[:padded_len] |
|
|
elif "subword_ids" in encoded_inputs: |
|
|
padded_standard_inputs["subword_ids"] = encoded_inputs["subword_ids"] |
|
|
|
|
|
else: |
|
|
if word_ids_list is not None: |
|
|
padded_standard_inputs["word_ids"] = word_ids_list |
|
|
elif "word_ids" in encoded_inputs: |
|
|
padded_standard_inputs["word_ids"] = encoded_inputs["word_ids"] |
|
|
|
|
|
if subword_ids_list is not None: |
|
|
padded_standard_inputs["subword_ids"] = subword_ids_list |
|
|
elif "subword_ids" in encoded_inputs: |
|
|
padded_standard_inputs["subword_ids"] = encoded_inputs["subword_ids"] |
|
|
|
|
|
return padded_standard_inputs |
|
|
|
|
|
def apply_chat_template( |
|
|
self, |
|
|
conversation, |
|
|
chat_template=None, |
|
|
tools=None, |
|
|
documents=None, |
|
|
add_generation_prompt=False, |
|
|
tokenize=True, |
|
|
padding=False, |
|
|
truncation=None, |
|
|
max_length=None, |
|
|
return_tensors=None, |
|
|
return_dict=False, |
|
|
return_assistant_tokens_mask=False, |
|
|
tokenizer_kwargs=None, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Override apply_chat_template to fix tensor dimension issues when |
|
|
return_tensors="pt" is used with single conversations and return_assistant_tokens_mask=True. |
|
|
""" |
|
|
|
|
|
needs_tensor_fix = ( |
|
|
return_tensors == "pt" |
|
|
and return_assistant_tokens_mask |
|
|
and return_dict |
|
|
and tokenize |
|
|
and not isinstance(conversation[0], list) |
|
|
if conversation |
|
|
else False |
|
|
) |
|
|
|
|
|
if needs_tensor_fix: |
|
|
|
|
|
|
|
|
result = super().apply_chat_template( |
|
|
conversation=conversation, |
|
|
chat_template=chat_template, |
|
|
tools=tools, |
|
|
documents=documents, |
|
|
add_generation_prompt=add_generation_prompt, |
|
|
tokenize=tokenize, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
max_length=max_length, |
|
|
return_tensors=None, |
|
|
return_dict=return_dict, |
|
|
return_assistant_tokens_mask=return_assistant_tokens_mask, |
|
|
tokenizer_kwargs=tokenizer_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
if return_tensors == "pt": |
|
|
import torch |
|
|
|
|
|
|
|
|
for key, value in result.items(): |
|
|
if isinstance(value, list): |
|
|
|
|
|
tensor = torch.tensor(value) |
|
|
|
|
|
if tensor.dim() == 0: |
|
|
tensor = tensor.unsqueeze(0) |
|
|
if tensor.dim() == 1: |
|
|
tensor = tensor.unsqueeze(0) |
|
|
result[key] = tensor |
|
|
|
|
|
return result |
|
|
else: |
|
|
|
|
|
return super().apply_chat_template( |
|
|
conversation=conversation, |
|
|
chat_template=chat_template, |
|
|
tools=tools, |
|
|
documents=documents, |
|
|
add_generation_prompt=add_generation_prompt, |
|
|
tokenize=tokenize, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
max_length=max_length, |
|
|
return_tensors=return_tensors, |
|
|
return_dict=return_dict, |
|
|
return_assistant_tokens_mask=return_assistant_tokens_mask, |
|
|
tokenizer_kwargs=tokenizer_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def is_subword(token: str, subword_prefix="##") -> bool: |
|
|
"""Returns if a token is a subword""" |
|
|
return token.startswith(subword_prefix) |
|
|
|
|
|
|
|
|
def create_word_ids( |
|
|
tokens: List[str], restart_new_sentence=False, seperator_token="[SEP]", padding_token="[PAD]" |
|
|
) -> List[int]: |
|
|
"""Creates word ids for given tokens, matching the logic from Bert2DTokenizerFast tests.""" |
|
|
word_ids: List[int] = [] |
|
|
current_word_id: int = -1 |
|
|
sentence_restart_flag = False |
|
|
|
|
|
actual_restart_new_sentence = restart_new_sentence and tokens.count(seperator_token) >= 2 |
|
|
|
|
|
for token in tokens: |
|
|
if token == padding_token: |
|
|
word_ids.append(0) |
|
|
|
|
|
elif actual_restart_new_sentence and not sentence_restart_flag and token == seperator_token: |
|
|
if current_word_id == -1: |
|
|
current_word_id = 0 |
|
|
word_ids.append(current_word_id) |
|
|
else: |
|
|
current_word_id += 1 |
|
|
word_ids.append(current_word_id) |
|
|
|
|
|
current_word_id = -1 |
|
|
sentence_restart_flag = True |
|
|
elif not is_subword(token): |
|
|
current_word_id += 1 |
|
|
word_ids.append(current_word_id) |
|
|
elif current_word_id == -1: |
|
|
current_word_id = 0 |
|
|
word_ids.append(current_word_id) |
|
|
else: |
|
|
word_ids.append(current_word_id) |
|
|
return word_ids |
|
|
|
|
|
|
|
|
def col_round(x: float) -> int: |
|
|
"""Colloquial rounding where 0.5 rounds to 1""" |
|
|
frac = x - math.floor(x) |
|
|
if frac < 0.5: |
|
|
return math.floor(x) |
|
|
return math.ceil(x) |
|
|
|
|
|
|
|
|
def get_uniform_id(si: int, max_intermediate_subwords: int, num_intermediate_subwords: int) -> int: |
|
|
"""Calculates uniform id for the given subword index, si, and max and number of intermediate subwords""" |
|
|
if num_intermediate_subwords == 0: |
|
|
return 0 |
|
|
|
|
|
|
|
|
|
|
|
effective_max_pos = max(0, max_intermediate_subwords - 1) |
|
|
return col_round(si * effective_max_pos / num_intermediate_subwords) |
|
|
|
|
|
|
|
|
def get_ids_from_subwords( |
|
|
num_subwords_in_current_word: int, |
|
|
max_intermediate_subword_positions_per_word: int, |
|
|
subword_embedding_order: str, |
|
|
intermediate_subword_distribution_strategy: str, |
|
|
current_word_starts_with_subword: bool = False, |
|
|
) -> List[int]: |
|
|
"""Calculate subword ids for the tokens of a single word.""" |
|
|
|
|
|
if num_subwords_in_current_word == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
if current_word_starts_with_subword: |
|
|
if num_subwords_in_current_word == 1: |
|
|
return [1] |
|
|
elif num_subwords_in_current_word == 1: |
|
|
return [0] |
|
|
|
|
|
|
|
|
if subword_embedding_order == "ending_first": |
|
|
subword_ids: List[int] = [] |
|
|
|
|
|
has_explicit_root = not current_word_starts_with_subword and num_subwords_in_current_word > 0 |
|
|
|
|
|
|
|
|
has_explicit_last = num_subwords_in_current_word > 1 or ( |
|
|
current_word_starts_with_subword and num_subwords_in_current_word == 1 |
|
|
) |
|
|
|
|
|
if has_explicit_root: |
|
|
subword_ids.append(0) |
|
|
|
|
|
num_tokens_for_intermediate_and_last = num_subwords_in_current_word - 1 |
|
|
else: |
|
|
|
|
|
num_tokens_for_intermediate_and_last = num_subwords_in_current_word |
|
|
|
|
|
num_intermediate_tokens = 0 |
|
|
if has_explicit_last: |
|
|
|
|
|
num_intermediate_tokens = num_tokens_for_intermediate_and_last - 1 |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
num_intermediate_tokens = num_tokens_for_intermediate_and_last |
|
|
|
|
|
|
|
|
if num_intermediate_tokens < 0: |
|
|
num_intermediate_tokens = 0 |
|
|
|
|
|
|
|
|
if num_intermediate_tokens > 0: |
|
|
if num_intermediate_tokens <= max_intermediate_subword_positions_per_word: |
|
|
|
|
|
for si in range(num_intermediate_tokens): |
|
|
subword_ids.append(2 + si) |
|
|
else: |
|
|
if intermediate_subword_distribution_strategy == "uniform": |
|
|
for si in range(num_intermediate_tokens): |
|
|
subword_ids.append( |
|
|
2 |
|
|
+ get_uniform_id(si, max_intermediate_subword_positions_per_word, num_intermediate_tokens) |
|
|
) |
|
|
elif intermediate_subword_distribution_strategy == "leftover_as_last": |
|
|
|
|
|
for si in range(max_intermediate_subword_positions_per_word): |
|
|
subword_ids.append(2 + si) |
|
|
|
|
|
for _ in range(num_intermediate_tokens - max_intermediate_subword_positions_per_word): |
|
|
subword_ids.append(1) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unsupported intermediate subword distribution strategy: {intermediate_subword_distribution_strategy}" |
|
|
) |
|
|
|
|
|
if has_explicit_last: |
|
|
subword_ids.append(1) |
|
|
|
|
|
return subword_ids |
|
|
else: |
|
|
raise ValueError(f"Unsupported subword embedding order: {subword_embedding_order}") |
|
|
|
|
|
|
|
|
def create_subword_ids( |
|
|
tokens: List[str], |
|
|
max_intermediate_subword_positions_per_word: int, |
|
|
subword_embedding_order: str, |
|
|
intermediate_subword_distribution_strategy: str, |
|
|
cls_token="[CLS]", |
|
|
sep_token="[SEP]", |
|
|
pad_token="[PAD]", |
|
|
) -> List[int]: |
|
|
"""Creates subword ids for the given tokens and parameters.""" |
|
|
|
|
|
if not tokens: |
|
|
return [] |
|
|
|
|
|
all_subword_ids: List[int] = [] |
|
|
current_word_segment_tokens: List[str] = [] |
|
|
|
|
|
|
|
|
|
|
|
first_content_token_is_subword = False |
|
|
if tokens: |
|
|
for token_val in tokens: |
|
|
if token_val not in [cls_token, sep_token, pad_token]: |
|
|
first_content_token_is_subword = is_subword(token_val) |
|
|
break |
|
|
|
|
|
first_content_word_processed = False |
|
|
|
|
|
for token_idx, token in enumerate(tokens): |
|
|
if token in [cls_token, sep_token, pad_token]: |
|
|
|
|
|
if current_word_segment_tokens: |
|
|
|
|
|
is_this_segment_the_very_first_content_word_and_starts_with_subword = ( |
|
|
first_content_token_is_subword |
|
|
and not first_content_word_processed |
|
|
and is_subword(current_word_segment_tokens[0]) |
|
|
) |
|
|
|
|
|
generated_ids = get_ids_from_subwords( |
|
|
num_subwords_in_current_word=len(current_word_segment_tokens), |
|
|
max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word, |
|
|
subword_embedding_order=subword_embedding_order, |
|
|
intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy, |
|
|
current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword, |
|
|
) |
|
|
all_subword_ids.extend(generated_ids) |
|
|
|
|
|
if ( |
|
|
not first_content_word_processed and current_word_segment_tokens |
|
|
): |
|
|
first_content_word_processed = True |
|
|
current_word_segment_tokens = [] |
|
|
|
|
|
all_subword_ids.append(0) |
|
|
elif not is_subword(token): |
|
|
|
|
|
if current_word_segment_tokens: |
|
|
is_this_segment_the_very_first_content_word_and_starts_with_subword = ( |
|
|
first_content_token_is_subword |
|
|
and not first_content_word_processed |
|
|
and is_subword(current_word_segment_tokens[0]) |
|
|
) |
|
|
|
|
|
generated_ids = get_ids_from_subwords( |
|
|
num_subwords_in_current_word=len(current_word_segment_tokens), |
|
|
max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word, |
|
|
subword_embedding_order=subword_embedding_order, |
|
|
intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy, |
|
|
current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword, |
|
|
) |
|
|
all_subword_ids.extend(generated_ids) |
|
|
if not first_content_word_processed and current_word_segment_tokens: |
|
|
first_content_word_processed = True |
|
|
current_word_segment_tokens = [token] |
|
|
else: |
|
|
current_word_segment_tokens.append(token) |
|
|
|
|
|
|
|
|
if current_word_segment_tokens: |
|
|
is_this_segment_the_very_first_content_word_and_starts_with_subword = ( |
|
|
first_content_token_is_subword |
|
|
and not first_content_word_processed |
|
|
and is_subword(current_word_segment_tokens[0]) |
|
|
) |
|
|
|
|
|
generated_ids = get_ids_from_subwords( |
|
|
num_subwords_in_current_word=len(current_word_segment_tokens), |
|
|
max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word, |
|
|
subword_embedding_order=subword_embedding_order, |
|
|
intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy, |
|
|
current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword, |
|
|
) |
|
|
all_subword_ids.extend(generated_ids) |
|
|
|
|
|
return all_subword_ids |
|
|
|
|
|
|
|
|
def padding_strategy_uses_max_length( |
|
|
padding_strategy: Union[bool, str, PaddingStrategy], max_length: Optional[int] |
|
|
) -> bool: |
|
|
"""Helper to determine if padding will occur up to a max_length.""" |
|
|
if padding_strategy is False or padding_strategy == PaddingStrategy.DO_NOT_PAD: |
|
|
return False |
|
|
if padding_strategy is True or padding_strategy == PaddingStrategy.LONGEST: |
|
|
|
|
|
return True |
|
|
if padding_strategy == PaddingStrategy.MAX_LENGTH: |
|
|
return max_length is not None |
|
|
return False |
|
|
|
|
|
|
|
|
__all__ = ["Bert2DTokenizerFast"] |
|
|
|