Spaces:
Running
Running
Upload 3 files
Browse files- data_prep.py +33 -0
- molecule_generator.py +571 -0
- shared_features.py +223 -0
data_prep.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sqlite3
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) # goes from src/ → project root
|
| 8 |
+
DB_PATH = os.path.join(PROJECT_ROOT, "database_main.db")
|
| 9 |
+
|
| 10 |
+
TARGET_CN = "cn" # Cetane number
|
| 11 |
+
N_FOLDS = 5
|
| 12 |
+
TOP_K = 5
|
| 13 |
+
print("Connecting to SQLite database...")
|
| 14 |
+
conn = sqlite3.connect("database_main.db")
|
| 15 |
+
|
| 16 |
+
query = """
|
| 17 |
+
SELECT
|
| 18 |
+
F.Fuel_Name,
|
| 19 |
+
F.SMILES,
|
| 20 |
+
T.Standardised_DCN AS cn
|
| 21 |
+
FROM FUEL F
|
| 22 |
+
LEFT JOIN TARGET T ON F.fuel_id = T.fuel_id
|
| 23 |
+
"""
|
| 24 |
+
df = pd.read_sql_query(query, conn)
|
| 25 |
+
conn.close()
|
| 26 |
+
df.dropna(subset=[TARGET_CN, "SMILES"], inplace=True)
|
| 27 |
+
|
| 28 |
+
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
|
| 29 |
+
print(df.head())
|
| 30 |
+
print(df.columns)
|
| 31 |
+
|
| 32 |
+
def load_data():
|
| 33 |
+
return df
|
molecule_generator.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from dataclasses import dataclass, asdict
|
| 5 |
+
from typing import List, Dict, Optional, Tuple, Callable
|
| 6 |
+
import joblib
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import random
|
| 10 |
+
from rdkit import Chem
|
| 11 |
+
from crem.crem import mutate_mol
|
| 12 |
+
from sklearn.base import BaseEstimator, RegressorMixin
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 15 |
+
|
| 16 |
+
# === Project Setup ===
|
| 17 |
+
PROJECT_ROOT = Path.cwd()
|
| 18 |
+
SRC_DIR = PROJECT_ROOT / "src"
|
| 19 |
+
sys.path.append(str(PROJECT_ROOT))
|
| 20 |
+
|
| 21 |
+
from shared_features import FeatureSelector, featurize_df
|
| 22 |
+
from data_prep import df
|
| 23 |
+
|
| 24 |
+
class GenericPredictor:
|
| 25 |
+
"""Generic predictor that works for any property model."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, model_dir: Path, property_name: str):
|
| 28 |
+
"""
|
| 29 |
+
Initialize predictor from a model directory.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model_dir: Path to the model directory containing artifacts/
|
| 33 |
+
property_name: Name of the property (for display purposes)
|
| 34 |
+
"""
|
| 35 |
+
print(f"Loading {property_name} Predictor...")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
model_path = model_dir / "model.joblib"
|
| 39 |
+
selector_path = model_dir / "selector.joblib"
|
| 40 |
+
|
| 41 |
+
# Debug info
|
| 42 |
+
print(f">>> MODEL PATH: {model_path}")
|
| 43 |
+
print(f">>> SELECTOR PATH: {selector_path}")
|
| 44 |
+
print(f">>> MODEL EXISTS: {model_path.exists()}")
|
| 45 |
+
print(f">>> SELECTOR EXISTS: {selector_path.exists()}")
|
| 46 |
+
|
| 47 |
+
# Load artifacts
|
| 48 |
+
self.model = joblib.load(model_path)
|
| 49 |
+
self.selector = FeatureSelector.load(selector_path)
|
| 50 |
+
self.property_name = property_name
|
| 51 |
+
|
| 52 |
+
print(f"✓ {property_name} Predictor ready!\n")
|
| 53 |
+
|
| 54 |
+
def predict(self, smiles_list):
|
| 55 |
+
"""Inference on a list of SMILES strings."""
|
| 56 |
+
if isinstance(smiles_list, str):
|
| 57 |
+
smiles_list = [smiles_list]
|
| 58 |
+
|
| 59 |
+
X_full = featurize_df(smiles_list, return_df=False)
|
| 60 |
+
|
| 61 |
+
if X_full is None:
|
| 62 |
+
print(f"⚠ Warning: No valid molecules found for {self.property_name}!")
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
X_selected = self.selector.transform(X_full)
|
| 66 |
+
predictions = self.model.predict(X_selected)
|
| 67 |
+
return predictions.tolist()
|
| 68 |
+
|
| 69 |
+
def predict_with_details(self, smiles_list):
|
| 70 |
+
"""Inference with valid/invalid info."""
|
| 71 |
+
if isinstance(smiles_list, str):
|
| 72 |
+
smiles_list = [smiles_list]
|
| 73 |
+
|
| 74 |
+
df = pd.DataFrame({"SMILES": smiles_list})
|
| 75 |
+
X_full, df_valid = featurize_df(df, return_df=True)
|
| 76 |
+
|
| 77 |
+
col_name = f"Predicted_{self.property_name}"
|
| 78 |
+
|
| 79 |
+
if X_full is None:
|
| 80 |
+
return pd.DataFrame(columns=["SMILES", col_name, "Valid"])
|
| 81 |
+
|
| 82 |
+
X_selected = self.selector.transform(X_full)
|
| 83 |
+
predictions = self.model.predict(X_selected)
|
| 84 |
+
|
| 85 |
+
df_valid[col_name] = predictions
|
| 86 |
+
df_valid["Valid"] = True
|
| 87 |
+
|
| 88 |
+
all_results = pd.DataFrame({"SMILES": smiles_list})
|
| 89 |
+
all_results = all_results.merge(
|
| 90 |
+
df_valid[["SMILES", col_name, "Valid"]],
|
| 91 |
+
on="SMILES", how="left"
|
| 92 |
+
)
|
| 93 |
+
all_results["Valid"] = all_results["Valid"].fillna(False)
|
| 94 |
+
|
| 95 |
+
return all_results
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Predictor paths relative to project root
|
| 99 |
+
# === Hugging Face Predictor Paths ===
|
| 100 |
+
HF_MODELS = {
|
| 101 |
+
"cn": "SalZa2004/Cetane_Number_Predictor",
|
| 102 |
+
"ysi": "SalZa2004/YSI_Predictor",
|
| 103 |
+
"bp": "SalZa2004/Boiling_Point_Predictor",
|
| 104 |
+
"density": "SalZa2004/Density_Predictor",
|
| 105 |
+
"lhv": "SalZa2004/LHV_Predictor",
|
| 106 |
+
"dynamic_viscosity": "SalZa2004/Dynamic_Viscosity_Predictor",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
PREDICTOR_PATHS = {
|
| 110 |
+
key: Path(
|
| 111 |
+
snapshot_download(
|
| 112 |
+
repo_id=repo,
|
| 113 |
+
repo_type="model"
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
for key, repo in HF_MODELS.items()
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass
|
| 121 |
+
class EvolutionConfig:
|
| 122 |
+
"""Configuration for evolutionary algorithm."""
|
| 123 |
+
target_cn: float
|
| 124 |
+
minimize_ysi: bool = True
|
| 125 |
+
generations: int = 6
|
| 126 |
+
population_size: int = 100
|
| 127 |
+
mutations_per_parent: int = 5
|
| 128 |
+
survivor_fraction: float = 0.5
|
| 129 |
+
min_bp: float = 60
|
| 130 |
+
max_bp: float = 250
|
| 131 |
+
min_dynamic_viscosity: float = 0.0
|
| 132 |
+
max_dynamic_viscosity: float = 2.0
|
| 133 |
+
min_density: float = 720
|
| 134 |
+
min_lhv: float = 30
|
| 135 |
+
use_bp_filter: bool = True
|
| 136 |
+
use_density_filter: bool = True
|
| 137 |
+
use_lhv_filter: bool = True
|
| 138 |
+
use_dynamic_viscosity_filter: bool = True
|
| 139 |
+
batch_size: int = 50
|
| 140 |
+
max_offspring_attempts: int = 10
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class Molecule:
|
| 144 |
+
"""Represents a molecule with its properties."""
|
| 145 |
+
smiles: str
|
| 146 |
+
cn: float
|
| 147 |
+
cn_error: float
|
| 148 |
+
bp: Optional[float] = None
|
| 149 |
+
ysi: Optional[float] = None
|
| 150 |
+
density: Optional[float] = None
|
| 151 |
+
lhv: Optional[float] = None
|
| 152 |
+
dynamic_viscosity: Optional[float] = None
|
| 153 |
+
|
| 154 |
+
def dominates(self, other: 'Molecule') -> bool:
|
| 155 |
+
"""Check if this molecule Pareto-dominates another."""
|
| 156 |
+
better_cn = self.cn_error <= other.cn_error
|
| 157 |
+
better_ysi = self.ysi <= other.ysi if self.ysi is not None else True
|
| 158 |
+
strictly_better = (self.cn_error < other.cn_error or
|
| 159 |
+
(self.ysi is not None and self.ysi < other.ysi))
|
| 160 |
+
return better_cn and better_ysi and strictly_better
|
| 161 |
+
|
| 162 |
+
def to_dict(self) -> Dict:
|
| 163 |
+
"""Convert to dictionary for DataFrame creation."""
|
| 164 |
+
return {k: v for k, v in asdict(self).items() if v is not None}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class PropertyPredictor:
|
| 168 |
+
"""Handles batch prediction for all molecular properties."""
|
| 169 |
+
|
| 170 |
+
def __init__(self, config: EvolutionConfig):
|
| 171 |
+
self.config = config
|
| 172 |
+
|
| 173 |
+
# Initialize only the predictors we need
|
| 174 |
+
self.predictors = {}
|
| 175 |
+
|
| 176 |
+
# Always need CN predictor
|
| 177 |
+
self.predictors['cn'] = GenericPredictor(
|
| 178 |
+
PREDICTOR_PATHS['cn'],
|
| 179 |
+
'Cetane Number'
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Conditional predictors
|
| 183 |
+
if config.minimize_ysi:
|
| 184 |
+
self.predictors['ysi'] = GenericPredictor(
|
| 185 |
+
PREDICTOR_PATHS['ysi'],
|
| 186 |
+
'YSI'
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if config.use_bp_filter:
|
| 190 |
+
self.predictors['bp'] = GenericPredictor(
|
| 191 |
+
PREDICTOR_PATHS['bp'],
|
| 192 |
+
'Boiling Point'
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if config.use_density_filter:
|
| 196 |
+
self.predictors['density'] = GenericPredictor(
|
| 197 |
+
PREDICTOR_PATHS['density'],
|
| 198 |
+
'Density'
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if config.use_lhv_filter:
|
| 202 |
+
self.predictors['lhv'] = GenericPredictor(
|
| 203 |
+
PREDICTOR_PATHS['lhv'],
|
| 204 |
+
'Lower Heating Value'
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if config.use_dynamic_viscosity_filter:
|
| 208 |
+
self.predictors['dynamic_viscosity'] = GenericPredictor(
|
| 209 |
+
PREDICTOR_PATHS['dynamic_viscosity'],
|
| 210 |
+
'Dynamic Viscosity'
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Define validation rules
|
| 214 |
+
self.validators = {
|
| 215 |
+
'bp': lambda v: self.config.min_bp <= v <= self.config.max_bp,
|
| 216 |
+
'density': lambda v: v > self.config.min_density,
|
| 217 |
+
'lhv': lambda v: v > self.config.min_lhv,
|
| 218 |
+
'dynamic_viscosity': lambda v: self.config.min_dynamic_viscosity < v <= self.config.max_dynamic_viscosity
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
def _safe_predict(self, predictions: List) -> List[Optional[float]]:
|
| 222 |
+
"""Safely convert predictions, handling None/NaN/inf values."""
|
| 223 |
+
return [
|
| 224 |
+
float(pred) if pred is not None and np.isfinite(pred) else None
|
| 225 |
+
for pred in predictions
|
| 226 |
+
]
|
| 227 |
+
|
| 228 |
+
def _predict_batch(self, property_name: str, smiles_list: List[str]) -> List[Optional[float]]:
|
| 229 |
+
"""Generic batch prediction method."""
|
| 230 |
+
predictor = self.predictors.get(property_name)
|
| 231 |
+
if not smiles_list or predictor is None:
|
| 232 |
+
return [None] * len(smiles_list)
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
predictions = predictor.predict(smiles_list)
|
| 236 |
+
return self._safe_predict(predictions)
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(f"Warning: Batch {property_name.upper()} prediction failed: {e}")
|
| 239 |
+
return [None] * len(smiles_list)
|
| 240 |
+
|
| 241 |
+
def predict_all_properties(self, smiles_list: List[str]) -> Dict[str, List[Optional[float]]]:
|
| 242 |
+
"""Predict all properties for a batch of SMILES."""
|
| 243 |
+
return {
|
| 244 |
+
prop: self._predict_batch(prop, smiles_list)
|
| 245 |
+
for prop in ['cn', 'ysi', 'bp', 'density', 'lhv', 'dynamic_viscosity'] # Check all possible properties
|
| 246 |
+
if prop in self.predictors # Only predict if predictor exists
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
def is_valid(self, property_name: str, value: Optional[float]) -> bool:
|
| 250 |
+
"""Check if a property value is valid according to config rules."""
|
| 251 |
+
if value is None:
|
| 252 |
+
return True
|
| 253 |
+
validator = self.validators.get(property_name)
|
| 254 |
+
return validator(value) if validator else True
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class Population:
|
| 258 |
+
"""Manages the population of molecules."""
|
| 259 |
+
|
| 260 |
+
def __init__(self, config: EvolutionConfig):
|
| 261 |
+
self.config = config
|
| 262 |
+
self.molecules: List[Molecule] = []
|
| 263 |
+
self.seen_smiles: set = set()
|
| 264 |
+
|
| 265 |
+
def add_molecule(self, mol: Molecule) -> bool:
|
| 266 |
+
"""Add a molecule if it's not already in the population."""
|
| 267 |
+
if mol.smiles in self.seen_smiles:
|
| 268 |
+
return False
|
| 269 |
+
self.molecules.append(mol)
|
| 270 |
+
self.seen_smiles.add(mol.smiles)
|
| 271 |
+
return True
|
| 272 |
+
|
| 273 |
+
def add_molecules(self, molecules: List[Molecule]) -> int:
|
| 274 |
+
"""Add multiple molecules, return count added."""
|
| 275 |
+
return sum(self.add_molecule(mol) for mol in molecules)
|
| 276 |
+
|
| 277 |
+
def pareto_front(self) -> List[Molecule]:
|
| 278 |
+
"""Extract the Pareto front from the population."""
|
| 279 |
+
if not self.config.minimize_ysi:
|
| 280 |
+
return []
|
| 281 |
+
|
| 282 |
+
return [
|
| 283 |
+
mol for mol in self.molecules
|
| 284 |
+
if not any(other.dominates(mol) for other in self.molecules if other is not mol)
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
def get_survivors(self) -> List[Molecule]:
|
| 288 |
+
"""Select survivors for the next generation."""
|
| 289 |
+
target_size = int(self.config.population_size * self.config.survivor_fraction)
|
| 290 |
+
|
| 291 |
+
if self.config.minimize_ysi:
|
| 292 |
+
survivors = self.pareto_front()
|
| 293 |
+
|
| 294 |
+
# Sort key for combined objectives
|
| 295 |
+
sort_key = lambda m: m.cn_error + m.ysi
|
| 296 |
+
|
| 297 |
+
if len(survivors) > target_size:
|
| 298 |
+
survivors = sorted(survivors, key=sort_key)[:target_size]
|
| 299 |
+
elif len(survivors) < target_size:
|
| 300 |
+
remainder = [m for m in self.molecules if m not in survivors]
|
| 301 |
+
remainder = sorted(remainder, key=sort_key)
|
| 302 |
+
survivors.extend(remainder[:target_size - len(survivors)])
|
| 303 |
+
else:
|
| 304 |
+
survivors = sorted(self.molecules, key=lambda m: m.cn_error)[:target_size]
|
| 305 |
+
|
| 306 |
+
return survivors
|
| 307 |
+
|
| 308 |
+
def to_dataframe(self) -> pd.DataFrame:
|
| 309 |
+
"""Convert population to DataFrame."""
|
| 310 |
+
df = pd.DataFrame([m.to_dict() for m in self.molecules])
|
| 311 |
+
|
| 312 |
+
sort_cols = ["cn_error", "ysi"] if self.config.minimize_ysi else ["cn_error"]
|
| 313 |
+
df = df.sort_values(sort_cols, ascending=True)
|
| 314 |
+
df.insert(0, 'rank', range(1, len(df) + 1))
|
| 315 |
+
return df
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class MolecularEvolution:
|
| 319 |
+
"""Main evolutionary algorithm coordinator."""
|
| 320 |
+
|
| 321 |
+
REP_DB_PATH = PROJECT_ROOT / "frag_db" / "diesel_fragments.db"
|
| 322 |
+
|
| 323 |
+
def __init__(self, config: EvolutionConfig):
|
| 324 |
+
self.config = config
|
| 325 |
+
self.predictor = PropertyPredictor(config)
|
| 326 |
+
self.population = Population(config)
|
| 327 |
+
|
| 328 |
+
def _mutate_molecule(self, mol: Chem.Mol) -> List[str]:
|
| 329 |
+
"""Generate mutations for a molecule using CREM."""
|
| 330 |
+
try:
|
| 331 |
+
mutants = list(mutate_mol(
|
| 332 |
+
mol,
|
| 333 |
+
db_name=str(self.REP_DB_PATH),
|
| 334 |
+
max_size=2,
|
| 335 |
+
return_mol=False
|
| 336 |
+
))
|
| 337 |
+
return [m for m in mutants if m and m not in self.population.seen_smiles]
|
| 338 |
+
except Exception:
|
| 339 |
+
return []
|
| 340 |
+
|
| 341 |
+
def _create_molecules(self, smiles_list: List[str]) -> List[Molecule]:
|
| 342 |
+
"""Create Molecule objects from SMILES with predictions."""
|
| 343 |
+
if not smiles_list:
|
| 344 |
+
return []
|
| 345 |
+
|
| 346 |
+
# Get all predictions at once
|
| 347 |
+
predictions = self.predictor.predict_all_properties(smiles_list)
|
| 348 |
+
|
| 349 |
+
molecules = []
|
| 350 |
+
for i, smiles in enumerate(smiles_list):
|
| 351 |
+
# Extract predictions for this molecule
|
| 352 |
+
props = {k: v[i] for k, v in predictions.items()}
|
| 353 |
+
|
| 354 |
+
# Validate required properties
|
| 355 |
+
if props['cn'] is None:
|
| 356 |
+
continue
|
| 357 |
+
if self.config.minimize_ysi and props['ysi'] is None:
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
# Validate filtered properties
|
| 361 |
+
if not all(self.predictor.is_valid(k, props[k]) for k in ['bp', 'density', 'lhv', 'dynamic_viscosity']):
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
molecules.append(Molecule(
|
| 365 |
+
smiles=smiles,
|
| 366 |
+
cn=props['cn'],
|
| 367 |
+
cn_error=abs(props['cn'] - self.config.target_cn),
|
| 368 |
+
bp=props['bp'],
|
| 369 |
+
ysi=props['ysi'],
|
| 370 |
+
density=props['density'],
|
| 371 |
+
lhv=props['lhv'],
|
| 372 |
+
dynamic_viscosity=props['dynamic_viscosity']
|
| 373 |
+
))
|
| 374 |
+
|
| 375 |
+
return molecules
|
| 376 |
+
|
| 377 |
+
def initialize_population(self, initial_smiles: List[str]) -> int:
|
| 378 |
+
"""Initialize the population from initial SMILES."""
|
| 379 |
+
print("Predicting properties for initial population...")
|
| 380 |
+
molecules = self._create_molecules(initial_smiles)
|
| 381 |
+
return self.population.add_molecules(molecules)
|
| 382 |
+
|
| 383 |
+
def _log_generation_stats(self, generation: int):
|
| 384 |
+
"""Log statistics for the current generation."""
|
| 385 |
+
mols = self.population.molecules
|
| 386 |
+
best_cn = min(mols, key=lambda m: m.cn_error)
|
| 387 |
+
avg_cn_err = np.mean([m.cn_error for m in mols])
|
| 388 |
+
|
| 389 |
+
log_dict = {
|
| 390 |
+
"generation": generation,
|
| 391 |
+
"best_cn_error": best_cn.cn_error,
|
| 392 |
+
"population_size": len(mols),
|
| 393 |
+
"avg_cn_error": avg_cn_err,
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
print_msg = (f"Gen {generation}/{self.config.generations} | "
|
| 397 |
+
f"Pop {len(mols)} | "
|
| 398 |
+
f"Best CN err: {best_cn.cn_error:.3f} | "
|
| 399 |
+
f"Avg CN err: {avg_cn_err:.3f}")
|
| 400 |
+
|
| 401 |
+
if self.config.minimize_ysi:
|
| 402 |
+
front = self.population.pareto_front()
|
| 403 |
+
best_ysi = min(mols, key=lambda m: m.ysi)
|
| 404 |
+
avg_ysi = np.mean([m.ysi for m in mols])
|
| 405 |
+
|
| 406 |
+
log_dict.update({
|
| 407 |
+
"best_ysi": best_ysi.ysi,
|
| 408 |
+
"pareto_size": len(front),
|
| 409 |
+
"avg_ysi": avg_ysi,
|
| 410 |
+
})
|
| 411 |
+
|
| 412 |
+
print_msg += (f" | Best YSI: {best_ysi.ysi:.3f} | "
|
| 413 |
+
f"Avg YSI: {avg_ysi:.3f} | "
|
| 414 |
+
f"Pareto size: {len(front)}")
|
| 415 |
+
|
| 416 |
+
print(print_msg)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _generate_offspring(self, survivors: List[Molecule]) -> List[Molecule]:
|
| 420 |
+
"""Generate offspring from survivors."""
|
| 421 |
+
target_count = self.config.population_size - len(survivors)
|
| 422 |
+
max_attempts = target_count * self.config.max_offspring_attempts
|
| 423 |
+
|
| 424 |
+
all_children = []
|
| 425 |
+
new_molecules = []
|
| 426 |
+
|
| 427 |
+
for attempt in range(max_attempts):
|
| 428 |
+
if len(new_molecules) >= target_count:
|
| 429 |
+
break
|
| 430 |
+
|
| 431 |
+
# Generate mutations
|
| 432 |
+
parent = random.choice(survivors)
|
| 433 |
+
mol = Chem.MolFromSmiles(parent.smiles)
|
| 434 |
+
if mol is None:
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
children = self._mutate_molecule(mol)
|
| 438 |
+
all_children.extend(children[:self.config.mutations_per_parent])
|
| 439 |
+
|
| 440 |
+
# Process in batches
|
| 441 |
+
if len(all_children) >= self.config.batch_size:
|
| 442 |
+
print(f" → Evaluating batch of {len(all_children)} offspring...")
|
| 443 |
+
new_molecules.extend(self._create_molecules(all_children))
|
| 444 |
+
all_children = []
|
| 445 |
+
|
| 446 |
+
# Process remaining children
|
| 447 |
+
if all_children:
|
| 448 |
+
print(f" → Evaluating final batch of {len(all_children)} offspring...")
|
| 449 |
+
new_molecules.extend(self._create_molecules(all_children))
|
| 450 |
+
|
| 451 |
+
return new_molecules
|
| 452 |
+
|
| 453 |
+
def _run_evolution_loop(self):
|
| 454 |
+
"""Run the main evolution loop."""
|
| 455 |
+
for gen in range(1, self.config.generations + 1):
|
| 456 |
+
self._log_generation_stats(gen)
|
| 457 |
+
|
| 458 |
+
survivors = self.population.get_survivors()
|
| 459 |
+
offspring = self._generate_offspring(survivors)
|
| 460 |
+
|
| 461 |
+
# Create new population
|
| 462 |
+
new_pop = Population(self.config)
|
| 463 |
+
new_pop.add_molecules(survivors + offspring)
|
| 464 |
+
self.population = new_pop
|
| 465 |
+
|
| 466 |
+
def _generate_results(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 467 |
+
"""Generate final results DataFrames."""
|
| 468 |
+
final_df = self.population.to_dataframe()
|
| 469 |
+
|
| 470 |
+
if self.config.minimize_ysi and "ysi" in final_df.columns:
|
| 471 |
+
final_df = final_df[
|
| 472 |
+
(final_df["cn_error"] < 5) &
|
| 473 |
+
(final_df["ysi"] < 50)
|
| 474 |
+
].sort_values(["cn_error", "ysi"], ascending=True)
|
| 475 |
+
|
| 476 |
+
# overwrite rank safely
|
| 477 |
+
final_df["rank"] = range(1, len(final_df) + 1)
|
| 478 |
+
|
| 479 |
+
if self.config.minimize_ysi:
|
| 480 |
+
pareto_mols = self.population.pareto_front()
|
| 481 |
+
pareto_df = pd.DataFrame([m.to_dict() for m in pareto_mols])
|
| 482 |
+
|
| 483 |
+
if not pareto_df.empty:
|
| 484 |
+
pareto_df = pareto_df[
|
| 485 |
+
(pareto_df['cn_error'] < 5) & (pareto_df['ysi'] < 50)
|
| 486 |
+
].sort_values(["cn_error", "ysi"], ascending=True)
|
| 487 |
+
pareto_df.insert(0, 'rank', range(1, len(pareto_df) + 1))
|
| 488 |
+
else:
|
| 489 |
+
pareto_df = pd.DataFrame()
|
| 490 |
+
|
| 491 |
+
return final_df, pareto_df
|
| 492 |
+
|
| 493 |
+
def evolve(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 494 |
+
"""Run the evolutionary algorithm."""
|
| 495 |
+
# Initialize
|
| 496 |
+
df_bins = pd.qcut(df["cn"], q=30)
|
| 497 |
+
initial_smiles = (
|
| 498 |
+
df.groupby(df_bins)
|
| 499 |
+
.apply(lambda x: x.sample(20, random_state=42))
|
| 500 |
+
.reset_index(drop=True)["SMILES"]
|
| 501 |
+
.tolist()
|
| 502 |
+
)
|
| 503 |
+
init_count = self.initialize_population(initial_smiles)
|
| 504 |
+
|
| 505 |
+
if init_count == 0:
|
| 506 |
+
print("❌ No valid initial molecules")
|
| 507 |
+
return pd.DataFrame(), pd.DataFrame()
|
| 508 |
+
|
| 509 |
+
print(f"✓ Initial population size: {init_count}")
|
| 510 |
+
|
| 511 |
+
# Evolution
|
| 512 |
+
self._run_evolution_loop()
|
| 513 |
+
|
| 514 |
+
# Results
|
| 515 |
+
return self._generate_results()
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def get_user_config() -> EvolutionConfig:
|
| 519 |
+
"""Get configuration from user input."""
|
| 520 |
+
print("\n" + "="*70)
|
| 521 |
+
print("MOLECULAR EVOLUTION WITH GENETIC ALGORITHM")
|
| 522 |
+
print("="*70)
|
| 523 |
+
|
| 524 |
+
while True:
|
| 525 |
+
target = float(input("Enter target CN: ") or "50")
|
| 526 |
+
if target > 40:
|
| 527 |
+
break
|
| 528 |
+
print("⚠️ Target CN is too low, optimization may be challenging.")
|
| 529 |
+
print("Consider using a higher target CN for better results.\n")
|
| 530 |
+
|
| 531 |
+
minimize_ysi = input("Minimise YSI (y/n): ").strip().lower() in ['y', 'yes']
|
| 532 |
+
|
| 533 |
+
return EvolutionConfig(target_cn=target, minimize_ysi=minimize_ysi)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def save_results(final_df: pd.DataFrame, pareto_df: pd.DataFrame, minimize_ysi: bool):
|
| 537 |
+
"""Save results to CSV files."""
|
| 538 |
+
results_dir = Path("results")
|
| 539 |
+
results_dir.mkdir(exist_ok=True)
|
| 540 |
+
|
| 541 |
+
final_df.to_csv(results_dir / "final_population.csv", index=False)
|
| 542 |
+
if minimize_ysi and not pareto_df.empty:
|
| 543 |
+
pareto_df.to_csv(results_dir / "pareto_front.csv", index=False)
|
| 544 |
+
|
| 545 |
+
print("\nSaved to results/")
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def display_results(final_df: pd.DataFrame, pareto_df: pd.DataFrame, minimize_ysi: bool):
|
| 549 |
+
"""Display results to console."""
|
| 550 |
+
cols = (["rank", "smiles", "cn", "cn_error", "ysi", "bp", "density", "lhv", "dynamic_viscosity"])
|
| 551 |
+
|
| 552 |
+
print("\n=== Best Candidates ===")
|
| 553 |
+
print(final_df.head(10)[cols].to_string(index=False))
|
| 554 |
+
|
| 555 |
+
if minimize_ysi and not pareto_df.empty:
|
| 556 |
+
print("\n=== PARETO FRONT (ranked) ===")
|
| 557 |
+
print(pareto_df[["rank", "smiles", "cn", "cn_error", "ysi", "bp", "density", "lhv", "dynamic_viscosity"]]
|
| 558 |
+
.head(20).to_string(index=False))
|
| 559 |
+
|
| 560 |
+
def main():
|
| 561 |
+
"""Main execution function."""
|
| 562 |
+
config = get_user_config()
|
| 563 |
+
|
| 564 |
+
evolution = MolecularEvolution(config)
|
| 565 |
+
final_df, pareto_df = evolution.evolve()
|
| 566 |
+
# Display and save results
|
| 567 |
+
display_results(final_df, pareto_df, config.minimize_ysi)
|
| 568 |
+
save_results(final_df, pareto_df, config.minimize_ysi)
|
| 569 |
+
|
| 570 |
+
if __name__ == "__main__":
|
| 571 |
+
main()
|
shared_features.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sqlite3
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.model_selection import train_test_split
|
| 6 |
+
|
| 7 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
|
| 8 |
+
DB_PATH = os.path.join(PROJECT_ROOT, "data", "database", "database_main.db")
|
| 9 |
+
|
| 10 |
+
def load_raw_data():
|
| 11 |
+
"""Load raw data from database."""
|
| 12 |
+
print("Connecting to SQLite database...")
|
| 13 |
+
conn = sqlite3.connect(DB_PATH)
|
| 14 |
+
|
| 15 |
+
query = """
|
| 16 |
+
SELECT
|
| 17 |
+
F.Fuel_Name,
|
| 18 |
+
F.SMILES,
|
| 19 |
+
T.Standardised_DCN AS cn
|
| 20 |
+
FROM FUEL F
|
| 21 |
+
LEFT JOIN TARGET T ON F.fuel_id = T.fuel_id
|
| 22 |
+
"""
|
| 23 |
+
df = pd.read_sql_query(query, conn)
|
| 24 |
+
conn.close()
|
| 25 |
+
|
| 26 |
+
# Clean data
|
| 27 |
+
df.dropna(subset=["cn", "SMILES"], inplace=True)
|
| 28 |
+
|
| 29 |
+
return df
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ============================================================================
|
| 33 |
+
# 2. FEATURIZATION MODULE
|
| 34 |
+
# ============================================================================
|
| 35 |
+
from rdkit import Chem
|
| 36 |
+
from rdkit.Chem import Descriptors, rdFingerprintGenerator
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
# Get descriptor names globally
|
| 40 |
+
DESCRIPTOR_NAMES = [d[0] for d in Descriptors._descList]
|
| 41 |
+
desc_functions = [d[1] for d in Descriptors._descList]
|
| 42 |
+
|
| 43 |
+
def morgan_fp_from_mol(mol, radius=2, n_bits=2048):
|
| 44 |
+
"""Generate Morgan fingerprint."""
|
| 45 |
+
fpgen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits)
|
| 46 |
+
fp = fpgen.GetFingerprint(mol)
|
| 47 |
+
arr = np.array(list(fp.ToBitString()), dtype=int)
|
| 48 |
+
return arr
|
| 49 |
+
|
| 50 |
+
def physchem_desc_from_mol(mol):
|
| 51 |
+
"""Calculate physicochemical descriptors."""
|
| 52 |
+
try:
|
| 53 |
+
desc = np.array([fn(mol) for fn in desc_functions], dtype=np.float32)
|
| 54 |
+
desc = np.nan_to_num(desc, nan=0.0, posinf=0.0, neginf=0.0)
|
| 55 |
+
return desc
|
| 56 |
+
except:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
def featurize(smiles):
|
| 60 |
+
"""Convert SMILES to feature vector."""
|
| 61 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 62 |
+
if mol is None:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
fp = morgan_fp_from_mol(mol)
|
| 66 |
+
desc = physchem_desc_from_mol(mol)
|
| 67 |
+
|
| 68 |
+
if fp is None or desc is None:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
return np.hstack([fp, desc])
|
| 72 |
+
|
| 73 |
+
def featurize_df(df, smiles_col="SMILES", return_df=True):
|
| 74 |
+
"""
|
| 75 |
+
Featurize a DataFrame or list of SMILES (vectorized for speed).
|
| 76 |
+
"""
|
| 77 |
+
# Handle different input types
|
| 78 |
+
if isinstance(df, (list, np.ndarray)):
|
| 79 |
+
df = pd.DataFrame({smiles_col: df})
|
| 80 |
+
elif isinstance(df, pd.Series):
|
| 81 |
+
df = pd.DataFrame({smiles_col: df})
|
| 82 |
+
|
| 83 |
+
# Convert all SMILES to molecules in batch
|
| 84 |
+
mols = [Chem.MolFromSmiles(smi) for smi in df[smiles_col]]
|
| 85 |
+
|
| 86 |
+
features = []
|
| 87 |
+
valid_indices = []
|
| 88 |
+
|
| 89 |
+
# Process valid molecules
|
| 90 |
+
for i, mol in enumerate(tqdm(mols, desc="Featurizing")):
|
| 91 |
+
if mol is None:
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
fp = morgan_fp_from_mol(mol)
|
| 96 |
+
desc = physchem_desc_from_mol(mol)
|
| 97 |
+
|
| 98 |
+
if fp is not None and desc is not None:
|
| 99 |
+
features.append(np.hstack([fp, desc]))
|
| 100 |
+
valid_indices.append(i)
|
| 101 |
+
except:
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
if len(features) == 0:
|
| 105 |
+
return (None, None) if return_df else None
|
| 106 |
+
|
| 107 |
+
X = np.vstack(features)
|
| 108 |
+
|
| 109 |
+
if return_df:
|
| 110 |
+
df_valid = df.iloc[valid_indices].reset_index(drop=True)
|
| 111 |
+
return X, df_valid
|
| 112 |
+
else:
|
| 113 |
+
return X
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ============================================================================
|
| 117 |
+
# 3. FEATURE SELECTOR CLASS
|
| 118 |
+
# ============================================================================
|
| 119 |
+
import joblib
|
| 120 |
+
|
| 121 |
+
class FeatureSelector:
|
| 122 |
+
"""Feature selection pipeline that can be saved and reused."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, n_morgan=2048, corr_threshold=0.95, top_k=300):
|
| 125 |
+
self.n_morgan = n_morgan
|
| 126 |
+
self.corr_threshold = corr_threshold
|
| 127 |
+
self.top_k = top_k
|
| 128 |
+
|
| 129 |
+
# Filled during fit()
|
| 130 |
+
self.corr_cols_to_drop = None
|
| 131 |
+
self.selected_indices = None
|
| 132 |
+
self.is_fitted = False
|
| 133 |
+
|
| 134 |
+
def fit(self, X, y):
|
| 135 |
+
"""Fit the feature selector on training data."""
|
| 136 |
+
print("\n" + "="*70)
|
| 137 |
+
print("FITTING FEATURE SELECTOR")
|
| 138 |
+
print("="*70)
|
| 139 |
+
|
| 140 |
+
# Step 1: Split Morgan and descriptors
|
| 141 |
+
X_mfp = X[:, :self.n_morgan]
|
| 142 |
+
X_desc = X[:, self.n_morgan:]
|
| 143 |
+
|
| 144 |
+
print(f"Morgan fingerprints: {X_mfp.shape[1]}")
|
| 145 |
+
print(f"Descriptors: {X_desc.shape[1]}")
|
| 146 |
+
|
| 147 |
+
# Step 2: Remove correlated descriptors
|
| 148 |
+
desc_df = pd.DataFrame(X_desc)
|
| 149 |
+
corr_matrix = desc_df.corr().abs()
|
| 150 |
+
upper = corr_matrix.where(
|
| 151 |
+
np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
self.corr_cols_to_drop = [
|
| 155 |
+
col for col in upper.columns if any(upper[col] > self.corr_threshold)
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
print(f"Correlated descriptors removed: {len(self.corr_cols_to_drop)}")
|
| 159 |
+
|
| 160 |
+
desc_filtered = desc_df.drop(columns=self.corr_cols_to_drop, axis=1).values
|
| 161 |
+
X_corr = np.hstack([X_mfp, desc_filtered])
|
| 162 |
+
|
| 163 |
+
print(f"Features after correlation filter: {X_corr.shape[1]}")
|
| 164 |
+
|
| 165 |
+
# Step 3: Feature importance selection
|
| 166 |
+
from sklearn.ensemble import ExtraTreesRegressor
|
| 167 |
+
|
| 168 |
+
print("Running feature importance selection...")
|
| 169 |
+
model = ExtraTreesRegressor(n_estimators=100, random_state=42, n_jobs=-1)
|
| 170 |
+
model.fit(X_corr, y)
|
| 171 |
+
|
| 172 |
+
importances = model.feature_importances_
|
| 173 |
+
indices = np.argsort(importances)[::-1]
|
| 174 |
+
|
| 175 |
+
self.selected_indices = indices[:self.top_k]
|
| 176 |
+
|
| 177 |
+
print(f"Final selected features: {len(self.selected_indices)}")
|
| 178 |
+
|
| 179 |
+
self.is_fitted = True
|
| 180 |
+
return self
|
| 181 |
+
|
| 182 |
+
def transform(self, X):
|
| 183 |
+
"""Apply the fitted feature selection to new data."""
|
| 184 |
+
if not self.is_fitted:
|
| 185 |
+
raise RuntimeError("FeatureSelector must be fitted before transform!")
|
| 186 |
+
|
| 187 |
+
# Step 1: Split Morgan and descriptors
|
| 188 |
+
X_mfp = X[:, :self.n_morgan]
|
| 189 |
+
X_desc = X[:, self.n_morgan:]
|
| 190 |
+
|
| 191 |
+
# Step 2: Remove same correlated descriptors
|
| 192 |
+
desc_df = pd.DataFrame(X_desc)
|
| 193 |
+
desc_filtered = desc_df.drop(columns=self.corr_cols_to_drop, axis=1).values
|
| 194 |
+
X_corr = np.hstack([X_mfp, desc_filtered])
|
| 195 |
+
|
| 196 |
+
# Step 3: Select same important features
|
| 197 |
+
X_selected = X_corr[:, self.selected_indices]
|
| 198 |
+
|
| 199 |
+
return X_selected
|
| 200 |
+
|
| 201 |
+
def fit_transform(self, X, y):
|
| 202 |
+
"""Fit and transform in one step."""
|
| 203 |
+
return self.fit(X, y).transform(X)
|
| 204 |
+
|
| 205 |
+
def save(self, filepath='feature_selector.joblib'):
|
| 206 |
+
"""Save the fitted selector."""
|
| 207 |
+
if not self.is_fitted:
|
| 208 |
+
raise RuntimeError("Cannot save unfitted selector!")
|
| 209 |
+
|
| 210 |
+
# Create directory if it doesn't exist
|
| 211 |
+
os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else '.', exist_ok=True)
|
| 212 |
+
|
| 213 |
+
joblib.dump(self, filepath)
|
| 214 |
+
print(f"✓ Feature selector saved to {filepath}")
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def load(filepath='feature_selector.joblib'):
|
| 218 |
+
"""Load a fitted selector."""
|
| 219 |
+
selector = joblib.load(filepath)
|
| 220 |
+
if not selector.is_fitted:
|
| 221 |
+
raise RuntimeError("Loaded selector is not fitted!")
|
| 222 |
+
print(f"✓ Feature selector loaded from {filepath}")
|
| 223 |
+
return selector
|