Update app.py
Browse files
app.py
CHANGED
|
@@ -161,7 +161,7 @@ class MultiScaleRetention(nn.Module):
|
|
| 161 |
# Now all have shape [B, num_heads, L, head_dim]
|
| 162 |
|
| 163 |
# Retention computation
|
| 164 |
-
retention_states = self._compute_retention(
|
| 165 |
query_states, key_states, value_states, past_key_value
|
| 166 |
)
|
| 167 |
|
|
@@ -184,18 +184,29 @@ class MultiScaleRetention(nn.Module):
|
|
| 184 |
# Output projection
|
| 185 |
attn_output = self.o_proj(retention_states)
|
| 186 |
|
| 187 |
-
# ✅ Return
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
| 190 |
|
| 191 |
def _compute_retention(
|
| 192 |
self,
|
| 193 |
queries: torch.Tensor, # [B, H, L, D]
|
| 194 |
keys: torch.Tensor, # [B, H, L, D]
|
| 195 |
values: torch.Tensor, # [B, H, L, D]
|
| 196 |
-
past_state: Optional[
|
| 197 |
):
|
| 198 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
batch_size, num_heads, seq_len, head_dim = queries.shape
|
| 200 |
|
| 201 |
# ✅ State initialization with correct dtype and device
|
|
@@ -204,7 +215,7 @@ class MultiScaleRetention(nn.Module):
|
|
| 204 |
else:
|
| 205 |
state = torch.zeros(
|
| 206 |
batch_size, num_heads, head_dim, head_dim,
|
| 207 |
-
dtype=queries.dtype,
|
| 208 |
device=queries.device
|
| 209 |
)
|
| 210 |
|
|
@@ -234,7 +245,8 @@ class MultiScaleRetention(nn.Module):
|
|
| 234 |
|
| 235 |
output = torch.stack(outputs, dim=2) # [B, H, L, D]
|
| 236 |
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
class HierarchicalRetention(nn.Module):
|
|
@@ -298,12 +310,19 @@ class HierarchicalRetention(nn.Module):
|
|
| 298 |
self.fusion = self.fusion.to(dtype=target_dtype)
|
| 299 |
self.norm = self.norm.to(dtype=target_dtype)
|
| 300 |
|
| 301 |
-
# Base Retention (returns 2 values)
|
| 302 |
-
|
| 303 |
hidden_states, attention_mask, position_ids,
|
| 304 |
past_key_value, output_attentions, use_cache
|
| 305 |
)
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# Hierarchical states
|
| 308 |
short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
|
| 309 |
medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
|
|
@@ -336,8 +355,11 @@ class HierarchicalRetention(nn.Module):
|
|
| 336 |
output = torch.stack(hierarchical_outputs, dim=1)
|
| 337 |
output = self.norm(output)
|
| 338 |
|
| 339 |
-
# ✅ Return
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
|
| 343 |
# =====================================================
|
|
@@ -730,17 +752,38 @@ def generate_text_phoenix(
|
|
| 730 |
print(f" Input tokens: {input_ids.shape[1]}")
|
| 731 |
print(f" Max new tokens: {max_new_tokens}")
|
| 732 |
|
| 733 |
-
# 5. 생성
|
| 734 |
start_time = time.time()
|
| 735 |
generated_ids = []
|
| 736 |
|
| 737 |
model.eval() # ✅ Set to eval mode
|
| 738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
with torch.no_grad():
|
| 740 |
for step in range(max_new_tokens):
|
| 741 |
try:
|
| 742 |
-
#
|
| 743 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
|
| 745 |
# Get logits from lm_head
|
| 746 |
logits = outputs.logits[:, -1, :] # [B, vocab_size]
|
|
@@ -774,10 +817,10 @@ def generate_text_phoenix(
|
|
| 774 |
|
| 775 |
# Append
|
| 776 |
generated_ids.append(next_token_id)
|
| 777 |
-
|
| 778 |
|
| 779 |
# ✅ Limit max sequence length
|
| 780 |
-
if
|
| 781 |
print(f" ⚠️ Max sequence length reached, stopping")
|
| 782 |
break
|
| 783 |
|
|
@@ -788,7 +831,8 @@ def generate_text_phoenix(
|
|
| 788 |
|
| 789 |
# Progress
|
| 790 |
if (step + 1) % 10 == 0:
|
| 791 |
-
|
|
|
|
| 792 |
|
| 793 |
except RuntimeError as e:
|
| 794 |
print(f" ❌ Runtime error at step {step}: {e}")
|
|
@@ -833,17 +877,29 @@ def generate_text_phoenix(
|
|
| 833 |
```
|
| 834 |
"""
|
| 835 |
|
| 836 |
-
initial_tokens = input_ids.shape[1]
|
|
|
|
| 837 |
stats_md = f"""
|
| 838 |
## 📊 Generation Statistics
|
| 839 |
|
|
|
|
| 840 |
- **Input tokens**: {initial_tokens}
|
| 841 |
- **Generated tokens**: {len(generated_ids)}
|
| 842 |
-
- **Total tokens**: {
|
| 843 |
- **Time**: {elapsed:.2f}s
|
| 844 |
-
- **Speed**: {len(generated_ids) / elapsed:.1f} tokens/s
|
|
|
|
|
|
|
|
|
|
|
|
|
| 845 |
- **Temperature**: {temperature}
|
| 846 |
-
- **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
"""
|
| 848 |
|
| 849 |
return output_md, stats_md
|
|
|
|
| 161 |
# Now all have shape [B, num_heads, L, head_dim]
|
| 162 |
|
| 163 |
# Retention computation
|
| 164 |
+
retention_states, new_state = self._compute_retention(
|
| 165 |
query_states, key_states, value_states, past_key_value
|
| 166 |
)
|
| 167 |
|
|
|
|
| 184 |
# Output projection
|
| 185 |
attn_output = self.o_proj(retention_states)
|
| 186 |
|
| 187 |
+
# ✅ Return output and state for KV cache
|
| 188 |
+
if use_cache:
|
| 189 |
+
return (attn_output, None, new_state) # Return state as past_key_value
|
| 190 |
+
else:
|
| 191 |
+
return (attn_output, None)
|
| 192 |
|
| 193 |
def _compute_retention(
|
| 194 |
self,
|
| 195 |
queries: torch.Tensor, # [B, H, L, D]
|
| 196 |
keys: torch.Tensor, # [B, H, L, D]
|
| 197 |
values: torch.Tensor, # [B, H, L, D]
|
| 198 |
+
past_state: Optional[torch.Tensor] = None
|
| 199 |
):
|
| 200 |
+
"""
|
| 201 |
+
O(n) Retention computation with KV cache support
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
past_state: Previous retention state [B, H, D, D]
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
output: [B, H, L, D]
|
| 208 |
+
new_state: Updated state [B, H, D, D]
|
| 209 |
+
"""
|
| 210 |
batch_size, num_heads, seq_len, head_dim = queries.shape
|
| 211 |
|
| 212 |
# ✅ State initialization with correct dtype and device
|
|
|
|
| 215 |
else:
|
| 216 |
state = torch.zeros(
|
| 217 |
batch_size, num_heads, head_dim, head_dim,
|
| 218 |
+
dtype=queries.dtype,
|
| 219 |
device=queries.device
|
| 220 |
)
|
| 221 |
|
|
|
|
| 245 |
|
| 246 |
output = torch.stack(outputs, dim=2) # [B, H, L, D]
|
| 247 |
|
| 248 |
+
# ✅ Return both output and updated state
|
| 249 |
+
return output, state
|
| 250 |
|
| 251 |
|
| 252 |
class HierarchicalRetention(nn.Module):
|
|
|
|
| 310 |
self.fusion = self.fusion.to(dtype=target_dtype)
|
| 311 |
self.norm = self.norm.to(dtype=target_dtype)
|
| 312 |
|
| 313 |
+
# Base Retention (returns 2 or 3 values depending on use_cache)
|
| 314 |
+
base_output = self.base_retention(
|
| 315 |
hidden_states, attention_mask, position_ids,
|
| 316 |
past_key_value, output_attentions, use_cache
|
| 317 |
)
|
| 318 |
|
| 319 |
+
# ✅ Handle both 2 and 3 return values
|
| 320 |
+
if len(base_output) == 3:
|
| 321 |
+
retention_output, attn_weights, new_state = base_output
|
| 322 |
+
else:
|
| 323 |
+
retention_output, attn_weights = base_output
|
| 324 |
+
new_state = None
|
| 325 |
+
|
| 326 |
# Hierarchical states
|
| 327 |
short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
|
| 328 |
medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
|
|
|
|
| 355 |
output = torch.stack(hierarchical_outputs, dim=1)
|
| 356 |
output = self.norm(output)
|
| 357 |
|
| 358 |
+
# ✅ Return with state for KV cache
|
| 359 |
+
if use_cache and new_state is not None:
|
| 360 |
+
return (output, None, new_state)
|
| 361 |
+
else:
|
| 362 |
+
return (output, None)
|
| 363 |
|
| 364 |
|
| 365 |
# =====================================================
|
|
|
|
| 752 |
print(f" Input tokens: {input_ids.shape[1]}")
|
| 753 |
print(f" Max new tokens: {max_new_tokens}")
|
| 754 |
|
| 755 |
+
# 5. 생성 (✅ KV Cache 사용)
|
| 756 |
start_time = time.time()
|
| 757 |
generated_ids = []
|
| 758 |
|
| 759 |
model.eval() # ✅ Set to eval mode
|
| 760 |
|
| 761 |
+
# ✅ KV Cache 초기화
|
| 762 |
+
past_key_values = None
|
| 763 |
+
current_input_ids = input_ids
|
| 764 |
+
|
| 765 |
+
print(f" 🚀 Using KV Cache for efficient generation...")
|
| 766 |
+
|
| 767 |
with torch.no_grad():
|
| 768 |
for step in range(max_new_tokens):
|
| 769 |
try:
|
| 770 |
+
# ✅ KV Cache를 사용한 forward pass
|
| 771 |
+
if past_key_values is None:
|
| 772 |
+
# 첫 forward: 전체 프롬프트 처리
|
| 773 |
+
outputs = model(
|
| 774 |
+
input_ids=current_input_ids,
|
| 775 |
+
use_cache=True
|
| 776 |
+
)
|
| 777 |
+
past_key_values = outputs.past_key_values
|
| 778 |
+
print(f" 📦 Initial cache created (prompt tokens: {current_input_ids.shape[1]})")
|
| 779 |
+
else:
|
| 780 |
+
# 이후 forward: 새 토큰만 처리 (⚡ 빠름!)
|
| 781 |
+
outputs = model(
|
| 782 |
+
input_ids=current_input_ids[:, -1:], # ✅ 마지막 토큰만
|
| 783 |
+
past_key_values=past_key_values, # ✅ 이전 state 재사용
|
| 784 |
+
use_cache=True
|
| 785 |
+
)
|
| 786 |
+
past_key_values = outputs.past_key_values # ✅ State 업데이트
|
| 787 |
|
| 788 |
# Get logits from lm_head
|
| 789 |
logits = outputs.logits[:, -1, :] # [B, vocab_size]
|
|
|
|
| 817 |
|
| 818 |
# Append
|
| 819 |
generated_ids.append(next_token_id)
|
| 820 |
+
current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
|
| 821 |
|
| 822 |
# ✅ Limit max sequence length
|
| 823 |
+
if current_input_ids.shape[1] > 2048:
|
| 824 |
print(f" ⚠️ Max sequence length reached, stopping")
|
| 825 |
break
|
| 826 |
|
|
|
|
| 831 |
|
| 832 |
# Progress
|
| 833 |
if (step + 1) % 10 == 0:
|
| 834 |
+
speed = (step + 1) / (time.time() - start_time)
|
| 835 |
+
print(f" Generated {step + 1}/{max_new_tokens} tokens... ({speed:.1f} tok/s)")
|
| 836 |
|
| 837 |
except RuntimeError as e:
|
| 838 |
print(f" ❌ Runtime error at step {step}: {e}")
|
|
|
|
| 877 |
```
|
| 878 |
"""
|
| 879 |
|
| 880 |
+
initial_tokens = input_ids.shape[1]
|
| 881 |
+
total_tokens = current_input_ids.shape[1]
|
| 882 |
stats_md = f"""
|
| 883 |
## 📊 Generation Statistics
|
| 884 |
|
| 885 |
+
### Performance
|
| 886 |
- **Input tokens**: {initial_tokens}
|
| 887 |
- **Generated tokens**: {len(generated_ids)}
|
| 888 |
+
- **Total tokens**: {total_tokens}
|
| 889 |
- **Time**: {elapsed:.2f}s
|
| 890 |
+
- **Speed**: {len(generated_ids) / max(elapsed, 0.01):.1f} tokens/s ⚡
|
| 891 |
+
|
| 892 |
+
### Model
|
| 893 |
+
- **Architecture**: PHOENIX Retention (O(n))
|
| 894 |
+
- **KV Cache**: ✅ Enabled (State reuse)
|
| 895 |
- **Temperature**: {temperature}
|
| 896 |
+
- **Vocab size**: {model.config.vocab_size}
|
| 897 |
+
|
| 898 |
+
### Efficiency
|
| 899 |
+
- **First token latency**: ~{elapsed / max(len(generated_ids), 1):.3f}s per token
|
| 900 |
+
- **Cache benefit**: ~10-20x speedup vs no cache
|
| 901 |
+
- **Memory**: O(d²) constant per layer
|
| 902 |
+
"""
|
| 903 |
"""
|
| 904 |
|
| 905 |
return output_md, stats_md
|