Bert2D-cased-Turkish-128K-WWM-NSW2 / tokenization_bert2d_fast.py
yigitbekir's picture
Update tokenization_bert2d_fast.py
4a6e110 verified
raw
history blame
50 kB
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Tokenization classes for Bert."""
import math
from typing import Dict, List, Optional, Union
from transformers import BatchEncoding
from transformers.tokenization_utils_base import EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import PaddingStrategy, TensorType, logging
from transformers import BertTokenizerFast
logger = logging.get_logger(__name__)
class Bert2DTokenizerFast(BertTokenizerFast):
r"""
Construct a "fast" BERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
File containing the vocabulary.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
unk_token (`str`, *optional*, defaults to `\"[UNK]\"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (`str`, *optional*, defaults to `\"[SEP]\"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (`str`, *optional*, defaults to `\"[PAD]\"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (`str`, *optional*, defaults to `\"[CLS]\"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (`str`, *optional*, defaults to `\"[MASK]\"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
clean_text (`bool`, *optional*, defaults to `True`):
Whether or not to clean the text before tokenization by removing any control characters and replacing all
whitespaces by the classic one.
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
issue](https://github.com/huggingface/transformers/issues/328)).
strip_accents (`bool`, *optional*):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for `lowercase` (as in the original BERT).
wordpieces_prefix (`str`, *optional*, defaults to `\"##\"`):
The prefix for subwords.
max_intermediate_subword_positions_per_word (`int`, *optional*, defaults to `1`):
The maximum number of intermediate subword positions per word. This is used to determine how many subword
positions are allowed for each word in the tokenization process.
subword_embedding_order (`str`, *optional*, defaults to `\"ending_first\"`):
The order in which subword embeddings are processed. Can be `\"ending_first\"` or `\"starting_first\"`.
intermediate_subword_distribution_strategy (`str`, *optional*, defaults to `\"uniform\"`):
The strategy for distributing intermediate subword positions. Can be `\"uniform\"` or `\"random\"`.
(Note: The original prompt mentioned "uniform" or "random", but the function code provided earlier
implemented "uniform" or "leftover_as_last". This docstring reflects the prompt's options.)
"""
model_input_names: List[str] = ["input_ids", "token_type_ids", "word_ids", "subword_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
do_lower_case=True,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
max_intermediate_subword_positions_per_word=1,
subword_embedding_order="ending_first",
intermediate_subword_distribution_strategy="uniform",
**kwargs,
):
super().__init__(
vocab_file=vocab_file, # Ensure vocab_file is passed correctly
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
self.max_intermediate_subword_positions_per_word = max_intermediate_subword_positions_per_word
self.subword_embedding_order = subword_embedding_order
self.intermediate_subword_distribution_strategy = intermediate_subword_distribution_strategy
# Ensure init_kwargs includes Bert2D specific parameters for correct saving and loading
self.init_kwargs["max_intermediate_subword_positions_per_word"] = max_intermediate_subword_positions_per_word
self.init_kwargs["subword_embedding_order"] = subword_embedding_order
self.init_kwargs["intermediate_subword_distribution_strategy"] = intermediate_subword_distribution_strategy
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy, None] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[str] = None, # Keep this for explicit passing
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
# Use keyword arguments for the super call (as per previous fix)
result = super().__call__(
text=text,
text_pair=text_pair,
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding, # Pass padding argument to super
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=None, # Process as lists first, then convert to tensor if needed
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
# Determine batch characteristics from the processed 'result'
# Input_ids can be a list (single example) or list of lists (batch)
is_batched_output = (
isinstance(result["input_ids"], list) and result["input_ids"] and isinstance(result["input_ids"][0], list)
)
actual_batch_size: int
seq_length: int # This will be the length after padding/truncation by super().__call__
if is_batched_output:
actual_batch_size = len(result["input_ids"])
seq_length = len(result["input_ids"][0]) if actual_batch_size > 0 else 0
else: # Single, non-batched list of ints
actual_batch_size = 1
seq_length = len(result["input_ids"])
# Temporarily wrap single example to use unified loop
if "input_ids" in result and not isinstance(result["input_ids"][0], list):
for key in result: # type: ignore
if isinstance(result[key], list): # type: ignore
result[key] = [result[key]] # type: ignore
# Generate word_ids and subword_ids as lists of lists
list_of_list_word_ids: List[List[int]] = []
list_of_list_subword_ids: List[List[int]] = []
for i in range(actual_batch_size):
# .tokens() is a method of BatchEncoding from a fast tokenizer
# It requires an index if the BatchEncoding is from a batch.
# If BatchEncoding is from a single example, it doesn't take an index.
# The `result` from super().__call__ is a BatchEncoding.
# If the original input to __call__ was a batch, result.tokens(i) is correct.
# If the original input was single, result.tokens() is correct.
# The BatchEncoding object itself handles this.
current_tokens = result.tokens(i) # type: ignore
# Determine if this sequence contains multiple sentences by checking for SEP tokens
should_restart_word_ids_heuristic = current_tokens.count(self.sep_token) >= 2
list_of_list_word_ids.append(
create_word_ids(
current_tokens,
restart_new_sentence=should_restart_word_ids_heuristic,
seperator_token=self.sep_token,
padding_token=self.pad_token,
)
)
list_of_list_subword_ids.append(
create_subword_ids(
current_tokens,
self.max_intermediate_subword_positions_per_word,
self.subword_embedding_order,
self.intermediate_subword_distribution_strategy,
cls_token=self.cls_token,
sep_token=self.sep_token,
pad_token=self.pad_token,
)
)
padding_value_for_ids = 0 # Standard padding for word_ids/subword_ids
effective_padding_side = padding_side if padding_side is not None else self.padding_side
# Pad word_ids and subword_ids to seq_length if padding was enabled
if padding_strategy_uses_max_length(padding, max_length):
for i in range(actual_batch_size):
for id_list in [list_of_list_word_ids[i], list_of_list_subword_ids[i]]:
current_len = len(id_list)
pad_len = seq_length - current_len
if pad_len > 0:
if effective_padding_side == "right":
id_list.extend([padding_value_for_ids] * pad_len)
else: # padding_side == "left"
for _ in range(pad_len):
id_list.insert(0, padding_value_for_ids)
elif pad_len < 0: # Truncate if longer (should ideally not happen if tokens were truncated)
if effective_padding_side == "right": # or truncation_side
del id_list[seq_length:]
else:
del id_list[:-seq_length]
if not is_batched_output: # Unwrap back to single list if original was single
result["word_ids"] = list_of_list_word_ids[0]
result["subword_ids"] = list_of_list_subword_ids[0]
# Unwrap other keys if they were wrapped
for key in list(result.keys()): # Iterate over a copy of keys
if isinstance(result[key], list) and len(result[key]) == 1 and key not in ["word_ids", "subword_ids"]:
# Check if it was a list of lists that became a list of one list
if isinstance(result[key][0], list):
result[key] = result[key][0] # type: ignore
else:
result["word_ids"] = list_of_list_word_ids
result["subword_ids"] = list_of_list_subword_ids
# Custom tensor conversion to ensure proper dimensions for pipeline compatibility
if return_tensors is not None:
if return_tensors == "pt":
import torch
# Convert to tensors first using standard method
result = result.convert_to_tensors(tensor_type=return_tensors)
# If this is a single input (not batched), ensure all tensors have batch dimension
if not is_batched_output:
for key, value in result.items():
if isinstance(value, torch.Tensor) and value.ndim == 1:
# Add batch dimension to make it 2D [1, seq_len] instead of [seq_len]
result[key] = value.unsqueeze(0)
else:
# For other tensor types (tf, np, etc.), use standard conversion
result = result.convert_to_tensors(tensor_type=return_tensors)
return result # type: ignore
def encode_plus(
self,
text: Union[TextInput, PreTokenizedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy, None] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Tokenize and prepare for the model a sequence or a pair of sequences.
This method includes the generation of word_ids and subword_ids specific to Bert2D.
"""
# Call parent's encode_plus first to get standard tokenization
result = super().encode_plus(
text=text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=None, # Process as lists first, then convert to tensor if needed
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
# Check if we have overflow tokens (multiple sequences in result)
has_overflow = return_overflowing_tokens and "overflowing_tokens" in result
# Determine if result is batched (could be batched if overflow tokens are present)
is_batched_output = (
isinstance(result["input_ids"], list) and result["input_ids"] and isinstance(result["input_ids"][0], list)
)
# If we have overflow tokens OR the result is already batched
if has_overflow or is_batched_output:
# We'll need to process each sequence separately
batch_size = len(result["input_ids"]) if is_batched_output else 1 + len(result["overflowing_tokens"])
batch_word_ids = []
batch_subword_ids = []
for i in range(batch_size):
# Get tokens for this sequence
tokens = result.tokens(i)
# Determine if this sequence contains multiple sentences
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2
word_ids = create_word_ids(
tokens,
restart_new_sentence=should_restart_word_ids_heuristic,
seperator_token=self.sep_token,
padding_token=self.pad_token,
)
subword_ids = create_subword_ids(
tokens,
self.max_intermediate_subword_positions_per_word,
self.subword_embedding_order,
self.intermediate_subword_distribution_strategy,
cls_token=self.cls_token,
sep_token=self.sep_token,
pad_token=self.pad_token,
)
batch_word_ids.append(word_ids)
batch_subword_ids.append(subword_ids)
# Add to result
result["word_ids"] = batch_word_ids
result["subword_ids"] = batch_subword_ids
else:
# Standard case - no overflow, single sequence
tokens = result.tokens()
# Determine if this sequence contains multiple sentences
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2
word_ids = create_word_ids(
tokens,
restart_new_sentence=should_restart_word_ids_heuristic,
seperator_token=self.sep_token,
padding_token=self.pad_token,
)
subword_ids = create_subword_ids(
tokens,
self.max_intermediate_subword_positions_per_word,
self.subword_embedding_order,
self.intermediate_subword_distribution_strategy,
cls_token=self.cls_token,
sep_token=self.sep_token,
pad_token=self.pad_token,
)
# Add custom fields to result
result["word_ids"] = word_ids
result["subword_ids"] = subword_ids
# Custom tensor conversion to ensure proper dimensions for pipeline compatibility
if return_tensors is not None:
if return_tensors == "pt":
import torch
# Convert to tensors first using standard method
result = result.convert_to_tensors(tensor_type=return_tensors)
# For single inputs (not batched or with overflow), ensure all tensors have batch dimension
if not is_batched_output and not has_overflow:
for key, value in result.items():
if isinstance(value, torch.Tensor) and value.ndim == 1:
# Add batch dimension to make it 2D [1, seq_len] instead of [seq_len]
result[key] = value.unsqueeze(0)
else:
# For other tensor types, use standard conversion
result = result.convert_to_tensors(tensor_type=return_tensors)
return result
def batch_encode_plus(
self,
batch_text_or_text_pairs: Union[
List[TextInput],
List[PreTokenizedInput],
List[Union[TextInput, PreTokenizedInput]],
],
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy, None] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Tokenize and prepare a batch of sequences or a batch of sequence pairs for the model.
This method includes the generation of word_ids and subword_ids specific to Bert2D.
"""
# Call the parent's batch_encode_plus first to get standard tokenization
result = super().batch_encode_plus(
batch_text_or_text_pairs,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=None, # Process as lists first, then convert to tensor if needed
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
# Generate word_ids and subword_ids for each item in the batch
# Use the actual batch size from result["input_ids"], which includes overflow sequences
batch_size = len(result["input_ids"])
batch_word_ids = []
batch_subword_ids = []
for i in range(batch_size):
# Get tokens for this batch item
tokens = result.tokens(i)
# Determine if this sequence contains multiple sentences
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2
word_ids = create_word_ids(
tokens,
restart_new_sentence=should_restart_word_ids_heuristic,
seperator_token=self.sep_token,
padding_token=self.pad_token,
)
subword_ids = create_subword_ids(
tokens,
self.max_intermediate_subword_positions_per_word,
self.subword_embedding_order,
self.intermediate_subword_distribution_strategy,
cls_token=self.cls_token,
sep_token=self.sep_token,
pad_token=self.pad_token,
)
batch_word_ids.append(word_ids)
batch_subword_ids.append(subword_ids)
# Add custom fields to result
result["word_ids"] = batch_word_ids
result["subword_ids"] = batch_subword_ids
# Convert to tensors if requested - for batched inputs we don't need special handling
# as they're already in the right format (list of lists)
if return_tensors is not None:
result = result.convert_to_tensors(tensor_type=return_tensors)
return result
def prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy, None] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
prepend_batch_axis: bool = False,
**kwargs,
) -> BatchEncoding:
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids, for the model.
This method adds `word_ids` and `subword_ids` specific to Bert2D.
"""
# Get the standard outputs from the parent class
prepared_inputs = super().prepare_for_model(
ids=ids,
pair_ids=pair_ids,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=None, # Process as lists first
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
prepend_batch_axis=prepend_batch_axis,
**kwargs,
)
# Convert input_ids to tokens to generate word_ids and subword_ids
tokens = self.convert_ids_to_tokens(prepared_inputs["input_ids"])
# Heuristic to check if we have a sentence pair
should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2
# Create and add word_ids
prepared_inputs["word_ids"] = create_word_ids(
tokens,
restart_new_sentence=should_restart_word_ids_heuristic,
seperator_token=self.sep_token,
padding_token=self.pad_token,
)
# Create and add subword_ids
prepared_inputs["subword_ids"] = create_subword_ids(
tokens,
self.max_intermediate_subword_positions_per_word,
self.subword_embedding_order,
self.intermediate_subword_distribution_strategy,
cls_token=self.cls_token,
sep_token=self.sep_token,
pad_token=self.pad_token,
)
# Custom tensor conversion to ensure proper dimensions for pipeline compatibility
if return_tensors is not None:
if return_tensors == "pt":
import torch
# Convert to tensors first using standard method
prepared_inputs = prepared_inputs.convert_to_tensors(tensor_type=return_tensors)
# If prepend_batch_axis is False, we need to ensure all tensors have batch dimension
if not prepend_batch_axis:
for key, value in prepared_inputs.items():
if isinstance(value, torch.Tensor) and value.ndim == 1:
# Add batch dimension to make it 2D [1, seq_len] instead of [seq_len]
prepared_inputs[key] = value.unsqueeze(0)
else:
# For other tensor types, use standard conversion
prepared_inputs = prepared_inputs.convert_to_tensors(tensor_type=return_tensors)
return prepared_inputs
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
padding_side: Optional[str] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
effective_padding_side = padding_side if padding_side is not None else self.padding_side
# Separate word_ids and subword_ids if they are lists.
# Tensors are assumed to be handled by the __call__ method's padding logic
# or are already in the correct shape if this method is reached.
word_ids_list = None
if "word_ids" in encoded_inputs and isinstance(encoded_inputs["word_ids"], list):
word_ids_list = encoded_inputs.pop("word_ids")
subword_ids_list = None
if "subword_ids" in encoded_inputs and isinstance(encoded_inputs["subword_ids"], list):
subword_ids_list = encoded_inputs.pop("subword_ids")
# Call the superclass's _pad method to handle standard keys like input_ids, attention_mask, etc.
# CRITICAL: Pass all relevant arguments, especially `padding_side`.
padded_standard_inputs = super()._pad(
encoded_inputs, # This now only contains standard keys if custom keys were lists and popped
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=effective_padding_side, # Pass the determined padding_side
return_attention_mask=return_attention_mask,
)
# Now, handle padding for word_ids and subword_ids if they were lists and were popped.
# This padding should align with how input_ids were padded by the super()._pad call.
# This logic is primarily for cases where inputs are not yet tensors (e.g. return_tensors=None).
if padding_strategy != PaddingStrategy.DO_NOT_PAD and max_length is not None:
main_input_name = self.model_input_names[0] # usually "input_ids"
if main_input_name not in padded_standard_inputs:
# This case should ideally not happen if _pad is called correctly.
# Fallback to adding custom IDs without padding if main input is missing.
if word_ids_list is not None:
padded_standard_inputs["word_ids"] = word_ids_list
if subword_ids_list is not None:
padded_standard_inputs["subword_ids"] = subword_ids_list
return padded_standard_inputs
padded_len = len(padded_standard_inputs[main_input_name])
padding_val = 0 # Standard padding value for these custom IDs
if word_ids_list is not None:
current_len = len(word_ids_list)
diff = padded_len - current_len
if diff > 0:
if effective_padding_side == "right":
padded_standard_inputs["word_ids"] = word_ids_list + [padding_val] * diff
else: # left
padded_standard_inputs["word_ids"] = [padding_val] * diff + word_ids_list
else: # No padding needed or truncation occurred (list might be longer)
padded_standard_inputs["word_ids"] = word_ids_list[:padded_len] # Ensure it's not longer
# If word_ids was not a list (e.g., a tensor passed through), and was not popped,
# it might still be in `encoded_inputs` (the original dict passed to _pad).
# If so, and it wasn't popped, ensure it's in the output.
elif "word_ids" in encoded_inputs:
padded_standard_inputs["word_ids"] = encoded_inputs["word_ids"]
if subword_ids_list is not None:
current_len = len(subword_ids_list)
diff = padded_len - current_len
if diff > 0:
if effective_padding_side == "right":
padded_standard_inputs["subword_ids"] = subword_ids_list + [padding_val] * diff
else: # left
padded_standard_inputs["subword_ids"] = [padding_val] * diff + subword_ids_list
else:
padded_standard_inputs["subword_ids"] = subword_ids_list[:padded_len]
elif "subword_ids" in encoded_inputs:
padded_standard_inputs["subword_ids"] = encoded_inputs["subword_ids"]
else: # No padding was applied to standard inputs, or no max_length specified
if word_ids_list is not None:
padded_standard_inputs["word_ids"] = word_ids_list
elif "word_ids" in encoded_inputs: # Ensure it's carried over if not popped
padded_standard_inputs["word_ids"] = encoded_inputs["word_ids"]
if subword_ids_list is not None:
padded_standard_inputs["subword_ids"] = subword_ids_list
elif "subword_ids" in encoded_inputs: # Ensure it's carried over if not popped
padded_standard_inputs["subword_ids"] = encoded_inputs["subword_ids"]
return padded_standard_inputs
def apply_chat_template(
self,
conversation,
chat_template=None,
tools=None,
documents=None,
add_generation_prompt=False,
tokenize=True,
padding=False,
truncation=None,
max_length=None,
return_tensors=None,
return_dict=False,
return_assistant_tokens_mask=False,
tokenizer_kwargs=None,
**kwargs,
):
"""
Override apply_chat_template to fix tensor dimension issues when
return_tensors="pt" is used with single conversations and return_assistant_tokens_mask=True.
"""
# Check if we need to apply the fix
needs_tensor_fix = (
return_tensors == "pt"
and return_assistant_tokens_mask
and return_dict
and tokenize
and not isinstance(conversation[0], list)
if conversation
else False # Single conversation, not batched
)
if needs_tensor_fix:
# For single conversations with tensor output, temporarily disable tensor conversion
# and handle it manually after the call
result = super().apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=None, # Disable tensor conversion temporarily
return_dict=return_dict,
return_assistant_tokens_mask=return_assistant_tokens_mask,
tokenizer_kwargs=tokenizer_kwargs,
**kwargs,
)
# Now manually convert to tensors ensuring proper dimensions
if return_tensors == "pt":
import torch
# Convert each field to tensors with proper dimensions
for key, value in result.items():
if isinstance(value, list):
# Convert list to tensor and ensure it has a batch dimension
tensor = torch.tensor(value)
# Ensure we have at least 1D (for sequences) and add batch dimension if needed
if tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0) # Add batch dimension
result[key] = tensor
return result
else:
# For all other cases, use the parent implementation
return super().apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_dict=return_dict,
return_assistant_tokens_mask=return_assistant_tokens_mask,
tokenizer_kwargs=tokenizer_kwargs,
**kwargs,
)
def is_subword(token: str, subword_prefix="##") -> bool:
"""Returns if a token is a subword"""
return token.startswith(subword_prefix)
def create_word_ids(
tokens: List[str], restart_new_sentence=False, seperator_token="[SEP]", padding_token="[PAD]"
) -> List[int]:
"""Creates word ids for given tokens, matching the logic from Bert2DTokenizerFast tests."""
word_ids: List[int] = []
current_word_id: int = -1
sentence_restart_flag = False
actual_restart_new_sentence = restart_new_sentence and tokens.count(seperator_token) >= 2
for token in tokens:
if token == padding_token: # Pad tokens get word_id 0
word_ids.append(0)
# current_word_id = 0 # Resetting current_word_id for padding might be complex if padding is not last
elif actual_restart_new_sentence and not sentence_restart_flag and token == seperator_token:
if current_word_id == -1: # First token is SEP
current_word_id = 0
word_ids.append(current_word_id)
else: # SEP after some content
current_word_id += 1
word_ids.append(current_word_id)
current_word_id = -1 # Reset for the new sentence (will become 0 at first non-subword)
sentence_restart_flag = True
elif not is_subword(token):
current_word_id += 1
word_ids.append(current_word_id)
elif current_word_id == -1: # First token of a sequence (or after reset SEP) is a subword
current_word_id = 0
word_ids.append(current_word_id)
else: # Subword of an existing word
word_ids.append(current_word_id)
return word_ids
def col_round(x: float) -> int:
"""Colloquial rounding where 0.5 rounds to 1"""
frac = x - math.floor(x)
if frac < 0.5:
return math.floor(x)
return math.ceil(x)
def get_uniform_id(si: int, max_intermediate_subwords: int, num_intermediate_subwords: int) -> int:
"""Calculates uniform id for the given subword index, si, and max and number of intermediate subwords"""
if num_intermediate_subwords == 0: # Avoid division by zero if there are no intermediate subwords
return 0
# Effective max position is max_intermediate_subwords - 1 because positions are 0-indexed
# e.g., if max_intermediate_subwords is 1, effective_max_pos is 0.
# if max_intermediate_subwords is 2, effective_max_pos is 1 (positions 0, 1).
effective_max_pos = max(0, max_intermediate_subwords - 1) # Ensure non-negative
return col_round(si * effective_max_pos / num_intermediate_subwords)
def get_ids_from_subwords(
num_subwords_in_current_word: int,
max_intermediate_subword_positions_per_word: int,
subword_embedding_order: str,
intermediate_subword_distribution_strategy: str,
current_word_starts_with_subword: bool = False,
) -> List[int]:
"""Calculate subword ids for the tokens of a single word."""
if num_subwords_in_current_word == 0:
return []
# Handle cases where the "word" is just one token
if current_word_starts_with_subword: # Word like "##ing"
if num_subwords_in_current_word == 1:
return [1] # Treat as "last" subword if it's the only token and starts with ##
elif num_subwords_in_current_word == 1: # Word like "run"
return [0] # Treat as "root" subword
# For multi-token words
if subword_embedding_order == "ending_first":
subword_ids: List[int] = []
has_explicit_root = not current_word_starts_with_subword and num_subwords_in_current_word > 0
# "Last" subword exists if there's more than one token,
# OR if it's a single token starting with ## (handled above, but for clarity here)
has_explicit_last = num_subwords_in_current_word > 1 or (
current_word_starts_with_subword and num_subwords_in_current_word == 1
)
if has_explicit_root:
subword_ids.append(0) # Root token (e.g., "run" in "running")
# Tokens remaining for intermediate and last part
num_tokens_for_intermediate_and_last = num_subwords_in_current_word - 1
else: # Word starts with subword (e.g., "##run" in "##running")
# All tokens contribute to intermediate and last part (or just last if only one ##token)
num_tokens_for_intermediate_and_last = num_subwords_in_current_word
num_intermediate_tokens = 0
if has_explicit_last:
# If there's a distinct "last" token, subtract it from the count
num_intermediate_tokens = num_tokens_for_intermediate_and_last - 1
else:
# If no distinct "last" token (e.g. "##word" - only one token, no root),
# then all (remaining) tokens are considered intermediate.
# This case should be rare if num_subwords_in_current_word > 1
num_intermediate_tokens = num_tokens_for_intermediate_and_last
# Ensure non-negative, can happen if num_subwords_in_current_word is 1 and has_explicit_last is true.
if num_intermediate_tokens < 0:
num_intermediate_tokens = 0
# Assign IDs to intermediate tokens
if num_intermediate_tokens > 0:
if num_intermediate_tokens <= max_intermediate_subword_positions_per_word:
# If fewer or equal intermediate tokens than available slots, assign unique IDs
for si in range(num_intermediate_tokens):
subword_ids.append(2 + si) # IDs 2, 3, ..., (2+max_intermediate_subword_positions_per_word-1)
else: # More intermediate tokens than available slots
if intermediate_subword_distribution_strategy == "uniform":
for si in range(num_intermediate_tokens):
subword_ids.append(
2
+ get_uniform_id(si, max_intermediate_subword_positions_per_word, num_intermediate_tokens)
)
elif intermediate_subword_distribution_strategy == "leftover_as_last":
# Fill available intermediate slots
for si in range(max_intermediate_subword_positions_per_word):
subword_ids.append(2 + si)
# Assign remaining intermediate tokens as "last" (ID 1)
for _ in range(num_intermediate_tokens - max_intermediate_subword_positions_per_word):
subword_ids.append(1)
else:
raise ValueError(
f"Unsupported intermediate subword distribution strategy: {intermediate_subword_distribution_strategy}"
)
if has_explicit_last:
subword_ids.append(1) # Last token (e.g., "##ing" in "running")
return subword_ids
else:
raise ValueError(f"Unsupported subword embedding order: {subword_embedding_order}")
def create_subword_ids(
tokens: List[str],
max_intermediate_subword_positions_per_word: int,
subword_embedding_order: str,
intermediate_subword_distribution_strategy: str,
cls_token="[CLS]",
sep_token="[SEP]",
pad_token="[PAD]",
) -> List[int]:
"""Creates subword ids for the given tokens and parameters."""
if not tokens:
return []
all_subword_ids: List[int] = []
current_word_segment_tokens: List[str] = []
# Determine if the very first content token (non-special) is a subword.
# This helps decide if the first word itself starts with a subword prefix.
first_content_token_is_subword = False
if tokens: # Check if tokens list is not empty
for token_val in tokens:
if token_val not in [cls_token, sep_token, pad_token]:
first_content_token_is_subword = is_subword(token_val)
break # Found the first content token
first_content_word_processed = False # Flag to track if we've processed the first actual word
for token_idx, token in enumerate(tokens):
if token in [cls_token, sep_token, pad_token]: # Special tokens
# If there was an ongoing word segment, process it first
if current_word_segment_tokens:
# Determine if this segment is the very first content word AND it starts with a subword
is_this_segment_the_very_first_content_word_and_starts_with_subword = (
first_content_token_is_subword
and not first_content_word_processed
and is_subword(current_word_segment_tokens[0])
)
generated_ids = get_ids_from_subwords(
num_subwords_in_current_word=len(current_word_segment_tokens),
max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word,
subword_embedding_order=subword_embedding_order,
intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy,
current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword,
)
all_subword_ids.extend(generated_ids)
if (
not first_content_word_processed and current_word_segment_tokens
): # Mark first content word as processed
first_content_word_processed = True
current_word_segment_tokens = [] # Reset for next word
all_subword_ids.append(0) # Special tokens get subword_id 0
elif not is_subword(token): # Token is a root word (doesn't start with ##)
# If there was an ongoing word segment (which must have been all subwords), process it
if current_word_segment_tokens:
is_this_segment_the_very_first_content_word_and_starts_with_subword = (
first_content_token_is_subword
and not first_content_word_processed
and is_subword(current_word_segment_tokens[0])
)
generated_ids = get_ids_from_subwords(
num_subwords_in_current_word=len(current_word_segment_tokens),
max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word,
subword_embedding_order=subword_embedding_order,
intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy,
current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword,
)
all_subword_ids.extend(generated_ids)
if not first_content_word_processed and current_word_segment_tokens:
first_content_word_processed = True
current_word_segment_tokens = [token] # Start a new word segment with this root token
else: # Token is a subword (starts with ##)
current_word_segment_tokens.append(token)
# After loop, process any remaining word segment
if current_word_segment_tokens:
is_this_segment_the_very_first_content_word_and_starts_with_subword = (
first_content_token_is_subword
and not first_content_word_processed
and is_subword(current_word_segment_tokens[0])
)
generated_ids = get_ids_from_subwords(
num_subwords_in_current_word=len(current_word_segment_tokens),
max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word,
subword_embedding_order=subword_embedding_order,
intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy,
current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword,
)
all_subword_ids.extend(generated_ids)
return all_subword_ids
def padding_strategy_uses_max_length(
padding_strategy: Union[bool, str, PaddingStrategy], max_length: Optional[int]
) -> bool:
"""Helper to determine if padding will occur up to a max_length."""
if padding_strategy is False or padding_strategy == PaddingStrategy.DO_NOT_PAD:
return False
if padding_strategy is True or padding_strategy == PaddingStrategy.LONGEST:
# Padding to longest in batch still implies a fixed length for that batch
return True
if padding_strategy == PaddingStrategy.MAX_LENGTH:
return max_length is not None
return False
__all__ = ["Bert2DTokenizerFast"]