|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import random as rd |
|
|
from finetune_utils import to_one_hot |
|
|
from utils import StepTimer |
|
|
|
|
|
import noise_schedule |
|
|
|
|
|
|
|
|
|
|
|
class Node: |
|
|
""" |
|
|
Node class: partially unmasked sequence |
|
|
- parentNode: Node object at previous time step |
|
|
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes |
|
|
- totalReward: vector of cumulative rewards for all K objectives |
|
|
- visits: number of times the node has been visited by an interation |
|
|
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node |
|
|
- timestep: the time step where the sequence was sampled |
|
|
""" |
|
|
def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None): |
|
|
self.args = args |
|
|
self.parentNode = parentNode |
|
|
self.childNodes = [] if childNodes is None else childNodes |
|
|
|
|
|
self.log_rnd = log_rnd |
|
|
|
|
|
|
|
|
self.log_policy_step = log_policy_step |
|
|
self.log_pretrained_step = log_pretrained_step |
|
|
|
|
|
|
|
|
self.totalReward = totalReward |
|
|
|
|
|
|
|
|
self.visits = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.timestep = timestep |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tokens = tokens |
|
|
|
|
|
def selectNode(self, rootNode): |
|
|
""" |
|
|
Selects a node to move to among the children nodes based on select score |
|
|
""" |
|
|
|
|
|
nodeStatus = self.getExpandStatus() |
|
|
|
|
|
|
|
|
if (nodeStatus == 3): |
|
|
|
|
|
selectScores = [] |
|
|
selectable_children = [] |
|
|
|
|
|
for childNode in self.childNodes: |
|
|
childStatus = childNode.getExpandStatus() |
|
|
|
|
|
if childStatus == 2 or childStatus == 3: |
|
|
selectScore = childNode.calcSelectScore() |
|
|
if torch.is_tensor(selectScore) and selectScore.numel() == 1: |
|
|
selectScore = selectScore.item() |
|
|
|
|
|
selectable_children.append(childNode) |
|
|
selectScores.append(float(selectScore)) |
|
|
|
|
|
|
|
|
if len(selectable_children) == 0: |
|
|
return rootNode, 3 |
|
|
|
|
|
selectScores = np.asarray(selectScores, dtype=np.float64) |
|
|
|
|
|
temp = 1.0 |
|
|
|
|
|
m = np.max(selectScores) |
|
|
exps = np.exp((selectScores - m) / temp) |
|
|
tot = exps.sum() |
|
|
|
|
|
if not np.isfinite(tot) or tot <= 0.0: |
|
|
probs = np.full(len(selectable_children), 1.0 / len(selectable_children)) |
|
|
else: |
|
|
probs = exps / tot |
|
|
|
|
|
|
|
|
idx = np.random.choice(len(selectable_children), p=probs) |
|
|
selected = selectable_children[idx] |
|
|
|
|
|
|
|
|
return selected, selected.getExpandStatus() |
|
|
|
|
|
elif (nodeStatus == 2): |
|
|
return self, nodeStatus |
|
|
|
|
|
|
|
|
return rootNode, 3 |
|
|
|
|
|
def selectNodeTopK(self, rootNode, k = 3, temp = 1.0): |
|
|
""" |
|
|
Pick from the top-k by select score. |
|
|
Returns: (selected_node, selected_status) |
|
|
""" |
|
|
nodeStatus = self.getExpandStatus() |
|
|
|
|
|
|
|
|
if nodeStatus == 2: |
|
|
return self, nodeStatus |
|
|
|
|
|
if nodeStatus == 3: |
|
|
selectable_children = [] |
|
|
selectScores = [] |
|
|
|
|
|
|
|
|
for ch in self.childNodes: |
|
|
s = ch.getExpandStatus() |
|
|
if s in (2, 3): |
|
|
sc = ch.calcSelectScore() |
|
|
if torch.is_tensor(sc): |
|
|
sc = sc.item() if sc.numel() == 1 else float(sc.mean().item()) |
|
|
sc = float(sc) if np.isfinite(sc) else -np.inf |
|
|
selectable_children.append(ch) |
|
|
selectScores.append(sc) |
|
|
|
|
|
if not selectable_children: |
|
|
return rootNode, 3 |
|
|
|
|
|
scores = np.asarray(selectScores, dtype=np.float64) |
|
|
|
|
|
|
|
|
k_eff = min(k, len(scores)) |
|
|
topk_idx = np.argpartition(-scores, kth=k_eff-1)[:k_eff] |
|
|
|
|
|
topk_idx = topk_idx[np.argsort(-scores[topk_idx])] |
|
|
|
|
|
|
|
|
pool_scores = scores[topk_idx] |
|
|
pool_children = [selectable_children[i] for i in topk_idx] |
|
|
|
|
|
|
|
|
m = np.max(pool_scores) |
|
|
z = (pool_scores - m) / max(1e-8, temp) |
|
|
exps = np.exp(np.clip(z, -60, 60)) |
|
|
tot = exps.sum() |
|
|
if not np.isfinite(tot) or tot <= 0.0: |
|
|
idx_local = 0 |
|
|
else: |
|
|
probs = exps / tot |
|
|
|
|
|
idx_local = int(np.random.choice(len(pool_children), p=probs)) |
|
|
|
|
|
selected = pool_children[idx_local] |
|
|
return selected, selected.getExpandStatus() |
|
|
|
|
|
return rootNode, 3 |
|
|
|
|
|
def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward): |
|
|
"""" |
|
|
Adds a child node: |
|
|
log_rnd: log_rnd of the path up to the added child node |
|
|
log_policy_step: scalar value of the log-prob of sampling the step under the policy |
|
|
log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model |
|
|
""" |
|
|
child = Node(args=self.args, |
|
|
tokens=tokens, |
|
|
log_rnd = log_rnd, |
|
|
log_policy_step=log_policy_step, |
|
|
log_pretrained_step=log_pretrained_step, |
|
|
parentNode=self, |
|
|
childNodes=[], |
|
|
totalReward=totalReward, |
|
|
timestep=self.timestep+1) |
|
|
|
|
|
self.childNodes.append(child) |
|
|
return child |
|
|
|
|
|
def update_logrnd(self, log_policy_step, log_rnd): |
|
|
self.log_policy_step = log_policy_step |
|
|
self.log_rnd = log_rnd |
|
|
|
|
|
def updateNode(self, rewards): |
|
|
""" |
|
|
Updates the cumulative rewards vector with the reward vector at a descendent leaf node. |
|
|
Increments the number of visits to the node. |
|
|
""" |
|
|
self.visits += 1 |
|
|
|
|
|
self.totalReward += rewards |
|
|
|
|
|
def calcSelectScore(self): |
|
|
""" |
|
|
Calculates the select score for the node from the cumulative rewards vector and number of visits. |
|
|
- c: determines the degree of exploration |
|
|
- minSelectScore: determines the |
|
|
""" |
|
|
|
|
|
normRewards = self.totalReward / self.visits |
|
|
|
|
|
|
|
|
|
|
|
return normRewards + (self.args.exploration * self.log_policy_step * np.sqrt(self.parentNode.visits) / self.visits) |
|
|
|
|
|
def getExpandStatus(self): |
|
|
""" |
|
|
Returns an integer indicating whether the node is a: |
|
|
1. terminal node (sequence is fully unmasked) |
|
|
2. legal leaf node (partially unmasked sequence that can be expanded) |
|
|
3. legal non-leaf node (already expanded sequence with M child nodes) |
|
|
""" |
|
|
if self.timestep == self.args.total_num_steps: |
|
|
return 1 |
|
|
elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0): |
|
|
return 2 |
|
|
return 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MCTS: |
|
|
def __init__(self, args, config, policy_model, pretrained, rewardFunc, rootNode=None): |
|
|
self.timer = StepTimer(policy_model.device) |
|
|
|
|
|
|
|
|
self.buf_stats = {"insert":0, "replace":0, "reject_worse":0, |
|
|
"reject_dup":0, "reject_nonfinite":0} |
|
|
self._seen_hashes = set() |
|
|
|
|
|
self.device = policy_model.device |
|
|
print(f"MCTS device: {self.device}") |
|
|
|
|
|
self.args = args |
|
|
self.config = config |
|
|
self.noise = noise_schedule.get_noise(config) |
|
|
self.time_conditioning = args.time_conditioning |
|
|
|
|
|
self.mask_index = policy_model.mask_index |
|
|
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index |
|
|
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)} |
|
|
if rootNode is None: |
|
|
self.rootNode = Node(self.args, tokens = masked_tokens, |
|
|
log_rnd=torch.zeros((), device=self.device), |
|
|
log_policy_step=torch.zeros((), device=self.device), |
|
|
log_pretrained_step=torch.zeros((), device=self.device), |
|
|
totalReward=torch.zeros((), device=self.device), timestep=0) |
|
|
else: |
|
|
self.rootNode = rootNode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.buffer = [] |
|
|
|
|
|
self.buffer_size = args.buffer_size |
|
|
|
|
|
self.num_steps = args.total_num_steps |
|
|
self.num_sequences = args.num_sequences |
|
|
|
|
|
|
|
|
self.pretrained = pretrained |
|
|
|
|
|
|
|
|
self.policy_model = policy_model |
|
|
|
|
|
self.device = policy_model.device |
|
|
|
|
|
self.sequence_length = args.seq_length |
|
|
|
|
|
self.num_iter = args.num_iter |
|
|
|
|
|
self.num_children = args.num_children |
|
|
|
|
|
|
|
|
self.rewardFunc = rewardFunc |
|
|
|
|
|
self.iter_num = 0 |
|
|
|
|
|
self.reward_log = [] |
|
|
self.logrnd_log = [] |
|
|
|
|
|
self.policy_model.eval() |
|
|
self.pretrained.eval() |
|
|
self.rewardFunc.eval() |
|
|
|
|
|
def _hash_tokens(self, t): |
|
|
|
|
|
return tuple(t.detach().cpu().tolist()) |
|
|
|
|
|
def reset(self, resetTree): |
|
|
self.iter_num = 0 |
|
|
self.buffer = [] |
|
|
self._seen_hashes = set() |
|
|
self.reward_log = [] |
|
|
self.logrnd_log = [] |
|
|
|
|
|
|
|
|
if resetTree: |
|
|
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index |
|
|
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)} |
|
|
self.rootNode = Node(self.args, tokens = masked_tokens, |
|
|
log_rnd=torch.zeros((), device=self.device), |
|
|
log_policy_step=torch.zeros((), device=self.device), |
|
|
log_pretrained_step=torch.zeros((), device=self.device), |
|
|
totalReward=torch.zeros((), device=self.device), timestep=0) |
|
|
|
|
|
def forward(self, resetTree=False): |
|
|
|
|
|
self.reset(resetTree) |
|
|
|
|
|
while (self.iter_num < self.num_iter): |
|
|
self.iter_num += 1 |
|
|
|
|
|
|
|
|
with self.timer.section("select"): |
|
|
leafNode, _ = self.select(self.rootNode) |
|
|
|
|
|
|
|
|
with self.timer.section("expand"): |
|
|
self.expand(leafNode) |
|
|
|
|
|
final_x, log_rnd, final_rewards = self.consolidateBuffer() |
|
|
|
|
|
rows = self.timer.summary() |
|
|
print("\n=== Timing summary (by total time) ===") |
|
|
for name, cnt, total, mean, p50, p95 in rows: |
|
|
print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms " |
|
|
f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms") |
|
|
|
|
|
|
|
|
return final_x, log_rnd, final_rewards |
|
|
|
|
|
|
|
|
def updateBuffer(self, x_final, log_rnd, final_reward): |
|
|
B = x_final.shape[0] |
|
|
for i in range(B): |
|
|
|
|
|
if not torch.isfinite(final_reward[i]) or not torch.isfinite(log_rnd[i]): |
|
|
self.buf_stats["reject_nonfinite"] += 1 |
|
|
continue |
|
|
|
|
|
h = self._hash_tokens(x_final[i]) |
|
|
if h in self._seen_hashes: |
|
|
self.buf_stats["reject_dup"] += 1 |
|
|
continue |
|
|
|
|
|
item = {"x_final": x_final[i].clone(), |
|
|
"log_rnd": log_rnd[i].clone(), |
|
|
"final_reward": final_reward[i].clone()} |
|
|
|
|
|
if len(self.buffer) < self.buffer_size: |
|
|
self.buffer.append(item) |
|
|
self._seen_hashes.add(h) |
|
|
self.buf_stats["insert"] += 1 |
|
|
else: |
|
|
|
|
|
min_idx, min_item = min( |
|
|
enumerate(self.buffer), |
|
|
key=lambda kv: (kv[1]["final_reward"].item(), kv[1]["log_rnd"].item()) |
|
|
) |
|
|
cand_key = (final_reward[i].item(), log_rnd[i].item()) |
|
|
min_key = (min_item["final_reward"].item(), min_item["log_rnd"].item()) |
|
|
|
|
|
if cand_key > min_key: |
|
|
|
|
|
old_h = self._hash_tokens(self.buffer[min_idx]["x_final"]) |
|
|
if old_h in self._seen_hashes: |
|
|
self._seen_hashes.remove(old_h) |
|
|
self.buffer[min_idx] = item |
|
|
self._seen_hashes.add(h) |
|
|
self.buf_stats["replace"] += 1 |
|
|
else: |
|
|
self.buf_stats["reject_worse"] += 1 |
|
|
|
|
|
def print_buffer_stats(self): |
|
|
print("[BUFFER] ", |
|
|
" ".join(f"{k}={v}" for k,v in self.buf_stats.items()), |
|
|
f" size={len(self.buffer)}/{self.buffer_size}") |
|
|
if len(self.buffer): |
|
|
vals = torch.stack([b["final_reward"] for b in self.buffer]).float() |
|
|
print(f"[BUFFER] reward min/mean/max: {vals.min():.4f} {vals.mean():.4f} {vals.max():.4f}") |
|
|
|
|
|
def consolidateBuffer(self): |
|
|
""" |
|
|
returns x_final, log_rnd, and final_rewards in tensors |
|
|
""" |
|
|
x_final = [] |
|
|
log_rnd = [] |
|
|
final_rewards = [] |
|
|
for item in self.buffer: |
|
|
x_final.append(item["x_final"]) |
|
|
log_rnd.append(item["log_rnd"]) |
|
|
final_rewards.append(item["final_reward"]) |
|
|
|
|
|
x_final = torch.stack(x_final, dim=0) |
|
|
log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) |
|
|
final_rewards = torch.stack(final_rewards, dim=0).to(dtype=torch.float32) |
|
|
|
|
|
return x_final, log_rnd, final_rewards |
|
|
|
|
|
|
|
|
def isPathEnd(self, path, maxDepth): |
|
|
""" |
|
|
Checks if the node is completely unmasked (ie. end of path) |
|
|
or if the path is at the max depth |
|
|
""" |
|
|
if (path[-1] != self.mask_index).all(): |
|
|
return True |
|
|
elif len(path) >= maxDepth: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def select(self, currNode, eps=1e-5): |
|
|
""" |
|
|
Traverse the tree from the root node until reaching a legal leaf node |
|
|
""" |
|
|
|
|
|
updated_log_rnd = torch.zeros((), device=self.device) |
|
|
while True: |
|
|
if self.args.select_topk: |
|
|
currNode, nodeStatus = currNode.selectNodeTopK(self.rootNode, k=self.args.select_topk_value, temp=1.0) |
|
|
else: |
|
|
currNode, nodeStatus = currNode.selectNode(self.rootNode) |
|
|
|
|
|
if currNode.parentNode is not None: |
|
|
|
|
|
child_tokens = currNode.tokens['seqs'].to(self.device) |
|
|
attn_mask = currNode.tokens['attention_mask'].to(self.device) |
|
|
parent = currNode.parentNode |
|
|
parent_tokens = parent.tokens['seqs'].to(self.device) |
|
|
t = torch.ones(1, device = self.device) |
|
|
dt = (1 - eps) / self.num_steps |
|
|
with torch.no_grad(): |
|
|
with self.timer.section("select.compute_log_policy"): |
|
|
updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens, |
|
|
child_tokens, |
|
|
t=t, dt=dt) |
|
|
updated_log_rnd += (currNode.log_pretrained_step - updated_log_policy_step) |
|
|
|
|
|
currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) |
|
|
|
|
|
|
|
|
if nodeStatus == 2: |
|
|
return currNode, nodeStatus |
|
|
elif nodeStatus == 1: |
|
|
currNode = self.rootNode |
|
|
|
|
|
def expand(self, parentNode, eps=1e-5): |
|
|
""" |
|
|
Sample unmasking steps from the pre-trained MDLM |
|
|
adds num_children partially unmasked sequences to the children of the parentNode |
|
|
""" |
|
|
|
|
|
num_children = self.num_children |
|
|
|
|
|
|
|
|
allChildReward = torch.zeros((), device=self.device) |
|
|
|
|
|
|
|
|
|
|
|
num_rollout_steps = self.num_steps - parentNode.timestep |
|
|
|
|
|
rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device) |
|
|
dt = (1 - eps) / self.num_steps |
|
|
|
|
|
|
|
|
x = parentNode.tokens['seqs'].to(self.device) |
|
|
attn_mask = parentNode.tokens['attention_mask'].to(self.device) |
|
|
parent_log_rnd = parentNode.log_rnd |
|
|
|
|
|
t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
with self.timer.section("expand.batch_mcts_reverse_step"): |
|
|
child_log_p, x_children, child_log_policy_step, child_log_pretrained_step = \ |
|
|
self.policy_model.batch_mcts_reverse_step(token_array=x, |
|
|
t=t, dt=dt, |
|
|
batch_size=num_children, |
|
|
pretrained=self.pretrained) |
|
|
|
|
|
|
|
|
|
|
|
child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device) |
|
|
|
|
|
x_rollout = x_children |
|
|
|
|
|
traj_log_rnd = child_log_rnd |
|
|
|
|
|
|
|
|
with self.timer.section("expand.rollout_total"): |
|
|
for i in range(1, num_rollout_steps): |
|
|
t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
log_p, x_next, log_policy_step, log_pretrained_step = \ |
|
|
self.policy_model.mcts_reverse_step(x_rollout, |
|
|
t=t, dt=dt, |
|
|
pretrained=self.pretrained) |
|
|
|
|
|
|
|
|
traj_log_rnd += log_pretrained_step - log_policy_step |
|
|
|
|
|
x_rollout = x_next |
|
|
|
|
|
|
|
|
mask_positions = (x_rollout == self.mask_index) |
|
|
|
|
|
|
|
|
any_mask_global = mask_positions.any().item() |
|
|
if any_mask_global: |
|
|
with torch.no_grad(): |
|
|
with self.timer.section("expand.noise_removal"): |
|
|
log_p, x_next, log_policy_step, log_pretrained_step = \ |
|
|
self.policy_model.mcts_noise_removal(x_rollout, |
|
|
t=t, dt=dt, |
|
|
pretrained=self.pretrained) |
|
|
|
|
|
traj_log_rnd += log_pretrained_step - log_policy_step |
|
|
|
|
|
x_rollout = x_next |
|
|
|
|
|
x_final = x_rollout |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_one_hot = to_one_hot(x_final) |
|
|
x_one_hot_reward = torch.transpose(x_one_hot, 1, 2) |
|
|
reward_preds = self.rewardFunc(x_one_hot_reward).squeeze(-1) |
|
|
rewards_value = reward_preds[:, 0] |
|
|
|
|
|
if self.args.reward_clip: |
|
|
rewards = torch.clamp(rewards_value, max=self.args.reward_clip_value) |
|
|
else: |
|
|
rewards = rewards_value |
|
|
|
|
|
traj_log_rnd += rewards / self.args.alpha |
|
|
|
|
|
self.reward_log.append(rewards.detach().cpu().numpy()) |
|
|
self.logrnd_log.append(traj_log_rnd.detach().cpu().numpy()) |
|
|
|
|
|
|
|
|
with self.timer.section("expand.update_buffer"): |
|
|
self.updateBuffer(x_final, traj_log_rnd, rewards) |
|
|
|
|
|
for i in range(num_children): |
|
|
|
|
|
|
|
|
allChildReward += rewards[i] |
|
|
|
|
|
|
|
|
childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask} |
|
|
parentNode.addChildNode(tokens=childTokens, |
|
|
log_rnd=child_log_rnd[i], |
|
|
log_policy_step=child_log_policy_step[i], |
|
|
log_pretrained_step=child_log_pretrained_step[i], |
|
|
totalReward=rewards[i]) |
|
|
|
|
|
|
|
|
with self.timer.section("expand.backprop"): |
|
|
self.backprop(parentNode, allChildReward) |
|
|
|
|
|
|
|
|
def backprop(self, node, allChildReward): |
|
|
|
|
|
while node: |
|
|
node.updateNode(allChildReward) |
|
|
node = node.parentNode |