gbyuvd commited on
Commit
cbbc837
·
verified ·
1 Parent(s): 2901ae7

Add ParetoController for rewards weights

Browse files
Files changed (1) hide show
  1. rl_utils.py +358 -4
rl_utils.py CHANGED
@@ -14,7 +14,7 @@ import torch.nn.functional as F
14
  from torch.distributions import Categorical
15
  from typing import List, Union, Optional, Tuple, Dict, Any
16
  import numpy as np
17
- from collections import Counter
18
 
19
  # Chemistry imports
20
  from rdkit import Chem
@@ -173,6 +173,7 @@ def compute_sa_reward(selfies_str: str) -> float:
173
  return -result["score"] # penalize "Hard"
174
  except Exception:
175
  return 0.0
 
176
 
177
  # ========================
178
  # MOLECULAR REWARD COMPONENTS
@@ -327,7 +328,6 @@ def compute_lipinski_reward(mol) -> float:
327
  def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
328
  """
329
  Compute comprehensive reward for a SELFIES string.
330
- FIXED: Uses corrected validity checking pipeline.
331
 
332
  Args:
333
  selfies_str: SELFIES representation of molecule
@@ -337,7 +337,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
337
  """
338
  smiles = selfies_to_smiles(selfies_str)
339
 
340
- # Check validity with the fixed pipeline
341
  is_valid = (smiles is not None and
342
  is_valid_smiles(smiles) and
343
  passes_durrant_lab_filter(smiles))
@@ -376,7 +376,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
376
  rewards["total"] = weighted_sum / sum(weights.values())
377
 
378
  return rewards
379
-
380
  def selfies_to_lipinski_reward(selfies_str: str) -> float:
381
  """Convert SELFIES to SMILES, then compute Lipinski reward."""
382
  smiles = selfies_to_smiles(selfies_str)
@@ -385,6 +385,258 @@ def selfies_to_lipinski_reward(selfies_str: str) -> float:
385
  mol = Chem.MolFromSmiles(smiles)
386
  return compute_lipinski_reward(mol)
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  # ========================
389
  # RL TRAINING CONTROLLERS
390
  # ========================
@@ -589,6 +841,107 @@ def compute_entropy_bonus(action_probs: torch.Tensor) -> torch.Tensor:
589
  # BATCH REWARD COMPUTATION
590
  # ========================
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  def batch_compute_rewards(
593
  selfies_list: List[str],
594
  reward_mode: str = "chemq3",
@@ -713,4 +1066,5 @@ def compute_training_metrics(
713
  # Add loss components
714
  metrics.update(loss_dict)
715
 
 
716
  return metrics
 
14
  from torch.distributions import Categorical
15
  from typing import List, Union, Optional, Tuple, Dict, Any
16
  import numpy as np
17
+ from collections import Counter, deque
18
 
19
  # Chemistry imports
20
  from rdkit import Chem
 
173
  return -result["score"] # penalize "Hard"
174
  except Exception:
175
  return 0.0
176
+
177
 
178
  # ========================
179
  # MOLECULAR REWARD COMPONENTS
 
328
  def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
329
  """
330
  Compute comprehensive reward for a SELFIES string.
 
331
 
332
  Args:
333
  selfies_str: SELFIES representation of molecule
 
337
  """
338
  smiles = selfies_to_smiles(selfies_str)
339
 
340
+ # Check validity first
341
  is_valid = (smiles is not None and
342
  is_valid_smiles(smiles) and
343
  passes_durrant_lab_filter(smiles))
 
376
  rewards["total"] = weighted_sum / sum(weights.values())
377
 
378
  return rewards
379
+
380
  def selfies_to_lipinski_reward(selfies_str: str) -> float:
381
  """Convert SELFIES to SMILES, then compute Lipinski reward."""
382
  smiles = selfies_to_smiles(selfies_str)
 
385
  mol = Chem.MolFromSmiles(smiles)
386
  return compute_lipinski_reward(mol)
387
 
388
+ # ========================
389
+ # PARETO-STYLE DYNAMIC REWARD CONTROLLER
390
+ # ========================
391
+
392
+ class ParetoRewardController:
393
+ """
394
+ Dynamic reward mixing based on Pareto optimality principles.
395
+ Adapts reward weights based on current population performance.
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ objectives: List[str] = None,
401
+ history_size: int = 500,
402
+ adaptation_rate: float = 0.1,
403
+ min_weight: float = 0.05,
404
+ max_weight: float = 0.95,
405
+ pareto_pressure: float = 1.0,
406
+ exploration_phase_length: int = 100
407
+ ):
408
+ """
409
+ Args:
410
+ objectives: List of objective names to track
411
+ history_size: Size of rolling history for Pareto analysis
412
+ adaptation_rate: How quickly weights adapt (0-1)
413
+ min_weight: Minimum weight for any objective
414
+ max_weight: Maximum weight for any objective
415
+ pareto_pressure: Higher = more aggressive toward Pareto front
416
+ exploration_phase_length: Steps of pure exploration before adaptation
417
+ """
418
+ self.objectives = objectives or ["total", "sa", "validity", "diversity"]
419
+ self.history_size = history_size
420
+ self.adaptation_rate = adaptation_rate
421
+ self.min_weight = min_weight
422
+ self.max_weight = max_weight
423
+ self.pareto_pressure = pareto_pressure
424
+ self.exploration_phase_length = exploration_phase_length
425
+
426
+ # Initialize weights equally
427
+ n_objectives = len(self.objectives)
428
+ self.weights = {obj: 1.0/n_objectives for obj in self.objectives}
429
+
430
+ # History tracking
431
+ self.objective_history = deque(maxlen=history_size)
432
+ self.pareto_history = deque(maxlen=100) # Track Pareto front evolution
433
+ self.step_count = 0
434
+
435
+ # Performance tracking
436
+ self.objective_trends = {obj: deque(maxlen=50) for obj in self.objectives}
437
+ self.stagnation_counters = {obj: 0 for obj in self.objectives}
438
+
439
+ def update(self, batch_objectives: Dict[str, torch.Tensor]) -> Dict[str, float]:
440
+ """
441
+ Update weights based on current batch performance.
442
+
443
+ Args:
444
+ batch_objectives: Dict of objective_name -> tensor of scores
445
+
446
+ Returns:
447
+ Updated weights dictionary
448
+ """
449
+ self.step_count += 1
450
+
451
+ # Convert to numpy for easier manipulation
452
+ batch_data = {}
453
+ for obj_name, tensor_vals in batch_objectives.items():
454
+ if obj_name in self.objectives:
455
+ batch_data[obj_name] = tensor_vals.detach().cpu().numpy()
456
+
457
+ # Store in history
458
+ if len(batch_data) > 0:
459
+ batch_size = len(batch_data[next(iter(batch_data))])
460
+ for i in range(batch_size):
461
+ point = {obj: batch_data[obj][i] for obj in self.objectives if obj in batch_data}
462
+ if len(point) == len(self.objectives): # Only store complete points
463
+ self.objective_history.append(point)
464
+
465
+ # Skip adaptation during exploration phase
466
+ if self.step_count <= self.exploration_phase_length:
467
+ return self.weights.copy()
468
+
469
+ # Compute current Pareto front
470
+ current_front = self._compute_pareto_front()
471
+ if len(current_front) > 0:
472
+ self.pareto_history.append(len(current_front))
473
+
474
+ # Adapt weights based on multiple criteria
475
+ self._adapt_weights_pareto_driven(batch_data)
476
+ self._adapt_weights_stagnation_driven(batch_data)
477
+ self._adapt_weights_diversity_driven()
478
+
479
+ # Ensure constraints
480
+ self._normalize_weights()
481
+
482
+ return self.weights.copy()
483
+
484
+ def _compute_pareto_front(self) -> List[Dict[str, float]]:
485
+ """Compute current Pareto front from history."""
486
+ if len(self.objective_history) < 10:
487
+ return []
488
+
489
+ points = list(self.objective_history)
490
+ pareto_front = []
491
+
492
+ for i, point1 in enumerate(points):
493
+ is_dominated = False
494
+ for j, point2 in enumerate(points):
495
+ if i != j and self._dominates(point2, point1):
496
+ is_dominated = True
497
+ break
498
+ if not is_dominated:
499
+ pareto_front.append(point1)
500
+
501
+ return pareto_front
502
+
503
+ def _dominates(self, point1: Dict[str, float], point2: Dict[str, float]) -> bool:
504
+ """Check if point1 dominates point2 (higher is better for all objectives)."""
505
+ better_in_all = True
506
+ strictly_better_in_one = False
507
+
508
+ for obj in self.objectives:
509
+ if obj in point1 and obj in point2:
510
+ if point1[obj] < point2[obj]:
511
+ better_in_all = False
512
+ break
513
+ elif point1[obj] > point2[obj]:
514
+ strictly_better_in_one = True
515
+
516
+ return better_in_all and strictly_better_in_one
517
+
518
+ def _adapt_weights_pareto_driven(self, batch_data: Dict[str, np.ndarray]):
519
+ """Adapt weights based on distance to Pareto front."""
520
+ if len(self.objective_history) < 50:
521
+ return
522
+
523
+ pareto_front = self._compute_pareto_front()
524
+ if len(pareto_front) == 0:
525
+ return
526
+
527
+ # Compute average distance to Pareto front for each objective
528
+ obj_distances = {obj: [] for obj in self.objectives}
529
+
530
+ for point in list(self.objective_history)[-100:]: # Recent history
531
+ min_distance = float('inf')
532
+ closest_front_point = None
533
+
534
+ for front_point in pareto_front:
535
+ distance = sum((point[obj] - front_point[obj])**2
536
+ for obj in self.objectives if obj in point and obj in front_point)
537
+ if distance < min_distance:
538
+ min_distance = distance
539
+ closest_front_point = front_point
540
+
541
+ if closest_front_point:
542
+ for obj in self.objectives:
543
+ if obj in point and obj in closest_front_point:
544
+ obj_distances[obj].append(abs(point[obj] - closest_front_point[obj]))
545
+
546
+ # Increase weight for objectives with larger gaps to Pareto front
547
+ for obj in self.objectives:
548
+ if obj_distances[obj]:
549
+ avg_distance = np.mean(obj_distances[obj])
550
+ # Higher distance = increase weight
551
+ weight_adjustment = avg_distance * self.adaptation_rate * self.pareto_pressure
552
+ self.weights[obj] = self.weights[obj] * (1 + weight_adjustment)
553
+
554
+ def _adapt_weights_stagnation_driven(self, batch_data: Dict[str, np.ndarray]):
555
+ """Increase weights for stagnating objectives."""
556
+ for obj in self.objectives:
557
+ if obj in batch_data:
558
+ current_mean = np.mean(batch_data[obj])
559
+ self.objective_trends[obj].append(current_mean)
560
+
561
+ if len(self.objective_trends[obj]) >= 20:
562
+ recent_trend = np.array(list(self.objective_trends[obj])[-20:])
563
+ # Check for stagnation (low variance)
564
+ if np.std(recent_trend) < 0.01: # Adjust threshold as needed
565
+ self.stagnation_counters[obj] += 1
566
+ # Boost weight for stagnating objectives
567
+ boost = min(0.1, self.stagnation_counters[obj] * 0.02)
568
+ self.weights[obj] += boost
569
+ else:
570
+ self.stagnation_counters[obj] = max(0, self.stagnation_counters[obj] - 1)
571
+
572
+ def _adapt_weights_diversity_driven(self):
573
+ """Adapt weights based on Pareto front diversity."""
574
+ if len(self.pareto_history) < 10:
575
+ return
576
+
577
+ recent_front_sizes = list(self.pareto_history)[-10:]
578
+ front_diversity = np.std(recent_front_sizes)
579
+
580
+ # If diversity is low, boost exploration objectives
581
+ if front_diversity < 1.0: # Adjust threshold
582
+ exploration_objectives = ["sa", "diversity"] # Objectives that promote exploration
583
+ for obj in exploration_objectives:
584
+ if obj in self.weights:
585
+ self.weights[obj] += 0.05 * self.adaptation_rate
586
+
587
+ def _normalize_weights(self):
588
+ """Ensure weights are normalized and within bounds."""
589
+ # Apply bounds
590
+ for obj in self.weights:
591
+ self.weights[obj] = np.clip(self.weights[obj], self.min_weight, self.max_weight)
592
+
593
+ # Normalize to sum to 1
594
+ total = sum(self.weights.values())
595
+ if total > 0:
596
+ for obj in self.weights:
597
+ self.weights[obj] /= total
598
+ else:
599
+ # Fallback to equal weights
600
+ n = len(self.weights)
601
+ for obj in self.weights:
602
+ self.weights[obj] = 1.0 / n
603
+
604
+ def get_mixed_reward(self, rewards_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
605
+ """
606
+ Compute mixed reward using current weights.
607
+
608
+ Args:
609
+ rewards_dict: Dictionary of reward tensors
610
+
611
+ Returns:
612
+ Mixed reward tensor
613
+ """
614
+ mixed_reward = None
615
+
616
+ for obj_name, weight in self.weights.items():
617
+ if obj_name in rewards_dict:
618
+ weighted_reward = weight * rewards_dict[obj_name]
619
+ if mixed_reward is None:
620
+ mixed_reward = weighted_reward
621
+ else:
622
+ mixed_reward += weighted_reward
623
+
624
+ return mixed_reward if mixed_reward is not None else torch.zeros_like(list(rewards_dict.values())[0])
625
+
626
+ def get_status(self) -> Dict[str, any]:
627
+ """Get current status for logging."""
628
+ pareto_front = self._compute_pareto_front()
629
+
630
+ return {
631
+ "weights": self.weights.copy(),
632
+ "step_count": self.step_count,
633
+ "pareto_front_size": len(pareto_front),
634
+ "stagnation_counters": self.stagnation_counters.copy(),
635
+ "history_size": len(self.objective_history),
636
+ "avg_pareto_size": np.mean(list(self.pareto_history)) if self.pareto_history else 0
637
+ }
638
+
639
+
640
  # ========================
641
  # RL TRAINING CONTROLLERS
642
  # ========================
 
841
  # BATCH REWARD COMPUTATION
842
  # ========================
843
 
844
+ def batch_compute_rewards_pareto(
845
+ selfies_list: List[str],
846
+ reward_mode: str = "mix",
847
+ reward_mix: float = 0.5,
848
+ pareto_controller: Optional[ParetoRewardController] = None
849
+ ) -> Dict[str, torch.Tensor]:
850
+ """
851
+ Drop-in replacement for batch_compute_rewards with Pareto support.
852
+
853
+ Args:
854
+ selfies_list: List of SELFIES strings
855
+ reward_mode: "chemq3", "sa", "mix", or "pareto"
856
+ reward_mix: Weight for comprehensive rewards when mixing (0-1)
857
+ pareto_controller: ParetoRewardController instance for "pareto" mode
858
+
859
+ Returns:
860
+ Dictionary containing reward tensors (same format as original)
861
+ """
862
+ batch_size = len(selfies_list)
863
+
864
+ validity_vals = []
865
+ lipinski_vals = []
866
+ total_rewards = []
867
+ sa_rewards = []
868
+
869
+ # Compute all individual rewards
870
+ for selfies_str in selfies_list:
871
+ smiles = selfies_to_smiles(selfies_str)
872
+
873
+ # Check validity comprehensively
874
+ is_valid = (smiles is not None and
875
+ is_valid_smiles(smiles) and
876
+ passes_durrant_lab_filter(smiles))
877
+
878
+ if reward_mode in ["chemq3", "mix", "pareto"]:
879
+ r = compute_comprehensive_reward(selfies_str)
880
+ validity_vals.append(r.get('validity', 0.0))
881
+ lipinski_vals.append(r.get('lipinski', 0.0))
882
+
883
+ if reward_mode in ["sa", "mix", "pareto"]:
884
+ sa = compute_sa_reward(selfies_str) if is_valid else 0.0
885
+ sa_rewards.append(sa)
886
+
887
+ # Store individual comprehensive reward for pareto mode
888
+ if reward_mode in ["chemq3", "pareto"]:
889
+ total_rewards.append(r.get('total', 0.0))
890
+ elif reward_mode == "sa":
891
+ total_rewards.append(sa)
892
+ elif reward_mode == "mix":
893
+ r_total = r.get("total", 0.0) if 'r' in locals() else 0.0
894
+ sa_val = sa if 'sa' in locals() else 0.0
895
+ mixed = reward_mix * r_total + (1.0 - reward_mix) * sa_val
896
+ total_rewards.append(mixed)
897
+
898
+ # Convert to tensors
899
+ result = {
900
+ "total_rewards": torch.tensor(total_rewards, dtype=torch.float32),
901
+ }
902
+
903
+ if validity_vals:
904
+ result["validity_rewards"] = torch.tensor(validity_vals, dtype=torch.float32)
905
+ if lipinski_vals:
906
+ result["lipinski_rewards"] = torch.tensor(lipinski_vals, dtype=torch.float32)
907
+ if sa_rewards:
908
+ result["sa_rewards"] = torch.tensor(sa_rewards, dtype=torch.float32)
909
+
910
+ # Compute diversity reward
911
+ valid_smiles = []
912
+ for selfies_str in selfies_list:
913
+ smiles = selfies_to_smiles(selfies_str)
914
+ if smiles and is_valid_smiles(smiles) and passes_durrant_lab_filter(smiles):
915
+ valid_smiles.append(smiles)
916
+
917
+ diversity_score = len(set(valid_smiles)) / max(1, len(valid_smiles))
918
+ result["diversity_rewards"] = torch.full((batch_size,), diversity_score, dtype=torch.float32)
919
+
920
+ # Apply Pareto mixing if requested
921
+ if reward_mode == "pareto" and pareto_controller is not None:
922
+ # Prepare objectives for controller update
923
+ batch_objectives = {
924
+ "total": result["total_rewards"],
925
+ "validity": result.get("validity_rewards", torch.zeros(batch_size)),
926
+ "diversity": result["diversity_rewards"]
927
+ }
928
+
929
+ if "sa_rewards" in result:
930
+ batch_objectives["sa"] = result["sa_rewards"]
931
+
932
+ # Update controller and get new weights
933
+ updated_weights = pareto_controller.update(batch_objectives)
934
+
935
+ # Compute mixed reward using adaptive weights
936
+ mixed_reward = pareto_controller.get_mixed_reward(batch_objectives)
937
+ result["total_rewards"] = mixed_reward
938
+
939
+ # Store weights for logging
940
+ result["pareto_weights"] = updated_weights
941
+
942
+ return result
943
+
944
+ # Legacy
945
  def batch_compute_rewards(
946
  selfies_list: List[str],
947
  reward_mode: str = "chemq3",
 
1066
  # Add loss components
1067
  metrics.update(loss_dict)
1068
 
1069
+
1070
  return metrics