AbstractPhil commited on
Commit
bca5039
·
verified ·
1 Parent(s): d0f203c

Create vocab.py

Browse files

Buggy vocab, not production ready but it will work.

Files changed (1) hide show
  1. vocab.py +1764 -0
vocab.py ADDED
@@ -0,0 +1,1764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+ import numpy as np
4
+ from abc import ABC, abstractmethod
5
+ from typing import Dict, Union, Tuple, Optional, Callable, Any, List
6
+ import warnings
7
+ from collections import defaultdict
8
+ import datasets
9
+ from datasets import load_dataset
10
+
11
+ # Optional dependencies for spatial indexing
12
+ try:
13
+ import faiss
14
+ FAISS_AVAILABLE = True
15
+ except ImportError:
16
+ FAISS_AVAILABLE = False
17
+
18
+ try:
19
+ from sklearn.neighbors import NearestNeighbors
20
+ SKLEARN_AVAILABLE = True
21
+ except ImportError:
22
+ SKLEARN_AVAILABLE = False
23
+
24
+
25
+ class SpatialIndex:
26
+ """Spatial indexing for fast similarity search."""
27
+
28
+ def __init__(self, vectors: np.ndarray, token_ids: List[int], method: str = "auto"):
29
+ self.token_ids = np.array(token_ids)
30
+ self.method = method
31
+ self._index = None
32
+
33
+ if method == "auto":
34
+ if FAISS_AVAILABLE and vectors.shape[0] > 1000:
35
+ method = "faiss"
36
+ elif SKLEARN_AVAILABLE:
37
+ method = "sklearn"
38
+ else:
39
+ method = "linear"
40
+
41
+ self._build_index(vectors, method)
42
+
43
+ def _build_index(self, vectors: np.ndarray, method: str):
44
+ if method == "faiss" and FAISS_AVAILABLE:
45
+ # L1 distance approximation using L2 index with normalized vectors
46
+ vectors_l2 = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-8)
47
+ self._index = faiss.IndexFlatIP(vectors_l2.shape[1]) # Inner product for normalized vectors
48
+ self._index.add(vectors_l2.astype(np.float32))
49
+ self.method = "faiss"
50
+
51
+ elif method == "sklearn" and SKLEARN_AVAILABLE:
52
+ # Use manhattan distance for true L1
53
+ self._index = NearestNeighbors(
54
+ metric='manhattan',
55
+ algorithm='ball_tree',
56
+ n_jobs=-1
57
+ ).fit(vectors)
58
+ self.method = "sklearn"
59
+ else:
60
+ # Fallback to linear search
61
+ self._vectors = vectors
62
+ self.method = "linear"
63
+
64
+ def search_radius(self, query_vector: np.ndarray, max_distance: float, max_results: int = 1000) -> Tuple[
65
+ List[int], List[float]]:
66
+ """Find all points within max_distance using L1 metric."""
67
+ if self.method == "sklearn":
68
+ indices = self._index.radius_neighbors([query_vector], radius=max_distance)[1][0]
69
+ if len(indices) > max_results:
70
+ # Compute actual distances and take closest
71
+ distances = np.sum(np.abs(self._vectors[indices] - query_vector), axis=1)
72
+ top_k = np.argsort(distances)[:max_results]
73
+ indices = indices[top_k]
74
+ distances = np.sum(np.abs(self._vectors[indices] - query_vector), axis=1)
75
+ return self.token_ids[indices].tolist(), distances.tolist()
76
+
77
+ elif self.method == "faiss":
78
+ # Approximate search using cosine similarity
79
+ query_l2 = query_vector / (np.linalg.norm(query_vector) + 1e-8)
80
+ similarities, indices = self._index.search(query_l2.reshape(1, -1).astype(np.float32), max_results)
81
+ # Filter by converting similarity threshold to approximate distance
82
+ threshold_sim = 1.0 - max_distance # rough approximation
83
+ mask = similarities[0] >= threshold_sim
84
+ return self.token_ids[indices[0][mask]].tolist(), (1.0 - similarities[0][mask]).tolist()
85
+
86
+ else: # linear
87
+ distances = np.sum(np.abs(self._vectors - query_vector), axis=1)
88
+ mask = distances <= max_distance
89
+ if np.sum(mask) > max_results:
90
+ indices = np.argsort(distances)[:max_results]
91
+ mask = np.zeros_like(distances, dtype=bool)
92
+ mask[indices] = True
93
+ return self.token_ids[mask].tolist(), distances[mask].tolist()
94
+
95
+
96
+ class GeometricVocab(ABC):
97
+ """
98
+ Optimized geometric vocabulary with spatial indexing and caching.
99
+ """
100
+
101
+ def __init__(self, dim: int):
102
+ self.dim = int(dim)
103
+ self._token_to_id: Dict[str, int] = {}
104
+ self._id_to_token: Dict[int, str] = {}
105
+ self._id_to_vec: Dict[int, np.ndarray] = {}
106
+ self._id_to_volume: Dict[int, float] = {}
107
+ self._id_to_provenance: Dict[int, dict] = {}
108
+ self._valid_token_ids: set[int] = set()
109
+
110
+ # Optimization caches
111
+ self._normalized_cache: Dict[int, np.ndarray] = {}
112
+ self._pooled_cache: Dict[int, np.ndarray] = {}
113
+ self._spatial_index: Optional[SpatialIndex] = None
114
+ self._index_dirty = False
115
+
116
+ # NEW: Character-level cache for Unicode composition
117
+ self._char_cache: Dict[str, np.ndarray] = {}
118
+ self._char_lookups_saved = 0 # Statistics
119
+
120
+ def _invalidate_caches(self):
121
+ """Invalidate caches when vocabulary changes."""
122
+ self._normalized_cache.clear()
123
+ self._pooled_cache.clear()
124
+ self._spatial_index = None
125
+ self._index_dirty = True
126
+ # Keep char cache across vocabulary changes as characters are stable
127
+
128
+ def _ensure_spatial_index(self):
129
+ """Build spatial index if needed."""
130
+ if self._spatial_index is None or self._index_dirty:
131
+ if len(self._valid_token_ids) < 10:
132
+ return # Too few tokens for indexing
133
+
134
+ pooled_vectors = []
135
+ token_ids = []
136
+ for tid in sorted(self._valid_token_ids):
137
+ pooled_vec = self._get_cached_pooled(tid)
138
+ if pooled_vec is not None:
139
+ pooled_vectors.append(pooled_vec)
140
+ token_ids.append(tid)
141
+
142
+ if pooled_vectors:
143
+ self._spatial_index = SpatialIndex(
144
+ np.array(pooled_vectors),
145
+ token_ids,
146
+ method="auto"
147
+ )
148
+ self._index_dirty = False
149
+
150
+ def _get_cached_pooled(self, token_id: int) -> Optional[np.ndarray]:
151
+ """Get pooled vector with caching."""
152
+ if token_id in self._pooled_cache:
153
+ return self._pooled_cache[token_id]
154
+
155
+ if token_id in self._id_to_vec:
156
+ X = self._id_to_vec[token_id]
157
+ pooled = X.mean(axis=0)
158
+ self._pooled_cache[token_id] = pooled
159
+ return pooled
160
+ return None
161
+
162
+ def _get_cached_normalized(self, token_id: int) -> Optional[np.ndarray]:
163
+ """Get L1-normalized pooled vector with caching."""
164
+ if token_id in self._normalized_cache:
165
+ return self._normalized_cache[token_id]
166
+
167
+ pooled = self._get_cached_pooled(token_id)
168
+ if pooled is not None:
169
+ normalized = pooled / (np.abs(pooled).sum() + 1e-8)
170
+ self._normalized_cache[token_id] = normalized
171
+ return normalized
172
+ return None
173
+
174
+ # --------------------------- abstract surface --------------------
175
+ @abstractmethod
176
+ def encode(self, token: str, *, return_id: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
177
+ raise NotImplementedError
178
+
179
+ @abstractmethod
180
+ def get_score(self, token_or_id: Union[str, int]) -> float:
181
+ raise NotImplementedError
182
+
183
+ # --------------------------- basic queries (optimized) -----------------------
184
+ def decode(self, token_id: int, fallback: str = "<unk>") -> Optional[str]:
185
+ if token_id in self._id_to_token:
186
+ return self._id_to_token[token_id]
187
+ return fallback if fallback in self._token_to_id else None
188
+
189
+ def decode_with_provenance(self, token_id: int, fallback: str = "<unk>") -> Tuple[Optional[str], Optional[dict]]:
190
+ tok = self.decode(token_id, fallback=fallback)
191
+ prov = self._id_to_provenance.get(token_id)
192
+ return tok, prov
193
+
194
+ def provenance(self, token_or_id: Union[str, int]) -> Optional[dict]:
195
+ tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
196
+ return self._id_to_provenance.get(tid)
197
+
198
+ def embedding(self, token_or_id: Union[str, int]) -> Optional[np.ndarray]:
199
+ tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
200
+ return self._id_to_vec.get(tid)
201
+
202
+ def pooled(self, token_or_id: Union[str, int], method: str = "mean") -> Optional[np.ndarray]:
203
+ """Optimized pooled method with character caching"""
204
+
205
+ # Fast path for single characters
206
+ if isinstance(token_or_id, str) and len(token_or_id) == 1:
207
+ if token_or_id in self._char_cache:
208
+ self._char_lookups_saved += 1
209
+ return self._char_cache[token_or_id].copy() # Return copy to prevent mutation
210
+
211
+ # Regular lookup
212
+ tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
213
+ if tid is None:
214
+ return None
215
+
216
+ if method == "mean":
217
+ pooled = self._get_cached_pooled(tid)
218
+
219
+ # Cache single characters for future use
220
+ if pooled is not None and isinstance(token_or_id, str) and len(token_or_id) == 1:
221
+ self._char_cache[token_or_id] = pooled.copy()
222
+
223
+ return pooled
224
+
225
+ # Fallback for other methods
226
+ X = self._id_to_vec.get(tid)
227
+ if X is None:
228
+ return None
229
+ if method == "first":
230
+ return X[0]
231
+ if method == "sum":
232
+ return X.sum(axis=0)
233
+ raise ValueError(f"Invalid pooling method: {method}")
234
+
235
+ def pooled_batch(self, tokens: List[Union[str, int]], method: str = "mean") -> List[Optional[np.ndarray]]:
236
+ """Batch pooling with character-level caching for efficiency"""
237
+ results = []
238
+
239
+ for token in tokens:
240
+ # Use optimized single pooled method which handles char caching
241
+ results.append(self.pooled(token, method))
242
+
243
+ return results
244
+
245
+ # --------------------------- optimized similarity ---------------------
246
+ def similarity(self, token_a: Union[str, int], token_b: Union[str, int]) -> float:
247
+ """
248
+ Optimized L1-normalized directional similarity using cached vectors.
249
+ """
250
+ tid_a = token_a if isinstance(token_a, int) else self._token_to_id.get(token_a)
251
+ tid_b = token_b if isinstance(token_b, int) else self._token_to_id.get(token_b)
252
+
253
+ if tid_a is None or tid_b is None:
254
+ return -1.0
255
+
256
+ a_norm = self._get_cached_normalized(tid_a)
257
+ b_norm = self._get_cached_normalized(tid_b)
258
+
259
+ if a_norm is None or b_norm is None:
260
+ return -1.0
261
+
262
+ return float(np.dot(a_norm, b_norm))
263
+
264
+ def similarity_magnitude(self, token_a: Union[str, int], token_b: Union[str, int]) -> float:
265
+ """
266
+ Raw dot-product using cached pooled vectors.
267
+ """
268
+ tid_a = token_a if isinstance(token_a, int) else self._token_to_id.get(token_a)
269
+ tid_b = token_b if isinstance(token_b, int) else self._token_to_id.get(token_b)
270
+
271
+ if tid_a is None or tid_b is None:
272
+ return -1.0
273
+
274
+ a = self._get_cached_pooled(tid_a)
275
+ b = self._get_cached_pooled(tid_b)
276
+
277
+ if a is None or b is None:
278
+ return -1.0
279
+
280
+ return float(np.dot(a, b))
281
+
282
+ # --------------------------- optimized spatial search ---------------------
283
+ def extract_band(self, trajectory: np.ndarray, max_angle: float = 0.3, method: str = "pooled") -> Dict[
284
+ str, np.ndarray]:
285
+ """
286
+ Optimized spatial search using indexing when available.
287
+ """
288
+ if trajectory.ndim == 2:
289
+ direction = trajectory.mean(0)
290
+ else:
291
+ direction = trajectory
292
+ direction = direction / (np.abs(direction).sum() + 1e-8)
293
+
294
+ # Try spatial index first
295
+ self._ensure_spatial_index()
296
+ if self._spatial_index is not None:
297
+ try:
298
+ # Convert angle threshold to distance threshold (approximation)
299
+ max_distance = max_angle * 2.0 # rough conversion
300
+ token_ids, distances = self._spatial_index.search_radius(
301
+ direction, max_distance, max_results=1000
302
+ )
303
+
304
+ # Refine results with exact L1 similarity check
305
+ out: Dict[str, np.ndarray] = {}
306
+ for tid in token_ids:
307
+ tok = self._id_to_token.get(tid)
308
+ if tok is None:
309
+ continue
310
+ v_norm = self._get_cached_normalized(tid)
311
+ if v_norm is not None and float(np.dot(v_norm, direction)) >= 1.0 - max_angle:
312
+ out[tok] = self._id_to_vec[tid]
313
+ return out
314
+
315
+ except Exception as e:
316
+ warnings.warn(f"Spatial index search failed: {e}, falling back to linear")
317
+
318
+ # Fallback to linear search
319
+ out: Dict[str, np.ndarray] = {}
320
+ for tok, tid in self._token_to_id.items():
321
+ v_norm = self._get_cached_normalized(tid)
322
+ if v_norm is not None and float(np.dot(v_norm, direction)) >= 1.0 - max_angle:
323
+ out[tok] = self._id_to_vec[tid]
324
+ return out
325
+
326
+ def find_similar_tokens(self, token: Union[str, int], k: int = 10, min_similarity: float = 0.5) -> List[
327
+ Tuple[str, float]]:
328
+ """
329
+ Find k most similar tokens using spatial indexing when available.
330
+ """
331
+ tid = token if isinstance(token, int) else self._token_to_id.get(token)
332
+ if tid is None:
333
+ return []
334
+
335
+ query_vec = self._get_cached_normalized(tid)
336
+ if query_vec is None:
337
+ return []
338
+
339
+ self._ensure_spatial_index()
340
+ if self._spatial_index is not None:
341
+ try:
342
+ # Use spatial index for approximate search
343
+ max_distance = (1.0 - min_similarity) * 2.0
344
+ token_ids, _ = self._spatial_index.search_radius(
345
+ query_vec, max_distance, max_results=k * 3 # Get extra for refinement
346
+ )
347
+
348
+ # Compute exact similarities and sort
349
+ similarities = []
350
+ for tid_cand in token_ids:
351
+ if tid_cand == tid: # Skip self
352
+ continue
353
+ sim = self.similarity(tid, tid_cand)
354
+ if sim >= min_similarity:
355
+ tok = self._id_to_token.get(tid_cand)
356
+ if tok:
357
+ similarities.append((tok, sim))
358
+
359
+ return sorted(similarities, key=lambda x: x[1], reverse=True)[:k]
360
+
361
+ except Exception as e:
362
+ warnings.warn(f"Spatial similarity search failed: {e}, falling back to linear")
363
+
364
+ # Linear fallback
365
+ similarities = []
366
+ for tok_cand, tid_cand in self._token_to_id.items():
367
+ if tid_cand == tid:
368
+ continue
369
+ sim = self.similarity(tid, tid_cand)
370
+ if sim >= min_similarity:
371
+ similarities.append((tok_cand, sim))
372
+
373
+ return sorted(similarities, key=lambda x: x[1], reverse=True)[:k]
374
+
375
+ # --------------------------- helpers exposed to callbacks --------
376
+ def _helpers(self) -> Dict[str, Callable[..., np.ndarray]]:
377
+ def _emb(x):
378
+ e = self.embedding(x)
379
+ return None if e is None else np.asarray(e, np.float32)
380
+
381
+ def _poo(x):
382
+ p = self.pooled(x)
383
+ return None if p is None else np.asarray(p, np.float32)
384
+
385
+ def _chars(s):
386
+ # Use batch pooling for efficiency
387
+ return self.pooled_batch(list(s)) if isinstance(s, str) else None
388
+
389
+ return {"embedding": _emb, "pooled": _poo, "chars_pooled": _chars}
390
+
391
+ # --------------------------- DEFAULT create_crystal (unicode path) ----
392
+ def _default_create_crystal(self, config: dict, callback: Callable[..., np.ndarray]) -> np.ndarray:
393
+ """
394
+ Deterministic default when user leaves callback/create_crystal=None.
395
+ """
396
+ pool_type = config.get("pool_type") or "unicode"
397
+ H = config["helpers"]
398
+ token_plain = str(config["data"]["token"])
399
+ d = int(config["dim"])
400
+
401
+ c_uni = self._compose_unicode_center(token_plain, H, pool_type, d)
402
+ c_defs = self._compose_wordnet_center(config.get("additional_definitions", []), H, pool_type, d)
403
+
404
+ if pool_type == "combination":
405
+ parts = [v for v in (c_uni, c_defs) if v is not None]
406
+ c = np.mean(np.stack(parts, 0), 0) if parts else np.zeros(d, np.float32)
407
+ elif pool_type == "wordnet":
408
+ c = c_defs if c_defs is not None else np.zeros(d, np.float32)
409
+ else:
410
+ c = c_uni if c_uni is not None else np.zeros(d, np.float32)
411
+
412
+ # L1 normalization only
413
+ l1 = float(np.abs(c).sum()) + 1e-8
414
+ c = c / l1
415
+ return self._deterministic_pentachoron(c)
416
+
417
+ def _default_unicode_callback(self, name: str, **kwargs) -> np.ndarray:
418
+ raise NotImplementedError("Default callback is not invoked directly.")
419
+
420
+ # --------------------------- universal builders (overrideable) ---
421
+ def _compose_unicode_center(
422
+ self, token_plain: str, H, pool_type: Optional[str], dim: int
423
+ ) -> Optional[np.ndarray]:
424
+ """
425
+ Build a center vector from the token's Unicode characters - OPTIMIZED.
426
+ """
427
+ # Use batch pooling for all characters at once
428
+ char_list = list(token_plain)
429
+ pooled_chars = self.pooled_batch(char_list)
430
+
431
+ vecs: List[np.ndarray] = []
432
+ for pooled_v in pooled_chars:
433
+ if pooled_v is None:
434
+ continue
435
+ v = np.asarray(pooled_v, np.float32)
436
+ if v.shape[0] != dim:
437
+ raise ValueError(f"Unicode pooled dim mismatch: got {v.shape[0]}, expected {dim}")
438
+ vecs.append(v)
439
+
440
+ if not vecs:
441
+ return None
442
+
443
+ stacked = np.stack(vecs, 0)
444
+
445
+ if pool_type in (None, "unicode", "mean"):
446
+ c = stacked.mean(axis=0)
447
+ elif pool_type == "abs":
448
+ c = np.abs(stacked).mean(axis=0)
449
+ elif pool_type == "dot":
450
+ c = stacked.mean(axis=0)
451
+ c = c / (np.abs(c).sum() + 1e-8) # L1 normalize
452
+ elif pool_type == "mse":
453
+ c = (stacked ** 2).mean(axis=0)
454
+ elif pool_type == "max":
455
+ c = stacked.max(axis=0)
456
+ else:
457
+ raise ValueError(f"Unsupported pool_type '{pool_type}'")
458
+
459
+ return c.astype(np.float32, copy=False)
460
+
461
+ def _compose_wordnet_center(
462
+ self, definitions: List[str], H, pool_type: Optional[str], dim: int
463
+ ) -> Optional[np.ndarray]:
464
+ """Build a center vector from definition text characters - OPTIMIZED."""
465
+ # Collect all characters from all definitions
466
+ all_chars = []
467
+ for text in definitions:
468
+ all_chars.extend(list(str(text)))
469
+
470
+ # Batch lookup
471
+ pooled_chars = self.pooled_batch(all_chars)
472
+
473
+ vecs: List[np.ndarray] = []
474
+ for pooled_v in pooled_chars:
475
+ if pooled_v is None:
476
+ continue
477
+ v = np.asarray(pooled_v, np.float32)
478
+ if v.shape[0] != dim:
479
+ raise ValueError(f"Definition pooled dim mismatch: got {v.shape[0]}, expected {dim}")
480
+ vecs.append(v)
481
+
482
+ if not vecs:
483
+ return None
484
+
485
+ stacked = np.stack(vecs, 0)
486
+
487
+ if pool_type in (None, "unicode", "mean"):
488
+ c = stacked.mean(axis=0)
489
+ elif pool_type == "abs":
490
+ c = np.abs(stacked).mean(axis=0)
491
+ elif pool_type == "dot":
492
+ c = stacked.mean(axis=0)
493
+ c = c / (np.abs(c).sum() + 1e-8) # L1 normalize
494
+ elif pool_type == "mse":
495
+ c = (stacked ** 2).mean(axis=0)
496
+ elif pool_type == "max":
497
+ c = stacked.max(axis=0)
498
+ else:
499
+ raise ValueError(f"Unsupported pool_type '{pool_type}'")
500
+
501
+ return c.astype(np.float32, copy=False)
502
+
503
+ def _deterministic_pentachoron(self, center_vec: np.ndarray) -> np.ndarray:
504
+ """Universal pentachoron inflation (deterministic; overrideable)."""
505
+ d = center_vec.shape[0]
506
+ proposals = np.stack([
507
+ center_vec,
508
+ np.roll(center_vec, 1),
509
+ np.roll(center_vec, 3) * np.sign(center_vec + 1e-8),
510
+ np.roll(center_vec, 7) - center_vec,
511
+ np.roll(center_vec, 11) + center_vec,
512
+ ], 0).astype(np.float32)
513
+
514
+ # L1 row norms
515
+ norms = np.sum(np.abs(proposals), axis=1, keepdims=True) + 1e-8
516
+ Q = proposals / norms
517
+
518
+ # GS orthogonalization with L1 row renorm at each step
519
+ for i in range(5):
520
+ for j in range(i):
521
+ Q[i] -= np.dot(Q[i], Q[j]) * Q[j]
522
+ Q[i] /= (np.sum(np.abs(Q[i])) + 1e-8)
523
+
524
+ gamma = np.array([1.0, 0.9, -0.8, 1.1, 1.2], np.float32)
525
+ X = np.zeros((5, d), np.float32)
526
+ for i in range(5):
527
+ X[i] = center_vec + gamma[i] * Q[i]
528
+ return X - X.mean(0, keepdims=True)
529
+
530
+ # --------------------------- finalize + provenance (overrideable) ----
531
+ def _finalize_crystal(self, X: np.ndarray) -> np.ndarray:
532
+ X = np.asarray(X, np.float32, order='C') # Ensure C-contiguous
533
+ if X.shape != (5, self.dim):
534
+ raise ValueError(f"Crystal must be shape (5, {self.dim}); got {X.shape}.")
535
+ return X - X.mean(0, keepdims=True)
536
+
537
+ def _auto_provenance_from_cfg(self, cfg: Dict[str, Any]) -> dict:
538
+ token = cfg["data"]["token"]
539
+ prov = {
540
+ "source": "special/compose",
541
+ "token": token,
542
+ "pool_type": cfg.get("pool_type") or "unicode",
543
+ "components": list(token),
544
+ "additional_definitions": list(cfg.get("additional_definitions", [])),
545
+ }
546
+ if cfg.get("antonyms"):
547
+ prov["antonyms"] = list(cfg["antonyms"])
548
+ if cfg.get("inversion_formula") is not None:
549
+ prov["inversion_formula"] = "user_supplied"
550
+ return prov
551
+
552
+ def _finalize_crystal_and_provenance(
553
+ self, product: Union[np.ndarray, Dict[str, Any]], cfg: Dict[str, Any]
554
+ ) -> Tuple[np.ndarray, dict]:
555
+ # ndarray path
556
+ if isinstance(product, np.ndarray):
557
+ X = self._finalize_crystal(product)
558
+ prov = self._auto_provenance_from_cfg(cfg)
559
+ return X, prov
560
+
561
+ # dict path
562
+ if not isinstance(product, dict):
563
+ raise TypeError(
564
+ "create_crystal must return ndarray or dict with {'base':..., 'ops':..., 'provenance':...}.")
565
+ base = np.asarray(product["base"], np.float32)
566
+ X = base
567
+ for op in product.get("ops", []):
568
+ name = op.get("name")
569
+ if name == "center":
570
+ X -= X.mean(0, keepdims=True)
571
+ elif name == "scale":
572
+ X *= float(op.get("k", 1.0))
573
+ elif name == "translate":
574
+ t = np.asarray(op.get("t"), np.float32)
575
+ if t.shape != (self.dim,):
576
+ raise ValueError(f"translate.t must be shape ({self.dim},)")
577
+ X = X + t[None, :]
578
+ elif name == "normalize_rows":
579
+ n = np.sum(np.abs(X), axis=1, keepdims=True) + 1e-8
580
+ X = X / n
581
+ elif name == "align_to":
582
+ v = np.asarray(op.get("v"), np.float32)
583
+ if v.shape != (self.dim,):
584
+ raise ValueError(f"align_to.v must be shape ({self.dim},)")
585
+ v = v / (np.abs(v).sum() + 1e-8)
586
+ p = X.mean(0)
587
+ p = p / (np.abs(p).sum() + 1e-8)
588
+ alpha = float(op.get("alpha", 1.0))
589
+ X = X + alpha * (v - p)[None, :]
590
+ else:
591
+ raise ValueError(f"Unsupported op '{name}'")
592
+ prov = dict(product.get("provenance", {})) or self._auto_provenance_from_cfg(cfg)
593
+ return self._finalize_crystal(X), prov
594
+
595
+ # --------------------------- universal manifestation routine ----------
596
+ def _manifest_special_tokens(
597
+ self,
598
+ base_set: Dict[str, int],
599
+ create_crystal: Callable[[dict, Callable[..., np.ndarray]], Union[np.ndarray, Dict[str, Any]]],
600
+ callback: Optional[Callable[..., np.ndarray]],
601
+ create_config: Dict[str, Any],
602
+ ) -> None:
603
+ """Universal, deterministic manifestor with character pre-caching."""
604
+
605
+ # NEW: Pre-cache all unique characters that will be needed
606
+ unique_chars = set()
607
+ for name in base_set.keys():
608
+ token_plain = name.strip("<>").strip()
609
+ unique_chars.update(token_plain)
610
+
611
+ print(f"[⚡] Pre-caching {len(unique_chars)} unique characters...")
612
+ for ch in unique_chars:
613
+ _ = self.pooled(ch) # Trigger caching
614
+
615
+ helpers = self._helpers()
616
+
617
+ for name, tid in base_set.items():
618
+ # Keep if already present
619
+ if tid in self._id_to_vec:
620
+ self._token_to_id[name] = tid
621
+ self._id_to_token.setdefault(tid, name)
622
+ self._valid_token_ids.add(tid)
623
+ continue
624
+
625
+ # Build per-token config
626
+ cfg = {
627
+ "dim": self.dim,
628
+ "pool_type": create_config.get("pool_type", None),
629
+ "special_tokens": create_config.get("special_tokens"),
630
+ "additional_definitions": create_config.get("additional_definitions", []),
631
+ "antonyms": create_config.get("antonyms"),
632
+ "inversion_formula": create_config.get("inversion_formula"),
633
+ "data": {"token": name.strip("<>").strip(), "token_id": tid, "origin": "special"},
634
+ "helpers": helpers,
635
+ }
636
+
637
+ if create_crystal is None:
638
+ create_crystal = self._default_create_crystal
639
+
640
+ product = create_crystal(cfg, callback) if callback is not None else create_crystal(cfg,
641
+ self._default_unicode_callback)
642
+ X, prov = self._finalize_crystal_and_provenance(product, cfg)
643
+
644
+ # Register
645
+ self._token_to_id[name] = tid
646
+ self._id_to_token[tid] = name
647
+ self._id_to_vec[tid] = X.astype(np.float32, copy=False, order='C')
648
+ self._id_to_provenance[tid] = prov
649
+ self._valid_token_ids.add(tid)
650
+ self._id_to_volume.setdefault(tid, 1.0)
651
+
652
+ # Aliases
653
+ for alias in (cfg.get("special_tokens") or []):
654
+ alias = str(alias)
655
+ self._token_to_id[alias] = tid
656
+ self._id_to_token.setdefault(tid, alias)
657
+ if cfg.get("special_tokens"):
658
+ self._id_to_provenance[tid].setdefault("aliases", list(cfg["special_tokens"]))
659
+
660
+ # Antonyms
661
+ antonyms = cfg.get("antonyms") or []
662
+ invf = cfg.get("inversion_formula")
663
+ if invf:
664
+ for anti in antonyms:
665
+ if anti in base_set:
666
+ anti_id = base_set[anti]
667
+ if anti_id not in self._id_to_vec:
668
+ X_inv = invf(X, cfg) # must be deterministic
669
+ X_inv = self._finalize_crystal(X_inv)
670
+ self._token_to_id[anti] = anti_id
671
+ self._id_to_token[anti_id] = anti
672
+ self._id_to_vec[anti_id] = X_inv.astype(np.float32, copy=False, order='C')
673
+ inv_prov = {
674
+ "source": "inversion",
675
+ "of_token": name,
676
+ "of_token_id": tid,
677
+ "pool_type": cfg.get("pool_type") or "unicode",
678
+ "components": prov.get("components", []),
679
+ "additional_definitions": cfg.get("additional_definitions", []),
680
+ "ops": ["invert"],
681
+ }
682
+ self._id_to_provenance[anti_id] = inv_prov
683
+ self._valid_token_ids.add(anti_id)
684
+ self._id_to_volume.setdefault(anti_id, 1.0)
685
+
686
+ # Invalidate caches after adding tokens
687
+ self._invalidate_caches()
688
+
689
+ if self._char_lookups_saved > 0:
690
+ print(f"[✅] Character cache saved {self._char_lookups_saved} lookups")
691
+
692
+ # --------------------------- basics -------------------------------
693
+ def vocab_size(self) -> int:
694
+ return len(self._token_to_id)
695
+
696
+ def token_to_id(self, token: str) -> Optional[int]:
697
+ return self._token_to_id.get(token)
698
+
699
+ def id_to_token(self, token_id: int) -> Optional[str]:
700
+ return self._id_to_token.get(token_id)
701
+
702
+ def cache_stats(self) -> Dict[str, int]:
703
+ """Get cache statistics."""
704
+ return {
705
+ "normalized_cache_size": len(self._normalized_cache),
706
+ "pooled_cache_size": len(self._pooled_cache),
707
+ "char_cache_size": len(self._char_cache),
708
+ "char_lookups_saved": self._char_lookups_saved,
709
+ "spatial_index_size": len(self._spatial_index.token_ids) if self._spatial_index else 0,
710
+ "vocab_size": len(self._valid_token_ids)
711
+ }
712
+
713
+ def clear_caches(self):
714
+ """Clear all caches to free memory."""
715
+ self._invalidate_caches()
716
+ self._char_cache.clear()
717
+ self._char_lookups_saved = 0
718
+
719
+
720
+ from typing import List, Dict, Union, Optional, Tuple, Callable, Any
721
+
722
+ class PretrainedGeometricVocab(GeometricVocab):
723
+ """
724
+ Parquet-backed deterministic vocab with columnar load, duplicate-mean aggregation,
725
+ pooled caching, and fast path for flat crystals.
726
+ """
727
+ def __init__(
728
+ self,
729
+ repo_id: str,
730
+ dim: int,
731
+ *,
732
+ subset: str = "unicode",
733
+ split: str = "train_100d",
734
+ base_set: Optional[Dict[str, int]] = None,
735
+ create_config: Optional[Dict[str, Any]] = None,
736
+ create_crystal: Optional[Callable[[dict, Callable[..., np.ndarray]], Union[np.ndarray, Dict[str, Any]]]] = None,
737
+ callback: Optional[Callable[..., np.ndarray]] = None,
738
+ manifest_specials: bool = True,
739
+ # perf/robustness knobs
740
+ store: str = "full", # "full" | "pooled" | "both"
741
+ reshape_order: str = "C",
742
+ vertex_count: int = 5,
743
+ infer_dim: bool = True,
744
+ strict_shapes: bool = False,
745
+ # new perf knobs
746
+ finalize_mode: str = "post_mean", # "none" | "post_mean"
747
+ cache_pooled: bool = True,
748
+ streaming=False,
749
+ ):
750
+ super().__init__(dim)
751
+ self.repo_id = str(repo_id)
752
+ self._id_to_pooled: Dict[int, np.ndarray] = {} # optional pooled cache
753
+
754
+ # ---------- load split (columnar, minimal columns) ----------
755
+ ds = load_dataset(self.repo_id, split=split)
756
+ have = set(ds.column_names)
757
+ wanted = ["token_id", "token", "crystal", "volume"]
758
+ keep = [c for c in wanted if c in have]
759
+ drop = [c for c in ds.column_names if c not in keep]
760
+ if drop:
761
+ ds = ds.remove_columns(drop)
762
+ ds = ds.with_format("numpy", columns=keep)
763
+
764
+ ids = ds["token_id"] if "token_id" in keep else np.array([], dtype=np.int64)
765
+ toks = ds["token"] if "token" in keep else np.array([], dtype=object)
766
+ cryst= ds["crystal"] if "crystal" in keep else np.array([], dtype=object)
767
+ vols = ds["volume"] if "volume" in keep else None
768
+
769
+ ids = np.asarray(ids).astype(np.int64, copy=False)
770
+ toks = np.asarray(toks)
771
+
772
+ # --------- shape helpers ----------
773
+ def _coerce(raw: Any) -> np.ndarray:
774
+ X = np.asarray(raw, np.float32)
775
+ if X.ndim == 2:
776
+ V, D = int(X.shape[0]), int(X.shape[1])
777
+ if V != vertex_count:
778
+ raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.")
779
+ if D != self.dim:
780
+ if infer_dim: self.dim = D
781
+ else: raise ValueError(f"Dim mismatch: got {D}, expected {self.dim}.")
782
+ return X
783
+ if X.ndim == 1:
784
+ n = int(X.size)
785
+ if n == vertex_count * self.dim:
786
+ return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
787
+ if infer_dim and n % vertex_count == 0:
788
+ self.dim = n // vertex_count
789
+ return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
790
+ if n == self.dim:
791
+ c = X / (np.abs(X).sum() + 1e-8)
792
+ return self._deterministic_pentachoron(c)
793
+ raise ValueError(f"Unsupported crystal shape {X.shape if isinstance(X, np.ndarray) else type(X)}.")
794
+
795
+ def _finalize_if_needed(X: np.ndarray) -> np.ndarray:
796
+ if finalize_mode == "none":
797
+ return np.asarray(X, np.float32, order="C")
798
+ elif finalize_mode == "post_mean":
799
+ return self._finalize_crystal(X)
800
+ else:
801
+ raise ValueError(f"finalize_mode must be 'none' or 'post_mean', got {finalize_mode!r}")
802
+
803
+ vols_f = np.asarray(vols, dtype=np.float32) if vols is not None else None
804
+
805
+ # ---------- FAST PATH: flat uniform crystals ----------
806
+ # Try to stack into (N, L); succeeds when each row is the same length.
807
+ fastpath_ok = False
808
+ A = None # (N, L) float32
809
+ try:
810
+ A = np.stack(cryst) # may raise if jagged / object
811
+ if A.ndim == 2 and A.dtype != object:
812
+ A = A.astype(np.float32, copy=False)
813
+ L = A.shape[1]
814
+ if L % vertex_count == 0:
815
+ # infer or validate D
816
+ D = L // vertex_count
817
+ if self.dim != D:
818
+ if infer_dim:
819
+ self.dim = int(D)
820
+ else:
821
+ raise ValueError(f"Dim mismatch: got D={D}, expected dim={self.dim}.")
822
+ fastpath_ok = True
823
+ except Exception:
824
+ fastpath_ok = False
825
+
826
+ if fastpath_ok and A is not None and len(ids) > 0:
827
+ # reshape to (N, V, D)
828
+ V = vertex_count
829
+ D = self.dim
830
+ A = A.reshape(-1, V, D, order=reshape_order)
831
+
832
+ # sort by ids and reduceat to mean duplicates in pure NumPy
833
+ order = np.argsort(ids, kind="stable")
834
+ ids_sorted = ids[order]
835
+ A_sorted = A[order]
836
+ vols_sorted = vols_f[order] if vols_f is not None else None
837
+
838
+ uniq_ids, idx, counts = np.unique(ids_sorted, return_index=True, return_counts=True)
839
+ sums = np.add.reduceat(A_sorted, idx, axis=0) # (K, V, D)
840
+ means = sums / counts[:, None, None] # (K, V, D)
841
+
842
+ if vols_sorted is not None:
843
+ v_sums = np.add.reduceat(vols_sorted, idx)
844
+ v_means = v_sums / counts.astype(np.float32)
845
+ else:
846
+ v_means = np.ones_like(uniq_ids, dtype=np.float32)
847
+
848
+ # commit maps
849
+ self._token_to_id.clear(); self._id_to_token.clear()
850
+ self._id_to_vec.clear(); self._id_to_volume.clear(); self._valid_token_ids.clear()
851
+ self._id_to_pooled.clear()
852
+
853
+ # pick a representative token per id: first occurrence in sorted block
854
+ toks_sorted = toks[order]
855
+ rep_toks = toks_sorted[idx]
856
+
857
+ for tid, tok, X_mean, v_m in zip(uniq_ids.tolist(), rep_toks.tolist(), means, v_means.tolist()):
858
+ # cache pooled BEFORE finalize to preserve signal
859
+ if cache_pooled:
860
+ self._id_to_pooled[tid] = X_mean.mean(axis=0).astype(np.float32, copy=False)
861
+ X_store = _finalize_if_needed(X_mean)
862
+
863
+ self._token_to_id[str(tok)] = tid
864
+ self._id_to_token[tid] = str(tok)
865
+ if store in ("full", "both"):
866
+ self._id_to_vec[tid] = np.asarray(X_store, np.float32, order="C")
867
+ elif store == "pooled":
868
+ # store pooled as embedding if desired
869
+ self._id_to_vec[tid] = (self._id_to_pooled[tid] if cache_pooled
870
+ else X_mean.mean(axis=0).astype(np.float32, copy=False))
871
+ self._id_to_volume[tid] = float(v_m)
872
+ self._valid_token_ids.add(tid)
873
+
874
+ else:
875
+ # ---------- FALLBACK: per-row coerce + dict mean ----------
876
+ ids_int = ids.tolist()
877
+ toks_str = [str(x) for x in toks.tolist()]
878
+ vols_f = (vols_f.tolist() if vols_f is not None else [1.0] * len(ids_int))
879
+
880
+ x_sum: Dict[int, np.ndarray] = {}
881
+ v_sum: Dict[int, float] = {}
882
+ n_cnt: Dict[int, int] = {}
883
+ tok_pref: Dict[int, str] = {}
884
+
885
+ for tid, tok, raw, vol in zip(ids_int, toks_str, cryst, vols_f):
886
+ X = _coerce(raw) # [V,D] float32
887
+ if tid not in x_sum:
888
+ x_sum[tid] = X.astype(np.float32, copy=True)
889
+ v_sum[tid] = float(vol)
890
+ n_cnt[tid] = 1
891
+ tok_pref[tid] = tok
892
+ else:
893
+ x_sum[tid] += X
894
+ v_sum[tid] += float(vol)
895
+ n_cnt[tid] += 1
896
+
897
+ self._token_to_id.clear(); self._id_to_token.clear()
898
+ self._id_to_vec.clear(); self._id_to_volume.clear(); self._valid_token_ids.clear()
899
+ self._id_to_pooled.clear()
900
+
901
+ for tid in x_sum.keys(): # order not critical; add sorted(tids) if you need determinism
902
+ X_mean = x_sum[tid] / float(n_cnt[tid])
903
+ if cache_pooled:
904
+ self._id_to_pooled[tid] = X_mean.mean(axis=0).astype(np.float32, copy=False)
905
+ X_store = _finalize_if_needed(X_mean)
906
+
907
+ tok = tok_pref[tid]
908
+ vol_m = v_sum[tid] / float(n_cnt[tid])
909
+
910
+ self._token_to_id[tok] = tid
911
+ self._id_to_token[tid] = tok
912
+ if store in ("full", "both"):
913
+ self._id_to_vec[tid] = np.asarray(X_store, np.float32, order="C")
914
+ elif store == "pooled":
915
+ self._id_to_vec[tid] = (self._id_to_pooled[tid] if cache_pooled
916
+ else X_mean.mean(axis=0).astype(np.float32, copy=False))
917
+ self._id_to_volume[tid] = float(vol_m)
918
+ self._valid_token_ids.add(tid)
919
+
920
+ # ---------- specials ----------
921
+ if manifest_specials and base_set:
922
+ self._manifest_special_tokens(
923
+ base_set=base_set,
924
+ create_crystal=create_crystal,
925
+ callback=callback,
926
+ create_config=create_config or {}
927
+ )
928
+
929
+ # -------- override pooled() to use cache (if present) --------
930
+ def pooled(self, token_or_id: Union[str, int], method: str = "mean") -> Optional[np.ndarray]:
931
+ # Favor cached pooled when available; fallback to base (computes mean)
932
+ tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
933
+ if tid is not None and tid in self._id_to_pooled:
934
+ return self._id_to_pooled[tid]
935
+ return super().pooled(token_or_id, method=method)
936
+
937
+ # -------- SP-like surface --------
938
+ def encode(self, token: str, *, return_id: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
939
+ tid = self._token_to_id.get(token)
940
+ if tid is None:
941
+ unk_id = self._token_to_id.get("<unk>")
942
+ if unk_id is None:
943
+ raise KeyError(f"Token '{token}' not found and '<unk>' missing.")
944
+ X = self._id_to_vec[unk_id]
945
+ return (X, unk_id) if return_id else X
946
+ X = self._id_to_vec[tid]
947
+ return (X, tid) if return_id else X
948
+
949
+ def get_score(self, token_or_id: Union[str, int]) -> float:
950
+ tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id, None)
951
+ if tid is None or tid not in self._valid_token_ids:
952
+ return -100.0
953
+ vol = self._id_to_volume.get(tid, 1.0)
954
+ return float(np.clip(vol / 10.0, 0.01, 1.0))
955
+
956
+ # -------- Torch cache ----------
957
+ def cache(self, tokens: Union[List[str], Dict[str, int]], device: str = "cpu", dtype: torch.dtype = torch.float32):
958
+ tok_list = list(tokens.keys()) if isinstance(tokens, dict) else list(tokens)
959
+ mats, pooled, keep = [], [], []
960
+ for t in tok_list:
961
+ X = self.embedding(t)
962
+ v = self.pooled(t)
963
+ if X is None or v is None:
964
+ continue
965
+ mats.append(torch.as_tensor(X, dtype=dtype))
966
+ pooled.append(torch.as_tensor(v, dtype=dtype))
967
+ keep.append(t)
968
+ if not mats:
969
+ raise ValueError("No valid tokens found in input.")
970
+ return {
971
+ "tokens": keep,
972
+ "crystals": torch.stack(mats, 0).to(device),
973
+ "pooled": torch.stack(pooled, 0).to(device),
974
+ }
975
+
976
+
977
+ def _coerce_crystal_shape(
978
+ self,
979
+ raw: Any,
980
+ *,
981
+ vertex_count: int,
982
+ reshape_order: str,
983
+ infer_dim: bool,
984
+ strict_shapes: bool
985
+ ) -> np.ndarray:
986
+ """
987
+ Accepts raw crystal data and returns [vertex_count, self.dim] float32 C-order.
988
+
989
+ Acceptable inputs:
990
+ - [vertex_count, D]
991
+ - [vertex_count * D] (flat) -> reshaped to [vertex_count, D]
992
+ - [D] (pooled center) -> converted by deterministic pentachoron (fallback)
993
+ """
994
+ X = np.asarray(raw, dtype=np.float32)
995
+
996
+ # Already [V, D]
997
+ if X.ndim == 2:
998
+ V, D = int(X.shape[0]), int(X.shape[1])
999
+ if V != vertex_count:
1000
+ if strict_shapes:
1001
+ raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.")
1002
+ # Gentle fallback: attempt to treat rows as vertices if divisible
1003
+ if V * D % vertex_count == 0 and infer_dim:
1004
+ # e.g., [10, D] -> try to collapse/average into [5,D]? Not safe.
1005
+ # Safer: hard error to avoid silent geometry change.
1006
+ raise ValueError(f"Unexpected vertex rows {V}; refusing to coerce silently.")
1007
+ else:
1008
+ raise ValueError(f"Crystal has {V} vertices, expected {vertex_count}.")
1009
+ # Update dim if needed
1010
+ if D != self.dim:
1011
+ if infer_dim:
1012
+ self.dim = D
1013
+ else:
1014
+ raise ValueError(f"Dim mismatch: got D={D}, expected dim={self.dim}.")
1015
+ # Ensure mean-centered (finalize handles centering)
1016
+ return X
1017
+
1018
+ # Flat [V*D]
1019
+ if X.ndim == 1:
1020
+ n = int(X.size)
1021
+ # Exact match for flat crystal
1022
+ if n == vertex_count * self.dim:
1023
+ return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
1024
+
1025
+ # Infer D from total length if divisible
1026
+ if infer_dim and n % vertex_count == 0:
1027
+ inferred = n // vertex_count
1028
+ self.dim = int(inferred)
1029
+ return np.reshape(X, (vertex_count, self.dim), order=reshape_order)
1030
+
1031
+ # Pooled [D]: inflate deterministically to [V, D]
1032
+ if n == self.dim:
1033
+ c = X / (np.abs(X).sum() + 1e-8) # L1
1034
+ return self._deterministic_pentachoron(c)
1035
+
1036
+ if strict_shapes:
1037
+ raise ValueError(
1038
+ f"Cannot coerce crystal of length {n}. "
1039
+ f"Expected {vertex_count*self.dim} (flat) or {self.dim} (pooled)."
1040
+ )
1041
+ # Conservative fallback: treat as pooled center with inferred D if reasonable
1042
+ if infer_dim and n > 0:
1043
+ self.dim = n
1044
+ c = X / (np.abs(X).sum() + 1e-8)
1045
+ return self._deterministic_pentachoron(c)
1046
+
1047
+ raise ValueError(f"Unsupported crystal shape {X.shape} (ndim={X.ndim}).")
1048
+
1049
+
1050
+ # -------- Introspection --------
1051
+ def describe(self) -> Dict[str, Union[str, int]]:
1052
+ return {"repo": self.repo_id, "dimension": self.dim, "vocab_size": self.vocab_size()}
1053
+
1054
+
1055
+ from __future__ import annotations
1056
+ import torch
1057
+ import numpy as np
1058
+ from abc import ABC, abstractmethod
1059
+ from typing import Dict, Union, Tuple, Optional, Callable, Any, List
1060
+ import warnings
1061
+ from collections import OrderedDict
1062
+ import datasets
1063
+ from datasets import load_dataset
1064
+
1065
+ # Global flag for warning suppression
1066
+ SILENT_MODE = False
1067
+
1068
+ def set_silent_mode(silent: bool):
1069
+ """Set global silent mode for token synthesis warnings"""
1070
+ global SILENT_MODE
1071
+ SILENT_MODE = silent
1072
+
1073
+ class LRUCache(OrderedDict):
1074
+ """Simple LRU cache implementation"""
1075
+ def __init__(self, maxsize=128):
1076
+ super().__init__()
1077
+ self.maxsize = maxsize
1078
+
1079
+ def __getitem__(self, key):
1080
+ value = super().__getitem__(key)
1081
+ self.move_to_end(key)
1082
+ return value
1083
+
1084
+ def __setitem__(self, key, value):
1085
+ if key in self:
1086
+ self.move_to_end(key)
1087
+ super().__setitem__(key, value)
1088
+ if len(self) > self.maxsize:
1089
+ oldest = next(iter(self))
1090
+ del self[oldest]
1091
+
1092
+
1093
+ class LazyGeometricVocab(GeometricVocab):
1094
+ """
1095
+ Lazy-loading geometric vocabulary that loads tokens on demand.
1096
+ Maintains a small working set in memory with LRU eviction.
1097
+ Supports automatic token synthesis for missing tokens.
1098
+ """
1099
+
1100
+ def __init__(
1101
+ self,
1102
+ repo_id: str,
1103
+ dim: int,
1104
+ *,
1105
+ name: str = "unicode_100d", # Updated default to match new structure
1106
+ split: str = "train", # Updated default to "train"
1107
+ stream: bool = True, # Use streaming by default to avoid bulk downloads
1108
+ base_set: Optional[Dict[str, int]] = None,
1109
+ create_config: Optional[Dict[str, Any]] = None,
1110
+ create_crystal: Optional[Callable] = None,
1111
+ callback: Optional[Callable] = None,
1112
+ manifest_specials: bool = True,
1113
+ # Lazy loading parameters
1114
+ cache_size: int = 1000, # Max tokens to keep in memory
1115
+ preload_tokens: Optional[List[str]] = None, # Critical tokens to preload
1116
+ index_cache_path: Optional[str] = None, # Path to save/load index
1117
+ # Tokenization
1118
+ tokenizer: Optional[Callable[[str], List[str]]] = None, # Custom tokenizer
1119
+ # Synthesis settings
1120
+ silent: bool = False, # Suppress synthesis warnings
1121
+ # Performance knobs
1122
+ store: str = "full",
1123
+ reshape_order: str = "C",
1124
+ vertex_count: int = 5,
1125
+ infer_dim: bool = True,
1126
+ finalize_mode: str = "post_mean",
1127
+ cache_pooled: bool = True,
1128
+ ):
1129
+ super().__init__(dim)
1130
+
1131
+ self.repo_id = repo_id
1132
+ self.name = name
1133
+ self.split = split
1134
+ self.stream = stream
1135
+ self.vertex_count = vertex_count
1136
+ self.reshape_order = reshape_order
1137
+ self.infer_dim = infer_dim
1138
+ self.finalize_mode = finalize_mode
1139
+ self.store = store
1140
+ self.cache_pooled = cache_pooled
1141
+ self.silent = silent
1142
+
1143
+ # Initialize pooled dictionary that may be missing from parent class
1144
+ if not hasattr(self, '_id_to_pooled'):
1145
+ self._id_to_pooled = {}
1146
+
1147
+ # For synthesis
1148
+ self.create_crystal_fn = create_crystal
1149
+ self.callback_fn = callback
1150
+ self.create_config = create_config or {}
1151
+ self._synthesized_tokens: set = set()
1152
+ self._next_synthetic_id = -1 # Use negative IDs for synthetic tokens
1153
+
1154
+ # Tokenizer - default to simple split
1155
+ self.tokenizer = tokenizer or (lambda s: s.split())
1156
+
1157
+ # LRU caches for lazy loading
1158
+ self._crystal_cache = LRUCache(maxsize=cache_size)
1159
+ self._pooled_lru = LRUCache(maxsize=cache_size * 2) # Pooled vectors are smaller
1160
+
1161
+ # Load dataset but don't fetch data yet
1162
+ self._dataset = None
1163
+ self._dataset_stream = None
1164
+ self._token_index: Dict[str, List[int]] = {} # token -> [row indices]
1165
+ self._id_index: Dict[int, List[int]] = {} # token_id -> [row indices]
1166
+ self._row_data: Dict[int, dict] = {} # row -> cached data
1167
+
1168
+ # Initialize index
1169
+ self._build_index(split, name)
1170
+
1171
+ # Pre-load base characters for synthesis
1172
+ self._preload_synthesis_base()
1173
+
1174
+ # Preload critical tokens if specified
1175
+ if preload_tokens:
1176
+ self._preload(preload_tokens)
1177
+
1178
+ # Manifest special tokens
1179
+ if manifest_specials and base_set:
1180
+ self._manifest_special_tokens(
1181
+ base_set=base_set,
1182
+ create_crystal=create_crystal,
1183
+ callback=callback,
1184
+ create_config=create_config or {}
1185
+ )
1186
+
1187
+ def _preload_synthesis_base(self):
1188
+ """Pre-load basic ASCII characters needed for synthesis"""
1189
+ # Essential characters that are commonly used in token synthesis
1190
+ base_chars = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-_()[]{}:;'\"")
1191
+
1192
+ print(f"Pre-loading {len(base_chars)} base characters for synthesis...")
1193
+ loaded = 0
1194
+ for char in base_chars:
1195
+ tid = self._token_to_id.get(char)
1196
+ if tid:
1197
+ # Pre-load this character's embedding
1198
+ if self._load_crystal(tid) is not None:
1199
+ loaded += 1
1200
+ print(f"Loaded {loaded} base characters")
1201
+
1202
+ def _build_index(self, split: str, name: str):
1203
+ """Build token/id index without loading crystal data"""
1204
+ print(f"Building index for {self.repo_id}/{name}/{split}...")
1205
+
1206
+ if self.stream:
1207
+ try:
1208
+ # Use streaming to avoid downloading all splits
1209
+ # Don't specify columns in streaming mode to avoid schema issues
1210
+ self._dataset_stream = load_dataset(
1211
+ self.repo_id,
1212
+ name=name,
1213
+ split=split,
1214
+ streaming=True
1215
+ )
1216
+
1217
+ # Build index from streaming dataset
1218
+ for idx, row in enumerate(self._dataset_stream):
1219
+ token = str(row["token"])
1220
+ token_id = int(row["token_id"])
1221
+
1222
+ # Token index
1223
+ if token not in self._token_index:
1224
+ self._token_index[token] = []
1225
+ self._token_index[token].append(idx)
1226
+
1227
+ # ID index
1228
+ if token_id not in self._id_index:
1229
+ self._id_index[token_id] = []
1230
+ self._id_index[token_id].append(idx)
1231
+
1232
+ # Update mappings (use first occurrence)
1233
+ if token not in self._token_to_id:
1234
+ self._token_to_id[token] = token_id
1235
+ self._id_to_token[token_id] = token
1236
+ self._valid_token_ids.add(token_id)
1237
+
1238
+ print(f"Index built (streaming): {len(self._token_index)} unique tokens")
1239
+
1240
+ except Exception as e:
1241
+ print(f"Streaming failed: {e}")
1242
+ print("Falling back to non-streaming mode...")
1243
+ self.stream = False
1244
+ # Recursive call with streaming disabled
1245
+ self._build_index(split, name)
1246
+
1247
+ else:
1248
+ # Non-streaming mode - load dataset normally
1249
+ try:
1250
+ # Try with data_files to load only specific split
1251
+ data_files = f"data/{name}/{split}-*.parquet"
1252
+ ds = load_dataset(
1253
+ self.repo_id,
1254
+ data_files=data_files,
1255
+ split="train"
1256
+ )
1257
+ except:
1258
+ # Fallback to normal loading
1259
+ try:
1260
+ ds = load_dataset(
1261
+ self.repo_id,
1262
+ name=name,
1263
+ split=split
1264
+ )
1265
+ except Exception as e:
1266
+ print(f"Failed to load dataset: {e}")
1267
+ raise
1268
+
1269
+ # Build indices
1270
+ for idx, row in enumerate(ds):
1271
+ token = str(row["token"])
1272
+ token_id = int(row["token_id"])
1273
+
1274
+ # Token index
1275
+ if token not in self._token_index:
1276
+ self._token_index[token] = []
1277
+ self._token_index[token].append(idx)
1278
+
1279
+ # ID index
1280
+ if token_id not in self._id_index:
1281
+ self._id_index[token_id] = []
1282
+ self._id_index[token_id].append(idx)
1283
+
1284
+ # Update mappings (use first occurrence)
1285
+ if token not in self._token_to_id:
1286
+ self._token_to_id[token] = token_id
1287
+ self._id_to_token[token_id] = token
1288
+ self._valid_token_ids.add(token_id)
1289
+
1290
+ # Store dataset reference (will lazy load full data)
1291
+ self._dataset = ds
1292
+ print(f"Index built: {len(self._token_index)} unique tokens")
1293
+
1294
+ def _load_row(self, row_idx: int) -> dict:
1295
+ """Load a single row from dataset"""
1296
+ if row_idx in self._row_data:
1297
+ return self._row_data[row_idx]
1298
+
1299
+ # If streaming, need to load the full dataset on first data access
1300
+ if self.stream and self._dataset is None:
1301
+ print(f"Loading full dataset for {self.repo_id}/{self.name}/{self.split}...")
1302
+ try:
1303
+ # Try with data_files first
1304
+ data_files = f"data/{self.name}/{self.split}-*.parquet"
1305
+ self._dataset = load_dataset(
1306
+ self.repo_id,
1307
+ data_files=data_files,
1308
+ split="train"
1309
+ )
1310
+ except:
1311
+ # Fallback to normal loading
1312
+ self._dataset = load_dataset(
1313
+ self.repo_id,
1314
+ name=self.name,
1315
+ split=self.split
1316
+ )
1317
+
1318
+ if self._dataset is None:
1319
+ raise RuntimeError("Dataset not initialized")
1320
+
1321
+ row = self._dataset[row_idx]
1322
+ self._row_data[row_idx] = row
1323
+ return row
1324
+
1325
+ def _load_crystal(self, token_id: int) -> Optional[np.ndarray]:
1326
+ """Load and aggregate crystal for a token_id"""
1327
+ if token_id in self._crystal_cache:
1328
+ return self._crystal_cache[token_id]
1329
+
1330
+ if token_id not in self._id_index:
1331
+ return None
1332
+
1333
+ row_indices = self._id_index[token_id]
1334
+ crystals = []
1335
+ volumes = []
1336
+
1337
+ for idx in row_indices:
1338
+ row = self._load_row(idx)
1339
+
1340
+ # Parse crystal
1341
+ raw_crystal = row.get("crystal")
1342
+ if raw_crystal is not None:
1343
+ X = self._coerce_crystal(raw_crystal)
1344
+ crystals.append(X)
1345
+
1346
+ # Get volume if available
1347
+ vol = row.get("volume", 1.0)
1348
+ volumes.append(float(vol))
1349
+
1350
+ if not crystals:
1351
+ return None
1352
+
1353
+ # Average multiple occurrences
1354
+ if len(crystals) == 1:
1355
+ X_final = crystals[0]
1356
+ vol_final = volumes[0]
1357
+ else:
1358
+ X_final = np.mean(crystals, axis=0)
1359
+ vol_final = np.mean(volumes)
1360
+
1361
+ # Finalize
1362
+ if self.finalize_mode == "post_mean":
1363
+ X_final = self._finalize_crystal(X_final)
1364
+
1365
+ # Cache based on store mode
1366
+ if self.store in ("full", "both"):
1367
+ self._crystal_cache[token_id] = X_final
1368
+ self._id_to_vec[token_id] = X_final
1369
+
1370
+ # Cache pooled if requested
1371
+ if self.cache_pooled:
1372
+ pooled = X_final.mean(axis=0)
1373
+ self._pooled_lru[token_id] = pooled
1374
+ if token_id not in self._id_to_pooled:
1375
+ self._id_to_pooled[token_id] = pooled
1376
+
1377
+ # Store volume
1378
+ self._id_to_volume[token_id] = vol_final
1379
+
1380
+ return X_final
1381
+
1382
+ def _coerce_crystal(self, raw: Any) -> np.ndarray:
1383
+ """Convert raw crystal data to proper shape"""
1384
+ X = np.asarray(raw, dtype=np.float32)
1385
+
1386
+ if X.ndim == 2:
1387
+ V, D = X.shape
1388
+ if V != self.vertex_count:
1389
+ raise ValueError(f"Expected {self.vertex_count} vertices, got {V}")
1390
+ if D != self.dim:
1391
+ if self.infer_dim:
1392
+ self.dim = D
1393
+ else:
1394
+ raise ValueError(f"Dimension mismatch: {D} vs {self.dim}")
1395
+ return X
1396
+
1397
+ if X.ndim == 1:
1398
+ n = X.size
1399
+ if n == self.vertex_count * self.dim:
1400
+ return X.reshape((self.vertex_count, self.dim), order=self.reshape_order)
1401
+ if self.infer_dim and n % self.vertex_count == 0:
1402
+ self.dim = n // self.vertex_count
1403
+ return X.reshape((self.vertex_count, self.dim), order=self.reshape_order)
1404
+ if n == self.dim:
1405
+ # Pooled vector - expand to crystal
1406
+ c = X / (np.abs(X).sum() + 1e-8)
1407
+ return self._deterministic_pentachoron(c)
1408
+
1409
+ raise ValueError(f"Cannot coerce crystal shape {X.shape}")
1410
+
1411
+ def _synthesize_token(self, token: str) -> int:
1412
+ """Synthesize a new token embedding on-the-fly with fallback for missing chars."""
1413
+ # Generate a new ID for synthetic token
1414
+ tid = self._next_synthetic_id
1415
+ self._next_synthetic_id -= 1
1416
+
1417
+ # Warn user unless silenced
1418
+ if not self.silent and not SILENT_MODE:
1419
+ warnings.warn(
1420
+ f"Token '{token}' synthesized - ensure you synthesize your tokens ahead of time.",
1421
+ UserWarning,
1422
+ stacklevel=3
1423
+ )
1424
+
1425
+ # Track as synthesized
1426
+ self._synthesized_tokens.add(token)
1427
+
1428
+ # Try to use character-based synthesis first
1429
+ try:
1430
+ # Check if all characters are available
1431
+ missing_chars = []
1432
+ for char in token:
1433
+ if char not in self._token_to_id and char not in self._char_cache:
1434
+ missing_chars.append(char)
1435
+
1436
+ # If missing chars, try to load or synthesize them first
1437
+ if missing_chars:
1438
+ for char in missing_chars:
1439
+ char_tid = self._token_to_id.get(char)
1440
+ if char_tid:
1441
+ # Try to load it
1442
+ self._load_crystal(char_tid)
1443
+ else:
1444
+ # Create a simple embedding for this character
1445
+ self._synthesize_simple_char(char)
1446
+
1447
+ # Now try the full synthesis
1448
+ helpers = self._helpers()
1449
+ cfg = {
1450
+ "dim": self.dim,
1451
+ "pool_type": self.create_config.get("pool_type", "unicode"),
1452
+ "data": {"token": token, "token_id": tid, "origin": "synthetic"},
1453
+ "helpers": helpers,
1454
+ }
1455
+
1456
+ if self.create_crystal_fn is not None:
1457
+ product = self.create_crystal_fn(cfg, self.callback_fn)
1458
+ else:
1459
+ product = self._default_create_crystal(cfg, self._default_unicode_callback)
1460
+
1461
+ X, prov = self._finalize_crystal_and_provenance(product, cfg)
1462
+
1463
+ except Exception as e:
1464
+ # Fallback to simple random synthesis
1465
+ print(f"Character-based synthesis failed for '{token}': {e}. Using random synthesis.")
1466
+ X = self._synthesize_random_crystal(token)
1467
+ prov = {"source": "synthetic_random", "token": token}
1468
+
1469
+ prov["synthetic"] = True
1470
+
1471
+ # Register in all maps
1472
+ self._token_to_id[token] = tid
1473
+ self._id_to_token[tid] = token
1474
+ self._id_to_vec[tid] = X.astype(np.float32, copy=False, order='C')
1475
+ self._id_to_provenance[tid] = prov
1476
+ self._valid_token_ids.add(tid)
1477
+ self._id_to_volume[tid] = 1.0
1478
+
1479
+ # Cache
1480
+ self._crystal_cache[tid] = X
1481
+ if self.cache_pooled:
1482
+ pooled = X.mean(axis=0)
1483
+ self._pooled_lru[tid] = pooled
1484
+ self._id_to_pooled[tid] = pooled
1485
+
1486
+ return tid
1487
+
1488
+ def _synthesize_simple_char(self, char: str):
1489
+ """Create a simple deterministic embedding for a single character"""
1490
+ import hashlib
1491
+
1492
+ # Use character's unicode codepoint for deterministic generation
1493
+ if len(char) == 1:
1494
+ seed = ord(char)
1495
+ else:
1496
+ seed = int(hashlib.md5(char.encode()).hexdigest()[:8], 16)
1497
+
1498
+ np.random.seed(seed)
1499
+
1500
+ # Generate a simple vector based on character properties
1501
+ vec = np.random.randn(self.dim).astype(np.float32)
1502
+ vec = vec / (np.abs(vec).sum() + 1e-8) # L1 normalize
1503
+
1504
+ # Cache it
1505
+ self._char_cache[char] = vec
1506
+
1507
+ def _synthesize_random_crystal(self, token: str) -> np.ndarray:
1508
+ """Fallback: create a deterministic random crystal based on token string"""
1509
+ import hashlib
1510
+
1511
+ # Create deterministic seed from token
1512
+ seed = int(hashlib.md5(token.encode()).hexdigest()[:8], 16)
1513
+ np.random.seed(seed)
1514
+
1515
+ # Generate a random crystal
1516
+ X = np.random.randn(self.vertex_count, self.dim).astype(np.float32)
1517
+ X = self._finalize_crystal(X)
1518
+
1519
+ return X
1520
+
1521
+ def _preload(self, tokens: List[str]):
1522
+ """Preload specific tokens into cache"""
1523
+ print(f"Preloading {len(tokens)} tokens...")
1524
+ for token in tokens:
1525
+ tid = self._token_to_id.get(token)
1526
+ if tid:
1527
+ self._load_crystal(tid)
1528
+
1529
+ # Override base methods to use lazy loading with synthesis
1530
+
1531
+ def embedding(self, token_or_id: Union[str, int], generate: bool = False) -> Optional[np.ndarray]:
1532
+ """Get embedding, loading if necessary, synthesizing if requested"""
1533
+ # Handle token ID input
1534
+ if isinstance(token_or_id, int):
1535
+ tid = token_or_id
1536
+ token = self._id_to_token.get(tid)
1537
+ else:
1538
+ token = token_or_id
1539
+ tid = self._token_to_id.get(token)
1540
+
1541
+ if tid is not None:
1542
+ # Check cache first
1543
+ if tid in self._id_to_vec:
1544
+ return self._id_to_vec[tid]
1545
+ # Load on demand
1546
+ return self._load_crystal(tid)
1547
+
1548
+ # Token not found - synthesize if requested
1549
+ if generate and token is not None:
1550
+ tid = self._synthesize_token(token)
1551
+ return self._id_to_vec[tid]
1552
+
1553
+ return None
1554
+
1555
+ def pooled(self, token_or_id: Union[str, int], method: str = "mean", generate: bool = False) -> Optional[np.ndarray]:
1556
+ """Get pooled vector, loading if necessary, synthesizing if requested"""
1557
+ # Handle token ID input
1558
+ if isinstance(token_or_id, int):
1559
+ tid = token_or_id
1560
+ token = self._id_to_token.get(tid)
1561
+ else:
1562
+ token = token_or_id
1563
+ tid = self._token_to_id.get(token)
1564
+
1565
+ if tid is not None:
1566
+ # Check pooled cache
1567
+ if tid in self._pooled_lru:
1568
+ return self._pooled_lru[tid]
1569
+ if tid in self._id_to_pooled:
1570
+ return self._id_to_pooled[tid]
1571
+
1572
+ # Load crystal and compute pooled
1573
+ X = self.embedding(tid, generate=False)
1574
+ if X is not None:
1575
+ if method == "mean":
1576
+ pooled = X.mean(axis=0)
1577
+ self._pooled_lru[tid] = pooled
1578
+ return pooled
1579
+ elif method == "first":
1580
+ return X[0]
1581
+ elif method == "sum":
1582
+ return X.sum(axis=0)
1583
+ else:
1584
+ raise ValueError(f"Unknown pooling method: {method}")
1585
+
1586
+ # Token not found - synthesize if requested
1587
+ if generate and token is not None:
1588
+ tid = self._synthesize_token(token)
1589
+ return self.pooled(tid, method=method, generate=False)
1590
+
1591
+ return None
1592
+
1593
+ def encode(self, token: str, *, return_id: bool = False, generate: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
1594
+ """Encode token, loading if necessary, synthesizing if requested"""
1595
+ tid = self._token_to_id.get(token)
1596
+
1597
+ if tid is None:
1598
+ if generate:
1599
+ # Synthesize new token
1600
+ tid = self._synthesize_token(token)
1601
+ X = self._id_to_vec[tid]
1602
+ else:
1603
+ # Fallback to UNK
1604
+ unk_id = self._token_to_id.get("<unk>")
1605
+ if unk_id is None:
1606
+ # No UNK token - try to synthesize if allowed
1607
+ if generate:
1608
+ tid = self._synthesize_token(token)
1609
+ X = self._id_to_vec[tid]
1610
+ else:
1611
+ raise KeyError(f"Token '{token}' not found and no <unk> token available")
1612
+ else:
1613
+ X = self.embedding(unk_id, generate=False)
1614
+ tid = unk_id
1615
+ else:
1616
+ X = self.embedding(tid, generate=False)
1617
+ if X is None:
1618
+ raise RuntimeError(f"Failed to load embedding for token '{token}'")
1619
+
1620
+ return (X, tid) if return_id else X
1621
+
1622
+ def get_score(self, token_or_id: Union[str, int]) -> float:
1623
+ """Get token score"""
1624
+ tid = token_or_id if isinstance(token_or_id, int) else self._token_to_id.get(token_or_id)
1625
+ if tid is None or tid not in self._valid_token_ids:
1626
+ return -100.0
1627
+
1628
+ # Load volume if needed
1629
+ if tid not in self._id_to_volume:
1630
+ self._load_crystal(tid)
1631
+
1632
+ vol = self._id_to_volume.get(tid, 1.0)
1633
+ return float(np.clip(vol / 10.0, 0.01, 1.0))
1634
+
1635
+ def encode_batch(self, tokens: Union[str, List[str]],
1636
+ *, return_ids: bool = False,
1637
+ prefetch: bool = True,
1638
+ generate: bool = False) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[int]]]:
1639
+ """
1640
+ Encode a batch of tokens efficiently.
1641
+
1642
+ Args:
1643
+ tokens: Either a string (will be tokenized) or list of token strings
1644
+ return_ids: Whether to return token IDs alongside embeddings
1645
+ prefetch: Whether to prefetch all tokens before encoding
1646
+ generate: If True, synthesize missing tokens
1647
+
1648
+ Returns:
1649
+ List of embeddings, optionally with list of token IDs
1650
+ """
1651
+ # Handle string input - tokenize it
1652
+ if isinstance(tokens, str):
1653
+ tokens = self.tokenizer(tokens)
1654
+
1655
+ if not isinstance(tokens, list):
1656
+ raise TypeError(f"Expected str or List[str], got {type(tokens)}")
1657
+
1658
+ # Track which tokens need synthesis
1659
+ tokens_to_synthesize = []
1660
+ if generate:
1661
+ for token in tokens:
1662
+ if token not in self._token_to_id:
1663
+ tokens_to_synthesize.append(token)
1664
+
1665
+ # Warn about batch synthesis if needed
1666
+ if tokens_to_synthesize and not self.silent and not SILENT_MODE:
1667
+ warnings.warn(
1668
+ f"{len(tokens_to_synthesize)} tokens synthesized - ensure you synthesize your tokens ahead of time. "
1669
+ f"Synthesized: {tokens_to_synthesize[:5]}{'...' if len(tokens_to_synthesize) > 5 else ''}",
1670
+ UserWarning,
1671
+ stacklevel=2
1672
+ )
1673
+
1674
+ # Prefetch existing tokens if requested
1675
+ if prefetch:
1676
+ self._prefetch_batch([t for t in tokens if t in self._token_to_id])
1677
+
1678
+ # Encode all tokens
1679
+ embeddings = []
1680
+ ids = []
1681
+
1682
+ for token in tokens:
1683
+ if return_ids:
1684
+ emb, tid = self.encode(token, return_id=True, generate=generate)
1685
+ embeddings.append(emb)
1686
+ ids.append(tid)
1687
+ else:
1688
+ emb = self.encode(token, return_id=False, generate=generate)
1689
+ embeddings.append(emb)
1690
+
1691
+ return (embeddings, ids) if return_ids else embeddings
1692
+
1693
+ def _prefetch_batch(self, tokens: List[str]):
1694
+ """
1695
+ Prefetch a batch of tokens efficiently.
1696
+ """
1697
+ # Collect all token IDs that need loading
1698
+ tokens_to_load = []
1699
+ for token in tokens:
1700
+ tid = self._token_to_id.get(token)
1701
+ if tid and tid not in self._crystal_cache and tid not in self._id_to_vec:
1702
+ tokens_to_load.append(tid)
1703
+
1704
+ if not tokens_to_load:
1705
+ return # Everything already cached
1706
+
1707
+ # Load crystals for each token
1708
+ for tid in tokens_to_load:
1709
+ self._load_crystal(tid)
1710
+
1711
+ def cache_stats(self) -> Dict[str, Any]:
1712
+ """Get cache statistics"""
1713
+ stats = super().cache_stats()
1714
+ stats.update({
1715
+ "crystal_cache_size": len(self._crystal_cache),
1716
+ "pooled_lru_size": len(self._pooled_lru),
1717
+ "rows_cached": len(self._row_data),
1718
+ "tokens_indexed": len(self._token_index),
1719
+ "ids_indexed": len(self._id_index),
1720
+ "synthesized_tokens": len(self._synthesized_tokens),
1721
+ })
1722
+ return stats
1723
+
1724
+ def evict_from_cache(self, tokens: Optional[List[str]] = None):
1725
+ """Manually evict tokens from cache to free memory"""
1726
+ if tokens is None:
1727
+ # Clear all caches
1728
+ self._crystal_cache.clear()
1729
+ self._pooled_lru.clear()
1730
+ self._id_to_vec.clear()
1731
+ self._id_to_pooled.clear()
1732
+ self._row_data.clear()
1733
+ else:
1734
+ # Evict specific tokens
1735
+ for token in tokens:
1736
+ tid = self._token_to_id.get(token)
1737
+ if tid:
1738
+ self._crystal_cache.pop(tid, None)
1739
+ self._pooled_lru.pop(tid, None)
1740
+ self._id_to_vec.pop(tid, None)
1741
+ self._id_to_pooled.pop(tid, None)
1742
+
1743
+ def get_synthesized_tokens(self) -> List[str]:
1744
+ """Get list of all tokens that were synthesized at runtime"""
1745
+ return list(self._synthesized_tokens)
1746
+
1747
+ def is_synthesized(self, token: str) -> bool:
1748
+ """Check if a token was synthesized at runtime"""
1749
+ return token in self._synthesized_tokens
1750
+
1751
+
1752
+
1753
+
1754
+ # For 100-dimensional embeddings
1755
+ vocab = LazyGeometricVocab(
1756
+ repo_id="AbstractPhil/geometric-vocab",
1757
+ dim=64,
1758
+ name="unicode_64d", # Specifies the dimension config
1759
+ split="train", # Now always "train"
1760
+ stream=False,
1761
+ cache_size=1024,
1762
+ silent=False
1763
+ )
1764
+ FROZEN_VOCAB = []