seawolf2357 commited on
Commit
83d9107
·
verified ·
1 Parent(s): 068c039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -48
app.py CHANGED
@@ -1,10 +1,12 @@
1
  """
2
  🔮 PHOENIX Retention Research Platform
3
- Real Implementation - Attention Replacement
4
 
5
  L40S GPU + Persistent Storage (SQLite + ChromaDB)
6
  Base Model: IBM Granite 4.0 H 350M (Attention → Retention)
7
  VIDraft AI Research Lab
 
 
8
  """
9
 
10
  import gradio as gr
@@ -45,13 +47,15 @@ print(f"💾 Storage: {STORAGE_PATH}")
45
  print(f"🎯 Default Base Model: {DEFAULT_MODEL}")
46
 
47
  # =====================================================
48
- # PHOENIX Retention Attention (핵심!)
49
  # =====================================================
50
 
51
  class MultiScaleRetention(nn.Module):
52
  """
53
  진짜 Retention Attention
54
  Transformer의 Self-Attention을 완전히 교체
 
 
55
  """
56
 
57
  def __init__(self, config, layer_idx=0):
@@ -109,6 +113,7 @@ class MultiScaleRetention(nn.Module):
109
  ):
110
  """
111
  O(n) 복잡도 Retention 메커니즘
 
112
  """
113
  batch_size, seq_len, input_dim = hidden_states.shape
114
 
@@ -123,46 +128,72 @@ class MultiScaleRetention(nn.Module):
123
  past_key_value = past_key_values
124
 
125
  # Q, K, V 계산
126
- query_states = self.q_proj(hidden_states) # [B, L, H]
127
- key_states = self.k_proj(hidden_states) # [B, L, H]
128
- value_states = self.v_proj(hidden_states) # [B, L, H]
 
 
 
129
 
130
- # Projection 후 크기 확인
131
- assert query_states.shape[-1] == self.hidden_size, \
132
- f"Q projection output is {query_states.shape[-1]}, expected {self.hidden_size}"
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # ✅ Multi-head reshape
135
- # [B, L, H] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim]
136
  query_states = query_states.view(
137
- batch_size, seq_len, self.num_heads, self.head_dim
138
  ).transpose(1, 2)
139
 
140
  key_states = key_states.view(
141
- batch_size, seq_len, self.num_heads, self.head_dim
142
  ).transpose(1, 2)
143
 
144
  value_states = value_states.view(
145
- batch_size, seq_len, self.num_heads, self.head_dim
146
  ).transpose(1, 2)
147
 
148
  # Retention 계산
149
  retention_states = self._compute_retention(
150
- query_states, key_states, value_states, past_key_value
 
151
  )
152
 
153
- # Reshape back: [B, num_heads, L, head_dim] -> [B, L, H]
154
  retention_states = retention_states.transpose(1, 2).contiguous()
155
  retention_states = retention_states.reshape(
156
- batch_size, seq_len, self.hidden_size
157
  )
158
 
159
- # Group norm
160
- retention_states = self.group_norm(
161
- retention_states.transpose(1, 2)
162
- ).transpose(1, 2)
 
 
 
 
 
163
 
164
  # Output projection
165
- attn_output = self.o_proj(retention_states)
 
 
 
 
 
 
166
 
167
  return (attn_output, None, past_key_value)
168
 
@@ -171,10 +202,15 @@ class MultiScaleRetention(nn.Module):
171
  queries: torch.Tensor, # [B, H, L, D]
172
  keys: torch.Tensor, # [B, H, L, D]
173
  values: torch.Tensor, # [B, H, L, D]
174
- past_state: Optional[Tuple] = None
 
175
  ):
176
  """O(n) Retention 계산"""
177
- batch_size, num_heads, seq_len, head_dim = queries.shape
 
 
 
 
178
 
179
  # State 초기화
180
  if past_state is not None:
@@ -301,12 +337,13 @@ class HierarchicalRetention(nn.Module):
301
 
302
 
303
  # =====================================================
304
- # 모델 변환 함수
305
  # =====================================================
306
 
307
  def replace_attention_with_retention(model, use_hierarchical=True):
308
  """
309
  Transformer의 Attention을 PHOENIX Retention으로 교체
 
310
  """
311
  print("🔄 Starting Attention → Retention conversion...")
312
 
@@ -353,24 +390,37 @@ def replace_attention_with_retention(model, use_hierarchical=True):
353
  else:
354
  new_retention = MultiScaleRetention(model.config, layer_idx)
355
 
356
- # ✅ 가중치 복사
357
  if hasattr(old_attn, 'q_proj'):
358
- # Shape 확인
359
- if (old_attn.q_proj.weight.shape ==
360
- new_retention.base_retention.q_proj.weight.shape):
 
 
 
361
 
362
- new_retention.base_retention.q_proj.weight.data = \
363
- old_attn.q_proj.weight.data.clone()
364
- new_retention.base_retention.k_proj.weight.data = \
365
- old_attn.k_proj.weight.data.clone()
366
- new_retention.base_retention.v_proj.weight.data = \
367
- old_attn.v_proj.weight.data.clone()
368
- new_retention.base_retention.o_proj.weight.data = \
369
- old_attn.o_proj.weight.data.clone()
370
 
371
- print(f" ✅ Layer {layer_idx}: Weights copied")
372
- else:
373
- print(f" ⚠️ Layer {layer_idx}: Shape mismatch, random init")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
  # 교체
376
  layer.self_attn = new_retention
@@ -1020,7 +1070,7 @@ def get_database_statistics():
1020
  # =====================================================
1021
 
1022
  with gr.Blocks(
1023
- title="🔮 PHOENIX Retention Research Platform - Real Implementation",
1024
  theme=gr.themes.Soft(),
1025
  ) as demo:
1026
 
@@ -1029,9 +1079,12 @@ with gr.Blocks(
1029
 
1030
  **Post-Hierarchical Optimized Efficient Neural Infinite-conteXt**
1031
 
1032
- ## 🔥 진짜 PHOENIX - Attention → Retention 완전 교체
1033
 
1034
- 버전은 Transformer의 Self-Attention을 PHOENIX Retention으로 **실제로 교체**합니다.
 
 
 
1035
 
1036
  ---
1037
  """)
@@ -1175,14 +1228,14 @@ with gr.Blocks(
1175
  ## 🔥 PHOENIX 핵심 차이점
1176
 
1177
  ### 이전 버전 (가짜)
1178
- ```
1179
  입력 → Granite Attention (O(n²)) → PHOENIX 후처리 → 출력
1180
- ```
1181
 
1182
  ### 현재 버전 (진짜)
1183
- ```
1184
  입력 → PHOENIX Retention (O(n)) → 출력
1185
- ```
1186
 
1187
  ## ⏱️ 예상 변환 시간 (350M 모델)
1188
 
@@ -1196,7 +1249,7 @@ with gr.Blocks(
1196
  - `Qwen/Qwen2.5-0.5B` (500M)
1197
  - `meta-llama/Llama-3.2-1B` (1B)
1198
 
1199
- **VIDraft AI Research Lab** | Real PHOENIX Implementation 🔥
1200
  """)
1201
 
1202
  if __name__ == "__main__":
@@ -1205,4 +1258,4 @@ if __name__ == "__main__":
1205
  server_name="0.0.0.0",
1206
  server_port=7860,
1207
  share=False
1208
- )
 
1
  """
2
  🔮 PHOENIX Retention Research Platform
3
+ Real Implementation - Attention Replacement (FIXED)
4
 
5
  L40S GPU + Persistent Storage (SQLite + ChromaDB)
6
  Base Model: IBM Granite 4.0 H 350M (Attention → Retention)
7
  VIDraft AI Research Lab
8
+
9
+ ✅ FIX: Shape mismatch 문제 해결
10
  """
11
 
12
  import gradio as gr
 
47
  print(f"🎯 Default Base Model: {DEFAULT_MODEL}")
48
 
49
  # =====================================================
50
+ # PHOENIX Retention Attention (핵심! - FIXED)
51
  # =====================================================
52
 
53
  class MultiScaleRetention(nn.Module):
54
  """
55
  진짜 Retention Attention
56
  Transformer의 Self-Attention을 완전히 교체
57
+
58
+ ✅ FIX: Adaptive dimension handling
59
  """
60
 
61
  def __init__(self, config, layer_idx=0):
 
113
  ):
114
  """
115
  O(n) 복잡도 Retention 메커니즘
116
+ ✅ FIX: Adaptive dimension handling
117
  """
118
  batch_size, seq_len, input_dim = hidden_states.shape
119
 
 
128
  past_key_value = past_key_values
129
 
130
  # Q, K, V 계산
131
+ query_states = self.q_proj(hidden_states) # [B, L, ?]
132
+ key_states = self.k_proj(hidden_states) # [B, L, ?]
133
+ value_states = self.v_proj(hidden_states) # [B, L, ?]
134
+
135
+ # ✅ 실제 projection output 차원 확인
136
+ actual_proj_dim = query_states.shape[-1]
137
 
138
+ if actual_proj_dim != self.hidden_size:
139
+ print(f" ⚠️ Layer {self.layer_idx} Projection dim mismatch:")
140
+ print(f" Expected: {self.hidden_size}, Got: {actual_proj_dim}")
141
+
142
+ # Adaptive head_dim 계산
143
+ if actual_proj_dim % self.num_heads != 0:
144
+ raise ValueError(
145
+ f"Projection output {actual_proj_dim} not divisible by "
146
+ f"num_heads {self.num_heads}"
147
+ )
148
+ adaptive_head_dim = actual_proj_dim // self.num_heads
149
+ print(f" 🔧 Using adaptive head_dim: {adaptive_head_dim}")
150
+ else:
151
+ adaptive_head_dim = self.head_dim
152
 
153
+ # ✅ Multi-head reshape (adaptive)
154
+ # [B, L, actual_proj_dim] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim]
155
  query_states = query_states.view(
156
+ batch_size, seq_len, self.num_heads, adaptive_head_dim
157
  ).transpose(1, 2)
158
 
159
  key_states = key_states.view(
160
+ batch_size, seq_len, self.num_heads, adaptive_head_dim
161
  ).transpose(1, 2)
162
 
163
  value_states = value_states.view(
164
+ batch_size, seq_len, self.num_heads, adaptive_head_dim
165
  ).transpose(1, 2)
166
 
167
  # Retention 계산
168
  retention_states = self._compute_retention(
169
+ query_states, key_states, value_states, past_key_value,
170
+ adaptive_head_dim
171
  )
172
 
173
+ # Reshape back: [B, num_heads, L, head_dim] -> [B, L, actual_proj_dim]
174
  retention_states = retention_states.transpose(1, 2).contiguous()
175
  retention_states = retention_states.reshape(
176
+ batch_size, seq_len, actual_proj_dim
177
  )
178
 
179
+ # Group norm (actual_proj_dim 사용)
180
+ if actual_proj_dim == self.hidden_size:
181
+ retention_states = self.group_norm(
182
+ retention_states.transpose(1, 2)
183
+ ).transpose(1, 2)
184
+ else:
185
+ # Adaptive normalization
186
+ norm = nn.GroupNorm(self.num_heads, actual_proj_dim).to(retention_states.device)
187
+ retention_states = norm(retention_states.transpose(1, 2)).transpose(1, 2)
188
 
189
  # Output projection
190
+ # actual_proj_dim -> hidden_size 변환 필요
191
+ if actual_proj_dim != self.hidden_size:
192
+ # Adaptive projection
193
+ adaptive_o_proj = nn.Linear(actual_proj_dim, self.hidden_size, bias=False).to(retention_states.device)
194
+ attn_output = adaptive_o_proj(retention_states)
195
+ else:
196
+ attn_output = self.o_proj(retention_states)
197
 
198
  return (attn_output, None, past_key_value)
199
 
 
202
  queries: torch.Tensor, # [B, H, L, D]
203
  keys: torch.Tensor, # [B, H, L, D]
204
  values: torch.Tensor, # [B, H, L, D]
205
+ past_state: Optional[Tuple] = None,
206
+ head_dim: Optional[int] = None
207
  ):
208
  """O(n) Retention 계산"""
209
+ batch_size, num_heads, seq_len, actual_head_dim = queries.shape
210
+
211
+ # ✅ Use provided head_dim or infer from queries
212
+ if head_dim is None:
213
+ head_dim = actual_head_dim
214
 
215
  # State 초기화
216
  if past_state is not None:
 
337
 
338
 
339
  # =====================================================
340
+ # 모델 변환 함수 (FIXED)
341
  # =====================================================
342
 
343
  def replace_attention_with_retention(model, use_hierarchical=True):
344
  """
345
  Transformer의 Attention을 PHOENIX Retention으로 교체
346
+ ✅ FIX: Better weight copying and dimension handling
347
  """
348
  print("🔄 Starting Attention → Retention conversion...")
349
 
 
390
  else:
391
  new_retention = MultiScaleRetention(model.config, layer_idx)
392
 
393
+ # ✅ 가중치 복사 (improved)
394
  if hasattr(old_attn, 'q_proj'):
395
+ try:
396
+ # Get target retention module
397
+ if use_hierarchical:
398
+ target_retention = new_retention.base_retention
399
+ else:
400
+ target_retention = new_retention
401
 
402
+ # Shape 확인 및 복사
403
+ old_q_shape = old_attn.q_proj.weight.shape
404
+ new_q_shape = target_retention.q_proj.weight.shape
 
 
 
 
 
405
 
406
+ if old_q_shape == new_q_shape:
407
+ target_retention.q_proj.weight.data = \
408
+ old_attn.q_proj.weight.data.clone()
409
+ target_retention.k_proj.weight.data = \
410
+ old_attn.k_proj.weight.data.clone()
411
+ target_retention.v_proj.weight.data = \
412
+ old_attn.v_proj.weight.data.clone()
413
+ target_retention.o_proj.weight.data = \
414
+ old_attn.o_proj.weight.data.clone()
415
+
416
+ print(f" ✅ Layer {layer_idx}: Weights copied (shape: {old_q_shape})")
417
+ else:
418
+ print(f" ⚠️ Layer {layer_idx}: Shape mismatch")
419
+ print(f" Old: {old_q_shape}, New: {new_q_shape}")
420
+ print(f" Using random initialization")
421
+
422
+ except Exception as e:
423
+ print(f" ⚠️ Layer {layer_idx}: Weight copy failed - {e}")
424
 
425
  # 교체
426
  layer.self_attn = new_retention
 
1070
  # =====================================================
1071
 
1072
  with gr.Blocks(
1073
+ title="🔮 PHOENIX Retention Research Platform - Real Implementation (FIXED)",
1074
  theme=gr.themes.Soft(),
1075
  ) as demo:
1076
 
 
1079
 
1080
  **Post-Hierarchical Optimized Efficient Neural Infinite-conteXt**
1081
 
1082
+ ## 🔥 진짜 PHOENIX - Attention → Retention 완전 교체 (FIXED)
1083
 
1084
+ **FIX**: Shape mismatch 문제 해결
1085
+ - Adaptive dimension handling
1086
+ - Better weight copying
1087
+ - Dynamic projection adjustment
1088
 
1089
  ---
1090
  """)
 
1228
  ## 🔥 PHOENIX 핵심 차이점
1229
 
1230
  ### 이전 버전 (가짜)
1231
+ ```
1232
  입력 → Granite Attention (O(n²)) → PHOENIX 후처리 → 출력
1233
+ ```
1234
 
1235
  ### 현재 버전 (진짜)
1236
+ ```
1237
  입력 → PHOENIX Retention (O(n)) → 출력
1238
+ ```
1239
 
1240
  ## ⏱️ 예상 변환 시간 (350M 모델)
1241
 
 
1249
  - `Qwen/Qwen2.5-0.5B` (500M)
1250
  - `meta-llama/Llama-3.2-1B` (1B)
1251
 
1252
+ **VIDraft AI Research Lab** | Real PHOENIX Implementation 🔥 (FIXED)
1253
  """)
1254
 
1255
  if __name__ == "__main__":
 
1258
  server_name="0.0.0.0",
1259
  server_port=7860,
1260
  share=False
1261
+ )