TR2-D2 / tr2d2-dna /mcts.py
zyc4975matholic
Include DNA training code
303c2e0
import numpy as np
import torch
import torch.nn.functional as F
import numpy as np
import random as rd
from finetune_utils import to_one_hot
from utils import StepTimer
import noise_schedule
### BEGINNING OF NODE CLASS ###
class Node:
"""
Node class: partially unmasked sequence
- parentNode: Node object at previous time step
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
- totalReward: vector of cumulative rewards for all K objectives
- visits: number of times the node has been visited by an interation
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
- timestep: the time step where the sequence was sampled
"""
def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None):
self.args = args
self.parentNode = parentNode
self.childNodes = [] if childNodes is None else childNodes
self.log_rnd = log_rnd # stores the log_rnd up to that step
#self.log_p0 = 0 # stores the log probabiltiy of the unmasking step from the previous iteration
self.log_policy_step = log_policy_step # stores the log probability of the unmasking step under the current policy
self.log_pretrained_step = log_pretrained_step
# initialize total rewards to the reward of the roll out unmasked sequence
self.totalReward = totalReward # potential reward of the node based on generated children
# set initial visits to 1
self.visits = 1
#self.path = path
# set timestep (value between 0 and num_steps)
self.timestep = timestep
# set the sampling probabiltiy equal to the probability from the reverse posterior
#self.sampleProb = sampleProb # stores the probability of the sampling step under the current policy
# dict with 'seqs' as token array and 'attention_mask'
self.tokens = tokens
def selectNode(self, rootNode):
"""
Selects a node to move to among the children nodes based on select score
"""
# extract the status of the current node
nodeStatus = self.getExpandStatus()
# if the node is a legal non-leaf node
if (nodeStatus == 3):
# initialize array that will store select score vectors of each child node
selectScores = []
selectable_children = [] # children nodes that can be selected
for childNode in self.childNodes:
childStatus = childNode.getExpandStatus()
# only append child if it is legal leaf node (expandable) or legal non-leaf node
if childStatus == 2 or childStatus == 3:
selectScore = childNode.calcSelectScore()
if torch.is_tensor(selectScore) and selectScore.numel() == 1:
selectScore = selectScore.item()
selectable_children.append(childNode)
selectScores.append(float(selectScore))
# no selectable children
if len(selectable_children) == 0:
return rootNode, 3
selectScores = np.asarray(selectScores, dtype=np.float64)
temp = 1.0
# compute softmax probabiltiies
m = np.max(selectScores)
exps = np.exp((selectScores - m) / temp)
tot = exps.sum()
if not np.isfinite(tot) or tot <= 0.0:
probs = np.full(len(selectable_children), 1.0 / len(selectable_children))
else:
probs = exps / tot
# choose child index from categorical distribution
idx = np.random.choice(len(selectable_children), p=probs)
selected = selectable_children[idx]
# return selected child node and status
return selected, selected.getExpandStatus()
elif (nodeStatus == 2):
return self, nodeStatus
# if node is not valid non-leaf node
return rootNode, 3
def selectNodeTopK(self, rootNode, k = 3, temp = 1.0):
"""
Pick from the top-k by select score.
Returns: (selected_node, selected_status)
"""
nodeStatus = self.getExpandStatus()
# If expandable leaf, return it directly
if nodeStatus == 2:
return self, nodeStatus
if nodeStatus == 3:
selectable_children = []
selectScores = []
# collect candidates
for ch in self.childNodes:
s = ch.getExpandStatus()
if s in (2, 3):
sc = ch.calcSelectScore()
if torch.is_tensor(sc):
sc = sc.item() if sc.numel() == 1 else float(sc.mean().item())
sc = float(sc) if np.isfinite(sc) else -np.inf # push bad scores to -inf
selectable_children.append(ch)
selectScores.append(sc)
if not selectable_children:
return rootNode, 3
scores = np.asarray(selectScores, dtype=np.float64)
# top-k indices (largest scores)
k_eff = min(k, len(scores))
topk_idx = np.argpartition(-scores, kth=k_eff-1)[:k_eff]
# sort the top-k by score desc for stability
topk_idx = topk_idx[np.argsort(-scores[topk_idx])]
# slice down to top-k pool
pool_scores = scores[topk_idx]
pool_children = [selectable_children[i] for i in topk_idx]
# softmax over the top-k
m = np.max(pool_scores)
z = (pool_scores - m) / max(1e-8, temp)
exps = np.exp(np.clip(z, -60, 60))
tot = exps.sum()
if not np.isfinite(tot) or tot <= 0.0:
idx_local = 0 # best
else:
probs = exps / tot
idx_local = int(np.random.choice(len(pool_children), p=probs))
selected = pool_children[idx_local]
return selected, selected.getExpandStatus()
return rootNode, 3
def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward):
""""
Adds a child node:
log_rnd: log_rnd of the path up to the added child node
log_policy_step: scalar value of the log-prob of sampling the step under the policy
log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model
"""
child = Node(args=self.args,
tokens=tokens,
log_rnd = log_rnd,
log_policy_step=log_policy_step,
log_pretrained_step=log_pretrained_step,
parentNode=self,
childNodes=[],
totalReward=totalReward,
timestep=self.timestep+1)
self.childNodes.append(child)
return child
def update_logrnd(self, log_policy_step, log_rnd):
self.log_policy_step = log_policy_step
self.log_rnd = log_rnd
def updateNode(self, rewards):
"""
Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
Increments the number of visits to the node.
"""
self.visits += 1
self.totalReward += rewards # singleton tensor
def calcSelectScore(self):
"""
Calculates the select score for the node from the cumulative rewards vector and number of visits.
- c: determines the degree of exploration
- minSelectScore: determines the
"""
# K-dimensional vector of normalized rewards for each objective
normRewards = self.totalReward / self.visits
# scales the cumulative reward by the sampling probability
return normRewards + (self.args.exploration * self.log_policy_step * np.sqrt(self.parentNode.visits) / self.visits)
def getExpandStatus(self):
"""
Returns an integer indicating whether the node is a:
1. terminal node (sequence is fully unmasked)
2. legal leaf node (partially unmasked sequence that can be expanded)
3. legal non-leaf node (already expanded sequence with M child nodes)
"""
if self.timestep == self.args.total_num_steps:
return 1
elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0):
return 2
return 3
### END OF NODE CLASS ###
### BEGINNING OF MCTS CLASS ###
class MCTS:
def __init__(self, args, config, policy_model, pretrained, rewardFunc, rootNode=None):
self.timer = StepTimer(policy_model.device)
# debugging
self.buf_stats = {"insert":0, "replace":0, "reject_worse":0,
"reject_dup":0, "reject_nonfinite":0}
self._seen_hashes = set()
self.device = policy_model.device
print(f"MCTS device: {self.device}")
self.args = args
self.config = config
self.noise = noise_schedule.get_noise(config)
self.time_conditioning = args.time_conditioning
self.mask_index = policy_model.mask_index
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
if rootNode is None:
self.rootNode = Node(self.args, tokens = masked_tokens,
log_rnd=torch.zeros((), device=self.device),
log_policy_step=torch.zeros((), device=self.device),
log_pretrained_step=torch.zeros((), device=self.device),
totalReward=torch.zeros((), device=self.device), timestep=0)
else:
self.rootNode = rootNode # stores the root node of the tree
# dictionary:
# "seq": final unmasked sequence
# "traj": list of (N_steps, L)
# "reward": reward of the trajectory
self.buffer = [] # List[Dict[str, Any]]
self.buffer_size = args.buffer_size
self.num_steps = args.total_num_steps
self.num_sequences = args.num_sequences
# pretrained model
self.pretrained = pretrained
# the policy model that we want to finetune
self.policy_model = policy_model
#self.tokenizer = policy_model.tokenizer
self.device = policy_model.device
self.sequence_length = args.seq_length
self.num_iter = args.num_iter
self.num_children = args.num_children
# score functions
self.rewardFunc = rewardFunc
self.iter_num = 0
self.reward_log = []
self.logrnd_log = []
self.policy_model.eval()
self.pretrained.eval()
self.rewardFunc.eval()
def _hash_tokens(self, t):
# t: (L,) torch.long
return tuple(t.detach().cpu().tolist())
def reset(self, resetTree):
self.iter_num = 0
self.buffer = []
self._seen_hashes = set() # Clear the hash set too!
self.reward_log = []
self.logrnd_log = []
# add option to continue with the same tree
if resetTree:
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
self.rootNode = Node(self.args, tokens = masked_tokens,
log_rnd=torch.zeros((), device=self.device),
log_policy_step=torch.zeros((), device=self.device),
log_pretrained_step=torch.zeros((), device=self.device),
totalReward=torch.zeros((), device=self.device), timestep=0)
def forward(self, resetTree=False):
self.reset(resetTree)
while (self.iter_num < self.num_iter):
self.iter_num += 1
# traverse the tree form the root node until a leaf node
with self.timer.section("select"):
leafNode, _ = self.select(self.rootNode)
# expand leaf node into num_children partially unmasked sequences at the next timestep
with self.timer.section("expand"):
self.expand(leafNode)
final_x, log_rnd, final_rewards = self.consolidateBuffer()
rows = self.timer.summary()
print("\n=== Timing summary (by total time) ===")
for name, cnt, total, mean, p50, p95 in rows:
print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
# return final_seqs (B, L), log_rnd (B, ), and final rewards (B, )
return final_x, log_rnd, final_rewards
def updateBuffer(self, x_final, log_rnd, final_reward):
B = x_final.shape[0]
for i in range(B):
# Finite check
if not torch.isfinite(final_reward[i]) or not torch.isfinite(log_rnd[i]):
self.buf_stats["reject_nonfinite"] += 1
continue
h = self._hash_tokens(x_final[i])
if h in self._seen_hashes:
self.buf_stats["reject_dup"] += 1
continue
item = {"x_final": x_final[i].clone(),
"log_rnd": log_rnd[i].clone(),
"final_reward": final_reward[i].clone()}
if len(self.buffer) < self.buffer_size:
self.buffer.append(item)
self._seen_hashes.add(h)
self.buf_stats["insert"] += 1
else:
# replace if strictly better, or tie-break with log_rnd
min_idx, min_item = min(
enumerate(self.buffer),
key=lambda kv: (kv[1]["final_reward"].item(), kv[1]["log_rnd"].item())
)
cand_key = (final_reward[i].item(), log_rnd[i].item())
min_key = (min_item["final_reward"].item(), min_item["log_rnd"].item())
if cand_key > min_key: # allow ties via 2nd key
# update hash set
old_h = self._hash_tokens(self.buffer[min_idx]["x_final"])
if old_h in self._seen_hashes:
self._seen_hashes.remove(old_h)
self.buffer[min_idx] = item
self._seen_hashes.add(h)
self.buf_stats["replace"] += 1
else:
self.buf_stats["reject_worse"] += 1
def print_buffer_stats(self):
print("[BUFFER] ",
" ".join(f"{k}={v}" for k,v in self.buf_stats.items()),
f" size={len(self.buffer)}/{self.buffer_size}")
if len(self.buffer):
vals = torch.stack([b["final_reward"] for b in self.buffer]).float()
print(f"[BUFFER] reward min/mean/max: {vals.min():.4f} {vals.mean():.4f} {vals.max():.4f}")
def consolidateBuffer(self):
"""
returns x_final, log_rnd, and final_rewards in tensors
"""
x_final = []
log_rnd = []
final_rewards = []
for item in self.buffer:
x_final.append(item["x_final"])
log_rnd.append(item["log_rnd"])
final_rewards.append(item["final_reward"])
x_final = torch.stack(x_final, dim=0) # (B, L)
log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (B)
final_rewards = torch.stack(final_rewards, dim=0).to(dtype=torch.float32) # (B)
return x_final, log_rnd, final_rewards
def isPathEnd(self, path, maxDepth):
"""
Checks if the node is completely unmasked (ie. end of path)
or if the path is at the max depth
"""
if (path[-1] != self.mask_index).all():
return True
elif len(path) >= maxDepth:
return True
return False
def select(self, currNode, eps=1e-5):
"""
Traverse the tree from the root node until reaching a legal leaf node
"""
#iter = 1
updated_log_rnd = torch.zeros((), device=self.device)
while True:
if self.args.select_topk:
currNode, nodeStatus = currNode.selectNodeTopK(self.rootNode, k=self.args.select_topk_value, temp=1.0)
else:
currNode, nodeStatus = currNode.selectNode(self.rootNode)
if currNode.parentNode is not None:
# compute new log_policy
child_tokens = currNode.tokens['seqs'].to(self.device)
attn_mask = currNode.tokens['attention_mask'].to(self.device)
parent = currNode.parentNode
parent_tokens = parent.tokens['seqs'].to(self.device)
t = torch.ones(1, device = self.device)
dt = (1 - eps) / self.num_steps
with torch.no_grad():
with self.timer.section("select.compute_log_policy"):
updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens,
child_tokens,
t=t, dt=dt)
updated_log_rnd += (currNode.log_pretrained_step - updated_log_policy_step)
currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) # update log_rnd
# node is terminal node or logal leaf node, return for expansion
if nodeStatus == 2:
return currNode, nodeStatus
elif nodeStatus == 1:
currNode = self.rootNode
def expand(self, parentNode, eps=1e-5):
"""
Sample unmasking steps from the pre-trained MDLM
adds num_children partially unmasked sequences to the children of the parentNode
"""
num_children = self.num_children
# initialize child rewards that will be added to total rewards
allChildReward = torch.zeros((), device=self.device)
# compute number of rollout steps
# if parentNode.timestep = self.num_steps then num_rollout_steps = 1
num_rollout_steps = self.num_steps - parentNode.timestep
# array of rollout timesteps from the timestep of parent node to 0
rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device)
dt = (1 - eps) / self.num_steps
# initialize x and attn_mask
x = parentNode.tokens['seqs'].to(self.device)
attn_mask = parentNode.tokens['attention_mask'].to(self.device)
parent_log_rnd = parentNode.log_rnd # stores the log_rnd up to parent node
t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device)
# generate (n_children, seq_length) array of sampled children nodes
# sample M child sequences and compute their log probabilities
with torch.no_grad():
with self.timer.section("expand.batch_mcts_reverse_step"):
child_log_p, x_children, child_log_policy_step, child_log_pretrained_step = \
self.policy_model.batch_mcts_reverse_step(token_array=x,
t=t, dt=dt,
batch_size=num_children,
pretrained=self.pretrained)
# compute weight of the step (num_children, 1)
child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device)
x_rollout = x_children
traj_log_rnd = child_log_rnd # initialize log_rnd for entire rolled out trajectory
# rollout under the policy and compute the log ratio at each step
with self.timer.section("expand.rollout_total"):
for i in range(1, num_rollout_steps):
t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device)
with torch.no_grad():
log_p, x_next, log_policy_step, log_pretrained_step = \
self.policy_model.mcts_reverse_step(x_rollout,
t=t, dt=dt,
pretrained=self.pretrained)
# add the rollout step
traj_log_rnd += log_pretrained_step - log_policy_step
x_rollout = x_next
# if mask token remains, fully unmask
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
# does **any** mask remain in any sequence
any_mask_global = mask_positions.any().item() # true if mask remains
if any_mask_global:
with torch.no_grad():
with self.timer.section("expand.noise_removal"):
log_p, x_next, log_policy_step, log_pretrained_step = \
self.policy_model.mcts_noise_removal(x_rollout,
t=t, dt=dt,
pretrained=self.pretrained)
traj_log_rnd += log_pretrained_step - log_policy_step
x_rollout = x_next
x_final = x_rollout # final sequences (B, L)
# edit? how is the reward model defined?
#childSequences = self.tokenizer.batch_decode(x_rollout)
#if self.args.data == "peptide":
#validSequences = []
# get final rewards
x_one_hot = to_one_hot(x_final)
x_one_hot_reward = torch.transpose(x_one_hot, 1, 2)
reward_preds = self.rewardFunc(x_one_hot_reward).squeeze(-1) # (num_children, 4)
rewards_value = reward_preds[:, 0] # (num_children, 1)
if self.args.reward_clip:
rewards = torch.clamp(rewards_value, max=self.args.reward_clip_value)
else:
rewards = rewards_value
traj_log_rnd += rewards / self.args.alpha
self.reward_log.append(rewards.detach().cpu().numpy())
self.logrnd_log.append(traj_log_rnd.detach().cpu().numpy())
# update buffer
with self.timer.section("expand.update_buffer"):
self.updateBuffer(x_final, traj_log_rnd, rewards)
for i in range(num_children):
# add to all child reward vector for backprop
allChildReward += rewards[i]
# create node for sequence and add to the children node of parent
childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
parentNode.addChildNode(tokens=childTokens,
log_rnd=child_log_rnd[i],
log_policy_step=child_log_policy_step[i],
log_pretrained_step=child_log_pretrained_step[i],
totalReward=rewards[i])
# backpropogate all child rewards
with self.timer.section("expand.backprop"):
self.backprop(parentNode, allChildReward)
def backprop(self, node, allChildReward):
# backpropogate rewards through the path leading to the leaf node from the root
while node:
node.updateNode(allChildReward)
node = node.parentNode