seawolf2357 commited on
Commit
28f2970
·
verified ·
1 Parent(s): ae03ea7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +984 -0
app.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🔮 PHOENIX Retention Research Platform
3
+ Complete Integration - Single File
4
+
5
+ L40S GPU + Persistent Storage (SQLite + ChromaDB)
6
+ VIDraft AI Research Lab
7
+ """
8
+
9
+ import gradio as gr
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import sqlite3
14
+ import json
15
+ import time
16
+ import numpy as np
17
+ from datetime import datetime
18
+ from pathlib import Path
19
+ import plotly.graph_objects as go
20
+ import plotly.express as px
21
+ import pandas as pd
22
+ from typing import Dict, List, Any, Tuple, Optional
23
+ import chromadb
24
+ from chromadb.config import Settings
25
+ from einops import rearrange, repeat
26
+
27
+ # =====================================================
28
+ # 전역 설정
29
+ # =====================================================
30
+
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ STORAGE_PATH = "/data" # HF Spaces 영구 스토리지
33
+ DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
34
+ VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
35
+
36
+ # 디렉토리 생성
37
+ Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
38
+ Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True)
39
+
40
+ print(f"🚀 PHOENIX Platform initialized on {DEVICE}")
41
+ print(f"💾 Storage: {STORAGE_PATH}")
42
+
43
+ # =====================================================
44
+ # 데이터베이스 관리 클래스
45
+ # =====================================================
46
+
47
+ class ExperimentDatabase:
48
+ """SQLite 데이터베이스 관리"""
49
+
50
+ def __init__(self, db_path: str):
51
+ self.db_path = db_path
52
+ self.init_database()
53
+
54
+ def init_database(self):
55
+ """데이터베이스 초기화"""
56
+ with sqlite3.connect(self.db_path) as conn:
57
+ cursor = conn.cursor()
58
+
59
+ # 실험 테이블
60
+ cursor.execute("""
61
+ CREATE TABLE IF NOT EXISTS experiments (
62
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
63
+ model_type TEXT NOT NULL,
64
+ sequence_length INTEGER,
65
+ power_mode TEXT,
66
+ compression_level REAL,
67
+ use_hierarchical BOOLEAN,
68
+ elapsed_time REAL,
69
+ memory_mb REAL,
70
+ throughput REAL,
71
+ avg_retention REAL,
72
+ compression_ratio REAL,
73
+ config_json TEXT,
74
+ metrics_json TEXT,
75
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
76
+ )
77
+ """)
78
+
79
+ # 인덱스 생성
80
+ cursor.execute("""
81
+ CREATE INDEX IF NOT EXISTS idx_model_type
82
+ ON experiments(model_type)
83
+ """)
84
+
85
+ cursor.execute("""
86
+ CREATE INDEX IF NOT EXISTS idx_timestamp
87
+ ON experiments(timestamp DESC)
88
+ """)
89
+
90
+ conn.commit()
91
+ print("✅ Database initialized")
92
+
93
+ def save_experiment(self, config: Dict, metrics: Dict) -> int:
94
+ """실험 저장"""
95
+ with sqlite3.connect(self.db_path) as conn:
96
+ cursor = conn.cursor()
97
+
98
+ cursor.execute("""
99
+ INSERT INTO experiments (
100
+ model_type, sequence_length, power_mode,
101
+ compression_level, use_hierarchical, elapsed_time,
102
+ memory_mb, throughput, avg_retention, compression_ratio,
103
+ config_json, metrics_json
104
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
105
+ """, (
106
+ config.get('model_type'),
107
+ config.get('sequence_length'),
108
+ config.get('power_mode'),
109
+ config.get('compression_level'),
110
+ config.get('use_hierarchical'),
111
+ metrics.get('elapsed_time'),
112
+ metrics.get('memory_mb'),
113
+ metrics.get('throughput'),
114
+ metrics.get('avg_retention'),
115
+ metrics.get('compression_ratio'),
116
+ json.dumps(config),
117
+ json.dumps(metrics)
118
+ ))
119
+
120
+ conn.commit()
121
+ return cursor.lastrowid
122
+
123
+ def get_experiment(self, exp_id: int) -> Optional[Dict]:
124
+ """실험 조회"""
125
+ with sqlite3.connect(self.db_path) as conn:
126
+ conn.row_factory = sqlite3.Row
127
+ cursor = conn.cursor()
128
+
129
+ cursor.execute("SELECT * FROM experiments WHERE id = ?", (exp_id,))
130
+ row = cursor.fetchone()
131
+ return dict(row) if row else None
132
+
133
+ def get_recent_experiments(self, limit: int = 20) -> List[Dict]:
134
+ """최근 실험 조회"""
135
+ with sqlite3.connect(self.db_path) as conn:
136
+ conn.row_factory = sqlite3.Row
137
+ cursor = conn.cursor()
138
+
139
+ cursor.execute("""
140
+ SELECT * FROM experiments
141
+ ORDER BY timestamp DESC
142
+ LIMIT ?
143
+ """, (limit,))
144
+
145
+ rows = cursor.fetchall()
146
+ return [dict(row) for row in rows]
147
+
148
+ def get_statistics(self) -> Dict:
149
+ """통계 조회"""
150
+ with sqlite3.connect(self.db_path) as conn:
151
+ cursor = conn.cursor()
152
+
153
+ cursor.execute("SELECT COUNT(*) FROM experiments")
154
+ total = cursor.fetchone()[0]
155
+
156
+ cursor.execute("""
157
+ SELECT model_type, COUNT(*) as count
158
+ FROM experiments
159
+ GROUP BY model_type
160
+ """)
161
+ by_model = dict(cursor.fetchall())
162
+
163
+ return {
164
+ 'total_experiments': total,
165
+ 'by_model': by_model
166
+ }
167
+
168
+ class RetentionVectorStore:
169
+ """ChromaDB 벡터 저장소"""
170
+
171
+ def __init__(self, persist_directory: str):
172
+ self.client = chromadb.Client(Settings(
173
+ persist_directory=persist_directory,
174
+ anonymized_telemetry=False
175
+ ))
176
+
177
+ self.collection = self.client.get_or_create_collection(
178
+ name="retention_states",
179
+ metadata={"description": "PHOENIX Retention states"}
180
+ )
181
+ print("✅ Vector store initialized")
182
+
183
+ def add_retention_state(self, experiment_id: int, states: Dict, metadata: Dict):
184
+ """Retention state 저장"""
185
+ # State를 벡터로 변환
186
+ state_vector = self._states_to_vector(states)
187
+
188
+ self.collection.add(
189
+ embeddings=[state_vector.tolist()],
190
+ metadatas=[{**metadata, 'experiment_id': experiment_id}],
191
+ ids=[f"exp_{experiment_id}"]
192
+ )
193
+
194
+ def search(self, query: str, top_k: int = 10) -> List[Dict]:
195
+ """실험 검색"""
196
+ query_vector = self._text_to_vector(query)
197
+
198
+ results = self.collection.query(
199
+ query_embeddings=[query_vector.tolist()],
200
+ n_results=top_k
201
+ )
202
+
203
+ if not results['ids'][0]:
204
+ return []
205
+
206
+ formatted_results = []
207
+ for i in range(len(results['ids'][0])):
208
+ formatted_results.append({
209
+ 'experiment_id': results['metadatas'][0][i].get('experiment_id'),
210
+ 'score': 1.0 - results['distances'][0][i],
211
+ 'metadata': results['metadatas'][0][i]
212
+ })
213
+
214
+ return formatted_results
215
+
216
+ def _states_to_vector(self, states: Dict) -> np.ndarray:
217
+ """States를 고정 크기 벡터로 변환"""
218
+ vectors = []
219
+ for key, value in states.items():
220
+ if isinstance(value, (int, float)):
221
+ vectors.append(float(value))
222
+ elif isinstance(value, torch.Tensor):
223
+ vectors.append(value.mean().item())
224
+ vectors.append(value.std().item())
225
+
226
+ # 고정 크기로 패딩/자르기
227
+ target_size = 128
228
+ if len(vectors) < target_size:
229
+ vectors.extend([0.0] * (target_size - len(vectors)))
230
+ else:
231
+ vectors = vectors[:target_size]
232
+
233
+ return np.array(vectors)
234
+
235
+ def _text_to_vector(self, text: str) -> np.ndarray:
236
+ """텍스트를 벡터로 변환 (간단한 해시 기반)"""
237
+ # 실제로는 sentence-transformers 사용 권장
238
+ hash_val = hash(text) % (2**31)
239
+ np.random.seed(hash_val)
240
+ return np.random.randn(128)
241
+
242
+ # =====================================================
243
+ # PHOENIX Retention 모델 구현
244
+ # =====================================================
245
+
246
+ class HierarchicalRetention(nn.Module):
247
+ """계층적 Retention (단기/중기/장기)"""
248
+
249
+ def __init__(self, d_model, d_state):
250
+ super().__init__()
251
+ self.d_model = d_model
252
+ self.d_state = d_state
253
+
254
+ # 3-tier states
255
+ self.short_decay = 0.5
256
+ self.medium_decay = 0.8
257
+ self.long_decay = 0.95
258
+
259
+ # Projection layers
260
+ self.proj_short = nn.Linear(d_model, d_state)
261
+ self.proj_medium = nn.Linear(d_state, d_state)
262
+ self.proj_long = nn.Linear(d_state, d_state * 2)
263
+
264
+ # Fusion
265
+ self.fusion = nn.Linear(d_state * 4, d_model)
266
+
267
+ def forward(self, x):
268
+ batch_size, seq_len, _ = x.shape
269
+
270
+ # Initialize states
271
+ short_state = torch.zeros(batch_size, self.d_state).to(x.device)
272
+ medium_state = torch.zeros(batch_size, self.d_state).to(x.device)
273
+ long_state = torch.zeros(batch_size, self.d_state * 2).to(x.device)
274
+
275
+ outputs = []
276
+
277
+ for t in range(seq_len):
278
+ x_t = x[:, t, :]
279
+
280
+ # Short-term update (every token)
281
+ short_input = self.proj_short(x_t)
282
+ short_state = self.short_decay * short_state + short_input
283
+
284
+ # Medium-term update (every 8 tokens)
285
+ if t % 8 == 0:
286
+ medium_state = self.medium_decay * medium_state + self.proj_medium(short_state)
287
+
288
+ # Long-term update (every 64 tokens)
289
+ if t % 64 == 0:
290
+ long_state = self.long_decay * long_state + self.proj_long(medium_state)
291
+
292
+ # Fuse all tiers
293
+ combined = torch.cat([short_state, medium_state, long_state], dim=-1)
294
+ output_t = self.fusion(combined)
295
+ outputs.append(output_t)
296
+
297
+ outputs = torch.stack(outputs, dim=1)
298
+
299
+ return outputs, {
300
+ 'short_state': short_state,
301
+ 'medium_state': medium_state,
302
+ 'long_state': long_state
303
+ }
304
+
305
+ class AdaptiveCompression(nn.Module):
306
+ """적응적 압축"""
307
+
308
+ def __init__(self, d_state):
309
+ super().__init__()
310
+ self.importance_net = nn.Linear(d_state, 1)
311
+ self.compressor = nn.Sequential(
312
+ nn.Linear(d_state, d_state // 2),
313
+ nn.GELU(),
314
+ nn.Linear(d_state // 2, d_state)
315
+ )
316
+
317
+ def forward(self, state, importance_threshold=0.5):
318
+ importance = torch.sigmoid(self.importance_net(state))
319
+
320
+ # 중요도에 따라 압축
321
+ mask = (importance > importance_threshold).float()
322
+ compressed = state * mask + self.compressor(state) * (1 - mask)
323
+
324
+ return compressed, importance.mean().item()
325
+
326
+ class DynamicPowerRetention(nn.Module):
327
+ """동적 Power 조절"""
328
+
329
+ def __init__(self, d_model):
330
+ super().__init__()
331
+ self.power_predictor = nn.Sequential(
332
+ nn.Linear(d_model, 64),
333
+ nn.ReLU(),
334
+ nn.Linear(64, 1),
335
+ nn.Sigmoid()
336
+ )
337
+
338
+ self.min_power = 1.5
339
+ self.max_power = 5.0
340
+
341
+ def compute_power(self, x):
342
+ power_ratio = self.power_predictor(x.mean(dim=1, keepdim=True))
343
+ power = self.min_power + power_ratio * (self.max_power - self.min_power)
344
+ return power.mean().item()
345
+
346
+ class PHOENIXRetention(nn.Module):
347
+ """PHOENIX Retention 통합 모델"""
348
+
349
+ def __init__(self, d_model=512, d_state=256, num_layers=12, device='cuda'):
350
+ super().__init__()
351
+ self.d_model = d_model
352
+ self.d_state = d_state
353
+ self.num_layers = num_layers
354
+ self.device = device
355
+
356
+ # Core components
357
+ self.hierarchical = HierarchicalRetention(d_model, d_state)
358
+ self.compressor = AdaptiveCompression(d_state)
359
+ self.power_adapter = DynamicPowerRetention(d_model)
360
+
361
+ # Layer norm
362
+ self.norm = nn.LayerNorm(d_model)
363
+
364
+ self.to(device)
365
+
366
+ def forward(self, x, return_states=True):
367
+ # Hierarchical retention
368
+ h_out, states = self.hierarchical(x)
369
+
370
+ # Adaptive compression
371
+ compressed_state = states['short_state']
372
+ compressed, compression_ratio = self.compressor(compressed_state)
373
+
374
+ # Dynamic power
375
+ power = self.power_adapter.compute_power(x)
376
+
377
+ # Normalize output
378
+ output = self.norm(h_out)
379
+
380
+ if return_states:
381
+ return output, {
382
+ 'short_state': states['short_state'],
383
+ 'medium_state': states['medium_state'],
384
+ 'long_state': states['long_state'],
385
+ 'compression_ratio': compression_ratio,
386
+ 'dynamic_power': power
387
+ }
388
+ return output
389
+
390
+ class BrumbyRetention(nn.Module):
391
+ """Brumby 베이스라인"""
392
+
393
+ def __init__(self, d_model=512, d_state=256, power=2, device='cuda'):
394
+ super().__init__()
395
+ self.d_model = d_model
396
+ self.d_state = d_state
397
+ self.power = power
398
+ self.device = device
399
+
400
+ self.proj_q = nn.Linear(d_model, d_state)
401
+ self.proj_k = nn.Linear(d_model, d_state)
402
+ self.proj_v = nn.Linear(d_model, d_state)
403
+ self.proj_out = nn.Linear(d_state, d_model)
404
+
405
+ self.to(device)
406
+
407
+ def forward(self, x, return_states=True):
408
+ batch_size, seq_len, _ = x.shape
409
+
410
+ Q = self.proj_q(x)
411
+ K = self.proj_k(x)
412
+ V = self.proj_v(x)
413
+
414
+ # Simple retention (simplified)
415
+ state = torch.zeros(batch_size, self.d_state).to(x.device)
416
+ outputs = []
417
+
418
+ for t in range(seq_len):
419
+ state = 0.9 * state + V[:, t, :] @ K[:, t, :].T
420
+ output_t = state @ Q[:, t, :].unsqueeze(-1)
421
+ outputs.append(output_t.squeeze(-1))
422
+
423
+ outputs = torch.stack(outputs, dim=1)
424
+ outputs = self.proj_out(outputs)
425
+
426
+ if return_states:
427
+ return outputs, {
428
+ 'state': state,
429
+ 'power': self.power
430
+ }
431
+ return outputs
432
+
433
+ # =====================================================
434
+ # 유틸리티 함수들
435
+ # =====================================================
436
+
437
+ def calculate_metrics(output, states):
438
+ """메트릭 계산"""
439
+ metrics = {}
440
+
441
+ # 메모리 사용량 (대략적)
442
+ total_params = sum(p.numel() for p in [output] if isinstance(p, torch.Tensor))
443
+ metrics['memory_mb'] = (total_params * 4) / (1024 * 1024) # float32 = 4 bytes
444
+
445
+ # Retention 비율
446
+ if 'short_state' in states:
447
+ metrics['avg_retention'] = states['short_state'].abs().mean().item()
448
+ else:
449
+ metrics['avg_retention'] = 0.5
450
+
451
+ # 압축률
452
+ if 'compression_ratio' in states:
453
+ metrics['compression_ratio'] = states['compression_ratio']
454
+ else:
455
+ metrics['compression_ratio'] = 0.5
456
+
457
+ # State 크기
458
+ if 'short_state' in states:
459
+ metrics['state_size'] = states['short_state'].shape[-1]
460
+ else:
461
+ metrics['state_size'] = 256
462
+
463
+ return metrics
464
+
465
+ def plot_retention_states(states):
466
+ """Retention states 시각화"""
467
+ fig = go.Figure()
468
+
469
+ if 'short_state' in states:
470
+ short = states['short_state'].detach().cpu().numpy().flatten()
471
+ fig.add_trace(go.Scatter(
472
+ y=short[:100],
473
+ mode='lines',
474
+ name='Short-term',
475
+ line=dict(color='red', width=2)
476
+ ))
477
+
478
+ if 'medium_state' in states:
479
+ medium = states['medium_state'].detach().cpu().numpy().flatten()
480
+ fig.add_trace(go.Scatter(
481
+ y=medium[:100],
482
+ mode='lines',
483
+ name='Medium-term',
484
+ line=dict(color='blue', width=2)
485
+ ))
486
+
487
+ if 'long_state' in states:
488
+ long = states['long_state'].detach().cpu().numpy().flatten()
489
+ fig.add_trace(go.Scatter(
490
+ y=long[:100],
491
+ mode='lines',
492
+ name='Long-term',
493
+ line=dict(color='green', width=2)
494
+ ))
495
+
496
+ fig.update_layout(
497
+ title='Retention State Visualization',
498
+ xaxis_title='Dimension',
499
+ yaxis_title='Activation',
500
+ hovermode='x unified',
501
+ template='plotly_white'
502
+ )
503
+
504
+ return fig
505
+
506
+ def plot_memory_usage(metrics):
507
+ """메모리 사용량 시각화"""
508
+ fig = go.Figure(go.Bar(
509
+ x=['Memory (MB)', 'State Size', 'Compression Ratio'],
510
+ y=[
511
+ metrics.get('memory_mb', 0),
512
+ metrics.get('state_size', 0) / 10, # Scale down
513
+ metrics.get('compression_ratio', 0) * 100 # Percentage
514
+ ],
515
+ marker_color=['lightblue', 'lightgreen', 'lightyellow']
516
+ ))
517
+
518
+ fig.update_layout(
519
+ title='Memory & Compression Metrics',
520
+ yaxis_title='Value',
521
+ template='plotly_white'
522
+ )
523
+
524
+ return fig
525
+
526
+ def plot_performance_comparison(df):
527
+ """성능 비교 시각화"""
528
+ fig = go.Figure()
529
+
530
+ # 속도 비교
531
+ fig.add_trace(go.Bar(
532
+ name='Execution Time (s)',
533
+ x=df['model'],
534
+ y=df['time'],
535
+ marker_color='indianred'
536
+ ))
537
+
538
+ # 처리량 비교
539
+ fig.add_trace(go.Bar(
540
+ name='Throughput (tokens/s)',
541
+ x=df['model'],
542
+ y=df['throughput'],
543
+ marker_color='lightsalmon',
544
+ yaxis='y2'
545
+ ))
546
+
547
+ fig.update_layout(
548
+ title='Model Performance Comparison',
549
+ xaxis_title='Model',
550
+ yaxis_title='Time (s)',
551
+ yaxis2=dict(
552
+ title='Throughput',
553
+ overlaying='y',
554
+ side='right'
555
+ ),
556
+ barmode='group',
557
+ template='plotly_white'
558
+ )
559
+
560
+ return fig
561
+
562
+ # =====================================================
563
+ # 모델 초기화
564
+ # =====================================================
565
+
566
+ def initialize_models():
567
+ """모델들 초기화"""
568
+ models = {}
569
+
570
+ try:
571
+ models['phoenix_small'] = PHOENIXRetention(
572
+ d_model=512,
573
+ d_state=256,
574
+ num_layers=12,
575
+ device=DEVICE
576
+ )
577
+
578
+ models['phoenix_medium'] = PHOENIXRetention(
579
+ d_model=1024,
580
+ d_state=512,
581
+ num_layers=24,
582
+ device=DEVICE
583
+ )
584
+
585
+ models['brumby_baseline'] = BrumbyRetention(
586
+ d_model=512,
587
+ d_state=256,
588
+ power=2,
589
+ device=DEVICE
590
+ )
591
+
592
+ print("✅ Models initialized successfully")
593
+ return models
594
+
595
+ except Exception as e:
596
+ print(f"❌ Model initialization failed: {e}")
597
+ return {}
598
+
599
+ # 데이터베이스 및 모델 초기화
600
+ db = ExperimentDatabase(DB_PATH)
601
+ vector_store = RetentionVectorStore(VECTOR_DB_PATH)
602
+ MODELS = initialize_models()
603
+
604
+ # =====================================================
605
+ # Gradio 인터페이스 함수들
606
+ # =====================================================
607
+
608
+ def run_retention_experiment(
609
+ model_type, input_text, sequence_length,
610
+ power_mode, compression_level, use_hierarchical
611
+ ):
612
+ """PHOENIX Retention 실험 실행"""
613
+ try:
614
+ start_time = time.time()
615
+
616
+ if model_type not in MODELS:
617
+ return "❌ 모델을 찾을 수 없습니다.", None, None
618
+
619
+ model = MODELS[model_type]
620
+
621
+ # 실험 설정
622
+ config = {
623
+ 'model_type': model_type,
624
+ 'sequence_length': sequence_length,
625
+ 'power_mode': power_mode,
626
+ 'compression_level': compression_level,
627
+ 'use_hierarchical': use_hierarchical,
628
+ 'timestamp': datetime.now().isoformat()
629
+ }
630
+
631
+ # 더미 입력 생성
632
+ x = torch.randn(1, sequence_length, model.d_model).to(DEVICE)
633
+
634
+ # Forward pass
635
+ with torch.no_grad():
636
+ output, states = model(x, return_states=True)
637
+
638
+ elapsed_time = time.time() - start_time
639
+
640
+ # 메트릭 계산
641
+ metrics = calculate_metrics(output, states)
642
+ metrics['elapsed_time'] = elapsed_time
643
+ metrics['throughput'] = sequence_length / elapsed_time
644
+
645
+ # 데이터베이스에 저장
646
+ experiment_id = db.save_experiment(config, metrics)
647
+
648
+ # 벡터 저장소에 저장
649
+ vector_store.add_retention_state(experiment_id, states, config)
650
+
651
+ # 결과 텍스트
652
+ result_text = f"""
653
+ ## 🎯 실험 결과 (ID: {experiment_id})
654
+
655
+ ### ⚙️ 설정
656
+ - **모델**: {model_type}
657
+ - **시퀀스 길이**: {sequence_length} 토큰
658
+ - **Power 모드**: {power_mode}
659
+ - **압축 레벨**: {compression_level}
660
+ - **계층적 사용**: {"✅" if use_hierarchical else "❌"}
661
+
662
+ ### 📊 성능 메트릭
663
+ - **실행 시간**: {elapsed_time:.3f}초
664
+ - **처리 속도**: {metrics['throughput']:.1f} 토큰/초
665
+ - **메모리 사용**: {metrics['memory_mb']:.1f} MB
666
+ - **State 크기**: {metrics['state_size']} 차원
667
+
668
+ ### 🧠 Retention 분석
669
+ - **평균 Retention 비율**: {metrics['avg_retention']:.3f}
670
+ - **압축률**: {metrics['compression_ratio']:.2%}
671
+ - **동적 Power**: {states.get('dynamic_power', 2.0):.2f}
672
+
673
+ ✅ **실험이 성공적으로 완료되었습니다!**
674
+ """
675
+
676
+ # 시각화
677
+ fig_states = plot_retention_states(states)
678
+ fig_memory = plot_memory_usage(metrics)
679
+
680
+ return result_text, fig_states, fig_memory
681
+
682
+ except Exception as e:
683
+ return f"❌ 실험 실패: {str(e)}", None, None
684
+
685
+ def compare_retention_methods(input_text, sequence_length, benchmark_tasks):
686
+ """모델 비교"""
687
+ try:
688
+ results = []
689
+
690
+ for model_name, model in MODELS.items():
691
+ start_time = time.time()
692
+
693
+ x = torch.randn(1, sequence_length, model.d_model).to(DEVICE)
694
+
695
+ with torch.no_grad():
696
+ output, states = model(x, return_states=True)
697
+
698
+ elapsed_time = time.time() - start_time
699
+ metrics = calculate_metrics(output, states)
700
+
701
+ results.append({
702
+ 'model': model_name,
703
+ 'time': elapsed_time,
704
+ 'memory': metrics.get('memory_mb', 0),
705
+ 'throughput': sequence_length / elapsed_time
706
+ })
707
+
708
+ df = pd.DataFrame(results)
709
+ fig = plot_performance_comparison(df)
710
+
711
+ comparison_text = f"""
712
+ ## 🏆 모델 비교 결과
713
+
714
+ ### ⚡ 속도 순위
715
+ {df.sort_values('time')[['model', 'time']].to_markdown(index=False)}
716
+
717
+ ### 🚀 처리량 순위
718
+ {df.sort_values('throughput', ascending=False)[['model', 'throughput']].to_markdown(index=False)}
719
+
720
+ ### 💾 메모리 효율성
721
+ {df.sort_values('memory')[['model', 'memory']].to_markdown(index=False)}
722
+ """
723
+
724
+ return comparison_text, fig
725
+
726
+ except Exception as e:
727
+ return f"❌ 비교 실패: {str(e)}", None
728
+
729
+ def search_experiments(query, top_k=10):
730
+ """실험 검색"""
731
+ try:
732
+ results = vector_store.search(query, top_k=top_k)
733
+
734
+ if not results:
735
+ return "🔍 검색 결과가 없습니다."
736
+
737
+ search_text = "## 🔍 검색 결과\n\n"
738
+
739
+ for i, result in enumerate(results, 1):
740
+ exp_id = result['experiment_id']
741
+ score = result['score']
742
+ metadata = result['metadata']
743
+
744
+ search_text += f"""
745
+ ### {i}. 실험 #{exp_id} (유사도: {score:.3f})
746
+ - **모델**: {metadata.get('model_type', 'N/A')}
747
+ - **시퀀스 길이**: {metadata.get('sequence_length', 'N/A')}
748
+ - **시간**: {metadata.get('timestamp', 'N/A')}
749
+ ---
750
+ """
751
+
752
+ return search_text
753
+
754
+ except Exception as e:
755
+ return f"❌ 검색 실패: {str(e)}"
756
+
757
+ def view_experiment_history(limit=20):
758
+ """실험 이력 조회"""
759
+ try:
760
+ experiments = db.get_recent_experiments(limit=limit)
761
+
762
+ if not experiments:
763
+ return "📭 실험 이력이 없습니다.", None
764
+
765
+ df = pd.DataFrame(experiments)
766
+
767
+ # 시간별 성능 추이
768
+ fig = px.line(
769
+ df,
770
+ x='timestamp',
771
+ y='elapsed_time',
772
+ color='model_type',
773
+ title='모델별 실행 시간 추이'
774
+ )
775
+
776
+ history_text = f"""
777
+ ## 📊 실험 이력 ({len(df)}개)
778
+
779
+ {df[['id', 'model_type', 'sequence_length', 'elapsed_time', 'throughput', 'timestamp']].to_markdown(index=False)}
780
+ """
781
+
782
+ return history_text, fig
783
+
784
+ except Exception as e:
785
+ return f"❌ 이력 조회 실패: {str(e)}", None
786
+
787
+ def get_database_statistics():
788
+ """데이터베이스 통계"""
789
+ try:
790
+ stats = db.get_statistics()
791
+
792
+ stats_text = f"""
793
+ ## 📊 데이터베이스 통계
794
+
795
+ ### 전체 현황
796
+ - **총 실험 수**: {stats['total_experiments']}
797
+
798
+ ### 모델별 실험 수
799
+ """
800
+ for model, count in stats['by_model'].items():
801
+ stats_text += f"- **{model}**: {count}개\n"
802
+
803
+ return stats_text
804
+
805
+ except Exception as e:
806
+ return f"❌ 통계 조회 실패: {str(e)}"
807
+
808
+ # =====================================================
809
+ # Gradio UI 구성
810
+ # =====================================================
811
+
812
+ with gr.Blocks(
813
+ title="🔮 PHOENIX Retention Research Platform",
814
+ theme=gr.themes.Soft(),
815
+ ) as demo:
816
+
817
+ gr.Markdown("""
818
+ # 🔮 PHOENIX Retention Research Platform
819
+
820
+ **Post-Hierarchical Optimized Efficient Neural Infinite-conteXt**
821
+
822
+ Brumby를 뛰어넘는 차세대 Attention-Free 아키텍처 연구 플랫폼
823
+
824
+ ---
825
+ """)
826
+
827
+ with gr.Tabs():
828
+
829
+ # Tab 1: 실험 실행
830
+ with gr.Tab("🧪 실험 실행"):
831
+ with gr.Row():
832
+ with gr.Column(scale=1):
833
+ model_select = gr.Dropdown(
834
+ choices=list(MODELS.keys()),
835
+ value='phoenix_small',
836
+ label="모델 선택"
837
+ )
838
+
839
+ input_text = gr.Textbox(
840
+ label="입력 텍스트",
841
+ placeholder="실험할 텍스트를 입력하세요...",
842
+ lines=5,
843
+ value="PHOENIX Retention hierarchical memory system"
844
+ )
845
+
846
+ sequence_length = gr.Slider(
847
+ minimum=16, maximum=1024, value=128, step=16,
848
+ label="시퀀스 길이"
849
+ )
850
+
851
+ power_mode = gr.Radio(
852
+ choices=["Fixed (2)", "Dynamic", "Adaptive"],
853
+ value="Dynamic",
854
+ label="Power 모드"
855
+ )
856
+
857
+ compression_level = gr.Slider(
858
+ minimum=0.0, maximum=1.0, value=0.5, step=0.1,
859
+ label="압축 레벨"
860
+ )
861
+
862
+ use_hierarchical = gr.Checkbox(
863
+ value=True,
864
+ label="계층적 Retention 사용"
865
+ )
866
+
867
+ run_btn = gr.Button("🚀 실험 실행", variant="primary")
868
+
869
+ with gr.Column(scale=2):
870
+ result_output = gr.Markdown(label="실험 결과")
871
+
872
+ with gr.Row():
873
+ states_plot = gr.Plot(label="Retention States")
874
+ memory_plot = gr.Plot(label="메모리 사용량")
875
+
876
+ run_btn.click(
877
+ fn=run_retention_experiment,
878
+ inputs=[model_select, input_text, sequence_length,
879
+ power_mode, compression_level, use_hierarchical],
880
+ outputs=[result_output, states_plot, memory_plot]
881
+ )
882
+
883
+ # Tab 2: 모델 비교
884
+ with gr.Tab("⚔️ 모델 비교"):
885
+ with gr.Row():
886
+ with gr.Column(scale=1):
887
+ compare_text = gr.Textbox(
888
+ label="비교 텍스트",
889
+ lines=5,
890
+ value="Performance comparison test"
891
+ )
892
+
893
+ compare_length = gr.Slider(
894
+ minimum=64, maximum=2048, value=512, step=64,
895
+ label="시퀀스 길이"
896
+ )
897
+
898
+ benchmark_tasks = gr.CheckboxGroup(
899
+ choices=["속도", "메모리", "처리량"],
900
+ value=["속도", "메모리"],
901
+ label="벤치마크 항목"
902
+ )
903
+
904
+ compare_btn = gr.Button("⚔️ 비교 시작", variant="primary")
905
+
906
+ with gr.Column(scale=2):
907
+ compare_result = gr.Markdown(label="비교 결과")
908
+ compare_plot = gr.Plot(label="성능 비교")
909
+
910
+ compare_btn.click(
911
+ fn=compare_retention_methods,
912
+ inputs=[compare_text, compare_length, benchmark_tasks],
913
+ outputs=[compare_result, compare_plot]
914
+ )
915
+
916
+ # Tab 3: 실험 이력
917
+ with gr.Tab("📊 실험 이력"):
918
+ with gr.Row():
919
+ with gr.Column(scale=1):
920
+ history_limit = gr.Slider(
921
+ minimum=10, maximum=100, value=20, step=10,
922
+ label="조회 개수"
923
+ )
924
+
925
+ history_btn = gr.Button("📊 이력 조회", variant="primary")
926
+
927
+ gr.Markdown("---")
928
+
929
+ search_query = gr.Textbox(
930
+ label="실험 검색",
931
+ placeholder="검색어 입력..."
932
+ )
933
+
934
+ search_btn = gr.Button("🔍 검색", variant="secondary")
935
+
936
+ gr.Markdown("---")
937
+
938
+ stats_btn = gr.Button("📈 통계 보기", variant="secondary")
939
+
940
+ with gr.Column(scale=2):
941
+ history_output = gr.Markdown(label="결과")
942
+ history_plot = gr.Plot(label="추이 그래프")
943
+
944
+ history_btn.click(
945
+ fn=view_experiment_history,
946
+ inputs=[history_limit],
947
+ outputs=[history_output, history_plot]
948
+ )
949
+
950
+ search_btn.click(
951
+ fn=search_experiments,
952
+ inputs=[search_query],
953
+ outputs=[history_output]
954
+ )
955
+
956
+ stats_btn.click(
957
+ fn=get_database_statistics,
958
+ outputs=[history_output]
959
+ )
960
+
961
+ gr.Markdown("""
962
+ ---
963
+
964
+ ### 🔥 PHOENIX 핵심 혁신
965
+
966
+ 1. **계층적 기억** - 단기/중기/장기 메모리 분리
967
+ 2. **적응적 압축** - 중요도 기반 동적 압축
968
+ 3. **동적 Power** - 입력 따라 자동 최적화
969
+ 4. **병렬 경로** - 다중 전략 동시 운영
970
+
971
+ **VIDraft AI Research Lab** | L40S GPU + Persistent Storage
972
+ """)
973
+
974
+ # =====================================================
975
+ # 앱 실행
976
+ # =====================================================
977
+
978
+ if __name__ == "__main__":
979
+ demo.queue(max_size=20)
980
+ demo.launch(
981
+ server_name="0.0.0.0",
982
+ server_port=7860,
983
+ share=False
984
+ )