gbyuvd commited on
Commit
63e238b
·
verified ·
1 Parent(s): de8452e

Update evaluate_molecular_model.py

Browse files
Files changed (1) hide show
  1. evaluate_molecular_model.py +425 -338
evaluate_molecular_model.py CHANGED
@@ -1,339 +1,426 @@
1
- # evaluate_molecular_model.py
2
- import os
3
- import sys
4
- import json
5
- import argparse
6
- import random
7
- from typing import List, Optional
8
- from tqdm import tqdm
9
-
10
- import torch
11
- from rdkit import Chem
12
- from rdkit.Chem import AllChem
13
- from rdkit import RDLogger
14
- import selfies as sf
15
- import pandas as pd
16
-
17
- # Suppress RDKit warnings
18
- RDLogger.DisableLog('rdApp.*')
19
-
20
- # Add local path
21
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
22
-
23
- from FastChemTokenizerHF import FastChemTokenizerSelfies
24
- from ChemQ3MTP import ChemQ3MTPForCausalLM
25
-
26
- # ----------------------------
27
- # Robust Conversion & Validation (as per your spec)
28
- # ----------------------------
29
-
30
- def selfies_to_smiles(selfies_str: str) -> Optional[str]:
31
- """Convert SELFIES string to SMILES, handling tokenizer artifacts."""
32
- try:
33
- clean_selfies = selfies_str.replace(" ", "")
34
- return sf.decoder(clean_selfies)
35
- except Exception:
36
- return None
37
-
38
- def is_valid_smiles(smiles: str) -> bool:
39
- """Check if a SMILES string represents a valid molecule."""
40
- if not isinstance(smiles, str) or len(smiles.strip()) == 0:
41
- return False
42
- mol = Chem.MolFromSmiles(smiles.strip())
43
- return mol is not None
44
-
45
- def get_sa_label_and_confidence(selfies_str: str) -> tuple[str, float]:
46
- """Get SA label (Easy/Hard) and confidence from the model's SA classifier."""
47
- try:
48
- from ChemQ3MTP.rl_utils import get_sa_classifier
49
- classifier = get_sa_classifier()
50
- if classifier is None:
51
- return "Unknown", 0.0
52
-
53
- # Get raw classifier output: [{'label': 'Easy', 'score': 0.9187200665473938}]
54
- result = classifier(selfies_str, truncation=True, max_length=128)[0]
55
- return result["label"], result["score"]
56
- except Exception as e:
57
- return "Unknown", 0.0
58
-
59
- def get_morgan_fingerprint_from_smiles(smiles: str, radius=2, n_bits=2048):
60
- mol = Chem.MolFromSmiles(smiles)
61
- if mol is None:
62
- return None
63
- return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
64
-
65
- def tanimoto_sim(fp1, fp2):
66
- from rdkit.DataStructs import TanimotoSimilarity
67
- return TanimotoSimilarity(fp1, fp2)
68
-
69
- # ----------------------------
70
- # Main Evaluation Function
71
- # ----------------------------
72
-
73
- def evaluate_model(
74
- model_path: str,
75
- train_data_path: str = "../data/chunk_5.csv",
76
- n_samples: int = 1000,
77
- seed: int = 42,
78
- max_gen_len: int = 32
79
- ):
80
- torch.manual_seed(seed)
81
- random.seed(seed)
82
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
- print(f"🚀 Evaluating model at: {model_path}")
84
- print(f" Device: {device} | Samples: {n_samples} | Seed: {seed}\n")
85
-
86
- # Load tokenizer and model
87
- tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core")
88
- model = ChemQ3MTPForCausalLM.from_pretrained(model_path)
89
- model.to(device)
90
- model.eval()
91
-
92
- # Load training set and normalize SELFIES (remove spaces)
93
- print("📂 Loading and normalizing training set for novelty...")
94
- train_df = pd.read_csv(train_data_path)
95
- train_selfies_clean = set()
96
- for s in train_df["SELFIES"].dropna().astype(str):
97
- clean_s = s.replace(" ", "")
98
- train_selfies_clean.add(clean_s)
99
- print(f" Training set size: {len(train_selfies_clean)} unique (space-free) SELFIES\n")
100
-
101
- # === MTP-AWARE GENERATION ===
102
- print("GenerationStrategy: Using MTP-aware generation...")
103
- all_selfies_raw = []
104
- batch_size = 32
105
- num_batches = (n_samples + batch_size - 1) // batch_size
106
-
107
- with torch.no_grad():
108
- for _ in tqdm(range(num_batches), desc="Generating"):
109
- current_batch_size = min(batch_size, n_samples - len(all_selfies_raw))
110
- if current_batch_size <= 0:
111
- break
112
-
113
- input_ids = torch.full(
114
- (current_batch_size, 1),
115
- tokenizer.bos_token_id,
116
- dtype=torch.long,
117
- device=device
118
- )
119
-
120
- if hasattr(model, 'generate_with_logprobs'):
121
- try:
122
- outputs = model.generate_with_logprobs(
123
- input_ids=input_ids,
124
- max_new_tokens=25,
125
- temperature=1.0,
126
- top_k=50,
127
- top_p=0.95,
128
- do_sample=True,
129
- return_probs=True,
130
- tokenizer=tokenizer
131
- )
132
- batch_selfies = outputs[0] # list of raw SELFIES (may have spaces)
133
- except Exception as e:
134
- print(f"⚠️ MTP generation failed: {e}. Falling back.")
135
- gen_tokens = model.generate(
136
- input_ids,
137
- max_length=max_gen_len,
138
- do_sample=True,
139
- top_k=50,
140
- top_p=0.95,
141
- temperature=1.0,
142
- pad_token_id=tokenizer.pad_token_id,
143
- eos_token_id=tokenizer.eos_token_id
144
- )
145
- batch_selfies = [
146
- tokenizer.decode(seq, skip_special_tokens=True)
147
- for seq in gen_tokens
148
- ]
149
- else:
150
- gen_tokens = model.generate(
151
- input_ids,
152
- max_length=max_gen_len,
153
- do_sample=True,
154
- top_k=50,
155
- top_p=0.95,
156
- temperature=1.0,
157
- pad_token_id=tokenizer.pad_token_id,
158
- eos_token_id=tokenizer.eos_token_id
159
- )
160
- batch_selfies = [
161
- tokenizer.decode(seq, skip_special_tokens=True)
162
- for seq in gen_tokens
163
- ]
164
-
165
- all_selfies_raw.extend(batch_selfies)
166
- if len(all_selfies_raw) >= n_samples:
167
- break
168
-
169
- all_selfies_raw = all_selfies_raw[:n_samples]
170
- print(f"\n✅ Generated {len(all_selfies_raw)} raw SELFIES strings.\n")
171
-
172
- # Process: SELFIES → clean SELFIES → SMILES → valid molecules
173
- valid_records = []
174
- print("🧪 Processing SELFIES and converting to SMILES...")
175
- for i, raw_selfies in enumerate(tqdm(all_selfies_raw, desc="Converting")):
176
- # Clean the SELFIES (remove spaces as tokenizer uses whitespace)
177
- clean_selfies = raw_selfies.replace(" ", "")
178
-
179
- # Convert to SMILES
180
- smiles = selfies_to_smiles(clean_selfies)
181
-
182
- if smiles and is_valid_smiles(smiles):
183
- valid_records.append({
184
- "raw_selfies": raw_selfies,
185
- "selfies_clean": clean_selfies,
186
- "selfies": clean_selfies, # canonical version
187
- "smiles": smiles.strip()
188
- })
189
-
190
- # >>> DEBUG: Print multiple examples and SA label analysis <<<
191
- if valid_records:
192
- print("\n🔍 DEBUG: Sample generated molecules")
193
- print("-" * 70)
194
- for i in range(min(5, len(valid_records))):
195
- example = valid_records[i]
196
- print(f"Example {i+1}:")
197
- print(f" Raw SELFIES : {example['raw_selfies'][:80]}{'...' if len(example['raw_selfies']) > 80 else ''}")
198
- print(f" SMILES : {example['smiles']}")
199
-
200
- # Get SA label and confidence
201
- label, confidence = get_sa_label_and_confidence(example['raw_selfies'])
202
- print(f" SA Label : {label} (confidence: {confidence:.3f})")
203
-
204
- if i == 0:
205
- # Test SA classifier with simple molecules
206
- simple_label, simple_conf = get_sa_label_and_confidence('[C]')
207
- benzene_label, benzene_conf = get_sa_label_and_confidence('[c] [c] [c] [c] [c] [c] [Ring1] [=Branch1]')
208
- print(f" 🧪 SA Test - Simple molecule: {simple_label} ({simple_conf:.3f})")
209
- print(f" 🧪 SA Test - Benzene: {benzene_label} ({benzene_conf:.3f})")
210
-
211
- # Check molecule properties
212
- mol = Chem.MolFromSmiles(example['smiles'])
213
- if mol:
214
- print(f" Atoms : {mol.GetNumAtoms()}")
215
- print(f" Bonds : {mol.GetNumBonds()}")
216
- print()
217
- print("-" * 70)
218
-
219
- # SA Label distribution analysis
220
- sa_labels = []
221
- for r in valid_records[:100]:
222
- label, _ = get_sa_label_and_confidence(r["raw_selfies"])
223
- sa_labels.append(label)
224
-
225
- easy_count = sa_labels.count("Easy")
226
- hard_count = sa_labels.count("Hard")
227
- unknown_count = sa_labels.count("Unknown")
228
-
229
- print(f"🔍 SA Label Analysis (first 100 molecules):")
230
- print(f" Easy to synthesize: {easy_count}/100 ({easy_count}%)")
231
- print(f" Hard to synthesize: {hard_count}/100 ({hard_count}%)")
232
- if unknown_count > 0:
233
- print(f" Unknown/Failed: {unknown_count}/100 ({unknown_count}%)")
234
- else:
235
- print("\n⚠️ WARNING: No valid molecules generated in sample!")
236
- # <<< END DEBUG >>>
237
-
238
- # Now compute metrics...
239
- validity = len(valid_records) / n_samples
240
-
241
- unique_valid = list({r["selfies_clean"]: r for r in valid_records}.values())
242
- uniqueness = len(unique_valid) / len(valid_records) if valid_records else 0.0
243
-
244
- novel_count = sum(1 for r in unique_valid if r["selfies_clean"] not in train_selfies_clean)
245
- novelty = novel_count / len(unique_valid) if unique_valid else 0.0
246
-
247
- # SA Label Counts (using model's SA classifier)
248
- sa_labels_all = []
249
- for r in unique_valid:
250
- label, _ = get_sa_label_and_confidence(r["raw_selfies"])
251
- sa_labels_all.append(label)
252
-
253
- easy_total = sa_labels_all.count("Easy")
254
- hard_total = sa_labels_all.count("Hard")
255
- unknown_total = sa_labels_all.count("Unknown")
256
- total_labeled = len(sa_labels_all)
257
-
258
- # Internal Diversity (on SMILES)
259
- if len(unique_valid) >= 2:
260
- fps = []
261
- for r in unique_valid:
262
- fp = get_morgan_fingerprint_from_smiles(r["smiles"])
263
- if fp is not None:
264
- fps.append(fp)
265
- if len(fps) >= 2:
266
- total_sim, count = 0.0, 0
267
- for i in range(len(fps)):
268
- for j in range(i + 1, len(fps)):
269
- total_sim += tanimoto_sim(fps[i], fps[j])
270
- count += 1
271
- internal_diversity = 1.0 - (total_sim / count)
272
- else:
273
- internal_diversity = 0.0
274
- else:
275
- internal_diversity = 0.0
276
-
277
- # ----------------------------
278
- # Final Summary
279
- # ----------------------------
280
- print("\n" + "="*55)
281
- print("📊 MOLECULAR GENERATION EVALUATION SUMMARY")
282
- print("="*55)
283
- print(f"Model Path : {model_path}")
284
- print(f"Generation Mode : {'MTP-aware' if hasattr(model, 'generate_with_logprobs') else 'Standard'}")
285
- print(f"Samples Generated: {n_samples}")
286
- print("-"*55)
287
- print(f"Validity : {validity:.4f} ({len(valid_records)}/{n_samples})")
288
- print(f"Uniqueness : {uniqueness:.4f} (unique valid)")
289
- print(f"Novelty (vs train): {novelty:.4f} (space-free SELFIES)")
290
- print(f"Synthesis Labels : Easy: {easy_total}/{total_labeled} ({easy_total/max(1,total_labeled)*100:.1f}%) | Hard: {hard_total}/{total_labeled} ({hard_total/max(1,total_labeled)*100:.1f}%)")
291
- if unknown_total > 0:
292
- print(f" Unknown: {unknown_total}/{total_labeled} ({unknown_total/max(1,total_labeled)*100:.1f}%)")
293
- print(f"Internal Diversity: {internal_diversity:.4f} (1 - avg Tanimoto)")
294
- print("="*55)
295
-
296
- results = {
297
- "model_path": model_path,
298
- "generation_mode": "MTP-aware" if hasattr(model, 'generate_with_logprobs') else "standard",
299
- "n_samples": n_samples,
300
- "validity": validity,
301
- "uniqueness": uniqueness,
302
- "novelty": novelty,
303
- "sa_easy_count": easy_total,
304
- "sa_hard_count": hard_total,
305
- "sa_easy_percentage": easy_total/max(1,total_labeled)*100,
306
- "sa_hard_percentage": hard_total/max(1,total_labeled)*100,
307
- "internal_diversity": internal_diversity,
308
- "valid_molecules_count": len(valid_records)
309
- }
310
-
311
- if unknown_total > 0:
312
- results["sa_unknown_count"] = unknown_total
313
- results["sa_unknown_percentage"] = unknown_total/max(1,total_labeled)*100
314
-
315
- output_json = os.path.join(model_path, "evaluation_summary.json")
316
- with open(output_json, "w") as f:
317
- json.dump(results, f, indent=2)
318
- print(f"\n💾 Results saved to: {output_json}")
319
-
320
- return results
321
-
322
- # ----------------------------
323
- # CLI
324
- # ----------------------------
325
-
326
- if __name__ == "__main__":
327
- parser = argparse.ArgumentParser(description="Evaluate molecular generative model with MTP-aware generation")
328
- parser.add_argument("--model_path", type=str, required=True, help="Path to model checkpoint")
329
- parser.add_argument("--n_samples", type=int, default=1000, help="Number of molecules to generate")
330
- parser.add_argument("--seed", type=int, default=42, help="Random seed")
331
- parser.add_argument("--train_data", type=str, default="../data/chunk_5.csv", help="Training data CSV")
332
-
333
- args = parser.parse_args()
334
- evaluate_model(
335
- model_path=args.model_path,
336
- train_data_path=args.train_data,
337
- n_samples=args.n_samples,
338
- seed=args.seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  )
 
1
+ # evaluate_molecular_model.py
2
+ # evaluate_molecular_model.py
3
+ import os
4
+ import sys
5
+ import json
6
+ import argparse
7
+ import random
8
+ from typing import List, Optional
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+ from rdkit import Chem
13
+ from rdkit.Chem import AllChem
14
+ from rdkit import RDLogger
15
+ import selfies as sf
16
+ import pandas as pd
17
+
18
+ # Suppress RDKit warnings
19
+ RDLogger.DisableLog('rdApp.*')
20
+
21
+ # Add local path
22
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ from FastChemTokenizerHF import FastChemTokenizerSelfies
25
+ from ChemQ3MTP import ChemQ3MTPForCausalLM
26
+
27
+ # ----------------------------
28
+ # Robust Conversion & Validation (as per your spec)
29
+ # ----------------------------
30
+
31
+ def selfies_to_smiles(selfies_str: str) -> Optional[str]:
32
+ """Convert SELFIES string to SMILES, handling tokenizer artifacts."""
33
+ try:
34
+ clean_selfies = selfies_str.replace(" ", "")
35
+ return sf.decoder(clean_selfies)
36
+ except Exception:
37
+ return None
38
+
39
+
40
+ def is_valid_smiles(smiles: str) -> bool:
41
+ """
42
+ Check if a SMILES string represents a valid molecule.
43
+ FIXED: Now properly checks for heavy atoms (non-hydrogens) >= 3
44
+ and rejects disconnected/separated molecules
45
+ """
46
+ if not isinstance(smiles, str) or len(smiles.strip()) == 0:
47
+ return False
48
+
49
+ smiles = smiles.strip()
50
+
51
+ # FAST CHECK: Reject separated molecules (contains dots)
52
+ if '.' in smiles:
53
+ return False # Disconnected components indicated by dots
54
+
55
+ try:
56
+ mol = Chem.MolFromSmiles(smiles)
57
+ if mol is None:
58
+ return False
59
+
60
+ # CRITICAL FIX: Check heavy atoms (non-hydrogens), not total atoms
61
+ heavy_atoms = mol.GetNumHeavyAtoms() # This excludes hydrogens
62
+ if heavy_atoms < 3:
63
+ return False
64
+
65
+ return True
66
+ except Exception:
67
+ return False
68
+
69
+ def passes_durrant_lab_filter(smiles: str) -> bool:
70
+ """
71
+ Apply Durrant's lab filter to remove improbable substructures.
72
+ FIXED: More robust error handling, pattern checking, and disconnected molecule rejection.
73
+ Returns True if molecule passes the filter (is acceptable), False otherwise.
74
+ """
75
+ if not smiles or not isinstance(smiles, str) or len(smiles.strip()) == 0:
76
+ return False
77
+
78
+ try:
79
+ mol = Chem.MolFromSmiles(smiles.strip())
80
+ if mol is None:
81
+ return False
82
+
83
+ # Check heavy atoms again (belt and suspenders approach)
84
+ if mol.GetNumHeavyAtoms() < 3:
85
+ return False
86
+
87
+ # REJECT SEPARATED/DISCONNECTED MOLECULES (double check here too)
88
+ fragments = Chem.rdmolops.GetMolFrags(mol, asMols=False)
89
+ if len(fragments) > 1:
90
+ return False # Reject molecules with multiple disconnected parts
91
+
92
+ # Define SMARTS patterns for problematic substructures
93
+ problematic_patterns = [
94
+ "C=[N-]", # Carbon double bonded to negative nitrogen
95
+ "[N-]C=[N+]", # Nitrogen anion bonded to nitrogen cation
96
+ "[nH+]c[n-]", # Aromatic nitrogen cation adjacent to nitrogen anion
97
+ "[#7+]~[#7+]", # Positive nitrogen connected to positive nitrogen
98
+ "[#7-]~[#7-]", # Negative nitrogen connected to negative nitrogen
99
+ "[!#7]~[#7+]~[#7-]~[!#7]", # Bridge: non-nitrogen - pos nitrogen - neg nitrogen - non-nitrogen
100
+ "[#5]", # Boron atoms
101
+ "O=[PH](=O)([#8])([#8])", # Phosphoryl with hydroxyls
102
+ "N=c1cc[#7]c[#7]1", # Nitrogen in aromatic ring with another nitrogen
103
+ "[$([NX2H1]),$([NX3H2])]=C[$([OH]),$([O-])]", # N=CH-OH or N=CH-O-
104
+ ]
105
+
106
+ # Check for metals (excluding common biologically relevant ions)
107
+ metal_exclusions = {11, 12, 19, 20} # Na, Mg, K, Ca
108
+ for atom in mol.GetAtoms():
109
+ atomic_num = atom.GetAtomicNum()
110
+ # More precise metal detection
111
+ if atomic_num > 20 and atomic_num not in metal_exclusions:
112
+ return False
113
+
114
+ # Check for each problematic pattern
115
+ for pattern in problematic_patterns:
116
+ try:
117
+ patt_mol = Chem.MolFromSmarts(pattern)
118
+ if patt_mol is not None:
119
+ matches = mol.GetSubstructMatches(patt_mol)
120
+ if matches:
121
+ return False # Found problematic substructure
122
+ except Exception:
123
+ # If SMARTS parsing fails, continue to next pattern
124
+ continue
125
+
126
+ return True # Passed all checks
127
+
128
+ except Exception:
129
+ return False
130
+
131
+
132
+ def get_sa_label_and_confidence(selfies_str: str) -> tuple[str, float]:
133
+ """Get SA label (Easy/Hard) and confidence from the model's SA classifier."""
134
+ try:
135
+ from ChemQ3MTP.rl_utils import get_sa_classifier
136
+ classifier = get_sa_classifier()
137
+ if classifier is None:
138
+ return "Unknown", 0.0
139
+
140
+ # Get raw classifier output: [{'label': 'Easy', 'score': 0.9187200665473938}]
141
+ result = classifier(selfies_str, truncation=True, max_length=128)[0]
142
+ return result["label"], result["score"]
143
+ except Exception as e:
144
+ return "Unknown", 0.0
145
+
146
+ def get_morgan_fingerprint_from_smiles(smiles: str, radius=2, n_bits=2048):
147
+ mol = Chem.MolFromSmiles(smiles)
148
+ if mol is None:
149
+ return None
150
+ return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
151
+
152
+ def tanimoto_sim(fp1, fp2):
153
+ from rdkit.DataStructs import TanimotoSimilarity
154
+ return TanimotoSimilarity(fp1, fp2)
155
+
156
+ # ----------------------------
157
+ # Main Evaluation Function
158
+ # ----------------------------
159
+
160
+ def evaluate_model(
161
+ model_path: str,
162
+ train_data_path: str = "../data/chunk_5.csv",
163
+ n_samples: int = 1000,
164
+ seed: int = 42,
165
+ max_gen_len: int = 32
166
+ ):
167
+ torch.manual_seed(seed)
168
+ random.seed(seed)
169
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
+ print(f"🚀 Evaluating model at: {model_path}")
171
+ print(f" Device: {device} | Samples: {n_samples} | Seed: {seed}\n")
172
+
173
+ # Load tokenizer and model
174
+ tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core")
175
+ model = ChemQ3MTPForCausalLM.from_pretrained(model_path)
176
+ model.to(device)
177
+ model.eval()
178
+
179
+ # Load training set and normalize SELFIES (remove spaces)
180
+ print("📂 Loading and normalizing training set for novelty...")
181
+ train_df = pd.read_csv(train_data_path)
182
+ train_selfies_clean = set()
183
+ for s in train_df["SELFIES"].dropna().astype(str):
184
+ clean_s = s.replace(" ", "")
185
+ train_selfies_clean.add(clean_s)
186
+ print(f" Training set size: {len(train_selfies_clean)} unique (space-free) SELFIES\n")
187
+
188
+ # === MTP-AWARE GENERATION ===
189
+ print("GenerationStrategy: Using MTP-aware generation...")
190
+ all_selfies_raw = []
191
+ batch_size = 32
192
+ num_batches = (n_samples + batch_size - 1) // batch_size
193
+
194
+ with torch.no_grad():
195
+ for _ in tqdm(range(num_batches), desc="Generating"):
196
+ current_batch_size = min(batch_size, n_samples - len(all_selfies_raw))
197
+ if current_batch_size <= 0:
198
+ break
199
+
200
+ input_ids = torch.full(
201
+ (current_batch_size, 1),
202
+ tokenizer.bos_token_id,
203
+ dtype=torch.long,
204
+ device=device
205
+ )
206
+
207
+ if hasattr(model, 'generate_with_logprobs'):
208
+ try:
209
+ outputs = model.generate_with_logprobs(
210
+ input_ids=input_ids,
211
+ max_new_tokens=25,
212
+ temperature=1.0,
213
+ top_k=50,
214
+ top_p=0.95,
215
+ do_sample=True,
216
+ return_probs=True,
217
+ tokenizer=tokenizer
218
+ )
219
+ batch_selfies = outputs[0] # list of raw SELFIES (may have spaces)
220
+ except Exception as e:
221
+ print(f"⚠️ MTP generation failed: {e}. Falling back.")
222
+ gen_tokens = model.generate(
223
+ input_ids,
224
+ max_length=max_gen_len,
225
+ do_sample=True,
226
+ top_k=50,
227
+ top_p=0.95,
228
+ temperature=1.0,
229
+ pad_token_id=tokenizer.pad_token_id,
230
+ eos_token_id=tokenizer.eos_token_id
231
+ )
232
+ batch_selfies = [
233
+ tokenizer.decode(seq, skip_special_tokens=True)
234
+ for seq in gen_tokens
235
+ ]
236
+ else:
237
+ gen_tokens = model.generate(
238
+ input_ids,
239
+ max_length=max_gen_len,
240
+ do_sample=True,
241
+ top_k=50,
242
+ top_p=0.95,
243
+ temperature=1.0,
244
+ pad_token_id=tokenizer.pad_token_id,
245
+ eos_token_id=tokenizer.eos_token_id
246
+ )
247
+ batch_selfies = [
248
+ tokenizer.decode(seq, skip_special_tokens=True)
249
+ for seq in gen_tokens
250
+ ]
251
+
252
+ all_selfies_raw.extend(batch_selfies)
253
+ if len(all_selfies_raw) >= n_samples:
254
+ break
255
+
256
+ all_selfies_raw = all_selfies_raw[:n_samples]
257
+ print(f"\n✅ Generated {len(all_selfies_raw)} raw SELFIES strings.\n")
258
+
259
+ # Process: SELFIES → clean SELFIES → SMILES → valid molecules
260
+ valid_records = []
261
+ print("🧪 Processing SELFIES and converting to SMILES...")
262
+ for i, raw_selfies in enumerate(tqdm(all_selfies_raw, desc="Converting")):
263
+ # Clean the SELFIES (remove spaces as tokenizer uses whitespace)
264
+ clean_selfies = raw_selfies.replace(" ", "")
265
+
266
+ # Convert to SMILES
267
+ smiles = selfies_to_smiles(clean_selfies)
268
+
269
+ if smiles is not None and is_valid_smiles(smiles) and passes_durrant_lab_filter(smiles):
270
+ valid_records.append({
271
+ "raw_selfies": raw_selfies,
272
+ "selfies_clean": clean_selfies,
273
+ "selfies": clean_selfies, # canonical version
274
+ "smiles": smiles.strip()
275
+ })
276
+
277
+ # >>> DEBUG: Print multiple examples and SA label analysis <<<
278
+ if valid_records:
279
+ print("\n🔍 DEBUG: Sample generated molecules")
280
+ print("-" * 70)
281
+ for i in range(min(5, len(valid_records))):
282
+ example = valid_records[i]
283
+ print(f"Example {i+1}:")
284
+ print(f" Raw SELFIES : {example['raw_selfies'][:80]}{'...' if len(example['raw_selfies']) > 80 else ''}")
285
+ print(f" SMILES : {example['smiles']}")
286
+
287
+ # Get SA label and confidence
288
+ label, confidence = get_sa_label_and_confidence(example['raw_selfies'])
289
+ print(f" SA Label : {label} (confidence: {confidence:.3f})")
290
+
291
+ if i == 0:
292
+ # Test SA classifier with simple molecules
293
+ simple_label, simple_conf = get_sa_label_and_confidence('[C]')
294
+ benzene_label, benzene_conf = get_sa_label_and_confidence('[c] [c] [c] [c] [c] [c] [Ring1] [=Branch1]')
295
+ print(f" 🧪 SA Test - Simple molecule: {simple_label} ({simple_conf:.3f})")
296
+ print(f" 🧪 SA Test - Benzene: {benzene_label} ({benzene_conf:.3f})")
297
+
298
+ # Check molecule properties
299
+ mol = Chem.MolFromSmiles(example['smiles'])
300
+ if mol:
301
+ print(f" Atoms : {mol.GetNumAtoms()}")
302
+ print(f" Bonds : {mol.GetNumBonds()}")
303
+ print()
304
+ print("-" * 70)
305
+
306
+ # SA Label distribution analysis
307
+ sa_labels = []
308
+ for r in valid_records[:100]:
309
+ label, _ = get_sa_label_and_confidence(r["raw_selfies"])
310
+ sa_labels.append(label)
311
+
312
+ easy_count = sa_labels.count("Easy")
313
+ hard_count = sa_labels.count("Hard")
314
+ unknown_count = sa_labels.count("Unknown")
315
+
316
+ print(f"🔍 SA Label Analysis (first 100 molecules):")
317
+ print(f" Easy to synthesize: {easy_count}/100 ({easy_count}%)")
318
+ print(f" Hard to synthesize: {hard_count}/100 ({hard_count}%)")
319
+ if unknown_count > 0:
320
+ print(f" Unknown/Failed: {unknown_count}/100 ({unknown_count}%)")
321
+ else:
322
+ print("\n⚠️ WARNING: No valid molecules generated in sample!")
323
+ # <<< END DEBUG >>>
324
+
325
+ # Now compute metrics...
326
+ validity = len(valid_records) / n_samples
327
+
328
+ unique_valid = list({r["selfies_clean"]: r for r in valid_records}.values())
329
+ uniqueness = len(unique_valid) / len(valid_records) if valid_records else 0.0
330
+
331
+ novel_count = sum(1 for r in unique_valid if r["selfies_clean"] not in train_selfies_clean)
332
+ novelty = novel_count / len(unique_valid) if unique_valid else 0.0
333
+
334
+ # SA Label Counts (using model's SA classifier)
335
+ sa_labels_all = []
336
+ for r in unique_valid:
337
+ label, _ = get_sa_label_and_confidence(r["raw_selfies"])
338
+ sa_labels_all.append(label)
339
+
340
+ easy_total = sa_labels_all.count("Easy")
341
+ hard_total = sa_labels_all.count("Hard")
342
+ unknown_total = sa_labels_all.count("Unknown")
343
+ total_labeled = len(sa_labels_all)
344
+
345
+ # Internal Diversity (on SMILES)
346
+ if len(unique_valid) >= 2:
347
+ fps = []
348
+ for r in unique_valid:
349
+ fp = get_morgan_fingerprint_from_smiles(r["smiles"])
350
+ if fp is not None:
351
+ fps.append(fp)
352
+ if len(fps) >= 2:
353
+ total_sim, count = 0.0, 0
354
+ for i in range(len(fps)):
355
+ for j in range(i + 1, len(fps)):
356
+ total_sim += tanimoto_sim(fps[i], fps[j])
357
+ count += 1
358
+ internal_diversity = 1.0 - (total_sim / count)
359
+ else:
360
+ internal_diversity = 0.0
361
+ else:
362
+ internal_diversity = 0.0
363
+
364
+ # ----------------------------
365
+ # Final Summary
366
+ # ----------------------------
367
+ print("\n" + "="*55)
368
+ print("📊 MOLECULAR GENERATION EVALUATION SUMMARY")
369
+ print("="*55)
370
+ print(f"Model Path : {model_path}")
371
+ print(f"Generation Mode : {'MTP-aware' if hasattr(model, 'generate_with_logprobs') else 'Standard'}")
372
+ print(f"Samples Generated: {n_samples}")
373
+ print("-"*55)
374
+ print(f"Validity : {validity:.4f} ({len(valid_records)}/{n_samples})")
375
+ print(f"Uniqueness : {uniqueness:.4f} (unique valid)")
376
+ print(f"Novelty (vs train): {novelty:.4f} (space-free SELFIES)")
377
+ print(f"Synthesis Labels : Easy: {easy_total}/{total_labeled} ({easy_total/max(1,total_labeled)*100:.1f}%) | Hard: {hard_total}/{total_labeled} ({hard_total/max(1,total_labeled)*100:.1f}%)")
378
+ if unknown_total > 0:
379
+ print(f" Unknown: {unknown_total}/{total_labeled} ({unknown_total/max(1,total_labeled)*100:.1f}%)")
380
+ print(f"Internal Diversity: {internal_diversity:.4f} (1 - avg Tanimoto)")
381
+ print("="*55)
382
+
383
+ results = {
384
+ "model_path": model_path,
385
+ "generation_mode": "MTP-aware" if hasattr(model, 'generate_with_logprobs') else "standard",
386
+ "n_samples": n_samples,
387
+ "validity": validity,
388
+ "uniqueness": uniqueness,
389
+ "novelty": novelty,
390
+ "sa_easy_count": easy_total,
391
+ "sa_hard_count": hard_total,
392
+ "sa_easy_percentage": easy_total/max(1,total_labeled)*100,
393
+ "sa_hard_percentage": hard_total/max(1,total_labeled)*100,
394
+ "internal_diversity": internal_diversity,
395
+ "valid_molecules_count": len(valid_records)
396
+ }
397
+
398
+ if unknown_total > 0:
399
+ results["sa_unknown_count"] = unknown_total
400
+ results["sa_unknown_percentage"] = unknown_total/max(1,total_labeled)*100
401
+
402
+ output_json = os.path.join(model_path, "evaluation_summary.json")
403
+ with open(output_json, "w") as f:
404
+ json.dump(results, f, indent=2)
405
+ print(f"\n💾 Results saved to: {output_json}")
406
+
407
+ return results
408
+
409
+ # ----------------------------
410
+ # CLI
411
+ # ----------------------------
412
+
413
+ if __name__ == "__main__":
414
+ parser = argparse.ArgumentParser(description="Evaluate molecular generative model with MTP-aware generation")
415
+ parser.add_argument("--model_path", type=str, required=True, help="Path to model checkpoint")
416
+ parser.add_argument("--n_samples", type=int, default=1000, help="Number of molecules to generate")
417
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
418
+ parser.add_argument("--train_data", type=str, default="../data/chunk_5.csv", help="Training data CSV")
419
+
420
+ args = parser.parse_args()
421
+ evaluate_model(
422
+ model_path=args.model_path,
423
+ train_data_path=args.train_data,
424
+ n_samples=args.n_samples,
425
+ seed=args.seed
426
  )