gbyuvd commited on
Commit
b1be276
·
verified ·
1 Parent(s): a527a2a

Update rl_utils

Browse files
Files changed (1) hide show
  1. rl_utils.py +108 -173
rl_utils.py CHANGED
@@ -2,6 +2,8 @@
2
  # RL_UTILS.PY
3
  # Chemistry RL Training Utilities for ChemQ3-MTP
4
  # by gbyuvd
 
 
5
  # ========================
6
 
7
  import torch
@@ -288,231 +290,164 @@ def selfies_to_lipinski_reward(selfies_str: str) -> float:
288
  # ========================
289
 
290
  class AdaptiveKLController:
291
- """
292
- Adaptive KL divergence controller for PPO training.
293
- Increases or decreases β so that E[KL] stays ≈ target_kl.
294
- """
295
-
296
- def __init__(
297
- self,
298
- init_kl_coef: float = 0.1,
299
- target_kl: float = 0.01,
300
- kl_horizon: int = 1000,
301
- increase_rate: float = 1.5,
302
- decrease_rate: float = 0.8
303
- ):
304
- self.kl_coef = init_kl_coef
305
- self.target_kl = target_kl
306
- self.kl_horizon = kl_horizon
307
- self.inc = increase_rate
308
- self.dec = decrease_rate
309
- self.buffer = []
310
 
311
  def update(self, kl: float) -> float:
312
- """Update KL coefficient based on observed KL divergence."""
313
- self.buffer.append(kl)
314
-
315
  if len(self.buffer) >= self.kl_horizon:
316
  avg_kl = sum(self.buffer) / len(self.buffer)
317
  self.buffer.clear()
318
-
319
  if avg_kl > self.target_kl * 1.5:
320
  self.kl_coef *= self.inc
321
- print(f"KL too high ({avg_kl:.4f}), increasing β to {self.kl_coef:.4f}")
322
  elif avg_kl < self.target_kl * 0.5:
323
  self.kl_coef *= self.dec
324
- print(f"KL too low ({avg_kl:.4f}), decreasing β to {self.kl_coef:.4f}")
325
-
326
  return self.kl_coef
327
 
 
 
 
 
328
  class EnhancedEntropyController:
329
- """
330
- Enhanced entropy controller with dynamic targets and temperature scheduling.
331
- """
332
-
333
- def __init__(
334
- self,
335
- min_entropy: float = 0.5,
336
- max_entropy: float = 3.0,
337
- target_entropy: float = 1.5,
338
- adaptation_rate: float = 0.01
339
- ):
340
  self.min_entropy = min_entropy
341
  self.max_entropy = max_entropy
342
  self.target_entropy = target_entropy
343
- self.adaptation_rate = adaptation_rate
344
- self.entropy_history = []
345
- self.entropy_weight = 0.01 # Starting weight
346
-
347
  def update_entropy_weight(self, current_entropy: float) -> float:
348
- """Dynamically adjust entropy weight based on current entropy levels."""
349
- self.entropy_history.append(current_entropy)
350
-
351
- # Keep rolling window
352
  if len(self.entropy_history) > 100:
353
  self.entropy_history = self.entropy_history[-100:]
354
-
355
  if len(self.entropy_history) >= 10:
356
  avg_entropy = np.mean(self.entropy_history[-10:])
357
-
358
- # If entropy too low, increase weight to encourage exploration
359
  if avg_entropy < self.target_entropy * 0.8:
360
  self.entropy_weight = min(0.05, self.entropy_weight * 1.1)
361
- # If entropy too high, decrease weight
362
  elif avg_entropy > self.target_entropy * 1.2:
363
  self.entropy_weight = max(0.001, self.entropy_weight * 0.95)
364
-
365
- return self.entropy_weight
366
-
367
- def compute_entropy_reward(self, entropy: float) -> float:
368
- """Reward function for entropy - prefer target range."""
369
- if self.min_entropy <= entropy <= self.max_entropy:
370
- # Gaussian reward centered at target
371
- distance = abs(entropy - self.target_entropy)
372
- max_distance = max(
373
- self.target_entropy - self.min_entropy,
374
- self.max_entropy - self.target_entropy
375
- )
376
- return np.exp(-(distance / max_distance) ** 2)
377
- else:
378
- return 0.1 # Small penalty for being outside range
379
 
380
  class CurriculumManager:
381
- """
382
- Curriculum learning manager for progressive training.
383
- Gradually increases max_new_tokens from start_len → max_len, then cycles.
384
- """
385
-
386
- def __init__(
387
- self,
388
- start_len: int = 10,
389
- max_len: int = 30,
390
- step_increase: int = 5,
391
- steps_per_level: int = 30
392
- ):
393
  self.start_len = start_len
394
  self.max_len = max_len
395
  self.step_increase = step_increase
396
  self.steps_per_level = steps_per_level
397
- self.step_counter = 0
398
  self.current_max_len = start_len
 
 
399
 
400
  def get_max_new_tokens(self) -> int:
401
- """Get current maximum new tokens."""
402
  return self.current_max_len
403
 
404
  def step(self) -> int:
405
- """Update curriculum and return new max_new_tokens."""
406
  self.step_counter += 1
407
-
408
  if self.step_counter % self.steps_per_level == 0:
409
- if self.current_max_len < self.max_len:
410
- self.current_max_len = min(
411
- self.current_max_len + self.step_increase,
412
- self.max_len
413
- )
 
414
  else:
415
- # Reset cycle
416
- self.current_max_len = self.start_len
417
- print(f"🔄 Cycle reset: max_new_tokens -> {self.current_max_len}")
418
-
419
- if self.current_max_len < self.max_len:
420
- print(f"📈 Curriculum Update: max_new_tokens = {self.current_max_len}")
421
-
422
  return self.current_max_len
423
 
424
  # ========================
425
- # PPO TRAINING UTILITIES
426
  # ========================
427
 
428
- def compute_ppo_loss(
429
- old_log_probs: torch.Tensor,
430
- new_log_probs: torch.Tensor,
431
- rewards: torch.Tensor,
432
- clip_epsilon: float = 0.2,
433
- baseline: Optional[torch.Tensor] = None
434
- ) -> Tuple[torch.Tensor, torch.Tensor]:
435
- """
436
- Compute PPO clipped loss with numerical stability (improved version).
437
- Note: This function computes the PPO surrogate loss only.
438
- The KL penalty should be computed separately and added to the total loss
439
- in your training loop using the KL coefficient (beta).
440
-
441
- Args:
442
- old_log_probs: Log probabilities from old policy [B, T]
443
- new_log_probs: Log probabilities from new policy [B, T]
444
- rewards: Reward values [B] (or advantages if baseline is provided)
445
- clip_epsilon: Clipping parameter
446
- baseline: Optional baseline for advantage computation [B]
447
-
448
- Returns:
449
- Tuple of (ppo_loss, advantage)
450
- """
451
- # Compute advantage
 
 
 
 
 
 
 
 
 
 
452
  if baseline is not None:
453
- advantage = rewards - baseline.detach()
454
  else:
455
- advantage = rewards
456
-
457
- # Clip advantages to prevent extreme values
458
- advantage = torch.clamp(advantage, -2.0, 2.0)
459
-
460
- # Compute log probability ratio per step for numerical stability
461
- # This avoids summing log probs first, which can lead to large exponents
462
- log_ratio_per_step = new_log_probs - old_log_probs # [B, T]
463
-
464
- # Clamp log ratios per step to prevent extreme ratios before exponentiating
465
- log_ratio_per_step = torch.clamp(log_ratio_per_step, -5.0, 5.0)
466
-
467
- # Exponentiate to get the ratio per step
468
- ratio_per_step = torch.exp(log_ratio_per_step) # [B, T]
469
-
470
- # Calculate surrogate objectives per step
471
- surr1_per_step = ratio_per_step * advantage.unsqueeze(1) # [B, T] * [B, 1] -> [B, T]
472
- surr2_per_step = torch.clamp(ratio_per_step, 1 - clip_epsilon, 1 + clip_epsilon) * advantage.unsqueeze(1) # [B, T]
473
-
474
- # Take the minimum per step, sum over the sequence length for each example, then average over the batch
475
- ppo_loss_per_example = -torch.min(surr1_per_step, surr2_per_step).sum(dim=1) # [B, T] -> [B]
476
- ppo_loss = ppo_loss_per_example.mean() # scalar
477
-
478
  return ppo_loss, advantage.detach()
479
 
480
- def compute_kl_divergence(
481
- old_action_probs: torch.Tensor,
482
- new_action_probs: torch.Tensor
483
- ) -> torch.Tensor:
484
- """
485
- Compute KL divergence between old and new action distributions.
486
-
487
- Args:
488
- old_action_probs: Old action probabilities [B, T, V]
489
- new_action_probs: New action probabilities [B, T, V]
490
-
491
- Returns:
492
- KL divergence per example [B]
493
- """
494
  old_probs = old_action_probs.clamp_min(1e-12)
495
  new_probs = new_action_probs.clamp_min(1e-12)
496
-
497
- # KL(old || new) = sum(old * log(old / new)) calculated as sum(old * (log(old) - log(new)))
498
- kl_per_step = (old_probs * (torch.log(old_probs) - torch.log(new_probs))).sum(dim=-1) # [B, T, V] -> [B, T]
499
- kl_per_example = kl_per_step.sum(dim=1) # [B, T] -> [B]
500
- return kl_per_example # [B]
501
 
502
  def compute_entropy_bonus(action_probs: torch.Tensor) -> torch.Tensor:
503
- """
504
- Compute entropy bonus for exploration.
505
-
506
- Args:
507
- action_probs: Action probabilities [B, T, V]
508
-
509
- Returns:
510
- Entropy per example [B]
511
- """
512
  probs = action_probs.clamp_min(1e-12)
513
- entropy_per_step = -(probs * torch.log(probs)).sum(dim=-1) # [B, T, V] -> [B, T]
514
- entropy_per_example = entropy_per_step.sum(dim=1) # [B, T] -> [B]
515
- return entropy_per_example # [B]
516
 
517
  # ========================
518
  # BATCH REWARD COMPUTATION
@@ -635,4 +570,4 @@ def compute_training_metrics(
635
  # Add loss components
636
  metrics.update(loss_dict)
637
 
638
- return metrics
 
2
  # RL_UTILS.PY
3
  # Chemistry RL Training Utilities for ChemQ3-MTP
4
  # by gbyuvd
5
+ # Patched: reward normalization, KL/entropy reset per phase,
6
+ # entropy target annealing, and symmetric curriculum (kept old naming).
7
  # ========================
8
 
9
  import torch
 
290
  # ========================
291
 
292
  class AdaptiveKLController:
293
+ def __init__(self, init_kl_coef: float = 0.1, target_kl: float = 0.01,
294
+ kl_horizon: int = 200, increase_rate: float = 2.0,
295
+ decrease_rate: float = 0.7):
296
+ self.kl_coef = float(init_kl_coef)
297
+ self.target_kl = float(target_kl)
298
+ self.kl_horizon = int(kl_horizon)
299
+ self.inc = float(increase_rate)
300
+ self.dec = float(decrease_rate)
301
+ self.buffer: List[float] = []
 
 
 
 
 
 
 
 
 
 
302
 
303
  def update(self, kl: float) -> float:
304
+ self.buffer.append(float(kl))
 
 
305
  if len(self.buffer) >= self.kl_horizon:
306
  avg_kl = sum(self.buffer) / len(self.buffer)
307
  self.buffer.clear()
 
308
  if avg_kl > self.target_kl * 1.5:
309
  self.kl_coef *= self.inc
310
+ print(f"KL too high ({avg_kl:.6f}), increasing β to {self.kl_coef:.6f}")
311
  elif avg_kl < self.target_kl * 0.5:
312
  self.kl_coef *= self.dec
313
+ print(f"KL too low ({avg_kl:.6f}), decreasing β to {self.kl_coef:.6f}")
 
314
  return self.kl_coef
315
 
316
+ def reset(self):
317
+ self.buffer.clear()
318
+
319
+
320
  class EnhancedEntropyController:
321
+ def __init__(self, min_entropy: float = 0.5, max_entropy: float = 3.0,
322
+ target_entropy: float = 1.5):
 
 
 
 
 
 
 
 
 
323
  self.min_entropy = min_entropy
324
  self.max_entropy = max_entropy
325
  self.target_entropy = target_entropy
326
+ self.entropy_history: List[float] = []
327
+ self.entropy_weight = 0.01
328
+
 
329
  def update_entropy_weight(self, current_entropy: float) -> float:
330
+ self.entropy_history.append(float(current_entropy))
 
 
 
331
  if len(self.entropy_history) > 100:
332
  self.entropy_history = self.entropy_history[-100:]
 
333
  if len(self.entropy_history) >= 10:
334
  avg_entropy = np.mean(self.entropy_history[-10:])
 
 
335
  if avg_entropy < self.target_entropy * 0.8:
336
  self.entropy_weight = min(0.05, self.entropy_weight * 1.1)
 
337
  elif avg_entropy > self.target_entropy * 1.2:
338
  self.entropy_weight = max(0.001, self.entropy_weight * 0.95)
339
+ return float(self.entropy_weight)
340
+
341
+ def adjust_for_seq_len(self, seq_len: int, base_entropy: float = 1.5):
342
+ seq_len = max(1, int(seq_len))
343
+ self.target_entropy = float(base_entropy * np.log1p(seq_len) / np.log1p(10))
344
+ self.target_entropy = float(np.clip(self.target_entropy, self.min_entropy, self.max_entropy))
345
+
346
+ def reset(self):
347
+ self.entropy_history.clear()
348
+ self.entropy_weight = 0.01
349
+
 
 
 
 
350
 
351
  class CurriculumManager:
352
+ """Symmetric curriculum: 10→15→20→25→20→15→10→..."""
353
+ def __init__(self, start_len: int = 10, max_len: int = 25,
354
+ step_increase: int = 5, steps_per_level: int = 30):
 
 
 
 
 
 
 
 
 
355
  self.start_len = start_len
356
  self.max_len = max_len
357
  self.step_increase = step_increase
358
  self.steps_per_level = steps_per_level
 
359
  self.current_max_len = start_len
360
+ self.step_counter = 0
361
+ self.direction = +1
362
 
363
  def get_max_new_tokens(self) -> int:
 
364
  return self.current_max_len
365
 
366
  def step(self) -> int:
 
367
  self.step_counter += 1
 
368
  if self.step_counter % self.steps_per_level == 0:
369
+ if self.direction == +1:
370
+ if self.current_max_len < self.max_len:
371
+ self.current_max_len += self.step_increase
372
+ else:
373
+ self.direction = -1
374
+ self.current_max_len -= self.step_increase
375
  else:
376
+ if self.current_max_len > self.start_len:
377
+ self.current_max_len -= self.step_increase
378
+ else:
379
+ self.direction = +1
380
+ self.current_max_len += self.step_increase
381
+ print(f"📈 Curriculum Update: max_new_tokens = {self.current_max_len}")
 
382
  return self.current_max_len
383
 
384
  # ========================
385
+ # HELPERS
386
  # ========================
387
 
388
+ def normalize_rewards(rewards: torch.Tensor, seq_len: int, mode: str = "sqrt") -> torch.Tensor:
389
+ if seq_len <= 1 or mode == "none":
390
+ return rewards
391
+ if mode == "per_token":
392
+ return rewards / float(seq_len)
393
+ elif mode == "sqrt":
394
+ return rewards / float(np.sqrt(seq_len))
395
+ else:
396
+ raise ValueError(f"Unknown normalization mode: {mode}")
397
+
398
+
399
+ def reset_controllers_on_phase_change(prev_len: Optional[int], new_len: int,
400
+ kl_controller: Optional[AdaptiveKLController] = None,
401
+ entropy_controller: Optional[EnhancedEntropyController] = None,
402
+ entropy_base: float = 1.5):
403
+ if prev_len is None or prev_len == new_len:
404
+ return
405
+ if kl_controller is not None:
406
+ kl_controller.reset()
407
+ if entropy_controller is not None:
408
+ entropy_controller.reset()
409
+ entropy_controller.adjust_for_seq_len(new_len, base_entropy=entropy_base)
410
+
411
+
412
+ # ========================
413
+ # PPO LOSS
414
+ # ========================
415
+
416
+ def compute_ppo_loss(old_log_probs: torch.Tensor, new_log_probs: torch.Tensor,
417
+ rewards: torch.Tensor, clip_epsilon: float = 0.2,
418
+ baseline: Optional[torch.Tensor] = None,
419
+ seq_len: int = 1, reward_norm: str = "sqrt",
420
+ adv_clip: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
421
+ normed_rewards = normalize_rewards(rewards, seq_len, mode=reward_norm)
422
  if baseline is not None:
423
+ advantage = normed_rewards - baseline.detach()
424
  else:
425
+ advantage = normed_rewards
426
+ if adv_clip is not None:
427
+ advantage = torch.clamp(advantage, -float(adv_clip), float(adv_clip))
428
+ else:
429
+ default_clip = 2.0 * np.sqrt(max(1, seq_len))
430
+ advantage = torch.clamp(advantage, -default_clip, default_clip)
431
+ log_ratio = torch.clamp(new_log_probs - old_log_probs, -10.0, 10.0)
432
+ ratio = torch.exp(log_ratio)
433
+ adv_expanded = advantage.unsqueeze(1) if advantage.dim() == 1 else advantage
434
+ surr1 = ratio * adv_expanded
435
+ surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * adv_expanded
436
+ ppo_loss = -torch.min(surr1, surr2).sum(dim=1).mean()
 
 
 
 
 
 
 
 
 
 
 
437
  return ppo_loss, advantage.detach()
438
 
439
+
440
+ def compute_kl_divergence(old_action_probs: torch.Tensor, new_action_probs: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
441
  old_probs = old_action_probs.clamp_min(1e-12)
442
  new_probs = new_action_probs.clamp_min(1e-12)
443
+ kl_per_step = (old_probs * (torch.log(old_probs) - torch.log(new_probs))).sum(dim=-1)
444
+ return kl_per_step.sum(dim=1)
445
+
 
 
446
 
447
  def compute_entropy_bonus(action_probs: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
448
  probs = action_probs.clamp_min(1e-12)
449
+ entropy_per_step = -(probs * torch.log(probs)).sum(dim=-1)
450
+ return entropy_per_step.sum(dim=1)
 
451
 
452
  # ========================
453
  # BATCH REWARD COMPUTATION
 
570
  # Add loss components
571
  metrics.update(loss_dict)
572
 
573
+ return metrics