import torch import torch.nn.functional as F import numpy as np import esm from tqdm import tqdm import os from datetime import datetime CANONICAL_AAS = list("ACDEFGHIKLMNPQRSTVWY") class EmbeddingToSequenceConverter: """ Decode contextual ESM2 hidden states to amino-acid sequences via the model's LM head. Accepts [L, 1280] or [B, L, 1280] tensors (L≈50 in your pipeline). """ def __init__(self, device="cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") print("Loading ESM model for sequence decoding...") self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() self.model.eval().to(self.device) self.aa_list = CANONICAL_AAS self.aa_token_ids = torch.tensor( [self.alphabet.get_idx(a) for a in self.aa_list], device=self.device, dtype=torch.long ) print("✓ ESM model loaded for sequence decoding") @torch.inference_mode() def _logits_from_hidden(self, hidden): # hidden: [L, D] or [B, L, D]; project exactly as ESM-2 does (LayerNorm → LM head) if hidden.dim() == 2: hidden = hidden.unsqueeze(0) hidden = hidden.to(self.device) # match model dtype to avoid dtype mismatches under autocast hidden = hidden.to(self.model.lm_head.weight.dtype) if hasattr(self.model, "emb_layer_norm_after"): hidden = self.model.emb_layer_norm_after(hidden) logits_full = self.model.lm_head(hidden) # [B, L, |V|] logits_20 = logits_full.index_select(-1, self.aa_token_ids) # [B, L, 20] return logits_20 @torch.inference_mode() def embedding_to_sequence(self, embedding, method="diverse", temperature=0.8, top_p=0.9, top_k=0, seed=None, return_conf=False): logits = self._logits_from_hidden(embedding) # [1, L, 20] if method in ("nearest", "nearest_neighbor"): idx = logits.argmax(-1)[0] probs = logits.softmax(-1)[0] else: if seed is not None: torch.manual_seed(seed) if temperature and temperature > 0: logits = logits / temperature probs = logits.softmax(-1)[0] # [L, 20] V = probs.size(-1) if top_k and top_k < V: kth = torch.topk(probs, top_k, dim=-1).values[..., -1:] probs = torch.where(probs >= kth, probs, torch.zeros_like(probs)) probs = probs / probs.sum(-1, keepdim=True).clamp_min(1e-12) if top_p and 0 < top_p < 1: sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1) cum = sorted_probs.cumsum(-1) mask = cum > top_p mask[..., 0] = False sorted_probs = sorted_probs.masked_fill(mask, 0) sorted_probs = sorted_probs / sorted_probs.sum(-1, keepdim=True).clamp_min(1e-12) samples = torch.multinomial(sorted_probs, 1).squeeze(-1) idx = sorted_idx.gather(-1, samples.unsqueeze(-1)).squeeze(-1) else: idx = torch.multinomial(probs, 1).squeeze(-1) seq = "".join(self.aa_list[i] for i in idx.tolist()) if return_conf: conf = probs.max(-1).values.mean().item() # avg per-pos max prob return seq, conf return seq @torch.inference_mode() def batch_embedding_to_sequences(self, embeddings, method="diverse", temperature=0.8, top_p=0.9, top_k=0, seed=None, return_conf=False, max_tokens=100_000): if embeddings.dim() == 2: return [self.embedding_to_sequence(embeddings, method, temperature, top_p, top_k, seed, return_conf)] B, L, V = embeddings.shape if seed is not None: torch.manual_seed(seed) # Batched logits to avoid OOM logits = [] start = 0 while start < B: chunk_bs = max(1, min(B - start, max_tokens // L)) logits.append(self._logits_from_hidden(embeddings[start:start+chunk_bs])) start += chunk_bs logits = torch.cat(logits, dim=0) # [B, L, 20] if method in ("nearest", "nearest_neighbor"): idx = logits.argmax(-1) # [B, L] probs = logits.softmax(-1) else: if temperature and temperature > 0: logits = logits / temperature probs = logits.softmax(-1) # [B, L, 20] B, L, V = probs.shape if top_k and top_k < V: kth = torch.topk(probs, top_k, dim=-1).values[..., -1:].expand_as(probs) probs = torch.where(probs >= kth, probs, torch.zeros_like(probs)) probs = probs / probs.sum(-1, keepdim=True).clamp_min(1e-12) if top_p and 0 < top_p < 1: flat = probs.view(-1, V) sorted_probs, sorted_idx = torch.sort(flat, descending=True, dim=-1) cum = sorted_probs.cumsum(-1) mask = cum > top_p mask[:, 0] = False sorted_probs = sorted_probs.masked_fill(mask, 0) sorted_probs = sorted_probs / sorted_probs.sum(-1, keepdim=True).clamp_min(1e-12) samples = torch.multinomial(sorted_probs, 1) # [B*L, 1] idx = sorted_idx.gather(-1, samples).view(B, L) # [B, L] else: idx = torch.multinomial(probs.view(-1, V), 1).view(B, L) seqs = ["".join(self.aa_list[i] for i in row.tolist()) for row in idx] if return_conf: conf = probs.max(-1).values.mean(-1).tolist() # [B] return list(zip(seqs, conf)) return seqs def validate_sequence(self, s): return all(a in set(self.aa_list) for a in s) def filter_valid_sequences(self, sequences): valid = [] for seq in sequences: if self.validate_sequence(seq): valid.append(seq) else: print(f"Warning: Invalid sequence found: {seq}") return valid def main(): """ Decode all CFG-generated peptide embeddings to sequences and analyze distribution. Uses the best trained model (loss: 0.017183, step: 53). """ print("=== CFG-Generated Peptide Sequence Decoder (Best Model) ===") # Initialize converter converter = EmbeddingToSequenceConverter() # Get today's date for filename today = datetime.now().strftime('%Y%m%d') # Load all CFG-generated embeddings (using best model) cfg_files = { 'No CFG (0.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_no_cfg_{today}.pt', 'Weak CFG (3.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_weak_cfg_{today}.pt', 'Strong CFG (7.5)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_strong_cfg_{today}.pt', 'Very Strong CFG (15.0)': f'/data2/edwardsun/generated_samples/generated_amps_best_model_very_strong_cfg_{today}.pt' } all_results = {} for cfg_name, file_path in cfg_files.items(): print(f"\n{'='*50}") print(f"Processing {cfg_name}...") print(f"Loading: {file_path}") try: # Load embeddings embeddings = torch.load(file_path, map_location='cpu') print(f"✓ Loaded {len(embeddings)} embeddings, shape: {embeddings.shape}") # Decode to sequences using diverse method print(f"Decoding sequences...") sequences = converter.batch_embedding_to_sequences(embeddings, method='diverse', temperature=0.5) # Filter valid sequences valid_sequences = converter.filter_valid_sequences(sequences) print(f"✓ Valid sequences: {len(valid_sequences)}/{len(sequences)}") # Store results all_results[cfg_name] = { 'sequences': valid_sequences, 'total': len(sequences), 'valid': len(valid_sequences) } # Show sample sequences print(f"\nSample sequences ({cfg_name}):") for i, seq in enumerate(valid_sequences[:5]): print(f" {i+1}: {seq}") except Exception as e: print(f"❌ Error processing {file_path}: {e}") all_results[cfg_name] = {'sequences': [], 'total': 0, 'valid': 0} # Analysis and comparison print(f"\n{'='*60}") print("CFG ANALYSIS SUMMARY") print(f"{'='*60}") for cfg_name, results in all_results.items(): sequences = results['sequences'] if sequences: # Calculate sequence statistics lengths = [len(seq) for seq in sequences] avg_length = np.mean(lengths) std_length = np.std(lengths) # Calculate amino acid composition all_aas = ''.join(sequences) aa_counts = {} for aa in 'ACDEFGHIKLMNPQRSTVWY': aa_counts[aa] = all_aas.count(aa) # Calculate diversity (unique sequences) unique_sequences = len(set(sequences)) diversity_ratio = unique_sequences / len(sequences) print(f"\n{cfg_name}:") print(f" Total sequences: {results['total']}") print(f" Valid sequences: {results['valid']}") print(f" Unique sequences: {unique_sequences}") print(f" Diversity ratio: {diversity_ratio:.3f}") print(f" Avg length: {avg_length:.1f} ± {std_length:.1f}") print(f" Length range: {min(lengths)}-{max(lengths)}") # Show top amino acids sorted_aas = sorted(aa_counts.items(), key=lambda x: x[1], reverse=True) print(f" Top 5 AAs: {', '.join([f'{aa}({count})' for aa, count in sorted_aas[:5]])}") # Create output directory if it doesn't exist output_dir = '/data2/edwardsun/decoded_sequences' os.makedirs(output_dir, exist_ok=True) # Save sequences to file with date output_file = os.path.join(output_dir, f"decoded_sequences_{cfg_name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')}_{today}.txt") with open(output_file, 'w') as f: f.write(f"# Decoded sequences from {cfg_name}\n") f.write(f"# Total: {results['total']}, Valid: {results['valid']}, Unique: {unique_sequences}\n") f.write(f"# Generated from best model (loss: 0.017183, step: 53)\n\n") for i, seq in enumerate(sequences): f.write(f"seq_{i+1:03d}\t{seq}\n") print(f" ✓ Saved to: {output_file}") # Overall comparison print(f"\n{'='*60}") print("OVERALL COMPARISON") print(f"{'='*60}") cfg_names = list(all_results.keys()) valid_counts = [all_results[name]['valid'] for name in cfg_names] unique_counts = [len(set(all_results[name]['sequences'])) for name in cfg_names] print(f"Valid sequences: {dict(zip(cfg_names, valid_counts))}") print(f"Unique sequences: {dict(zip(cfg_names, unique_counts))}") # Find most diverse and most similar if all(valid_counts): diversity_ratios = [unique_counts[i]/valid_counts[i] for i in range(len(valid_counts))] most_diverse = cfg_names[diversity_ratios.index(max(diversity_ratios))] least_diverse = cfg_names[diversity_ratios.index(min(diversity_ratios))] print(f"\nMost diverse: {most_diverse} (ratio: {max(diversity_ratios):.3f})") print(f"Least diverse: {least_diverse} (ratio: {min(diversity_ratios):.3f})") print(f"\n✓ Decoding complete! Check the output files for detailed sequences.") if __name__ == "__main__": main()