Add ParetoController for rewards weights
Browse files- rl_utils.py +358 -4
rl_utils.py
CHANGED
|
@@ -14,7 +14,7 @@ import torch.nn.functional as F
|
|
| 14 |
from torch.distributions import Categorical
|
| 15 |
from typing import List, Union, Optional, Tuple, Dict, Any
|
| 16 |
import numpy as np
|
| 17 |
-
from collections import Counter
|
| 18 |
|
| 19 |
# Chemistry imports
|
| 20 |
from rdkit import Chem
|
|
@@ -173,6 +173,7 @@ def compute_sa_reward(selfies_str: str) -> float:
|
|
| 173 |
return -result["score"] # penalize "Hard"
|
| 174 |
except Exception:
|
| 175 |
return 0.0
|
|
|
|
| 176 |
|
| 177 |
# ========================
|
| 178 |
# MOLECULAR REWARD COMPONENTS
|
|
@@ -327,7 +328,6 @@ def compute_lipinski_reward(mol) -> float:
|
|
| 327 |
def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
|
| 328 |
"""
|
| 329 |
Compute comprehensive reward for a SELFIES string.
|
| 330 |
-
FIXED: Uses corrected validity checking pipeline.
|
| 331 |
|
| 332 |
Args:
|
| 333 |
selfies_str: SELFIES representation of molecule
|
|
@@ -337,7 +337,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
|
|
| 337 |
"""
|
| 338 |
smiles = selfies_to_smiles(selfies_str)
|
| 339 |
|
| 340 |
-
# Check validity
|
| 341 |
is_valid = (smiles is not None and
|
| 342 |
is_valid_smiles(smiles) and
|
| 343 |
passes_durrant_lab_filter(smiles))
|
|
@@ -376,7 +376,7 @@ def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
|
|
| 376 |
rewards["total"] = weighted_sum / sum(weights.values())
|
| 377 |
|
| 378 |
return rewards
|
| 379 |
-
|
| 380 |
def selfies_to_lipinski_reward(selfies_str: str) -> float:
|
| 381 |
"""Convert SELFIES to SMILES, then compute Lipinski reward."""
|
| 382 |
smiles = selfies_to_smiles(selfies_str)
|
|
@@ -385,6 +385,258 @@ def selfies_to_lipinski_reward(selfies_str: str) -> float:
|
|
| 385 |
mol = Chem.MolFromSmiles(smiles)
|
| 386 |
return compute_lipinski_reward(mol)
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
# ========================
|
| 389 |
# RL TRAINING CONTROLLERS
|
| 390 |
# ========================
|
|
@@ -589,6 +841,107 @@ def compute_entropy_bonus(action_probs: torch.Tensor) -> torch.Tensor:
|
|
| 589 |
# BATCH REWARD COMPUTATION
|
| 590 |
# ========================
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
def batch_compute_rewards(
|
| 593 |
selfies_list: List[str],
|
| 594 |
reward_mode: str = "chemq3",
|
|
@@ -713,4 +1066,5 @@ def compute_training_metrics(
|
|
| 713 |
# Add loss components
|
| 714 |
metrics.update(loss_dict)
|
| 715 |
|
|
|
|
| 716 |
return metrics
|
|
|
|
| 14 |
from torch.distributions import Categorical
|
| 15 |
from typing import List, Union, Optional, Tuple, Dict, Any
|
| 16 |
import numpy as np
|
| 17 |
+
from collections import Counter, deque
|
| 18 |
|
| 19 |
# Chemistry imports
|
| 20 |
from rdkit import Chem
|
|
|
|
| 173 |
return -result["score"] # penalize "Hard"
|
| 174 |
except Exception:
|
| 175 |
return 0.0
|
| 176 |
+
|
| 177 |
|
| 178 |
# ========================
|
| 179 |
# MOLECULAR REWARD COMPONENTS
|
|
|
|
| 328 |
def compute_comprehensive_reward(selfies_str: str) -> Dict[str, float]:
|
| 329 |
"""
|
| 330 |
Compute comprehensive reward for a SELFIES string.
|
|
|
|
| 331 |
|
| 332 |
Args:
|
| 333 |
selfies_str: SELFIES representation of molecule
|
|
|
|
| 337 |
"""
|
| 338 |
smiles = selfies_to_smiles(selfies_str)
|
| 339 |
|
| 340 |
+
# Check validity first
|
| 341 |
is_valid = (smiles is not None and
|
| 342 |
is_valid_smiles(smiles) and
|
| 343 |
passes_durrant_lab_filter(smiles))
|
|
|
|
| 376 |
rewards["total"] = weighted_sum / sum(weights.values())
|
| 377 |
|
| 378 |
return rewards
|
| 379 |
+
|
| 380 |
def selfies_to_lipinski_reward(selfies_str: str) -> float:
|
| 381 |
"""Convert SELFIES to SMILES, then compute Lipinski reward."""
|
| 382 |
smiles = selfies_to_smiles(selfies_str)
|
|
|
|
| 385 |
mol = Chem.MolFromSmiles(smiles)
|
| 386 |
return compute_lipinski_reward(mol)
|
| 387 |
|
| 388 |
+
# ========================
|
| 389 |
+
# PARETO-STYLE DYNAMIC REWARD CONTROLLER
|
| 390 |
+
# ========================
|
| 391 |
+
|
| 392 |
+
class ParetoRewardController:
|
| 393 |
+
"""
|
| 394 |
+
Dynamic reward mixing based on Pareto optimality principles.
|
| 395 |
+
Adapts reward weights based on current population performance.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def __init__(
|
| 399 |
+
self,
|
| 400 |
+
objectives: List[str] = None,
|
| 401 |
+
history_size: int = 500,
|
| 402 |
+
adaptation_rate: float = 0.1,
|
| 403 |
+
min_weight: float = 0.05,
|
| 404 |
+
max_weight: float = 0.95,
|
| 405 |
+
pareto_pressure: float = 1.0,
|
| 406 |
+
exploration_phase_length: int = 100
|
| 407 |
+
):
|
| 408 |
+
"""
|
| 409 |
+
Args:
|
| 410 |
+
objectives: List of objective names to track
|
| 411 |
+
history_size: Size of rolling history for Pareto analysis
|
| 412 |
+
adaptation_rate: How quickly weights adapt (0-1)
|
| 413 |
+
min_weight: Minimum weight for any objective
|
| 414 |
+
max_weight: Maximum weight for any objective
|
| 415 |
+
pareto_pressure: Higher = more aggressive toward Pareto front
|
| 416 |
+
exploration_phase_length: Steps of pure exploration before adaptation
|
| 417 |
+
"""
|
| 418 |
+
self.objectives = objectives or ["total", "sa", "validity", "diversity"]
|
| 419 |
+
self.history_size = history_size
|
| 420 |
+
self.adaptation_rate = adaptation_rate
|
| 421 |
+
self.min_weight = min_weight
|
| 422 |
+
self.max_weight = max_weight
|
| 423 |
+
self.pareto_pressure = pareto_pressure
|
| 424 |
+
self.exploration_phase_length = exploration_phase_length
|
| 425 |
+
|
| 426 |
+
# Initialize weights equally
|
| 427 |
+
n_objectives = len(self.objectives)
|
| 428 |
+
self.weights = {obj: 1.0/n_objectives for obj in self.objectives}
|
| 429 |
+
|
| 430 |
+
# History tracking
|
| 431 |
+
self.objective_history = deque(maxlen=history_size)
|
| 432 |
+
self.pareto_history = deque(maxlen=100) # Track Pareto front evolution
|
| 433 |
+
self.step_count = 0
|
| 434 |
+
|
| 435 |
+
# Performance tracking
|
| 436 |
+
self.objective_trends = {obj: deque(maxlen=50) for obj in self.objectives}
|
| 437 |
+
self.stagnation_counters = {obj: 0 for obj in self.objectives}
|
| 438 |
+
|
| 439 |
+
def update(self, batch_objectives: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
| 440 |
+
"""
|
| 441 |
+
Update weights based on current batch performance.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
batch_objectives: Dict of objective_name -> tensor of scores
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Updated weights dictionary
|
| 448 |
+
"""
|
| 449 |
+
self.step_count += 1
|
| 450 |
+
|
| 451 |
+
# Convert to numpy for easier manipulation
|
| 452 |
+
batch_data = {}
|
| 453 |
+
for obj_name, tensor_vals in batch_objectives.items():
|
| 454 |
+
if obj_name in self.objectives:
|
| 455 |
+
batch_data[obj_name] = tensor_vals.detach().cpu().numpy()
|
| 456 |
+
|
| 457 |
+
# Store in history
|
| 458 |
+
if len(batch_data) > 0:
|
| 459 |
+
batch_size = len(batch_data[next(iter(batch_data))])
|
| 460 |
+
for i in range(batch_size):
|
| 461 |
+
point = {obj: batch_data[obj][i] for obj in self.objectives if obj in batch_data}
|
| 462 |
+
if len(point) == len(self.objectives): # Only store complete points
|
| 463 |
+
self.objective_history.append(point)
|
| 464 |
+
|
| 465 |
+
# Skip adaptation during exploration phase
|
| 466 |
+
if self.step_count <= self.exploration_phase_length:
|
| 467 |
+
return self.weights.copy()
|
| 468 |
+
|
| 469 |
+
# Compute current Pareto front
|
| 470 |
+
current_front = self._compute_pareto_front()
|
| 471 |
+
if len(current_front) > 0:
|
| 472 |
+
self.pareto_history.append(len(current_front))
|
| 473 |
+
|
| 474 |
+
# Adapt weights based on multiple criteria
|
| 475 |
+
self._adapt_weights_pareto_driven(batch_data)
|
| 476 |
+
self._adapt_weights_stagnation_driven(batch_data)
|
| 477 |
+
self._adapt_weights_diversity_driven()
|
| 478 |
+
|
| 479 |
+
# Ensure constraints
|
| 480 |
+
self._normalize_weights()
|
| 481 |
+
|
| 482 |
+
return self.weights.copy()
|
| 483 |
+
|
| 484 |
+
def _compute_pareto_front(self) -> List[Dict[str, float]]:
|
| 485 |
+
"""Compute current Pareto front from history."""
|
| 486 |
+
if len(self.objective_history) < 10:
|
| 487 |
+
return []
|
| 488 |
+
|
| 489 |
+
points = list(self.objective_history)
|
| 490 |
+
pareto_front = []
|
| 491 |
+
|
| 492 |
+
for i, point1 in enumerate(points):
|
| 493 |
+
is_dominated = False
|
| 494 |
+
for j, point2 in enumerate(points):
|
| 495 |
+
if i != j and self._dominates(point2, point1):
|
| 496 |
+
is_dominated = True
|
| 497 |
+
break
|
| 498 |
+
if not is_dominated:
|
| 499 |
+
pareto_front.append(point1)
|
| 500 |
+
|
| 501 |
+
return pareto_front
|
| 502 |
+
|
| 503 |
+
def _dominates(self, point1: Dict[str, float], point2: Dict[str, float]) -> bool:
|
| 504 |
+
"""Check if point1 dominates point2 (higher is better for all objectives)."""
|
| 505 |
+
better_in_all = True
|
| 506 |
+
strictly_better_in_one = False
|
| 507 |
+
|
| 508 |
+
for obj in self.objectives:
|
| 509 |
+
if obj in point1 and obj in point2:
|
| 510 |
+
if point1[obj] < point2[obj]:
|
| 511 |
+
better_in_all = False
|
| 512 |
+
break
|
| 513 |
+
elif point1[obj] > point2[obj]:
|
| 514 |
+
strictly_better_in_one = True
|
| 515 |
+
|
| 516 |
+
return better_in_all and strictly_better_in_one
|
| 517 |
+
|
| 518 |
+
def _adapt_weights_pareto_driven(self, batch_data: Dict[str, np.ndarray]):
|
| 519 |
+
"""Adapt weights based on distance to Pareto front."""
|
| 520 |
+
if len(self.objective_history) < 50:
|
| 521 |
+
return
|
| 522 |
+
|
| 523 |
+
pareto_front = self._compute_pareto_front()
|
| 524 |
+
if len(pareto_front) == 0:
|
| 525 |
+
return
|
| 526 |
+
|
| 527 |
+
# Compute average distance to Pareto front for each objective
|
| 528 |
+
obj_distances = {obj: [] for obj in self.objectives}
|
| 529 |
+
|
| 530 |
+
for point in list(self.objective_history)[-100:]: # Recent history
|
| 531 |
+
min_distance = float('inf')
|
| 532 |
+
closest_front_point = None
|
| 533 |
+
|
| 534 |
+
for front_point in pareto_front:
|
| 535 |
+
distance = sum((point[obj] - front_point[obj])**2
|
| 536 |
+
for obj in self.objectives if obj in point and obj in front_point)
|
| 537 |
+
if distance < min_distance:
|
| 538 |
+
min_distance = distance
|
| 539 |
+
closest_front_point = front_point
|
| 540 |
+
|
| 541 |
+
if closest_front_point:
|
| 542 |
+
for obj in self.objectives:
|
| 543 |
+
if obj in point and obj in closest_front_point:
|
| 544 |
+
obj_distances[obj].append(abs(point[obj] - closest_front_point[obj]))
|
| 545 |
+
|
| 546 |
+
# Increase weight for objectives with larger gaps to Pareto front
|
| 547 |
+
for obj in self.objectives:
|
| 548 |
+
if obj_distances[obj]:
|
| 549 |
+
avg_distance = np.mean(obj_distances[obj])
|
| 550 |
+
# Higher distance = increase weight
|
| 551 |
+
weight_adjustment = avg_distance * self.adaptation_rate * self.pareto_pressure
|
| 552 |
+
self.weights[obj] = self.weights[obj] * (1 + weight_adjustment)
|
| 553 |
+
|
| 554 |
+
def _adapt_weights_stagnation_driven(self, batch_data: Dict[str, np.ndarray]):
|
| 555 |
+
"""Increase weights for stagnating objectives."""
|
| 556 |
+
for obj in self.objectives:
|
| 557 |
+
if obj in batch_data:
|
| 558 |
+
current_mean = np.mean(batch_data[obj])
|
| 559 |
+
self.objective_trends[obj].append(current_mean)
|
| 560 |
+
|
| 561 |
+
if len(self.objective_trends[obj]) >= 20:
|
| 562 |
+
recent_trend = np.array(list(self.objective_trends[obj])[-20:])
|
| 563 |
+
# Check for stagnation (low variance)
|
| 564 |
+
if np.std(recent_trend) < 0.01: # Adjust threshold as needed
|
| 565 |
+
self.stagnation_counters[obj] += 1
|
| 566 |
+
# Boost weight for stagnating objectives
|
| 567 |
+
boost = min(0.1, self.stagnation_counters[obj] * 0.02)
|
| 568 |
+
self.weights[obj] += boost
|
| 569 |
+
else:
|
| 570 |
+
self.stagnation_counters[obj] = max(0, self.stagnation_counters[obj] - 1)
|
| 571 |
+
|
| 572 |
+
def _adapt_weights_diversity_driven(self):
|
| 573 |
+
"""Adapt weights based on Pareto front diversity."""
|
| 574 |
+
if len(self.pareto_history) < 10:
|
| 575 |
+
return
|
| 576 |
+
|
| 577 |
+
recent_front_sizes = list(self.pareto_history)[-10:]
|
| 578 |
+
front_diversity = np.std(recent_front_sizes)
|
| 579 |
+
|
| 580 |
+
# If diversity is low, boost exploration objectives
|
| 581 |
+
if front_diversity < 1.0: # Adjust threshold
|
| 582 |
+
exploration_objectives = ["sa", "diversity"] # Objectives that promote exploration
|
| 583 |
+
for obj in exploration_objectives:
|
| 584 |
+
if obj in self.weights:
|
| 585 |
+
self.weights[obj] += 0.05 * self.adaptation_rate
|
| 586 |
+
|
| 587 |
+
def _normalize_weights(self):
|
| 588 |
+
"""Ensure weights are normalized and within bounds."""
|
| 589 |
+
# Apply bounds
|
| 590 |
+
for obj in self.weights:
|
| 591 |
+
self.weights[obj] = np.clip(self.weights[obj], self.min_weight, self.max_weight)
|
| 592 |
+
|
| 593 |
+
# Normalize to sum to 1
|
| 594 |
+
total = sum(self.weights.values())
|
| 595 |
+
if total > 0:
|
| 596 |
+
for obj in self.weights:
|
| 597 |
+
self.weights[obj] /= total
|
| 598 |
+
else:
|
| 599 |
+
# Fallback to equal weights
|
| 600 |
+
n = len(self.weights)
|
| 601 |
+
for obj in self.weights:
|
| 602 |
+
self.weights[obj] = 1.0 / n
|
| 603 |
+
|
| 604 |
+
def get_mixed_reward(self, rewards_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 605 |
+
"""
|
| 606 |
+
Compute mixed reward using current weights.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
rewards_dict: Dictionary of reward tensors
|
| 610 |
+
|
| 611 |
+
Returns:
|
| 612 |
+
Mixed reward tensor
|
| 613 |
+
"""
|
| 614 |
+
mixed_reward = None
|
| 615 |
+
|
| 616 |
+
for obj_name, weight in self.weights.items():
|
| 617 |
+
if obj_name in rewards_dict:
|
| 618 |
+
weighted_reward = weight * rewards_dict[obj_name]
|
| 619 |
+
if mixed_reward is None:
|
| 620 |
+
mixed_reward = weighted_reward
|
| 621 |
+
else:
|
| 622 |
+
mixed_reward += weighted_reward
|
| 623 |
+
|
| 624 |
+
return mixed_reward if mixed_reward is not None else torch.zeros_like(list(rewards_dict.values())[0])
|
| 625 |
+
|
| 626 |
+
def get_status(self) -> Dict[str, any]:
|
| 627 |
+
"""Get current status for logging."""
|
| 628 |
+
pareto_front = self._compute_pareto_front()
|
| 629 |
+
|
| 630 |
+
return {
|
| 631 |
+
"weights": self.weights.copy(),
|
| 632 |
+
"step_count": self.step_count,
|
| 633 |
+
"pareto_front_size": len(pareto_front),
|
| 634 |
+
"stagnation_counters": self.stagnation_counters.copy(),
|
| 635 |
+
"history_size": len(self.objective_history),
|
| 636 |
+
"avg_pareto_size": np.mean(list(self.pareto_history)) if self.pareto_history else 0
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
|
| 640 |
# ========================
|
| 641 |
# RL TRAINING CONTROLLERS
|
| 642 |
# ========================
|
|
|
|
| 841 |
# BATCH REWARD COMPUTATION
|
| 842 |
# ========================
|
| 843 |
|
| 844 |
+
def batch_compute_rewards_pareto(
|
| 845 |
+
selfies_list: List[str],
|
| 846 |
+
reward_mode: str = "mix",
|
| 847 |
+
reward_mix: float = 0.5,
|
| 848 |
+
pareto_controller: Optional[ParetoRewardController] = None
|
| 849 |
+
) -> Dict[str, torch.Tensor]:
|
| 850 |
+
"""
|
| 851 |
+
Drop-in replacement for batch_compute_rewards with Pareto support.
|
| 852 |
+
|
| 853 |
+
Args:
|
| 854 |
+
selfies_list: List of SELFIES strings
|
| 855 |
+
reward_mode: "chemq3", "sa", "mix", or "pareto"
|
| 856 |
+
reward_mix: Weight for comprehensive rewards when mixing (0-1)
|
| 857 |
+
pareto_controller: ParetoRewardController instance for "pareto" mode
|
| 858 |
+
|
| 859 |
+
Returns:
|
| 860 |
+
Dictionary containing reward tensors (same format as original)
|
| 861 |
+
"""
|
| 862 |
+
batch_size = len(selfies_list)
|
| 863 |
+
|
| 864 |
+
validity_vals = []
|
| 865 |
+
lipinski_vals = []
|
| 866 |
+
total_rewards = []
|
| 867 |
+
sa_rewards = []
|
| 868 |
+
|
| 869 |
+
# Compute all individual rewards
|
| 870 |
+
for selfies_str in selfies_list:
|
| 871 |
+
smiles = selfies_to_smiles(selfies_str)
|
| 872 |
+
|
| 873 |
+
# Check validity comprehensively
|
| 874 |
+
is_valid = (smiles is not None and
|
| 875 |
+
is_valid_smiles(smiles) and
|
| 876 |
+
passes_durrant_lab_filter(smiles))
|
| 877 |
+
|
| 878 |
+
if reward_mode in ["chemq3", "mix", "pareto"]:
|
| 879 |
+
r = compute_comprehensive_reward(selfies_str)
|
| 880 |
+
validity_vals.append(r.get('validity', 0.0))
|
| 881 |
+
lipinski_vals.append(r.get('lipinski', 0.0))
|
| 882 |
+
|
| 883 |
+
if reward_mode in ["sa", "mix", "pareto"]:
|
| 884 |
+
sa = compute_sa_reward(selfies_str) if is_valid else 0.0
|
| 885 |
+
sa_rewards.append(sa)
|
| 886 |
+
|
| 887 |
+
# Store individual comprehensive reward for pareto mode
|
| 888 |
+
if reward_mode in ["chemq3", "pareto"]:
|
| 889 |
+
total_rewards.append(r.get('total', 0.0))
|
| 890 |
+
elif reward_mode == "sa":
|
| 891 |
+
total_rewards.append(sa)
|
| 892 |
+
elif reward_mode == "mix":
|
| 893 |
+
r_total = r.get("total", 0.0) if 'r' in locals() else 0.0
|
| 894 |
+
sa_val = sa if 'sa' in locals() else 0.0
|
| 895 |
+
mixed = reward_mix * r_total + (1.0 - reward_mix) * sa_val
|
| 896 |
+
total_rewards.append(mixed)
|
| 897 |
+
|
| 898 |
+
# Convert to tensors
|
| 899 |
+
result = {
|
| 900 |
+
"total_rewards": torch.tensor(total_rewards, dtype=torch.float32),
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
if validity_vals:
|
| 904 |
+
result["validity_rewards"] = torch.tensor(validity_vals, dtype=torch.float32)
|
| 905 |
+
if lipinski_vals:
|
| 906 |
+
result["lipinski_rewards"] = torch.tensor(lipinski_vals, dtype=torch.float32)
|
| 907 |
+
if sa_rewards:
|
| 908 |
+
result["sa_rewards"] = torch.tensor(sa_rewards, dtype=torch.float32)
|
| 909 |
+
|
| 910 |
+
# Compute diversity reward
|
| 911 |
+
valid_smiles = []
|
| 912 |
+
for selfies_str in selfies_list:
|
| 913 |
+
smiles = selfies_to_smiles(selfies_str)
|
| 914 |
+
if smiles and is_valid_smiles(smiles) and passes_durrant_lab_filter(smiles):
|
| 915 |
+
valid_smiles.append(smiles)
|
| 916 |
+
|
| 917 |
+
diversity_score = len(set(valid_smiles)) / max(1, len(valid_smiles))
|
| 918 |
+
result["diversity_rewards"] = torch.full((batch_size,), diversity_score, dtype=torch.float32)
|
| 919 |
+
|
| 920 |
+
# Apply Pareto mixing if requested
|
| 921 |
+
if reward_mode == "pareto" and pareto_controller is not None:
|
| 922 |
+
# Prepare objectives for controller update
|
| 923 |
+
batch_objectives = {
|
| 924 |
+
"total": result["total_rewards"],
|
| 925 |
+
"validity": result.get("validity_rewards", torch.zeros(batch_size)),
|
| 926 |
+
"diversity": result["diversity_rewards"]
|
| 927 |
+
}
|
| 928 |
+
|
| 929 |
+
if "sa_rewards" in result:
|
| 930 |
+
batch_objectives["sa"] = result["sa_rewards"]
|
| 931 |
+
|
| 932 |
+
# Update controller and get new weights
|
| 933 |
+
updated_weights = pareto_controller.update(batch_objectives)
|
| 934 |
+
|
| 935 |
+
# Compute mixed reward using adaptive weights
|
| 936 |
+
mixed_reward = pareto_controller.get_mixed_reward(batch_objectives)
|
| 937 |
+
result["total_rewards"] = mixed_reward
|
| 938 |
+
|
| 939 |
+
# Store weights for logging
|
| 940 |
+
result["pareto_weights"] = updated_weights
|
| 941 |
+
|
| 942 |
+
return result
|
| 943 |
+
|
| 944 |
+
# Legacy
|
| 945 |
def batch_compute_rewards(
|
| 946 |
selfies_list: List[str],
|
| 947 |
reward_mode: str = "chemq3",
|
|
|
|
| 1066 |
# Add loss components
|
| 1067 |
metrics.update(loss_dict)
|
| 1068 |
|
| 1069 |
+
|
| 1070 |
return metrics
|