# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for Bert2D.""" import math from typing import List, Optional, Union # import os # Only needed for __main__ # import shutil # Only needed for __main__ from transformers.tokenization_utils_base import ( BatchEncoding, PaddingStrategy, PreTokenizedInput, TensorType, TextInput, TruncationStrategy, ) from transformers.utils import logging from transformers import BertTokenizer logger = logging.get_logger(__name__) # Set logger level to DEBUG for more verbose output if needed during development/testing # logger.setLevel(logging.DEBUG) # Comment out for production # Helper functions ported from Bert2DTokenizerFast def is_subword(token: str, subword_prefix="##") -> bool: """Returns if a token is a subword""" return token.startswith(subword_prefix) def create_word_ids( tokens: List[str], restart_new_sentence=False, seperator_token="[SEP]", padding_token="[PAD]" ) -> List[int]: """Creates word ids for given tokens, matching the logic from Bert2DTokenizerFast tests.""" word_ids: List[int] = [] current_word_id: int = -1 sentence_restart_flag = False actual_restart_new_sentence = restart_new_sentence and tokens.count(seperator_token) >= 2 for token in tokens: if token == padding_token: # Pad tokens get word_id 0 word_ids.append(0) # current_word_id = 0 # Resetting current_word_id for padding might be complex if padding is not last elif actual_restart_new_sentence and not sentence_restart_flag and token == seperator_token: if current_word_id == -1: # First token is SEP current_word_id = 0 word_ids.append(current_word_id) else: # SEP after some content current_word_id += 1 word_ids.append(current_word_id) current_word_id = -1 # Reset for the new sentence (will become 0 at first non-subword) sentence_restart_flag = True elif not is_subword(token): current_word_id += 1 word_ids.append(current_word_id) elif current_word_id == -1: # First token of a sequence (or after reset SEP) is a subword current_word_id = 0 word_ids.append(current_word_id) else: # Subword of an existing word word_ids.append(current_word_id) return word_ids def col_round(x: float) -> int: """Colloquial rounding where 0.5 rounds to 1""" frac = x - math.floor(x) if frac < 0.5: return math.floor(x) return math.ceil(x) def get_uniform_id(si: int, max_intermediate_subwords: int, num_intermediate_subwords: int) -> int: """Calculates uniform id for the given subword index, si, and max and number of intermediate subwords""" if num_intermediate_subwords == 0: return 0 effective_max_pos = max(0, max_intermediate_subwords - 1) return col_round(si * effective_max_pos / num_intermediate_subwords) def get_ids_from_subwords( num_subwords_in_current_word: int, max_intermediate_subword_positions_per_word: int, subword_embedding_order: str, intermediate_subword_distribution_strategy: str, current_word_starts_with_subword: bool = False, ) -> List[int]: """Calculate subword ids for the tokens of a single word.""" if num_subwords_in_current_word == 0: return [] if current_word_starts_with_subword: if num_subwords_in_current_word == 1: return [1] elif num_subwords_in_current_word == 1: return [0] if subword_embedding_order == "ending_first": subword_ids: List[int] = [] has_explicit_root = not current_word_starts_with_subword and num_subwords_in_current_word > 0 has_explicit_last = num_subwords_in_current_word > 1 or ( current_word_starts_with_subword and num_subwords_in_current_word == 1 ) if has_explicit_root: subword_ids.append(0) num_tokens_for_intermediate_and_last = num_subwords_in_current_word - 1 else: num_tokens_for_intermediate_and_last = num_subwords_in_current_word if has_explicit_last: num_intermediate_tokens = num_tokens_for_intermediate_and_last - 1 else: num_intermediate_tokens = num_tokens_for_intermediate_and_last if num_intermediate_tokens < 0: num_intermediate_tokens = 0 if num_intermediate_tokens > 0: if num_intermediate_tokens <= max_intermediate_subword_positions_per_word: for si in range(num_intermediate_tokens): subword_ids.append(2 + si) else: if intermediate_subword_distribution_strategy == "uniform": for si in range(num_intermediate_tokens): subword_ids.append( 2 + get_uniform_id(si, max_intermediate_subword_positions_per_word, num_intermediate_tokens) ) elif intermediate_subword_distribution_strategy == "leftover_as_last": for si in range(max_intermediate_subword_positions_per_word): subword_ids.append(2 + si) for _ in range(num_intermediate_tokens - max_intermediate_subword_positions_per_word): subword_ids.append(1) else: raise ValueError( f"Unsupported intermediate subword distribution strategy: {intermediate_subword_distribution_strategy}" ) if has_explicit_last: subword_ids.append(1) return subword_ids else: raise ValueError(f"Unsupported subword embedding order: {subword_embedding_order}") def create_subword_ids( tokens: List[str], max_intermediate_subword_positions_per_word: int, subword_embedding_order: str, intermediate_subword_distribution_strategy: str, cls_token="[CLS]", sep_token="[SEP]", pad_token="[PAD]", ) -> List[int]: """Creates subword ids for the given tokens and parameters.""" if not tokens: return [] all_subword_ids: List[int] = [] current_word_segment_tokens: List[str] = [] first_content_token_is_subword = False if tokens: for token_val in tokens: if token_val not in [cls_token, sep_token, pad_token]: first_content_token_is_subword = is_subword(token_val) break first_content_word_processed = False for token_idx, token in enumerate(tokens): if token in [cls_token, sep_token, pad_token]: if current_word_segment_tokens: is_this_segment_the_very_first_content_word_and_starts_with_subword = ( first_content_token_is_subword and not first_content_word_processed and is_subword(current_word_segment_tokens[0]) ) generated_ids = get_ids_from_subwords( num_subwords_in_current_word=len(current_word_segment_tokens), max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word, subword_embedding_order=subword_embedding_order, intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy, current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword, ) all_subword_ids.extend(generated_ids) if not first_content_word_processed and current_word_segment_tokens: first_content_word_processed = True current_word_segment_tokens = [] all_subword_ids.append(0) elif not is_subword(token): if current_word_segment_tokens: is_this_segment_the_very_first_content_word_and_starts_with_subword = ( first_content_token_is_subword and not first_content_word_processed and is_subword(current_word_segment_tokens[0]) ) generated_ids = get_ids_from_subwords( num_subwords_in_current_word=len(current_word_segment_tokens), max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word, subword_embedding_order=subword_embedding_order, intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy, current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword, ) all_subword_ids.extend(generated_ids) if not first_content_word_processed and current_word_segment_tokens: first_content_word_processed = True current_word_segment_tokens = [token] else: current_word_segment_tokens.append(token) if current_word_segment_tokens: is_this_segment_the_very_first_content_word_and_starts_with_subword = ( first_content_token_is_subword and not first_content_word_processed and is_subword(current_word_segment_tokens[0]) ) generated_ids = get_ids_from_subwords( num_subwords_in_current_word=len(current_word_segment_tokens), max_intermediate_subword_positions_per_word=max_intermediate_subword_positions_per_word, subword_embedding_order=subword_embedding_order, intermediate_subword_distribution_strategy=intermediate_subword_distribution_strategy, current_word_starts_with_subword=is_this_segment_the_very_first_content_word_and_starts_with_subword, ) all_subword_ids.extend(generated_ids) return all_subword_ids class Bert2DTokenizer(BertTokenizer): r""" Construct a BERT2D tokenizer. Based on WordPiece. This tokenizer inherits from [`BertTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Bert2DTokenizer adds functionality for generating `word_ids` and `subword_ids` which are used for 2D positional embeddings. Args: vocab_file (`str`): File containing the vocabulary. do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. do_basic_tokenize (`bool`, *optional*, defaults to `True`): Whether or not to do basic tokenization before WordPiece. never_split (`Iterable`, *optional*): Collection of tokens which will never be split during tokenization. Only has an effect when `do_basic_tokenize=True`. unk_token (`str`, *optional*, defaults to `"[UNK]"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. sep_token (`str`, *optional*, defaults to `"[SEP]"`): The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens. pad_token (`str`, *optional*, defaults to `"[PAD]"`): The token used for padding, for example when batching sequences of different lengths. cls_token (`str`, *optional*, defaults to `"[CLS]"`): The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens. mask_token (`str`, *optional*, defaults to `"[MASK]"`): The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict. tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this [issue](https://github.com/huggingface/transformers/issues/328)). strip_accents (`bool`, *optional*): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for `lowercase` (as in the original BERT). max_intermediate_subword_positions_per_word (`int`, *optional*, defaults to `1`): The maximum number of intermediate subword positions per word. This is used to determine how many subword positions are allowed for each word in the tokenization process. subword_embedding_order (`str`, *optional*, defaults to `"ending_first"`): The order in which subword embeddings are processed. Can be `"ending_first"`. intermediate_subword_distribution_strategy (`str`, *optional*, defaults to `"uniform"`): The strategy for distributing intermediate subword positions. Can be `"uniform"` or `"leftover_as_last"`. """ model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask", "word_ids", "subword_ids"] def __init__( self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", tokenize_chinese_chars=True, strip_accents=None, max_intermediate_subword_positions_per_word=1, subword_embedding_order="ending_first", intermediate_subword_distribution_strategy="uniform", **kwargs, ): # Step 1: Store Bert2D-specific args from init signature into local variables local_max_intermediate = max_intermediate_subword_positions_per_word local_subword_order = subword_embedding_order local_intermediate_strategy = intermediate_subword_distribution_strategy # Step 2: Remove Bert2D-specific args from kwargs to prevent passing them to super() kwargs.pop("max_intermediate_subword_positions_per_word", None) kwargs.pop("subword_embedding_order", None) kwargs.pop("intermediate_subword_distribution_strategy", None) # Step 3: Call super().__init__(), explicitly passing Bert2DTokenizer.model_input_names # This ensures that PreTrainedTokenizerBase uses our desired model_input_names. super().__init__( vocab_file=vocab_file, do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, never_split=never_split, unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, **kwargs, ) # Step 4: Set Bert2D specific attributes on self using the stored local variables self.max_intermediate_subword_positions_per_word = local_max_intermediate self.subword_embedding_order = local_subword_order self.intermediate_subword_distribution_strategy = local_intermediate_strategy if subword_embedding_order != "ending_first": logger.warning( f"Bert2DTokenizer slow currently only fully supports 'ending_first' for subword_embedding_order. Received: {subword_embedding_order}" ) # Step 5: Update init_kwargs if it exists (for serialization/reconstruction by base classes) # This makes sure that when the tokenizer is saved and reloaded, these custom parameters are preserved. if hasattr(self, "init_kwargs") and isinstance(self.init_kwargs, dict): self.init_kwargs["max_intermediate_subword_positions_per_word"] = ( self.max_intermediate_subword_positions_per_word ) self.init_kwargs["subword_embedding_order"] = self.subword_embedding_order self.init_kwargs["intermediate_subword_distribution_strategy"] = ( self.intermediate_subword_distribution_strategy ) # else: # This case might occur if the superclass doesn't initialize init_kwargs, # which would be unusual for PreTrainedTokenizer based classes. # logger.warning("self.init_kwargs not found or not a dict during Bert2DTokenizer __init__. Custom params might not be saved.") def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None, text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None, text_pair_target: Optional[ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] ] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, padding_side: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ) -> BatchEncoding: batch_encoding_super = super().__call__( text=text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=None, # Get lists first to allow modification return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs, ) input_ids_processed = batch_encoding_super["input_ids"] is_batched = bool( (isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))) or ( isinstance(text, (list, tuple)) and text_pair is not None and isinstance(text_pair, (list, tuple)) and text and isinstance(text[0], str) ) ) if not is_batched and isinstance(text, (list, tuple)) and not (text and isinstance(text[0], (list, tuple))): if ( isinstance(input_ids_processed, list) and bool(input_ids_processed) and isinstance(input_ids_processed[0], list) ): is_batched = True elif not is_batched and text_pair is not None and isinstance(text_pair, (list, tuple)): if ( isinstance(input_ids_processed, list) and bool(input_ids_processed) and isinstance(input_ids_processed[0], list) ): is_batched = True list_of_input_ids_for_processing: List[List[int]] if not is_batched: list_of_input_ids_for_processing = [input_ids_processed] else: list_of_input_ids_for_processing = input_ids_processed all_word_ids: List[List[int]] = [] all_subword_ids: List[List[int]] = [] for ids_for_one_sequence in list_of_input_ids_for_processing: tokens = self.convert_ids_to_tokens(ids_for_one_sequence, skip_special_tokens=False) should_restart_word_ids_heuristic = text_pair is not None word_ids_for_sequence = create_word_ids( tokens, restart_new_sentence=should_restart_word_ids_heuristic, seperator_token=self.sep_token, padding_token=self.pad_token, ) subword_ids_for_sequence = create_subword_ids( tokens, max_intermediate_subword_positions_per_word=self.max_intermediate_subword_positions_per_word, subword_embedding_order=self.subword_embedding_order, intermediate_subword_distribution_strategy=self.intermediate_subword_distribution_strategy, cls_token=self.cls_token, sep_token=self.sep_token, pad_token=self.pad_token, ) all_word_ids.append(word_ids_for_sequence) all_subword_ids.append(subword_ids_for_sequence) if not is_batched: batch_encoding_super["word_ids"] = all_word_ids[0] batch_encoding_super["subword_ids"] = all_subword_ids[0] else: batch_encoding_super["word_ids"] = all_word_ids batch_encoding_super["subword_ids"] = all_subword_ids # Custom tensor conversion to ensure proper dimensions for pipeline compatibility if return_tensors is not None: if return_tensors == "pt": import torch # Convert to tensors first using standard method batch_encoding_super = batch_encoding_super.convert_to_tensors(tensor_type=return_tensors) # If this is a single input (not batched), ensure all tensors have batch dimension if not is_batched: for key, value in batch_encoding_super.items(): if isinstance(value, torch.Tensor) and value.ndim == 1: # Add batch dimension to make it 2D [1, seq_len] instead of [seq_len] batch_encoding_super[key] = value.unsqueeze(0) else: # For other tensor types (tf, np, etc.), use standard conversion batch_encoding_super = batch_encoding_super.convert_to_tensors(tensor_type=return_tensors) return batch_encoding_super def encode_plus( self, text: Union[TextInput, PreTokenizedInput], text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ) -> BatchEncoding: """ Tokenize and prepare for the model a sequence or a pair of sequences. This method includes the generation of word_ids and subword_ids specific to Bert2D. """ # Call parent's encode_plus first to get standard tokenization result = super().encode_plus( text=text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=None, # Process as lists first, then convert to tensor if needed return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs, ) # Check if we have overflow tokens (multiple sequences in result) has_overflow = return_overflowing_tokens and "overflowing_tokens" in result # Determine if result is batched (could be batched if overflow tokens are present) is_batched_output = ( isinstance(result["input_ids"], list) and result["input_ids"] and isinstance(result["input_ids"][0], list) ) # If we have overflow tokens OR the result is already batched if has_overflow or is_batched_output: # We'll need to process each sequence separately batch_size = len(result["input_ids"]) if is_batched_output else 1 batch_word_ids = [] batch_subword_ids = [] for i in range(batch_size): # Get tokens for this sequence tokens = self.convert_ids_to_tokens( result["input_ids"][i] if is_batched_output else result["input_ids"], skip_special_tokens=False ) # Determine if this sequence contains multiple sentences by counting SEP tokens should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 word_ids = create_word_ids( tokens, restart_new_sentence=should_restart_word_ids_heuristic, seperator_token=self.sep_token, padding_token=self.pad_token, ) subword_ids = create_subword_ids( tokens, max_intermediate_subword_positions_per_word=self.max_intermediate_subword_positions_per_word, subword_embedding_order=self.subword_embedding_order, intermediate_subword_distribution_strategy=self.intermediate_subword_distribution_strategy, cls_token=self.cls_token, sep_token=self.sep_token, pad_token=self.pad_token, ) batch_word_ids.append(word_ids) batch_subword_ids.append(subword_ids) # Add to result if is_batched_output: result["word_ids"] = batch_word_ids result["subword_ids"] = batch_subword_ids else: # If input was single but we have overflow tokens, result should still be a list result["word_ids"] = batch_word_ids[0] result["subword_ids"] = batch_subword_ids[0] else: # Standard case - no overflow, single sequence tokens = self.convert_ids_to_tokens(result["input_ids"], skip_special_tokens=False) # Determine if this sequence contains multiple sentences by counting SEP tokens should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 word_ids = create_word_ids( tokens, restart_new_sentence=should_restart_word_ids_heuristic, seperator_token=self.sep_token, padding_token=self.pad_token, ) subword_ids = create_subword_ids( tokens, max_intermediate_subword_positions_per_word=self.max_intermediate_subword_positions_per_word, subword_embedding_order=self.subword_embedding_order, intermediate_subword_distribution_strategy=self.intermediate_subword_distribution_strategy, cls_token=self.cls_token, sep_token=self.sep_token, pad_token=self.pad_token, ) # Add custom fields to result result["word_ids"] = word_ids result["subword_ids"] = subword_ids # Custom tensor conversion to ensure proper dimensions for pipeline compatibility if return_tensors is not None: if return_tensors == "pt": import torch # Convert to tensors first using standard method result = result.convert_to_tensors(tensor_type=return_tensors) # For single inputs (not batched or with overflow), ensure all tensors have batch dimension if not is_batched_output and not has_overflow: for key, value in result.items(): if isinstance(value, torch.Tensor) and value.ndim == 1: # Add batch dimension to make it 2D [1, seq_len] instead of [seq_len] result[key] = value.unsqueeze(0) else: # For other tensor types, use standard conversion result = result.convert_to_tensors(tensor_type=return_tensors) return result def batch_encode_plus( self, batch_text_or_text_pairs: Union[ List[TextInput], List[PreTokenizedInput], List[Union[TextInput, PreTokenizedInput]], ], add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ) -> BatchEncoding: """ Tokenize and prepare a batch of sequences or a batch of sequence pairs for the model. This method includes the generation of word_ids and subword_ids specific to Bert2D. """ # Call the parent's batch_encode_plus first to get standard tokenization result = super().batch_encode_plus( batch_text_or_text_pairs, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=None, # Process as lists first, then convert to tensor if needed return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs, ) # Generate word_ids and subword_ids for each item in the batch # Use the actual batch size from result["input_ids"], which includes overflow sequences batch_size = len(result["input_ids"]) batch_word_ids = [] batch_subword_ids = [] for i in range(batch_size): # Get tokens for this batch item tokens = self.convert_ids_to_tokens(result["input_ids"][i], skip_special_tokens=False) # Determine if this sequence contains multiple sentences by counting SEP tokens should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 word_ids = create_word_ids( tokens, restart_new_sentence=should_restart_word_ids_heuristic, seperator_token=self.sep_token, padding_token=self.pad_token, ) subword_ids = create_subword_ids( tokens, max_intermediate_subword_positions_per_word=self.max_intermediate_subword_positions_per_word, subword_embedding_order=self.subword_embedding_order, intermediate_subword_distribution_strategy=self.intermediate_subword_distribution_strategy, cls_token=self.cls_token, sep_token=self.sep_token, pad_token=self.pad_token, ) batch_word_ids.append(word_ids) batch_subword_ids.append(subword_ids) # Add custom fields to result result["word_ids"] = batch_word_ids result["subword_ids"] = batch_subword_ids # Convert to tensors if requested - for batched inputs we don't need special handling # as they're already in the right format (list of lists) if return_tensors is not None: result = result.convert_to_tensors(tensor_type=return_tensors) return result def prepare_for_model( self, ids: List[int], pair_ids: Optional[List[int]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy, None] = None, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, prepend_batch_axis: bool = False, **kwargs, ) -> BatchEncoding: """ Prepares a sequence of input id, or a pair of sequences of inputs ids, for the model. This method adds `word_ids` and `subword_ids` specific to Bert2D. """ # Get the standard outputs from the parent class prepared_inputs = super().prepare_for_model( ids=ids, pair_ids=pair_ids, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_tensors=None, # Process as lists first return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, prepend_batch_axis=prepend_batch_axis, **kwargs, ) # Convert input_ids to tokens to generate word_ids and subword_ids tokens = self.convert_ids_to_tokens(prepared_inputs["input_ids"]) # Heuristic to check if we have a sentence pair should_restart_word_ids_heuristic = tokens.count(self.sep_token) >= 2 # Create and add word_ids prepared_inputs["word_ids"] = create_word_ids( tokens, restart_new_sentence=should_restart_word_ids_heuristic, seperator_token=self.sep_token, padding_token=self.pad_token, ) # Create and add subword_ids prepared_inputs["subword_ids"] = create_subword_ids( tokens, self.max_intermediate_subword_positions_per_word, self.subword_embedding_order, self.intermediate_subword_distribution_strategy, cls_token=self.cls_token, sep_token=self.sep_token, pad_token=self.pad_token, ) # Custom tensor conversion to ensure proper dimensions for pipeline compatibility if return_tensors is not None: if return_tensors == "pt": import torch # Convert to tensors first using standard method prepared_inputs = prepared_inputs.convert_to_tensors(tensor_type=return_tensors) # If prepend_batch_axis is False, we need to ensure all tensors have batch dimension if not prepend_batch_axis: for key, value in prepared_inputs.items(): if isinstance(value, torch.Tensor) and value.ndim == 1: # Add batch dimension to make it 2D [1, seq_len] instead of [seq_len] prepared_inputs[key] = value.unsqueeze(0) else: # For other tensor types, use standard conversion prepared_inputs = prepared_inputs.convert_to_tensors(tensor_type=return_tensors) return prepared_inputs def apply_chat_template( self, conversation, chat_template=None, tools=None, documents=None, add_generation_prompt=False, tokenize=True, padding=False, truncation=None, max_length=None, return_tensors=None, return_dict=False, return_assistant_tokens_mask=False, tokenizer_kwargs=None, **kwargs, ): """ Override apply_chat_template to fix tensor dimension issues when return_tensors="pt" is used with single conversations and return_assistant_tokens_mask=True. """ # Check if we need to apply the fix needs_tensor_fix = ( return_tensors == "pt" and return_assistant_tokens_mask and return_dict and tokenize and not isinstance(conversation[0], list) if conversation else False # Single conversation, not batched ) if needs_tensor_fix: # For single conversations with tensor output, temporarily disable tensor conversion # and handle it manually after the call result = super().apply_chat_template( conversation=conversation, chat_template=chat_template, tools=tools, documents=documents, add_generation_prompt=add_generation_prompt, tokenize=tokenize, padding=padding, truncation=truncation, max_length=max_length, return_tensors=None, # Disable tensor conversion temporarily return_dict=return_dict, return_assistant_tokens_mask=return_assistant_tokens_mask, tokenizer_kwargs=tokenizer_kwargs, **kwargs, ) # Now manually convert to tensors ensuring proper dimensions if return_tensors == "pt": import torch # Convert each field to tensors with proper dimensions for key, value in result.items(): if isinstance(value, list): # Convert list to tensor and ensure it has a batch dimension tensor = torch.tensor(value) # Ensure we have at least 1D (for sequences) and add batch dimension if needed if tensor.dim() == 0: tensor = tensor.unsqueeze(0) if tensor.dim() == 1: tensor = tensor.unsqueeze(0) # Add batch dimension result[key] = tensor return result else: # For all other cases, use the parent implementation return super().apply_chat_template( conversation=conversation, chat_template=chat_template, tools=tools, documents=documents, add_generation_prompt=add_generation_prompt, tokenize=tokenize, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, return_dict=return_dict, return_assistant_tokens_mask=return_assistant_tokens_mask, tokenizer_kwargs=tokenizer_kwargs, **kwargs, ) __all__ = [ "Bert2DTokenizer", ]