Update app.py
Browse files
app.py
CHANGED
|
@@ -202,14 +202,16 @@ class MultiScaleRetention(nn.Module):
|
|
| 202 |
|
| 203 |
outputs = []
|
| 204 |
|
|
|
|
|
|
|
|
|
|
| 205 |
# Sequential processing (O(n))
|
| 206 |
for t in range(seq_len):
|
| 207 |
q_t = queries[:, :, t, :] # [B, H, D]
|
| 208 |
k_t = keys[:, :, t, :] # [B, H, D]
|
| 209 |
v_t = values[:, :, t, :] # [B, H, D]
|
| 210 |
|
| 211 |
-
# Decay
|
| 212 |
-
decay = torch.sigmoid(self.decay).view(1, -1, 1, 1)
|
| 213 |
state = decay * state
|
| 214 |
|
| 215 |
# State update: S = decay * S + k @ v^T
|
|
|
|
| 202 |
|
| 203 |
outputs = []
|
| 204 |
|
| 205 |
+
# ✅ Decay를 입력과 같은 device로 이동
|
| 206 |
+
decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(queries.device)
|
| 207 |
+
|
| 208 |
# Sequential processing (O(n))
|
| 209 |
for t in range(seq_len):
|
| 210 |
q_t = queries[:, :, t, :] # [B, H, D]
|
| 211 |
k_t = keys[:, :, t, :] # [B, H, D]
|
| 212 |
v_t = values[:, :, t, :] # [B, H, D]
|
| 213 |
|
| 214 |
+
# Decay application
|
|
|
|
| 215 |
state = decay * state
|
| 216 |
|
| 217 |
# State update: S = decay * S + k @ v^T
|