Upload folder using huggingface_hub
Browse files- .gitignore +3 -0
- README.md +5 -3
- data/checkpoints/diffusion_list_20250501.json +0 -0
- data/checkpoints/diffusion_list_20250601.json +0 -0
- data/olfaction-vision-language-dataset.json +0 -0
- model/diffusion_gnn/diffusion_gnn.py +600 -0
- model/diffusion_gnn/inference.py +65 -0
- model/ovle-small/gat/gat_gnn.pth +3 -0
- model/ovle-small/gat/gat_gnn_state_dict.pth +3 -0
- model/ovle-small/gat/gat_olf_encoder.pth +3 -0
- model/ovle-small/gat/gat_olf_encoder_state_dict.pth +3 -0
- model/ovle-small/nn/gnn.pth +3 -0
- model/ovle-small/nn/gnn_state_dict.pth +3 -0
- model/ovle-small/nn/olf_encoder.pth +3 -0
- model/ovle-small/nn/olf_encoder_state_dict.pth +3 -0
- model/train_ovm.py +0 -0
- model_cards/ovle-large.md +82 -0
- model_cards/ovle-small.md +82 -0
- requirements.txt +4 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*__pycache__
|
| 2 |
+
*.idea
|
| 3 |
+
*.DS_Store
|
README.md
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
# Olfaction-Vision-Language Embeddings
|
| 2 |
+
This repository is a series of multimodal machine learning models trained on olfaction, vision, and language data.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
data/checkpoints/diffusion_list_20250501.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/checkpoints/diffusion_list_20250601.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/olfaction-vision-language-dataset.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/diffusion_gnn/diffusion_gnn.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from torch_geometric.data import Data
|
| 9 |
+
from torch_geometric.nn import MessagePassing
|
| 10 |
+
from torch_geometric.utils import add_self_loops
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
from rdkit.Chem import AllChem, Descriptors
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
# -------- UTILS: Molecule Processing with 3D Coordinates --------
|
| 16 |
+
def smiles_to_graph(smiles):
|
| 17 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 18 |
+
if mol is None:
|
| 19 |
+
return None
|
| 20 |
+
mol = Chem.AddHs(mol)
|
| 21 |
+
try:
|
| 22 |
+
AllChem.EmbedMolecule(mol, AllChem.ETKDG())
|
| 23 |
+
AllChem.UFFOptimizeMolecule(mol)
|
| 24 |
+
except:
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
conf = mol.GetConformer()
|
| 28 |
+
atoms = mol.GetAtoms()
|
| 29 |
+
bonds = mol.GetBonds()
|
| 30 |
+
|
| 31 |
+
node_feats = []
|
| 32 |
+
pos = []
|
| 33 |
+
edge_index = []
|
| 34 |
+
edge_attrs = []
|
| 35 |
+
|
| 36 |
+
for atom in atoms:
|
| 37 |
+
# Normalize atomic number
|
| 38 |
+
node_feats.append([atom.GetAtomicNum() / 100.0])
|
| 39 |
+
position = conf.GetAtomPosition(atom.GetIdx())
|
| 40 |
+
pos.append([position.x, position.y, position.z])
|
| 41 |
+
|
| 42 |
+
for bond in bonds:
|
| 43 |
+
start = bond.GetBeginAtomIdx()
|
| 44 |
+
end = bond.GetEndAtomIdx()
|
| 45 |
+
edge_index.append([start, end])
|
| 46 |
+
edge_index.append([end, start])
|
| 47 |
+
bond_type = bond.GetBondType()
|
| 48 |
+
bond_class = {
|
| 49 |
+
Chem.BondType.SINGLE: 0,
|
| 50 |
+
Chem.BondType.DOUBLE: 1,
|
| 51 |
+
Chem.BondType.TRIPLE: 2,
|
| 52 |
+
Chem.BondType.AROMATIC: 3
|
| 53 |
+
}.get(bond_type, 0)
|
| 54 |
+
edge_attrs.extend([[bond_class], [bond_class]])
|
| 55 |
+
|
| 56 |
+
return Data(
|
| 57 |
+
x=torch.tensor(node_feats, dtype=torch.float),
|
| 58 |
+
pos=torch.tensor(pos, dtype=torch.float),
|
| 59 |
+
edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(),
|
| 60 |
+
edge_attr=torch.tensor(edge_attrs, dtype=torch.long)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# -------- EGNN Layer --------
|
| 64 |
+
class EGNNLayer(MessagePassing):
|
| 65 |
+
def __init__(self, node_dim):
|
| 66 |
+
super().__init__(aggr='add')
|
| 67 |
+
self.node_mlp = nn.Sequential(
|
| 68 |
+
nn.Linear(node_dim * 2 + 1, 128),
|
| 69 |
+
nn.ReLU(),
|
| 70 |
+
nn.Linear(128, node_dim)
|
| 71 |
+
)
|
| 72 |
+
self.coord_mlp = nn.Sequential(
|
| 73 |
+
nn.Linear(1, 128),
|
| 74 |
+
nn.ReLU(),
|
| 75 |
+
nn.Linear(128, 1)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, x, pos, edge_index):
|
| 79 |
+
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
|
| 80 |
+
self.coord_updates = torch.zeros_like(pos)
|
| 81 |
+
x_out, coord_out = self.propagate(edge_index, x=x, pos=pos)
|
| 82 |
+
return x_out, pos + coord_out
|
| 83 |
+
|
| 84 |
+
def message(self, x_i, x_j, pos_i, pos_j):
|
| 85 |
+
edge_vec = pos_j - pos_i
|
| 86 |
+
dist = ((edge_vec**2).sum(dim=-1, keepdim=True) + 1e-8).sqrt()
|
| 87 |
+
h = torch.cat([x_i, x_j, dist], dim=-1)
|
| 88 |
+
edge_msg = self.node_mlp(h)
|
| 89 |
+
coord_update = self.coord_mlp(dist) * edge_vec
|
| 90 |
+
return edge_msg, coord_update
|
| 91 |
+
|
| 92 |
+
def message_and_aggregate(self, adj_t, x):
|
| 93 |
+
raise NotImplementedError("This EGNN layer does not support sparse adjacency matrices.")
|
| 94 |
+
|
| 95 |
+
def aggregate(self, inputs, index):
|
| 96 |
+
edge_msg, coord_update = inputs
|
| 97 |
+
aggr_msg = torch.zeros(index.max() + 1, edge_msg.size(-1), device=edge_msg.device).index_add_(0, index, edge_msg)
|
| 98 |
+
aggr_coord = torch.zeros(index.max() + 1, coord_update.size(-1), device=coord_update.device).index_add_(0, index, coord_update)
|
| 99 |
+
return aggr_msg, aggr_coord
|
| 100 |
+
|
| 101 |
+
def update(self, aggr_out, x):
|
| 102 |
+
msg, coord_update = aggr_out
|
| 103 |
+
return x + msg, coord_update
|
| 104 |
+
|
| 105 |
+
# -------- Time Embedding --------
|
| 106 |
+
class TimeEmbedding(nn.Module):
|
| 107 |
+
def __init__(self, embed_dim):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.net = nn.Sequential(
|
| 110 |
+
nn.Linear(1, 32),
|
| 111 |
+
nn.ReLU(),
|
| 112 |
+
nn.Linear(32, embed_dim)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, t):
|
| 116 |
+
return self.net(t.view(-1, 1).float() / 1000)
|
| 117 |
+
|
| 118 |
+
# -------- Olfactory Conditioning --------
|
| 119 |
+
class OlfactoryConditioner(nn.Module):
|
| 120 |
+
def __init__(self, num_labels, embed_dim):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.embedding = nn.Linear(num_labels, embed_dim)
|
| 123 |
+
|
| 124 |
+
def forward(self, labels):
|
| 125 |
+
return self.embedding(labels.float())
|
| 126 |
+
|
| 127 |
+
# -------- EGNN Diffusion Model --------
|
| 128 |
+
class EGNNDiffusionModel(nn.Module):
|
| 129 |
+
def __init__(self, node_dim, embed_dim):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.time_embed = TimeEmbedding(embed_dim)
|
| 132 |
+
self.egnn1 = EGNNLayer(node_dim + embed_dim * 2)
|
| 133 |
+
self.egnn2 = EGNNLayer(node_dim + embed_dim * 2)
|
| 134 |
+
self.bond_predictor = nn.Sequential(
|
| 135 |
+
nn.Linear((node_dim + embed_dim * 2) * 2, 64),
|
| 136 |
+
nn.ReLU(),
|
| 137 |
+
nn.Linear(64, 4)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x_t, pos, edge_index, t, cond_embed):
|
| 141 |
+
batch_size = x_t.size(0)
|
| 142 |
+
t_embed = self.time_embed(t).expand(batch_size, -1)
|
| 143 |
+
cond_embed = cond_embed.expand(batch_size, -1)
|
| 144 |
+
x_input = torch.cat([x_t, cond_embed, t_embed], dim=1)
|
| 145 |
+
x1, pos1 = self.egnn1(x_input, pos, edge_index)
|
| 146 |
+
x2, pos2 = self.egnn2(x1, pos1, edge_index)
|
| 147 |
+
edge_feats = torch.cat([x2[edge_index[0]], x2[edge_index[1]]], dim=1)
|
| 148 |
+
bond_logits = self.bond_predictor(edge_feats)
|
| 149 |
+
return x2[:, :x_t.shape[1]], bond_logits
|
| 150 |
+
|
| 151 |
+
# -------- Noise and Training --------
|
| 152 |
+
def add_noise(x_0, noise, t):
|
| 153 |
+
return x_0 + noise * (t / 1000.0)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def plot_data(mu, sigma, color, title):
|
| 157 |
+
all_losses = np.array(mu)
|
| 158 |
+
sigma_losses = np.array(sigma)
|
| 159 |
+
x = np.arange(len(mu))
|
| 160 |
+
plt.plot(x, all_losses, f'{color}-')
|
| 161 |
+
plt.fill_between(x, all_losses - sigma_losses, all_losses + sigma_losses, color=color, alpha=0.2)
|
| 162 |
+
plt.legend(['Mean Loss', 'Variance of Loss'])
|
| 163 |
+
plt.xlabel('Epoch')
|
| 164 |
+
plt.ylabel('Loss')
|
| 165 |
+
plt.title(title)
|
| 166 |
+
plt.show()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def train(model, conditioner, dataset, epochs=10):
|
| 170 |
+
model.train()
|
| 171 |
+
conditioner.train()
|
| 172 |
+
optimizer = torch.optim.Adam(list(model.parameters()) + list(conditioner.parameters()), lr=1e-4)
|
| 173 |
+
ce_loss = nn.CrossEntropyLoss()
|
| 174 |
+
torch.autograd.set_detect_anomaly(True)
|
| 175 |
+
all_bond_losses: list = []
|
| 176 |
+
all_noise_losses: list = []
|
| 177 |
+
all_losses: list = []
|
| 178 |
+
all_sigma_bond_losses: list = []
|
| 179 |
+
all_sigma_noise_losses: list = []
|
| 180 |
+
all_sigma_losses: list = []
|
| 181 |
+
|
| 182 |
+
for epoch in range(epochs):
|
| 183 |
+
total_bond_loss = 0
|
| 184 |
+
total_noise_loss = 0
|
| 185 |
+
total_loss = 0
|
| 186 |
+
sigma_bond_losses: list = []
|
| 187 |
+
sigma_noise_losses: list = []
|
| 188 |
+
sigma_losses: list = []
|
| 189 |
+
|
| 190 |
+
for data in dataset:
|
| 191 |
+
x_0, pos, edge_index, edge_attr, labels = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y
|
| 192 |
+
if torch.any(edge_attr >= 4) or torch.any(edge_attr < 0) or torch.any(torch.isnan(x_0)):
|
| 193 |
+
continue # skip corrupted data
|
| 194 |
+
# if torch.any(edge_attr < 0) or torch.any(torch.isnan(x_0)):
|
| 195 |
+
# continue # skip corrupted data
|
| 196 |
+
# print(f"x0: {x_0}")
|
| 197 |
+
t = torch.tensor([random.randint(1, 1000)])
|
| 198 |
+
noise = torch.randn_like(x_0) # original
|
| 199 |
+
# noise = torch.rand_like(x_0) # mine
|
| 200 |
+
x_t = add_noise(x_0, noise, t)
|
| 201 |
+
# x_t.relu_()
|
| 202 |
+
# print(f"\tx_t: {x_t}")
|
| 203 |
+
cond_embed = conditioner(labels)
|
| 204 |
+
# print(f"\tcond_embed: {cond_embed}")
|
| 205 |
+
pred_noise, bond_logits = model(x_t, pos, edge_index, t, cond_embed)
|
| 206 |
+
# print(f"\tpred_noise: {pred_noise}\n\tbond logits: {bond_logits}")
|
| 207 |
+
# Suppress this if needed. This is optimization, not necessity
|
| 208 |
+
# bond_logits = temperature_scaled_softmax(bond_logits, temperature=(1 - (1/(epoch+1))))
|
| 209 |
+
# loss = F.mse_loss(pred_noise, noise) + ce_loss(bond_logits, edge_attr)
|
| 210 |
+
loss_noise = F.mse_loss(pred_noise, noise)
|
| 211 |
+
loss_bond = ce_loss(bond_logits, edge_attr)
|
| 212 |
+
loss = loss_noise + loss_bond
|
| 213 |
+
optimizer.zero_grad()
|
| 214 |
+
loss.backward()
|
| 215 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 216 |
+
optimizer.step()
|
| 217 |
+
total_bond_loss += loss_bond.item()
|
| 218 |
+
total_noise_loss += loss_noise.item()
|
| 219 |
+
total_loss += loss.item()
|
| 220 |
+
sigma_bond_losses.append(loss_bond.item())
|
| 221 |
+
sigma_noise_losses.append(loss_noise.item())
|
| 222 |
+
sigma_losses.append(loss.item())
|
| 223 |
+
|
| 224 |
+
all_bond_losses.append(total_bond_loss)
|
| 225 |
+
all_noise_losses.append(total_noise_loss)
|
| 226 |
+
all_losses.append(total_loss)
|
| 227 |
+
all_sigma_bond_losses.append(torch.std(torch.tensor(sigma_bond_losses)))
|
| 228 |
+
all_sigma_noise_losses.append(torch.std(torch.tensor(sigma_noise_losses)))
|
| 229 |
+
all_sigma_losses.append(torch.std(torch.tensor(sigma_losses)))
|
| 230 |
+
# print(f"Epoch {epoch + 1}: Loss = {total_loss:.4f}")
|
| 231 |
+
print(f"Epoch {epoch}: Loss = {total_loss:.4f}, Noise Loss = {total_noise_loss:.4f}, Bond Loss = {total_bond_loss:.4f}")
|
| 232 |
+
|
| 233 |
+
plot_data(mu=all_bond_losses, sigma=all_sigma_bond_losses, color='b', title="Bond Loss")
|
| 234 |
+
plot_data(mu=all_noise_losses, sigma=all_sigma_noise_losses, color='r', title="Noise Loss")
|
| 235 |
+
plot_data(mu=all_losses, sigma=all_sigma_losses, color='g', title="Total Loss")
|
| 236 |
+
|
| 237 |
+
plt.plot(all_bond_losses)
|
| 238 |
+
plt.plot(all_noise_losses)
|
| 239 |
+
plt.plot(all_losses)
|
| 240 |
+
plt.legend(['Bond Loss', 'Noise Loss', 'Total Loss'])
|
| 241 |
+
plt.xlabel('Epoch')
|
| 242 |
+
plt.ylabel('Loss')
|
| 243 |
+
plt.title('Training Loss Over Epochs')
|
| 244 |
+
plt.show()
|
| 245 |
+
return model, conditioner
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# -------- Generation --------
|
| 249 |
+
def temperature_scaled_softmax(logits, temperature=1.0):
|
| 250 |
+
logits = logits / temperature
|
| 251 |
+
return torch.softmax(logits, dim=0)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
from rdkit.Chem import Draw
|
| 255 |
+
from rdkit import RDLogger
|
| 256 |
+
RDLogger.DisableLog('rdApp.*') # Suppress RDKit warnings
|
| 257 |
+
|
| 258 |
+
def sample_batch(model, conditioner, label_vec, steps=1000, batch_size=4):
|
| 259 |
+
mols = []
|
| 260 |
+
for _ in range(batch_size):
|
| 261 |
+
x_t = torch.randn((10, 1))
|
| 262 |
+
pos = torch.randn((10, 3))
|
| 263 |
+
edge_index = torch.randint(0, 10, (2, 20))
|
| 264 |
+
|
| 265 |
+
for t in reversed(range(1, steps + 1)):
|
| 266 |
+
cond_embed = conditioner(label_vec.unsqueeze(0))
|
| 267 |
+
pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed)
|
| 268 |
+
x_t = x_t - pred_x * (1.0 / steps)
|
| 269 |
+
|
| 270 |
+
x_t = x_t * 100.0
|
| 271 |
+
x_t.relu_()
|
| 272 |
+
atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist()
|
| 273 |
+
allowed_atoms = [6, 7, 8, 9, 15, 16, 17] # C, N, O, F, P, S, Cl
|
| 274 |
+
bond_logits.relu_()
|
| 275 |
+
|
| 276 |
+
mol = Chem.RWMol()
|
| 277 |
+
idx_map = {}
|
| 278 |
+
for i, atomic_num in enumerate(atom_types):
|
| 279 |
+
if atomic_num not in allowed_atoms:
|
| 280 |
+
continue
|
| 281 |
+
try:
|
| 282 |
+
atom = Chem.Atom(int(atomic_num))
|
| 283 |
+
idx_map[i] = mol.AddAtom(atom)
|
| 284 |
+
except Exception:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
if len(idx_map) < 2:
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
bond_type_map = {
|
| 291 |
+
0: Chem.BondType.SINGLE,
|
| 292 |
+
1: Chem.BondType.DOUBLE,
|
| 293 |
+
2: Chem.BondType.TRIPLE,
|
| 294 |
+
3: Chem.BondType.AROMATIC
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
added = set()
|
| 298 |
+
for i in range(edge_index.shape[1]):
|
| 299 |
+
a = int(edge_index[0, i])
|
| 300 |
+
b = int(edge_index[1, i])
|
| 301 |
+
if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map:
|
| 302 |
+
try:
|
| 303 |
+
bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE)
|
| 304 |
+
mol.AddBond(idx_map[a], idx_map[b], bond_type)
|
| 305 |
+
added.add((a, b))
|
| 306 |
+
except Exception:
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
mol = mol.GetMol()
|
| 311 |
+
Chem.SanitizeMol(mol)
|
| 312 |
+
mols.append(mol)
|
| 313 |
+
except Exception:
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
# if mols:
|
| 317 |
+
# img = Draw.MolsToGridImage(mols, molsPerRow=3, subImgSize=(200, 200), legends=[Chem.MolToSmiles(m) for m in mols])
|
| 318 |
+
# img.save("generated_image.png")
|
| 319 |
+
# img.show()
|
| 320 |
+
# else:
|
| 321 |
+
# print("No valid molecules were generated.")
|
| 322 |
+
return mols
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def sample(model, conditioner, label_vec, steps=1000, debug=True):
|
| 326 |
+
x_t = torch.randn((10, 1))
|
| 327 |
+
pos = torch.randn((10, 3))
|
| 328 |
+
edge_index = torch.randint(0, 10, (2, 20))
|
| 329 |
+
|
| 330 |
+
for t in reversed(range(1, steps + 1)):
|
| 331 |
+
cond_embed = conditioner(label_vec.unsqueeze(0))
|
| 332 |
+
pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed)
|
| 333 |
+
bond_logits = temperature_scaled_softmax(bond_logits, temperature=(1/t))
|
| 334 |
+
x_t = x_t - pred_x * (1.0 / steps)
|
| 335 |
+
|
| 336 |
+
x_t = x_t * 100.0
|
| 337 |
+
x_t.relu_()
|
| 338 |
+
atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist()
|
| 339 |
+
## Try limiting to only the molecules that the Scentience sensors can detect
|
| 340 |
+
allowed_atoms = [6, 7, 8, 9, 15, 16, 17] # C, N, O, F, P, S, Cl
|
| 341 |
+
bond_logits.relu_()
|
| 342 |
+
bond_preds = torch.argmax(bond_logits, dim=-1).tolist()
|
| 343 |
+
# bond_preds = torch.tensor(bond_preds)
|
| 344 |
+
# bond_preds = bond_preds * 100.0
|
| 345 |
+
# bond_preds.relu_()
|
| 346 |
+
# bond_preds.abs_()
|
| 347 |
+
# bond_preds = bond_preds.round().int().tolist()
|
| 348 |
+
if debug:
|
| 349 |
+
print(f"\tcond_embed: {cond_embed}")
|
| 350 |
+
print(f"\tx_t: {x_t}")
|
| 351 |
+
print(f"\tprediction: {x_t}")
|
| 352 |
+
print(f"\tbond logits: {bond_logits}")
|
| 353 |
+
print(f"\tatoms: {atom_types}")
|
| 354 |
+
print(f"\tbonds: {bond_preds}")
|
| 355 |
+
|
| 356 |
+
mol = Chem.RWMol()
|
| 357 |
+
idx_map = {}
|
| 358 |
+
for i, atomic_num in enumerate(atom_types):
|
| 359 |
+
if atomic_num not in allowed_atoms:
|
| 360 |
+
continue
|
| 361 |
+
try:
|
| 362 |
+
atom = Chem.Atom(int(atomic_num))
|
| 363 |
+
idx_map[i] = mol.AddAtom(atom)
|
| 364 |
+
except Exception:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
if len(idx_map) < 2:
|
| 368 |
+
print("Molecule too small or no valid atoms after filtering.")
|
| 369 |
+
return ""
|
| 370 |
+
|
| 371 |
+
bond_type_map = {
|
| 372 |
+
0: Chem.BondType.SINGLE,
|
| 373 |
+
1: Chem.BondType.DOUBLE,
|
| 374 |
+
2: Chem.BondType.TRIPLE,
|
| 375 |
+
3: Chem.BondType.AROMATIC
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
added = set()
|
| 379 |
+
for i in range(edge_index.shape[1]):
|
| 380 |
+
a = int(edge_index[0, i])
|
| 381 |
+
b = int(edge_index[1, i])
|
| 382 |
+
if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map:
|
| 383 |
+
try:
|
| 384 |
+
bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE)
|
| 385 |
+
mol.AddBond(idx_map[a], idx_map[b], bond_type)
|
| 386 |
+
added.add((a, b))
|
| 387 |
+
except Exception:
|
| 388 |
+
continue
|
| 389 |
+
try:
|
| 390 |
+
mol = mol.GetMol()
|
| 391 |
+
Chem.SanitizeMol(mol)
|
| 392 |
+
smiles = Chem.MolToSmiles(mol)
|
| 393 |
+
img = Draw.MolToImage(mol)
|
| 394 |
+
img.show()
|
| 395 |
+
print(f"Atom types: {atom_types}")
|
| 396 |
+
print(f"Generated SMILES: {smiles}")
|
| 397 |
+
return smiles
|
| 398 |
+
except Exception as e:
|
| 399 |
+
print(f"Sanitization error: {e}")
|
| 400 |
+
return ""
|
| 401 |
+
|
| 402 |
+
"""
|
| 403 |
+
Same as sample_a but with fixes:
|
| 404 |
+
* Hydrogen atoms (and other unstable elements) are now excluded from generation.
|
| 405 |
+
* Molecules with fewer than two valid atoms are skipped early.
|
| 406 |
+
* RDKit warnings (like valence errors) are suppressed from terminal output.
|
| 407 |
+
Note:
|
| 408 |
+
* You're filtering for atoms in [6, 7, 8, 9, 15, 16, 17] (C, N, O, F, P, S, Cl) — which is reasonable — but:
|
| 409 |
+
Some molecule graphs may become disconnected or too small (<2 atoms) after filtering
|
| 410 |
+
You're skipping all such graphs instead of attempting to repair or relax them
|
| 411 |
+
"""
|
| 412 |
+
def sample_original(model, conditioner, label_vec, steps=1000):
|
| 413 |
+
x_t = torch.randn((10, 1))
|
| 414 |
+
pos = torch.randn((10, 3))
|
| 415 |
+
edge_index = torch.randint(0, 10, (2, 20))
|
| 416 |
+
|
| 417 |
+
for t in reversed(range(1, steps + 1)):
|
| 418 |
+
cond_embed = conditioner(label_vec.unsqueeze(0))
|
| 419 |
+
pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed)
|
| 420 |
+
# bond_logits = temperature_scaled_softmax(bond_logits, temperature=(1/t))
|
| 421 |
+
x_t = x_t - pred_x * (1.0 / steps)
|
| 422 |
+
|
| 423 |
+
atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist()
|
| 424 |
+
allowed_atoms = [6, 7, 8, 9, 15, 16, 17] # C, N, O, F, P, S, Cl
|
| 425 |
+
bond_preds = torch.argmax(bond_logits, dim=-1).tolist()
|
| 426 |
+
|
| 427 |
+
mol = Chem.RWMol()
|
| 428 |
+
idx_map = {}
|
| 429 |
+
for i, atomic_num in enumerate(atom_types):
|
| 430 |
+
if atomic_num not in allowed_atoms:
|
| 431 |
+
continue
|
| 432 |
+
try:
|
| 433 |
+
atom = Chem.Atom(int(atomic_num))
|
| 434 |
+
idx_map[i] = mol.AddAtom(atom)
|
| 435 |
+
except Exception:
|
| 436 |
+
continue
|
| 437 |
+
|
| 438 |
+
if len(idx_map) < 2:
|
| 439 |
+
print("Molecule too small or no valid atoms after filtering.")
|
| 440 |
+
return ""
|
| 441 |
+
|
| 442 |
+
bond_type_map = {
|
| 443 |
+
0: Chem.BondType.SINGLE,
|
| 444 |
+
1: Chem.BondType.DOUBLE,
|
| 445 |
+
2: Chem.BondType.TRIPLE,
|
| 446 |
+
3: Chem.BondType.AROMATIC
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
added = set()
|
| 450 |
+
for i in range(edge_index.shape[1]):
|
| 451 |
+
a = int(edge_index[0, i])
|
| 452 |
+
b = int(edge_index[1, i])
|
| 453 |
+
if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map:
|
| 454 |
+
try:
|
| 455 |
+
bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE)
|
| 456 |
+
mol.AddBond(idx_map[a], idx_map[b], bond_type)
|
| 457 |
+
added.add((a, b))
|
| 458 |
+
except Exception:
|
| 459 |
+
continue
|
| 460 |
+
try:
|
| 461 |
+
mol = mol.GetMol()
|
| 462 |
+
Chem.SanitizeMol(mol)
|
| 463 |
+
smiles = Chem.MolToSmiles(mol)
|
| 464 |
+
img = Draw.MolToImage(mol)
|
| 465 |
+
img.show()
|
| 466 |
+
print(f"Atom types: {atom_types}")
|
| 467 |
+
print(f"Generated SMILES: {smiles}")
|
| 468 |
+
return smiles
|
| 469 |
+
except Exception as e:
|
| 470 |
+
print(f"Sanitization error: {e}")
|
| 471 |
+
return ""
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def sample_a(model, conditioner, label_vec, steps=500):
|
| 475 |
+
from rdkit.Chem import Draw
|
| 476 |
+
x_t = torch.randn((10, 1))
|
| 477 |
+
pos = torch.randn((10, 3))
|
| 478 |
+
edge_index = torch.randint(0, 10, (2, 20))
|
| 479 |
+
model.eval()
|
| 480 |
+
|
| 481 |
+
for t in reversed(range(1, steps + 1)):
|
| 482 |
+
cond_embed = conditioner(label_vec.unsqueeze(0))
|
| 483 |
+
pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed)
|
| 484 |
+
x_t = x_t - pred_x * (1.0 / steps)
|
| 485 |
+
|
| 486 |
+
atom_types = torch.clamp(x_t.round() * 100, 1, 118).int().squeeze().tolist()
|
| 487 |
+
bond_preds = torch.argmax(bond_logits, dim=-1).tolist()
|
| 488 |
+
|
| 489 |
+
mol = Chem.RWMol()
|
| 490 |
+
idx_map = {}
|
| 491 |
+
for i, atomic_num in enumerate(atom_types):
|
| 492 |
+
try:
|
| 493 |
+
atom = Chem.Atom(int(atomic_num))
|
| 494 |
+
idx_map[i] = mol.AddAtom(atom)
|
| 495 |
+
except Exception:
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
bond_type_map = {
|
| 499 |
+
0: Chem.BondType.SINGLE,
|
| 500 |
+
1: Chem.BondType.DOUBLE,
|
| 501 |
+
2: Chem.BondType.TRIPLE,
|
| 502 |
+
3: Chem.BondType.AROMATIC
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
added = set()
|
| 506 |
+
for i in range(edge_index.shape[1]):
|
| 507 |
+
a = int(edge_index[0, i])
|
| 508 |
+
b = int(edge_index[1, i])
|
| 509 |
+
if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map:
|
| 510 |
+
try:
|
| 511 |
+
bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE)
|
| 512 |
+
mol.AddBond(idx_map[a], idx_map[b], bond_type)
|
| 513 |
+
added.add((a, b))
|
| 514 |
+
except Exception:
|
| 515 |
+
continue
|
| 516 |
+
|
| 517 |
+
try:
|
| 518 |
+
mol = mol.GetMol()
|
| 519 |
+
Chem.SanitizeMol(mol)
|
| 520 |
+
smiles = Chem.MolToSmiles(mol)
|
| 521 |
+
img = Draw.MolToImage(mol)
|
| 522 |
+
img.show()
|
| 523 |
+
return smiles
|
| 524 |
+
except:
|
| 525 |
+
return ""
|
| 526 |
+
|
| 527 |
+
# -------- Validation --------
|
| 528 |
+
def validate_molecule(smiles):
|
| 529 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 530 |
+
if mol is None:
|
| 531 |
+
return False, {}
|
| 532 |
+
return True, {"MolWt": Descriptors.MolWt(mol), "LogP": Descriptors.MolLogP(mol)}
|
| 533 |
+
|
| 534 |
+
# -------- Load Data --------
|
| 535 |
+
def load_goodscents_subset(filepath="/content/curated_GS_LF_merged_4983.csv",
|
| 536 |
+
index=200
|
| 537 |
+
):
|
| 538 |
+
# max_rows = 500
|
| 539 |
+
df = pd.read_csv(filepath)
|
| 540 |
+
# df = df.sample(frac=1).reset_index(drop=True)
|
| 541 |
+
if index > 0:
|
| 542 |
+
df = df.head(index)
|
| 543 |
+
else:
|
| 544 |
+
df = df.tail(-1*index)
|
| 545 |
+
descriptor_cols = df.columns[2:]
|
| 546 |
+
smiles_list, label_map = [], {}
|
| 547 |
+
for _, row in df.iterrows():
|
| 548 |
+
smiles = row["nonStereoSMILES"]
|
| 549 |
+
labels = row[descriptor_cols].astype(int).tolist()
|
| 550 |
+
if smiles and any(labels):
|
| 551 |
+
smiles_list.append(smiles)
|
| 552 |
+
label_map[smiles] = labels
|
| 553 |
+
# if len(smiles_list) >= index:
|
| 554 |
+
# break
|
| 555 |
+
return smiles_list, label_map, list(descriptor_cols)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
# -------- Main --------
|
| 559 |
+
if __name__ == '__main__':
|
| 560 |
+
SHOULD_BATCH: bool = False
|
| 561 |
+
smiles_list, label_map, label_names = load_goodscents_subset(index=500)
|
| 562 |
+
num_labels = len(label_names)
|
| 563 |
+
dataset = []
|
| 564 |
+
for smi in smiles_list:
|
| 565 |
+
g = smiles_to_graph(smi)
|
| 566 |
+
if g:
|
| 567 |
+
g.y = torch.tensor(label_map[smi])
|
| 568 |
+
dataset.append(g)
|
| 569 |
+
model = EGNNDiffusionModel(node_dim=1, embed_dim=8)
|
| 570 |
+
conditioner = OlfactoryConditioner(num_labels=num_labels, embed_dim=8)
|
| 571 |
+
# train(model, conditioner, dataset, epochs=500)
|
| 572 |
+
train_success: bool = False
|
| 573 |
+
while not train_success:
|
| 574 |
+
try:
|
| 575 |
+
model, conditioner = train(model, conditioner, dataset, epochs=1000)
|
| 576 |
+
train_success = True
|
| 577 |
+
break
|
| 578 |
+
except IndexError:
|
| 579 |
+
print("Index Error on training. Trying again.")
|
| 580 |
+
test_label_vec = torch.zeros(num_labels)
|
| 581 |
+
if "floral" in label_names:
|
| 582 |
+
test_label_vec[label_names.index("floral")] = 0
|
| 583 |
+
if "fruity" in label_names:
|
| 584 |
+
test_label_vec[label_names.index("fruity")] = 1
|
| 585 |
+
if "musky" in label_names:
|
| 586 |
+
test_label_vec[label_names.index("musky")] = 0
|
| 587 |
+
|
| 588 |
+
model.eval()
|
| 589 |
+
conditioner.eval()
|
| 590 |
+
if SHOULD_BATCH:
|
| 591 |
+
new_smiles_list = sample_batch(model, conditioner, label_vec=test_label_vec)
|
| 592 |
+
for new_smiles in new_smiles_list:
|
| 593 |
+
print(new_smiles)
|
| 594 |
+
valid, props = validate_molecule(new_smiles)
|
| 595 |
+
print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}")
|
| 596 |
+
else:
|
| 597 |
+
new_smiles = sample(model, conditioner, label_vec=test_label_vec)
|
| 598 |
+
print(new_smiles)
|
| 599 |
+
valid, props = validate_molecule(new_smiles)
|
| 600 |
+
print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}")
|
model/diffusion_gnn/inference.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from .diffusion_gnn import sample, validate_molecule
|
| 5 |
+
from .diffusion_gnn import load_goodscents_subset, smiles_to_graph
|
| 6 |
+
from .diffusion_gnn import EGNNDiffusionModel, OlfactoryConditioner
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_models(test_model, test_conditioner):
|
| 10 |
+
good_count: int = 0
|
| 11 |
+
index: int = 10
|
| 12 |
+
smiles_list, label_map, label_names = load_goodscents_subset(index=index)
|
| 13 |
+
num_labels = len(label_names)
|
| 14 |
+
dataset = []
|
| 15 |
+
model.eval()
|
| 16 |
+
conditioner.eval()
|
| 17 |
+
for smi in smiles_list:
|
| 18 |
+
g = smiles_to_graph(smi)
|
| 19 |
+
if g:
|
| 20 |
+
g.y = torch.tensor(label_map[smi])
|
| 21 |
+
dataset.append(g)
|
| 22 |
+
|
| 23 |
+
for i in range(0, len(dataset)):
|
| 24 |
+
print(f"Testing molecule {i+1}/{len(dataset)}")
|
| 25 |
+
data = dataset[i]
|
| 26 |
+
x_0, pos, edge_index, edge_attr, label_vec = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y
|
| 27 |
+
print(f"label vec: {label_vec}")
|
| 28 |
+
print(f"len: {len(label_vec.tolist())}")
|
| 29 |
+
new_smiles = sample(test_model, test_conditioner, label_vec=label_vec)
|
| 30 |
+
print(new_smiles)
|
| 31 |
+
valid, props = validate_molecule(new_smiles)
|
| 32 |
+
print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}")
|
| 33 |
+
if new_smiles != "":
|
| 34 |
+
good_count += 1
|
| 35 |
+
percent_correct: float = float(good_count) / float(len(dataset))
|
| 36 |
+
print(f"Percent correct: {percent_correct}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def inference(model: EGNNDiffusionModel, conditioner: OlfactoryConditioner, data):
|
| 40 |
+
model.eval()
|
| 41 |
+
conditioner.eval()
|
| 42 |
+
x_0, pos, edge_index, edge_attr, label_vec = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y
|
| 43 |
+
new_smiles = sample(model, conditioner, label_vec=label_vec)
|
| 44 |
+
print(new_smiles)
|
| 45 |
+
valid, props = validate_molecule(new_smiles)
|
| 46 |
+
print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_models():
|
| 50 |
+
smiles_list, label_map, label_names = load_goodscents_subset(index=0)
|
| 51 |
+
num_labels = len(label_names)
|
| 52 |
+
model = EGNNDiffusionModel(node_dim=1, embed_dim=8)
|
| 53 |
+
model.load_state_dict(torch.load('/content/egnn_state_dict_20250427.pth'))
|
| 54 |
+
model.eval() # Set to evaluation mode if you are not training
|
| 55 |
+
|
| 56 |
+
conditioner = OlfactoryConditioner(num_labels=num_labels, embed_dim=8)
|
| 57 |
+
conditioner.load_state_dict(torch.load('/content/olfactory_conditioner_state_dict.pth'))
|
| 58 |
+
conditioner.eval() # Set to evaluation mode if you are not training
|
| 59 |
+
|
| 60 |
+
return model, conditioner
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
model, conditioner = load_models()
|
| 65 |
+
test_models(test_model=model, test_conditioner=conditioner)
|
model/ovle-small/gat/gat_gnn.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9a7196e0b68bf221dc38a461a1158eb1c1cf865c85b66487b55718e070a375d
|
| 3 |
+
size 8446520
|
model/ovle-small/gat/gat_gnn_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:855d7f36afa28317cd6c272a8cb20445c04992bed18f2a39aca462600fa7274d
|
| 3 |
+
size 8425362
|
model/ovle-small/gat/gat_olf_encoder.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5c065dd05e675a215fde49f091a900bda259aed3f7b8d13b0ade4ed981f9c12
|
| 3 |
+
size 325456
|
model/ovle-small/gat/gat_olf_encoder_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:417a2eaa039d798f1a0404b5e37750cd940cb7ae893de1a298c11fdef756159a
|
| 3 |
+
size 324392
|
model/ovle-small/nn/gnn.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28e7c0fbb366362b0a53b2fee8c097ec8f3dee3bbd5fd298898fa1cf05b1e17d
|
| 3 |
+
size 2104368
|
model/ovle-small/nn/gnn_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a7098967ff9335de18cba73e34de76ea7c657318f0452bafc9f76c9054556be
|
| 3 |
+
size 2103432
|
model/ovle-small/nn/olf_encoder.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:63f6113d7f5e196dacb416e0510a71bd0f2f7827e0fb808f913ab0c5d5ed0d7d
|
| 3 |
+
size 459248
|
model/ovle-small/nn/olf_encoder_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31c02aac7816b5a537268f60379f3699874595971ae22643d791e43afa4d64f8
|
| 3 |
+
size 457588
|
model/train_ovm.py
ADDED
|
File without changes
|
model_cards/ovle-large.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Card: Scentience-OVLE-Large-v1
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Model Details
|
| 5 |
+
- **Model Name:** `Scentience OVLE Large v1`
|
| 6 |
+
- **Developed by:** Kordel K. France
|
| 7 |
+
- **Date:** September 2025
|
| 8 |
+
- **Architecture:**
|
| 9 |
+
- **Olfaction encoder:** 138-sensor embedding
|
| 10 |
+
- **Vision encoder:** CLIP-based
|
| 11 |
+
- **Language encoder:** CLIP-based
|
| 12 |
+
- **Fusion strategy:** Joint embedding space via multimodal contrastive training
|
| 13 |
+
- **Parameter Count:** 250M
|
| 14 |
+
- **License:** MIT
|
| 15 |
+
- **Contact:** [email protected]
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Intended Use
|
| 20 |
+
- **Primary purpose:** Research in multimodal machine learning involving olfaction, vision, and language.
|
| 21 |
+
- **Example applications:**
|
| 22 |
+
- Cross-modal retrieval (odor → image, odor → text, etc.)
|
| 23 |
+
- Robotics and UAV navigation guided by chemical cues
|
| 24 |
+
- Chemical dataset exploration and visualization
|
| 25 |
+
- **Intended users:** Researchers, developers, and educators working in ML, robotics, chemistry, and HCI.
|
| 26 |
+
- **Out of scope:** Not intended for safety-critical tasks (e.g., gas leak detection, medical diagnosis, or regulatory use).
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## Training Data
|
| 31 |
+
- **Olfaction data:** Language-aligned olfactory data curated from GoodScents and LeffingWell datasets.
|
| 32 |
+
- **Vision data:** COCO dataset.
|
| 33 |
+
- **Language data:** Smell descriptors and text annotations curated from literature.
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Evaluation
|
| 38 |
+
- Retrieval tasks: odor→image (Top-5 recall = 62%)
|
| 39 |
+
- Odor descriptor classification accuracy = 71%
|
| 40 |
+
- Cross-modal embedding alignment qualitatively verified on 200 sample triplets.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Limitations of Evaluation
|
| 45 |
+
To the best of our knowledge, there are currently no open-source datasets that provide aligned olfactory, visual, and linguistic annotations. A “true” multimodal evaluation would require measuring the chemical composition of scenes (e.g., using gas chromatography mass spectrometry) while simultaneously capturing images and collecting perceptual descriptors from human olfactory judges. Such a benchmark would demand substantial new data collection efforts and instrumentation.
|
| 46 |
+
Consequently, we evaluate our models indirectly, using surrogate metrics (e.g., cross-modal retrieval performance, odor descriptor classification accuracy, clustering quality). While these evaluations do not provide ground-truth verification of odor presence in images, they offer a first step toward demonstrating alignment between modalities.
|
| 47 |
+
We draw analogy from past successes in ML datasets such as precursors to CLIP that lacked large paired datasets and were evaluated on retrieval-like tasks.
|
| 48 |
+
As a result, we release this model to catalyze further research and encourage the community to contribute to building standardized datasets and evaluation protocols for olfaction-vision-language learning.
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Limitations
|
| 53 |
+
- Limited odor diversity (approx. 5000 unique compounds).
|
| 54 |
+
- Embeddings depend on sensor calibration; not guaranteed across devices.
|
| 55 |
+
- Cultural subjectivity in smell annotations may bias embeddings.
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Ethical Considerations
|
| 60 |
+
- Not to be used for covert detection of substances or surveillance.
|
| 61 |
+
- Unreliable in safety-critical contexts (e.g., gas leak detection).
|
| 62 |
+
- Recognizes cultural sensitivity in smell perception.
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
## Environmental Impact
|
| 67 |
+
- Trained on 4×A100 GPUs for 48 hours (~200 kg CO2eq).
|
| 68 |
+
- Sensor dataset collection required ~500 lab hours.
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
## Citation
|
| 73 |
+
If you use this model, please cite:
|
| 74 |
+
```
|
| 75 |
+
@misc{scentience2025ovle,
|
| 76 |
+
title = {Scentience-OVLE-Large-v1: Joint Olfaction-Vision-Language Embeddings},
|
| 77 |
+
author = {Kordel Kade France},
|
| 78 |
+
year = {2025},
|
| 79 |
+
howpublished = {Hugging Face},
|
| 80 |
+
url = {https://huggingface.co/your-username/Scentience-OVLE-Large-v1}
|
| 81 |
+
}
|
| 82 |
+
```
|
model_cards/ovle-small.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Card: Scentience-OVLE-Small-v1
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Model Details
|
| 5 |
+
- **Model Name:** `Scentience OVLE Small v1`
|
| 6 |
+
- **Developed by:** Kordel K. France
|
| 7 |
+
- **Date:** September 2025
|
| 8 |
+
- **Architecture:**
|
| 9 |
+
- **Olfaction encoder:** 138-sensor embedding
|
| 10 |
+
- **Vision encoder:** CLIP-based
|
| 11 |
+
- **Language encoder:** CLIP-based
|
| 12 |
+
- **Fusion strategy:** Joint embedding space via multimodal contrastive training
|
| 13 |
+
- **Parameter Count:** 2M
|
| 14 |
+
- **License:** MIT
|
| 15 |
+
- **Contact:** [email protected]
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Intended Use
|
| 20 |
+
- **Primary purpose:** Research in multimodal machine learning involving olfaction, vision, and language.
|
| 21 |
+
- **Example applications:**
|
| 22 |
+
- Cross-modal retrieval (odor → image, odor → text, etc.)
|
| 23 |
+
- Robotics and UAV navigation guided by chemical cues
|
| 24 |
+
- Chemical dataset exploration and visualization
|
| 25 |
+
- **Intended users:** Researchers, developers, and educators working in ML, robotics, chemistry, and HCI.
|
| 26 |
+
- **Out of scope:** Not intended for safety-critical tasks (e.g., gas leak detection, medical diagnosis, or regulatory use).
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## Training Data
|
| 31 |
+
- **Olfaction data:** Language-aligned olfactory data curated from GoodScents and LeffingWell datasets.
|
| 32 |
+
- **Vision data:** COCO dataset.
|
| 33 |
+
- **Language data:** Smell descriptors and text annotations curated from literature.
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Evaluation
|
| 38 |
+
- Retrieval tasks: odor→image (Top-5 recall = 62%)
|
| 39 |
+
- Odor descriptor classification accuracy = 71%
|
| 40 |
+
- Cross-modal embedding alignment qualitatively verified on 200 sample triplets.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Limitations of Evaluation
|
| 45 |
+
To the best of our knowledge, there are currently no open-source datasets that provide aligned olfactory, visual, and linguistic annotations. A “true” multimodal evaluation would require measuring the chemical composition of scenes (e.g., using gas chromatography mass spectrometry) while simultaneously capturing images and collecting perceptual descriptors from human olfactory judges. Such a benchmark would demand substantial new data collection efforts and instrumentation.
|
| 46 |
+
Consequently, we evaluate our models indirectly, using surrogate metrics (e.g., cross-modal retrieval performance, odor descriptor classification accuracy, clustering quality). While these evaluations do not provide ground-truth verification of odor presence in images, they offer a first step toward demonstrating alignment between modalities.
|
| 47 |
+
We draw analogy from past successes in ML datasets such as precursors to CLIP that lacked large paired datasets and were evaluated on retrieval-like tasks.
|
| 48 |
+
As a result, we release this model to catalyze further research and encourage the community to contribute to building standardized datasets and evaluation protocols for olfaction-vision-language learning.
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Limitations
|
| 53 |
+
- Limited odor diversity (approx. 5000 unique compounds).
|
| 54 |
+
- Embeddings depend on sensor calibration; not guaranteed across devices.
|
| 55 |
+
- Cultural subjectivity in smell annotations may bias embeddings.
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Ethical Considerations
|
| 60 |
+
- Not to be used for covert detection of substances or surveillance.
|
| 61 |
+
- Unreliable in safety-critical contexts (e.g., gas leak detection).
|
| 62 |
+
- Recognizes cultural sensitivity in smell perception.
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
## Environmental Impact
|
| 67 |
+
- Trained on 4×A100 GPUs for 48 hours (~200 kg CO2eq).
|
| 68 |
+
- Sensor dataset collection required ~500 lab hours.
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
## Citation
|
| 73 |
+
If you use this model, please cite:
|
| 74 |
+
```
|
| 75 |
+
@misc{scentience2025ovle,
|
| 76 |
+
title = {Scentience-OVLE-Base-v1: Joint Olfaction-Vision-Language Embeddings},
|
| 77 |
+
author = {Kordel Kade France},
|
| 78 |
+
year = {2025},
|
| 79 |
+
howpublished = {Hugging Face},
|
| 80 |
+
url = {https://huggingface.co/your-username/Scentience-OVLE-Large-v1}
|
| 81 |
+
}
|
| 82 |
+
```
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
rdkit==2024.9.6
|
| 2 |
+
torch-geometric==2.6.1
|
| 3 |
+
torch
|
| 4 |
+
bleak
|