max-vit-goliath / vocab.py
AbstractPhil's picture
Create vocab.py
bca5039 verified
raw
history blame
71.4 kB
from __future__ import annotations
import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, Union, Tuple, Optional, Callable, Any, List
import warnings
from collections import defaultdict
import datasets
from datasets import load_dataset
# Optional dependencies for spatial indexing
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
try:
from sklearn.neighbors import NearestNeighbors
SKLEARN_AVAILABLE = True
except ImportError:
SKLEARN_AVAILABLE = False
class SpatialIndex:
"""Spatial indexing for fast similarity search."""
def __init__(self, vectors: np.ndarray, token_ids: List[int], method: str = "auto"):
self.token_ids = np.array(token_ids)
self.method = method
self._index = None
if method == "auto":
if FAISS_AVAILABLE and vectors.shape[0] > 1000:
method = "faiss"
elif SKLEARN_AVAILABLE:
method = "sklearn"
else:
method = "linear"
self._build_index(vectors, method)
def _build_index(self, vectors: np.ndarray, method: str):
if method == "faiss" and FAISS_AVAILABLE:
# L1 distance approximation using L2 index with normalized vectors
vectors_l2 = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-8)
self._index = faiss.IndexFlatIP(vectors_l2.shape[1]) # Inner product for normalized vectors
self._index.add(vectors_l2.astype(np.float32))
self.method = "faiss"
elif method == "sklearn" and SKLEARN_AVAILABLE:
# Use manhattan distance for true L1
self._index = NearestNeighbors(
metric='manhattan',
algorithm='ball_tree',
n_jobs=-1
).fit(vectors)
self.method = "sklearn"
else:
# Fallback to linear search
self._vectors = vectors
self.method = "linear"
def search_radius(self, query_vector: np.ndarray, max_distance: float, max_results: int = 1000) -> Tuple[
List[int], List[float]]:
"""Find all points within max_distance using L1 metric."""
if self.method == "sklearn":
indices = self._index.radius_neighbors([query_vector], radius=max_distance)[1][0]
if len(indices) > max_results:
# Compute actual distances and take closest
distances = np.sum(np.abs(self._vectors[indices] - query_vector), axis=1)
top_k = np.argsort(distances)[:max_results]
indices = indices[top_k]
distances = np.sum(np.abs(self._vectors[indices] - query_vector), axis=1)
return self.token_ids[indices].tolist(), distances.tolist()
elif self.method == "faiss":
# Approximate search using cosine similarity
query_l2 = query_vector / (np.linalg.norm(query_vector) + 1e-8)
similarities, indices = self._index.search(query_l2.reshape(1, -1).astype(np.float32), max_results)
# Filter by converting similarity threshold to approximate distance
threshold_sim = 1.0 - max_distance # rough approximation
mask = similarities[0] >= threshold_sim
return self.token_ids[indices[0][mask]].tolist(), (1.0 - similarities[0][mask]).tolist()
else: # linear
distances = np.sum(np.abs(self._vectors - query_vector), axis=1)
mask = distances <= max_distance
if np.sum(mask) > max_results:
indices = np.argsort(distances)[:max_results]
mask = np.zeros_like(distances, dtype=bool)
mask[indices] = True
return self.token_ids[mask].tolist(), distances[mask].tolist()
class GeometricVocab(ABC):
"""
Optimized geometric vocabulary with spatial indexing and caching.
"""
def __init__(self, dim: int):
self.dim = int(dim)
self._token_to_id: Dict[str, int] = {}
self._id_to_token: Dict[int, str] = {}
self._id_to_vec: Dict[int, np.ndarray] = {}
self._id_to_volume: Dict[int, float] = {}
self._id_to_provenance: Dict[int, dict] = {}
self._valid_token_ids: set[int] = set()
# Optimization caches
self._normalized_cache: Dict[int, np.ndarray] = {}
self._pooled_cache: Dict[int, np.ndarray] = {}
self._spatial_index: Optional[SpatialIndex] = None
self._index_dirty = False
# NEW: Character-level cache for Unicode composition
self._char_cache: Dict[str, np.ndarray] = {}
self._char_lookups_saved = 0 # Statistics
def _invalidate_caches(self):
"""Invalidate caches when vocabulary changes."""
self._normalized_cache.clear()
self._pooled_cache.clear()
self._spatial_index = None
self._index_dirty = True
# Keep char cache across vocabulary changes as characters are stable
def _ensure_spatial_index(self):
"""Build spatial index if needed."""
if self._spatial_index is None or self._index_dirty:
if len(self._valid_token_ids) < 10:
return # Too few tokens for indexing
pooled_vectors = []
token_ids = []
for tid in sorted(self._valid_token_ids):
pooled_vec = self._get_cached_pooled(tid)
if pooled_vec is not None:
pooled_vectors.append(pooled_vec)
token_ids.append(tid)
if pooled_vectors:
self._spatial_index = SpatialIndex(
np.array(pooled_vectors),
token_ids,
method="auto"
)
self._index_dirty = False
def _get_cached_pooled(self, token_id: int) -> Optional[np.ndarray]:
"""Get pooled vector with caching."""
if token_id in self._pooled_cache:
return self._pooled_cache[token_id]
if token_id in self._id_to_vec:
X = self._id_to_vec[token_id]
pooled = X.mean(axis=0)
self._pooled_cache[token_id] = pooled
return pooled
return None
def _get_cached_normalized(self, token_id: int) -> Optional[np.ndarray]:
"""Get L1-normalized pooled vector with caching."""
if token_id in self._normalized_cache:
return self._normalized_cache[token_id]
pooled = self._get_cached_pooled(token_id)
if pooled is not None:
normalized = pooled / (np.abs(pooled).sum() + 1e-8)
self._normalized_cache[token_id] = normalized
return normalized
return None
# --------------------------- abstract surface --------------------
@abstractmethod
def encode(self, token: str, *, return_id: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
raise NotImplementedError
@abstractmethod
def get_score(self, token_or_id: Union[str, int]) -> float:
raise NotImplementedError
# --------------------------- basic queries (optimized) -----------------------
def decode(self, token_id: int, fallback: str = "<unk>") -> Optional[str]:
if token_id in self._id_to_token:
return self._id_to_token[token_id]
return fallback if fallback in self._token_to_id else None
def decode_with_provenance(self, token_id: int, fallback: str = "<unk>") -> Tuple[Optional[str], Optional[dict]]:
tok = self.decode(token_id, fallback=fallback)
prov = self._id_to_provenance.get(token_id)
return tok, prov
def provenance(self, token_or_id: Union[str, int]) -> Optional[dict]:
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
return self._id_to_provenance.get(tid)
def embedding(self, token_or_id: Union[str, int]) -> Optional[np.ndarray]:
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
return self._id_to_vec.get(tid)
def pooled(self, token_or_id: Union[str, int], method: str = "mean") -> Optional[np.ndarray]:
"""Optimized pooled method with character caching"""
# Fast path for single characters
if isinstance(token_or_id, str) and len(token_or_id) == 1:
if token_or_id in self._char_cache:
self._char_lookups_saved += 1
return self._char_cache[token_or_id].copy() # Return copy to prevent mutation
# Regular lookup
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
if tid is None:
return None
if method == "mean":
pooled = self._get_cached_pooled(tid)
# Cache single characters for future use
if pooled is not None and isinstance(token_or_id, str) and len(token_or_id) == 1:
self._char_cache[token_or_id] = pooled.copy()
return pooled
# Fallback for other methods
X = self._id_to_vec.get(tid)
if X is None:
return None
if method == "first":
return X[0]
if method == "sum":
return X.sum(axis=0)
raise ValueError(f"Invalid pooling method: {method}")
def pooled_batch(self, tokens: List[Union[str, int]], method: str = "mean") -> List[Optional[np.ndarray]]:
"""Batch pooling with character-level caching for efficiency"""
results = []
for token in tokens:
# Use optimized single pooled method which handles char caching
results.append(self.pooled(token, method))
return results
# --------------------------- optimized similarity ---------------------
def similarity(self, token_a: Union[str, int], token_b: Union[str, int]) -> float:
"""
Optimized L1-normalized directional similarity using cached vectors.
"""
tid_a = token_a if isinstance(token_a, int) else self._token_to_id.get(token_a)
tid_b = token_b if isinstance(token_b, int) else self._token_to_id.get(token_b)
if tid_a is None or tid_b is None:
return -1.0
a_norm = self._get_cached_normalized(tid_a)
b_norm = self._get_cached_normalized(tid_b)
if a_norm is None or b_norm is None:
return -1.0
return float(np.dot(a_norm, b_norm))
def similarity_magnitude(self, token_a: Union[str, int], token_b: Union[str, int]) -> float:
"""
Raw dot-product using cached pooled vectors.
"""
tid_a = token_a if isinstance(token_a, int) else self._token_to_id.get(token_a)
tid_b = token_b if isinstance(token_b, int) else self._token_to_id.get(token_b)
if tid_a is None or tid_b is None:
return -1.0
a = self._get_cached_pooled(tid_a)
b = self._get_cached_pooled(tid_b)
if a is None or b is None:
return -1.0
return float(np.dot(a, b))
# --------------------------- optimized spatial search ---------------------
def extract_band(self, trajectory: np.ndarray, max_angle: float = 0.3, method: str = "pooled") -> Dict[
str, np.ndarray]:
"""
Optimized spatial search using indexing when available.
"""
if trajectory.ndim == 2:
direction = trajectory.mean(0)
else:
direction = trajectory
direction = direction / (np.abs(direction).sum() + 1e-8)
# Try spatial index first
self._ensure_spatial_index()
if self._spatial_index is not None:
try:
# Convert angle threshold to distance threshold (approximation)
max_distance = max_angle * 2.0 # rough conversion
token_ids, distances = self._spatial_index.search_radius(
direction, max_distance, max_results=1000
)
# Refine results with exact L1 similarity check
out: Dict[str, np.ndarray] = {}
for tid in token_ids:
tok = self._id_to_token.get(tid)
if tok is None:
continue
v_norm = self._get_cached_normalized(tid)
if v_norm is not None and float(np.dot(v_norm, direction)) >= 1.0 - max_angle:
out[tok] = self._id_to_vec[tid]
return out
except Exception as e:
warnings.warn(f"Spatial index search failed: {e}, falling back to linear")
# Fallback to linear search
out: Dict[str, np.ndarray] = {}
for tok, tid in self._token_to_id.items():
v_norm = self._get_cached_normalized(tid)
if v_norm is not None and float(np.dot(v_norm, direction)) >= 1.0 - max_angle:
out[tok] = self._id_to_vec[tid]
return out
def find_similar_tokens(self, token: Union[str, int], k: int = 10, min_similarity: float = 0.5) -> List[
Tuple[str, float]]:
"""
Find k most similar tokens using spatial indexing when available.
"""
tid = token if isinstance(token, int) else self._token_to_id.get(token)
if tid is None:
return []
query_vec = self._get_cached_normalized(tid)
if query_vec is None:
return []
self._ensure_spatial_index()
if self._spatial_index is not None:
try:
# Use spatial index for approximate search
max_distance = (1.0 - min_similarity) * 2.0
token_ids, _ = self._spatial_index.search_radius(
query_vec, max_distance, max_results=k * 3 # Get extra for refinement
)
# Compute exact similarities and sort
similarities = []
for tid_cand in token_ids:
if tid_cand == tid: # Skip self
continue
sim = self.similarity(tid, tid_cand)
if sim >= min_similarity:
tok = self._id_to_token.get(tid_cand)
if tok:
similarities.append((tok, sim))
return sorted(similarities, key=lambda x: x[1], reverse=True)[:k]
except Exception as e:
warnings.warn(f"Spatial similarity search failed: {e}, falling back to linear")
# Linear fallback
similarities = []
for tok_cand, tid_cand in self._token_to_id.items():
if tid_cand == tid:
continue
sim = self.similarity(tid, tid_cand)
if sim >= min_similarity:
similarities.append((tok_cand, sim))
return sorted(similarities, key=lambda x: x[1], reverse=True)[:k]
# --------------------------- helpers exposed to callbacks --------
def _helpers(self) -> Dict[str, Callable[..., np.ndarray]]:
def _emb(x):
e = self.embedding(x)
return None if e is None else np.asarray(e, np.float32)
def _poo(x):
p = self.pooled(x)
return None if p is None else np.asarray(p, np.float32)
def _chars(s):
# Use batch pooling for efficiency
return self.pooled_batch(list(s)) if isinstance(s, str) else None
return {"embedding": _emb, "pooled": _poo, "chars_pooled": _chars}
# --------------------------- DEFAULT create_crystal (unicode path) ----
def _default_create_crystal(self, config: dict, callback: Callable[..., np.ndarray]) -> np.ndarray:
"""
Deterministic default when user leaves callback/create_crystal=None.
"""
pool_type = config.get("pool_type") or "unicode"
H = config["helpers"]
token_plain = str(config["data"]["token"])
d = int(config["dim"])
c_uni = self._compose_unicode_center(token_plain, H, pool_type, d)
c_defs = self._compose_wordnet_center(config.get("additional_definitions", []), H, pool_type, d)
if pool_type == "combination":
parts = [v for v in (c_uni, c_defs) if v is not None]
c = np.mean(np.stack(parts, 0), 0) if parts else np.zeros(d, np.float32)
elif pool_type == "wordnet":
c = c_defs if c_defs is not None else np.zeros(d, np.float32)
else:
c = c_uni if c_uni is not None else np.zeros(d, np.float32)
# L1 normalization only
l1 = float(np.abs(c).sum()) + 1e-8
c = c / l1
return self._deterministic_pentachoron(c)
def _default_unicode_callback(self, name: str, **kwargs) -> np.ndarray:
raise NotImplementedError("Default callback is not invoked directly.")
# --------------------------- universal builders (overrideable) ---
def _compose_unicode_center(
self, token_plain: str, H, pool_type: Optional[str], dim: int
) -> Optional[np.ndarray]:
"""
Build a center vector from the token's Unicode characters - OPTIMIZED.
"""
# Use batch pooling for all characters at once
char_list = list(token_plain)
pooled_chars = self.pooled_batch(char_list)
vecs: List[np.ndarray] = []
for pooled_v in pooled_chars:
if pooled_v is None:
continue
v = np.asarray(pooled_v, np.float32)
if v.shape[0] != dim:
raise ValueError(f"Unicode pooled dim mismatch: got {v.shape[0]}, expected {dim}")
vecs.append(v)
if not vecs:
return None
stacked = np.stack(vecs, 0)
if pool_type in (None, "unicode", "mean"):
c = stacked.mean(axis=0)
elif pool_type == "abs":
c = np.abs(stacked).mean(axis=0)
elif pool_type == "dot":
c = stacked.mean(axis=0)
c = c / (np.abs(c).sum() + 1e-8) # L1 normalize
elif pool_type == "mse":
c = (stacked ** 2).mean(axis=0)
elif pool_type == "max":
c = stacked.max(axis=0)
else:
raise ValueError(f"Unsupported pool_type '{pool_type}'")
return c.astype(np.float32, copy=False)
def _compose_wordnet_center(
self, definitions: List[str], H, pool_type: Optional[str], dim: int
) -> Optional[np.ndarray]:
"""Build a center vector from definition text characters - OPTIMIZED."""
# Collect all characters from all definitions
all_chars = []
for text in definitions:
all_chars.extend(list(str(text)))
# Batch lookup
pooled_chars = self.pooled_batch(all_chars)
vecs: List[np.ndarray] = []
for pooled_v in pooled_chars:
if pooled_v is None:
continue
v = np.asarray(pooled_v, np.float32)
if v.shape[0] != dim:
raise ValueError(f"Definition pooled dim mismatch: got {v.shape[0]}, expected {dim}")
vecs.append(v)
if not vecs:
return None
stacked = np.stack(vecs, 0)
if pool_type in (None, "unicode", "mean"):
c = stacked.mean(axis=0)
elif pool_type == "abs":
c = np.abs(stacked).mean(axis=0)
elif pool_type == "dot":
c = stacked.mean(axis=0)
c = c / (np.abs(c).sum() + 1e-8) # L1 normalize
elif pool_type == "mse":
c = (stacked ** 2).mean(axis=0)
elif pool_type == "max":
c = stacked.max(axis=0)
else:
raise ValueError(f"Unsupported pool_type '{pool_type}'")
return c.astype(np.float32, copy=False)
def _deterministic_pentachoron(self, center_vec: np.ndarray) -> np.ndarray:
"""Universal pentachoron inflation (deterministic; overrideable)."""
d = center_vec.shape[0]
proposals = np.stack([
center_vec,
np.roll(center_vec, 1),
np.roll(center_vec, 3) * np.sign(center_vec + 1e-8),
np.roll(center_vec, 7) - center_vec,
np.roll(center_vec, 11) + center_vec,
], 0).astype(np.float32)
# L1 row norms
norms = np.sum(np.abs(proposals), axis=1, keepdims=True) + 1e-8
Q = proposals / norms
# GS orthogonalization with L1 row renorm at each step
for i in range(5):
for j in range(i):
Q[i] -= np.dot(Q[i], Q[j]) * Q[j]
Q[i] /= (np.sum(np.abs(Q[i])) + 1e-8)
gamma = np.array([1.0, 0.9, -0.8, 1.1, 1.2], np.float32)
X = np.zeros((5, d), np.float32)
for i in range(5):
X[i] = center_vec + gamma[i] * Q[i]
return X - X.mean(0, keepdims=True)
# --------------------------- finalize + provenance (overrideable) ----
def _finalize_crystal(self, X: np.ndarray) -> np.ndarray:
X = np.asarray(X, np.float32, order='C') # Ensure C-contiguous
if X.shape != (5, self.dim):
raise ValueError(f"Crystal must be shape (5, {self.dim}); got {X.shape}.")
return X - X.mean(0, keepdims=True)
def _auto_provenance_from_cfg(self, cfg: Dict[str, Any]) -> dict:
token = cfg["data"]["token"]
prov = {
"source": "special/compose",
"token": token,
"pool_type": cfg.get("pool_type") or "unicode",
"components": list(token),
"additional_definitions": list(cfg.get("additional_definitions", [])),
}
if cfg.get("antonyms"):
prov["antonyms"] = list(cfg["antonyms"])
if cfg.get("inversion_formula") is not None:
prov["inversion_formula"] = "user_supplied"
return prov
def _finalize_crystal_and_provenance(
self, product: Union[np.ndarray, Dict[str, Any]], cfg: Dict[str, Any]
) -> Tuple[np.ndarray, dict]:
# ndarray path
if isinstance(product, np.ndarray):
X = self._finalize_crystal(product)
prov = self._auto_provenance_from_cfg(cfg)
return X, prov
# dict path
if not isinstance(product, dict):
raise TypeError(
"create_crystal must return ndarray or dict with {'base':..., 'ops':..., 'provenance':...}.")
base = np.asarray(product["base"], np.float32)
X = base
for op in product.get("ops", []):
name = op.get("name")
if name == "center":
X -= X.mean(0, keepdims=True)
elif name == "scale":
X *= float(op.get("k", 1.0))
elif name == "translate":
t = np.asarray(op.get("t"), np.float32)
if t.shape != (self.dim,):
raise ValueError(f"translate.t must be shape ({self.dim},)")
X = X + t[None, :]
elif name == "normalize_rows":
n = np.sum(np.abs(X), axis=1, keepdims=True) + 1e-8
X = X / n
elif name == "align_to":
v = np.asarray(op.get("v"), np.float32)
if v.shape != (self.dim,):
raise ValueError(f"align_to.v must be shape ({self.dim},)")
v = v / (np.abs(v).sum() + 1e-8)
p = X.mean(0)
p = p / (np.abs(p).sum() + 1e-8)
alpha = float(op.get("alpha", 1.0))
X = X + alpha * (v - p)[None, :]
else:
raise ValueError(f"Unsupported op '{name}'")
prov = dict(product.get("provenance", {})) or self._auto_provenance_from_cfg(cfg)
return self._finalize_crystal(X), prov
# --------------------------- universal manifestation routine ----------
def _manifest_special_tokens(
self,
base_set: Dict[str, int],
create_crystal: Callable[[dict, Callable[..., np.ndarray]], Union[np.ndarray, Dict[str, Any]]],
callback: Optional[Callable[..., np.ndarray]],
create_config: Dict[str, Any],
) -> None:
"""Universal, deterministic manifestor with character pre-caching."""
# NEW: Pre-cache all unique characters that will be needed
unique_chars = set()
for name in base_set.keys():
token_plain = name.strip("<>").strip()
unique_chars.update(token_plain)
print(f"[⚡] Pre-caching {len(unique_chars)} unique characters...")
for ch in unique_chars:
_ = self.pooled(ch) # Trigger caching
helpers = self._helpers()
for name, tid in base_set.items():
# Keep if already present
if tid in self._id_to_vec:
self._token_to_id[name] = tid
self._id_to_token.setdefault(tid, name)
self._valid_token_ids.add(tid)
continue
# Build per-token config
cfg = {
"dim": self.dim,
"pool_type": create_config.get("pool_type", None),
"special_tokens": create_config.get("special_tokens"),
"additional_definitions": create_config.get("additional_definitions", []),
"antonyms": create_config.get("antonyms"),
"inversion_formula": create_config.get("inversion_formula"),
"data": {"token": name.strip("<>").strip(), "token_id": tid, "origin": "special"},
"helpers": helpers,
}
if create_crystal is None:
create_crystal = self._default_create_crystal
product = create_crystal(cfg, callback) if callback is not None else create_crystal(cfg,
self._default_unicode_callback)
X, prov = self._finalize_crystal_and_provenance(product, cfg)
# Register
self._token_to_id[name] = tid
self._id_to_token[tid] = name
self._id_to_vec[tid] = X.astype(np.float32, copy=False, order='C')
self._id_to_provenance[tid] = prov
self._valid_token_ids.add(tid)
self._id_to_volume.setdefault(tid, 1.0)
# Aliases
for alias in (cfg.get("special_tokens") or []):
alias = str(alias)
self._token_to_id[alias] = tid
self._id_to_token.setdefault(tid, alias)
if cfg.get("special_tokens"):
self._id_to_provenance[tid].setdefault("aliases", list(cfg["special_tokens"]))
# Antonyms
antonyms = cfg.get("antonyms") or []
invf = cfg.get("inversion_formula")
if invf:
for anti in antonyms:
if anti in base_set:
anti_id = base_set[anti]
if anti_id not in self._id_to_vec:
X_inv = invf(X, cfg) # must be deterministic
X_inv = self._finalize_crystal(X_inv)
self._token_to_id[anti] = anti_id
self._id_to_token[anti_id] = anti
self._id_to_vec[anti_id] = X_inv.astype(np.float32, copy=False, order='C')
inv_prov = {
"source": "inversion",
"of_token": name,
"of_token_id": tid,
"pool_type": cfg.get("pool_type") or "unicode",
"components": prov.get("components", []),
"additional_definitions": cfg.get("additional_definitions", []),
"ops": ["invert"],
}
self._id_to_provenance[anti_id] = inv_prov
self._valid_token_ids.add(anti_id)
self._id_to_volume.setdefault(anti_id, 1.0)
# Invalidate caches after adding tokens
self._invalidate_caches()
if self._char_lookups_saved > 0:
print(f"[✅] Character cache saved {self._char_lookups_saved} lookups")
# --------------------------- basics -------------------------------
def vocab_size(self) -> int:
return len(self._token_to_id)
def token_to_id(self, token: str) -> Optional[int]:
return self._token_to_id.get(token)
def id_to_token(self, token_id: int) -> Optional[str]:
return self._id_to_token.get(token_id)
def cache_stats(self) -> Dict[str, int]:
"""Get cache statistics."""
return {
"normalized_cache_size": len(self._normalized_cache),
"pooled_cache_size": len(self._pooled_cache),
"char_cache_size": len(self._char_cache),
"char_lookups_saved": self._char_lookups_saved,
"spatial_index_size": len(self._spatial_index.token_ids) if self._spatial_index else 0,
"vocab_size": len(self._valid_token_ids)
}
def clear_caches(self):
"""Clear all caches to free memory."""
self._invalidate_caches()
self._char_cache.clear()
self._char_lookups_saved = 0
from typing import List, Dict, Union, Optional, Tuple, Callable, Any
class PretrainedGeometricVocab(GeometricVocab):
"""
Parquet-backed deterministic vocab with columnar load, duplicate-mean aggregation,
pooled caching, and fast path for flat crystals.
"""
def __init__(
self,
repo_id: str,
dim: int,
*,
subset: str = "unicode",
split: str = "train_100d",
base_set: Optional[Dict[str, int]] = None,
create_config: Optional[Dict[str, Any]] = None,
create_crystal: Optional[Callable[[dict, Callable[..., np.ndarray]], Union[np.ndarray, Dict[str, Any]]]] = None,
callback: Optional[Callable[..., np.ndarray]] = None,
manifest_specials: bool = True,
# perf/robustness knobs
store: str = "full", # "full" | "pooled" | "both"
reshape_order: str = "C",
vertex_count: int = 5,
infer_dim: bool = True,
strict_shapes: bool = False,
# new perf knobs
finalize_mode: str = "post_mean", # "none" | "post_mean"
cache_pooled: bool = True,
streaming=False,
):
super().__init__(dim)
self.repo_id = str(repo_id)
self._id_to_pooled: Dict[int, np.ndarray] = {} # optional pooled cache
# ---------- load split (columnar, minimal columns) ----------
ds = load_dataset(self.repo_id, split=split)
have = set(ds.column_names)
wanted = ["token_id", "token", "crystal", "volume"]
keep = [c for c in wanted if c in have]
drop = [c for c in ds.column_names if c not in keep]
if drop:
ds = ds.remove_columns(drop)
ds = ds.with_format("numpy", columns=keep)
ids = ds["token_id"] if "token_id" in keep else np.array([], dtype=np.int64)
toks = ds["token"] if "token" in keep else np.array([], dtype=object)
cryst= ds["crystal"] if "crystal" in keep else np.array([], dtype=object)
vols = ds["volume"] if "volume" in keep else None
ids = np.asarray(ids).astype(np.int64, copy=False)
toks = np.asarray(toks)
# --------- shape helpers ----------
def _coerce(raw: Any) -> np.ndarray:
X = np.asarray(raw, np.float32)
if X.ndim == 2:
V, D = int(X.shape[0]), int(X.shape[1])
if V != vertex_count:
raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.")
if D != self.dim:
if infer_dim: self.dim = D
else: raise ValueError(f"Dim mismatch: got {D}, expected {self.dim}.")
return X
if X.ndim == 1:
n = int(X.size)
if n == vertex_count * self.dim:
return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
if infer_dim and n % vertex_count == 0:
self.dim = n // vertex_count
return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
if n == self.dim:
c = X / (np.abs(X).sum() + 1e-8)
return self._deterministic_pentachoron(c)
raise ValueError(f"Unsupported crystal shape {X.shape if isinstance(X, np.ndarray) else type(X)}.")
def _finalize_if_needed(X: np.ndarray) -> np.ndarray:
if finalize_mode == "none":
return np.asarray(X, np.float32, order="C")
elif finalize_mode == "post_mean":
return self._finalize_crystal(X)
else:
raise ValueError(f"finalize_mode must be 'none' or 'post_mean', got {finalize_mode!r}")
vols_f = np.asarray(vols, dtype=np.float32) if vols is not None else None
# ---------- FAST PATH: flat uniform crystals ----------
# Try to stack into (N, L); succeeds when each row is the same length.
fastpath_ok = False
A = None # (N, L) float32
try:
A = np.stack(cryst) # may raise if jagged / object
if A.ndim == 2 and A.dtype != object:
A = A.astype(np.float32, copy=False)
L = A.shape[1]
if L % vertex_count == 0:
# infer or validate D
D = L // vertex_count
if self.dim != D:
if infer_dim:
self.dim = int(D)
else:
raise ValueError(f"Dim mismatch: got D={D}, expected dim={self.dim}.")
fastpath_ok = True
except Exception:
fastpath_ok = False
if fastpath_ok and A is not None and len(ids) > 0:
# reshape to (N, V, D)
V = vertex_count
D = self.dim
A = A.reshape(-1, V, D, order=reshape_order)
# sort by ids and reduceat to mean duplicates in pure NumPy
order = np.argsort(ids, kind="stable")
ids_sorted = ids[order]
A_sorted = A[order]
vols_sorted = vols_f[order] if vols_f is not None else None
uniq_ids, idx, counts = np.unique(ids_sorted, return_index=True, return_counts=True)
sums = np.add.reduceat(A_sorted, idx, axis=0) # (K, V, D)
means = sums / counts[:, None, None] # (K, V, D)
if vols_sorted is not None:
v_sums = np.add.reduceat(vols_sorted, idx)
v_means = v_sums / counts.astype(np.float32)
else:
v_means = np.ones_like(uniq_ids, dtype=np.float32)
# commit maps
self._token_to_id.clear(); self._id_to_token.clear()
self._id_to_vec.clear(); self._id_to_volume.clear(); self._valid_token_ids.clear()
self._id_to_pooled.clear()
# pick a representative token per id: first occurrence in sorted block
toks_sorted = toks[order]
rep_toks = toks_sorted[idx]
for tid, tok, X_mean, v_m in zip(uniq_ids.tolist(), rep_toks.tolist(), means, v_means.tolist()):
# cache pooled BEFORE finalize to preserve signal
if cache_pooled:
self._id_to_pooled[tid] = X_mean.mean(axis=0).astype(np.float32, copy=False)
X_store = _finalize_if_needed(X_mean)
self._token_to_id[str(tok)] = tid
self._id_to_token[tid] = str(tok)
if store in ("full", "both"):
self._id_to_vec[tid] = np.asarray(X_store, np.float32, order="C")
elif store == "pooled":
# store pooled as embedding if desired
self._id_to_vec[tid] = (self._id_to_pooled[tid] if cache_pooled
else X_mean.mean(axis=0).astype(np.float32, copy=False))
self._id_to_volume[tid] = float(v_m)
self._valid_token_ids.add(tid)
else:
# ---------- FALLBACK: per-row coerce + dict mean ----------
ids_int = ids.tolist()
toks_str = [str(x) for x in toks.tolist()]
vols_f = (vols_f.tolist() if vols_f is not None else [1.0] * len(ids_int))
x_sum: Dict[int, np.ndarray] = {}
v_sum: Dict[int, float] = {}
n_cnt: Dict[int, int] = {}
tok_pref: Dict[int, str] = {}
for tid, tok, raw, vol in zip(ids_int, toks_str, cryst, vols_f):
X = _coerce(raw) # [V,D] float32
if tid not in x_sum:
x_sum[tid] = X.astype(np.float32, copy=True)
v_sum[tid] = float(vol)
n_cnt[tid] = 1
tok_pref[tid] = tok
else:
x_sum[tid] += X
v_sum[tid] += float(vol)
n_cnt[tid] += 1
self._token_to_id.clear(); self._id_to_token.clear()
self._id_to_vec.clear(); self._id_to_volume.clear(); self._valid_token_ids.clear()
self._id_to_pooled.clear()
for tid in x_sum.keys(): # order not critical; add sorted(tids) if you need determinism
X_mean = x_sum[tid] / float(n_cnt[tid])
if cache_pooled:
self._id_to_pooled[tid] = X_mean.mean(axis=0).astype(np.float32, copy=False)
X_store = _finalize_if_needed(X_mean)
tok = tok_pref[tid]
vol_m = v_sum[tid] / float(n_cnt[tid])
self._token_to_id[tok] = tid
self._id_to_token[tid] = tok
if store in ("full", "both"):
self._id_to_vec[tid] = np.asarray(X_store, np.float32, order="C")
elif store == "pooled":
self._id_to_vec[tid] = (self._id_to_pooled[tid] if cache_pooled
else X_mean.mean(axis=0).astype(np.float32, copy=False))
self._id_to_volume[tid] = float(vol_m)
self._valid_token_ids.add(tid)
# ---------- specials ----------
if manifest_specials and base_set:
self._manifest_special_tokens(
base_set=base_set,
create_crystal=create_crystal,
callback=callback,
create_config=create_config or {}
)
# -------- override pooled() to use cache (if present) --------
def pooled(self, token_or_id: Union[str, int], method: str = "mean") -> Optional[np.ndarray]:
# Favor cached pooled when available; fallback to base (computes mean)
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
if tid is not None and tid in self._id_to_pooled:
return self._id_to_pooled[tid]
return super().pooled(token_or_id, method=method)
# -------- SP-like surface --------
def encode(self, token: str, *, return_id: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
tid = self._token_to_id.get(token)
if tid is None:
unk_id = self._token_to_id.get("<unk>")
if unk_id is None:
raise KeyError(f"Token '{token}' not found and '<unk>' missing.")
X = self._id_to_vec[unk_id]
return (X, unk_id) if return_id else X
X = self._id_to_vec[tid]
return (X, tid) if return_id else X
def get_score(self, token_or_id: Union[str, int]) -> float:
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id, None)
if tid is None or tid not in self._valid_token_ids:
return -100.0
vol = self._id_to_volume.get(tid, 1.0)
return float(np.clip(vol / 10.0, 0.01, 1.0))
# -------- Torch cache ----------
def cache(self, tokens: Union[List[str], Dict[str, int]], device: str = "cpu", dtype: torch.dtype = torch.float32):
tok_list = list(tokens.keys()) if isinstance(tokens, dict) else list(tokens)
mats, pooled, keep = [], [], []
for t in tok_list:
X = self.embedding(t)
v = self.pooled(t)
if X is None or v is None:
continue
mats.append(torch.as_tensor(X, dtype=dtype))
pooled.append(torch.as_tensor(v, dtype=dtype))
keep.append(t)
if not mats:
raise ValueError("No valid tokens found in input.")
return {
"tokens": keep,
"crystals": torch.stack(mats, 0).to(device),
"pooled": torch.stack(pooled, 0).to(device),
}
def _coerce_crystal_shape(
self,
raw: Any,
*,
vertex_count: int,
reshape_order: str,
infer_dim: bool,
strict_shapes: bool
) -> np.ndarray:
"""
Accepts raw crystal data and returns [vertex_count, self.dim] float32 C-order.
Acceptable inputs:
- [vertex_count, D]
- [vertex_count * D] (flat) -> reshaped to [vertex_count, D]
- [D] (pooled center) -> converted by deterministic pentachoron (fallback)
"""
X = np.asarray(raw, dtype=np.float32)
# Already [V, D]
if X.ndim == 2:
V, D = int(X.shape[0]), int(X.shape[1])
if V != vertex_count:
if strict_shapes:
raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.")
# Gentle fallback: attempt to treat rows as vertices if divisible
if V * D % vertex_count == 0 and infer_dim:
# e.g., [10, D] -> try to collapse/average into [5,D]? Not safe.
# Safer: hard error to avoid silent geometry change.
raise ValueError(f"Unexpected vertex rows {V}; refusing to coerce silently.")
else:
raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.")
# Update dim if needed
if D != self.dim:
if infer_dim:
self.dim = D
else:
raise ValueError(f"Dim mismatch: got D={D}, expected dim={self.dim}.")
# Ensure mean-centered (finalize handles centering)
return X
# Flat [V*D]
if X.ndim == 1:
n = int(X.size)
# Exact match for flat crystal
if n == vertex_count * self.dim:
return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
# Infer D from total length if divisible
if infer_dim and n % vertex_count == 0:
inferred = n // vertex_count
self.dim = int(inferred)
return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
# Pooled [D]: inflate deterministically to [V, D]
if n == self.dim:
c = X / (np.abs(X).sum() + 1e-8) # L1
return self._deterministic_pentachoron(c)
if strict_shapes:
raise ValueError(
f"Cannot coerce crystal of length {n}. "
f"Expected {vertex_count*self.dim} (flat) or {self.dim} (pooled)."
)
# Conservative fallback: treat as pooled center with inferred D if reasonable
if infer_dim and n > 0:
self.dim = n
c = X / (np.abs(X).sum() + 1e-8)
return self._deterministic_pentachoron(c)
raise ValueError(f"Unsupported crystal shape {X.shape} (ndim={X.ndim}).")
# -------- Introspection --------
def describe(self) -> Dict[str, Union[str, int]]:
return {"repo": self.repo_id, "dimension": self.dim, "vocab_size": self.vocab_size()}
from __future__ import annotations
import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, Union, Tuple, Optional, Callable, Any, List
import warnings
from collections import OrderedDict
import datasets
from datasets import load_dataset
# Global flag for warning suppression
SILENT_MODE = False
def set_silent_mode(silent: bool):
"""Set global silent mode for token synthesis warnings"""
global SILENT_MODE
SILENT_MODE = silent
class LRUCache(OrderedDict):
"""Simple LRU cache implementation"""
def __init__(self, maxsize=128):
super().__init__()
self.maxsize = maxsize
def __getitem__(self, key):
value = super().__getitem__(key)
self.move_to_end(key)
return value
def __setitem__(self, key, value):
if key in self:
self.move_to_end(key)
super().__setitem__(key, value)
if len(self) > self.maxsize:
oldest = next(iter(self))
del self[oldest]
class LazyGeometricVocab(GeometricVocab):
"""
Lazy-loading geometric vocabulary that loads tokens on demand.
Maintains a small working set in memory with LRU eviction.
Supports automatic token synthesis for missing tokens.
"""
def __init__(
self,
repo_id: str,
dim: int,
*,
name: str = "unicode_100d", # Updated default to match new structure
split: str = "train", # Updated default to "train"
stream: bool = True, # Use streaming by default to avoid bulk downloads
base_set: Optional[Dict[str, int]] = None,
create_config: Optional[Dict[str, Any]] = None,
create_crystal: Optional[Callable] = None,
callback: Optional[Callable] = None,
manifest_specials: bool = True,
# Lazy loading parameters
cache_size: int = 1000, # Max tokens to keep in memory
preload_tokens: Optional[List[str]] = None, # Critical tokens to preload
index_cache_path: Optional[str] = None, # Path to save/load index
# Tokenization
tokenizer: Optional[Callable[[str], List[str]]] = None, # Custom tokenizer
# Synthesis settings
silent: bool = False, # Suppress synthesis warnings
# Performance knobs
store: str = "full",
reshape_order: str = "C",
vertex_count: int = 5,
infer_dim: bool = True,
finalize_mode: str = "post_mean",
cache_pooled: bool = True,
):
super().__init__(dim)
self.repo_id = repo_id
self.name = name
self.split = split
self.stream = stream
self.vertex_count = vertex_count
self.reshape_order = reshape_order
self.infer_dim = infer_dim
self.finalize_mode = finalize_mode
self.store = store
self.cache_pooled = cache_pooled
self.silent = silent
# Initialize pooled dictionary that may be missing from parent class
if not hasattr(self, '_id_to_pooled'):
self._id_to_pooled = {}
# For synthesis
self.create_crystal_fn = create_crystal
self.callback_fn = callback
self.create_config = create_config or {}
self._synthesized_tokens: set = set()
self._next_synthetic_id = -1 # Use negative IDs for synthetic tokens
# Tokenizer - default to simple split
self.tokenizer = tokenizer or (lambda s: s.split())
# LRU caches for lazy loading
self._crystal_cache = LRUCache(maxsize=cache_size)
self._pooled_lru = LRUCache(maxsize=cache_size * 2) # Pooled vectors are smaller
# Load dataset but don't fetch data yet
self._dataset = None
self._dataset_stream = None
self._token_index: Dict[str, List[int]] = {} # token -> [row indices]
self._id_index: Dict[int, List[int]] = {} # token_id -> [row indices]
self._row_data: Dict[int, dict] = {} # row -> cached data
# Initialize index
self._build_index(split, name)
# Pre-load base characters for synthesis
self._preload_synthesis_base()
# Preload critical tokens if specified
if preload_tokens:
self._preload(preload_tokens)
# Manifest special tokens
if manifest_specials and base_set:
self._manifest_special_tokens(
base_set=base_set,
create_crystal=create_crystal,
callback=callback,
create_config=create_config or {}
)
def _preload_synthesis_base(self):
"""Pre-load basic ASCII characters needed for synthesis"""
# Essential characters that are commonly used in token synthesis
base_chars = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-_()[]{}:;'\"")
print(f"Pre-loading {len(base_chars)} base characters for synthesis...")
loaded = 0
for char in base_chars:
tid = self._token_to_id.get(char)
if tid:
# Pre-load this character's embedding
if self._load_crystal(tid) is not None:
loaded += 1
print(f"Loaded {loaded} base characters")
def _build_index(self, split: str, name: str):
"""Build token/id index without loading crystal data"""
print(f"Building index for {self.repo_id}/{name}/{split}...")
if self.stream:
try:
# Use streaming to avoid downloading all splits
# Don't specify columns in streaming mode to avoid schema issues
self._dataset_stream = load_dataset(
self.repo_id,
name=name,
split=split,
streaming=True
)
# Build index from streaming dataset
for idx, row in enumerate(self._dataset_stream):
token = str(row["token"])
token_id = int(row["token_id"])
# Token index
if token not in self._token_index:
self._token_index[token] = []
self._token_index[token].append(idx)
# ID index
if token_id not in self._id_index:
self._id_index[token_id] = []
self._id_index[token_id].append(idx)
# Update mappings (use first occurrence)
if token not in self._token_to_id:
self._token_to_id[token] = token_id
self._id_to_token[token_id] = token
self._valid_token_ids.add(token_id)
print(f"Index built (streaming): {len(self._token_index)} unique tokens")
except Exception as e:
print(f"Streaming failed: {e}")
print("Falling back to non-streaming mode...")
self.stream = False
# Recursive call with streaming disabled
self._build_index(split, name)
else:
# Non-streaming mode - load dataset normally
try:
# Try with data_files to load only specific split
data_files = f"data/{name}/{split}-*.parquet"
ds = load_dataset(
self.repo_id,
data_files=data_files,
split="train"
)
except:
# Fallback to normal loading
try:
ds = load_dataset(
self.repo_id,
name=name,
split=split
)
except Exception as e:
print(f"Failed to load dataset: {e}")
raise
# Build indices
for idx, row in enumerate(ds):
token = str(row["token"])
token_id = int(row["token_id"])
# Token index
if token not in self._token_index:
self._token_index[token] = []
self._token_index[token].append(idx)
# ID index
if token_id not in self._id_index:
self._id_index[token_id] = []
self._id_index[token_id].append(idx)
# Update mappings (use first occurrence)
if token not in self._token_to_id:
self._token_to_id[token] = token_id
self._id_to_token[token_id] = token
self._valid_token_ids.add(token_id)
# Store dataset reference (will lazy load full data)
self._dataset = ds
print(f"Index built: {len(self._token_index)} unique tokens")
def _load_row(self, row_idx: int) -> dict:
"""Load a single row from dataset"""
if row_idx in self._row_data:
return self._row_data[row_idx]
# If streaming, need to load the full dataset on first data access
if self.stream and self._dataset is None:
print(f"Loading full dataset for {self.repo_id}/{self.name}/{self.split}...")
try:
# Try with data_files first
data_files = f"data/{self.name}/{self.split}-*.parquet"
self._dataset = load_dataset(
self.repo_id,
data_files=data_files,
split="train"
)
except:
# Fallback to normal loading
self._dataset = load_dataset(
self.repo_id,
name=self.name,
split=self.split
)
if self._dataset is None:
raise RuntimeError("Dataset not initialized")
row = self._dataset[row_idx]
self._row_data[row_idx] = row
return row
def _load_crystal(self, token_id: int) -> Optional[np.ndarray]:
"""Load and aggregate crystal for a token_id"""
if token_id in self._crystal_cache:
return self._crystal_cache[token_id]
if token_id not in self._id_index:
return None
row_indices = self._id_index[token_id]
crystals = []
volumes = []
for idx in row_indices:
row = self._load_row(idx)
# Parse crystal
raw_crystal = row.get("crystal")
if raw_crystal is not None:
X = self._coerce_crystal(raw_crystal)
crystals.append(X)
# Get volume if available
vol = row.get("volume", 1.0)
volumes.append(float(vol))
if not crystals:
return None
# Average multiple occurrences
if len(crystals) == 1:
X_final = crystals[0]
vol_final = volumes[0]
else:
X_final = np.mean(crystals, axis=0)
vol_final = np.mean(volumes)
# Finalize
if self.finalize_mode == "post_mean":
X_final = self._finalize_crystal(X_final)
# Cache based on store mode
if self.store in ("full", "both"):
self._crystal_cache[token_id] = X_final
self._id_to_vec[token_id] = X_final
# Cache pooled if requested
if self.cache_pooled:
pooled = X_final.mean(axis=0)
self._pooled_lru[token_id] = pooled
if token_id not in self._id_to_pooled:
self._id_to_pooled[token_id] = pooled
# Store volume
self._id_to_volume[token_id] = vol_final
return X_final
def _coerce_crystal(self, raw: Any) -> np.ndarray:
"""Convert raw crystal data to proper shape"""
X = np.asarray(raw, dtype=np.float32)
if X.ndim == 2:
V, D = X.shape
if V != self.vertex_count:
raise ValueError(f"Expected {self.vertex_count} vertices, got {V}")
if D != self.dim:
if self.infer_dim:
self.dim = D
else:
raise ValueError(f"Dimension mismatch: {D} vs {self.dim}")
return X
if X.ndim == 1:
n = X.size
if n == self.vertex_count * self.dim:
return X.reshape((self.vertex_count, self.dim), order=self.reshape_order)
if self.infer_dim and n % self.vertex_count == 0:
self.dim = n // self.vertex_count
return X.reshape((self.vertex_count, self.dim), order=self.reshape_order)
if n == self.dim:
# Pooled vector - expand to crystal
c = X / (np.abs(X).sum() + 1e-8)
return self._deterministic_pentachoron(c)
raise ValueError(f"Cannot coerce crystal shape {X.shape}")
def _synthesize_token(self, token: str) -> int:
"""Synthesize a new token embedding on-the-fly with fallback for missing chars."""
# Generate a new ID for synthetic token
tid = self._next_synthetic_id
self._next_synthetic_id -= 1
# Warn user unless silenced
if not self.silent and not SILENT_MODE:
warnings.warn(
f"Token '{token}' synthesized - ensure you synthesize your tokens ahead of time.",
UserWarning,
stacklevel=3
)
# Track as synthesized
self._synthesized_tokens.add(token)
# Try to use character-based synthesis first
try:
# Check if all characters are available
missing_chars = []
for char in token:
if char not in self._token_to_id and char not in self._char_cache:
missing_chars.append(char)
# If missing chars, try to load or synthesize them first
if missing_chars:
for char in missing_chars:
char_tid = self._token_to_id.get(char)
if char_tid:
# Try to load it
self._load_crystal(char_tid)
else:
# Create a simple embedding for this character
self._synthesize_simple_char(char)
# Now try the full synthesis
helpers = self._helpers()
cfg = {
"dim": self.dim,
"pool_type": self.create_config.get("pool_type", "unicode"),
"data": {"token": token, "token_id": tid, "origin": "synthetic"},
"helpers": helpers,
}
if self.create_crystal_fn is not None:
product = self.create_crystal_fn(cfg, self.callback_fn)
else:
product = self._default_create_crystal(cfg, self._default_unicode_callback)
X, prov = self._finalize_crystal_and_provenance(product, cfg)
except Exception as e:
# Fallback to simple random synthesis
print(f"Character-based synthesis failed for '{token}': {e}. Using random synthesis.")
X = self._synthesize_random_crystal(token)
prov = {"source": "synthetic_random", "token": token}
prov["synthetic"] = True
# Register in all maps
self._token_to_id[token] = tid
self._id_to_token[tid] = token
self._id_to_vec[tid] = X.astype(np.float32, copy=False, order='C')
self._id_to_provenance[tid] = prov
self._valid_token_ids.add(tid)
self._id_to_volume[tid] = 1.0
# Cache
self._crystal_cache[tid] = X
if self.cache_pooled:
pooled = X.mean(axis=0)
self._pooled_lru[tid] = pooled
self._id_to_pooled[tid] = pooled
return tid
def _synthesize_simple_char(self, char: str):
"""Create a simple deterministic embedding for a single character"""
import hashlib
# Use character's unicode codepoint for deterministic generation
if len(char) == 1:
seed = ord(char)
else:
seed = int(hashlib.md5(char.encode()).hexdigest()[:8], 16)
np.random.seed(seed)
# Generate a simple vector based on character properties
vec = np.random.randn(self.dim).astype(np.float32)
vec = vec / (np.abs(vec).sum() + 1e-8) # L1 normalize
# Cache it
self._char_cache[char] = vec
def _synthesize_random_crystal(self, token: str) -> np.ndarray:
"""Fallback: create a deterministic random crystal based on token string"""
import hashlib
# Create deterministic seed from token
seed = int(hashlib.md5(token.encode()).hexdigest()[:8], 16)
np.random.seed(seed)
# Generate a random crystal
X = np.random.randn(self.vertex_count, self.dim).astype(np.float32)
X = self._finalize_crystal(X)
return X
def _preload(self, tokens: List[str]):
"""Preload specific tokens into cache"""
print(f"Preloading {len(tokens)} tokens...")
for token in tokens:
tid = self._token_to_id.get(token)
if tid:
self._load_crystal(tid)
# Override base methods to use lazy loading with synthesis
def embedding(self, token_or_id: Union[str, int], generate: bool = False) -> Optional[np.ndarray]:
"""Get embedding, loading if necessary, synthesizing if requested"""
# Handle token ID input
if isinstance(token_or_id, int):
tid = token_or_id
token = self._id_to_token.get(tid)
else:
token = token_or_id
tid = self._token_to_id.get(token)
if tid is not None:
# Check cache first
if tid in self._id_to_vec:
return self._id_to_vec[tid]
# Load on demand
return self._load_crystal(tid)
# Token not found - synthesize if requested
if generate and token is not None:
tid = self._synthesize_token(token)
return self._id_to_vec[tid]
return None
def pooled(self, token_or_id: Union[str, int], method: str = "mean", generate: bool = False) -> Optional[np.ndarray]:
"""Get pooled vector, loading if necessary, synthesizing if requested"""
# Handle token ID input
if isinstance(token_or_id, int):
tid = token_or_id
token = self._id_to_token.get(tid)
else:
token = token_or_id
tid = self._token_to_id.get(token)
if tid is not None:
# Check pooled cache
if tid in self._pooled_lru:
return self._pooled_lru[tid]
if tid in self._id_to_pooled:
return self._id_to_pooled[tid]
# Load crystal and compute pooled
X = self.embedding(tid, generate=False)
if X is not None:
if method == "mean":
pooled = X.mean(axis=0)
self._pooled_lru[tid] = pooled
return pooled
elif method == "first":
return X[0]
elif method == "sum":
return X.sum(axis=0)
else:
raise ValueError(f"Unknown pooling method: {method}")
# Token not found - synthesize if requested
if generate and token is not None:
tid = self._synthesize_token(token)
return self.pooled(tid, method=method, generate=False)
return None
def encode(self, token: str, *, return_id: bool = False, generate: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
"""Encode token, loading if necessary, synthesizing if requested"""
tid = self._token_to_id.get(token)
if tid is None:
if generate:
# Synthesize new token
tid = self._synthesize_token(token)
X = self._id_to_vec[tid]
else:
# Fallback to UNK
unk_id = self._token_to_id.get("<unk>")
if unk_id is None:
# No UNK token - try to synthesize if allowed
if generate:
tid = self._synthesize_token(token)
X = self._id_to_vec[tid]
else:
raise KeyError(f"Token '{token}' not found and no <unk> token available")
else:
X = self.embedding(unk_id, generate=False)
tid = unk_id
else:
X = self.embedding(tid, generate=False)
if X is None:
raise RuntimeError(f"Failed to load embedding for token '{token}'")
return (X, tid) if return_id else X
def get_score(self, token_or_id: Union[str, int]) -> float:
"""Get token score"""
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
if tid is None or tid not in self._valid_token_ids:
return -100.0
# Load volume if needed
if tid not in self._id_to_volume:
self._load_crystal(tid)
vol = self._id_to_volume.get(tid, 1.0)
return float(np.clip(vol / 10.0, 0.01, 1.0))
def encode_batch(self, tokens: Union[str, List[str]],
*, return_ids: bool = False,
prefetch: bool = True,
generate: bool = False) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[int]]]:
"""
Encode a batch of tokens efficiently.
Args:
tokens: Either a string (will be tokenized) or list of token strings
return_ids: Whether to return token IDs alongside embeddings
prefetch: Whether to prefetch all tokens before encoding
generate: If True, synthesize missing tokens
Returns:
List of embeddings, optionally with list of token IDs
"""
# Handle string input - tokenize it
if isinstance(tokens, str):
tokens = self.tokenizer(tokens)
if not isinstance(tokens, list):
raise TypeError(f"Expected str or List[str], got {type(tokens)}")
# Track which tokens need synthesis
tokens_to_synthesize = []
if generate:
for token in tokens:
if token not in self._token_to_id:
tokens_to_synthesize.append(token)
# Warn about batch synthesis if needed
if tokens_to_synthesize and not self.silent and not SILENT_MODE:
warnings.warn(
f"{len(tokens_to_synthesize)} tokens synthesized - ensure you synthesize your tokens ahead of time. "
f"Synthesized: {tokens_to_synthesize[:5]}{'...' if len(tokens_to_synthesize) > 5 else ''}",
UserWarning,
stacklevel=2
)
# Prefetch existing tokens if requested
if prefetch:
self._prefetch_batch([t for t in tokens if t in self._token_to_id])
# Encode all tokens
embeddings = []
ids = []
for token in tokens:
if return_ids:
emb, tid = self.encode(token, return_id=True, generate=generate)
embeddings.append(emb)
ids.append(tid)
else:
emb = self.encode(token, return_id=False, generate=generate)
embeddings.append(emb)
return (embeddings, ids) if return_ids else embeddings
def _prefetch_batch(self, tokens: List[str]):
"""
Prefetch a batch of tokens efficiently.
"""
# Collect all token IDs that need loading
tokens_to_load = []
for token in tokens:
tid = self._token_to_id.get(token)
if tid and tid not in self._crystal_cache and tid not in self._id_to_vec:
tokens_to_load.append(tid)
if not tokens_to_load:
return # Everything already cached
# Load crystals for each token
for tid in tokens_to_load:
self._load_crystal(tid)
def cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
stats = super().cache_stats()
stats.update({
"crystal_cache_size": len(self._crystal_cache),
"pooled_lru_size": len(self._pooled_lru),
"rows_cached": len(self._row_data),
"tokens_indexed": len(self._token_index),
"ids_indexed": len(self._id_index),
"synthesized_tokens": len(self._synthesized_tokens),
})
return stats
def evict_from_cache(self, tokens: Optional[List[str]] = None):
"""Manually evict tokens from cache to free memory"""
if tokens is None:
# Clear all caches
self._crystal_cache.clear()
self._pooled_lru.clear()
self._id_to_vec.clear()
self._id_to_pooled.clear()
self._row_data.clear()
else:
# Evict specific tokens
for token in tokens:
tid = self._token_to_id.get(token)
if tid:
self._crystal_cache.pop(tid, None)
self._pooled_lru.pop(tid, None)
self._id_to_vec.pop(tid, None)
self._id_to_pooled.pop(tid, None)
def get_synthesized_tokens(self) -> List[str]:
"""Get list of all tokens that were synthesized at runtime"""
return list(self._synthesized_tokens)
def is_synthesized(self, token: str) -> bool:
"""Check if a token was synthesized at runtime"""
return token in self._synthesized_tokens
# For 100-dimensional embeddings
vocab = LazyGeometricVocab(
repo_id="AbstractPhil/geometric-vocab",
dim=64,
name="unicode_64d", # Specifies the dimension config
split="train", # Now always "train"
stream=False,
cache_size=1024,
silent=False
)
FROZEN_VOCAB = []