#!/usr/bin/env python3 """ Decode all 80 generated sequences and test them with HMD-AMP. """ import torch import numpy as np import pandas as pd from Bio import SeqIO from Bio.SeqRecord import SeqRecord from Bio.Seq import Seq import os from datetime import datetime from tqdm import tqdm import sys # Import the decoder from final_sequence_decoder import EmbeddingToSequenceConverter # Import HMD-AMP components sys.path.append('/home/edwardsun/flow/HMD-AMP') from sklearn.utils import shuffle import esm from deepforest import CascadeForestClassifier from src.utils import * def load_generated_embeddings(): """Load all generated embeddings from today.""" base_path = '/data2/edwardsun/generated_samples' today = '20250829' files = [ f'generated_amps_best_model_no_cfg_{today}.pt', f'generated_amps_best_model_weak_cfg_{today}.pt', f'generated_amps_best_model_strong_cfg_{today}.pt', f'generated_amps_best_model_very_strong_cfg_{today}.pt' ] all_embeddings = [] all_labels = [] for file in files: file_path = os.path.join(base_path, file) if os.path.exists(file_path): print(f"Loading {file}...") embeddings = torch.load(file_path, map_location='cpu') # Extract config type from filename if 'no_cfg' in file: cfg_type = 'no_cfg' elif 'weak_cfg' in file: cfg_type = 'weak_cfg' elif 'strong_cfg' in file and 'very' not in file: cfg_type = 'strong_cfg' elif 'very_strong_cfg' in file: cfg_type = 'very_strong_cfg' # Each file contains 20 sequences for i in range(embeddings.shape[0]): all_embeddings.append(embeddings[i]) all_labels.append(f"{cfg_type}_{i+1}") print(f"āœ“ Loaded {len(all_embeddings)} embeddings total") return all_embeddings, all_labels def decode_embeddings_to_sequences(embeddings, labels): """Decode embeddings to sequences.""" print("Initializing sequence decoder...") decoder = EmbeddingToSequenceConverter(device='cuda') sequences = [] sequence_ids = [] print("Decoding embeddings to sequences...") for i, (embedding, label) in enumerate(tqdm(zip(embeddings, labels), total=len(embeddings))): # Decode using diverse method for better results sequence = decoder.embedding_to_sequence( embedding, method='diverse', temperature=0.8 ) sequences.append(sequence) sequence_ids.append(f"generated_seq_{i+1}_{label}") return sequences, sequence_ids def save_sequences_as_fasta(sequences, sequence_ids, filename): """Save sequences as FASTA file.""" records = [] for seq_id, seq in zip(sequence_ids, sequences): record = SeqRecord(Seq(seq), id=seq_id, description="") records.append(record) SeqIO.write(records, filename, "fasta") print(f"āœ“ Saved {len(sequences)} sequences to {filename}") def test_with_hmd_amp(sequences, sequence_ids): """Test sequences with HMD-AMP classifier.""" print("\n🧬 Testing sequences with HMD-AMP classifier...") # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load models ftmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/ft_parts.pth' clsmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/clsmodel' # Create temporary FASTA file for HMD-AMP temp_fasta = 'temp_sequences.fasta' save_sequences_as_fasta(sequences, sequence_ids, temp_fasta) try: # Generate sequence features using HMD-AMP's feature extraction seq_embeddings, _, seq_ids = amp_feature_extraction(ftmodel_save_path, device, temp_fasta) # Load classifier cls_model = CascadeForestClassifier() cls_model.load(clsmodel_save_path) # Make predictions binary_pred = cls_model.predict(seq_embeddings) print(f"šŸ“Š HMD-AMP Results:") print(f"Total sequences: {len(sequences)}") print(f"Predicted AMPs: {np.sum(binary_pred)} ({np.sum(binary_pred)/len(sequences)*100:.1f}%)") print(f"Predicted non-AMPs: {len(sequences) - np.sum(binary_pred)} ({(len(sequences) - np.sum(binary_pred))/len(sequences)*100:.1f}%)") # Analyze results by CFG type results_df = pd.DataFrame({ 'ID': sequence_ids, 'Sequence': sequences, 'AMP_Prediction': binary_pred, 'CFG_Type': [seq_id.split('_')[-2] for seq_id in sequence_ids] }) # Group by CFG type cfg_analysis = results_df.groupby('CFG_Type')['AMP_Prediction'].agg(['count', 'sum', 'mean']).round(3) cfg_analysis.columns = ['Total', 'Predicted_AMPs', 'AMP_Rate'] print(f"\nšŸ“‹ Results by CFG Configuration:") print(cfg_analysis) # Show predicted AMPs amp_results = results_df[results_df['AMP_Prediction'] == 1] if len(amp_results) > 0: print(f"\nšŸ† Sequences predicted as AMPs ({len(amp_results)}):") for idx, row in amp_results.iterrows(): seq = row['Sequence'] cationic = seq.count('K') + seq.count('R') net_charge = seq.count('K') + seq.count('R') + seq.count('H') - seq.count('D') - seq.count('E') print(f" {row['ID']}: {seq}") print(f" Length: {len(seq)}, Cationic (K+R): {cationic}, Net charge: {net_charge:+d}") else: print(f"\nāŒ No sequences predicted as AMPs") # Save detailed results results_df.to_csv('hmd_amp_detailed_results.csv', index=False) cfg_analysis.to_csv('hmd_amp_cfg_analysis.csv') print(f"\nšŸ’¾ Results saved:") print(f" - hmd_amp_detailed_results.csv (detailed per-sequence results)") print(f" - hmd_amp_cfg_analysis.csv (summary by CFG type)") return results_df, cfg_analysis finally: # Clean up temporary file if os.path.exists(temp_fasta): os.remove(temp_fasta) def main(): print("šŸš€ Starting sequence decoding and HMD-AMP testing...") # Load embeddings embeddings, labels = load_generated_embeddings() # Decode to sequences sequences, sequence_ids = decode_embeddings_to_sequences(embeddings, labels) # Save sequences as FASTA fasta_filename = f'generated_sequences_{datetime.now().strftime("%Y%m%d_%H%M%S")}.fasta' save_sequences_as_fasta(sequences, sequence_ids, fasta_filename) # Test with HMD-AMP results_df, cfg_analysis = test_with_hmd_amp(sequences, sequence_ids) print(f"\nāœ… Complete! Generated and tested {len(sequences)} sequences") print(f"šŸ“ Sequences saved as: {fasta_filename}") # Final summary total_amps = results_df['AMP_Prediction'].sum() print(f"\nšŸ“Š FINAL SUMMARY:") print(f"Generated sequences: {len(sequences)}") print(f"HMD-AMP predicted AMPs: {total_amps}/{len(sequences)} ({total_amps/len(sequences)*100:.1f}%)") if total_amps > 0: print(f"✨ Success! Your flow model generated {total_amps} sequences that HMD-AMP classifies as AMPs!") else: print(f"šŸ” No sequences classified as AMPs - this may indicate the need for stronger AMP conditioning.") if __name__ == "__main__": main()