seawolf2357 commited on
Commit
068c039
·
verified ·
1 Parent(s): 2c0487e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -109
app.py CHANGED
@@ -58,10 +58,12 @@ class MultiScaleRetention(nn.Module):
58
  super().__init__()
59
  self.config = config
60
  self.layer_idx = layer_idx
 
 
61
  self.hidden_size = config.hidden_size
62
  self.num_heads = config.num_attention_heads
63
 
64
- # ✅ Head dimension 정확하게 계산
65
  self.head_dim = self.hidden_size // self.num_heads
66
 
67
  # ✅ 나누어떨어지는지 확인
@@ -71,12 +73,13 @@ class MultiScaleRetention(nn.Module):
71
  f"num_attention_heads ({self.num_heads})"
72
  )
73
 
74
- print(f" 📐 Layer {layer_idx} Retention config:")
75
  print(f" - hidden_size: {self.hidden_size}")
76
  print(f" - num_heads: {self.num_heads}")
77
  print(f" - head_dim: {self.head_dim}")
78
 
79
- # Q, K, V projections (hidden_size hidden_size)
 
80
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
81
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
82
  self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
@@ -107,61 +110,47 @@ class MultiScaleRetention(nn.Module):
107
  """
108
  O(n) 복잡도 Retention 메커니즘
109
  """
110
- batch_size, seq_len, _ = hidden_states.shape
 
 
 
 
 
 
 
111
 
112
  if past_key_values is not None:
113
  past_key_value = past_key_values
114
 
115
  # Q, K, V 계산
116
- query_states = self.q_proj(hidden_states)
117
- key_states = self.k_proj(hidden_states)
118
- value_states = self.v_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
119
 
120
- # Shape 디버깅
121
- print(f"\n 🔍 Retention forward shapes:")
122
- print(f" - Input hidden_states: {hidden_states.shape}")
123
- print(f" - After projection Q: {query_states.shape}")
124
- print(f" - Expected reshape: [{batch_size}, {seq_len}, {self.num_heads}, {self.head_dim}]")
125
 
126
- # Multi-head reshape - 정확한 차원으로
127
- try:
128
- query_states = query_states.view(
129
- batch_size, seq_len, self.num_heads, self.head_dim
130
- ).transpose(1, 2) # [B, H, L, D]
131
-
132
- key_states = key_states.view(
133
- batch_size, seq_len, self.num_heads, self.head_dim
134
- ).transpose(1, 2)
135
-
136
- value_states = value_states.view(
137
- batch_size, seq_len, self.num_heads, self.head_dim
138
- ).transpose(1, 2)
139
-
140
- print(f" - After reshape Q: {query_states.shape}")
141
- print(f" ✅ Reshape successful!")
142
-
143
- except RuntimeError as e:
144
- print(f"\n ❌ Reshape failed!")
145
- print(f" - query_states shape: {query_states.shape}")
146
- print(f" - query_states size: {query_states.numel()}")
147
- print(f" - Target shape: [{batch_size}, {seq_len}, {self.num_heads}, {self.head_dim}]")
148
- print(f" - Target size: {batch_size * seq_len * self.num_heads * self.head_dim}")
149
- print(f" - Error: {e}")
150
-
151
- # ✅ 실제 크기 계산
152
- actual_total = query_states.numel()
153
- actual_per_token = actual_total // (batch_size * seq_len)
154
- print(f" - Actual hidden per token: {actual_per_token}")
155
-
156
- raise
157
 
158
  # Retention 계산
159
  retention_states = self._compute_retention(
160
- query_states, key_states, value_states,
161
- past_key_value
162
  )
163
 
164
- # Reshape back
165
  retention_states = retention_states.transpose(1, 2).contiguous()
166
  retention_states = retention_states.reshape(
167
  batch_size, seq_len, self.hidden_size
@@ -187,11 +176,6 @@ class MultiScaleRetention(nn.Module):
187
  """O(n) Retention 계산"""
188
  batch_size, num_heads, seq_len, head_dim = queries.shape
189
 
190
- print(f" 🔄 Computing retention:")
191
- print(f" - queries: {queries.shape}")
192
- print(f" - keys: {keys.shape}")
193
- print(f" - values: {values.shape}")
194
-
195
  # State 초기화
196
  if past_state is not None:
197
  state = past_state
@@ -222,12 +206,8 @@ class MultiScaleRetention(nn.Module):
222
 
223
  output = torch.stack(outputs, dim=2) # [B, H, L, D]
224
 
225
- print(f" - output: {output.shape}")
226
-
227
  return output
228
 
229
-
230
-
231
  class HierarchicalRetention(nn.Module):
232
  """
233
  PHOENIX의 계층적 Retention
@@ -263,16 +243,15 @@ class HierarchicalRetention(nn.Module):
263
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
264
  output_attentions: bool = False,
265
  use_cache: bool = False,
266
- cache_position: Optional[torch.Tensor] = None, # ✅ 추가
267
- past_key_values: Optional[Tuple[torch.Tensor]] = None, # ✅ 추가
268
- **kwargs # ✅ 추가 - 기타 모든 인자 받기
269
  ):
270
  """
271
  Granite 모델과 호환되는 forward 메서드
272
  """
273
  batch_size, seq_len, hidden_size = hidden_states.shape
274
 
275
- # past_key_values와 past_key_value 통합 처리
276
  if past_key_values is not None:
277
  past_key_value = past_key_values
278
 
@@ -347,62 +326,39 @@ def replace_attention_with_retention(model, use_hierarchical=True):
347
 
348
  total_layers = len(layers)
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  for layer_idx, layer in enumerate(layers):
351
  try:
352
- # Attention 레이어 찾기
353
  if hasattr(layer, 'self_attn'):
354
  old_attn = layer.self_attn
355
- config = model.config
356
-
357
- print(f"\n 📐 Layer {layer_idx} - Original Attention:")
358
-
359
- # ✅ 실제 가중치 shape 확인
360
- if hasattr(old_attn, 'q_proj'):
361
- print(f" - Q weight: {old_attn.q_proj.weight.shape}")
362
- print(f" - K weight: {old_attn.k_proj.weight.shape}")
363
- print(f" - V weight: {old_attn.v_proj.weight.shape}")
364
- print(f" - O weight: {old_attn.o_proj.weight.shape}")
365
-
366
- # ✅ 실제 output 크기 확인
367
- actual_hidden = old_attn.q_proj.weight.shape[0]
368
- actual_input = old_attn.q_proj.weight.shape[1]
369
-
370
- print(f" - Actual output dim: {actual_hidden}")
371
- print(f" - Actual input dim: {actual_input}")
372
- print(f" - Config hidden_size: {config.hidden_size}")
373
-
374
- # ✅ Config가 맞지 않으면 조정
375
- if actual_hidden != config.hidden_size or actual_input != config.hidden_size:
376
- print(f" ⚠️ Dimension mismatch detected!")
377
- print(f" Using actual dimensions: {actual_input} → {actual_hidden}")
378
-
379
- # 새로운 config 생성
380
- class CustomConfig:
381
- def __init__(self, hidden, heads):
382
- self.hidden_size = hidden
383
- self.num_attention_heads = heads
384
-
385
- config = CustomConfig(actual_hidden, model.config.num_attention_heads)
386
 
387
  # PHOENIX Retention 생성
388
- print(f"\n 🔄 Creating PHOENIX Retention for layer {layer_idx}...")
389
-
390
  if use_hierarchical:
391
- new_retention = HierarchicalRetention(config, layer_idx)
392
  else:
393
- new_retention = MultiScaleRetention(config, layer_idx)
394
 
395
- # ✅ 가중치 복사 (shape 완벽히 확인)
396
  if hasattr(old_attn, 'q_proj'):
397
- old_q_shape = old_attn.q_proj.weight.shape
398
- new_q_shape = new_retention.base_retention.q_proj.weight.shape
399
-
400
- print(f"\n 📋 Weight copy:")
401
- print(f" - Old Q: {old_q_shape}")
402
- print(f" - New Q: {new_q_shape}")
403
-
404
- if old_q_shape == new_q_shape:
405
- # Shape 일치 - 복사
406
  new_retention.base_retention.q_proj.weight.data = \
407
  old_attn.q_proj.weight.data.clone()
408
  new_retention.base_retention.k_proj.weight.data = \
@@ -412,9 +368,9 @@ def replace_attention_with_retention(model, use_hierarchical=True):
412
  new_retention.base_retention.o_proj.weight.data = \
413
  old_attn.o_proj.weight.data.clone()
414
 
415
- print(f" ✅ Weights copied successfully")
416
  else:
417
- print(f" ⚠️ Shape mismatch - using random initialization")
418
 
419
  # 교체
420
  layer.self_attn = new_retention
@@ -423,8 +379,7 @@ def replace_attention_with_retention(model, use_hierarchical=True):
423
  print(f" ✅ Layer {layer_idx}: Attention → Retention")
424
 
425
  except Exception as e:
426
- print(f"\n ❌ Layer {layer_idx}: Conversion failed")
427
- print(f" Error: {e}")
428
  import traceback
429
  traceback.print_exc()
430
  continue
 
58
  super().__init__()
59
  self.config = config
60
  self.layer_idx = layer_idx
61
+
62
+ # ✅ 실제 hidden_size 가져오기
63
  self.hidden_size = config.hidden_size
64
  self.num_heads = config.num_attention_heads
65
 
66
+ # ✅ Head dimension 계산
67
  self.head_dim = self.hidden_size // self.num_heads
68
 
69
  # ✅ 나누어떨어지는지 확인
 
73
  f"num_attention_heads ({self.num_heads})"
74
  )
75
 
76
+ print(f" 📐 Layer {layer_idx} Retention initialized:")
77
  print(f" - hidden_size: {self.hidden_size}")
78
  print(f" - num_heads: {self.num_heads}")
79
  print(f" - head_dim: {self.head_dim}")
80
 
81
+ # Projections - input과 output 크기 명시
82
+ # input: hidden_size -> output: hidden_size
83
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
84
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
85
  self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
 
110
  """
111
  O(n) 복잡도 Retention 메커니즘
112
  """
113
+ batch_size, seq_len, input_dim = hidden_states.shape
114
+
115
+ # ✅ 입력 차원 확인
116
+ if input_dim != self.hidden_size:
117
+ raise ValueError(
118
+ f"Input hidden_states has dimension {input_dim} "
119
+ f"but model expects {self.hidden_size}"
120
+ )
121
 
122
  if past_key_values is not None:
123
  past_key_value = past_key_values
124
 
125
  # Q, K, V 계산
126
+ query_states = self.q_proj(hidden_states) # [B, L, H]
127
+ key_states = self.k_proj(hidden_states) # [B, L, H]
128
+ value_states = self.v_proj(hidden_states) # [B, L, H]
129
+
130
+ # ✅ Projection 후 크기 확인
131
+ assert query_states.shape[-1] == self.hidden_size, \
132
+ f"Q projection output is {query_states.shape[-1]}, expected {self.hidden_size}"
133
+
134
+ # ✅ Multi-head reshape
135
+ # [B, L, H] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim]
136
+ query_states = query_states.view(
137
+ batch_size, seq_len, self.num_heads, self.head_dim
138
+ ).transpose(1, 2)
139
 
140
+ key_states = key_states.view(
141
+ batch_size, seq_len, self.num_heads, self.head_dim
142
+ ).transpose(1, 2)
 
 
143
 
144
+ value_states = value_states.view(
145
+ batch_size, seq_len, self.num_heads, self.head_dim
146
+ ).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # Retention 계산
149
  retention_states = self._compute_retention(
150
+ query_states, key_states, value_states, past_key_value
 
151
  )
152
 
153
+ # Reshape back: [B, num_heads, L, head_dim] -> [B, L, H]
154
  retention_states = retention_states.transpose(1, 2).contiguous()
155
  retention_states = retention_states.reshape(
156
  batch_size, seq_len, self.hidden_size
 
176
  """O(n) Retention 계산"""
177
  batch_size, num_heads, seq_len, head_dim = queries.shape
178
 
 
 
 
 
 
179
  # State 초기화
180
  if past_state is not None:
181
  state = past_state
 
206
 
207
  output = torch.stack(outputs, dim=2) # [B, H, L, D]
208
 
 
 
209
  return output
210
 
 
 
211
  class HierarchicalRetention(nn.Module):
212
  """
213
  PHOENIX의 계층적 Retention
 
243
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
244
  output_attentions: bool = False,
245
  use_cache: bool = False,
246
+ cache_position: Optional[torch.Tensor] = None,
247
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
248
+ **kwargs
249
  ):
250
  """
251
  Granite 모델과 호환되는 forward 메서드
252
  """
253
  batch_size, seq_len, hidden_size = hidden_states.shape
254
 
 
255
  if past_key_values is not None:
256
  past_key_value = past_key_values
257
 
 
326
 
327
  total_layers = len(layers)
328
 
329
+ # ✅ 첫 번째 레이어에서 실제 hidden_size 확인
330
+ first_layer = layers[0]
331
+ if hasattr(first_layer, 'self_attn') and hasattr(first_layer.self_attn, 'q_proj'):
332
+ actual_output_dim = first_layer.self_attn.q_proj.weight.shape[0]
333
+ actual_input_dim = first_layer.self_attn.q_proj.weight.shape[1]
334
+
335
+ print(f"\n📐 Detected dimensions from first layer:")
336
+ print(f" - Input dim: {actual_input_dim}")
337
+ print(f" - Output dim: {actual_output_dim}")
338
+ print(f" - Config hidden_size: {model.config.hidden_size}")
339
+
340
+ # ✅ Config 업데이트
341
+ if actual_output_dim != model.config.hidden_size:
342
+ print(f" ⚠️ Updating config to match actual dimensions")
343
+ model.config.hidden_size = actual_output_dim
344
+
345
  for layer_idx, layer in enumerate(layers):
346
  try:
 
347
  if hasattr(layer, 'self_attn'):
348
  old_attn = layer.self_attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  # PHOENIX Retention 생성
 
 
351
  if use_hierarchical:
352
+ new_retention = HierarchicalRetention(model.config, layer_idx)
353
  else:
354
+ new_retention = MultiScaleRetention(model.config, layer_idx)
355
 
356
+ # ✅ 가중치 복사
357
  if hasattr(old_attn, 'q_proj'):
358
+ # Shape 확인
359
+ if (old_attn.q_proj.weight.shape ==
360
+ new_retention.base_retention.q_proj.weight.shape):
361
+
 
 
 
 
 
362
  new_retention.base_retention.q_proj.weight.data = \
363
  old_attn.q_proj.weight.data.clone()
364
  new_retention.base_retention.k_proj.weight.data = \
 
368
  new_retention.base_retention.o_proj.weight.data = \
369
  old_attn.o_proj.weight.data.clone()
370
 
371
+ print(f" Layer {layer_idx}: Weights copied")
372
  else:
373
+ print(f" ⚠️ Layer {layer_idx}: Shape mismatch, random init")
374
 
375
  # 교체
376
  layer.self_attn = new_retention
 
379
  print(f" ✅ Layer {layer_idx}: Attention → Retention")
380
 
381
  except Exception as e:
382
+ print(f" ❌ Layer {layer_idx}: Failed - {e}")
 
383
  import traceback
384
  traceback.print_exc()
385
  continue