FlowFinal / src /decode_and_test_sequences.py
esunAI's picture
Add decode_and_test_sequences.py
bc21134 verified
raw
history blame
7.58 kB
#!/usr/bin/env python3
"""
Decode all 80 generated sequences and test them with HMD-AMP.
"""
import torch
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
import os
from datetime import datetime
from tqdm import tqdm
import sys
# Import the decoder
from final_sequence_decoder import EmbeddingToSequenceConverter
# Import HMD-AMP components
sys.path.append('/home/edwardsun/flow/HMD-AMP')
from sklearn.utils import shuffle
import esm
from deepforest import CascadeForestClassifier
from src.utils import *
def load_generated_embeddings():
"""Load all generated embeddings from today."""
base_path = '/data2/edwardsun/generated_samples'
today = '20250829'
files = [
f'generated_amps_best_model_no_cfg_{today}.pt',
f'generated_amps_best_model_weak_cfg_{today}.pt',
f'generated_amps_best_model_strong_cfg_{today}.pt',
f'generated_amps_best_model_very_strong_cfg_{today}.pt'
]
all_embeddings = []
all_labels = []
for file in files:
file_path = os.path.join(base_path, file)
if os.path.exists(file_path):
print(f"Loading {file}...")
embeddings = torch.load(file_path, map_location='cpu')
# Extract config type from filename
if 'no_cfg' in file:
cfg_type = 'no_cfg'
elif 'weak_cfg' in file:
cfg_type = 'weak_cfg'
elif 'strong_cfg' in file and 'very' not in file:
cfg_type = 'strong_cfg'
elif 'very_strong_cfg' in file:
cfg_type = 'very_strong_cfg'
# Each file contains 20 sequences
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i])
all_labels.append(f"{cfg_type}_{i+1}")
print(f"βœ“ Loaded {len(all_embeddings)} embeddings total")
return all_embeddings, all_labels
def decode_embeddings_to_sequences(embeddings, labels):
"""Decode embeddings to sequences."""
print("Initializing sequence decoder...")
decoder = EmbeddingToSequenceConverter(device='cuda')
sequences = []
sequence_ids = []
print("Decoding embeddings to sequences...")
for i, (embedding, label) in enumerate(tqdm(zip(embeddings, labels), total=len(embeddings))):
# Decode using diverse method for better results
sequence = decoder.embedding_to_sequence(
embedding,
method='diverse',
temperature=0.8
)
sequences.append(sequence)
sequence_ids.append(f"generated_seq_{i+1}_{label}")
return sequences, sequence_ids
def save_sequences_as_fasta(sequences, sequence_ids, filename):
"""Save sequences as FASTA file."""
records = []
for seq_id, seq in zip(sequence_ids, sequences):
record = SeqRecord(Seq(seq), id=seq_id, description="")
records.append(record)
SeqIO.write(records, filename, "fasta")
print(f"βœ“ Saved {len(sequences)} sequences to {filename}")
def test_with_hmd_amp(sequences, sequence_ids):
"""Test sequences with HMD-AMP classifier."""
print("\n🧬 Testing sequences with HMD-AMP classifier...")
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models
ftmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/ft_parts.pth'
clsmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/clsmodel'
# Create temporary FASTA file for HMD-AMP
temp_fasta = 'temp_sequences.fasta'
save_sequences_as_fasta(sequences, sequence_ids, temp_fasta)
try:
# Generate sequence features using HMD-AMP's feature extraction
seq_embeddings, _, seq_ids = amp_feature_extraction(ftmodel_save_path, device, temp_fasta)
# Load classifier
cls_model = CascadeForestClassifier()
cls_model.load(clsmodel_save_path)
# Make predictions
binary_pred = cls_model.predict(seq_embeddings)
print(f"πŸ“Š HMD-AMP Results:")
print(f"Total sequences: {len(sequences)}")
print(f"Predicted AMPs: {np.sum(binary_pred)} ({np.sum(binary_pred)/len(sequences)*100:.1f}%)")
print(f"Predicted non-AMPs: {len(sequences) - np.sum(binary_pred)} ({(len(sequences) - np.sum(binary_pred))/len(sequences)*100:.1f}%)")
# Analyze results by CFG type
results_df = pd.DataFrame({
'ID': sequence_ids,
'Sequence': sequences,
'AMP_Prediction': binary_pred,
'CFG_Type': [seq_id.split('_')[-2] for seq_id in sequence_ids]
})
# Group by CFG type
cfg_analysis = results_df.groupby('CFG_Type')['AMP_Prediction'].agg(['count', 'sum', 'mean']).round(3)
cfg_analysis.columns = ['Total', 'Predicted_AMPs', 'AMP_Rate']
print(f"\nπŸ“‹ Results by CFG Configuration:")
print(cfg_analysis)
# Show predicted AMPs
amp_results = results_df[results_df['AMP_Prediction'] == 1]
if len(amp_results) > 0:
print(f"\nπŸ† Sequences predicted as AMPs ({len(amp_results)}):")
for idx, row in amp_results.iterrows():
seq = row['Sequence']
cationic = seq.count('K') + seq.count('R')
net_charge = seq.count('K') + seq.count('R') + seq.count('H') - seq.count('D') - seq.count('E')
print(f" {row['ID']}: {seq}")
print(f" Length: {len(seq)}, Cationic (K+R): {cationic}, Net charge: {net_charge:+d}")
else:
print(f"\n❌ No sequences predicted as AMPs")
# Save detailed results
results_df.to_csv('hmd_amp_detailed_results.csv', index=False)
cfg_analysis.to_csv('hmd_amp_cfg_analysis.csv')
print(f"\nπŸ’Ύ Results saved:")
print(f" - hmd_amp_detailed_results.csv (detailed per-sequence results)")
print(f" - hmd_amp_cfg_analysis.csv (summary by CFG type)")
return results_df, cfg_analysis
finally:
# Clean up temporary file
if os.path.exists(temp_fasta):
os.remove(temp_fasta)
def main():
print("πŸš€ Starting sequence decoding and HMD-AMP testing...")
# Load embeddings
embeddings, labels = load_generated_embeddings()
# Decode to sequences
sequences, sequence_ids = decode_embeddings_to_sequences(embeddings, labels)
# Save sequences as FASTA
fasta_filename = f'generated_sequences_{datetime.now().strftime("%Y%m%d_%H%M%S")}.fasta'
save_sequences_as_fasta(sequences, sequence_ids, fasta_filename)
# Test with HMD-AMP
results_df, cfg_analysis = test_with_hmd_amp(sequences, sequence_ids)
print(f"\nβœ… Complete! Generated and tested {len(sequences)} sequences")
print(f"πŸ“ Sequences saved as: {fasta_filename}")
# Final summary
total_amps = results_df['AMP_Prediction'].sum()
print(f"\nπŸ“Š FINAL SUMMARY:")
print(f"Generated sequences: {len(sequences)}")
print(f"HMD-AMP predicted AMPs: {total_amps}/{len(sequences)} ({total_amps/len(sequences)*100:.1f}%)")
if total_amps > 0:
print(f"✨ Success! Your flow model generated {total_amps} sequences that HMD-AMP classifies as AMPs!")
else:
print(f"πŸ” No sequences classified as AMPs - this may indicate the need for stronger AMP conditioning.")
if __name__ == "__main__":
main()