# Remote code: configuration and modeling for NSA from typing import List, Optional, Dict import json from transformers import PreTrainedTokenizer class NSAByteTokenizer(PreTrainedTokenizer): """A simple byte-level tokenizer with fixed vocab size 256. - Encodes UTF-8 bytes of the input string as token ids 0..255. - No special tokens by default; EOS/PAD can be configured via special tokens map. - Decoding uses UTF-8 with replacement for invalid sequences. """ def __init__(self, **kwargs): # Build a stable 256-entry vocab mapping before base init (base may query the vocab) self._vocab: Dict[str, int] = {f"<{i}>": i for i in range(256)} self._ids_to_tokens: Dict[int, str] = {i: f"<{i}>" for i in range(256)} super().__init__(**kwargs) # Only return input_ids and attention_mask to avoid unused token_type_ids in generation self.model_input_names = ["input_ids", "attention_mask"] @property def vocab_size(self) -> int: # type: ignore[override] return 256 def get_vocab(self) -> Dict[str, int]: # type: ignore[override] return dict(self._vocab) def _tokenize(self, text: str) -> List[str]: # type: ignore[override] data = text.encode("utf-8", errors="replace") return [f"<{b}>" for b in data] def _convert_token_to_id(self, token: str) -> int: # type: ignore[override] if token in self._vocab: return self._vocab[token] # Fallback: try parse numeric inside <..> if token.startswith("<") and token.endswith(">"): try: v = int(token[1:-1]) if 0 <= v < 256: return v except Exception: pass return 0 def _convert_id_to_token(self, index: int) -> str: # type: ignore[override] return self._ids_to_tokens.get(int(index) % 256, "<0>") def convert_tokens_to_string(self, tokens: List[str]) -> str: # type: ignore[override] bs = [] for t in tokens: if t in self._vocab: bs.append(self._vocab[t]) else: try: if t.startswith("<") and t.endswith(">"): v = int(t[1:-1]) if 0 <= v < 256: bs.append(v) continue except Exception: pass return bytes(bs).decode("utf-8", errors="replace") def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: # type: ignore[override] if token_ids_1 is None: return token_ids_0 return token_ids_0 + token_ids_1 def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None): # type: ignore[override] # Nothing to save besides special tokens map handled by the base class. return (), ()