seawolf2357 commited on
Commit
b5c73ce
·
verified ·
1 Parent(s): 0772ae3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -202,14 +202,16 @@ class MultiScaleRetention(nn.Module):
202
 
203
  outputs = []
204
 
 
 
 
205
  # Sequential processing (O(n))
206
  for t in range(seq_len):
207
  q_t = queries[:, :, t, :] # [B, H, D]
208
  k_t = keys[:, :, t, :] # [B, H, D]
209
  v_t = values[:, :, t, :] # [B, H, D]
210
 
211
- # Decay
212
- decay = torch.sigmoid(self.decay).view(1, -1, 1, 1)
213
  state = decay * state
214
 
215
  # State update: S = decay * S + k @ v^T
 
202
 
203
  outputs = []
204
 
205
+ # ✅ Decay를 입력과 같은 device로 이동
206
+ decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(queries.device)
207
+
208
  # Sequential processing (O(n))
209
  for t in range(seq_len):
210
  q_t = queries[:, :, t, :] # [B, H, D]
211
  k_t = keys[:, :, t, :] # [B, H, D]
212
  v_t = values[:, :, t, :] # [B, H, D]
213
 
214
+ # Decay application
 
215
  state = decay * state
216
 
217
  # State update: S = decay * S + k @ v^T