|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import random |
|
|
import sys |
|
|
from diffusion import Diffusion |
|
|
import hydra |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import os |
|
|
import seaborn as sns |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import argparse |
|
|
|
|
|
from diffusion import Diffusion |
|
|
from hydra import initialize, compose |
|
|
from hydra.core.global_hydra import GlobalHydra |
|
|
import numpy as np |
|
|
import torch |
|
|
import argparse |
|
|
import os |
|
|
import datetime |
|
|
from utils.utils import str2bool, set_seed |
|
|
|
|
|
|
|
|
from utils.app import PeptideAnalyzer |
|
|
from peptide_mcts import MCTS |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_mcts(args, cfg, policy_model, pretrained, prot=None, prot_name=None, filename=None): |
|
|
|
|
|
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'] |
|
|
|
|
|
mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot]) |
|
|
|
|
|
final_x, log_rnd, final_rewards, score_vectors, sequences = mcts.forward() |
|
|
|
|
|
return final_x, log_rnd, final_rewards, score_vectors, sequences |
|
|
|
|
|
def save_logs_to_file(reward_log, logrnd_log, |
|
|
valid_fraction_log, affinity1_log, |
|
|
sol_log, hemo_log, nf_log, |
|
|
permeability_log, output_path): |
|
|
""" |
|
|
Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file. |
|
|
|
|
|
Parameters: |
|
|
valid_fraction_log (list): Log of valid fractions over iterations. |
|
|
affinity1_log (list): Log of binding affinity over iterations. |
|
|
permeability_log (list): Log of membrane permeability over iterations. |
|
|
output_path (str): Path to save the log CSV file. |
|
|
""" |
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
|
|
|
|
|
log_data = { |
|
|
"Iteration": list(range(1, len(valid_fraction_log) + 1)), |
|
|
"Reward": reward_log, |
|
|
"Log RND": logrnd_log, |
|
|
"Valid Fraction": valid_fraction_log, |
|
|
"Binding Affinity": affinity1_log, |
|
|
"Solubility": sol_log, |
|
|
"Hemolysis": hemo_log, |
|
|
"Nonfouling": nf_log, |
|
|
"Permeability": permeability_log |
|
|
} |
|
|
|
|
|
df = pd.DataFrame(log_data) |
|
|
|
|
|
|
|
|
df.to_csv(output_path, index=False) |
|
|
|
|
|
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
argparser.add_argument('--base_path', type=str, default='') |
|
|
argparser.add_argument('--learning_rate', type=float, default=1e-4) |
|
|
argparser.add_argument('--num_epochs', type=int, default=1000) |
|
|
argparser.add_argument('--num_accum_steps', type=int, default=4) |
|
|
argparser.add_argument('--truncate_steps', type=int, default=50) |
|
|
argparser.add_argument("--truncate_kl", type=str2bool, default=False) |
|
|
argparser.add_argument('--gumbel_temp', type=float, default=1.0) |
|
|
argparser.add_argument('--gradnorm_clip', type=float, default=1.0) |
|
|
argparser.add_argument('--batch_size', type=int, default=32) |
|
|
argparser.add_argument('--name', type=str, default='debug') |
|
|
argparser.add_argument('--total_num_steps', type=int, default=128) |
|
|
argparser.add_argument('--copy_flag_temp', type=float, default=None) |
|
|
argparser.add_argument('--save_every_n_epochs', type=int, default=50) |
|
|
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0) |
|
|
argparser.add_argument("--seed", type=int, default=0) |
|
|
|
|
|
argparser.add_argument('--run_name', type=str, default='drakes') |
|
|
argparser.add_argument("--device", default="cuda", type=str) |
|
|
|
|
|
|
|
|
argparser.add_argument('--num_sequences', type=int, default=100) |
|
|
argparser.add_argument('--num_children', type=int, default=20) |
|
|
argparser.add_argument('--num_iter', type=int, default=100) |
|
|
argparser.add_argument('--seq_length', type=int, default=200) |
|
|
argparser.add_argument('--time_conditioning', action='store_true', default=False) |
|
|
argparser.add_argument('--mcts_sampling', type=int, default=0) |
|
|
argparser.add_argument('--buffer_size', type=int, default=100) |
|
|
argparser.add_argument('--wdce_num_replicates', type=int, default=16) |
|
|
argparser.add_argument('--noise_removal', action='store_true', default=False) |
|
|
argparser.add_argument('--exploration', type=float, default=0.1) |
|
|
argparser.add_argument('--reset_every_n_step', type=int, default=100) |
|
|
argparser.add_argument('--alpha', type=float, default=0.01) |
|
|
argparser.add_argument('--scalarization', type=str, default='sum') |
|
|
argparser.add_argument('--no_mcts', action='store_true', default=False) |
|
|
argparser.add_argument("--centering", action='store_true', default=False) |
|
|
argparser.add_argument('--num_obj', type=int, default=5) |
|
|
|
|
|
argparser.add_argument('--prot_seq', type=str, default=None) |
|
|
argparser.add_argument('--prot_name', type=str, default=None) |
|
|
|
|
|
args = argparser.parse_args() |
|
|
print(args) |
|
|
|
|
|
|
|
|
ckpt_path = f'{args.base_path}/TR2-D2/tr2d2-pep/pretrained/peptune-pretrained.ckpt' |
|
|
|
|
|
|
|
|
GlobalHydra.instance().clear() |
|
|
|
|
|
|
|
|
initialize(config_path="configs", job_name="load_model") |
|
|
cfg = compose(config_name="config.yaml") |
|
|
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
set_seed(args.seed, use_cuda=True) |
|
|
|
|
|
|
|
|
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV' |
|
|
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF' |
|
|
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM' |
|
|
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS' |
|
|
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM' |
|
|
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF' |
|
|
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL' |
|
|
ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS' |
|
|
skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL' |
|
|
|
|
|
if args.prot_seq is not None: |
|
|
prot = args.prot_seq |
|
|
prot_name = args.prot_name |
|
|
filename = args.prot_name |
|
|
else: |
|
|
prot = tfr |
|
|
prot_name = "tfr" |
|
|
filename = "tfr" |
|
|
|
|
|
|
|
|
new_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device) |
|
|
old_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
final_x, log_rnd, final_rewards, score_vectors, sequences = generate_mcts(args, cfg, new_model, old_model, |
|
|
prot=prot, prot_name=prot_name) |
|
|
|
|
|
final_x = final_x.detach().to('cpu') |
|
|
log_rnd = log_rnd.detach().to('cpu').float().view(-1) |
|
|
|
|
|
|
|
|
print("loaded models...") |
|
|
analyzer = PeptideAnalyzer() |
|
|
|
|
|
generation_results = [] |
|
|
|
|
|
for i in range(final_x.shape[0]): |
|
|
sequence = sequences[i] |
|
|
log_rnd_single = log_rnd[i] |
|
|
final_reward = final_rewards[i] |
|
|
|
|
|
aa_seq, seq_length = analyzer.analyze_structure(sequence) |
|
|
|
|
|
scores = score_vectors[i] |
|
|
|
|
|
binding1 = scores[0] |
|
|
solubility = scores[1] |
|
|
hemo = scores[2] |
|
|
nonfouling = scores[3] |
|
|
permeability = scores[4] |
|
|
|
|
|
generation_results.append([sequence, aa_seq, final_reward, log_rnd_single, binding1, solubility, hemo, nonfouling, permeability]) |
|
|
print(f"length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}") |
|
|
|
|
|
sys.stdout.flush() |
|
|
|
|
|
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Peptide Sequence', 'Final Reward', 'Log RND', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability']) |
|
|
|
|
|
|
|
|
df.to_csv(f'{args.base_path}/TR2-D2/tr2d2-pep/plots/{prot_name}-peptune-baseline/generation_results.csv', index=False) |
|
|
|