Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
"""
|
| 2 |
🔮 PHOENIX Retention Research Platform
|
| 3 |
-
Real Implementation - Attention Replacement
|
| 4 |
|
| 5 |
L40S GPU + Persistent Storage (SQLite + ChromaDB)
|
| 6 |
Base Model: IBM Granite 4.0 H 350M (Attention → Retention)
|
| 7 |
VIDraft AI Research Lab
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import gradio as gr
|
|
@@ -45,13 +47,15 @@ print(f"💾 Storage: {STORAGE_PATH}")
|
|
| 45 |
print(f"🎯 Default Base Model: {DEFAULT_MODEL}")
|
| 46 |
|
| 47 |
# =====================================================
|
| 48 |
-
# PHOENIX Retention Attention (핵심!)
|
| 49 |
# =====================================================
|
| 50 |
|
| 51 |
class MultiScaleRetention(nn.Module):
|
| 52 |
"""
|
| 53 |
진짜 Retention Attention
|
| 54 |
Transformer의 Self-Attention을 완전히 교체
|
|
|
|
|
|
|
| 55 |
"""
|
| 56 |
|
| 57 |
def __init__(self, config, layer_idx=0):
|
|
@@ -109,6 +113,7 @@ class MultiScaleRetention(nn.Module):
|
|
| 109 |
):
|
| 110 |
"""
|
| 111 |
O(n) 복잡도 Retention 메커니즘
|
|
|
|
| 112 |
"""
|
| 113 |
batch_size, seq_len, input_dim = hidden_states.shape
|
| 114 |
|
|
@@ -123,46 +128,72 @@ class MultiScaleRetention(nn.Module):
|
|
| 123 |
past_key_value = past_key_values
|
| 124 |
|
| 125 |
# Q, K, V 계산
|
| 126 |
-
query_states = self.q_proj(hidden_states) # [B, L,
|
| 127 |
-
key_states = self.k_proj(hidden_states) # [B, L,
|
| 128 |
-
value_states = self.v_proj(hidden_states) # [B, L,
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
# ✅ Multi-head reshape
|
| 135 |
-
# [B, L,
|
| 136 |
query_states = query_states.view(
|
| 137 |
-
batch_size, seq_len, self.num_heads,
|
| 138 |
).transpose(1, 2)
|
| 139 |
|
| 140 |
key_states = key_states.view(
|
| 141 |
-
batch_size, seq_len, self.num_heads,
|
| 142 |
).transpose(1, 2)
|
| 143 |
|
| 144 |
value_states = value_states.view(
|
| 145 |
-
batch_size, seq_len, self.num_heads,
|
| 146 |
).transpose(1, 2)
|
| 147 |
|
| 148 |
# Retention 계산
|
| 149 |
retention_states = self._compute_retention(
|
| 150 |
-
query_states, key_states, value_states, past_key_value
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
-
# Reshape back: [B, num_heads, L, head_dim] -> [B, L,
|
| 154 |
retention_states = retention_states.transpose(1, 2).contiguous()
|
| 155 |
retention_states = retention_states.reshape(
|
| 156 |
-
batch_size, seq_len,
|
| 157 |
)
|
| 158 |
|
| 159 |
-
# Group norm
|
| 160 |
-
|
| 161 |
-
retention_states.
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# Output projection
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
return (attn_output, None, past_key_value)
|
| 168 |
|
|
@@ -171,10 +202,15 @@ class MultiScaleRetention(nn.Module):
|
|
| 171 |
queries: torch.Tensor, # [B, H, L, D]
|
| 172 |
keys: torch.Tensor, # [B, H, L, D]
|
| 173 |
values: torch.Tensor, # [B, H, L, D]
|
| 174 |
-
past_state: Optional[Tuple] = None
|
|
|
|
| 175 |
):
|
| 176 |
"""O(n) Retention 계산"""
|
| 177 |
-
batch_size, num_heads, seq_len,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# State 초기화
|
| 180 |
if past_state is not None:
|
|
@@ -301,12 +337,13 @@ class HierarchicalRetention(nn.Module):
|
|
| 301 |
|
| 302 |
|
| 303 |
# =====================================================
|
| 304 |
-
# 모델 변환 함수
|
| 305 |
# =====================================================
|
| 306 |
|
| 307 |
def replace_attention_with_retention(model, use_hierarchical=True):
|
| 308 |
"""
|
| 309 |
Transformer의 Attention을 PHOENIX Retention으로 교체
|
|
|
|
| 310 |
"""
|
| 311 |
print("🔄 Starting Attention → Retention conversion...")
|
| 312 |
|
|
@@ -353,24 +390,37 @@ def replace_attention_with_retention(model, use_hierarchical=True):
|
|
| 353 |
else:
|
| 354 |
new_retention = MultiScaleRetention(model.config, layer_idx)
|
| 355 |
|
| 356 |
-
# ✅ 가중치 복사
|
| 357 |
if hasattr(old_attn, 'q_proj'):
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
old_attn.k_proj.weight.data.clone()
|
| 366 |
-
new_retention.base_retention.v_proj.weight.data = \
|
| 367 |
-
old_attn.v_proj.weight.data.clone()
|
| 368 |
-
new_retention.base_retention.o_proj.weight.data = \
|
| 369 |
-
old_attn.o_proj.weight.data.clone()
|
| 370 |
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
# 교체
|
| 376 |
layer.self_attn = new_retention
|
|
@@ -1020,7 +1070,7 @@ def get_database_statistics():
|
|
| 1020 |
# =====================================================
|
| 1021 |
|
| 1022 |
with gr.Blocks(
|
| 1023 |
-
title="🔮 PHOENIX Retention Research Platform - Real Implementation",
|
| 1024 |
theme=gr.themes.Soft(),
|
| 1025 |
) as demo:
|
| 1026 |
|
|
@@ -1029,9 +1079,12 @@ with gr.Blocks(
|
|
| 1029 |
|
| 1030 |
**Post-Hierarchical Optimized Efficient Neural Infinite-conteXt**
|
| 1031 |
|
| 1032 |
-
## 🔥 진짜 PHOENIX - Attention → Retention 완전 교체
|
| 1033 |
|
| 1034 |
-
|
|
|
|
|
|
|
|
|
|
| 1035 |
|
| 1036 |
---
|
| 1037 |
""")
|
|
@@ -1175,14 +1228,14 @@ with gr.Blocks(
|
|
| 1175 |
## 🔥 PHOENIX 핵심 차이점
|
| 1176 |
|
| 1177 |
### 이전 버전 (가짜)
|
| 1178 |
-
```
|
| 1179 |
입력 → Granite Attention (O(n²)) → PHOENIX 후처리 → 출력
|
| 1180 |
-
```
|
| 1181 |
|
| 1182 |
### 현재 버전 (진짜)
|
| 1183 |
-
```
|
| 1184 |
입력 → PHOENIX Retention (O(n)) → 출력
|
| 1185 |
-
```
|
| 1186 |
|
| 1187 |
## ⏱️ 예상 변환 시간 (350M 모델)
|
| 1188 |
|
|
@@ -1196,7 +1249,7 @@ with gr.Blocks(
|
|
| 1196 |
- `Qwen/Qwen2.5-0.5B` (500M)
|
| 1197 |
- `meta-llama/Llama-3.2-1B` (1B)
|
| 1198 |
|
| 1199 |
-
**VIDraft AI Research Lab** | Real PHOENIX Implementation 🔥
|
| 1200 |
""")
|
| 1201 |
|
| 1202 |
if __name__ == "__main__":
|
|
@@ -1205,4 +1258,4 @@ if __name__ == "__main__":
|
|
| 1205 |
server_name="0.0.0.0",
|
| 1206 |
server_port=7860,
|
| 1207 |
share=False
|
| 1208 |
-
)
|
|
|
|
| 1 |
"""
|
| 2 |
🔮 PHOENIX Retention Research Platform
|
| 3 |
+
Real Implementation - Attention Replacement (FIXED)
|
| 4 |
|
| 5 |
L40S GPU + Persistent Storage (SQLite + ChromaDB)
|
| 6 |
Base Model: IBM Granite 4.0 H 350M (Attention → Retention)
|
| 7 |
VIDraft AI Research Lab
|
| 8 |
+
|
| 9 |
+
✅ FIX: Shape mismatch 문제 해결
|
| 10 |
"""
|
| 11 |
|
| 12 |
import gradio as gr
|
|
|
|
| 47 |
print(f"🎯 Default Base Model: {DEFAULT_MODEL}")
|
| 48 |
|
| 49 |
# =====================================================
|
| 50 |
+
# PHOENIX Retention Attention (핵심! - FIXED)
|
| 51 |
# =====================================================
|
| 52 |
|
| 53 |
class MultiScaleRetention(nn.Module):
|
| 54 |
"""
|
| 55 |
진짜 Retention Attention
|
| 56 |
Transformer의 Self-Attention을 완전히 교체
|
| 57 |
+
|
| 58 |
+
✅ FIX: Adaptive dimension handling
|
| 59 |
"""
|
| 60 |
|
| 61 |
def __init__(self, config, layer_idx=0):
|
|
|
|
| 113 |
):
|
| 114 |
"""
|
| 115 |
O(n) 복잡도 Retention 메커니즘
|
| 116 |
+
✅ FIX: Adaptive dimension handling
|
| 117 |
"""
|
| 118 |
batch_size, seq_len, input_dim = hidden_states.shape
|
| 119 |
|
|
|
|
| 128 |
past_key_value = past_key_values
|
| 129 |
|
| 130 |
# Q, K, V 계산
|
| 131 |
+
query_states = self.q_proj(hidden_states) # [B, L, ?]
|
| 132 |
+
key_states = self.k_proj(hidden_states) # [B, L, ?]
|
| 133 |
+
value_states = self.v_proj(hidden_states) # [B, L, ?]
|
| 134 |
+
|
| 135 |
+
# ✅ 실제 projection output 차원 확인
|
| 136 |
+
actual_proj_dim = query_states.shape[-1]
|
| 137 |
|
| 138 |
+
if actual_proj_dim != self.hidden_size:
|
| 139 |
+
print(f" ⚠️ Layer {self.layer_idx} Projection dim mismatch:")
|
| 140 |
+
print(f" Expected: {self.hidden_size}, Got: {actual_proj_dim}")
|
| 141 |
+
|
| 142 |
+
# Adaptive head_dim 계산
|
| 143 |
+
if actual_proj_dim % self.num_heads != 0:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"Projection output {actual_proj_dim} not divisible by "
|
| 146 |
+
f"num_heads {self.num_heads}"
|
| 147 |
+
)
|
| 148 |
+
adaptive_head_dim = actual_proj_dim // self.num_heads
|
| 149 |
+
print(f" 🔧 Using adaptive head_dim: {adaptive_head_dim}")
|
| 150 |
+
else:
|
| 151 |
+
adaptive_head_dim = self.head_dim
|
| 152 |
|
| 153 |
+
# ✅ Multi-head reshape (adaptive)
|
| 154 |
+
# [B, L, actual_proj_dim] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim]
|
| 155 |
query_states = query_states.view(
|
| 156 |
+
batch_size, seq_len, self.num_heads, adaptive_head_dim
|
| 157 |
).transpose(1, 2)
|
| 158 |
|
| 159 |
key_states = key_states.view(
|
| 160 |
+
batch_size, seq_len, self.num_heads, adaptive_head_dim
|
| 161 |
).transpose(1, 2)
|
| 162 |
|
| 163 |
value_states = value_states.view(
|
| 164 |
+
batch_size, seq_len, self.num_heads, adaptive_head_dim
|
| 165 |
).transpose(1, 2)
|
| 166 |
|
| 167 |
# Retention 계산
|
| 168 |
retention_states = self._compute_retention(
|
| 169 |
+
query_states, key_states, value_states, past_key_value,
|
| 170 |
+
adaptive_head_dim
|
| 171 |
)
|
| 172 |
|
| 173 |
+
# Reshape back: [B, num_heads, L, head_dim] -> [B, L, actual_proj_dim]
|
| 174 |
retention_states = retention_states.transpose(1, 2).contiguous()
|
| 175 |
retention_states = retention_states.reshape(
|
| 176 |
+
batch_size, seq_len, actual_proj_dim
|
| 177 |
)
|
| 178 |
|
| 179 |
+
# ✅ Group norm (actual_proj_dim 사용)
|
| 180 |
+
if actual_proj_dim == self.hidden_size:
|
| 181 |
+
retention_states = self.group_norm(
|
| 182 |
+
retention_states.transpose(1, 2)
|
| 183 |
+
).transpose(1, 2)
|
| 184 |
+
else:
|
| 185 |
+
# Adaptive normalization
|
| 186 |
+
norm = nn.GroupNorm(self.num_heads, actual_proj_dim).to(retention_states.device)
|
| 187 |
+
retention_states = norm(retention_states.transpose(1, 2)).transpose(1, 2)
|
| 188 |
|
| 189 |
# Output projection
|
| 190 |
+
# ✅ actual_proj_dim -> hidden_size 변환 필요
|
| 191 |
+
if actual_proj_dim != self.hidden_size:
|
| 192 |
+
# Adaptive projection
|
| 193 |
+
adaptive_o_proj = nn.Linear(actual_proj_dim, self.hidden_size, bias=False).to(retention_states.device)
|
| 194 |
+
attn_output = adaptive_o_proj(retention_states)
|
| 195 |
+
else:
|
| 196 |
+
attn_output = self.o_proj(retention_states)
|
| 197 |
|
| 198 |
return (attn_output, None, past_key_value)
|
| 199 |
|
|
|
|
| 202 |
queries: torch.Tensor, # [B, H, L, D]
|
| 203 |
keys: torch.Tensor, # [B, H, L, D]
|
| 204 |
values: torch.Tensor, # [B, H, L, D]
|
| 205 |
+
past_state: Optional[Tuple] = None,
|
| 206 |
+
head_dim: Optional[int] = None
|
| 207 |
):
|
| 208 |
"""O(n) Retention 계산"""
|
| 209 |
+
batch_size, num_heads, seq_len, actual_head_dim = queries.shape
|
| 210 |
+
|
| 211 |
+
# ✅ Use provided head_dim or infer from queries
|
| 212 |
+
if head_dim is None:
|
| 213 |
+
head_dim = actual_head_dim
|
| 214 |
|
| 215 |
# State 초기화
|
| 216 |
if past_state is not None:
|
|
|
|
| 337 |
|
| 338 |
|
| 339 |
# =====================================================
|
| 340 |
+
# 모델 변환 함수 (FIXED)
|
| 341 |
# =====================================================
|
| 342 |
|
| 343 |
def replace_attention_with_retention(model, use_hierarchical=True):
|
| 344 |
"""
|
| 345 |
Transformer의 Attention을 PHOENIX Retention으로 교체
|
| 346 |
+
✅ FIX: Better weight copying and dimension handling
|
| 347 |
"""
|
| 348 |
print("🔄 Starting Attention → Retention conversion...")
|
| 349 |
|
|
|
|
| 390 |
else:
|
| 391 |
new_retention = MultiScaleRetention(model.config, layer_idx)
|
| 392 |
|
| 393 |
+
# ✅ 가중치 복사 (improved)
|
| 394 |
if hasattr(old_attn, 'q_proj'):
|
| 395 |
+
try:
|
| 396 |
+
# Get target retention module
|
| 397 |
+
if use_hierarchical:
|
| 398 |
+
target_retention = new_retention.base_retention
|
| 399 |
+
else:
|
| 400 |
+
target_retention = new_retention
|
| 401 |
|
| 402 |
+
# Shape 확인 및 복사
|
| 403 |
+
old_q_shape = old_attn.q_proj.weight.shape
|
| 404 |
+
new_q_shape = target_retention.q_proj.weight.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
+
if old_q_shape == new_q_shape:
|
| 407 |
+
target_retention.q_proj.weight.data = \
|
| 408 |
+
old_attn.q_proj.weight.data.clone()
|
| 409 |
+
target_retention.k_proj.weight.data = \
|
| 410 |
+
old_attn.k_proj.weight.data.clone()
|
| 411 |
+
target_retention.v_proj.weight.data = \
|
| 412 |
+
old_attn.v_proj.weight.data.clone()
|
| 413 |
+
target_retention.o_proj.weight.data = \
|
| 414 |
+
old_attn.o_proj.weight.data.clone()
|
| 415 |
+
|
| 416 |
+
print(f" ✅ Layer {layer_idx}: Weights copied (shape: {old_q_shape})")
|
| 417 |
+
else:
|
| 418 |
+
print(f" ⚠️ Layer {layer_idx}: Shape mismatch")
|
| 419 |
+
print(f" Old: {old_q_shape}, New: {new_q_shape}")
|
| 420 |
+
print(f" Using random initialization")
|
| 421 |
+
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f" ⚠️ Layer {layer_idx}: Weight copy failed - {e}")
|
| 424 |
|
| 425 |
# 교체
|
| 426 |
layer.self_attn = new_retention
|
|
|
|
| 1070 |
# =====================================================
|
| 1071 |
|
| 1072 |
with gr.Blocks(
|
| 1073 |
+
title="🔮 PHOENIX Retention Research Platform - Real Implementation (FIXED)",
|
| 1074 |
theme=gr.themes.Soft(),
|
| 1075 |
) as demo:
|
| 1076 |
|
|
|
|
| 1079 |
|
| 1080 |
**Post-Hierarchical Optimized Efficient Neural Infinite-conteXt**
|
| 1081 |
|
| 1082 |
+
## 🔥 진짜 PHOENIX - Attention → Retention 완전 교체 (FIXED)
|
| 1083 |
|
| 1084 |
+
✅ **FIX**: Shape mismatch 문제 해결
|
| 1085 |
+
- Adaptive dimension handling
|
| 1086 |
+
- Better weight copying
|
| 1087 |
+
- Dynamic projection adjustment
|
| 1088 |
|
| 1089 |
---
|
| 1090 |
""")
|
|
|
|
| 1228 |
## 🔥 PHOENIX 핵심 차이점
|
| 1229 |
|
| 1230 |
### 이전 버전 (가짜)
|
| 1231 |
+
```
|
| 1232 |
입력 → Granite Attention (O(n²)) → PHOENIX 후처리 → 출력
|
| 1233 |
+
```
|
| 1234 |
|
| 1235 |
### 현재 버전 (진짜)
|
| 1236 |
+
```
|
| 1237 |
입력 → PHOENIX Retention (O(n)) → 출력
|
| 1238 |
+
```
|
| 1239 |
|
| 1240 |
## ⏱️ 예상 변환 시간 (350M 모델)
|
| 1241 |
|
|
|
|
| 1249 |
- `Qwen/Qwen2.5-0.5B` (500M)
|
| 1250 |
- `meta-llama/Llama-3.2-1B` (1B)
|
| 1251 |
|
| 1252 |
+
**VIDraft AI Research Lab** | Real PHOENIX Implementation 🔥 (FIXED)
|
| 1253 |
""")
|
| 1254 |
|
| 1255 |
if __name__ == "__main__":
|
|
|
|
| 1258 |
server_name="0.0.0.0",
|
| 1259 |
server_port=7860,
|
| 1260 |
share=False
|
| 1261 |
+
)
|