AbstractPhil commited on
Commit
bdc1770
·
verified ·
1 Parent(s): c605588

Updated the model to correctly reflect the fixes.

Browse files
Files changed (1) hide show
  1. model.py +265 -171
model.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
3
  Enhanced with Geometric Attention for improved head cohesion and generalization
 
4
  """
5
 
6
  import torch
@@ -35,7 +36,7 @@ class PentachoraConfig:
35
  aux_loss_weight: float = 0.3
36
  geo_loss_weight: float = 0.1
37
  vocab: Optional[Any] = None
38
-
39
  @property
40
  def num_patches(self) -> int:
41
  return (self.img_size // self.patch_size) ** 2
@@ -74,33 +75,33 @@ class GeometricConfig:
74
 
75
  class GeometricNavigator(nn.Module):
76
  """Maps inputs to geometric regions in 4D space."""
77
-
78
  def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig):
79
  super().__init__()
80
  self.input_dim = input_dim
81
  self.num_regions = num_regions
82
  self.config = config
83
-
84
  self.to_nav = nn.Linear(input_dim, 4, bias=False)
85
  self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5))
86
-
87
  # Initialize geometry after module is created
88
  self.register_parameter('D', None)
89
  self.register_parameter('S', None)
90
-
91
  def _lazy_init_geometry(self, device):
92
  """Initialize geometry on first forward pass."""
93
  if self.D is not None:
94
  return
95
-
96
  base = perfect_4simplex(device)
97
-
98
  D = torch.zeros(self.num_regions, 5, 4, device=device)
99
  S = torch.zeros(self.num_regions, 5, 4, device=device)
100
-
101
  for r in range(self.num_regions):
102
  D[r] = base + self.config.jitter * torch.randn_like(base)
103
-
104
  theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device)
105
  rot = torch.eye(4, device=device)
106
  c, s_val = torch.cos(theta), torch.sin(theta)
@@ -108,67 +109,67 @@ class GeometricNavigator(nn.Module):
108
  rot[1, 0] = s_val; rot[1, 1] = c
109
  S[r] = (base @ rot) + self.config.shift
110
  S[r] += self.config.jitter * torch.randn_like(S[r])
111
-
112
  self.D = nn.Parameter(D)
113
  self.S = nn.Parameter(S)
114
-
115
  def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
116
  """Navigate inputs through geometric space."""
117
  self._lazy_init_geometry(x.device)
118
-
119
  nav_x = self.to_nav(x)
120
  nav_x_exp = nav_x[:, None, None, :]
121
  D_exp = self.D[None, :, :, :]
122
-
123
  d_disp = torch.norm(nav_x_exp - D_exp, dim=-1)
124
  s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
125
-
126
  w = F.softmax(self.vertex_w, dim=1)
127
  phase_scores = []
128
-
129
  for phase in self.config.phases:
130
  phase_tensor = torch.tensor(phase, device=x.device)
131
  ct = torch.cos(phase_tensor)
132
  st = torch.sin(phase_tensor)
133
-
134
  Vt = ct * self.D + st * self.S
135
  w_expanded = w.unsqueeze(-1)
136
  Vt_mean = Vt.mean(dim=1, keepdim=True)
137
  Vt = (1.0 - w_expanded) * Vt + w_expanded * Vt_mean
138
-
139
  Vt_exp = Vt[None, :, :, :]
140
  d_ribbon = torch.norm(nav_x_exp - Vt_exp, dim=-1)
141
  s_ribbon = -softmin_over_last(d_ribbon, self.config.softmin_tau)
142
  phase_scores.append(s_ribbon)
143
-
144
  s_ribbon = torch.stack(phase_scores).mean(dim=0)
145
  scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
146
-
147
  diagnostics = {
148
  'dispatcher_scores': s_disp.detach(),
149
  'ribbon_scores': s_ribbon.detach()
150
  }
151
-
152
  return {'scores': scores, 'diagnostics': diagnostics}
153
 
154
  class GeometricAttention(nn.Module):
155
  """Multi-head geometric attention with Q-K alignment."""
156
-
157
  def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
158
  config: Optional[GeometricConfig] = None, dropout: float = 0.0):
159
  super().__init__()
160
  self.dim = dim
161
  self.num_heads = num_heads
162
  self.head_dim = dim // num_heads
163
-
164
  if num_regions is None:
165
  num_regions = min(self.head_dim, 16)
166
  if config is None:
167
  config = GeometricConfig()
168
-
169
  self.config = config
170
  self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
171
-
172
  self.q_navigators = nn.ModuleList([
173
  GeometricNavigator(self.head_dim, num_regions, config)
174
  for _ in range(num_heads)
@@ -177,53 +178,53 @@ class GeometricAttention(nn.Module):
177
  GeometricNavigator(self.head_dim, num_regions, config)
178
  for _ in range(num_heads)
179
  ])
180
-
181
  self.out_proj = nn.Linear(dim, dim)
182
  self.dropout = nn.Dropout(dropout)
183
-
184
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
185
  return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]:
186
  B, T, D = x.shape
187
-
188
  qkv = self.to_qkv(x)
189
  q, k, v = qkv.chunk(3, dim=-1)
190
-
191
  q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
192
  k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
193
  v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
194
-
195
  outputs = []
196
  all_diagnostics = [] if return_diagnostics else None
197
-
198
  for h in range(self.num_heads):
199
  q_h_flat = q[:, h].reshape(B * T, self.head_dim)
200
  k_h_flat = k[:, h].reshape(B * T, self.head_dim)
201
-
202
  q_nav = self.q_navigators[h].navigate(q_h_flat)
203
  k_nav = self.k_navigators[h].navigate(k_h_flat)
204
-
205
  q_scores = q_nav['scores'].reshape(B, T, -1)
206
  k_scores = k_nav['scores'].reshape(B, T, -1)
207
-
208
  attn = torch.bmm(q_scores, k_scores.transpose(1, 2))
209
  attn = attn / math.sqrt(q_scores.size(-1))
210
-
211
  if mask is not None:
212
  attn = attn.masked_fill(mask.unsqueeze(1) == 0, -1e9)
213
-
214
  attn = F.softmax(attn, dim=-1)
215
  attn = self.dropout(attn)
216
-
217
  out = torch.bmm(attn, v[:, h])
218
  outputs.append(out)
219
-
220
  if return_diagnostics:
221
  all_diagnostics.append({'q': q_nav['diagnostics'], 'k': k_nav['diagnostics']})
222
-
223
  output = torch.stack(outputs, dim=1).transpose(1, 2).reshape(B, T, D)
224
  output = self.out_proj(output)
225
  output = self.dropout(output)
226
-
227
  if return_diagnostics:
228
  return output, {'head_diagnostics': all_diagnostics}
229
  return output, None
@@ -249,103 +250,156 @@ class DropPath(nn.Module):
249
  return output
250
 
251
  # ============================================
252
- # HIERARCHICAL CLS WITH PENTACHORA
253
  # ============================================
254
 
255
  class HierarchicalPentachoronCLS(nn.Module):
256
  """
257
  Hierarchical CLS structure with pentachoron geometry.
258
- Creates global, vertex-level, and class-specific representations.
259
  """
260
  def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100):
261
  super().__init__()
262
  self.dim = dim # Model's internal dimension
263
  self.vocab_dim = vocab_dim # Vocabulary's dimension
264
  self.num_classes = num_classes
265
-
266
- # Hierarchical CLS tokens (in model dimension)
267
- self.global_cls = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
268
- self.vertex_cls = nn.Parameter(torch.randn(1, 5, dim) * 0.02)
269
-
270
- # Class-specific pentachora (in vocabulary dimension)
271
  self.class_pentachora = nn.Parameter(torch.randn(num_classes, 5, vocab_dim) * 0.02)
272
 
273
- # Projection layer to align vocab_dim with model dim if they differ
274
  if vocab_dim != dim:
275
- self.vocab_projection = nn.Linear(vocab_dim, dim)
276
  else:
277
- self.vocab_projection = nn.Identity()
278
 
279
- # Aggregation layers
280
- self.vertex_to_global = nn.Linear(dim * 5, dim)
281
- self.norm = nn.LayerNorm(dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
- def forward(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
284
- """Generate CLS tokens for batch."""
285
- global_cls = self.global_cls.expand(batch_size, -1, -1)
286
- vertex_cls = self.vertex_cls.expand(batch_size, -1, -1)
287
  return global_cls, vertex_cls
288
-
289
- def aggregate_vertices(self, vertex_cls: torch.Tensor) -> torch.Tensor:
290
- """Aggregate vertex representations to global."""
291
- B = vertex_cls.shape[0]
292
- flattened = vertex_cls.reshape(B, -1)
293
- aggregated = self.vertex_to_global(flattened).unsqueeze(1)
294
- return self.norm(aggregated)
 
 
 
 
 
 
 
 
 
295
 
296
  # ============================================
297
- # GEOMETRIC PROJECTION LAYER
298
  # ============================================
299
 
300
  class GeometricProjection(nn.Module):
301
- """Project patches onto pentachoron geometry."""
 
 
 
302
  def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1):
303
  super().__init__()
304
  self.dim = dim # Model dimension
305
  self.vocab_dim = vocab_dim # Vocabulary dimension
306
  self.num_classes = num_classes
 
 
 
307
 
308
- # Separate projection for each vertex (project from model dim to vocab dim for alignment)
309
  self.vertex_projections = nn.ModuleList([
310
- nn.Linear(dim, vocab_dim, bias=False) for _ in range(5)
311
  ])
312
 
 
 
 
313
  self.norm = nn.LayerNorm(dim)
314
  self.dropout = nn.Dropout(dropout)
315
-
316
  def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor:
317
  """
318
  Compute alignment between patches and class pentachora.
319
-
320
  Args:
321
- patches: [B, N, D] - patch embeddings
322
- pentachora: [C, 5, vocab_dim] - class pentachora
323
-
324
  Returns:
325
  [B, N, C] - alignment scores
326
  """
327
  B, N, D = patches.shape
328
  C = pentachora.shape[0]
329
-
 
330
  patches = self.norm(patches)
331
 
 
 
 
 
332
  # Compute alignment with each vertex
333
  alignments = []
334
  for v in range(5):
335
- # Project patches through vertex-specific projection
336
- patches_proj = self.vertex_projections[v](patches)
337
- patches_proj = F.normalize(patches_proj, dim=-1)
338
 
339
  # Get vertex v of all classes
340
- vertex_v = F.normalize(pentachora[:, v, :], dim=-1)
341
 
342
  # Compute alignment scores
343
- alignment = torch.matmul(patches_proj, vertex_v.T)
344
  alignments.append(alignment)
345
-
346
  # Average alignments across vertices
347
- alignments = torch.stack(alignments, dim=-1).mean(dim=-1)
348
-
349
  return self.dropout(alignments)
350
 
351
  # ============================================
@@ -359,13 +413,13 @@ class MLP(nn.Module):
359
  super().__init__()
360
  out_features = out_features or in_features
361
  hidden_features = hidden_features or in_features
362
-
363
  self.fc1 = nn.Linear(in_features, hidden_features)
364
  self.act = nn.GELU()
365
  self.drop1 = nn.Dropout(dropout)
366
  self.fc2 = nn.Linear(hidden_features, out_features)
367
  self.drop2 = nn.Dropout(dropout)
368
-
369
  def forward(self, x: torch.Tensor) -> torch.Tensor:
370
  x = self.fc1(x)
371
  x = self.act(x)
@@ -385,7 +439,7 @@ class PentachoronViTBlock(nn.Module):
385
  drop_path: float = 0.):
386
  super().__init__()
387
  self.norm1 = nn.LayerNorm(dim)
388
-
389
  # Use GeometricAttention for structured layers, standard for others
390
  if use_mesh:
391
  self.attn = GeometricAttention(
@@ -398,15 +452,15 @@ class PentachoronViTBlock(nn.Module):
398
  else:
399
  # Standard multi-head attention for later layers
400
  self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True)
401
-
402
  self.use_mesh = use_mesh
403
  self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
404
-
405
  self.norm2 = nn.LayerNorm(dim)
406
  mlp_hidden = int(dim * mlp_ratio)
407
  self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout)
408
  self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
409
-
410
  def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor:
411
  if self.use_mesh:
412
  # GeometricAttention
@@ -417,7 +471,7 @@ class PentachoronViTBlock(nn.Module):
417
  normalized = self.norm1(x)
418
  attn_out, _ = self.attn(normalized, normalized, normalized)
419
  x = x + self.drop_path1(attn_out)
420
-
421
  x = x + self.drop_path2(self.mlp(self.norm2(x)))
422
  return x
423
 
@@ -433,10 +487,10 @@ class PatchEmbed(nn.Module):
433
  self.img_size = img_size
434
  self.patch_size = patch_size
435
  self.num_patches = (img_size // patch_size) ** 2
436
-
437
  self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
438
  self.norm = nn.LayerNorm(embed_dim)
439
-
440
  def forward(self, x: torch.Tensor) -> torch.Tensor:
441
  x = self.proj(x)
442
  x = rearrange(x, 'b c h w -> b (h w) c')
@@ -444,60 +498,61 @@ class PatchEmbed(nn.Module):
444
  return x
445
 
446
  # ============================================
447
- # PENTACHORA VISION TRANSFORMER
448
  # ============================================
449
 
450
  class PentachoraViT(nn.Module):
451
  """
452
  Vision Transformer with pentachoron-based hierarchical CLS tokens
453
  and geometric vocabulary integration.
 
454
  """
455
  def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs):
456
  super().__init__()
457
-
458
  # Use config or kwargs
459
  if config is not None:
460
  cfg = config
461
  else:
462
  cfg = PentachoraConfig(**kwargs)
463
-
464
  self.config = cfg
465
  self.num_classes = cfg.num_classes
466
  self.dim = cfg.dim
467
  self.depth = cfg.depth
468
  self.preserve_structure_until_layer = cfg.preserve_structure_until_layer
469
-
470
- # Set vocabulary dimension - from config, kwargs, or default to model dim
471
  if cfg.vocab_dim is not None:
472
  self.vocab_dim = cfg.vocab_dim
473
  elif 'vocab_dim' in kwargs:
474
  self.vocab_dim = kwargs['vocab_dim']
475
  else:
476
  self.vocab_dim = cfg.dim
477
-
478
  # Patch embedding
479
  self.patch_embed = PatchEmbed(
480
  cfg.img_size, cfg.patch_size, 3, cfg.dim
481
  )
482
  num_patches = self.patch_embed.num_patches
483
-
484
  # Positional embedding
485
  self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02)
486
  self.pos_drop = nn.Dropout(cfg.dropout_rate)
487
-
488
  # CLS tokens with pentachoron structure
489
  self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes)
490
-
491
- # Geometric projection layer - CREATE BEFORE vocab init
492
  self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate)
493
-
494
- # Initialize from vocabulary AFTER creating all components
495
  if cfg.vocab is not None:
496
  self._init_from_vocab(cfg.vocab)
497
-
498
  # Stochastic depth decay rule
499
  dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)]
500
-
501
  # Transformer blocks with geometric attention
502
  self.blocks = nn.ModuleList([
503
  PentachoronViTBlock(
@@ -511,17 +566,26 @@ class PentachoraViT(nn.Module):
511
  )
512
  for i in range(cfg.depth)
513
  ])
514
-
515
  # Final norm
516
  self.norm = nn.LayerNorm(cfg.dim)
517
-
518
  # Classification heads
519
- self.head = nn.Linear(cfg.dim, cfg.num_classes)
520
- self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes)
 
 
 
 
 
 
521
 
 
 
 
522
  # Initialize weights
523
  self.apply(self._init_weights)
524
-
525
  def _init_weights(self, m: nn.Module):
526
  """Initialize model weights."""
527
  if isinstance(m, nn.Linear):
@@ -535,79 +599,81 @@ class PentachoraViT(nn.Module):
535
  nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
536
  if m.bias is not None:
537
  nn.init.constant_(m.bias, 0)
538
-
539
  def _init_from_vocab(self, vocab):
540
  """Initialize class pentachora from geometric vocabulary."""
541
  try:
542
  print("Initializing pentachora from vocabulary...")
543
-
544
  if not hasattr(vocab, 'encode_batch'):
545
  print("Vocabulary provided but encode_batch method not found, using random initialization")
546
  return
547
-
548
  # Get CIFAR-100 class names
549
  class_names = self._get_cifar100_classes()
550
-
551
  # Generate pentachora for all classes
552
  pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True)
553
  pentachora = np.stack(pentachora_list, axis=0)
554
-
555
  # Get actual dimensions from the encoded data
556
  actual_vocab_dim = pentachora.shape[-1]
557
-
558
  print(f"Encoded pentachora shape: {pentachora.shape}")
559
  print(f"Detected vocabulary dimension: {actual_vocab_dim}")
560
-
561
  # Validate basic shape requirements
562
  if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5:
563
  print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}")
564
  print("Using random initialization")
565
  return
566
-
567
- # Update all components to use the actual vocabulary dimension
568
  self.vocab_dim = actual_vocab_dim
569
  self.cls_tokens.vocab_dim = actual_vocab_dim
570
  self.geometric_proj.vocab_dim = actual_vocab_dim
571
-
572
  # Replace class_pentachora with the loaded vocabulary
573
  self.cls_tokens.class_pentachora = nn.Parameter(
574
  torch.tensor(pentachora, dtype=torch.float32)
575
  )
576
-
577
  # Update/create projection layer if dimensions differ
578
  if actual_vocab_dim != self.dim:
579
- self.cls_tokens.vocab_projection = nn.Linear(actual_vocab_dim, self.dim)
580
  else:
581
- self.cls_tokens.vocab_projection = nn.Identity()
582
-
583
- # Rebuild geometric projection layers with correct dimensions
 
584
  self.geometric_proj.vertex_projections = nn.ModuleList([
585
- nn.Linear(self.dim, actual_vocab_dim, bias=False) for _ in range(5)
586
  ])
587
-
588
  # Re-initialize the new layers
 
589
  for proj in self.geometric_proj.vertex_projections:
590
  nn.init.xavier_uniform_(proj.weight)
591
  if actual_vocab_dim != self.dim:
592
- nn.init.xavier_uniform_(self.cls_tokens.vocab_projection.weight)
593
-
594
  print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary")
595
  print(f" Vocabulary dimension: {actual_vocab_dim}")
596
  print(f" Model internal dimension: {self.dim}")
597
- print(f" Projection: {actual_vocab_dim} {self.dim}")
598
-
599
  except Exception as e:
600
  print(f"Error initializing from vocabulary: {e}")
601
  print("Using random initialization")
602
-
603
  def _get_cifar100_classes(self):
604
  """Get CIFAR-100 class names."""
605
  return [
606
- 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
607
- 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
608
- 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
609
- 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
610
- 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
611
  'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
612
  'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
613
  'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
@@ -618,59 +684,86 @@ class PentachoraViT(nn.Module):
618
  'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
619
  'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
620
  ]
621
-
622
- def forward_features(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
623
- """Extract features from input."""
624
- B = x.shape[0]
625
 
 
 
 
 
 
 
626
  # Patch embedding
627
  x = self.patch_embed(x)
628
  x = x + self.pos_embed
629
  x = self.pos_drop(x)
630
-
631
- # Get hierarchical CLS tokens
632
- global_cls, vertex_cls = self.cls_tokens(B)
633
-
634
  # Concatenate CLS tokens with patches
635
  x = torch.cat([global_cls, vertex_cls, x], dim=1)
636
-
637
  # Apply transformer blocks
638
  for i, block in enumerate(self.blocks):
639
  preserve = i < self.preserve_structure_until_layer
640
  x = block(x, preserve_structure=preserve)
641
-
642
  # Apply final norm
643
  x = self.norm(x)
644
-
645
  # Split tokens
646
  global_cls = x[:, 0]
647
  vertex_cls = x[:, 1:6]
648
  patches = x[:, 6:]
649
-
650
  return {
651
  'global_cls': global_cls,
652
  'vertex_cls': vertex_cls,
653
  'patches': patches
654
  }
655
-
656
- def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
657
- """Forward pass through the model."""
658
- features = self.forward_features(x)
659
 
660
- # Primary classification using global CLS
661
- logits = self.head(features['global_cls'])
 
 
 
 
662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  # Auxiliary classification using vertex tokens
664
  B = features['vertex_cls'].shape[0]
665
  vertex_flat = features['vertex_cls'].reshape(B, -1)
666
  aux_logits = self.head_aux(vertex_flat)
667
-
668
  # Geometric alignment scores
669
  geometric_alignments = self.geometric_proj(
670
- features['patches'],
671
  self.cls_tokens.class_pentachora
672
  )
673
-
674
  return {
675
  'logits': logits,
676
  'aux_logits': aux_logits,
@@ -692,24 +785,24 @@ class PentachoraLoss(nn.Module):
692
  self.aux_weight = aux_weight
693
  self.geo_weight = geo_weight
694
  self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing)
695
-
696
  def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
697
  """Compute combined loss."""
698
  # Primary classification loss
699
  loss = self.criterion(outputs['logits'], targets)
700
-
701
  # Auxiliary loss from vertex tokens
702
  if 'aux_logits' in outputs and self.aux_weight > 0:
703
  aux_loss = self.criterion(outputs['aux_logits'], targets)
704
  loss = loss + self.aux_weight * aux_loss
705
-
706
  # Geometric alignment loss
707
  if 'geometric_alignments' in outputs and self.geo_weight > 0:
708
  # Average over patches
709
  geo_logits = outputs['geometric_alignments'].mean(dim=1)
710
  geo_loss = self.criterion(geo_logits, targets)
711
  loss = loss + self.geo_weight * geo_loss
712
-
713
  return loss
714
 
715
  # ============================================
@@ -718,8 +811,8 @@ class PentachoraLoss(nn.Module):
718
 
719
  MODEL_CONFIGS = {
720
  'pentachora_spark': PentachoraConfig(
721
- dim=64, depth=5, heads=4, mlp_ratio=4.0,
722
- preserve_structure_until_layer=2,
723
  dropout_rate=0.0, drop_path_rate=0.0
724
  ),
725
  'pentachora_tiny': PentachoraConfig(
@@ -749,31 +842,32 @@ def create_pentachora_vit(variant: str = 'pentachora_small',
749
  **kwargs) -> PentachoraViT:
750
  """
751
  Create PentachoraViT model.
752
-
753
  Args:
754
  variant: Model variant name
755
  pretrained: Whether to load pretrained weights
756
  **kwargs: Override config parameters (including vocab_dim)
757
-
758
  Returns:
759
  PentachoraViT model
760
  """
761
  if variant not in MODEL_CONFIGS:
762
  raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}")
763
-
764
  config = MODEL_CONFIGS[variant]
765
-
766
  # Override config with kwargs
767
  for key, value in kwargs.items():
768
  setattr(config, key, value)
769
-
770
  model = PentachoraViT(config)
771
-
772
  if pretrained:
773
  warnings.warn("Pretrained weights not available yet")
774
-
775
  return model
776
 
 
777
  def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT:
778
  """Create spark variant (smallest)."""
779
  return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs)
 
1
  """
2
  PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
3
  Enhanced with Geometric Attention for improved head cohesion and generalization
4
+ FIXED: CLS tokens now properly reference and utilize vocabulary embeddings
5
  """
6
 
7
  import torch
 
36
  aux_loss_weight: float = 0.3
37
  geo_loss_weight: float = 0.1
38
  vocab: Optional[Any] = None
39
+
40
  @property
41
  def num_patches(self) -> int:
42
  return (self.img_size // self.patch_size) ** 2
 
75
 
76
  class GeometricNavigator(nn.Module):
77
  """Maps inputs to geometric regions in 4D space."""
78
+
79
  def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig):
80
  super().__init__()
81
  self.input_dim = input_dim
82
  self.num_regions = num_regions
83
  self.config = config
84
+
85
  self.to_nav = nn.Linear(input_dim, 4, bias=False)
86
  self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5))
87
+
88
  # Initialize geometry after module is created
89
  self.register_parameter('D', None)
90
  self.register_parameter('S', None)
91
+
92
  def _lazy_init_geometry(self, device):
93
  """Initialize geometry on first forward pass."""
94
  if self.D is not None:
95
  return
96
+
97
  base = perfect_4simplex(device)
98
+
99
  D = torch.zeros(self.num_regions, 5, 4, device=device)
100
  S = torch.zeros(self.num_regions, 5, 4, device=device)
101
+
102
  for r in range(self.num_regions):
103
  D[r] = base + self.config.jitter * torch.randn_like(base)
104
+
105
  theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device)
106
  rot = torch.eye(4, device=device)
107
  c, s_val = torch.cos(theta), torch.sin(theta)
 
109
  rot[1, 0] = s_val; rot[1, 1] = c
110
  S[r] = (base @ rot) + self.config.shift
111
  S[r] += self.config.jitter * torch.randn_like(S[r])
112
+
113
  self.D = nn.Parameter(D)
114
  self.S = nn.Parameter(S)
115
+
116
  def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
117
  """Navigate inputs through geometric space."""
118
  self._lazy_init_geometry(x.device)
119
+
120
  nav_x = self.to_nav(x)
121
  nav_x_exp = nav_x[:, None, None, :]
122
  D_exp = self.D[None, :, :, :]
123
+
124
  d_disp = torch.norm(nav_x_exp - D_exp, dim=-1)
125
  s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
126
+
127
  w = F.softmax(self.vertex_w, dim=1)
128
  phase_scores = []
129
+
130
  for phase in self.config.phases:
131
  phase_tensor = torch.tensor(phase, device=x.device)
132
  ct = torch.cos(phase_tensor)
133
  st = torch.sin(phase_tensor)
134
+
135
  Vt = ct * self.D + st * self.S
136
  w_expanded = w.unsqueeze(-1)
137
  Vt_mean = Vt.mean(dim=1, keepdim=True)
138
  Vt = (1.0 - w_expanded) * Vt + w_expanded * Vt_mean
139
+
140
  Vt_exp = Vt[None, :, :, :]
141
  d_ribbon = torch.norm(nav_x_exp - Vt_exp, dim=-1)
142
  s_ribbon = -softmin_over_last(d_ribbon, self.config.softmin_tau)
143
  phase_scores.append(s_ribbon)
144
+
145
  s_ribbon = torch.stack(phase_scores).mean(dim=0)
146
  scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
147
+
148
  diagnostics = {
149
  'dispatcher_scores': s_disp.detach(),
150
  'ribbon_scores': s_ribbon.detach()
151
  }
152
+
153
  return {'scores': scores, 'diagnostics': diagnostics}
154
 
155
  class GeometricAttention(nn.Module):
156
  """Multi-head geometric attention with Q-K alignment."""
157
+
158
  def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
159
  config: Optional[GeometricConfig] = None, dropout: float = 0.0):
160
  super().__init__()
161
  self.dim = dim
162
  self.num_heads = num_heads
163
  self.head_dim = dim // num_heads
164
+
165
  if num_regions is None:
166
  num_regions = min(self.head_dim, 16)
167
  if config is None:
168
  config = GeometricConfig()
169
+
170
  self.config = config
171
  self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
172
+
173
  self.q_navigators = nn.ModuleList([
174
  GeometricNavigator(self.head_dim, num_regions, config)
175
  for _ in range(num_heads)
 
178
  GeometricNavigator(self.head_dim, num_regions, config)
179
  for _ in range(num_heads)
180
  ])
181
+
182
  self.out_proj = nn.Linear(dim, dim)
183
  self.dropout = nn.Dropout(dropout)
184
+
185
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
186
  return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]:
187
  B, T, D = x.shape
188
+
189
  qkv = self.to_qkv(x)
190
  q, k, v = qkv.chunk(3, dim=-1)
191
+
192
  q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
193
  k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
194
  v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
195
+
196
  outputs = []
197
  all_diagnostics = [] if return_diagnostics else None
198
+
199
  for h in range(self.num_heads):
200
  q_h_flat = q[:, h].reshape(B * T, self.head_dim)
201
  k_h_flat = k[:, h].reshape(B * T, self.head_dim)
202
+
203
  q_nav = self.q_navigators[h].navigate(q_h_flat)
204
  k_nav = self.k_navigators[h].navigate(k_h_flat)
205
+
206
  q_scores = q_nav['scores'].reshape(B, T, -1)
207
  k_scores = k_nav['scores'].reshape(B, T, -1)
208
+
209
  attn = torch.bmm(q_scores, k_scores.transpose(1, 2))
210
  attn = attn / math.sqrt(q_scores.size(-1))
211
+
212
  if mask is not None:
213
  attn = attn.masked_fill(mask.unsqueeze(1) == 0, -1e9)
214
+
215
  attn = F.softmax(attn, dim=-1)
216
  attn = self.dropout(attn)
217
+
218
  out = torch.bmm(attn, v[:, h])
219
  outputs.append(out)
220
+
221
  if return_diagnostics:
222
  all_diagnostics.append({'q': q_nav['diagnostics'], 'k': k_nav['diagnostics']})
223
+
224
  output = torch.stack(outputs, dim=1).transpose(1, 2).reshape(B, T, D)
225
  output = self.out_proj(output)
226
  output = self.dropout(output)
227
+
228
  if return_diagnostics:
229
  return output, {'head_diagnostics': all_diagnostics}
230
  return output, None
 
250
  return output
251
 
252
  # ============================================
253
+ # HIERARCHICAL CLS WITH PENTACHORA (FIXED)
254
  # ============================================
255
 
256
  class HierarchicalPentachoronCLS(nn.Module):
257
  """
258
  Hierarchical CLS structure with pentachoron geometry.
259
+ FIXED: Now properly uses vocabulary embeddings for CLS tokens.
260
  """
261
  def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100):
262
  super().__init__()
263
  self.dim = dim # Model's internal dimension
264
  self.vocab_dim = vocab_dim # Vocabulary's dimension
265
  self.num_classes = num_classes
266
+
267
+ # Class-specific pentachora from vocabulary (in vocabulary dimension)
 
 
 
 
268
  self.class_pentachora = nn.Parameter(torch.randn(num_classes, 5, vocab_dim) * 0.02)
269
 
270
+ # Projection from vocabulary dimension to model dimension
271
  if vocab_dim != dim:
272
+ self.vocab_to_model = nn.Linear(vocab_dim, dim)
273
  else:
274
+ self.vocab_to_model = nn.Identity()
275
 
276
+ # Learnable aggregation weights for creating global CLS from vertices
277
+ self.vertex_weights = nn.Parameter(torch.ones(5) / 5)
278
+
279
+ # Optional learnable offset for global CLS
280
+ self.global_offset = nn.Parameter(torch.zeros(1, 1, dim))
281
+
282
+ # Layer norms
283
+ self.vertex_norm = nn.LayerNorm(dim)
284
+ self.global_norm = nn.LayerNorm(dim)
285
+
286
+ def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
287
+ """
288
+ Generate CLS tokens for batch.
289
+
290
+ Args:
291
+ batch_size: Batch size
292
+ class_indices: Optional class indices for class-specific initialization
293
+
294
+ Returns:
295
+ global_cls: [B, 1, D] - Global CLS tokens
296
+ vertex_cls: [B, 5, D] - Vertex CLS tokens
297
+ """
298
+ if class_indices is not None and class_indices.shape[0] == batch_size:
299
+ # Use class-specific pentachora when class indices are provided
300
+ # This would typically be used during training with labels
301
+ vertex_cls_vocab = self.class_pentachora[class_indices] # [B, 5, vocab_dim]
302
+ else:
303
+ # Use mean of all class pentachora when no specific classes provided
304
+ # This is used during inference or when class is unknown
305
+ vertex_cls_vocab = self.class_pentachora.mean(dim=0, keepdim=True) # [1, 5, vocab_dim]
306
+ vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1) # [B, 5, vocab_dim]
307
+
308
+ # Project from vocabulary dimension to model dimension
309
+ vertex_cls = self.vocab_to_model(vertex_cls_vocab) # [B, 5, dim]
310
+ vertex_cls = self.vertex_norm(vertex_cls)
311
+
312
+ # Create global CLS as weighted combination of vertices
313
+ weights = F.softmax(self.vertex_weights, dim=0)
314
+ global_cls = torch.einsum('bvd,v->bd', vertex_cls, weights).unsqueeze(1) # [B, 1, dim]
315
+ global_cls = global_cls + self.global_offset
316
+ global_cls = self.global_norm(global_cls)
317
 
 
 
 
 
318
  return global_cls, vertex_cls
319
+
320
+ def get_class_prototypes(self) -> torch.Tensor:
321
+ """
322
+ Get class prototypes in model dimension.
323
+
324
+ Returns:
325
+ prototypes: [num_classes, dim] - Class prototype vectors
326
+ """
327
+ # Project class pentachora to model dimension
328
+ pentachora_model = self.vocab_to_model(self.class_pentachora) # [C, 5, dim]
329
+
330
+ # Aggregate vertices to get class prototypes
331
+ weights = F.softmax(self.vertex_weights, dim=0)
332
+ prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights) # [C, dim]
333
+
334
+ return prototypes
335
 
336
  # ============================================
337
+ # GEOMETRIC PROJECTION LAYER (ENHANCED)
338
  # ============================================
339
 
340
  class GeometricProjection(nn.Module):
341
+ """
342
+ Project patches onto pentachoron geometry.
343
+ ENHANCED: Now provides better integration with vocabulary.
344
+ """
345
  def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1):
346
  super().__init__()
347
  self.dim = dim # Model dimension
348
  self.vocab_dim = vocab_dim # Vocabulary dimension
349
  self.num_classes = num_classes
350
+
351
+ # Projection from model dim to vocab dim for alignment
352
+ self.to_vocab_space = nn.Linear(dim, vocab_dim)
353
 
354
+ # Vertex-specific projections for fine-grained alignment
355
  self.vertex_projections = nn.ModuleList([
356
+ nn.Linear(vocab_dim, vocab_dim, bias=False) for _ in range(5)
357
  ])
358
 
359
+ # Temperature for alignment scores
360
+ self.temperature = nn.Parameter(torch.ones(1))
361
+
362
  self.norm = nn.LayerNorm(dim)
363
  self.dropout = nn.Dropout(dropout)
364
+
365
  def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor:
366
  """
367
  Compute alignment between patches and class pentachora.
368
+
369
  Args:
370
+ patches: [B, N, D] - patch embeddings in model dimension
371
+ pentachora: [C, 5, vocab_dim] - class pentachora in vocabulary dimension
372
+
373
  Returns:
374
  [B, N, C] - alignment scores
375
  """
376
  B, N, D = patches.shape
377
  C = pentachora.shape[0]
378
+
379
+ # Normalize patches
380
  patches = self.norm(patches)
381
 
382
+ # Project patches to vocabulary space
383
+ patches_vocab = self.to_vocab_space(patches) # [B, N, vocab_dim]
384
+ patches_vocab = F.normalize(patches_vocab, dim=-1)
385
+
386
  # Compute alignment with each vertex
387
  alignments = []
388
  for v in range(5):
389
+ # Apply vertex-specific transformation
390
+ patches_v = self.vertex_projections[v](patches_vocab)
391
+ patches_v = F.normalize(patches_v, dim=-1)
392
 
393
  # Get vertex v of all classes
394
+ vertex_v = F.normalize(pentachora[:, v, :], dim=-1) # [C, vocab_dim]
395
 
396
  # Compute alignment scores
397
+ alignment = torch.matmul(patches_v, vertex_v.T) / self.temperature # [B, N, C]
398
  alignments.append(alignment)
399
+
400
  # Average alignments across vertices
401
+ alignments = torch.stack(alignments, dim=-1).mean(dim=-1) # [B, N, C]
402
+
403
  return self.dropout(alignments)
404
 
405
  # ============================================
 
413
  super().__init__()
414
  out_features = out_features or in_features
415
  hidden_features = hidden_features or in_features
416
+
417
  self.fc1 = nn.Linear(in_features, hidden_features)
418
  self.act = nn.GELU()
419
  self.drop1 = nn.Dropout(dropout)
420
  self.fc2 = nn.Linear(hidden_features, out_features)
421
  self.drop2 = nn.Dropout(dropout)
422
+
423
  def forward(self, x: torch.Tensor) -> torch.Tensor:
424
  x = self.fc1(x)
425
  x = self.act(x)
 
439
  drop_path: float = 0.):
440
  super().__init__()
441
  self.norm1 = nn.LayerNorm(dim)
442
+
443
  # Use GeometricAttention for structured layers, standard for others
444
  if use_mesh:
445
  self.attn = GeometricAttention(
 
452
  else:
453
  # Standard multi-head attention for later layers
454
  self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True)
455
+
456
  self.use_mesh = use_mesh
457
  self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
458
+
459
  self.norm2 = nn.LayerNorm(dim)
460
  mlp_hidden = int(dim * mlp_ratio)
461
  self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout)
462
  self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
463
+
464
  def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor:
465
  if self.use_mesh:
466
  # GeometricAttention
 
471
  normalized = self.norm1(x)
472
  attn_out, _ = self.attn(normalized, normalized, normalized)
473
  x = x + self.drop_path1(attn_out)
474
+
475
  x = x + self.drop_path2(self.mlp(self.norm2(x)))
476
  return x
477
 
 
487
  self.img_size = img_size
488
  self.patch_size = patch_size
489
  self.num_patches = (img_size // patch_size) ** 2
490
+
491
  self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
492
  self.norm = nn.LayerNorm(embed_dim)
493
+
494
  def forward(self, x: torch.Tensor) -> torch.Tensor:
495
  x = self.proj(x)
496
  x = rearrange(x, 'b c h w -> b (h w) c')
 
498
  return x
499
 
500
  # ============================================
501
+ # PENTACHORA VISION TRANSFORMER (FIXED)
502
  # ============================================
503
 
504
  class PentachoraViT(nn.Module):
505
  """
506
  Vision Transformer with pentachoron-based hierarchical CLS tokens
507
  and geometric vocabulary integration.
508
+ FIXED: CLS tokens now properly reference vocabulary embeddings.
509
  """
510
  def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs):
511
  super().__init__()
512
+
513
  # Use config or kwargs
514
  if config is not None:
515
  cfg = config
516
  else:
517
  cfg = PentachoraConfig(**kwargs)
518
+
519
  self.config = cfg
520
  self.num_classes = cfg.num_classes
521
  self.dim = cfg.dim
522
  self.depth = cfg.depth
523
  self.preserve_structure_until_layer = cfg.preserve_structure_until_layer
524
+
525
+ # Set vocabulary dimension
526
  if cfg.vocab_dim is not None:
527
  self.vocab_dim = cfg.vocab_dim
528
  elif 'vocab_dim' in kwargs:
529
  self.vocab_dim = kwargs['vocab_dim']
530
  else:
531
  self.vocab_dim = cfg.dim
532
+
533
  # Patch embedding
534
  self.patch_embed = PatchEmbed(
535
  cfg.img_size, cfg.patch_size, 3, cfg.dim
536
  )
537
  num_patches = self.patch_embed.num_patches
538
+
539
  # Positional embedding
540
  self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02)
541
  self.pos_drop = nn.Dropout(cfg.dropout_rate)
542
+
543
  # CLS tokens with pentachoron structure
544
  self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes)
545
+
546
+ # Geometric projection layer
547
  self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate)
548
+
549
+ # Initialize from vocabulary if provided
550
  if cfg.vocab is not None:
551
  self._init_from_vocab(cfg.vocab)
552
+
553
  # Stochastic depth decay rule
554
  dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)]
555
+
556
  # Transformer blocks with geometric attention
557
  self.blocks = nn.ModuleList([
558
  PentachoronViTBlock(
 
566
  )
567
  for i in range(cfg.depth)
568
  ])
569
+
570
  # Final norm
571
  self.norm = nn.LayerNorm(cfg.dim)
572
+
573
  # Classification heads
574
+ # Primary head uses prototypes for classification
575
+ self.use_prototype_classifier = True
576
+ if self.use_prototype_classifier:
577
+ # No learnable parameters - uses class prototypes directly
578
+ self.head = None
579
+ else:
580
+ # Traditional linear head
581
+ self.head = nn.Linear(cfg.dim, cfg.num_classes)
582
 
583
+ # Auxiliary head for vertex tokens
584
+ self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes)
585
+
586
  # Initialize weights
587
  self.apply(self._init_weights)
588
+
589
  def _init_weights(self, m: nn.Module):
590
  """Initialize model weights."""
591
  if isinstance(m, nn.Linear):
 
599
  nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
600
  if m.bias is not None:
601
  nn.init.constant_(m.bias, 0)
602
+
603
  def _init_from_vocab(self, vocab):
604
  """Initialize class pentachora from geometric vocabulary."""
605
  try:
606
  print("Initializing pentachora from vocabulary...")
607
+
608
  if not hasattr(vocab, 'encode_batch'):
609
  print("Vocabulary provided but encode_batch method not found, using random initialization")
610
  return
611
+
612
  # Get CIFAR-100 class names
613
  class_names = self._get_cifar100_classes()
614
+
615
  # Generate pentachora for all classes
616
  pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True)
617
  pentachora = np.stack(pentachora_list, axis=0)
618
+
619
  # Get actual dimensions from the encoded data
620
  actual_vocab_dim = pentachora.shape[-1]
621
+
622
  print(f"Encoded pentachora shape: {pentachora.shape}")
623
  print(f"Detected vocabulary dimension: {actual_vocab_dim}")
624
+
625
  # Validate basic shape requirements
626
  if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5:
627
  print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}")
628
  print("Using random initialization")
629
  return
630
+
631
+ # Update vocabulary dimension
632
  self.vocab_dim = actual_vocab_dim
633
  self.cls_tokens.vocab_dim = actual_vocab_dim
634
  self.geometric_proj.vocab_dim = actual_vocab_dim
635
+
636
  # Replace class_pentachora with the loaded vocabulary
637
  self.cls_tokens.class_pentachora = nn.Parameter(
638
  torch.tensor(pentachora, dtype=torch.float32)
639
  )
640
+
641
  # Update/create projection layer if dimensions differ
642
  if actual_vocab_dim != self.dim:
643
+ self.cls_tokens.vocab_to_model = nn.Linear(actual_vocab_dim, self.dim)
644
  else:
645
+ self.cls_tokens.vocab_to_model = nn.Identity()
646
+
647
+ # Rebuild geometric projection components
648
+ self.geometric_proj.to_vocab_space = nn.Linear(self.dim, actual_vocab_dim)
649
  self.geometric_proj.vertex_projections = nn.ModuleList([
650
+ nn.Linear(actual_vocab_dim, actual_vocab_dim, bias=False) for _ in range(5)
651
  ])
652
+
653
  # Re-initialize the new layers
654
+ nn.init.xavier_uniform_(self.geometric_proj.to_vocab_space.weight)
655
  for proj in self.geometric_proj.vertex_projections:
656
  nn.init.xavier_uniform_(proj.weight)
657
  if actual_vocab_dim != self.dim:
658
+ nn.init.xavier_uniform_(self.cls_tokens.vocab_to_model.weight)
659
+
660
  print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary")
661
  print(f" Vocabulary dimension: {actual_vocab_dim}")
662
  print(f" Model internal dimension: {self.dim}")
663
+ print(f" CLS tokens now reference vocabulary embeddings")
664
+
665
  except Exception as e:
666
  print(f"Error initializing from vocabulary: {e}")
667
  print("Using random initialization")
668
+
669
  def _get_cifar100_classes(self):
670
  """Get CIFAR-100 class names."""
671
  return [
672
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
673
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
674
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
675
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
676
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
677
  'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
678
  'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
679
  'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
 
684
  'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
685
  'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
686
  ]
687
+
688
+ def forward_features(self, x: torch.Tensor, class_indices: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
689
+ """
690
+ Extract features from input.
691
 
692
+ Args:
693
+ x: Input images [B, 3, H, W]
694
+ class_indices: Optional class indices for class-aware CLS tokens [B]
695
+ """
696
+ B = x.shape[0]
697
+
698
  # Patch embedding
699
  x = self.patch_embed(x)
700
  x = x + self.pos_embed
701
  x = self.pos_drop(x)
702
+
703
+ # Get hierarchical CLS tokens (now properly using vocabulary)
704
+ global_cls, vertex_cls = self.cls_tokens(B, class_indices)
705
+
706
  # Concatenate CLS tokens with patches
707
  x = torch.cat([global_cls, vertex_cls, x], dim=1)
708
+
709
  # Apply transformer blocks
710
  for i, block in enumerate(self.blocks):
711
  preserve = i < self.preserve_structure_until_layer
712
  x = block(x, preserve_structure=preserve)
713
+
714
  # Apply final norm
715
  x = self.norm(x)
716
+
717
  # Split tokens
718
  global_cls = x[:, 0]
719
  vertex_cls = x[:, 1:6]
720
  patches = x[:, 6:]
721
+
722
  return {
723
  'global_cls': global_cls,
724
  'vertex_cls': vertex_cls,
725
  'patches': patches
726
  }
727
+
728
+ def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
729
+ """
730
+ Forward pass through the model.
731
 
732
+ Args:
733
+ x: Input images [B, 3, H, W]
734
+ targets: Optional target labels for class-aware processing [B]
735
+ """
736
+ # During training, use target labels for class-specific CLS initialization
737
+ class_indices = targets if self.training and targets is not None else None
738
 
739
+ features = self.forward_features(x, class_indices)
740
+
741
+ # Primary classification using prototype matching
742
+ if self.use_prototype_classifier:
743
+ # Get class prototypes from vocabulary
744
+ prototypes = self.cls_tokens.get_class_prototypes() # [C, D]
745
+ prototypes = F.normalize(prototypes, dim=-1)
746
+
747
+ # Normalize global CLS tokens
748
+ global_cls_norm = F.normalize(features['global_cls'], dim=-1) # [B, D]
749
+
750
+ # Compute similarity to prototypes
751
+ logits = torch.matmul(global_cls_norm, prototypes.T) * 20.0 # Scale for better gradients
752
+ else:
753
+ # Traditional linear classification
754
+ logits = self.head(features['global_cls'])
755
+
756
  # Auxiliary classification using vertex tokens
757
  B = features['vertex_cls'].shape[0]
758
  vertex_flat = features['vertex_cls'].reshape(B, -1)
759
  aux_logits = self.head_aux(vertex_flat)
760
+
761
  # Geometric alignment scores
762
  geometric_alignments = self.geometric_proj(
763
+ features['patches'],
764
  self.cls_tokens.class_pentachora
765
  )
766
+
767
  return {
768
  'logits': logits,
769
  'aux_logits': aux_logits,
 
785
  self.aux_weight = aux_weight
786
  self.geo_weight = geo_weight
787
  self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing)
788
+
789
  def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
790
  """Compute combined loss."""
791
  # Primary classification loss
792
  loss = self.criterion(outputs['logits'], targets)
793
+
794
  # Auxiliary loss from vertex tokens
795
  if 'aux_logits' in outputs and self.aux_weight > 0:
796
  aux_loss = self.criterion(outputs['aux_logits'], targets)
797
  loss = loss + self.aux_weight * aux_loss
798
+
799
  # Geometric alignment loss
800
  if 'geometric_alignments' in outputs and self.geo_weight > 0:
801
  # Average over patches
802
  geo_logits = outputs['geometric_alignments'].mean(dim=1)
803
  geo_loss = self.criterion(geo_logits, targets)
804
  loss = loss + self.geo_weight * geo_loss
805
+
806
  return loss
807
 
808
  # ============================================
 
811
 
812
  MODEL_CONFIGS = {
813
  'pentachora_spark': PentachoraConfig(
814
+ dim=100, depth=5, heads=4, mlp_ratio=4.0,
815
+ preserve_structure_until_layer=1,
816
  dropout_rate=0.0, drop_path_rate=0.0
817
  ),
818
  'pentachora_tiny': PentachoraConfig(
 
842
  **kwargs) -> PentachoraViT:
843
  """
844
  Create PentachoraViT model.
845
+
846
  Args:
847
  variant: Model variant name
848
  pretrained: Whether to load pretrained weights
849
  **kwargs: Override config parameters (including vocab_dim)
850
+
851
  Returns:
852
  PentachoraViT model
853
  """
854
  if variant not in MODEL_CONFIGS:
855
  raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}")
856
+
857
  config = MODEL_CONFIGS[variant]
858
+
859
  # Override config with kwargs
860
  for key, value in kwargs.items():
861
  setattr(config, key, value)
862
+
863
  model = PentachoraViT(config)
864
+
865
  if pretrained:
866
  warnings.warn("Pretrained weights not available yet")
867
+
868
  return model
869
 
870
+ # Convenience functions for each variant
871
  def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT:
872
  """Create spark variant (smallest)."""
873
  return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs)