|
|
from __future__ import annotations |
|
|
import torch |
|
|
import numpy as np |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, Union, Tuple, Optional, Callable, Any, List |
|
|
import warnings |
|
|
from collections import defaultdict |
|
|
import datasets |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
try: |
|
|
import faiss |
|
|
FAISS_AVAILABLE = True |
|
|
except ImportError: |
|
|
FAISS_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from sklearn.neighbors import NearestNeighbors |
|
|
SKLEARN_AVAILABLE = True |
|
|
except ImportError: |
|
|
SKLEARN_AVAILABLE = False |
|
|
|
|
|
|
|
|
class SpatialIndex: |
|
|
"""Spatial indexing for fast similarity search.""" |
|
|
|
|
|
def __init__(self, vectors: np.ndarray, token_ids: List[int], method: str = "auto"): |
|
|
self.token_ids = np.array(token_ids) |
|
|
self.method = method |
|
|
self._index = None |
|
|
|
|
|
if method == "auto": |
|
|
if FAISS_AVAILABLE and vectors.shape[0] > 1000: |
|
|
method = "faiss" |
|
|
elif SKLEARN_AVAILABLE: |
|
|
method = "sklearn" |
|
|
else: |
|
|
method = "linear" |
|
|
|
|
|
self._build_index(vectors, method) |
|
|
|
|
|
def _build_index(self, vectors: np.ndarray, method: str): |
|
|
if method == "faiss" and FAISS_AVAILABLE: |
|
|
|
|
|
vectors_l2 = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-8) |
|
|
self._index = faiss.IndexFlatIP(vectors_l2.shape[1]) |
|
|
self._index.add(vectors_l2.astype(np.float32)) |
|
|
self.method = "faiss" |
|
|
|
|
|
elif method == "sklearn" and SKLEARN_AVAILABLE: |
|
|
|
|
|
self._index = NearestNeighbors( |
|
|
metric='manhattan', |
|
|
algorithm='ball_tree', |
|
|
n_jobs=-1 |
|
|
).fit(vectors) |
|
|
self.method = "sklearn" |
|
|
else: |
|
|
|
|
|
self._vectors = vectors |
|
|
self.method = "linear" |
|
|
|
|
|
def search_radius(self, query_vector: np.ndarray, max_distance: float, max_results: int = 1000) -> Tuple[ |
|
|
List[int], List[float]]: |
|
|
"""Find all points within max_distance using L1 metric.""" |
|
|
if self.method == "sklearn": |
|
|
indices = self._index.radius_neighbors([query_vector], radius=max_distance)[1][0] |
|
|
if len(indices) > max_results: |
|
|
|
|
|
distances = np.sum(np.abs(self._vectors[indices] - query_vector), axis=1) |
|
|
top_k = np.argsort(distances)[:max_results] |
|
|
indices = indices[top_k] |
|
|
distances = np.sum(np.abs(self._vectors[indices] - query_vector), axis=1) |
|
|
return self.token_ids[indices].tolist(), distances.tolist() |
|
|
|
|
|
elif self.method == "faiss": |
|
|
|
|
|
query_l2 = query_vector / (np.linalg.norm(query_vector) + 1e-8) |
|
|
similarities, indices = self._index.search(query_l2.reshape(1, -1).astype(np.float32), max_results) |
|
|
|
|
|
threshold_sim = 1.0 - max_distance |
|
|
mask = similarities[0] >= threshold_sim |
|
|
return self.token_ids[indices[0][mask]].tolist(), (1.0 - similarities[0][mask]).tolist() |
|
|
|
|
|
else: |
|
|
distances = np.sum(np.abs(self._vectors - query_vector), axis=1) |
|
|
mask = distances <= max_distance |
|
|
if np.sum(mask) > max_results: |
|
|
indices = np.argsort(distances)[:max_results] |
|
|
mask = np.zeros_like(distances, dtype=bool) |
|
|
mask[indices] = True |
|
|
return self.token_ids[mask].tolist(), distances[mask].tolist() |
|
|
|
|
|
|
|
|
class GeometricVocab(ABC): |
|
|
""" |
|
|
Optimized geometric vocabulary with spatial indexing and caching. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int): |
|
|
self.dim = int(dim) |
|
|
self._token_to_id: Dict[str, int] = {} |
|
|
self._id_to_token: Dict[int, str] = {} |
|
|
self._id_to_vec: Dict[int, np.ndarray] = {} |
|
|
self._id_to_volume: Dict[int, float] = {} |
|
|
self._id_to_provenance: Dict[int, dict] = {} |
|
|
self._valid_token_ids: set[int] = set() |
|
|
|
|
|
|
|
|
self._normalized_cache: Dict[int, np.ndarray] = {} |
|
|
self._pooled_cache: Dict[int, np.ndarray] = {} |
|
|
self._spatial_index: Optional[SpatialIndex] = None |
|
|
self._index_dirty = False |
|
|
|
|
|
|
|
|
self._char_cache: Dict[str, np.ndarray] = {} |
|
|
self._char_lookups_saved = 0 |
|
|
|
|
|
def _invalidate_caches(self): |
|
|
"""Invalidate caches when vocabulary changes.""" |
|
|
self._normalized_cache.clear() |
|
|
self._pooled_cache.clear() |
|
|
self._spatial_index = None |
|
|
self._index_dirty = True |
|
|
|
|
|
|
|
|
def _ensure_spatial_index(self): |
|
|
"""Build spatial index if needed.""" |
|
|
if self._spatial_index is None or self._index_dirty: |
|
|
if len(self._valid_token_ids) < 10: |
|
|
return |
|
|
|
|
|
pooled_vectors = [] |
|
|
token_ids = [] |
|
|
for tid in sorted(self._valid_token_ids): |
|
|
pooled_vec = self._get_cached_pooled(tid) |
|
|
if pooled_vec is not None: |
|
|
pooled_vectors.append(pooled_vec) |
|
|
token_ids.append(tid) |
|
|
|
|
|
if pooled_vectors: |
|
|
self._spatial_index = SpatialIndex( |
|
|
np.array(pooled_vectors), |
|
|
token_ids, |
|
|
method="auto" |
|
|
) |
|
|
self._index_dirty = False |
|
|
|
|
|
def _get_cached_pooled(self, token_id: int) -> Optional[np.ndarray]: |
|
|
"""Get pooled vector with caching.""" |
|
|
if token_id in self._pooled_cache: |
|
|
return self._pooled_cache[token_id] |
|
|
|
|
|
if token_id in self._id_to_vec: |
|
|
X = self._id_to_vec[token_id] |
|
|
pooled = X.mean(axis=0) |
|
|
self._pooled_cache[token_id] = pooled |
|
|
return pooled |
|
|
return None |
|
|
|
|
|
def _get_cached_normalized(self, token_id: int) -> Optional[np.ndarray]: |
|
|
"""Get L1-normalized pooled vector with caching.""" |
|
|
if token_id in self._normalized_cache: |
|
|
return self._normalized_cache[token_id] |
|
|
|
|
|
pooled = self._get_cached_pooled(token_id) |
|
|
if pooled is not None: |
|
|
normalized = pooled / (np.abs(pooled).sum() + 1e-8) |
|
|
self._normalized_cache[token_id] = normalized |
|
|
return normalized |
|
|
return None |
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
def encode(self, token: str, *, return_id: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]: |
|
|
raise NotImplementedError |
|
|
|
|
|
@abstractmethod |
|
|
def get_score(self, token_or_id: Union[str, int]) -> float: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def decode(self, token_id: int, fallback: str = "<unk>") -> Optional[str]: |
|
|
if token_id in self._id_to_token: |
|
|
return self._id_to_token[token_id] |
|
|
return fallback if fallback in self._token_to_id else None |
|
|
|
|
|
def decode_with_provenance(self, token_id: int, fallback: str = "<unk>") -> Tuple[Optional[str], Optional[dict]]: |
|
|
tok = self.decode(token_id, fallback=fallback) |
|
|
prov = self._id_to_provenance.get(token_id) |
|
|
return tok, prov |
|
|
|
|
|
def provenance(self, token_or_id: Union[str, int]) -> Optional[dict]: |
|
|
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id) |
|
|
return self._id_to_provenance.get(tid) |
|
|
|
|
|
def embedding(self, token_or_id: Union[str, int]) -> Optional[np.ndarray]: |
|
|
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id) |
|
|
return self._id_to_vec.get(tid) |
|
|
|
|
|
def pooled(self, token_or_id: Union[str, int], method: str = "mean") -> Optional[np.ndarray]: |
|
|
"""Optimized pooled method with character caching""" |
|
|
|
|
|
|
|
|
if isinstance(token_or_id, str) and len(token_or_id) == 1: |
|
|
if token_or_id in self._char_cache: |
|
|
self._char_lookups_saved += 1 |
|
|
return self._char_cache[token_or_id].copy() |
|
|
|
|
|
|
|
|
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id) |
|
|
if tid is None: |
|
|
return None |
|
|
|
|
|
if method == "mean": |
|
|
pooled = self._get_cached_pooled(tid) |
|
|
|
|
|
|
|
|
if pooled is not None and isinstance(token_or_id, str) and len(token_or_id) == 1: |
|
|
self._char_cache[token_or_id] = pooled.copy() |
|
|
|
|
|
return pooled |
|
|
|
|
|
|
|
|
X = self._id_to_vec.get(tid) |
|
|
if X is None: |
|
|
return None |
|
|
if method == "first": |
|
|
return X[0] |
|
|
if method == "sum": |
|
|
return X.sum(axis=0) |
|
|
raise ValueError(f"Invalid pooling method: {method}") |
|
|
|
|
|
def pooled_batch(self, tokens: List[Union[str, int]], method: str = "mean") -> List[Optional[np.ndarray]]: |
|
|
"""Batch pooling with character-level caching for efficiency""" |
|
|
results = [] |
|
|
|
|
|
for token in tokens: |
|
|
|
|
|
results.append(self.pooled(token, method)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def similarity(self, token_a: Union[str, int], token_b: Union[str, int]) -> float: |
|
|
""" |
|
|
Optimized L1-normalized directional similarity using cached vectors. |
|
|
""" |
|
|
tid_a = token_a if isinstance(token_a, int) else self._token_to_id.get(token_a) |
|
|
tid_b = token_b if isinstance(token_b, int) else self._token_to_id.get(token_b) |
|
|
|
|
|
if tid_a is None or tid_b is None: |
|
|
return -1.0 |
|
|
|
|
|
a_norm = self._get_cached_normalized(tid_a) |
|
|
b_norm = self._get_cached_normalized(tid_b) |
|
|
|
|
|
if a_norm is None or b_norm is None: |
|
|
return -1.0 |
|
|
|
|
|
return float(np.dot(a_norm, b_norm)) |
|
|
|
|
|
def similarity_magnitude(self, token_a: Union[str, int], token_b: Union[str, int]) -> float: |
|
|
""" |
|
|
Raw dot-product using cached pooled vectors. |
|
|
""" |
|
|
tid_a = token_a if isinstance(token_a, int) else self._token_to_id.get(token_a) |
|
|
tid_b = token_b if isinstance(token_b, int) else self._token_to_id.get(token_b) |
|
|
|
|
|
if tid_a is None or tid_b is None: |
|
|
return -1.0 |
|
|
|
|
|
a = self._get_cached_pooled(tid_a) |
|
|
b = self._get_cached_pooled(tid_b) |
|
|
|
|
|
if a is None or b is None: |
|
|
return -1.0 |
|
|
|
|
|
return float(np.dot(a, b)) |
|
|
|
|
|
|
|
|
def extract_band(self, trajectory: np.ndarray, max_angle: float = 0.3, method: str = "pooled") -> Dict[ |
|
|
str, np.ndarray]: |
|
|
""" |
|
|
Optimized spatial search using indexing when available. |
|
|
""" |
|
|
if trajectory.ndim == 2: |
|
|
direction = trajectory.mean(0) |
|
|
else: |
|
|
direction = trajectory |
|
|
direction = direction / (np.abs(direction).sum() + 1e-8) |
|
|
|
|
|
|
|
|
self._ensure_spatial_index() |
|
|
if self._spatial_index is not None: |
|
|
try: |
|
|
|
|
|
max_distance = max_angle * 2.0 |
|
|
token_ids, distances = self._spatial_index.search_radius( |
|
|
direction, max_distance, max_results=1000 |
|
|
) |
|
|
|
|
|
|
|
|
out: Dict[str, np.ndarray] = {} |
|
|
for tid in token_ids: |
|
|
tok = self._id_to_token.get(tid) |
|
|
if tok is None: |
|
|
continue |
|
|
v_norm = self._get_cached_normalized(tid) |
|
|
if v_norm is not None and float(np.dot(v_norm, direction)) >= 1.0 - max_angle: |
|
|
out[tok] = self._id_to_vec[tid] |
|
|
return out |
|
|
|
|
|
except Exception as e: |
|
|
warnings.warn(f"Spatial index search failed: {e}, falling back to linear") |
|
|
|
|
|
|
|
|
out: Dict[str, np.ndarray] = {} |
|
|
for tok, tid in self._token_to_id.items(): |
|
|
v_norm = self._get_cached_normalized(tid) |
|
|
if v_norm is not None and float(np.dot(v_norm, direction)) >= 1.0 - max_angle: |
|
|
out[tok] = self._id_to_vec[tid] |
|
|
return out |
|
|
|
|
|
def find_similar_tokens(self, token: Union[str, int], k: int = 10, min_similarity: float = 0.5) -> List[ |
|
|
Tuple[str, float]]: |
|
|
""" |
|
|
Find k most similar tokens using spatial indexing when available. |
|
|
""" |
|
|
tid = token if isinstance(token, int) else self._token_to_id.get(token) |
|
|
if tid is None: |
|
|
return [] |
|
|
|
|
|
query_vec = self._get_cached_normalized(tid) |
|
|
if query_vec is None: |
|
|
return [] |
|
|
|
|
|
self._ensure_spatial_index() |
|
|
if self._spatial_index is not None: |
|
|
try: |
|
|
|
|
|
max_distance = (1.0 - min_similarity) * 2.0 |
|
|
token_ids, _ = self._spatial_index.search_radius( |
|
|
query_vec, max_distance, max_results=k * 3 |
|
|
) |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
for tid_cand in token_ids: |
|
|
if tid_cand == tid: |
|
|
continue |
|
|
sim = self.similarity(tid, tid_cand) |
|
|
if sim >= min_similarity: |
|
|
tok = self._id_to_token.get(tid_cand) |
|
|
if tok: |
|
|
similarities.append((tok, sim)) |
|
|
|
|
|
return sorted(similarities, key=lambda x: x[1], reverse=True)[:k] |
|
|
|
|
|
except Exception as e: |
|
|
warnings.warn(f"Spatial similarity search failed: {e}, falling back to linear") |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
for tok_cand, tid_cand in self._token_to_id.items(): |
|
|
if tid_cand == tid: |
|
|
continue |
|
|
sim = self.similarity(tid, tid_cand) |
|
|
if sim >= min_similarity: |
|
|
similarities.append((tok_cand, sim)) |
|
|
|
|
|
return sorted(similarities, key=lambda x: x[1], reverse=True)[:k] |
|
|
|
|
|
|
|
|
def _helpers(self) -> Dict[str, Callable[..., np.ndarray]]: |
|
|
def _emb(x): |
|
|
e = self.embedding(x) |
|
|
return None if e is None else np.asarray(e, np.float32) |
|
|
|
|
|
def _poo(x): |
|
|
p = self.pooled(x) |
|
|
return None if p is None else np.asarray(p, np.float32) |
|
|
|
|
|
def _chars(s): |
|
|
|
|
|
return self.pooled_batch(list(s)) if isinstance(s, str) else None |
|
|
|
|
|
return {"embedding": _emb, "pooled": _poo, "chars_pooled": _chars} |
|
|
|
|
|
|
|
|
def _default_create_crystal(self, config: dict, callback: Callable[..., np.ndarray]) -> np.ndarray: |
|
|
""" |
|
|
Deterministic default when user leaves callback/create_crystal=None. |
|
|
""" |
|
|
pool_type = config.get("pool_type") or "unicode" |
|
|
H = config["helpers"] |
|
|
token_plain = str(config["data"]["token"]) |
|
|
d = int(config["dim"]) |
|
|
|
|
|
c_uni = self._compose_unicode_center(token_plain, H, pool_type, d) |
|
|
c_defs = self._compose_wordnet_center(config.get("additional_definitions", []), H, pool_type, d) |
|
|
|
|
|
if pool_type == "combination": |
|
|
parts = [v for v in (c_uni, c_defs) if v is not None] |
|
|
c = np.mean(np.stack(parts, 0), 0) if parts else np.zeros(d, np.float32) |
|
|
elif pool_type == "wordnet": |
|
|
c = c_defs if c_defs is not None else np.zeros(d, np.float32) |
|
|
else: |
|
|
c = c_uni if c_uni is not None else np.zeros(d, np.float32) |
|
|
|
|
|
|
|
|
l1 = float(np.abs(c).sum()) + 1e-8 |
|
|
c = c / l1 |
|
|
return self._deterministic_pentachoron(c) |
|
|
|
|
|
def _default_unicode_callback(self, name: str, **kwargs) -> np.ndarray: |
|
|
raise NotImplementedError("Default callback is not invoked directly.") |
|
|
|
|
|
|
|
|
def _compose_unicode_center( |
|
|
self, token_plain: str, H, pool_type: Optional[str], dim: int |
|
|
) -> Optional[np.ndarray]: |
|
|
""" |
|
|
Build a center vector from the token's Unicode characters - OPTIMIZED. |
|
|
""" |
|
|
|
|
|
char_list = list(token_plain) |
|
|
pooled_chars = self.pooled_batch(char_list) |
|
|
|
|
|
vecs: List[np.ndarray] = [] |
|
|
for pooled_v in pooled_chars: |
|
|
if pooled_v is None: |
|
|
continue |
|
|
v = np.asarray(pooled_v, np.float32) |
|
|
if v.shape[0] != dim: |
|
|
raise ValueError(f"Unicode pooled dim mismatch: got {v.shape[0]}, expected {dim}") |
|
|
vecs.append(v) |
|
|
|
|
|
if not vecs: |
|
|
return None |
|
|
|
|
|
stacked = np.stack(vecs, 0) |
|
|
|
|
|
if pool_type in (None, "unicode", "mean"): |
|
|
c = stacked.mean(axis=0) |
|
|
elif pool_type == "abs": |
|
|
c = np.abs(stacked).mean(axis=0) |
|
|
elif pool_type == "dot": |
|
|
c = stacked.mean(axis=0) |
|
|
c = c / (np.abs(c).sum() + 1e-8) |
|
|
elif pool_type == "mse": |
|
|
c = (stacked ** 2).mean(axis=0) |
|
|
elif pool_type == "max": |
|
|
c = stacked.max(axis=0) |
|
|
else: |
|
|
raise ValueError(f"Unsupported pool_type '{pool_type}'") |
|
|
|
|
|
return c.astype(np.float32, copy=False) |
|
|
|
|
|
def _compose_wordnet_center( |
|
|
self, definitions: List[str], H, pool_type: Optional[str], dim: int |
|
|
) -> Optional[np.ndarray]: |
|
|
"""Build a center vector from definition text characters - OPTIMIZED.""" |
|
|
|
|
|
all_chars = [] |
|
|
for text in definitions: |
|
|
all_chars.extend(list(str(text))) |
|
|
|
|
|
|
|
|
pooled_chars = self.pooled_batch(all_chars) |
|
|
|
|
|
vecs: List[np.ndarray] = [] |
|
|
for pooled_v in pooled_chars: |
|
|
if pooled_v is None: |
|
|
continue |
|
|
v = np.asarray(pooled_v, np.float32) |
|
|
if v.shape[0] != dim: |
|
|
raise ValueError(f"Definition pooled dim mismatch: got {v.shape[0]}, expected {dim}") |
|
|
vecs.append(v) |
|
|
|
|
|
if not vecs: |
|
|
return None |
|
|
|
|
|
stacked = np.stack(vecs, 0) |
|
|
|
|
|
if pool_type in (None, "unicode", "mean"): |
|
|
c = stacked.mean(axis=0) |
|
|
elif pool_type == "abs": |
|
|
c = np.abs(stacked).mean(axis=0) |
|
|
elif pool_type == "dot": |
|
|
c = stacked.mean(axis=0) |
|
|
c = c / (np.abs(c).sum() + 1e-8) |
|
|
elif pool_type == "mse": |
|
|
c = (stacked ** 2).mean(axis=0) |
|
|
elif pool_type == "max": |
|
|
c = stacked.max(axis=0) |
|
|
else: |
|
|
raise ValueError(f"Unsupported pool_type '{pool_type}'") |
|
|
|
|
|
return c.astype(np.float32, copy=False) |
|
|
|
|
|
def _deterministic_pentachoron(self, center_vec: np.ndarray) -> np.ndarray: |
|
|
"""Universal pentachoron inflation (deterministic; overrideable).""" |
|
|
d = center_vec.shape[0] |
|
|
proposals = np.stack([ |
|
|
center_vec, |
|
|
np.roll(center_vec, 1), |
|
|
np.roll(center_vec, 3) * np.sign(center_vec + 1e-8), |
|
|
np.roll(center_vec, 7) - center_vec, |
|
|
np.roll(center_vec, 11) + center_vec, |
|
|
], 0).astype(np.float32) |
|
|
|
|
|
|
|
|
norms = np.sum(np.abs(proposals), axis=1, keepdims=True) + 1e-8 |
|
|
Q = proposals / norms |
|
|
|
|
|
|
|
|
for i in range(5): |
|
|
for j in range(i): |
|
|
Q[i] -= np.dot(Q[i], Q[j]) * Q[j] |
|
|
Q[i] /= (np.sum(np.abs(Q[i])) + 1e-8) |
|
|
|
|
|
gamma = np.array([1.0, 0.9, -0.8, 1.1, 1.2], np.float32) |
|
|
X = np.zeros((5, d), np.float32) |
|
|
for i in range(5): |
|
|
X[i] = center_vec + gamma[i] * Q[i] |
|
|
return X - X.mean(0, keepdims=True) |
|
|
|
|
|
|
|
|
def _finalize_crystal(self, X: np.ndarray) -> np.ndarray: |
|
|
X = np.asarray(X, np.float32, order='C') |
|
|
if X.shape != (5, self.dim): |
|
|
raise ValueError(f"Crystal must be shape (5, {self.dim}); got {X.shape}.") |
|
|
return X - X.mean(0, keepdims=True) |
|
|
|
|
|
def _auto_provenance_from_cfg(self, cfg: Dict[str, Any]) -> dict: |
|
|
token = cfg["data"]["token"] |
|
|
prov = { |
|
|
"source": "special/compose", |
|
|
"token": token, |
|
|
"pool_type": cfg.get("pool_type") or "unicode", |
|
|
"components": list(token), |
|
|
"additional_definitions": list(cfg.get("additional_definitions", [])), |
|
|
} |
|
|
if cfg.get("antonyms"): |
|
|
prov["antonyms"] = list(cfg["antonyms"]) |
|
|
if cfg.get("inversion_formula") is not None: |
|
|
prov["inversion_formula"] = "user_supplied" |
|
|
return prov |
|
|
|
|
|
def _finalize_crystal_and_provenance( |
|
|
self, product: Union[np.ndarray, Dict[str, Any]], cfg: Dict[str, Any] |
|
|
) -> Tuple[np.ndarray, dict]: |
|
|
|
|
|
if isinstance(product, np.ndarray): |
|
|
X = self._finalize_crystal(product) |
|
|
prov = self._auto_provenance_from_cfg(cfg) |
|
|
return X, prov |
|
|
|
|
|
|
|
|
if not isinstance(product, dict): |
|
|
raise TypeError( |
|
|
"create_crystal must return ndarray or dict with {'base':..., 'ops':..., 'provenance':...}.") |
|
|
base = np.asarray(product["base"], np.float32) |
|
|
X = base |
|
|
for op in product.get("ops", []): |
|
|
name = op.get("name") |
|
|
if name == "center": |
|
|
X -= X.mean(0, keepdims=True) |
|
|
elif name == "scale": |
|
|
X *= float(op.get("k", 1.0)) |
|
|
elif name == "translate": |
|
|
t = np.asarray(op.get("t"), np.float32) |
|
|
if t.shape != (self.dim,): |
|
|
raise ValueError(f"translate.t must be shape ({self.dim},)") |
|
|
X = X + t[None, :] |
|
|
elif name == "normalize_rows": |
|
|
n = np.sum(np.abs(X), axis=1, keepdims=True) + 1e-8 |
|
|
X = X / n |
|
|
elif name == "align_to": |
|
|
v = np.asarray(op.get("v"), np.float32) |
|
|
if v.shape != (self.dim,): |
|
|
raise ValueError(f"align_to.v must be shape ({self.dim},)") |
|
|
v = v / (np.abs(v).sum() + 1e-8) |
|
|
p = X.mean(0) |
|
|
p = p / (np.abs(p).sum() + 1e-8) |
|
|
alpha = float(op.get("alpha", 1.0)) |
|
|
X = X + alpha * (v - p)[None, :] |
|
|
else: |
|
|
raise ValueError(f"Unsupported op '{name}'") |
|
|
prov = dict(product.get("provenance", {})) or self._auto_provenance_from_cfg(cfg) |
|
|
return self._finalize_crystal(X), prov |
|
|
|
|
|
|
|
|
def _manifest_special_tokens( |
|
|
self, |
|
|
base_set: Dict[str, int], |
|
|
create_crystal: Callable[[dict, Callable[..., np.ndarray]], Union[np.ndarray, Dict[str, Any]]], |
|
|
callback: Optional[Callable[..., np.ndarray]], |
|
|
create_config: Dict[str, Any], |
|
|
) -> None: |
|
|
"""Universal, deterministic manifestor with character pre-caching.""" |
|
|
|
|
|
|
|
|
unique_chars = set() |
|
|
for name in base_set.keys(): |
|
|
token_plain = name.strip("<>").strip() |
|
|
unique_chars.update(token_plain) |
|
|
|
|
|
print(f"[⚡] Pre-caching {len(unique_chars)} unique characters...") |
|
|
for ch in unique_chars: |
|
|
_ = self.pooled(ch) |
|
|
|
|
|
helpers = self._helpers() |
|
|
|
|
|
for name, tid in base_set.items(): |
|
|
|
|
|
if tid in self._id_to_vec: |
|
|
self._token_to_id[name] = tid |
|
|
self._id_to_token.setdefault(tid, name) |
|
|
self._valid_token_ids.add(tid) |
|
|
continue |
|
|
|
|
|
|
|
|
cfg = { |
|
|
"dim": self.dim, |
|
|
"pool_type": create_config.get("pool_type", None), |
|
|
"special_tokens": create_config.get("special_tokens"), |
|
|
"additional_definitions": create_config.get("additional_definitions", []), |
|
|
"antonyms": create_config.get("antonyms"), |
|
|
"inversion_formula": create_config.get("inversion_formula"), |
|
|
"data": {"token": name.strip("<>").strip(), "token_id": tid, "origin": "special"}, |
|
|
"helpers": helpers, |
|
|
} |
|
|
|
|
|
if create_crystal is None: |
|
|
create_crystal = self._default_create_crystal |
|
|
|
|
|
product = create_crystal(cfg, callback) if callback is not None else create_crystal(cfg, |
|
|
self._default_unicode_callback) |
|
|
X, prov = self._finalize_crystal_and_provenance(product, cfg) |
|
|
|
|
|
|
|
|
self._token_to_id[name] = tid |
|
|
self._id_to_token[tid] = name |
|
|
self._id_to_vec[tid] = X.astype(np.float32, copy=False, order='C') |
|
|
self._id_to_provenance[tid] = prov |
|
|
self._valid_token_ids.add(tid) |
|
|
self._id_to_volume.setdefault(tid, 1.0) |
|
|
|
|
|
|
|
|
for alias in (cfg.get("special_tokens") or []): |
|
|
alias = str(alias) |
|
|
self._token_to_id[alias] = tid |
|
|
self._id_to_token.setdefault(tid, alias) |
|
|
if cfg.get("special_tokens"): |
|
|
self._id_to_provenance[tid].setdefault("aliases", list(cfg["special_tokens"])) |
|
|
|
|
|
|
|
|
antonyms = cfg.get("antonyms") or [] |
|
|
invf = cfg.get("inversion_formula") |
|
|
if invf: |
|
|
for anti in antonyms: |
|
|
if anti in base_set: |
|
|
anti_id = base_set[anti] |
|
|
if anti_id not in self._id_to_vec: |
|
|
X_inv = invf(X, cfg) |
|
|
X_inv = self._finalize_crystal(X_inv) |
|
|
self._token_to_id[anti] = anti_id |
|
|
self._id_to_token[anti_id] = anti |
|
|
self._id_to_vec[anti_id] = X_inv.astype(np.float32, copy=False, order='C') |
|
|
inv_prov = { |
|
|
"source": "inversion", |
|
|
"of_token": name, |
|
|
"of_token_id": tid, |
|
|
"pool_type": cfg.get("pool_type") or "unicode", |
|
|
"components": prov.get("components", []), |
|
|
"additional_definitions": cfg.get("additional_definitions", []), |
|
|
"ops": ["invert"], |
|
|
} |
|
|
self._id_to_provenance[anti_id] = inv_prov |
|
|
self._valid_token_ids.add(anti_id) |
|
|
self._id_to_volume.setdefault(anti_id, 1.0) |
|
|
|
|
|
|
|
|
self._invalidate_caches() |
|
|
|
|
|
if self._char_lookups_saved > 0: |
|
|
print(f"[✅] Character cache saved {self._char_lookups_saved} lookups") |
|
|
|
|
|
|
|
|
def vocab_size(self) -> int: |
|
|
return len(self._token_to_id) |
|
|
|
|
|
def token_to_id(self, token: str) -> Optional[int]: |
|
|
return self._token_to_id.get(token) |
|
|
|
|
|
def id_to_token(self, token_id: int) -> Optional[str]: |
|
|
return self._id_to_token.get(token_id) |
|
|
|
|
|
def cache_stats(self) -> Dict[str, int]: |
|
|
"""Get cache statistics.""" |
|
|
return { |
|
|
"normalized_cache_size": len(self._normalized_cache), |
|
|
"pooled_cache_size": len(self._pooled_cache), |
|
|
"char_cache_size": len(self._char_cache), |
|
|
"char_lookups_saved": self._char_lookups_saved, |
|
|
"spatial_index_size": len(self._spatial_index.token_ids) if self._spatial_index else 0, |
|
|
"vocab_size": len(self._valid_token_ids) |
|
|
} |
|
|
|
|
|
def clear_caches(self): |
|
|
"""Clear all caches to free memory.""" |
|
|
self._invalidate_caches() |
|
|
self._char_cache.clear() |
|
|
self._char_lookups_saved = 0 |
|
|
|
|
|
|
|
|
from typing import List, Dict, Union, Optional, Tuple, Callable, Any |
|
|
|
|
|
class PretrainedGeometricVocab(GeometricVocab): |
|
|
""" |
|
|
Parquet-backed deterministic vocab with columnar load, duplicate-mean aggregation, |
|
|
pooled caching, and fast path for flat crystals. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
repo_id: str, |
|
|
dim: int, |
|
|
*, |
|
|
subset: str = "unicode", |
|
|
split: str = "train_100d", |
|
|
base_set: Optional[Dict[str, int]] = None, |
|
|
create_config: Optional[Dict[str, Any]] = None, |
|
|
create_crystal: Optional[Callable[[dict, Callable[..., np.ndarray]], Union[np.ndarray, Dict[str, Any]]]] = None, |
|
|
callback: Optional[Callable[..., np.ndarray]] = None, |
|
|
manifest_specials: bool = True, |
|
|
|
|
|
store: str = "full", |
|
|
reshape_order: str = "C", |
|
|
vertex_count: int = 5, |
|
|
infer_dim: bool = True, |
|
|
strict_shapes: bool = False, |
|
|
|
|
|
finalize_mode: str = "post_mean", |
|
|
cache_pooled: bool = True, |
|
|
streaming=False, |
|
|
): |
|
|
super().__init__(dim) |
|
|
self.repo_id = str(repo_id) |
|
|
self._id_to_pooled: Dict[int, np.ndarray] = {} |
|
|
|
|
|
|
|
|
ds = load_dataset(self.repo_id, split=split) |
|
|
have = set(ds.column_names) |
|
|
wanted = ["token_id", "token", "crystal", "volume"] |
|
|
keep = [c for c in wanted if c in have] |
|
|
drop = [c for c in ds.column_names if c not in keep] |
|
|
if drop: |
|
|
ds = ds.remove_columns(drop) |
|
|
ds = ds.with_format("numpy", columns=keep) |
|
|
|
|
|
ids = ds["token_id"] if "token_id" in keep else np.array([], dtype=np.int64) |
|
|
toks = ds["token"] if "token" in keep else np.array([], dtype=object) |
|
|
cryst= ds["crystal"] if "crystal" in keep else np.array([], dtype=object) |
|
|
vols = ds["volume"] if "volume" in keep else None |
|
|
|
|
|
ids = np.asarray(ids).astype(np.int64, copy=False) |
|
|
toks = np.asarray(toks) |
|
|
|
|
|
|
|
|
def _coerce(raw: Any) -> np.ndarray: |
|
|
X = np.asarray(raw, np.float32) |
|
|
if X.ndim == 2: |
|
|
V, D = int(X.shape[0]), int(X.shape[1]) |
|
|
if V != vertex_count: |
|
|
raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.") |
|
|
if D != self.dim: |
|
|
if infer_dim: self.dim = D |
|
|
else: raise ValueError(f"Dim mismatch: got {D}, expected {self.dim}.") |
|
|
return X |
|
|
if X.ndim == 1: |
|
|
n = int(X.size) |
|
|
if n == vertex_count * self.dim: |
|
|
return np.reshape(X, (vertex_count, self.dim), order=reshape_order) |
|
|
if infer_dim and n % vertex_count == 0: |
|
|
self.dim = n // vertex_count |
|
|
return np.reshape(X, (vertex_count, self.dim), order=reshape_order) |
|
|
if n == self.dim: |
|
|
c = X / (np.abs(X).sum() + 1e-8) |
|
|
return self._deterministic_pentachoron(c) |
|
|
raise ValueError(f"Unsupported crystal shape {X.shape if isinstance(X, np.ndarray) else type(X)}.") |
|
|
|
|
|
def _finalize_if_needed(X: np.ndarray) -> np.ndarray: |
|
|
if finalize_mode == "none": |
|
|
return np.asarray(X, np.float32, order="C") |
|
|
elif finalize_mode == "post_mean": |
|
|
return self._finalize_crystal(X) |
|
|
else: |
|
|
raise ValueError(f"finalize_mode must be 'none' or 'post_mean', got {finalize_mode!r}") |
|
|
|
|
|
vols_f = np.asarray(vols, dtype=np.float32) if vols is not None else None |
|
|
|
|
|
|
|
|
|
|
|
fastpath_ok = False |
|
|
A = None |
|
|
try: |
|
|
A = np.stack(cryst) |
|
|
if A.ndim == 2 and A.dtype != object: |
|
|
A = A.astype(np.float32, copy=False) |
|
|
L = A.shape[1] |
|
|
if L % vertex_count == 0: |
|
|
|
|
|
D = L // vertex_count |
|
|
if self.dim != D: |
|
|
if infer_dim: |
|
|
self.dim = int(D) |
|
|
else: |
|
|
raise ValueError(f"Dim mismatch: got D={D}, expected dim={self.dim}.") |
|
|
fastpath_ok = True |
|
|
except Exception: |
|
|
fastpath_ok = False |
|
|
|
|
|
if fastpath_ok and A is not None and len(ids) > 0: |
|
|
|
|
|
V = vertex_count |
|
|
D = self.dim |
|
|
A = A.reshape(-1, V, D, order=reshape_order) |
|
|
|
|
|
|
|
|
order = np.argsort(ids, kind="stable") |
|
|
ids_sorted = ids[order] |
|
|
A_sorted = A[order] |
|
|
vols_sorted = vols_f[order] if vols_f is not None else None |
|
|
|
|
|
uniq_ids, idx, counts = np.unique(ids_sorted, return_index=True, return_counts=True) |
|
|
sums = np.add.reduceat(A_sorted, idx, axis=0) |
|
|
means = sums / counts[:, None, None] |
|
|
|
|
|
if vols_sorted is not None: |
|
|
v_sums = np.add.reduceat(vols_sorted, idx) |
|
|
v_means = v_sums / counts.astype(np.float32) |
|
|
else: |
|
|
v_means = np.ones_like(uniq_ids, dtype=np.float32) |
|
|
|
|
|
|
|
|
self._token_to_id.clear(); self._id_to_token.clear() |
|
|
self._id_to_vec.clear(); self._id_to_volume.clear(); self._valid_token_ids.clear() |
|
|
self._id_to_pooled.clear() |
|
|
|
|
|
|
|
|
toks_sorted = toks[order] |
|
|
rep_toks = toks_sorted[idx] |
|
|
|
|
|
for tid, tok, X_mean, v_m in zip(uniq_ids.tolist(), rep_toks.tolist(), means, v_means.tolist()): |
|
|
|
|
|
if cache_pooled: |
|
|
self._id_to_pooled[tid] = X_mean.mean(axis=0).astype(np.float32, copy=False) |
|
|
X_store = _finalize_if_needed(X_mean) |
|
|
|
|
|
self._token_to_id[str(tok)] = tid |
|
|
self._id_to_token[tid] = str(tok) |
|
|
if store in ("full", "both"): |
|
|
self._id_to_vec[tid] = np.asarray(X_store, np.float32, order="C") |
|
|
elif store == "pooled": |
|
|
|
|
|
self._id_to_vec[tid] = (self._id_to_pooled[tid] if cache_pooled |
|
|
else X_mean.mean(axis=0).astype(np.float32, copy=False)) |
|
|
self._id_to_volume[tid] = float(v_m) |
|
|
self._valid_token_ids.add(tid) |
|
|
|
|
|
else: |
|
|
|
|
|
ids_int = ids.tolist() |
|
|
toks_str = [str(x) for x in toks.tolist()] |
|
|
vols_f = (vols_f.tolist() if vols_f is not None else [1.0] * len(ids_int)) |
|
|
|
|
|
x_sum: Dict[int, np.ndarray] = {} |
|
|
v_sum: Dict[int, float] = {} |
|
|
n_cnt: Dict[int, int] = {} |
|
|
tok_pref: Dict[int, str] = {} |
|
|
|
|
|
for tid, tok, raw, vol in zip(ids_int, toks_str, cryst, vols_f): |
|
|
X = _coerce(raw) |
|
|
if tid not in x_sum: |
|
|
x_sum[tid] = X.astype(np.float32, copy=True) |
|
|
v_sum[tid] = float(vol) |
|
|
n_cnt[tid] = 1 |
|
|
tok_pref[tid] = tok |
|
|
else: |
|
|
x_sum[tid] += X |
|
|
v_sum[tid] += float(vol) |
|
|
n_cnt[tid] += 1 |
|
|
|
|
|
self._token_to_id.clear(); self._id_to_token.clear() |
|
|
self._id_to_vec.clear(); self._id_to_volume.clear(); self._valid_token_ids.clear() |
|
|
self._id_to_pooled.clear() |
|
|
|
|
|
for tid in x_sum.keys(): |
|
|
X_mean = x_sum[tid] / float(n_cnt[tid]) |
|
|
if cache_pooled: |
|
|
self._id_to_pooled[tid] = X_mean.mean(axis=0).astype(np.float32, copy=False) |
|
|
X_store = _finalize_if_needed(X_mean) |
|
|
|
|
|
tok = tok_pref[tid] |
|
|
vol_m = v_sum[tid] / float(n_cnt[tid]) |
|
|
|
|
|
self._token_to_id[tok] = tid |
|
|
self._id_to_token[tid] = tok |
|
|
if store in ("full", "both"): |
|
|
self._id_to_vec[tid] = np.asarray(X_store, np.float32, order="C") |
|
|
elif store == "pooled": |
|
|
self._id_to_vec[tid] = (self._id_to_pooled[tid] if cache_pooled |
|
|
else X_mean.mean(axis=0).astype(np.float32, copy=False)) |
|
|
self._id_to_volume[tid] = float(vol_m) |
|
|
self._valid_token_ids.add(tid) |
|
|
|
|
|
|
|
|
if manifest_specials and base_set: |
|
|
self._manifest_special_tokens( |
|
|
base_set=base_set, |
|
|
create_crystal=create_crystal, |
|
|
callback=callback, |
|
|
create_config=create_config or {} |
|
|
) |
|
|
|
|
|
|
|
|
def pooled(self, token_or_id: Union[str, int], method: str = "mean") -> Optional[np.ndarray]: |
|
|
|
|
|
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id) |
|
|
if tid is not None and tid in self._id_to_pooled: |
|
|
return self._id_to_pooled[tid] |
|
|
return super().pooled(token_or_id, method=method) |
|
|
|
|
|
|
|
|
def encode(self, token: str, *, return_id: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]: |
|
|
tid = self._token_to_id.get(token) |
|
|
if tid is None: |
|
|
unk_id = self._token_to_id.get("<unk>") |
|
|
if unk_id is None: |
|
|
raise KeyError(f"Token '{token}' not found and '<unk>' missing.") |
|
|
X = self._id_to_vec[unk_id] |
|
|
return (X, unk_id) if return_id else X |
|
|
X = self._id_to_vec[tid] |
|
|
return (X, tid) if return_id else X |
|
|
|
|
|
def get_score(self, token_or_id: Union[str, int]) -> float: |
|
|
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id, None) |
|
|
if tid is None or tid not in self._valid_token_ids: |
|
|
return -100.0 |
|
|
vol = self._id_to_volume.get(tid, 1.0) |
|
|
return float(np.clip(vol / 10.0, 0.01, 1.0)) |
|
|
|
|
|
|
|
|
def cache(self, tokens: Union[List[str], Dict[str, int]], device: str = "cpu", dtype: torch.dtype = torch.float32): |
|
|
tok_list = list(tokens.keys()) if isinstance(tokens, dict) else list(tokens) |
|
|
mats, pooled, keep = [], [], [] |
|
|
for t in tok_list: |
|
|
X = self.embedding(t) |
|
|
v = self.pooled(t) |
|
|
if X is None or v is None: |
|
|
continue |
|
|
mats.append(torch.as_tensor(X, dtype=dtype)) |
|
|
pooled.append(torch.as_tensor(v, dtype=dtype)) |
|
|
keep.append(t) |
|
|
if not mats: |
|
|
raise ValueError("No valid tokens found in input.") |
|
|
return { |
|
|
"tokens": keep, |
|
|
"crystals": torch.stack(mats, 0).to(device), |
|
|
"pooled": torch.stack(pooled, 0).to(device), |
|
|
} |
|
|
|
|
|
|
|
|
def _coerce_crystal_shape( |
|
|
self, |
|
|
raw: Any, |
|
|
*, |
|
|
vertex_count: int, |
|
|
reshape_order: str, |
|
|
infer_dim: bool, |
|
|
strict_shapes: bool |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Accepts raw crystal data and returns [vertex_count, self.dim] float32 C-order. |
|
|
|
|
|
Acceptable inputs: |
|
|
- [vertex_count, D] |
|
|
- [vertex_count * D] (flat) -> reshaped to [vertex_count, D] |
|
|
- [D] (pooled center) -> converted by deterministic pentachoron (fallback) |
|
|
""" |
|
|
X = np.asarray(raw, dtype=np.float32) |
|
|
|
|
|
|
|
|
if X.ndim == 2: |
|
|
V, D = int(X.shape[0]), int(X.shape[1]) |
|
|
if V != vertex_count: |
|
|
if strict_shapes: |
|
|
raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.") |
|
|
|
|
|
if V * D % vertex_count == 0 and infer_dim: |
|
|
|
|
|
|
|
|
raise ValueError(f"Unexpected vertex rows {V}; refusing to coerce silently.") |
|
|
else: |
|
|
raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.") |
|
|
|
|
|
if D != self.dim: |
|
|
if infer_dim: |
|
|
self.dim = D |
|
|
else: |
|
|
raise ValueError(f"Dim mismatch: got D={D}, expected dim={self.dim}.") |
|
|
|
|
|
return X |
|
|
|
|
|
|
|
|
if X.ndim == 1: |
|
|
n = int(X.size) |
|
|
|
|
|
if n == vertex_count * self.dim: |
|
|
return np.reshape(X, (vertex_count, self.dim), order=reshape_order) |
|
|
|
|
|
|
|
|
if infer_dim and n % vertex_count == 0: |
|
|
inferred = n // vertex_count |
|
|
self.dim = int(inferred) |
|
|
return np.reshape(X, (vertex_count, self.dim), order=reshape_order) |
|
|
|
|
|
|
|
|
if n == self.dim: |
|
|
c = X / (np.abs(X).sum() + 1e-8) |
|
|
return self._deterministic_pentachoron(c) |
|
|
|
|
|
if strict_shapes: |
|
|
raise ValueError( |
|
|
f"Cannot coerce crystal of length {n}. " |
|
|
f"Expected {vertex_count*self.dim} (flat) or {self.dim} (pooled)." |
|
|
) |
|
|
|
|
|
if infer_dim and n > 0: |
|
|
self.dim = n |
|
|
c = X / (np.abs(X).sum() + 1e-8) |
|
|
return self._deterministic_pentachoron(c) |
|
|
|
|
|
raise ValueError(f"Unsupported crystal shape {X.shape} (ndim={X.ndim}).") |
|
|
|
|
|
|
|
|
|
|
|
def describe(self) -> Dict[str, Union[str, int]]: |
|
|
return {"repo": self.repo_id, "dimension": self.dim, "vocab_size": self.vocab_size()} |
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
import torch |
|
|
import numpy as np |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, Union, Tuple, Optional, Callable, Any, List |
|
|
import warnings |
|
|
from collections import OrderedDict |
|
|
import datasets |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
SILENT_MODE = False |
|
|
|
|
|
def set_silent_mode(silent: bool): |
|
|
"""Set global silent mode for token synthesis warnings""" |
|
|
global SILENT_MODE |
|
|
SILENT_MODE = silent |
|
|
|
|
|
class LRUCache(OrderedDict): |
|
|
"""Simple LRU cache implementation""" |
|
|
def __init__(self, maxsize=128): |
|
|
super().__init__() |
|
|
self.maxsize = maxsize |
|
|
|
|
|
def __getitem__(self, key): |
|
|
value = super().__getitem__(key) |
|
|
self.move_to_end(key) |
|
|
return value |
|
|
|
|
|
def __setitem__(self, key, value): |
|
|
if key in self: |
|
|
self.move_to_end(key) |
|
|
super().__setitem__(key, value) |
|
|
if len(self) > self.maxsize: |
|
|
oldest = next(iter(self)) |
|
|
del self[oldest] |
|
|
|
|
|
|
|
|
class LazyGeometricVocab(GeometricVocab): |
|
|
""" |
|
|
Lazy-loading geometric vocabulary that loads tokens on demand. |
|
|
Maintains a small working set in memory with LRU eviction. |
|
|
Supports automatic token synthesis for missing tokens. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
repo_id: str, |
|
|
dim: int, |
|
|
*, |
|
|
name: str = "unicode_100d", |
|
|
split: str = "train", |
|
|
stream: bool = True, |
|
|
base_set: Optional[Dict[str, int]] = None, |
|
|
create_config: Optional[Dict[str, Any]] = None, |
|
|
create_crystal: Optional[Callable] = None, |
|
|
callback: Optional[Callable] = None, |
|
|
manifest_specials: bool = True, |
|
|
|
|
|
cache_size: int = 1000, |
|
|
preload_tokens: Optional[List[str]] = None, |
|
|
index_cache_path: Optional[str] = None, |
|
|
|
|
|
tokenizer: Optional[Callable[[str], List[str]]] = None, |
|
|
|
|
|
silent: bool = False, |
|
|
|
|
|
store: str = "full", |
|
|
reshape_order: str = "C", |
|
|
vertex_count: int = 5, |
|
|
infer_dim: bool = True, |
|
|
finalize_mode: str = "post_mean", |
|
|
cache_pooled: bool = True, |
|
|
): |
|
|
super().__init__(dim) |
|
|
|
|
|
self.repo_id = repo_id |
|
|
self.name = name |
|
|
self.split = split |
|
|
self.stream = stream |
|
|
self.vertex_count = vertex_count |
|
|
self.reshape_order = reshape_order |
|
|
self.infer_dim = infer_dim |
|
|
self.finalize_mode = finalize_mode |
|
|
self.store = store |
|
|
self.cache_pooled = cache_pooled |
|
|
self.silent = silent |
|
|
|
|
|
|
|
|
if not hasattr(self, '_id_to_pooled'): |
|
|
self._id_to_pooled = {} |
|
|
|
|
|
|
|
|
self.create_crystal_fn = create_crystal |
|
|
self.callback_fn = callback |
|
|
self.create_config = create_config or {} |
|
|
self._synthesized_tokens: set = set() |
|
|
self._next_synthetic_id = -1 |
|
|
|
|
|
|
|
|
self.tokenizer = tokenizer or (lambda s: s.split()) |
|
|
|
|
|
|
|
|
self._crystal_cache = LRUCache(maxsize=cache_size) |
|
|
self._pooled_lru = LRUCache(maxsize=cache_size * 2) |
|
|
|
|
|
|
|
|
self._dataset = None |
|
|
self._dataset_stream = None |
|
|
self._token_index: Dict[str, List[int]] = {} |
|
|
self._id_index: Dict[int, List[int]] = {} |
|
|
self._row_data: Dict[int, dict] = {} |
|
|
|
|
|
|
|
|
self._build_index(split, name) |
|
|
|
|
|
|
|
|
self._preload_synthesis_base() |
|
|
|
|
|
|
|
|
if preload_tokens: |
|
|
self._preload(preload_tokens) |
|
|
|
|
|
|
|
|
if manifest_specials and base_set: |
|
|
self._manifest_special_tokens( |
|
|
base_set=base_set, |
|
|
create_crystal=create_crystal, |
|
|
callback=callback, |
|
|
create_config=create_config or {} |
|
|
) |
|
|
|
|
|
def _preload_synthesis_base(self): |
|
|
"""Pre-load basic ASCII characters needed for synthesis""" |
|
|
|
|
|
base_chars = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-_()[]{}:;'\"") |
|
|
|
|
|
print(f"Pre-loading {len(base_chars)} base characters for synthesis...") |
|
|
loaded = 0 |
|
|
for char in base_chars: |
|
|
tid = self._token_to_id.get(char) |
|
|
if tid: |
|
|
|
|
|
if self._load_crystal(tid) is not None: |
|
|
loaded += 1 |
|
|
print(f"Loaded {loaded} base characters") |
|
|
|
|
|
def _build_index(self, split: str, name: str): |
|
|
"""Build token/id index without loading crystal data""" |
|
|
print(f"Building index for {self.repo_id}/{name}/{split}...") |
|
|
|
|
|
if self.stream: |
|
|
try: |
|
|
|
|
|
|
|
|
self._dataset_stream = load_dataset( |
|
|
self.repo_id, |
|
|
name=name, |
|
|
split=split, |
|
|
streaming=True |
|
|
) |
|
|
|
|
|
|
|
|
for idx, row in enumerate(self._dataset_stream): |
|
|
token = str(row["token"]) |
|
|
token_id = int(row["token_id"]) |
|
|
|
|
|
|
|
|
if token not in self._token_index: |
|
|
self._token_index[token] = [] |
|
|
self._token_index[token].append(idx) |
|
|
|
|
|
|
|
|
if token_id not in self._id_index: |
|
|
self._id_index[token_id] = [] |
|
|
self._id_index[token_id].append(idx) |
|
|
|
|
|
|
|
|
if token not in self._token_to_id: |
|
|
self._token_to_id[token] = token_id |
|
|
self._id_to_token[token_id] = token |
|
|
self._valid_token_ids.add(token_id) |
|
|
|
|
|
print(f"Index built (streaming): {len(self._token_index)} unique tokens") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Streaming failed: {e}") |
|
|
print("Falling back to non-streaming mode...") |
|
|
self.stream = False |
|
|
|
|
|
self._build_index(split, name) |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
|
|
|
data_files = f"data/{name}/{split}-*.parquet" |
|
|
ds = load_dataset( |
|
|
self.repo_id, |
|
|
data_files=data_files, |
|
|
split="train" |
|
|
) |
|
|
except: |
|
|
|
|
|
try: |
|
|
ds = load_dataset( |
|
|
self.repo_id, |
|
|
name=name, |
|
|
split=split |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to load dataset: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
for idx, row in enumerate(ds): |
|
|
token = str(row["token"]) |
|
|
token_id = int(row["token_id"]) |
|
|
|
|
|
|
|
|
if token not in self._token_index: |
|
|
self._token_index[token] = [] |
|
|
self._token_index[token].append(idx) |
|
|
|
|
|
|
|
|
if token_id not in self._id_index: |
|
|
self._id_index[token_id] = [] |
|
|
self._id_index[token_id].append(idx) |
|
|
|
|
|
|
|
|
if token not in self._token_to_id: |
|
|
self._token_to_id[token] = token_id |
|
|
self._id_to_token[token_id] = token |
|
|
self._valid_token_ids.add(token_id) |
|
|
|
|
|
|
|
|
self._dataset = ds |
|
|
print(f"Index built: {len(self._token_index)} unique tokens") |
|
|
|
|
|
def _load_row(self, row_idx: int) -> dict: |
|
|
"""Load a single row from dataset""" |
|
|
if row_idx in self._row_data: |
|
|
return self._row_data[row_idx] |
|
|
|
|
|
|
|
|
if self.stream and self._dataset is None: |
|
|
print(f"Loading full dataset for {self.repo_id}/{self.name}/{self.split}...") |
|
|
try: |
|
|
|
|
|
data_files = f"data/{self.name}/{self.split}-*.parquet" |
|
|
self._dataset = load_dataset( |
|
|
self.repo_id, |
|
|
data_files=data_files, |
|
|
split="train" |
|
|
) |
|
|
except: |
|
|
|
|
|
self._dataset = load_dataset( |
|
|
self.repo_id, |
|
|
name=self.name, |
|
|
split=self.split |
|
|
) |
|
|
|
|
|
if self._dataset is None: |
|
|
raise RuntimeError("Dataset not initialized") |
|
|
|
|
|
row = self._dataset[row_idx] |
|
|
self._row_data[row_idx] = row |
|
|
return row |
|
|
|
|
|
def _load_crystal(self, token_id: int) -> Optional[np.ndarray]: |
|
|
"""Load and aggregate crystal for a token_id""" |
|
|
if token_id in self._crystal_cache: |
|
|
return self._crystal_cache[token_id] |
|
|
|
|
|
if token_id not in self._id_index: |
|
|
return None |
|
|
|
|
|
row_indices = self._id_index[token_id] |
|
|
crystals = [] |
|
|
volumes = [] |
|
|
|
|
|
for idx in row_indices: |
|
|
row = self._load_row(idx) |
|
|
|
|
|
|
|
|
raw_crystal = row.get("crystal") |
|
|
if raw_crystal is not None: |
|
|
X = self._coerce_crystal(raw_crystal) |
|
|
crystals.append(X) |
|
|
|
|
|
|
|
|
vol = row.get("volume", 1.0) |
|
|
volumes.append(float(vol)) |
|
|
|
|
|
if not crystals: |
|
|
return None |
|
|
|
|
|
|
|
|
if len(crystals) == 1: |
|
|
X_final = crystals[0] |
|
|
vol_final = volumes[0] |
|
|
else: |
|
|
X_final = np.mean(crystals, axis=0) |
|
|
vol_final = np.mean(volumes) |
|
|
|
|
|
|
|
|
if self.finalize_mode == "post_mean": |
|
|
X_final = self._finalize_crystal(X_final) |
|
|
|
|
|
|
|
|
if self.store in ("full", "both"): |
|
|
self._crystal_cache[token_id] = X_final |
|
|
self._id_to_vec[token_id] = X_final |
|
|
|
|
|
|
|
|
if self.cache_pooled: |
|
|
pooled = X_final.mean(axis=0) |
|
|
self._pooled_lru[token_id] = pooled |
|
|
if token_id not in self._id_to_pooled: |
|
|
self._id_to_pooled[token_id] = pooled |
|
|
|
|
|
|
|
|
self._id_to_volume[token_id] = vol_final |
|
|
|
|
|
return X_final |
|
|
|
|
|
def _coerce_crystal(self, raw: Any) -> np.ndarray: |
|
|
"""Convert raw crystal data to proper shape""" |
|
|
X = np.asarray(raw, dtype=np.float32) |
|
|
|
|
|
if X.ndim == 2: |
|
|
V, D = X.shape |
|
|
if V != self.vertex_count: |
|
|
raise ValueError(f"Expected {self.vertex_count} vertices, got {V}") |
|
|
if D != self.dim: |
|
|
if self.infer_dim: |
|
|
self.dim = D |
|
|
else: |
|
|
raise ValueError(f"Dimension mismatch: {D} vs {self.dim}") |
|
|
return X |
|
|
|
|
|
if X.ndim == 1: |
|
|
n = X.size |
|
|
if n == self.vertex_count * self.dim: |
|
|
return X.reshape((self.vertex_count, self.dim), order=self.reshape_order) |
|
|
if self.infer_dim and n % self.vertex_count == 0: |
|
|
self.dim = n // self.vertex_count |
|
|
return X.reshape((self.vertex_count, self.dim), order=self.reshape_order) |
|
|
if n == self.dim: |
|
|
|
|
|
c = X / (np.abs(X).sum() + 1e-8) |
|
|
return self._deterministic_pentachoron(c) |
|
|
|
|
|
raise ValueError(f"Cannot coerce crystal shape {X.shape}") |
|
|
|
|
|
def _synthesize_token(self, token: str) -> int: |
|
|
"""Synthesize a new token embedding on-the-fly with fallback for missing chars.""" |
|
|
|
|
|
tid = self._next_synthetic_id |
|
|
self._next_synthetic_id -= 1 |
|
|
|
|
|
|
|
|
if not self.silent and not SILENT_MODE: |
|
|
warnings.warn( |
|
|
f"Token '{token}' synthesized - ensure you synthesize your tokens ahead of time.", |
|
|
UserWarning, |
|
|
stacklevel=3 |
|
|
) |
|
|
|
|
|
|
|
|
self._synthesized_tokens.add(token) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
missing_chars = [] |
|
|
for char in token: |
|
|
if char not in self._token_to_id and char not in self._char_cache: |
|
|
missing_chars.append(char) |
|
|
|
|
|
|
|
|
if missing_chars: |
|
|
for char in missing_chars: |
|
|
char_tid = self._token_to_id.get(char) |
|
|
if char_tid: |
|
|
|
|
|
self._load_crystal(char_tid) |
|
|
else: |
|
|
|
|
|
self._synthesize_simple_char(char) |
|
|
|
|
|
|
|
|
helpers = self._helpers() |
|
|
cfg = { |
|
|
"dim": self.dim, |
|
|
"pool_type": self.create_config.get("pool_type", "unicode"), |
|
|
"data": {"token": token, "token_id": tid, "origin": "synthetic"}, |
|
|
"helpers": helpers, |
|
|
} |
|
|
|
|
|
if self.create_crystal_fn is not None: |
|
|
product = self.create_crystal_fn(cfg, self.callback_fn) |
|
|
else: |
|
|
product = self._default_create_crystal(cfg, self._default_unicode_callback) |
|
|
|
|
|
X, prov = self._finalize_crystal_and_provenance(product, cfg) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Character-based synthesis failed for '{token}': {e}. Using random synthesis.") |
|
|
X = self._synthesize_random_crystal(token) |
|
|
prov = {"source": "synthetic_random", "token": token} |
|
|
|
|
|
prov["synthetic"] = True |
|
|
|
|
|
|
|
|
self._token_to_id[token] = tid |
|
|
self._id_to_token[tid] = token |
|
|
self._id_to_vec[tid] = X.astype(np.float32, copy=False, order='C') |
|
|
self._id_to_provenance[tid] = prov |
|
|
self._valid_token_ids.add(tid) |
|
|
self._id_to_volume[tid] = 1.0 |
|
|
|
|
|
|
|
|
self._crystal_cache[tid] = X |
|
|
if self.cache_pooled: |
|
|
pooled = X.mean(axis=0) |
|
|
self._pooled_lru[tid] = pooled |
|
|
self._id_to_pooled[tid] = pooled |
|
|
|
|
|
return tid |
|
|
|
|
|
def _synthesize_simple_char(self, char: str): |
|
|
"""Create a simple deterministic embedding for a single character""" |
|
|
import hashlib |
|
|
|
|
|
|
|
|
if len(char) == 1: |
|
|
seed = ord(char) |
|
|
else: |
|
|
seed = int(hashlib.md5(char.encode()).hexdigest()[:8], 16) |
|
|
|
|
|
np.random.seed(seed) |
|
|
|
|
|
|
|
|
vec = np.random.randn(self.dim).astype(np.float32) |
|
|
vec = vec / (np.abs(vec).sum() + 1e-8) |
|
|
|
|
|
|
|
|
self._char_cache[char] = vec |
|
|
|
|
|
def _synthesize_random_crystal(self, token: str) -> np.ndarray: |
|
|
"""Fallback: create a deterministic random crystal based on token string""" |
|
|
import hashlib |
|
|
|
|
|
|
|
|
seed = int(hashlib.md5(token.encode()).hexdigest()[:8], 16) |
|
|
np.random.seed(seed) |
|
|
|
|
|
|
|
|
X = np.random.randn(self.vertex_count, self.dim).astype(np.float32) |
|
|
X = self._finalize_crystal(X) |
|
|
|
|
|
return X |
|
|
|
|
|
def _preload(self, tokens: List[str]): |
|
|
"""Preload specific tokens into cache""" |
|
|
print(f"Preloading {len(tokens)} tokens...") |
|
|
for token in tokens: |
|
|
tid = self._token_to_id.get(token) |
|
|
if tid: |
|
|
self._load_crystal(tid) |
|
|
|
|
|
|
|
|
|
|
|
def embedding(self, token_or_id: Union[str, int], generate: bool = False) -> Optional[np.ndarray]: |
|
|
"""Get embedding, loading if necessary, synthesizing if requested""" |
|
|
|
|
|
if isinstance(token_or_id, int): |
|
|
tid = token_or_id |
|
|
token = self._id_to_token.get(tid) |
|
|
else: |
|
|
token = token_or_id |
|
|
tid = self._token_to_id.get(token) |
|
|
|
|
|
if tid is not None: |
|
|
|
|
|
if tid in self._id_to_vec: |
|
|
return self._id_to_vec[tid] |
|
|
|
|
|
return self._load_crystal(tid) |
|
|
|
|
|
|
|
|
if generate and token is not None: |
|
|
tid = self._synthesize_token(token) |
|
|
return self._id_to_vec[tid] |
|
|
|
|
|
return None |
|
|
|
|
|
def pooled(self, token_or_id: Union[str, int], method: str = "mean", generate: bool = False) -> Optional[np.ndarray]: |
|
|
"""Get pooled vector, loading if necessary, synthesizing if requested""" |
|
|
|
|
|
if isinstance(token_or_id, int): |
|
|
tid = token_or_id |
|
|
token = self._id_to_token.get(tid) |
|
|
else: |
|
|
token = token_or_id |
|
|
tid = self._token_to_id.get(token) |
|
|
|
|
|
if tid is not None: |
|
|
|
|
|
if tid in self._pooled_lru: |
|
|
return self._pooled_lru[tid] |
|
|
if tid in self._id_to_pooled: |
|
|
return self._id_to_pooled[tid] |
|
|
|
|
|
|
|
|
X = self.embedding(tid, generate=False) |
|
|
if X is not None: |
|
|
if method == "mean": |
|
|
pooled = X.mean(axis=0) |
|
|
self._pooled_lru[tid] = pooled |
|
|
return pooled |
|
|
elif method == "first": |
|
|
return X[0] |
|
|
elif method == "sum": |
|
|
return X.sum(axis=0) |
|
|
else: |
|
|
raise ValueError(f"Unknown pooling method: {method}") |
|
|
|
|
|
|
|
|
if generate and token is not None: |
|
|
tid = self._synthesize_token(token) |
|
|
return self.pooled(tid, method=method, generate=False) |
|
|
|
|
|
return None |
|
|
|
|
|
def encode(self, token: str, *, return_id: bool = False, generate: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]: |
|
|
"""Encode token, loading if necessary, synthesizing if requested""" |
|
|
tid = self._token_to_id.get(token) |
|
|
|
|
|
if tid is None: |
|
|
if generate: |
|
|
|
|
|
tid = self._synthesize_token(token) |
|
|
X = self._id_to_vec[tid] |
|
|
else: |
|
|
|
|
|
unk_id = self._token_to_id.get("<unk>") |
|
|
if unk_id is None: |
|
|
|
|
|
if generate: |
|
|
tid = self._synthesize_token(token) |
|
|
X = self._id_to_vec[tid] |
|
|
else: |
|
|
raise KeyError(f"Token '{token}' not found and no <unk> token available") |
|
|
else: |
|
|
X = self.embedding(unk_id, generate=False) |
|
|
tid = unk_id |
|
|
else: |
|
|
X = self.embedding(tid, generate=False) |
|
|
if X is None: |
|
|
raise RuntimeError(f"Failed to load embedding for token '{token}'") |
|
|
|
|
|
return (X, tid) if return_id else X |
|
|
|
|
|
def get_score(self, token_or_id: Union[str, int]) -> float: |
|
|
"""Get token score""" |
|
|
tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id) |
|
|
if tid is None or tid not in self._valid_token_ids: |
|
|
return -100.0 |
|
|
|
|
|
|
|
|
if tid not in self._id_to_volume: |
|
|
self._load_crystal(tid) |
|
|
|
|
|
vol = self._id_to_volume.get(tid, 1.0) |
|
|
return float(np.clip(vol / 10.0, 0.01, 1.0)) |
|
|
|
|
|
def encode_batch(self, tokens: Union[str, List[str]], |
|
|
*, return_ids: bool = False, |
|
|
prefetch: bool = True, |
|
|
generate: bool = False) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[int]]]: |
|
|
""" |
|
|
Encode a batch of tokens efficiently. |
|
|
|
|
|
Args: |
|
|
tokens: Either a string (will be tokenized) or list of token strings |
|
|
return_ids: Whether to return token IDs alongside embeddings |
|
|
prefetch: Whether to prefetch all tokens before encoding |
|
|
generate: If True, synthesize missing tokens |
|
|
|
|
|
Returns: |
|
|
List of embeddings, optionally with list of token IDs |
|
|
""" |
|
|
|
|
|
if isinstance(tokens, str): |
|
|
tokens = self.tokenizer(tokens) |
|
|
|
|
|
if not isinstance(tokens, list): |
|
|
raise TypeError(f"Expected str or List[str], got {type(tokens)}") |
|
|
|
|
|
|
|
|
tokens_to_synthesize = [] |
|
|
if generate: |
|
|
for token in tokens: |
|
|
if token not in self._token_to_id: |
|
|
tokens_to_synthesize.append(token) |
|
|
|
|
|
|
|
|
if tokens_to_synthesize and not self.silent and not SILENT_MODE: |
|
|
warnings.warn( |
|
|
f"{len(tokens_to_synthesize)} tokens synthesized - ensure you synthesize your tokens ahead of time. " |
|
|
f"Synthesized: {tokens_to_synthesize[:5]}{'...' if len(tokens_to_synthesize) > 5 else ''}", |
|
|
UserWarning, |
|
|
stacklevel=2 |
|
|
) |
|
|
|
|
|
|
|
|
if prefetch: |
|
|
self._prefetch_batch([t for t in tokens if t in self._token_to_id]) |
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
ids = [] |
|
|
|
|
|
for token in tokens: |
|
|
if return_ids: |
|
|
emb, tid = self.encode(token, return_id=True, generate=generate) |
|
|
embeddings.append(emb) |
|
|
ids.append(tid) |
|
|
else: |
|
|
emb = self.encode(token, return_id=False, generate=generate) |
|
|
embeddings.append(emb) |
|
|
|
|
|
return (embeddings, ids) if return_ids else embeddings |
|
|
|
|
|
def _prefetch_batch(self, tokens: List[str]): |
|
|
""" |
|
|
Prefetch a batch of tokens efficiently. |
|
|
""" |
|
|
|
|
|
tokens_to_load = [] |
|
|
for token in tokens: |
|
|
tid = self._token_to_id.get(token) |
|
|
if tid and tid not in self._crystal_cache and tid not in self._id_to_vec: |
|
|
tokens_to_load.append(tid) |
|
|
|
|
|
if not tokens_to_load: |
|
|
return |
|
|
|
|
|
|
|
|
for tid in tokens_to_load: |
|
|
self._load_crystal(tid) |
|
|
|
|
|
def cache_stats(self) -> Dict[str, Any]: |
|
|
"""Get cache statistics""" |
|
|
stats = super().cache_stats() |
|
|
stats.update({ |
|
|
"crystal_cache_size": len(self._crystal_cache), |
|
|
"pooled_lru_size": len(self._pooled_lru), |
|
|
"rows_cached": len(self._row_data), |
|
|
"tokens_indexed": len(self._token_index), |
|
|
"ids_indexed": len(self._id_index), |
|
|
"synthesized_tokens": len(self._synthesized_tokens), |
|
|
}) |
|
|
return stats |
|
|
|
|
|
def evict_from_cache(self, tokens: Optional[List[str]] = None): |
|
|
"""Manually evict tokens from cache to free memory""" |
|
|
if tokens is None: |
|
|
|
|
|
self._crystal_cache.clear() |
|
|
self._pooled_lru.clear() |
|
|
self._id_to_vec.clear() |
|
|
self._id_to_pooled.clear() |
|
|
self._row_data.clear() |
|
|
else: |
|
|
|
|
|
for token in tokens: |
|
|
tid = self._token_to_id.get(token) |
|
|
if tid: |
|
|
self._crystal_cache.pop(tid, None) |
|
|
self._pooled_lru.pop(tid, None) |
|
|
self._id_to_vec.pop(tid, None) |
|
|
self._id_to_pooled.pop(tid, None) |
|
|
|
|
|
def get_synthesized_tokens(self) -> List[str]: |
|
|
"""Get list of all tokens that were synthesized at runtime""" |
|
|
return list(self._synthesized_tokens) |
|
|
|
|
|
def is_synthesized(self, token: str) -> bool: |
|
|
"""Check if a token was synthesized at runtime""" |
|
|
return token in self._synthesized_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab = LazyGeometricVocab( |
|
|
repo_id="AbstractPhil/geometric-vocab", |
|
|
dim=64, |
|
|
name="unicode_64d", |
|
|
split="train", |
|
|
stream=False, |
|
|
cache_size=1024, |
|
|
silent=False |
|
|
) |
|
|
FROZEN_VOCAB = [] |