|
|
|
|
|
""" |
|
|
Decode all 80 generated sequences and test them with HMD-AMP. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from Bio import SeqIO |
|
|
from Bio.SeqRecord import SeqRecord |
|
|
from Bio.Seq import Seq |
|
|
import os |
|
|
from datetime import datetime |
|
|
from tqdm import tqdm |
|
|
import sys |
|
|
|
|
|
|
|
|
from final_sequence_decoder import EmbeddingToSequenceConverter |
|
|
|
|
|
|
|
|
sys.path.append('/home/edwardsun/flow/HMD-AMP') |
|
|
from sklearn.utils import shuffle |
|
|
import esm |
|
|
from deepforest import CascadeForestClassifier |
|
|
from src.utils import * |
|
|
|
|
|
def load_generated_embeddings(): |
|
|
"""Load all generated embeddings from today.""" |
|
|
base_path = '/data2/edwardsun/generated_samples' |
|
|
today = '20250829' |
|
|
|
|
|
files = [ |
|
|
f'generated_amps_best_model_no_cfg_{today}.pt', |
|
|
f'generated_amps_best_model_weak_cfg_{today}.pt', |
|
|
f'generated_amps_best_model_strong_cfg_{today}.pt', |
|
|
f'generated_amps_best_model_very_strong_cfg_{today}.pt' |
|
|
] |
|
|
|
|
|
all_embeddings = [] |
|
|
all_labels = [] |
|
|
|
|
|
for file in files: |
|
|
file_path = os.path.join(base_path, file) |
|
|
if os.path.exists(file_path): |
|
|
print(f"Loading {file}...") |
|
|
embeddings = torch.load(file_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if 'no_cfg' in file: |
|
|
cfg_type = 'no_cfg' |
|
|
elif 'weak_cfg' in file: |
|
|
cfg_type = 'weak_cfg' |
|
|
elif 'strong_cfg' in file and 'very' not in file: |
|
|
cfg_type = 'strong_cfg' |
|
|
elif 'very_strong_cfg' in file: |
|
|
cfg_type = 'very_strong_cfg' |
|
|
|
|
|
|
|
|
for i in range(embeddings.shape[0]): |
|
|
all_embeddings.append(embeddings[i]) |
|
|
all_labels.append(f"{cfg_type}_{i+1}") |
|
|
|
|
|
print(f"β Loaded {len(all_embeddings)} embeddings total") |
|
|
return all_embeddings, all_labels |
|
|
|
|
|
def decode_embeddings_to_sequences(embeddings, labels): |
|
|
"""Decode embeddings to sequences.""" |
|
|
print("Initializing sequence decoder...") |
|
|
decoder = EmbeddingToSequenceConverter(device='cuda') |
|
|
|
|
|
sequences = [] |
|
|
sequence_ids = [] |
|
|
|
|
|
print("Decoding embeddings to sequences...") |
|
|
for i, (embedding, label) in enumerate(tqdm(zip(embeddings, labels), total=len(embeddings))): |
|
|
|
|
|
sequence = decoder.embedding_to_sequence( |
|
|
embedding, |
|
|
method='diverse', |
|
|
temperature=0.8 |
|
|
) |
|
|
sequences.append(sequence) |
|
|
sequence_ids.append(f"generated_seq_{i+1}_{label}") |
|
|
|
|
|
return sequences, sequence_ids |
|
|
|
|
|
def save_sequences_as_fasta(sequences, sequence_ids, filename): |
|
|
"""Save sequences as FASTA file.""" |
|
|
records = [] |
|
|
for seq_id, seq in zip(sequence_ids, sequences): |
|
|
record = SeqRecord(Seq(seq), id=seq_id, description="") |
|
|
records.append(record) |
|
|
|
|
|
SeqIO.write(records, filename, "fasta") |
|
|
print(f"β Saved {len(sequences)} sequences to {filename}") |
|
|
|
|
|
def test_with_hmd_amp(sequences, sequence_ids): |
|
|
"""Test sequences with HMD-AMP classifier.""" |
|
|
print("\n𧬠Testing sequences with HMD-AMP classifier...") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
ftmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/ft_parts.pth' |
|
|
clsmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/clsmodel' |
|
|
|
|
|
|
|
|
temp_fasta = 'temp_sequences.fasta' |
|
|
save_sequences_as_fasta(sequences, sequence_ids, temp_fasta) |
|
|
|
|
|
try: |
|
|
|
|
|
seq_embeddings, _, seq_ids = amp_feature_extraction(ftmodel_save_path, device, temp_fasta) |
|
|
|
|
|
|
|
|
cls_model = CascadeForestClassifier() |
|
|
cls_model.load(clsmodel_save_path) |
|
|
|
|
|
|
|
|
binary_pred = cls_model.predict(seq_embeddings) |
|
|
|
|
|
print(f"π HMD-AMP Results:") |
|
|
print(f"Total sequences: {len(sequences)}") |
|
|
print(f"Predicted AMPs: {np.sum(binary_pred)} ({np.sum(binary_pred)/len(sequences)*100:.1f}%)") |
|
|
print(f"Predicted non-AMPs: {len(sequences) - np.sum(binary_pred)} ({(len(sequences) - np.sum(binary_pred))/len(sequences)*100:.1f}%)") |
|
|
|
|
|
|
|
|
results_df = pd.DataFrame({ |
|
|
'ID': sequence_ids, |
|
|
'Sequence': sequences, |
|
|
'AMP_Prediction': binary_pred, |
|
|
'CFG_Type': [seq_id.split('_')[-2] for seq_id in sequence_ids] |
|
|
}) |
|
|
|
|
|
|
|
|
cfg_analysis = results_df.groupby('CFG_Type')['AMP_Prediction'].agg(['count', 'sum', 'mean']).round(3) |
|
|
cfg_analysis.columns = ['Total', 'Predicted_AMPs', 'AMP_Rate'] |
|
|
|
|
|
print(f"\nπ Results by CFG Configuration:") |
|
|
print(cfg_analysis) |
|
|
|
|
|
|
|
|
amp_results = results_df[results_df['AMP_Prediction'] == 1] |
|
|
if len(amp_results) > 0: |
|
|
print(f"\nπ Sequences predicted as AMPs ({len(amp_results)}):") |
|
|
for idx, row in amp_results.iterrows(): |
|
|
seq = row['Sequence'] |
|
|
cationic = seq.count('K') + seq.count('R') |
|
|
net_charge = seq.count('K') + seq.count('R') + seq.count('H') - seq.count('D') - seq.count('E') |
|
|
print(f" {row['ID']}: {seq}") |
|
|
print(f" Length: {len(seq)}, Cationic (K+R): {cationic}, Net charge: {net_charge:+d}") |
|
|
else: |
|
|
print(f"\nβ No sequences predicted as AMPs") |
|
|
|
|
|
|
|
|
results_df.to_csv('hmd_amp_detailed_results.csv', index=False) |
|
|
cfg_analysis.to_csv('hmd_amp_cfg_analysis.csv') |
|
|
|
|
|
print(f"\nπΎ Results saved:") |
|
|
print(f" - hmd_amp_detailed_results.csv (detailed per-sequence results)") |
|
|
print(f" - hmd_amp_cfg_analysis.csv (summary by CFG type)") |
|
|
|
|
|
return results_df, cfg_analysis |
|
|
|
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_fasta): |
|
|
os.remove(temp_fasta) |
|
|
|
|
|
def main(): |
|
|
print("π Starting sequence decoding and HMD-AMP testing...") |
|
|
|
|
|
|
|
|
embeddings, labels = load_generated_embeddings() |
|
|
|
|
|
|
|
|
sequences, sequence_ids = decode_embeddings_to_sequences(embeddings, labels) |
|
|
|
|
|
|
|
|
fasta_filename = f'generated_sequences_{datetime.now().strftime("%Y%m%d_%H%M%S")}.fasta' |
|
|
save_sequences_as_fasta(sequences, sequence_ids, fasta_filename) |
|
|
|
|
|
|
|
|
results_df, cfg_analysis = test_with_hmd_amp(sequences, sequence_ids) |
|
|
|
|
|
print(f"\nβ
Complete! Generated and tested {len(sequences)} sequences") |
|
|
print(f"π Sequences saved as: {fasta_filename}") |
|
|
|
|
|
|
|
|
total_amps = results_df['AMP_Prediction'].sum() |
|
|
print(f"\nπ FINAL SUMMARY:") |
|
|
print(f"Generated sequences: {len(sequences)}") |
|
|
print(f"HMD-AMP predicted AMPs: {total_amps}/{len(sequences)} ({total_amps/len(sequences)*100:.1f}%)") |
|
|
|
|
|
if total_amps > 0: |
|
|
print(f"β¨ Success! Your flow model generated {total_amps} sequences that HMD-AMP classifies as AMPs!") |
|
|
else: |
|
|
print(f"π No sequences classified as AMPs - this may indicate the need for stronger AMP conditioning.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|