carrotcake3's picture
Update app.py
93c3d7e verified
from flask import Flask, render_template, request, redirect, url_for, send_file, session
import sqlite3
import pandas as pd
import os
import joblib
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw
from sklearn.base import BaseEstimator, RegressorMixin
from huggingface_hub import hf_hub_download
import sys
import pubchempy as pcp
import importlib.util
from molecule_generator import run_generation
# ------------------------------
# MODEL REPOSITORY CONFIGURATION
# ------------------------------
REPO_ID_CN = "SalZa2004/Cetane_Number_Predictor"
REPO_ID_YSI = "SalZa2004/YSI_Predictor"
# -------------------------
# LOAD SHARED FEATURES
# -------------------------
def load_shared_features():
path = hf_hub_download(REPO_ID_CN, "shared_features.py")
spec = importlib.util.spec_from_file_location("shared_features", path)
shared = importlib.util.module_from_spec(spec)
spec.loader.exec_module(shared)
# Register module for joblib
sys.modules["shared_features"] = shared
sys.modules["main"] = shared
sys.modules["__main__"] = shared
globals()["FeatureSelector"] = shared.FeatureSelector
return shared
shared = load_shared_features()
# -------------------------
# LOAD CN MODEL
# -------------------------
def load_model_cn(shared):
import joblib
model_path = hf_hub_download(REPO_ID_CN, "model.joblib")
selector_path = hf_hub_download(REPO_ID_CN, "selector.joblib")
model = joblib.load(model_path)
selector = joblib.load(selector_path)
return model, selector
cn_model, cn_selector = load_model_cn(shared)
# -------------------------
# LOAD YSI MODEL
# -------------------------
def load_model_ysi(shared):
import joblib
model_path = hf_hub_download(REPO_ID_YSI, "model.joblib")
selector_path = hf_hub_download(REPO_ID_YSI, "selector.joblib")
model = joblib.load(model_path)
selector = joblib.load(selector_path)
return model, selector
ysi_model, ysi_selector = load_model_ysi(shared)
# -------------------------
# HELPER FUNCTIONS
# -------------------------
def validate_smiles(smiles):
if pd.isna(smiles) or smiles == "":
return False
return Chem.MolFromSmiles(smiles) is not None
def predict_cn(smiles):
X = shared.featurize_df([smiles], return_df=False)
if X is None:
return None
X = cn_selector.transform(X)
return float(cn_model.predict(X)[0])
def predict_ysi(smiles):
X = shared.featurize_df([smiles], return_df=False)
if X is None:
return None
X = ysi_selector.transform(X)
return float(ysi_model.predict(X)[0])
def pubchem_name_to_smiles(name):
"""Return canonical SMILES from a compound name."""
if not name or not isinstance(name, str):
return None
name = name.strip()
if name == "":
return None
try:
results = pcp.get_compounds(name, "name")
if not results:
return None
return results[0].canonical_smiles
except Exception:
return None
def pubchem_smiles_to_name(smiles):
"""Return preferred IUPAC name from SMILES."""
try:
results = pcp.get_compounds(smiles, "smiles")
if not results:
return None
compound = results[0]
# Prefer IUPAC name if available
if getattr(compound, "iupac_name", None):
return compound.iupac_name
# Fallback to title
return compound.title
except Exception:
return None
# Run Flask app
app = Flask(__name__)
@app.route("/")
def dashboard():
return render_template("dashboard.html")
@app.route("/pure", methods=["GET", "POST"])
def pure_predictor():
results = []
error = None
#------------------
# CSV FILE INPUT
#------------------
if request.method == "POST" and request.form.get("mode") == "csv":
csv_file = request.files.get("csv_file")
if not csv_file:
error = "No CSV file uploaded."
return render_template("pure_predictor.html", results=results, error=error)
try:
df = pd.read_csv(csv_file)
if "SMILES" not in df.columns:
error = "CSV must contain a 'SMILES' column."
return render_template("pure_predictor.html", results=results, error=error)
for i, row in df.iterrows():
raw_name = row.get("IUPAC names", "")
if pd.isna(raw_name):
name = ""
else:
name = str(raw_name).strip()
raw_smiles = row.get("SMILES", "")
if pd.isna(raw_smiles):
smiles = ""
else:
smiles = str(raw_smiles).strip()
entry = {
"name": name if name else "-",
"smiles": smiles,
"dcn": None,
"ysi": None,
"error": None,
"img_id": None
}
# STEP 1 — If SMILES empty → convert NAME → SMILES
if smiles == "" and name not in ("", None, "-"):
final_smiles = pubchem_name_to_smiles(name)
if final_smiles is None:
entry["error"] = "Name not found in PubChem"
results.append(entry)
continue
else:
final_smiles = smiles
# STEP 2 — Validate SMILES
if not validate_smiles(final_smiles):
entry["error"] = "Invalid SMILES"
results.append(entry)
continue
entry["smiles"] = final_smiles
# STEP 3 — Convert SMILES → IUPAC name
iupac_name = pubchem_smiles_to_name(final_smiles)
if (not name or name == "-") and iupac_name:
entry["name"] = iupac_name
# STEP 4 — Predict DCN
pred_cn = predict_cn(final_smiles)
pred_ysi = predict_ysi(final_smiles)
if pred_cn is None and pred_ysi is None:
entry["error"] = "Prediction failed"
else:
entry["dcn"] = round(pred_cn, 2) if pred_cn is not None else None
entry["ysi"] = round(pred_ysi, 2) if pred_ysi is not None else None
mol = Chem.MolFromSmiles(final_smiles)
img = Draw.MolToImage(mol, size=(300, 250))
img_filename = f"mol_csv_{i}.png"
img_path = os.path.join("static", "generated", img_filename)
os.makedirs(os.path.dirname(img_path), exist_ok=True)
img.save(img_path)
entry["img_id"] = img_filename
results.append(entry)
return render_template("pure_predictor.html", results=results)
except Exception as e:
error = f"Failed to read CSV file: {e}"
return render_template("pure_predictor.html", results=results, error=error)
#------------------
# MANUAL INPUT
#------------------
elif request.method == "POST":
names = request.form.getlist("fuel_name[]")
smiles_list = request.form.getlist("smiles[]")
for i, (name, smiles) in enumerate(zip(names, smiles_list)):
name = name.strip()
smiles = smiles.strip()
entry = {
"name": name if name else "-",
"smiles": smiles,
"dcn": None,
"ysi": None,
"error": None,
"img_id": None
}
# STEP 1 — If SMILES empty → convert NAME → SMILES
if smiles == "" and name not in ("", None, "-"):
final_smiles = pubchem_name_to_smiles(name)
if final_smiles is None:
entry["error"] = "Name not found in PubChem"
results.append(entry)
continue
else:
final_smiles = smiles
# STEP 2 — Validate SMILES
if not validate_smiles(final_smiles):
entry["error"] = "Invalid SMILES"
results.append(entry)
continue
entry["smiles"] = final_smiles
# STEP 3 — Convert SMILES → IUPAC name
iupac_name = pubchem_smiles_to_name(final_smiles)
if (not name or name == "-") and iupac_name:
entry["name"] = iupac_name
# STEP 4 — Predict & draw molecule
pred_cn = predict_cn(final_smiles)
pred_ysi = predict_ysi(final_smiles)
if pred_cn is None and pred_ysi is None:
entry["error"] = "Prediction failed"
else:
entry["dcn"] = round(pred_cn, 2) if pred_cn is not None else None
entry["ysi"] = round(pred_ysi, 2) if pred_ysi is not None else None
mol = Chem.MolFromSmiles(final_smiles)
img = Draw.MolToImage(mol, size=(300, 250))
img_filename = f"mol_{i}.png"
img_path = os.path.join("static", "generated", img_filename)
os.makedirs(os.path.dirname(img_path), exist_ok=True)
img.save(img_path)
entry["img_id"] = img_filename
results.append(entry)
return render_template("pure_predictor.html", results=results, error=error)
from flask import send_file
import io
import json
@app.route("/download_results", methods=["POST"])
def download_results():
import io, json
results_json = request.form.get("results_data")
results = json.loads(results_json)
cleaned_rows = []
for r in results:
cleaned_rows.append({
"IUPAC Name": r.get("name", "-"),
"SMILES": r.get("smiles", "-"),
"Predicted DCN": r.get("dcn", None),
"Predicted YSI": r.get("ysi", None),
"Status": ("OK" if r.get("error") in (None, "", "OK") else r.get("error"))
})
df = pd.DataFrame(cleaned_rows)
# column order
df = df[["IUPAC Name", "SMILES", "Predicted DCN", "Predicted YSI", "Status"]]
buffer = io.StringIO()
df.to_csv(buffer, index=False)
buffer.seek(0)
return send_file(
io.BytesIO(buffer.getvalue().encode()),
mimetype="text/csv",
as_attachment=True,
download_name="pure_fuel_predictions.csv"
)
@app.route("/mixture")
def mixture_predictor():
return render_template("mixture_predictor.html")
@app.route("/generate", methods=["GET", "POST"])
def generative():
final_table = None
pareto_table = None
error = None
if request.method == "POST":
try:
target_cn = float(request.form.get("target_cn"))
minimize_ysi = request.form.get("minimize_ysi") == "on"
final_df, pareto_df = run_generation(
target_cn=target_cn,
minimize_ysi=minimize_ysi
)
final_table = final_df.to_html(index=False, classes="table table-striped")
pareto_table = pareto_df.to_html(index=False, classes="table table-striped")
except Exception as e:
error = str(e)
return render_template(
"generative.html",
final_table=final_table,
pareto_table=pareto_table,
error=error
)
@app.route("/constraints")
def constraints():
return render_template("constraints.html")
@app.route("/dataset")
def dataset():
return render_template("dataset.html")
@app.route("/download/pure")
def download_pure():
return send_file(
"datasets/pure_fuel_properties_compiled_v2.xlsx",
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
as_attachment=True,
download_name="pure_fuel_dataset.xlsx"
)
@app.route("/download/mixture")
def download_mixture():
return send_file(
"datasets/mixture_database.xlsx",
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
as_attachment=True,
download_name="mixture_fuel_dataset.xlsx"
)
@app.route("/about")
def about():
return render_template("about.html")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)