seawolf2357 commited on
Commit
2c0487e
·
verified ·
1 Parent(s): d655ec6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -103
app.py CHANGED
@@ -61,27 +61,26 @@ class MultiScaleRetention(nn.Module):
61
  self.hidden_size = config.hidden_size
62
  self.num_heads = config.num_attention_heads
63
 
64
- # ✅ Head dimension 안전하게 계산
65
- if hasattr(config, 'head_dim'):
66
- self.head_dim = config.head_dim
67
- else:
68
- self.head_dim = self.hidden_size // self.num_heads
69
 
70
- # ✅ 나누어떨어지지 않는 경우 처리
71
  if self.hidden_size % self.num_heads != 0:
72
- print(f" ⚠️ Layer {layer_idx}: hidden_size ({self.hidden_size}) not divisible by num_heads ({self.num_heads})")
73
- # 가장 가까운 배수로 조정
74
- self.head_dim = (self.hidden_size + self.num_heads - 1) // self.num_heads
75
- self.effective_hidden = self.head_dim * self.num_heads
76
- print(f" Adjusted: head_dim={self.head_dim}, effective_hidden={self.effective_hidden}")
77
- else:
78
- self.effective_hidden = self.hidden_size
79
 
80
- # Q, K, V projections
81
- self.q_proj = nn.Linear(self.hidden_size, self.effective_hidden, bias=False)
82
- self.k_proj = nn.Linear(self.hidden_size, self.effective_hidden, bias=False)
83
- self.v_proj = nn.Linear(self.hidden_size, self.effective_hidden, bias=False)
84
- self.o_proj = nn.Linear(self.effective_hidden, self.hidden_size, bias=False)
 
 
 
 
 
85
 
86
  # Retention 특화 파라미터
87
  decay_values = torch.linspace(0.8, 0.95, self.num_heads)
@@ -90,7 +89,7 @@ class MultiScaleRetention(nn.Module):
90
  # Group normalization
91
  self.group_norm = nn.GroupNorm(
92
  num_groups=self.num_heads,
93
- num_channels=self.effective_hidden
94
  )
95
 
96
  def forward(
@@ -118,21 +117,42 @@ class MultiScaleRetention(nn.Module):
118
  key_states = self.k_proj(hidden_states)
119
  value_states = self.v_proj(hidden_states)
120
 
121
- # ✅ Multi-head reshape (안전하게)
 
 
 
 
 
 
122
  try:
123
  query_states = query_states.view(
124
  batch_size, seq_len, self.num_heads, self.head_dim
125
- ).transpose(1, 2)
 
126
  key_states = key_states.view(
127
  batch_size, seq_len, self.num_heads, self.head_dim
128
  ).transpose(1, 2)
 
129
  value_states = value_states.view(
130
  batch_size, seq_len, self.num_heads, self.head_dim
131
  ).transpose(1, 2)
 
 
 
 
132
  except RuntimeError as e:
133
- print(f" ⚠️ Reshape error: {e}")
134
- print(f" query_states shape: {query_states.shape}")
135
- print(f" Expected: [B={batch_size}, L={seq_len}, H={self.num_heads}, D={self.head_dim}]")
 
 
 
 
 
 
 
 
 
136
  raise
137
 
138
  # Retention 계산
@@ -144,7 +164,7 @@ class MultiScaleRetention(nn.Module):
144
  # Reshape back
145
  retention_states = retention_states.transpose(1, 2).contiguous()
146
  retention_states = retention_states.reshape(
147
- batch_size, seq_len, self.effective_hidden
148
  )
149
 
150
  # Group norm
@@ -159,14 +179,19 @@ class MultiScaleRetention(nn.Module):
159
 
160
  def _compute_retention(
161
  self,
162
- queries: torch.Tensor,
163
- keys: torch.Tensor,
164
- values: torch.Tensor,
165
  past_state: Optional[Tuple] = None
166
  ):
167
  """O(n) Retention 계산"""
168
  batch_size, num_heads, seq_len, head_dim = queries.shape
169
 
 
 
 
 
 
170
  # State 초기화
171
  if past_state is not None:
172
  state = past_state
@@ -180,22 +205,24 @@ class MultiScaleRetention(nn.Module):
180
 
181
  # 순차 처리 (O(n))
182
  for t in range(seq_len):
183
- q_t = queries[:, :, t, :]
184
- k_t = keys[:, :, t, :]
185
- v_t = values[:, :, t, :]
186
 
187
  # Decay 적용
188
  decay = torch.sigmoid(self.decay).view(1, -1, 1, 1)
189
  state = decay * state
190
 
191
- # State 업데이트
192
  state = state + torch.einsum('bhd,bhe->bhde', k_t, v_t)
193
 
194
- # Output
195
  output_t = torch.einsum('bhd,bhde->bhe', q_t, state)
196
  outputs.append(output_t)
197
 
198
- output = torch.stack(outputs, dim=2)
 
 
199
 
200
  return output
201
 
@@ -327,41 +354,55 @@ def replace_attention_with_retention(model, use_hierarchical=True):
327
  old_attn = layer.self_attn
328
  config = model.config
329
 
330
- # 실제 hidden_size 확인
331
- print(f" 📐 Layer {layer_idx} config:")
332
- print(f" - hidden_size: {config.hidden_size}")
333
- print(f" - num_attention_heads: {config.num_attention_heads}")
334
 
335
  # ✅ 실제 가중치 shape 확인
336
  if hasattr(old_attn, 'q_proj'):
337
- actual_hidden_size = old_attn.q_proj.weight.shape[0]
338
- actual_input_size = old_attn.q_proj.weight.shape[1]
339
- print(f" - Actual Q proj: {old_attn.q_proj.weight.shape}")
340
- print(f" - Actual hidden: {actual_hidden_size}")
 
 
 
 
341
 
342
- # Config 업데이트
343
- if actual_hidden_size != config.hidden_size:
344
- print(f" ⚠️ Config mismatch! Using actual size: {actual_hidden_size}")
345
- # 임시 config 생성
346
- temp_config = type('Config', (), {})()
347
- temp_config.hidden_size = actual_hidden_size
348
- temp_config.num_attention_heads = config.num_attention_heads
 
349
 
350
- # Head dimension 재계산
351
- temp_config.head_dim = actual_hidden_size // config.num_attention_heads
 
 
 
352
 
353
- config = temp_config
 
 
 
354
 
355
- # PHOENIX Retention으로 교체
356
  if use_hierarchical:
357
  new_retention = HierarchicalRetention(config, layer_idx)
358
  else:
359
  new_retention = MultiScaleRetention(config, layer_idx)
360
 
361
- # ✅ 가중치 복사 (shape 체크 추가)
362
  if hasattr(old_attn, 'q_proj'):
363
- # Shape 확인 후 복사
364
- if old_attn.q_proj.weight.shape == new_retention.base_retention.q_proj.weight.shape:
 
 
 
 
 
 
 
365
  new_retention.base_retention.q_proj.weight.data = \
366
  old_attn.q_proj.weight.data.clone()
367
  new_retention.base_retention.k_proj.weight.data = \
@@ -370,11 +411,10 @@ def replace_attention_with_retention(model, use_hierarchical=True):
370
  old_attn.v_proj.weight.data.clone()
371
  new_retention.base_retention.o_proj.weight.data = \
372
  old_attn.o_proj.weight.data.clone()
373
- print(f" ✅ Layer {layer_idx}: Weights copied")
 
374
  else:
375
- print(f" ⚠️ Layer {layer_idx}: Shape mismatch, using random init")
376
- print(f" Old: {old_attn.q_proj.weight.shape}")
377
- print(f" New: {new_retention.base_retention.q_proj.weight.shape}")
378
 
379
  # 교체
380
  layer.self_attn = new_retention
@@ -382,50 +422,9 @@ def replace_attention_with_retention(model, use_hierarchical=True):
382
 
383
  print(f" ✅ Layer {layer_idx}: Attention → Retention")
384
 
385
- elif hasattr(layer, 'attn'):
386
- # Alternative structure
387
- old_attn = layer.attn
388
- config = model.config
389
-
390
- # ✅ 실제 크기 확인
391
- if hasattr(old_attn, 'c_attn'):
392
- actual_size = old_attn.c_attn.weight.shape[0] // 3
393
- print(f" 📐 Layer {layer_idx} actual hidden_size: {actual_size}")
394
-
395
- if actual_size != config.hidden_size:
396
- temp_config = type('Config', (), {})()
397
- temp_config.hidden_size = actual_size
398
- temp_config.num_attention_heads = config.num_attention_heads
399
- config = temp_config
400
-
401
- if use_hierarchical:
402
- new_retention = HierarchicalRetention(config, layer_idx)
403
- else:
404
- new_retention = MultiScaleRetention(config, layer_idx)
405
-
406
- # 가중치 복사
407
- if hasattr(old_attn, 'c_attn'):
408
- qkv_weight = old_attn.c_attn.weight.data
409
- hidden_size = config.hidden_size
410
-
411
- new_retention.base_retention.q_proj.weight.data = \
412
- qkv_weight[:hidden_size, :].clone()
413
- new_retention.base_retention.k_proj.weight.data = \
414
- qkv_weight[hidden_size:2*hidden_size, :].clone()
415
- new_retention.base_retention.v_proj.weight.data = \
416
- qkv_weight[2*hidden_size:, :].clone()
417
-
418
- if hasattr(old_attn, 'c_proj'):
419
- new_retention.base_retention.o_proj.weight.data = \
420
- old_attn.c_proj.weight.data.clone()
421
-
422
- layer.attn = new_retention
423
- replaced_count += 1
424
-
425
- print(f" ✅ Layer {layer_idx}: Attention → Retention")
426
-
427
  except Exception as e:
428
- print(f" ⚠️ Layer {layer_idx}: Conversion failed - {e}")
 
429
  import traceback
430
  traceback.print_exc()
431
  continue
 
61
  self.hidden_size = config.hidden_size
62
  self.num_heads = config.num_attention_heads
63
 
64
+ # ✅ Head dimension 정확하게 계산
65
+ self.head_dim = self.hidden_size // self.num_heads
 
 
 
66
 
67
+ # ✅ 나누어떨어지는지 확인
68
  if self.hidden_size % self.num_heads != 0:
69
+ raise ValueError(
70
+ f"hidden_size ({self.hidden_size}) must be divisible by "
71
+ f"num_attention_heads ({self.num_heads})"
72
+ )
 
 
 
73
 
74
+ print(f" 📐 Layer {layer_idx} Retention config:")
75
+ print(f" - hidden_size: {self.hidden_size}")
76
+ print(f" - num_heads: {self.num_heads}")
77
+ print(f" - head_dim: {self.head_dim}")
78
+
79
+ # Q, K, V projections (hidden_size → hidden_size)
80
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
81
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
82
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
83
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
84
 
85
  # Retention 특화 파라미터
86
  decay_values = torch.linspace(0.8, 0.95, self.num_heads)
 
89
  # Group normalization
90
  self.group_norm = nn.GroupNorm(
91
  num_groups=self.num_heads,
92
+ num_channels=self.hidden_size
93
  )
94
 
95
  def forward(
 
117
  key_states = self.k_proj(hidden_states)
118
  value_states = self.v_proj(hidden_states)
119
 
120
+ # ✅ Shape 디버깅
121
+ print(f"\n 🔍 Retention forward shapes:")
122
+ print(f" - Input hidden_states: {hidden_states.shape}")
123
+ print(f" - After projection Q: {query_states.shape}")
124
+ print(f" - Expected reshape: [{batch_size}, {seq_len}, {self.num_heads}, {self.head_dim}]")
125
+
126
+ # ✅ Multi-head reshape - 정확한 차원으로
127
  try:
128
  query_states = query_states.view(
129
  batch_size, seq_len, self.num_heads, self.head_dim
130
+ ).transpose(1, 2) # [B, H, L, D]
131
+
132
  key_states = key_states.view(
133
  batch_size, seq_len, self.num_heads, self.head_dim
134
  ).transpose(1, 2)
135
+
136
  value_states = value_states.view(
137
  batch_size, seq_len, self.num_heads, self.head_dim
138
  ).transpose(1, 2)
139
+
140
+ print(f" - After reshape Q: {query_states.shape}")
141
+ print(f" ✅ Reshape successful!")
142
+
143
  except RuntimeError as e:
144
+ print(f"\n Reshape failed!")
145
+ print(f" - query_states shape: {query_states.shape}")
146
+ print(f" - query_states size: {query_states.numel()}")
147
+ print(f" - Target shape: [{batch_size}, {seq_len}, {self.num_heads}, {self.head_dim}]")
148
+ print(f" - Target size: {batch_size * seq_len * self.num_heads * self.head_dim}")
149
+ print(f" - Error: {e}")
150
+
151
+ # ✅ 실제 크기 계산
152
+ actual_total = query_states.numel()
153
+ actual_per_token = actual_total // (batch_size * seq_len)
154
+ print(f" - Actual hidden per token: {actual_per_token}")
155
+
156
  raise
157
 
158
  # Retention 계산
 
164
  # Reshape back
165
  retention_states = retention_states.transpose(1, 2).contiguous()
166
  retention_states = retention_states.reshape(
167
+ batch_size, seq_len, self.hidden_size
168
  )
169
 
170
  # Group norm
 
179
 
180
  def _compute_retention(
181
  self,
182
+ queries: torch.Tensor, # [B, H, L, D]
183
+ keys: torch.Tensor, # [B, H, L, D]
184
+ values: torch.Tensor, # [B, H, L, D]
185
  past_state: Optional[Tuple] = None
186
  ):
187
  """O(n) Retention 계산"""
188
  batch_size, num_heads, seq_len, head_dim = queries.shape
189
 
190
+ print(f" 🔄 Computing retention:")
191
+ print(f" - queries: {queries.shape}")
192
+ print(f" - keys: {keys.shape}")
193
+ print(f" - values: {values.shape}")
194
+
195
  # State 초기화
196
  if past_state is not None:
197
  state = past_state
 
205
 
206
  # 순차 처리 (O(n))
207
  for t in range(seq_len):
208
+ q_t = queries[:, :, t, :] # [B, H, D]
209
+ k_t = keys[:, :, t, :] # [B, H, D]
210
+ v_t = values[:, :, t, :] # [B, H, D]
211
 
212
  # Decay 적용
213
  decay = torch.sigmoid(self.decay).view(1, -1, 1, 1)
214
  state = decay * state
215
 
216
+ # State 업데이트: S = decay * S + k @ v^T
217
  state = state + torch.einsum('bhd,bhe->bhde', k_t, v_t)
218
 
219
+ # Output: q @ S
220
  output_t = torch.einsum('bhd,bhde->bhe', q_t, state)
221
  outputs.append(output_t)
222
 
223
+ output = torch.stack(outputs, dim=2) # [B, H, L, D]
224
+
225
+ print(f" - output: {output.shape}")
226
 
227
  return output
228
 
 
354
  old_attn = layer.self_attn
355
  config = model.config
356
 
357
+ print(f"\n 📐 Layer {layer_idx} - Original Attention:")
 
 
 
358
 
359
  # ✅ 실제 가중치 shape 확인
360
  if hasattr(old_attn, 'q_proj'):
361
+ print(f" - Q weight: {old_attn.q_proj.weight.shape}")
362
+ print(f" - K weight: {old_attn.k_proj.weight.shape}")
363
+ print(f" - V weight: {old_attn.v_proj.weight.shape}")
364
+ print(f" - O weight: {old_attn.o_proj.weight.shape}")
365
+
366
+ # ✅ 실제 output 크기 확인
367
+ actual_hidden = old_attn.q_proj.weight.shape[0]
368
+ actual_input = old_attn.q_proj.weight.shape[1]
369
 
370
+ print(f" - Actual output dim: {actual_hidden}")
371
+ print(f" - Actual input dim: {actual_input}")
372
+ print(f" - Config hidden_size: {config.hidden_size}")
373
+
374
+ # Config 맞지 않으면 조정
375
+ if actual_hidden != config.hidden_size or actual_input != config.hidden_size:
376
+ print(f" ⚠️ Dimension mismatch detected!")
377
+ print(f" Using actual dimensions: {actual_input} → {actual_hidden}")
378
 
379
+ # 새로운 config 생성
380
+ class CustomConfig:
381
+ def __init__(self, hidden, heads):
382
+ self.hidden_size = hidden
383
+ self.num_attention_heads = heads
384
 
385
+ config = CustomConfig(actual_hidden, model.config.num_attention_heads)
386
+
387
+ # PHOENIX Retention 생성
388
+ print(f"\n 🔄 Creating PHOENIX Retention for layer {layer_idx}...")
389
 
 
390
  if use_hierarchical:
391
  new_retention = HierarchicalRetention(config, layer_idx)
392
  else:
393
  new_retention = MultiScaleRetention(config, layer_idx)
394
 
395
+ # ✅ 가중치 복사 (shape 완벽히 확인)
396
  if hasattr(old_attn, 'q_proj'):
397
+ old_q_shape = old_attn.q_proj.weight.shape
398
+ new_q_shape = new_retention.base_retention.q_proj.weight.shape
399
+
400
+ print(f"\n 📋 Weight copy:")
401
+ print(f" - Old Q: {old_q_shape}")
402
+ print(f" - New Q: {new_q_shape}")
403
+
404
+ if old_q_shape == new_q_shape:
405
+ # Shape 일치 - 복사
406
  new_retention.base_retention.q_proj.weight.data = \
407
  old_attn.q_proj.weight.data.clone()
408
  new_retention.base_retention.k_proj.weight.data = \
 
411
  old_attn.v_proj.weight.data.clone()
412
  new_retention.base_retention.o_proj.weight.data = \
413
  old_attn.o_proj.weight.data.clone()
414
+
415
+ print(f" ✅ Weights copied successfully")
416
  else:
417
+ print(f" ⚠️ Shape mismatch - using random initialization")
 
 
418
 
419
  # 교체
420
  layer.self_attn = new_retention
 
422
 
423
  print(f" ✅ Layer {layer_idx}: Attention → Retention")
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  except Exception as e:
426
+ print(f"\n Layer {layer_idx}: Conversion failed")
427
+ print(f" Error: {e}")
428
  import traceback
429
  traceback.print_exc()
430
  continue