Spaces:
Sleeping
Sleeping
Commit
·
994fb49
1
Parent(s):
b1aa639
new experiments
Browse files- mvp/models/contrastive.py +42 -44
- mvp/models/mol_encoder.py +1 -2
- mvp/models/spec_encoder.py +5 -3
- mvp/params_formSpec.yaml +16 -13
- mvp/params_tmp.yaml +19 -15
- mvp/run.sh +3 -3
- mvp/tune.py +264 -0
- notebooks/2v1.ipynb +261 -0
- notebooks/UMAP_spectra_embeddings.ipynb +0 -0
- notebooks/attribute_viz.ipynb +56 -0
- notebooks/filip_viz.ipynb +706 -0
- notebooks/hyperparameter_tuning_result.ipynb +0 -0
- notebooks/magma_script.ipynb +146 -0
- notebooks/peak_embedding_UMAP.ipynb +0 -0
- notebooks/peak_formula_analysis.ipynb +370 -0
- notebooks/visualization.ipynb +201 -0
mvp/models/contrastive.py
CHANGED
|
@@ -27,17 +27,15 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 27 |
super().__init__(**kwargs)
|
| 28 |
self.save_hyperparameters()
|
| 29 |
|
| 30 |
-
if 'use_fp' not in self.hparams:
|
| 31 |
-
self.hparams.use_fp = False
|
| 32 |
if 'use_fp' not in self.hparams:
|
| 33 |
self.hparams.use_fp = False
|
| 34 |
if 'use_NL_spec' not in self.hparams:
|
| 35 |
self.hparams.use_NL_spec = False
|
| 36 |
|
| 37 |
-
if 'loss_strategy' not in self.hparams:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
|
| 43 |
self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
|
|
@@ -58,26 +56,26 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 58 |
self.result_dct = defaultdict(lambda: defaultdict(list))
|
| 59 |
|
| 60 |
|
| 61 |
-
def _loss_setup(self):
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
def forward(self, batch, stage):
|
| 83 |
g = batch['cand'] if stage == Stage.TEST else batch['mol']
|
|
@@ -99,11 +97,11 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 99 |
def compute_loss(self, batch: dict, spec_enc, mol_enc, output):
|
| 100 |
loss = 0
|
| 101 |
losses = {}
|
| 102 |
-
contr_loss,
|
| 103 |
-
contr_loss = self.loss_wts['contr_wt'] *contr_loss
|
| 104 |
losses['contr_loss'] = contr_loss.detach().item()
|
| 105 |
-
losses['cong_loss'] = cong_loss.detach().item()
|
| 106 |
-
losses['noncong_loss'] = noncong_loss.detach().item()
|
| 107 |
|
| 108 |
loss+=contr_loss
|
| 109 |
if self.hparams.pred_fp:
|
|
@@ -217,7 +215,7 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 217 |
on_epoch=True,
|
| 218 |
)
|
| 219 |
|
| 220 |
-
def test_step(self, batch):
|
| 221 |
# Unpack inputs
|
| 222 |
identifiers = batch['identifier']
|
| 223 |
cand_smiles = batch['cand_smiles']
|
|
@@ -274,17 +272,17 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
|
|
| 274 |
]
|
| 275 |
return monitors
|
| 276 |
|
| 277 |
-
def _update_loss_weights(self)-> None:
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
def on_train_epoch_end(self) -> None:
|
| 287 |
-
|
| 288 |
|
| 289 |
class MultiViewContrastive(ContrastiveModel):
|
| 290 |
|
|
@@ -473,7 +471,7 @@ class FilipContrastive(ContrastiveModel):
|
|
| 473 |
|
| 474 |
return losses
|
| 475 |
|
| 476 |
-
def test_step(self, batch):
|
| 477 |
# Unpack inputs
|
| 478 |
identifiers = batch['identifier']
|
| 479 |
cand_smiles = batch['cand_smiles']
|
|
@@ -758,7 +756,7 @@ class CrossAttenContrastive(ContrastiveModel):
|
|
| 758 |
g_n_nodes = batch['mol_n_nodes']
|
| 759 |
|
| 760 |
# encode peaks and nodes
|
| 761 |
-
spec_enc = self.spec_enc_model(spec)
|
| 762 |
mol_enc = self.mol_enc_model(g)
|
| 763 |
|
| 764 |
# pad mol_enc and spec_enc to have the same length
|
|
|
|
| 27 |
super().__init__(**kwargs)
|
| 28 |
self.save_hyperparameters()
|
| 29 |
|
|
|
|
|
|
|
| 30 |
if 'use_fp' not in self.hparams:
|
| 31 |
self.hparams.use_fp = False
|
| 32 |
if 'use_NL_spec' not in self.hparams:
|
| 33 |
self.hparams.use_NL_spec = False
|
| 34 |
|
| 35 |
+
# if 'loss_strategy' not in self.hparams:
|
| 36 |
+
# self.hparams.loss_strategy = 'static'
|
| 37 |
+
# self.hparams.contr_wt = 1.0
|
| 38 |
+
# self.hparams.use_contr = True
|
| 39 |
|
| 40 |
self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
|
| 41 |
self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
|
|
|
|
| 56 |
self.result_dct = defaultdict(lambda: defaultdict(list))
|
| 57 |
|
| 58 |
|
| 59 |
+
# def _loss_setup(self):
|
| 60 |
+
# self.loss_wts = {}
|
| 61 |
+
# self.loss_updates = {}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# for p, loss in zip(['use_contr','pred_fp', 'use_cons_spec', 'aug_cands'], ['contr_wt','fp_wt','cons_spec_wt' ,'aug_cands_wt']):
|
| 65 |
+
# if p not in self.hparams:
|
| 66 |
+
# self.hparams[p] = False
|
| 67 |
+
# if self.hparams[p]:
|
| 68 |
+
# if self.hparams.loss_strategy == 'linear':
|
| 69 |
+
# start_wt = self.hparams[loss+'_update']['start']
|
| 70 |
+
# end_wt = self.hparams[loss+'_update']['end']
|
| 71 |
+
# change = (end_wt - start_wt)/self.hparams.max_epochs
|
| 72 |
+
# self.loss_updates[loss] = change
|
| 73 |
+
# self.loss_wts[loss] = start_wt
|
| 74 |
+
# elif self.hparams.loss_strategy == 'manual':
|
| 75 |
+
# self.loss_updates[loss] = self.hparams[loss+'_update']
|
| 76 |
+
# self.loss_wts[loss] = self.hparams[loss]
|
| 77 |
+
# else:
|
| 78 |
+
# self.loss_wts[loss] = self.hparams[loss]
|
| 79 |
|
| 80 |
def forward(self, batch, stage):
|
| 81 |
g = batch['cand'] if stage == Stage.TEST else batch['mol']
|
|
|
|
| 97 |
def compute_loss(self, batch: dict, spec_enc, mol_enc, output):
|
| 98 |
loss = 0
|
| 99 |
losses = {}
|
| 100 |
+
contr_loss, _, _ = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp)
|
| 101 |
+
# contr_loss = self.loss_wts['contr_wt'] *contr_loss
|
| 102 |
losses['contr_loss'] = contr_loss.detach().item()
|
| 103 |
+
# losses['cong_loss'] = cong_loss.detach().item()
|
| 104 |
+
# losses['noncong_loss'] = noncong_loss.detach().item()
|
| 105 |
|
| 106 |
loss+=contr_loss
|
| 107 |
if self.hparams.pred_fp:
|
|
|
|
| 215 |
on_epoch=True,
|
| 216 |
)
|
| 217 |
|
| 218 |
+
def test_step(self, batch, batch_idx):
|
| 219 |
# Unpack inputs
|
| 220 |
identifiers = batch['identifier']
|
| 221 |
cand_smiles = batch['cand_smiles']
|
|
|
|
| 272 |
]
|
| 273 |
return monitors
|
| 274 |
|
| 275 |
+
# def _update_loss_weights(self)-> None:
|
| 276 |
+
# if self.hparams.loss_strategy == 'linear':
|
| 277 |
+
# for loss in self.loss_wts:
|
| 278 |
+
# self.loss_wts[loss] += self.loss_updates[loss]
|
| 279 |
+
# elif self.hparams.loss_strategy == 'manual':
|
| 280 |
+
# for loss in self.loss_wts:
|
| 281 |
+
# if self.current_epoch in self.loss_updates[loss]:
|
| 282 |
+
# self.loss_wts[loss] = self.loss_updates[loss][self.current_epoch]
|
| 283 |
+
|
| 284 |
+
# def on_train_epoch_end(self) -> None:
|
| 285 |
+
# self._update_loss_weights()
|
| 286 |
|
| 287 |
class MultiViewContrastive(ContrastiveModel):
|
| 288 |
|
|
|
|
| 471 |
|
| 472 |
return losses
|
| 473 |
|
| 474 |
+
def test_step(self, batch, batch_idx):
|
| 475 |
# Unpack inputs
|
| 476 |
identifiers = batch['identifier']
|
| 477 |
cand_smiles = batch['cand_smiles']
|
|
|
|
| 756 |
g_n_nodes = batch['mol_n_nodes']
|
| 757 |
|
| 758 |
# encode peaks and nodes
|
| 759 |
+
spec_enc = self.spec_enc_model(spec, spec_n_forms)
|
| 760 |
mol_enc = self.mol_enc_model(g)
|
| 761 |
|
| 762 |
# pad mol_enc and spec_enc to have the same length
|
mvp/models/mol_encoder.py
CHANGED
|
@@ -12,14 +12,13 @@ class MolEnc(nn.Module):
|
|
| 12 |
|
| 13 |
self.return_emb = False
|
| 14 |
|
| 15 |
-
if args.model in ('
|
| 16 |
self.return_emb = True
|
| 17 |
|
| 18 |
dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
|
| 19 |
batchnorm = [True for _ in range(len(args.gnn_channels))]
|
| 20 |
gnn_map = {
|
| 21 |
"gcn": GCN(in_dim, args.gnn_channels, batchnorm = batchnorm, dropout = dropout),
|
| 22 |
-
"gat": GAT(in_dim, args.gnn_channels, args.attn_heads)
|
| 23 |
}
|
| 24 |
self.GNN = gnn_map[args.gnn_type]
|
| 25 |
self.pool = dgl.nn.pytorch.glob.MaxPooling()
|
|
|
|
| 12 |
|
| 13 |
self.return_emb = False
|
| 14 |
|
| 15 |
+
if args.model in ('filipContrastive', 'crossAttenContrastive'):
|
| 16 |
self.return_emb = True
|
| 17 |
|
| 18 |
dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
|
| 19 |
batchnorm = [True for _ in range(len(args.gnn_channels))]
|
| 20 |
gnn_map = {
|
| 21 |
"gcn": GCN(in_dim, args.gnn_channels, batchnorm = batchnorm, dropout = dropout),
|
|
|
|
| 22 |
}
|
| 23 |
self.GNN = gnn_map[args.gnn_type]
|
| 24 |
self.pool = dgl.nn.pytorch.glob.MaxPooling()
|
mvp/models/spec_encoder.py
CHANGED
|
@@ -121,12 +121,14 @@ class SpecFormulaTransformer(nn.Module):
|
|
| 121 |
self.use_cls = args.use_cls
|
| 122 |
if args.use_cls:
|
| 123 |
self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1])
|
| 124 |
-
encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=
|
| 125 |
-
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=
|
| 126 |
|
| 127 |
if not out_dim:
|
| 128 |
out_dim = args.final_embedding_dim
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
|
| 131 |
def forward(self, spec, n_peaks):
|
| 132 |
h = self.formulaEnc(spec)
|
|
|
|
| 121 |
self.use_cls = args.use_cls
|
| 122 |
if args.use_cls:
|
| 123 |
self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1])
|
| 124 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=args.formula_attn_heads, batch_first=True)
|
| 125 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=args.formula_transformer_layers)
|
| 126 |
|
| 127 |
if not out_dim:
|
| 128 |
out_dim = args.final_embedding_dim
|
| 129 |
+
|
| 130 |
+
if not self.returnEmb:
|
| 131 |
+
self.fc = nn.Linear(args.formula_dims[-1], out_dim)
|
| 132 |
|
| 133 |
def forward(self, spec, n_peaks):
|
| 134 |
h = self.formulaEnc(spec)
|
mvp/params_formSpec.yaml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Experiment setup
|
| 2 |
job_key: ''
|
| 3 |
-
run_name: '
|
| 4 |
run_details: ""
|
| 5 |
project_name: ''
|
| 6 |
wandb_entity_name: 'mass-spec-ml'
|
|
@@ -12,14 +12,14 @@ checkpoint_pth:
|
|
| 12 |
# Training setup
|
| 13 |
max_epochs: 2000
|
| 14 |
accelerator: 'gpu'
|
| 15 |
-
devices: [
|
| 16 |
log_every_n_steps: 250
|
| 17 |
val_check_interval: 1.0
|
| 18 |
|
| 19 |
# Data paths
|
| 20 |
candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
|
| 21 |
-
dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
|
| 22 |
-
subformula_dir_pth: /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
|
| 23 |
split_pth:
|
| 24 |
fp_dir_pth:
|
| 25 |
cons_spec_dir_pth:
|
|
@@ -28,9 +28,9 @@ partial_checkpoint: ""
|
|
| 28 |
|
| 29 |
# General hyperparameters
|
| 30 |
batch_size: 64
|
| 31 |
-
lr: 5.0e-05
|
| 32 |
-
weight_decay:
|
| 33 |
-
contr_temp: 0.05
|
| 34 |
early_stopping_patience: 300
|
| 35 |
loss_strategy: 'static'
|
| 36 |
num_workers: 50
|
|
@@ -39,7 +39,7 @@ num_workers: 50
|
|
| 39 |
############################## Data transforms ##############################
|
| 40 |
# - Spectra
|
| 41 |
spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
|
| 42 |
-
formula_source: '
|
| 43 |
# 1. Binner
|
| 44 |
max_mz: 1000
|
| 45 |
bin_width: 1
|
|
@@ -103,20 +103,23 @@ fc_dropout: 0.4
|
|
| 103 |
|
| 104 |
# - Spectra Token encoder
|
| 105 |
hidden_dims: [64, 128]
|
| 106 |
-
|
| 107 |
|
| 108 |
# - Formula-based spec encoders
|
| 109 |
-
formula_dropout:
|
| 110 |
-
formula_dims: [64, 128, 256]
|
| 111 |
cross_attn_heads: 2
|
| 112 |
use_cls: False
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# -- GAT params
|
| 115 |
attn_heads: [12,12,12]
|
| 116 |
|
| 117 |
# - Molecule encoder (GNN)
|
| 118 |
-
gnn_channels: [64,128,
|
| 119 |
gnn_type: "gcn"
|
| 120 |
num_gnn_layers: 3
|
| 121 |
gnn_hidden_dim: 512
|
| 122 |
-
gnn_dropout: 0.3
|
|
|
|
| 1 |
# Experiment setup
|
| 2 |
job_key: ''
|
| 3 |
+
run_name: 'optimized_filip-model'
|
| 4 |
run_details: ""
|
| 5 |
project_name: ''
|
| 6 |
wandb_entity_name: 'mass-spec-ml'
|
|
|
|
| 12 |
# Training setup
|
| 13 |
max_epochs: 2000
|
| 14 |
accelerator: 'gpu'
|
| 15 |
+
devices: [1]
|
| 16 |
log_every_n_steps: 250
|
| 17 |
val_check_interval: 1.0
|
| 18 |
|
| 19 |
# Data paths
|
| 20 |
candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
|
| 21 |
+
dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
|
| 22 |
+
subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
|
| 23 |
split_pth:
|
| 24 |
fp_dir_pth:
|
| 25 |
cons_spec_dir_pth:
|
|
|
|
| 28 |
|
| 29 |
# General hyperparameters
|
| 30 |
batch_size: 64
|
| 31 |
+
lr: 2.881339661302105e-05 # 5.0e-05
|
| 32 |
+
weight_decay: 1.1586679936312845e-05
|
| 33 |
+
contr_temp: 0.022772534845886608 # 0.05
|
| 34 |
early_stopping_patience: 300
|
| 35 |
loss_strategy: 'static'
|
| 36 |
num_workers: 50
|
|
|
|
| 39 |
############################## Data transforms ##############################
|
| 40 |
# - Spectra
|
| 41 |
spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
|
| 42 |
+
formula_source: 'default' # magma_1, magma_all, sirius, default
|
| 43 |
# 1. Binner
|
| 44 |
max_mz: 1000
|
| 45 |
bin_width: 1
|
|
|
|
| 103 |
|
| 104 |
# - Spectra Token encoder
|
| 105 |
hidden_dims: [64, 128]
|
| 106 |
+
|
| 107 |
|
| 108 |
# - Formula-based spec encoders
|
| 109 |
+
formula_dropout: 0.2
|
| 110 |
+
formula_dims: [512, 256, 512] #[64, 128, 256]
|
| 111 |
cross_attn_heads: 2
|
| 112 |
use_cls: False
|
| 113 |
+
peak_dropout: 0.414425691950033 # 0.2
|
| 114 |
+
formula_attn_heads: 4 # 2
|
| 115 |
+
formula_transformer_layers: 2
|
| 116 |
|
| 117 |
# -- GAT params
|
| 118 |
attn_heads: [12,12,12]
|
| 119 |
|
| 120 |
# - Molecule encoder (GNN)
|
| 121 |
+
gnn_channels: [64,128,512]
|
| 122 |
gnn_type: "gcn"
|
| 123 |
num_gnn_layers: 3
|
| 124 |
gnn_hidden_dim: 512
|
| 125 |
+
gnn_dropout: 0.23234950970370824 #0.3
|
mvp/params_tmp.yaml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Experiment setup
|
| 2 |
job_key: ''
|
| 3 |
-
run_name: '
|
| 4 |
run_details: ""
|
| 5 |
project_name: ''
|
| 6 |
wandb_entity_name: 'mass-spec-ml'
|
|
@@ -10,16 +10,16 @@ debug: False
|
|
| 10 |
checkpoint_pth:
|
| 11 |
|
| 12 |
# Training setup
|
| 13 |
-
max_epochs:
|
| 14 |
accelerator: 'gpu'
|
| 15 |
devices: [1]
|
| 16 |
log_every_n_steps: 250
|
| 17 |
val_check_interval: 1.0
|
| 18 |
|
| 19 |
# Data paths
|
| 20 |
-
candidates_pth: /
|
| 21 |
-
dataset_pth: /data/yzhouc01/
|
| 22 |
-
subformula_dir_pth: /data/yzhouc01/
|
| 23 |
split_pth:
|
| 24 |
fp_dir_pth:
|
| 25 |
cons_spec_dir_pth:
|
|
@@ -28,9 +28,9 @@ partial_checkpoint: ""
|
|
| 28 |
|
| 29 |
# General hyperparameters
|
| 30 |
batch_size: 64
|
| 31 |
-
lr: 5.0e-05
|
| 32 |
-
weight_decay:
|
| 33 |
-
contr_temp: 0.05
|
| 34 |
early_stopping_patience: 300
|
| 35 |
loss_strategy: 'static'
|
| 36 |
num_workers: 50
|
|
@@ -39,6 +39,7 @@ num_workers: 50
|
|
| 39 |
############################## Data transforms ##############################
|
| 40 |
# - Spectra
|
| 41 |
spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
|
|
|
|
| 42 |
# 1. Binner
|
| 43 |
max_mz: 1000
|
| 44 |
bin_width: 1
|
|
@@ -91,7 +92,7 @@ use_NL: False
|
|
| 91 |
task: 'retrieval'
|
| 92 |
spec_enc: Transformer_Formula # Transformer_MzInt #Transformer_Formula
|
| 93 |
mol_enc: "GNN"
|
| 94 |
-
model:
|
| 95 |
contr_views: [['spec_enc', 'mol_enc']] #[['spec_enc', 'mol_enc'], ['spec_enc', 'NL_spec_enc'], ['mol_enc', 'NL_spec_enc']] #[['spec_enc', 'mol_enc'], ['mol_enc', 'cons_spec_enc'], ['cons_spec_enc', 'spec_enc'], ['fp_enc', 'mol_enc'], ['fp_enc', 'spec_enc'], ['fp_enc', 'cons_spec_enc']]
|
| 96 |
log_only_loss_at_stages: []
|
| 97 |
df_test_path: ""
|
|
@@ -102,20 +103,23 @@ fc_dropout: 0.4
|
|
| 102 |
|
| 103 |
# - Spectra Token encoder
|
| 104 |
hidden_dims: [64, 128]
|
| 105 |
-
|
| 106 |
|
| 107 |
# - Formula-based spec encoders
|
| 108 |
-
formula_dropout:
|
| 109 |
-
formula_dims: [64, 128, 256]
|
| 110 |
-
cross_attn_heads: 2
|
| 111 |
use_cls: False
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# -- GAT params
|
| 114 |
attn_heads: [12,12,12]
|
| 115 |
|
| 116 |
# - Molecule encoder (GNN)
|
| 117 |
-
gnn_channels: [64,128,
|
| 118 |
gnn_type: "gcn"
|
| 119 |
num_gnn_layers: 3
|
| 120 |
gnn_hidden_dim: 512
|
| 121 |
-
gnn_dropout: 0.3
|
|
|
|
| 1 |
# Experiment setup
|
| 2 |
job_key: ''
|
| 3 |
+
run_name: 'crossAttnModel'
|
| 4 |
run_details: ""
|
| 5 |
project_name: ''
|
| 6 |
wandb_entity_name: 'mass-spec-ml'
|
|
|
|
| 10 |
checkpoint_pth:
|
| 11 |
|
| 12 |
# Training setup
|
| 13 |
+
max_epochs: 1000
|
| 14 |
accelerator: 'gpu'
|
| 15 |
devices: [1]
|
| 16 |
log_every_n_steps: 250
|
| 17 |
val_check_interval: 1.0
|
| 18 |
|
| 19 |
# Data paths
|
| 20 |
+
candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
|
| 21 |
+
dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
|
| 22 |
+
subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
|
| 23 |
split_pth:
|
| 24 |
fp_dir_pth:
|
| 25 |
cons_spec_dir_pth:
|
|
|
|
| 28 |
|
| 29 |
# General hyperparameters
|
| 30 |
batch_size: 64
|
| 31 |
+
lr: 2.881339661302105e-05 # 5.0e-05
|
| 32 |
+
weight_decay: 1.1586679936312845e-05
|
| 33 |
+
contr_temp: 0.022772534845886608 # 0.05
|
| 34 |
early_stopping_patience: 300
|
| 35 |
loss_strategy: 'static'
|
| 36 |
num_workers: 50
|
|
|
|
| 39 |
############################## Data transforms ##############################
|
| 40 |
# - Spectra
|
| 41 |
spectra_view: SpecFormula #SpecMzIntTokens #SpecFormula
|
| 42 |
+
formula_source: 'default' # magma_1, magma_all, sirius, default
|
| 43 |
# 1. Binner
|
| 44 |
max_mz: 1000
|
| 45 |
bin_width: 1
|
|
|
|
| 92 |
task: 'retrieval'
|
| 93 |
spec_enc: Transformer_Formula # Transformer_MzInt #Transformer_Formula
|
| 94 |
mol_enc: "GNN"
|
| 95 |
+
model: crossAttenContrastive # "MultiviewContrastive"
|
| 96 |
contr_views: [['spec_enc', 'mol_enc']] #[['spec_enc', 'mol_enc'], ['spec_enc', 'NL_spec_enc'], ['mol_enc', 'NL_spec_enc']] #[['spec_enc', 'mol_enc'], ['mol_enc', 'cons_spec_enc'], ['cons_spec_enc', 'spec_enc'], ['fp_enc', 'mol_enc'], ['fp_enc', 'spec_enc'], ['fp_enc', 'cons_spec_enc']]
|
| 97 |
log_only_loss_at_stages: []
|
| 98 |
df_test_path: ""
|
|
|
|
| 103 |
|
| 104 |
# - Spectra Token encoder
|
| 105 |
hidden_dims: [64, 128]
|
| 106 |
+
|
| 107 |
|
| 108 |
# - Formula-based spec encoders
|
| 109 |
+
formula_dropout: 0.2
|
| 110 |
+
formula_dims: [128, 256, 512] #[64, 128, 256]
|
| 111 |
+
cross_attn_heads: 4 # 2
|
| 112 |
use_cls: False
|
| 113 |
+
peak_dropout: 0.414425691950033 # 0.2
|
| 114 |
+
formula_attn_heads: 4 # 2
|
| 115 |
+
formula_transformer_layers: 2
|
| 116 |
|
| 117 |
# -- GAT params
|
| 118 |
attn_heads: [12,12,12]
|
| 119 |
|
| 120 |
# - Molecule encoder (GNN)
|
| 121 |
+
gnn_channels: [64,128,512]
|
| 122 |
gnn_type: "gcn"
|
| 123 |
num_gnn_layers: 3
|
| 124 |
gnn_hidden_dim: 512
|
| 125 |
+
gnn_dropout: 0.23234950970370824 #0.3
|
mvp/run.sh
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
python test.py
|
| 3 |
-
python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
|
|
|
|
| 1 |
+
python train.py --param_pth params_tmp.yaml
|
| 2 |
+
python test.py --param_pth params_tmp.yaml
|
| 3 |
+
python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json --param_pth params_tmp.yaml
|
mvp/tune.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import yaml
|
| 6 |
+
import optuna
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 12 |
+
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
from pytorch_lightning import Trainer
|
| 15 |
+
from optuna.integration import PyTorchLightningPruningCallback
|
| 16 |
+
from pytorch_lightning.callbacks import Callback
|
| 17 |
+
|
| 18 |
+
from mvp.data.data_module import ContrastiveDataModule
|
| 19 |
+
from mvp.data.datasets import ContrastiveDataset
|
| 20 |
+
from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
|
| 21 |
+
from mvp.utils.models import get_model
|
| 22 |
+
from mvp.definitions import TEST_RESULTS_DIR
|
| 23 |
+
from functools import partial
|
| 24 |
+
from rdkit import RDLogger
|
| 25 |
+
from massspecgym.models.base import Stage
|
| 26 |
+
|
| 27 |
+
# Suppress RDKit warnings
|
| 28 |
+
lg = RDLogger.logger()
|
| 29 |
+
lg.setLevel(RDLogger.CRITICAL)
|
| 30 |
+
|
| 31 |
+
parser = argparse.ArgumentParser()
|
| 32 |
+
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
|
| 33 |
+
parser.add_argument("--n_trials", type=int, default=20)
|
| 34 |
+
|
| 35 |
+
class EpochLossTracker(Callback):
|
| 36 |
+
def __init__(self, trial):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.trial = trial
|
| 39 |
+
self.history = {"train_loss": [], "val_loss": []}
|
| 40 |
+
|
| 41 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
| 42 |
+
if "train_loss" in trainer.callback_metrics:
|
| 43 |
+
self.history["train_loss"].append(
|
| 44 |
+
float(trainer.callback_metrics["train_loss"].cpu().item())
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
| 48 |
+
val_key = f"{Stage.VAL.to_pref()}loss"
|
| 49 |
+
if val_key in trainer.callback_metrics:
|
| 50 |
+
self.history["val_loss"].append(
|
| 51 |
+
float(trainer.callback_metrics[val_key].cpu().item())
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def on_fit_end(self, trainer, pl_module):
|
| 55 |
+
# Attach to trial so save_trial_result can access it
|
| 56 |
+
self.trial.set_user_attr("loss_history", self.history)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SafePruningCallback(PyTorchLightningPruningCallback, Callback):
|
| 61 |
+
"""Wraps Optuna pruning to make it a proper Lightning Callback."""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def setup_logging(log_path):
|
| 66 |
+
"""Setup logging without breaking tqdm progress bars."""
|
| 67 |
+
logger = logging.getLogger()
|
| 68 |
+
logger.setLevel(logging.INFO)
|
| 69 |
+
|
| 70 |
+
# Remove existing handlers (avoid duplicate or wrong outputs)
|
| 71 |
+
if logger.hasHandlers():
|
| 72 |
+
logger.handlers.clear()
|
| 73 |
+
|
| 74 |
+
# File handler
|
| 75 |
+
file_handler = logging.FileHandler(log_path, mode="a")
|
| 76 |
+
file_handler.setFormatter(
|
| 77 |
+
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Console handler (stderr so tqdm stays clean)
|
| 81 |
+
console_handler = logging.StreamHandler(sys.stderr)
|
| 82 |
+
console_handler.setFormatter(
|
| 83 |
+
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
logger.addHandler(file_handler)
|
| 87 |
+
logger.addHandler(console_handler)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def save_trial_result(base_dir, trial, params, duration):
|
| 91 |
+
"""Append trial results to a CSV file after each trial."""
|
| 92 |
+
history_path = os.path.join(base_dir, "trial_history.csv")
|
| 93 |
+
|
| 94 |
+
# Fetch losses from trial user_attrs
|
| 95 |
+
loss_hist = trial.user_attrs.get("loss_history", {})
|
| 96 |
+
record = {
|
| 97 |
+
"number": trial.number,
|
| 98 |
+
"duration_sec": duration,
|
| 99 |
+
"train_loss": loss_hist.get("train_loss", []),
|
| 100 |
+
"val_loss": loss_hist.get("val_loss", []),
|
| 101 |
+
**trial.params,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# Append to CSV safely
|
| 105 |
+
if os.path.exists(history_path):
|
| 106 |
+
df = pd.read_csv(history_path)
|
| 107 |
+
df = pd.concat([df, pd.DataFrame([record])], ignore_index=True)
|
| 108 |
+
else:
|
| 109 |
+
df = pd.DataFrame([record])
|
| 110 |
+
df.to_csv(history_path, index=False)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials):
|
| 114 |
+
start_time = time.time()
|
| 115 |
+
params = base_params.copy()
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
# Training-related params
|
| 119 |
+
params["batch_size"] = trial.suggest_categorical("batch_size", [64, 128])
|
| 120 |
+
params["lr"] = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
|
| 121 |
+
params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
|
| 122 |
+
params["contr_temp"] = trial.suggest_float("contrastive_temp", 0.02, 0.1)
|
| 123 |
+
|
| 124 |
+
# Spectra encoder-related params
|
| 125 |
+
params['peak_dropout'] = trial.suggest_float("peak_dropout", 0.1, 0.5)
|
| 126 |
+
params['formula_attn_heads'] = trial.suggest_categorical("formula_attn_heads", [2, 4])
|
| 127 |
+
params['formula_transformer_layers'] = trial.suggest_categorical("formula_transformer_layers", [2, 4])
|
| 128 |
+
|
| 129 |
+
choice = trial.suggest_categorical(
|
| 130 |
+
"formula_dims",
|
| 131 |
+
["64,128", "512,256", "256,512", "128", "256"]
|
| 132 |
+
)
|
| 133 |
+
params["formula_dims"] = [int(x) for x in choice.split(",")]
|
| 134 |
+
|
| 135 |
+
# Molecule encoder-related params
|
| 136 |
+
params['gnn_dropout'] = trial.suggest_float("gnn_dropout", 0.1, 0.5)
|
| 137 |
+
choice = trial.suggest_categorical(
|
| 138 |
+
"gnn_channels",
|
| 139 |
+
["64,128", "128,256", "256,512", "64,128,128"]
|
| 140 |
+
)
|
| 141 |
+
params["gnn_channels"] = [int(x) for x in choice.split(",")]
|
| 142 |
+
|
| 143 |
+
# Ensure last layer matches final embedding dim
|
| 144 |
+
final_embedding_dim = trial.suggest_categorical("final_embedding_dim", [256, 512])
|
| 145 |
+
params['formula_dims'].append(final_embedding_dim)
|
| 146 |
+
params['gnn_channels'].append(final_embedding_dim)
|
| 147 |
+
|
| 148 |
+
logging.info(f"Formula dims: {params['formula_dims']}")
|
| 149 |
+
logging.info(f"GNN channels: {params['gnn_channels']}")
|
| 150 |
+
|
| 151 |
+
# Init seed
|
| 152 |
+
pl.seed_everything(params["seed"])
|
| 153 |
+
|
| 154 |
+
# Init dataset + datamodule
|
| 155 |
+
spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
|
| 156 |
+
mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
|
| 157 |
+
dataset = get_ms_dataset(params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params)
|
| 158 |
+
|
| 159 |
+
collate_fn = partial(
|
| 160 |
+
ContrastiveDataset.collate_fn,
|
| 161 |
+
spec_enc=params["spec_enc"],
|
| 162 |
+
spectra_view=params["spectra_view"],
|
| 163 |
+
mask_peak_ratio=params["mask_peak_ratio"],
|
| 164 |
+
aug_cands=params["aug_cands"],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
data_module = ContrastiveDataModule(
|
| 168 |
+
dataset=dataset,
|
| 169 |
+
collate_fn=collate_fn,
|
| 170 |
+
split_pth=params["split_pth"],
|
| 171 |
+
batch_size=params["batch_size"],
|
| 172 |
+
num_workers=params["num_workers"],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Init model
|
| 176 |
+
model = get_model(params["model"], params)
|
| 177 |
+
|
| 178 |
+
# Metric to optimize
|
| 179 |
+
callbacks = []
|
| 180 |
+
monitor_metric = f"{Stage.VAL.to_pref()}loss"
|
| 181 |
+
pruning_cb = SafePruningCallback(trial, monitor=monitor_metric)
|
| 182 |
+
callbacks.append(pruning_cb)
|
| 183 |
+
loss_tracker = EpochLossTracker(trial)
|
| 184 |
+
callbacks.append(loss_tracker)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
trainer = Trainer(
|
| 189 |
+
accelerator=params["accelerator"],
|
| 190 |
+
devices=params["devices"],
|
| 191 |
+
max_epochs=params["max_epochs"],
|
| 192 |
+
logger=False,
|
| 193 |
+
enable_checkpointing=False,
|
| 194 |
+
callbacks=callbacks,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
data_module.prepare_data()
|
| 198 |
+
data_module.setup()
|
| 199 |
+
|
| 200 |
+
# Validate before training
|
| 201 |
+
trainer.validate(model, datamodule=data_module)
|
| 202 |
+
|
| 203 |
+
# Fit (may be pruned early)
|
| 204 |
+
trainer.fit(model, datamodule=data_module)
|
| 205 |
+
|
| 206 |
+
# Duration
|
| 207 |
+
duration = time.time() - start_time
|
| 208 |
+
trial_times.append(duration)
|
| 209 |
+
avg_time = sum(trial_times) / len(trial_times)
|
| 210 |
+
remaining = (total_trials - trial.number - 1) * avg_time
|
| 211 |
+
logging.info(f"[Trial {trial.number}] Duration: {duration/60:.2f} min | Avg: {avg_time/60:.2f} min | ETA: {remaining/60:.2f} min")
|
| 212 |
+
|
| 213 |
+
value = trainer.callback_metrics[monitor_metric].item()
|
| 214 |
+
trial.set_user_attr("duration", duration)
|
| 215 |
+
|
| 216 |
+
# Save progress
|
| 217 |
+
save_trial_result(base_dir, trial, base_params, duration, )
|
| 218 |
+
|
| 219 |
+
return value
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
duration = time.time() - start_time
|
| 223 |
+
logging.exception(f"Trial {trial.number} failed: {e}")
|
| 224 |
+
save_trial_result(base_dir, trial, base_params, duration)
|
| 225 |
+
raise
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def main(args):
|
| 229 |
+
with open(args.param_pth) as f:
|
| 230 |
+
params = yaml.load(f, Loader=yaml.FullLoader)
|
| 231 |
+
|
| 232 |
+
now = datetime.datetime.now().strftime("%Y%m%d")
|
| 233 |
+
base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
|
| 234 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 235 |
+
params["experiment_dir"] = base_dir
|
| 236 |
+
|
| 237 |
+
# Setup logging
|
| 238 |
+
log_path = os.path.join(base_dir, "optuna.log")
|
| 239 |
+
setup_logging(log_path)
|
| 240 |
+
|
| 241 |
+
trial_times = []
|
| 242 |
+
|
| 243 |
+
study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
|
| 244 |
+
study.optimize(lambda trial: objective(trial, params, trial_times, base_dir, args.n_trials), n_trials=args.n_trials)
|
| 245 |
+
|
| 246 |
+
# Print best trial
|
| 247 |
+
logging.info("\nBest trial:")
|
| 248 |
+
logging.info(study.best_trial.params)
|
| 249 |
+
|
| 250 |
+
# Merge base params with best trial
|
| 251 |
+
best_params = params.copy()
|
| 252 |
+
best_params.update(study.best_trial.params)
|
| 253 |
+
|
| 254 |
+
# Save best params to YAML
|
| 255 |
+
best_param_path = os.path.join(base_dir, "best_params.yaml")
|
| 256 |
+
with open(best_param_path, "w") as f:
|
| 257 |
+
yaml.dump(best_params, f)
|
| 258 |
+
logging.info(f"\nBest parameters saved to: {best_param_path}")
|
| 259 |
+
logging.info(f"Run training with: python train.py --param_pth {best_param_path}")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
args = parser.parse_args([] if "__file__" not in globals() else None)
|
| 264 |
+
main(args)
|
notebooks/2v1.ipynb
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 33,
|
| 6 |
+
"id": "d3fe3363",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pickle\n",
|
| 11 |
+
"from rdkit import Chem\n",
|
| 12 |
+
"import pandas\n",
|
| 13 |
+
"import matplotlib.pyplot as plt\n",
|
| 14 |
+
"import numpy as np\n",
|
| 15 |
+
"from rdkit.Chem.Draw import MolsToGridImage"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 2,
|
| 21 |
+
"id": "efdf1cc0",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [],
|
| 24 |
+
"source": [
|
| 25 |
+
"with open(\"/data/yzhouc01/FILIP-MS/experiments/20250824_filipContrastive/result_MassSpecGym_retrieval_candidates_formula.pkl\", 'rb') as f:\n",
|
| 26 |
+
" result = pickle.load(f)"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": null,
|
| 32 |
+
"id": "0113e869",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": []
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": 8,
|
| 40 |
+
"id": "de9cacb6",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [
|
| 43 |
+
{
|
| 44 |
+
"data": {
|
| 45 |
+
"text/html": [
|
| 46 |
+
"<div>\n",
|
| 47 |
+
"<style scoped>\n",
|
| 48 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 49 |
+
" vertical-align: middle;\n",
|
| 50 |
+
" }\n",
|
| 51 |
+
"\n",
|
| 52 |
+
" .dataframe tbody tr th {\n",
|
| 53 |
+
" vertical-align: top;\n",
|
| 54 |
+
" }\n",
|
| 55 |
+
"\n",
|
| 56 |
+
" .dataframe thead th {\n",
|
| 57 |
+
" text-align: right;\n",
|
| 58 |
+
" }\n",
|
| 59 |
+
"</style>\n",
|
| 60 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 61 |
+
" <thead>\n",
|
| 62 |
+
" <tr style=\"text-align: right;\">\n",
|
| 63 |
+
" <th></th>\n",
|
| 64 |
+
" <th>identifier</th>\n",
|
| 65 |
+
" <th>candidates</th>\n",
|
| 66 |
+
" <th>scores</th>\n",
|
| 67 |
+
" <th>labels</th>\n",
|
| 68 |
+
" <th>rank</th>\n",
|
| 69 |
+
" </tr>\n",
|
| 70 |
+
" </thead>\n",
|
| 71 |
+
" <tbody>\n",
|
| 72 |
+
" <tr>\n",
|
| 73 |
+
" <th>17551</th>\n",
|
| 74 |
+
" <td>MassSpecGymID0414164</td>\n",
|
| 75 |
+
" <td>[CN(C)[C@H]1[C@@H]2[C@H]([C@@H]3C(=C)C4=C(C=CC...</td>\n",
|
| 76 |
+
" <td>[0.5793027877807617, 0.5367040038108826, 0.487...</td>\n",
|
| 77 |
+
" <td>[True, False, False, False, False, False, Fals...</td>\n",
|
| 78 |
+
" <td>7</td>\n",
|
| 79 |
+
" </tr>\n",
|
| 80 |
+
" <tr>\n",
|
| 81 |
+
" <th>17552</th>\n",
|
| 82 |
+
" <td>MassSpecGymID0414165</td>\n",
|
| 83 |
+
" <td>[CN(C)[C@H]1[C@@H]2[C@H]([C@@H]3C(=C)C4=C(C(=C...</td>\n",
|
| 84 |
+
" <td>[0.38384538888931274, 0.24421672523021698, 0.2...</td>\n",
|
| 85 |
+
" <td>[True, False, False, False, False, False, Fals...</td>\n",
|
| 86 |
+
" <td>31</td>\n",
|
| 87 |
+
" </tr>\n",
|
| 88 |
+
" <tr>\n",
|
| 89 |
+
" <th>17553</th>\n",
|
| 90 |
+
" <td>MassSpecGymID0414166</td>\n",
|
| 91 |
+
" <td>[C[C@H]1/C=C/C=C/2\\CO[C@H]3[C@@]2([C@@H](C=C([...</td>\n",
|
| 92 |
+
" <td>[0.6297411918640137, 0.5269991159439087, 0.183...</td>\n",
|
| 93 |
+
" <td>[True, False, False, False, False, False, Fals...</td>\n",
|
| 94 |
+
" <td>11</td>\n",
|
| 95 |
+
" </tr>\n",
|
| 96 |
+
" <tr>\n",
|
| 97 |
+
" <th>17554</th>\n",
|
| 98 |
+
" <td>MassSpecGymID0414167</td>\n",
|
| 99 |
+
" <td>[C[C@H]1/C=C/C=C/2\\CO[C@H]3[C@@]2([C@@H](C=C([...</td>\n",
|
| 100 |
+
" <td>[0.613699197769165, 0.554176390171051, 0.21911...</td>\n",
|
| 101 |
+
" <td>[True, False, False, False, False, False, Fals...</td>\n",
|
| 102 |
+
" <td>12</td>\n",
|
| 103 |
+
" </tr>\n",
|
| 104 |
+
" <tr>\n",
|
| 105 |
+
" <th>17555</th>\n",
|
| 106 |
+
" <td>MassSpecGymID0414171</td>\n",
|
| 107 |
+
" <td>[C[C@@]1([C@H]2C[C@H]3[C@@H](C(=O)C(=C([C@]3(C...</td>\n",
|
| 108 |
+
" <td>[0.5223979949951172, 0.5548790693283081, 0.512...</td>\n",
|
| 109 |
+
" <td>[True, False, False, False, False, False, Fals...</td>\n",
|
| 110 |
+
" <td>14</td>\n",
|
| 111 |
+
" </tr>\n",
|
| 112 |
+
" </tbody>\n",
|
| 113 |
+
"</table>\n",
|
| 114 |
+
"</div>"
|
| 115 |
+
],
|
| 116 |
+
"text/plain": [
|
| 117 |
+
" identifier \\\n",
|
| 118 |
+
"17551 MassSpecGymID0414164 \n",
|
| 119 |
+
"17552 MassSpecGymID0414165 \n",
|
| 120 |
+
"17553 MassSpecGymID0414166 \n",
|
| 121 |
+
"17554 MassSpecGymID0414167 \n",
|
| 122 |
+
"17555 MassSpecGymID0414171 \n",
|
| 123 |
+
"\n",
|
| 124 |
+
" candidates \\\n",
|
| 125 |
+
"17551 [CN(C)[C@H]1[C@@H]2[C@H]([C@@H]3C(=C)C4=C(C=CC... \n",
|
| 126 |
+
"17552 [CN(C)[C@H]1[C@@H]2[C@H]([C@@H]3C(=C)C4=C(C(=C... \n",
|
| 127 |
+
"17553 [C[C@H]1/C=C/C=C/2\\CO[C@H]3[C@@]2([C@@H](C=C([... \n",
|
| 128 |
+
"17554 [C[C@H]1/C=C/C=C/2\\CO[C@H]3[C@@]2([C@@H](C=C([... \n",
|
| 129 |
+
"17555 [C[C@@]1([C@H]2C[C@H]3[C@@H](C(=O)C(=C([C@]3(C... \n",
|
| 130 |
+
"\n",
|
| 131 |
+
" scores \\\n",
|
| 132 |
+
"17551 [0.5793027877807617, 0.5367040038108826, 0.487... \n",
|
| 133 |
+
"17552 [0.38384538888931274, 0.24421672523021698, 0.2... \n",
|
| 134 |
+
"17553 [0.6297411918640137, 0.5269991159439087, 0.183... \n",
|
| 135 |
+
"17554 [0.613699197769165, 0.554176390171051, 0.21911... \n",
|
| 136 |
+
"17555 [0.5223979949951172, 0.5548790693283081, 0.512... \n",
|
| 137 |
+
"\n",
|
| 138 |
+
" labels rank \n",
|
| 139 |
+
"17551 [True, False, False, False, False, False, Fals... 7 \n",
|
| 140 |
+
"17552 [True, False, False, False, False, False, Fals... 31 \n",
|
| 141 |
+
"17553 [True, False, False, False, False, False, Fals... 11 \n",
|
| 142 |
+
"17554 [True, False, False, False, False, False, Fals... 12 \n",
|
| 143 |
+
"17555 [True, False, False, False, False, False, Fals... 14 "
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
"execution_count": 8,
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"output_type": "execute_result"
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"source": [
|
| 152 |
+
"result.tail()"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": 52,
|
| 158 |
+
"id": "f420c511",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [
|
| 161 |
+
{
|
| 162 |
+
"name": "stderr",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
"[19:30:58] \n",
|
| 166 |
+
"\n",
|
| 167 |
+
"****\n",
|
| 168 |
+
"Pre-condition Violation\n",
|
| 169 |
+
"bad size\n",
|
| 170 |
+
"Violation occurred on line 183 in file /project/build/temp.linux-x86_64-cpython-311/rdkit/Code/GraphMol/MolDraw2D/MolDraw2D.cpp\n",
|
| 171 |
+
"Failed Expression: !legends || legends->size() == mols.size()\n",
|
| 172 |
+
"----------\n",
|
| 173 |
+
"Stacktrace:\n",
|
| 174 |
+
"----------\n",
|
| 175 |
+
"****\n",
|
| 176 |
+
"\n"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"ename": "RuntimeError",
|
| 181 |
+
"evalue": "Pre-condition Violation\n\tbad size\n\tViolation occurred on line 183 in file Code/GraphMol/MolDraw2D/MolDraw2D.cpp\n\tFailed Expression: !legends || legends->size() == mols.size()\n\tRDKIT: 2024.03.5\n\tBOOST: 1_85\n",
|
| 182 |
+
"output_type": "error",
|
| 183 |
+
"traceback": [
|
| 184 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 185 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 186 |
+
"Cell \u001b[0;32mIn[52], line 21\u001b[0m\n\u001b[1;32m 18\u001b[0m cand2 \u001b[38;5;241m=\u001b[39m Chem\u001b[38;5;241m.\u001b[39mMolFromSmiles(candidates[sorted_scores][\u001b[38;5;241m2\u001b[39m])\n\u001b[1;32m 20\u001b[0m mols \u001b[38;5;241m=\u001b[39m [target, cand, cand1]\n\u001b[0;32m---> 21\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mMolsToGridImage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmols\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmolsPerRow\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msubImgSize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m200\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m200\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlegends\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtarget (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mtarget_score\u001b[49m\u001b[38;5;132;43;01m:\u001b[39;49;00m\u001b[38;5;124;43m.3f\u001b[39;49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcand@1 (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mcand1_score\u001b[49m\u001b[38;5;132;43;01m:\u001b[39;49;00m\u001b[38;5;124;43m.3f\u001b[39;49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n\u001b[1;32m 24\u001b[0m display(img)\n",
|
| 187 |
+
"File \u001b[0;32m/data/yzc-conda/spec/lib/python3.11/site-packages/rdkit/Chem/Draw/IPythonConsole.py:271\u001b[0m, in \u001b[0;36mShowMols\u001b[0;34m(mols, maxMols, **kwargs)\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdrawOptions\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m kwargs:\n\u001b[1;32m 269\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdrawOptions\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m drawOptions\n\u001b[0;32m--> 271\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmols\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m InteractiveRenderer\u001b[38;5;241m.\u001b[39misEnabled():\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m HTML(res)\n",
|
| 188 |
+
"File \u001b[0;32m/data/yzc-conda/spec/lib/python3.11/site-packages/rdkit/Chem/Draw/__init__.py:821\u001b[0m, in \u001b[0;36mMolsToGridImage\u001b[0;34m(mols, molsPerRow, subImgSize, legends, highlightAtomLists, highlightBondLists, useSVG, returnPNG, **kwargs)\u001b[0m\n\u001b[1;32m 817\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _MolsToGridSVG(mols, molsPerRow\u001b[38;5;241m=\u001b[39mmolsPerRow, subImgSize\u001b[38;5;241m=\u001b[39msubImgSize, legends\u001b[38;5;241m=\u001b[39mlegends,\n\u001b[1;32m 818\u001b[0m highlightAtomLists\u001b[38;5;241m=\u001b[39mhighlightAtomLists,\n\u001b[1;32m 819\u001b[0m highlightBondLists\u001b[38;5;241m=\u001b[39mhighlightBondLists, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 820\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 821\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_MolsToGridImage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmols\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmolsPerRow\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmolsPerRow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msubImgSize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubImgSize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlegends\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlegends\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 822\u001b[0m \u001b[43m \u001b[49m\u001b[43mhighlightAtomLists\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhighlightAtomLists\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 823\u001b[0m \u001b[43m \u001b[49m\u001b[43mhighlightBondLists\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhighlightBondLists\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturnPNG\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturnPNG\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 189 |
+
"File \u001b[0;32m/data/yzc-conda/spec/lib/python3.11/site-packages/rdkit/Chem/Draw/__init__.py:567\u001b[0m, in \u001b[0;36m_MolsToGridImage\u001b[0;34m(mols, molsPerRow, subImgSize, legends, highlightAtomLists, highlightBondLists, drawOptions, returnPNG, **kwargs)\u001b[0m\n\u001b[1;32m 565\u001b[0m \u001b[38;5;28msetattr\u001b[39m(dops, k, v)\n\u001b[1;32m 566\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m kwargs[k]\n\u001b[0;32m--> 567\u001b[0m \u001b[43md2d\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mDrawMolecules\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmols\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlegends\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlegends\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhighlightAtoms\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhighlightAtomLists\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 568\u001b[0m \u001b[43m \u001b[49m\u001b[43mhighlightBonds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhighlightBondLists\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 569\u001b[0m d2d\u001b[38;5;241m.\u001b[39mFinishDrawing()\n\u001b[1;32m 570\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m returnPNG:\n",
|
| 190 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Pre-condition Violation\n\tbad size\n\tViolation occurred on line 183 in file Code/GraphMol/MolDraw2D/MolDraw2D.cpp\n\tFailed Expression: !legends || legends->size() == mols.size()\n\tRDKIT: 2024.03.5\n\tBOOST: 1_85\n"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"data": {
|
| 195 |
+
"image/png": "",
|
| 196 |
+
"text/plain": [
|
| 197 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"output_type": "display_data"
|
| 202 |
+
}
|
| 203 |
+
],
|
| 204 |
+
"source": [
|
| 205 |
+
"row = result[result['rank'] ==2].sample(1)\n",
|
| 206 |
+
"scores = row['scores'].iloc[0]\n",
|
| 207 |
+
"plt.hist(scores, bins=np.arange(0,1,0.01))\n",
|
| 208 |
+
"ax = plt.gca()\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"target_score = scores[0]\n",
|
| 211 |
+
"cand1_score = np.max(scores)\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"plt.text(0.665, 0.9, f'target: {target_score:.3f}',transform=ax.transAxes,)\n",
|
| 214 |
+
"plt.text(0.665, 0.86, f'cand@1: {cand1_score:.3f}',transform=ax.transAxes,)\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"candidates = np.array(row['candidates'].iloc[0])\n",
|
| 217 |
+
"target = Chem.MolFromSmiles(candidates[0])\n",
|
| 218 |
+
"sorted_scores = np.argsort(scores)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"cand = Chem.MolFromSmiles(candidates[sorted_scores][0])\n",
|
| 221 |
+
"cand1 = Chem.MolFromSmiles(candidates[sorted_scores][1])\n",
|
| 222 |
+
"cand2 = Chem.MolFromSmiles(candidates[sorted_scores][2])\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"mols = [target, cand, cand1]\n",
|
| 225 |
+
"img = MolsToGridImage(mols, molsPerRow=3, subImgSize=(200, 200), legends=[f'target ({target_score:.3f})', f'cand@1 ({cand1_score:.3f})'])\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"plt.show()\n",
|
| 228 |
+
"display(img)"
|
| 229 |
+
]
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"cell_type": "code",
|
| 233 |
+
"execution_count": null,
|
| 234 |
+
"id": "541cd229",
|
| 235 |
+
"metadata": {},
|
| 236 |
+
"outputs": [],
|
| 237 |
+
"source": []
|
| 238 |
+
}
|
| 239 |
+
],
|
| 240 |
+
"metadata": {
|
| 241 |
+
"kernelspec": {
|
| 242 |
+
"display_name": "spec",
|
| 243 |
+
"language": "python",
|
| 244 |
+
"name": "python3"
|
| 245 |
+
},
|
| 246 |
+
"language_info": {
|
| 247 |
+
"codemirror_mode": {
|
| 248 |
+
"name": "ipython",
|
| 249 |
+
"version": 3
|
| 250 |
+
},
|
| 251 |
+
"file_extension": ".py",
|
| 252 |
+
"mimetype": "text/x-python",
|
| 253 |
+
"name": "python",
|
| 254 |
+
"nbconvert_exporter": "python",
|
| 255 |
+
"pygments_lexer": "ipython3",
|
| 256 |
+
"version": "3.11.7"
|
| 257 |
+
}
|
| 258 |
+
},
|
| 259 |
+
"nbformat": 4,
|
| 260 |
+
"nbformat_minor": 5
|
| 261 |
+
}
|
notebooks/UMAP_spectra_embeddings.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/attribute_viz.ipynb
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "76c4ed82",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"# preprocessing code\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"# input is a labeled formulas file from mist code\n",
|
| 13 |
+
"# molecule is smiles\n",
|
| 14 |
+
"\n"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"id": "04ea680e",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"# load data"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"id": "990ccd94",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"# load model"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"id": "77334bf5",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"# encode spectra and molecules\n",
|
| 45 |
+
"# vilisualize attributes"
|
| 46 |
+
]
|
| 47 |
+
}
|
| 48 |
+
],
|
| 49 |
+
"metadata": {
|
| 50 |
+
"language_info": {
|
| 51 |
+
"name": "python"
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"nbformat": 4,
|
| 55 |
+
"nbformat_minor": 5
|
| 56 |
+
}
|
notebooks/filip_viz.ipynb
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import torch\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"import plotly.graph_objects as go\n",
|
| 12 |
+
"from plotly.subplots import make_subplots\n",
|
| 13 |
+
"from rdkit import Chem\n",
|
| 14 |
+
"from rdkit.Chem import rdDepictor\n",
|
| 15 |
+
"from rdkit.Chem.Draw import rdMolDraw2D"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "markdown",
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"source": [
|
| 22 |
+
"## load model and dataset"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": 2,
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"outputs": [
|
| 30 |
+
{
|
| 31 |
+
"name": "stdout",
|
| 32 |
+
"output_type": "stream",
|
| 33 |
+
"text": [
|
| 34 |
+
"Data path: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv\n",
|
| 35 |
+
"Processing formula spectra\n"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"name": "stderr",
|
| 40 |
+
"output_type": "stream",
|
| 41 |
+
"text": [
|
| 42 |
+
"100%|██████████| 213548/213548 [00:16<00:00, 13048.41it/s]\n",
|
| 43 |
+
"/data/yzhouc01/FILIP-MS/mvp/data/datasets.py:221: SettingWithCopyWarning: \n",
|
| 44 |
+
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 45 |
+
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
| 48 |
+
" tmp_df['spec'] = tmp_df.apply(lambda row: data_utils.make_tmp_subformula_spectra(row), axis=1)\n"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"name": "stdout",
|
| 53 |
+
"output_type": "stream",
|
| 54 |
+
"text": [
|
| 55 |
+
"Loaded Model from checkpoint\n"
|
| 56 |
+
]
|
| 57 |
+
}
|
| 58 |
+
],
|
| 59 |
+
"source": [
|
| 60 |
+
"import sys\n",
|
| 61 |
+
"sys.path.insert(0, \"/data/yzhouc01/MassSpecGym\")\n",
|
| 62 |
+
"sys.path.insert(0, \"/data/yzhouc01/FILIP-MS\")\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"from rdkit import RDLogger\n",
|
| 65 |
+
"import pytorch_lightning as pl\n",
|
| 66 |
+
"from pytorch_lightning import Trainer\n",
|
| 67 |
+
"from massspecgym.models.base import Stage\n",
|
| 68 |
+
"import os\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset\n",
|
| 71 |
+
"from mvp.utils.models import get_model\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"from mvp.definitions import TEST_RESULTS_DIR\n",
|
| 74 |
+
"import yaml\n",
|
| 75 |
+
"from functools import partial\n",
|
| 76 |
+
"# Suppress RDKit warnings and errors\n",
|
| 77 |
+
"lg = RDLogger.logger()\n",
|
| 78 |
+
"lg.setLevel(RDLogger.CRITICAL)\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"# Load model and data\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'\n",
|
| 83 |
+
"with open(param_pth) as f:\n",
|
| 84 |
+
" params = yaml.load(f, Loader=yaml.FullLoader)\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"spec_featurizer = get_spec_featurizer(params['spectra_view'], params)\n",
|
| 87 |
+
"mol_featurizer = get_mol_featurizer(params['molecule_view'], params)\n",
|
| 88 |
+
"dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# load model\n",
|
| 92 |
+
"import torch \n",
|
| 93 |
+
"checkpoint_pth = \"/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt\"\n",
|
| 94 |
+
"params['checkpoint_pth'] = checkpoint_pth\n",
|
| 95 |
+
"model = get_model(params['model'], params)"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "markdown",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"## mol/spec embeddings"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": 7,
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"outputs": [
|
| 110 |
+
{
|
| 111 |
+
"name": "stdout",
|
| 112 |
+
"output_type": "stream",
|
| 113 |
+
"text": [
|
| 114 |
+
"Atom 0: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 115 |
+
"Atom 1: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 116 |
+
"Atom 2: O, Graph node features: tensor([0.1600, 0.0000, 0.0000, 1.0000, 0.0000])\n",
|
| 117 |
+
"Atom 3: N, Graph node features: tensor([0.1401, 0.0000, 0.0000, 0.0000, 1.0000])\n",
|
| 118 |
+
"Atom 4: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 119 |
+
"Atom 5: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 120 |
+
"Atom 6: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 121 |
+
"Atom 7: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 122 |
+
"Atom 8: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 123 |
+
"Atom 9: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 124 |
+
"Atom 10: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 125 |
+
"Atom 11: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 126 |
+
"Atom 12: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 127 |
+
"Atom 13: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 128 |
+
"Atom 14: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 129 |
+
"Atom 15: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 130 |
+
"Atom 16: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n",
|
| 131 |
+
"Atom 17: O, Graph node features: tensor([0.1600, 0.0000, 0.0000, 1.0000, 0.0000])\n",
|
| 132 |
+
"Atom 18: O, Graph node features: tensor([0.1600, 0.0000, 0.0000, 1.0000, 0.0000])\n",
|
| 133 |
+
"Atom 19: O, Graph node features: tensor([0.1600, 0.0000, 0.0000, 1.0000, 0.0000])\n",
|
| 134 |
+
"Atom 20: C, Graph node features: tensor([0.1201, 0.0000, 1.0000, 0.0000, 0.0000])\n"
|
| 135 |
+
]
|
| 136 |
+
}
|
| 137 |
+
],
|
| 138 |
+
"source": [
|
| 139 |
+
"# sanity check, rdkit order is preserved\n",
|
| 140 |
+
"i = 0\n",
|
| 141 |
+
"s = dataset.metadata.iloc[i]['smiles']\n",
|
| 142 |
+
"mol = Chem.MolFromSmiles(s, sanitize=True) \n",
|
| 143 |
+
"g = dataset[i]['mol']\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"# Compare RDKit atoms vs graph nodes\n",
|
| 146 |
+
"for atom in mol.GetAtoms():\n",
|
| 147 |
+
" idx = atom.GetIdx()\n",
|
| 148 |
+
" print(f\"Atom {idx}: {atom.GetSymbol()}, Graph node features: {g.ndata['h'][idx][:5]}\")"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "code",
|
| 153 |
+
"execution_count": 5,
|
| 154 |
+
"metadata": {},
|
| 155 |
+
"outputs": [],
|
| 156 |
+
"source": [
|
| 157 |
+
"import numpy as np\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"# Atomic masses corresponding to your atom_labels\n",
|
| 160 |
+
"ATOM_LABELS = ['H', 'C', 'O', 'N', 'P', 'S', 'Cl', 'F', 'Br', 'I', 'B', 'As', 'Si', 'Se']\n",
|
| 161 |
+
"ATOM_MASSES = np.array([\n",
|
| 162 |
+
" 1.0078, 12.0000, 15.9949, 14.0031, 30.9738, 31.9721, \n",
|
| 163 |
+
" 35.45, 18.9984, 79.90, 126.90, 10.811, 74.9216, 28.085, 78.96\n",
|
| 164 |
+
"])\n",
|
| 165 |
+
"norm_vector = [102.0, 59.0, 25.0, 13.0, 3.0, 6.0, 6.0, 17.0, 4.0, 4.0, 1.0, 1.0, 5.0, 2.0]\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"def spectra_from_encoding(spectral_tensor, norm_vector=norm_vector):\n",
|
| 168 |
+
" \"\"\"\n",
|
| 169 |
+
" Convert encoded spectra (num_peaks x 15) into m/z, intensities, and molecular formulas.\n",
|
| 170 |
+
" Can undo normalization if a norm_vector is provided.\n",
|
| 171 |
+
" \n",
|
| 172 |
+
" Args:\n",
|
| 173 |
+
" spectral_tensor (np.ndarray or torch.Tensor): [num_peaks, 15]\n",
|
| 174 |
+
" norm_vector (np.ndarray or list): length 14, normalization factor for each atom\n",
|
| 175 |
+
" \n",
|
| 176 |
+
" Returns:\n",
|
| 177 |
+
" mzs (list of float): list of m/z values\n",
|
| 178 |
+
" intensities (list of float): list of intensities\n",
|
| 179 |
+
" formulas (list of str): molecular formula strings\n",
|
| 180 |
+
" \"\"\"\n",
|
| 181 |
+
" if hasattr(spectral_tensor, \"detach\"):\n",
|
| 182 |
+
" spectral_tensor = spectral_tensor.detach().cpu().numpy()\n",
|
| 183 |
+
" \n",
|
| 184 |
+
" counts = spectral_tensor[:, :14] # atom counts\n",
|
| 185 |
+
" intensities = spectral_tensor[:, 14] # last col = intensity\n",
|
| 186 |
+
" \n",
|
| 187 |
+
" # Undo normalization\n",
|
| 188 |
+
" if norm_vector is not None:\n",
|
| 189 |
+
" counts = counts * np.array(norm_vector)\n",
|
| 190 |
+
" \n",
|
| 191 |
+
" # Compute m/z\n",
|
| 192 |
+
" mzs = (counts * ATOM_MASSES).sum(axis=1)\n",
|
| 193 |
+
" \n",
|
| 194 |
+
" # Build molecular formula strings\n",
|
| 195 |
+
" formulas = []\n",
|
| 196 |
+
" for peak_counts in counts:\n",
|
| 197 |
+
" formula_parts = []\n",
|
| 198 |
+
" for elem, count in zip(ATOM_LABELS, peak_counts):\n",
|
| 199 |
+
" n = int(round(count))\n",
|
| 200 |
+
" if n > 0:\n",
|
| 201 |
+
" formula_parts.append(f\"{elem}{n if n > 1 else ''}\")\n",
|
| 202 |
+
" formulas.append(\"\".join(formula_parts) if formula_parts else \"Unknown\")\n",
|
| 203 |
+
" \n",
|
| 204 |
+
" return mzs.tolist(), intensities.tolist(), formulas"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "markdown",
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"source": [
|
| 211 |
+
"## visualization"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": 8,
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
+
"source": [
|
| 220 |
+
"\n",
|
| 221 |
+
"import torch.nn.functional as F\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"def mol_to_graph_coords(mol):\n",
|
| 224 |
+
" \"\"\"Return atom coordinates and bond list for a molecule.\"\"\"\n",
|
| 225 |
+
" rdDepictor.Compute2DCoords(mol)\n",
|
| 226 |
+
" conf = mol.GetConformer()\n",
|
| 227 |
+
" coords = {i: conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())}\n",
|
| 228 |
+
" bonds = [(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds()]\n",
|
| 229 |
+
" return coords, bonds\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"def interactive_attention_visualization(spectral_embeds, graph_embeds, \n",
|
| 232 |
+
" peak_mzs, peak_intensities, peak_formulas, mol):\n",
|
| 233 |
+
" \"\"\"\n",
|
| 234 |
+
" Interactive visualization of peak-node similarity with color scale legend.\n",
|
| 235 |
+
" - Clicking a peak recolors nodes by similarity\n",
|
| 236 |
+
" - Clicking a node recolors peaks by similarity\n",
|
| 237 |
+
" \"\"\"\n",
|
| 238 |
+
" # Similarity matrix\n",
|
| 239 |
+
" spectral_embeds = F.normalize(spectral_embeds, p=2, dim=-1)\n",
|
| 240 |
+
" graph_embeds = F.normalize(graph_embeds, p=2, dim=-1)\n",
|
| 241 |
+
" \n",
|
| 242 |
+
" similarity = torch.matmul(spectral_embeds, graph_embeds.T).detach().cpu().numpy()\n",
|
| 243 |
+
" sim_norm = (similarity - similarity.min()) / (similarity.max() - similarity.min() + 1e-8)\n",
|
| 244 |
+
" \n",
|
| 245 |
+
" num_peaks, num_nodes = similarity.shape\n",
|
| 246 |
+
" \n",
|
| 247 |
+
" # --- Molecule graph ---\n",
|
| 248 |
+
" coords, bonds = mol_to_graph_coords(mol)\n",
|
| 249 |
+
" atom_labels = [a.GetSymbol() for a in mol.GetAtoms()]\n",
|
| 250 |
+
" atom_x = [coords[i].x for i in range(num_nodes)]\n",
|
| 251 |
+
" atom_y = [coords[i].y for i in range(num_nodes)]\n",
|
| 252 |
+
" \n",
|
| 253 |
+
" # --- Spectrum trace ---\n",
|
| 254 |
+
" spectrum_trace = go.Bar(\n",
|
| 255 |
+
" x=peak_mzs,\n",
|
| 256 |
+
" y=peak_intensities,\n",
|
| 257 |
+
" name='peak',\n",
|
| 258 |
+
" marker=dict(color=\"lightgray\", colorscale=\"Viridis\", cmin=0, cmax=1,\n",
|
| 259 |
+
" colorbar=dict(title=\"Similarity\", len=0.8, y=0.5)),\n",
|
| 260 |
+
" hovertext=[f\"Formula {f}\" for f in peak_formulas],\n",
|
| 261 |
+
" customdata=list(range(num_peaks)) # peak index\n",
|
| 262 |
+
" )\n",
|
| 263 |
+
" \n",
|
| 264 |
+
" # --- Graph nodes ---\n",
|
| 265 |
+
" graph_nodes = go.Scatter(\n",
|
| 266 |
+
" x=atom_x, y=atom_y,\n",
|
| 267 |
+
" mode=\"markers+text\",\n",
|
| 268 |
+
" name='node',\n",
|
| 269 |
+
" text=atom_labels,\n",
|
| 270 |
+
" textposition=\"middle center\",\n",
|
| 271 |
+
" marker=dict(size=20, color=\"lightgray\", colorscale=\"Viridis\", cmin=0, cmax=1,\n",
|
| 272 |
+
" colorbar=dict(title=\"Similarity\", len=0.8, y=0.5)),\n",
|
| 273 |
+
" customdata=list(range(num_nodes)),\n",
|
| 274 |
+
" # hovertext=[f\"Atom {i} ({label})\" for i, label in enumerate(atom_labels)]\n",
|
| 275 |
+
" )\n",
|
| 276 |
+
" \n",
|
| 277 |
+
" # --- Graph bonds ---\n",
|
| 278 |
+
" edge_x, edge_y = [], []\n",
|
| 279 |
+
" for i, j in bonds:\n",
|
| 280 |
+
" edge_x += [coords[i].x, coords[j].x, None]\n",
|
| 281 |
+
" edge_y += [coords[i].y, coords[j].y, None]\n",
|
| 282 |
+
" graph_edges = go.Scatter(\n",
|
| 283 |
+
" x=edge_x, y=edge_y,\n",
|
| 284 |
+
" mode=\"lines\", line=dict(color=\"gray\", width=2),\n",
|
| 285 |
+
" hoverinfo=\"none\", showlegend=False\n",
|
| 286 |
+
" )\n",
|
| 287 |
+
" \n",
|
| 288 |
+
" # --- Subplots ---\n",
|
| 289 |
+
" fig = make_subplots(rows=1, cols=2, subplot_titles=(\"Spectrum\", \"Molecule\"), \n",
|
| 290 |
+
" column_widths=[0.6, 0.4])\n",
|
| 291 |
+
" \n",
|
| 292 |
+
" fig.add_trace(spectrum_trace, row=1, col=1)\n",
|
| 293 |
+
" fig.add_trace(graph_edges, row=1, col=2)\n",
|
| 294 |
+
" fig.add_trace(graph_nodes, row=1, col=2)\n",
|
| 295 |
+
" \n",
|
| 296 |
+
" fig.update_xaxes(title=\"m/z\", row=1, col=1)\n",
|
| 297 |
+
" fig.update_yaxes(title=\"Intensity\", row=1, col=1)\n",
|
| 298 |
+
" fig.update_xaxes(visible=False, row=1, col=2)\n",
|
| 299 |
+
" fig.update_yaxes(visible=False, row=1, col=2)\n",
|
| 300 |
+
" \n",
|
| 301 |
+
" fig.update_layout(title=\"Peak ↔ Node Similarity\", showlegend=False)\n",
|
| 302 |
+
" \n",
|
| 303 |
+
" # --- Interactivity ---\n",
|
| 304 |
+
" from ipywidgets import VBox\n",
|
| 305 |
+
" fw = go.FigureWidget(fig)\n",
|
| 306 |
+
"\n",
|
| 307 |
+
" def highlight_nodes(trace, points, selector):\n",
|
| 308 |
+
" \"\"\"Click on peak → recolor nodes\"\"\"\n",
|
| 309 |
+
" if points.point_inds:\n",
|
| 310 |
+
" peak_idx = points.point_inds[0]\n",
|
| 311 |
+
" scores = sim_norm[peak_idx, :]\n",
|
| 312 |
+
" with fw.batch_update():\n",
|
| 313 |
+
" fw.data[2].marker.color = scores\n",
|
| 314 |
+
" fw.data[0].marker.color = [\"red\" if i == peak_idx else \"lightgray\" for i in range(num_peaks)]\n",
|
| 315 |
+
"\n",
|
| 316 |
+
" def highlight_peaks(trace, points, selector):\n",
|
| 317 |
+
" \"\"\"Click on node → recolor peaks\"\"\"\n",
|
| 318 |
+
" if points.point_inds:\n",
|
| 319 |
+
" node_idx = points.point_inds[0]\n",
|
| 320 |
+
" scores = sim_norm[:, node_idx]\n",
|
| 321 |
+
" with fw.batch_update():\n",
|
| 322 |
+
" fw.data[0].marker.color = scores\n",
|
| 323 |
+
" fw.data[2].marker.color = [\"red\" if i == node_idx else \"lightgray\" for i in range(num_nodes)]\n",
|
| 324 |
+
" \n",
|
| 325 |
+
" fw.data[0].on_click(highlight_nodes) # spectrum\n",
|
| 326 |
+
" fw.data[2].on_click(highlight_peaks) # nodes\n",
|
| 327 |
+
" \n",
|
| 328 |
+
" return fw\n"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"cell_type": "code",
|
| 333 |
+
"execution_count": 7,
|
| 334 |
+
"metadata": {},
|
| 335 |
+
"outputs": [
|
| 336 |
+
{
|
| 337 |
+
"data": {
|
| 338 |
+
"text/html": [
|
| 339 |
+
"<div>\n",
|
| 340 |
+
"<style scoped>\n",
|
| 341 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 342 |
+
" vertical-align: middle;\n",
|
| 343 |
+
" }\n",
|
| 344 |
+
"\n",
|
| 345 |
+
" .dataframe tbody tr th {\n",
|
| 346 |
+
" vertical-align: top;\n",
|
| 347 |
+
" }\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" .dataframe thead th {\n",
|
| 350 |
+
" text-align: right;\n",
|
| 351 |
+
" }\n",
|
| 352 |
+
"</style>\n",
|
| 353 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 354 |
+
" <thead>\n",
|
| 355 |
+
" <tr style=\"text-align: right;\">\n",
|
| 356 |
+
" <th></th>\n",
|
| 357 |
+
" <th>identifier</th>\n",
|
| 358 |
+
" <th>mzs</th>\n",
|
| 359 |
+
" <th>intensities</th>\n",
|
| 360 |
+
" <th>smiles</th>\n",
|
| 361 |
+
" <th>inchikey</th>\n",
|
| 362 |
+
" <th>formula</th>\n",
|
| 363 |
+
" <th>precursor_formula</th>\n",
|
| 364 |
+
" <th>parent_mass</th>\n",
|
| 365 |
+
" <th>precursor_mz</th>\n",
|
| 366 |
+
" <th>adduct</th>\n",
|
| 367 |
+
" <th>instrument_type</th>\n",
|
| 368 |
+
" <th>collision_energy</th>\n",
|
| 369 |
+
" <th>fold</th>\n",
|
| 370 |
+
" <th>simulation_challenge</th>\n",
|
| 371 |
+
" <th>formulas</th>\n",
|
| 372 |
+
" <th>formula_mzs</th>\n",
|
| 373 |
+
" <th>formula_intensities</th>\n",
|
| 374 |
+
" </tr>\n",
|
| 375 |
+
" </thead>\n",
|
| 376 |
+
" <tbody>\n",
|
| 377 |
+
" <tr>\n",
|
| 378 |
+
" <th>76895</th>\n",
|
| 379 |
+
" <td>MassSpecGymID0098304</td>\n",
|
| 380 |
+
" <td>65.0386,77.0387,79.0543,80.0621,82.0414,89.038...</td>\n",
|
| 381 |
+
" <td>0.07907907907907907,0.05905905905905906,0.1791...</td>\n",
|
| 382 |
+
" <td>COC1=C(C=CC(=C1)/C=C/C=O)O</td>\n",
|
| 383 |
+
" <td>DKZBBWMURDFHNE</td>\n",
|
| 384 |
+
" <td>C10H10O3</td>\n",
|
| 385 |
+
" <td>C10H11O3</td>\n",
|
| 386 |
+
" <td>178.063024</td>\n",
|
| 387 |
+
" <td>179.0703</td>\n",
|
| 388 |
+
" <td>[M+H]+</td>\n",
|
| 389 |
+
" <td>Orbitrap</td>\n",
|
| 390 |
+
" <td>34.023357</td>\n",
|
| 391 |
+
" <td>train</td>\n",
|
| 392 |
+
" <td>True</td>\n",
|
| 393 |
+
" <td>[C5H4, C6H4, C6H6, C6H7, C5H5O, C7H4, C7H5, C7...</td>\n",
|
| 394 |
+
" <td>[65.0386, 77.0387, 79.0543, 80.0621, 82.0414, ...</td>\n",
|
| 395 |
+
" <td>[0.281209886028212, 0.24302057526061452, 0.423...</td>\n",
|
| 396 |
+
" </tr>\n",
|
| 397 |
+
" <tr>\n",
|
| 398 |
+
" <th>72767</th>\n",
|
| 399 |
+
" <td>MassSpecGymID0092123</td>\n",
|
| 400 |
+
" <td>112.0506</td>\n",
|
| 401 |
+
" <td>1.0</td>\n",
|
| 402 |
+
" <td>C1[C@@H](O[C@@H](S1)CO)N2C=CC(=NC2=O)N</td>\n",
|
| 403 |
+
" <td>JTEGQNOMFQHVDC</td>\n",
|
| 404 |
+
" <td>C8H11N3O3S</td>\n",
|
| 405 |
+
" <td>C8H12N3O3S</td>\n",
|
| 406 |
+
" <td>229.052124</td>\n",
|
| 407 |
+
" <td>230.0594</td>\n",
|
| 408 |
+
" <td>[M+H]+</td>\n",
|
| 409 |
+
" <td>Orbitrap</td>\n",
|
| 410 |
+
" <td>15.000000</td>\n",
|
| 411 |
+
" <td>train</td>\n",
|
| 412 |
+
" <td>True</td>\n",
|
| 413 |
+
" <td>[C4H5N3O]</td>\n",
|
| 414 |
+
" <td>[112.0506]</td>\n",
|
| 415 |
+
" <td>[1.0]</td>\n",
|
| 416 |
+
" </tr>\n",
|
| 417 |
+
" <tr>\n",
|
| 418 |
+
" <th>221715</th>\n",
|
| 419 |
+
" <td>MassSpecGymID0401545</td>\n",
|
| 420 |
+
" <td>50.992828,51.670795,51.675632,51.678509,51.681...</td>\n",
|
| 421 |
+
" <td>0.0006785717253652819,0.001297853734957549,0.0...</td>\n",
|
| 422 |
+
" <td>CC(=O)N1CCC(CC1)C(=O)O</td>\n",
|
| 423 |
+
" <td>WFCLWJHOKCQYOQ</td>\n",
|
| 424 |
+
" <td>C8H13NO3</td>\n",
|
| 425 |
+
" <td>C8H14NO3</td>\n",
|
| 426 |
+
" <td>171.089724</td>\n",
|
| 427 |
+
" <td>172.0970</td>\n",
|
| 428 |
+
" <td>[M+H]+</td>\n",
|
| 429 |
+
" <td>Orbitrap</td>\n",
|
| 430 |
+
" <td>NaN</td>\n",
|
| 431 |
+
" <td>train</td>\n",
|
| 432 |
+
" <td>False</td>\n",
|
| 433 |
+
" <td>[C5H9N, C5H5NO2, C5H5NO2, C6H7O2, C6H7O2, C6H7...</td>\n",
|
| 434 |
+
" <td>[84.080452, 112.038101, 112.041489, 112.048691...</td>\n",
|
| 435 |
+
" <td>[0.11135079703351926, 0.060621778264910706, 0....</td>\n",
|
| 436 |
+
" </tr>\n",
|
| 437 |
+
" </tbody>\n",
|
| 438 |
+
"</table>\n",
|
| 439 |
+
"</div>"
|
| 440 |
+
],
|
| 441 |
+
"text/plain": [
|
| 442 |
+
" identifier \\\n",
|
| 443 |
+
"76895 MassSpecGymID0098304 \n",
|
| 444 |
+
"72767 MassSpecGymID0092123 \n",
|
| 445 |
+
"221715 MassSpecGymID0401545 \n",
|
| 446 |
+
"\n",
|
| 447 |
+
" mzs \\\n",
|
| 448 |
+
"76895 65.0386,77.0387,79.0543,80.0621,82.0414,89.038... \n",
|
| 449 |
+
"72767 112.0506 \n",
|
| 450 |
+
"221715 50.992828,51.670795,51.675632,51.678509,51.681... \n",
|
| 451 |
+
"\n",
|
| 452 |
+
" intensities \\\n",
|
| 453 |
+
"76895 0.07907907907907907,0.05905905905905906,0.1791... \n",
|
| 454 |
+
"72767 1.0 \n",
|
| 455 |
+
"221715 0.0006785717253652819,0.001297853734957549,0.0... \n",
|
| 456 |
+
"\n",
|
| 457 |
+
" smiles inchikey formula \\\n",
|
| 458 |
+
"76895 COC1=C(C=CC(=C1)/C=C/C=O)O DKZBBWMURDFHNE C10H10O3 \n",
|
| 459 |
+
"72767 C1[C@@H](O[C@@H](S1)CO)N2C=CC(=NC2=O)N JTEGQNOMFQHVDC C8H11N3O3S \n",
|
| 460 |
+
"221715 CC(=O)N1CCC(CC1)C(=O)O WFCLWJHOKCQYOQ C8H13NO3 \n",
|
| 461 |
+
"\n",
|
| 462 |
+
" precursor_formula parent_mass precursor_mz adduct instrument_type \\\n",
|
| 463 |
+
"76895 C10H11O3 178.063024 179.0703 [M+H]+ Orbitrap \n",
|
| 464 |
+
"72767 C8H12N3O3S 229.052124 230.0594 [M+H]+ Orbitrap \n",
|
| 465 |
+
"221715 C8H14NO3 171.089724 172.0970 [M+H]+ Orbitrap \n",
|
| 466 |
+
"\n",
|
| 467 |
+
" collision_energy fold simulation_challenge \\\n",
|
| 468 |
+
"76895 34.023357 train True \n",
|
| 469 |
+
"72767 15.000000 train True \n",
|
| 470 |
+
"221715 NaN train False \n",
|
| 471 |
+
"\n",
|
| 472 |
+
" formulas \\\n",
|
| 473 |
+
"76895 [C5H4, C6H4, C6H6, C6H7, C5H5O, C7H4, C7H5, C7... \n",
|
| 474 |
+
"72767 [C4H5N3O] \n",
|
| 475 |
+
"221715 [C5H9N, C5H5NO2, C5H5NO2, C6H7O2, C6H7O2, C6H7... \n",
|
| 476 |
+
"\n",
|
| 477 |
+
" formula_mzs \\\n",
|
| 478 |
+
"76895 [65.0386, 77.0387, 79.0543, 80.0621, 82.0414, ... \n",
|
| 479 |
+
"72767 [112.0506] \n",
|
| 480 |
+
"221715 [84.080452, 112.038101, 112.041489, 112.048691... \n",
|
| 481 |
+
"\n",
|
| 482 |
+
" formula_intensities \n",
|
| 483 |
+
"76895 [0.281209886028212, 0.24302057526061452, 0.423... \n",
|
| 484 |
+
"72767 [1.0] \n",
|
| 485 |
+
"221715 [0.11135079703351926, 0.060621778264910706, 0.... "
|
| 486 |
+
]
|
| 487 |
+
},
|
| 488 |
+
"execution_count": 7,
|
| 489 |
+
"metadata": {},
|
| 490 |
+
"output_type": "execute_result"
|
| 491 |
+
}
|
| 492 |
+
],
|
| 493 |
+
"source": [
|
| 494 |
+
"dataset.metadata[dataset.metadata['precursor_mz'] <=250].sample(3)"
|
| 495 |
+
]
|
| 496 |
+
},
|
| 497 |
+
{
|
| 498 |
+
"cell_type": "code",
|
| 499 |
+
"execution_count": 12,
|
| 500 |
+
"metadata": {},
|
| 501 |
+
"outputs": [
|
| 502 |
+
{
|
| 503 |
+
"data": {
|
| 504 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 505 |
+
"model_id": "202cce9fe93e477e9e9b2e8775a33be3",
|
| 506 |
+
"version_major": 2,
|
| 507 |
+
"version_minor": 0
|
| 508 |
+
},
|
| 509 |
+
"text/plain": [
|
| 510 |
+
"FigureWidget({\n",
|
| 511 |
+
" 'data': [{'customdata': [0, 1, 2, 3, 4, 5, 6, 7],\n",
|
| 512 |
+
" 'hovertext': [Formula H8C8O, Formula H10C10O, Formula H10C11O,\n",
|
| 513 |
+
" Formula H16C13O, Formula H16C12ON, Formula H16C14O,\n",
|
| 514 |
+
" Formula H19C14ON, Formula H21C14O2N],\n",
|
| 515 |
+
" 'marker': {'cmax': 1,\n",
|
| 516 |
+
" 'cmin': 0,\n",
|
| 517 |
+
" 'color': 'lightgray',\n",
|
| 518 |
+
" 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n",
|
| 519 |
+
" 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n",
|
| 520 |
+
" '#482878'], [0.2222222222222222,\n",
|
| 521 |
+
" '#3e4989'], [0.3333333333333333,\n",
|
| 522 |
+
" '#31688e'], [0.4444444444444444,\n",
|
| 523 |
+
" '#26828e'], [0.5555555555555556,\n",
|
| 524 |
+
" '#1f9e89'], [0.6666666666666666,\n",
|
| 525 |
+
" '#35b779'], [0.7777777777777778,\n",
|
| 526 |
+
" '#6ece58'], [0.8888888888888888,\n",
|
| 527 |
+
" '#b5de2b'], [1.0, '#fde725']]},\n",
|
| 528 |
+
" 'name': 'peak',\n",
|
| 529 |
+
" 'type': 'bar',\n",
|
| 530 |
+
" 'uid': 'c5467361-ef75-4547-a2f0-d124d0db63ce',\n",
|
| 531 |
+
" 'x': [120.05730010663046, 146.07290266870038, 158.07289873479382,\n",
|
| 532 |
+
" 188.11970182247236, 190.1227957280129, 200.1196978885658,\n",
|
| 533 |
+
" 217.14619823001323, 235.15669775236026],\n",
|
| 534 |
+
" 'xaxis': 'x',\n",
|
| 535 |
+
" 'y': [0.46595707535743713, 0.095524862408638, 0.08919080346822739,\n",
|
| 536 |
+
" 1.0, 0.10595753788948059, 0.1503429412841797,\n",
|
| 537 |
+
" 0.1142716035246849, 1.100000023841858],\n",
|
| 538 |
+
" 'yaxis': 'y'},\n",
|
| 539 |
+
" {'hoverinfo': 'none',\n",
|
| 540 |
+
" 'line': {'color': 'gray', 'width': 2},\n",
|
| 541 |
+
" 'mode': 'lines',\n",
|
| 542 |
+
" 'showlegend': False,\n",
|
| 543 |
+
" 'type': 'scatter',\n",
|
| 544 |
+
" 'uid': '86756d4d-d34e-4422-8024-50e923982319',\n",
|
| 545 |
+
" 'x': [-5.436182348119566, -3.9950746378612374, None,\n",
|
| 546 |
+
" -3.9950746378612374, -2.9140954429243706, None,\n",
|
| 547 |
+
" -2.9140954429243706, -3.274223958245831, None,\n",
|
| 548 |
+
" -3.274223958245831, -2.193244763308963, None,\n",
|
| 549 |
+
" -2.193244763308963, -0.7521370530506348, None,\n",
|
| 550 |
+
" -0.7521370530506348, -0.39200853772917454, None,\n",
|
| 551 |
+
" -0.39200853772917454, -1.4729877326660425, None,\n",
|
| 552 |
+
" -0.39200853772917454, 1.0490991725291536, None,\n",
|
| 553 |
+
" 1.0490991725291536, 0.38899287090069107, None,\n",
|
| 554 |
+
" 0.38899287090069107, 1.225427933877235, None,\n",
|
| 555 |
+
" 1.225427933877235, 2.721969298482242, None, 2.721969298482242,\n",
|
| 556 |
+
" 3.3820756001107055, None, 3.3820756001107055,\n",
|
| 557 |
+
" 2.545640537134162, None, 2.545640537134162, 3.2057468387626264,\n",
|
| 558 |
+
" None, 3.2057468387626264, 4.702288203367634, None,\n",
|
| 559 |
+
" 1.0490991725291536, 1.2087140187413736, None,\n",
|
| 560 |
+
" -1.4729877326660425, -2.9140954429243706, None,\n",
|
| 561 |
+
" 2.545640537134162, 1.0490991725291536, None],\n",
|
| 562 |
+
" 'xaxis': 'x2',\n",
|
| 563 |
+
" 'y': [-0.8275311534181835, -1.243714487339663, None,\n",
|
| 564 |
+
" -1.243714487339663, -0.2037702676270663, None,\n",
|
| 565 |
+
" -0.2037702676270663, 1.2523572860070113, None,\n",
|
| 566 |
+
" 1.2523572860070113, 2.2923015057196072, None,\n",
|
| 567 |
+
" 2.2923015057196072, 1.8761181717981272, None,\n",
|
| 568 |
+
" 1.8761181717981272, 0.4199906181640502, None,\n",
|
| 569 |
+
" 0.4199906181640502, -0.6199536015485465, None,\n",
|
| 570 |
+
" 0.4199906181640502, 0.0038072842425705966, None,\n",
|
| 571 |
+
" 0.0038072842425705966, -1.343137284234691, None,\n",
|
| 572 |
+
" -1.343137284234691, -2.588278394881763, None,\n",
|
| 573 |
+
" -2.588278394881763, -2.4864749370515753, None,\n",
|
| 574 |
+
" -2.4864749370515753, -1.1395303685743154, None,\n",
|
| 575 |
+
" -1.1395303685743154, 0.10561074207275722, None,\n",
|
| 576 |
+
" 0.10561074207275722, 1.4525553105500184, None,\n",
|
| 577 |
+
" 1.4525553105500184, 1.554358768380205, None,\n",
|
| 578 |
+
" 0.0038072842425705966, 1.4952908077414555, None,\n",
|
| 579 |
+
" -0.6199536015485465, -0.2037702676270663, None,\n",
|
| 580 |
+
" 0.10561074207275722, 0.0038072842425705966, None],\n",
|
| 581 |
+
" 'yaxis': 'y2'},\n",
|
| 582 |
+
" {'customdata': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n",
|
| 583 |
+
" 16],\n",
|
| 584 |
+
" 'marker': {'cmax': 1,\n",
|
| 585 |
+
" 'cmin': 0,\n",
|
| 586 |
+
" 'color': 'lightgray',\n",
|
| 587 |
+
" 'colorbar': {'len': 0.8, 'title': {'text': 'Similarity'}, 'y': 0.5},\n",
|
| 588 |
+
" 'colorscale': [[0.0, '#440154'], [0.1111111111111111,\n",
|
| 589 |
+
" '#482878'], [0.2222222222222222,\n",
|
| 590 |
+
" '#3e4989'], [0.3333333333333333,\n",
|
| 591 |
+
" '#31688e'], [0.4444444444444444,\n",
|
| 592 |
+
" '#26828e'], [0.5555555555555556,\n",
|
| 593 |
+
" '#1f9e89'], [0.6666666666666666,\n",
|
| 594 |
+
" '#35b779'], [0.7777777777777778,\n",
|
| 595 |
+
" '#6ece58'], [0.8888888888888888,\n",
|
| 596 |
+
" '#b5de2b'], [1.0, '#fde725']],\n",
|
| 597 |
+
" 'size': 20},\n",
|
| 598 |
+
" 'mode': 'markers+text',\n",
|
| 599 |
+
" 'name': 'node',\n",
|
| 600 |
+
" 'text': [C, O, C, C, C, C, C, C, C, C, C, C, C, C, C, N, O],\n",
|
| 601 |
+
" 'textposition': 'middle center',\n",
|
| 602 |
+
" 'type': 'scatter',\n",
|
| 603 |
+
" 'uid': '5c2fb088-99a4-4115-ada1-d9b560973db3',\n",
|
| 604 |
+
" 'x': [-5.436182348119566, -3.9950746378612374, -2.9140954429243706,\n",
|
| 605 |
+
" -3.274223958245831, -2.193244763308963, -0.7521370530506348,\n",
|
| 606 |
+
" -0.39200853772917454, -1.4729877326660425, 1.0490991725291536,\n",
|
| 607 |
+
" 0.38899287090069107, 1.225427933877235, 2.721969298482242,\n",
|
| 608 |
+
" 3.3820756001107055, 2.545640537134162, 3.2057468387626264,\n",
|
| 609 |
+
" 4.702288203367634, 1.2087140187413736],\n",
|
| 610 |
+
" 'xaxis': 'x2',\n",
|
| 611 |
+
" 'y': [-0.8275311534181835, -1.243714487339663, -0.2037702676270663,\n",
|
| 612 |
+
" 1.2523572860070113, 2.2923015057196072, 1.8761181717981272,\n",
|
| 613 |
+
" 0.4199906181640502, -0.6199536015485465, 0.0038072842425705966,\n",
|
| 614 |
+
" -1.343137284234691, -2.588278394881763, -2.4864749370515753,\n",
|
| 615 |
+
" -1.1395303685743154, 0.10561074207275722, 1.4525553105500184,\n",
|
| 616 |
+
" 1.554358768380205, 1.4952908077414555],\n",
|
| 617 |
+
" 'yaxis': 'y2'}],\n",
|
| 618 |
+
" 'layout': {'annotations': [{'font': {'size': 16},\n",
|
| 619 |
+
" 'showarrow': False,\n",
|
| 620 |
+
" 'text': 'Spectrum',\n",
|
| 621 |
+
" 'x': 0.27,\n",
|
| 622 |
+
" 'xanchor': 'center',\n",
|
| 623 |
+
" 'xref': 'paper',\n",
|
| 624 |
+
" 'y': 1.0,\n",
|
| 625 |
+
" 'yanchor': 'bottom',\n",
|
| 626 |
+
" 'yref': 'paper'},\n",
|
| 627 |
+
" {'font': {'size': 16},\n",
|
| 628 |
+
" 'showarrow': False,\n",
|
| 629 |
+
" 'text': 'Molecule',\n",
|
| 630 |
+
" 'x': 0.8200000000000001,\n",
|
| 631 |
+
" 'xanchor': 'center',\n",
|
| 632 |
+
" 'xref': 'paper',\n",
|
| 633 |
+
" 'y': 1.0,\n",
|
| 634 |
+
" 'yanchor': 'bottom',\n",
|
| 635 |
+
" 'yref': 'paper'}],\n",
|
| 636 |
+
" 'showlegend': False,\n",
|
| 637 |
+
" 'template': '...',\n",
|
| 638 |
+
" 'title': {'text': 'Peak ↔ Node Similarity'},\n",
|
| 639 |
+
" 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.54], 'title': {'text': 'm/z'}},\n",
|
| 640 |
+
" 'xaxis2': {'anchor': 'y2', 'domain': [0.64, 1.0], 'visible': False},\n",
|
| 641 |
+
" 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Intensity'}},\n",
|
| 642 |
+
" 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'visible': False}}\n",
|
| 643 |
+
"})"
|
| 644 |
+
]
|
| 645 |
+
},
|
| 646 |
+
"execution_count": 12,
|
| 647 |
+
"metadata": {},
|
| 648 |
+
"output_type": "execute_result"
|
| 649 |
+
}
|
| 650 |
+
],
|
| 651 |
+
"source": [
|
| 652 |
+
"from rdkit import Chem\n",
|
| 653 |
+
"\n",
|
| 654 |
+
"# Data\n",
|
| 655 |
+
"# i = 40991\n",
|
| 656 |
+
"\n",
|
| 657 |
+
"ms_id = \"MassSpecGymID0033096\"\n",
|
| 658 |
+
"i = dataset.metadata[dataset.metadata['identifier'] == ms_id].index[0]\n",
|
| 659 |
+
"s = dataset.metadata.iloc[i]['smiles']\n",
|
| 660 |
+
"mol = Chem.MolFromSmiles(s)\n",
|
| 661 |
+
"g = dataset[i]['mol']\n",
|
| 662 |
+
"spec = dataset[i]['SpecFormula']\n",
|
| 663 |
+
"\n",
|
| 664 |
+
"peak_mzs, peak_intensities, peak_formulas = spectra_from_encoding(spec)\n",
|
| 665 |
+
"\n",
|
| 666 |
+
"# Embeddings\n",
|
| 667 |
+
"model = model.to(torch.device('cpu'))\n",
|
| 668 |
+
"model.eval()\n",
|
| 669 |
+
"with torch.no_grad():\n",
|
| 670 |
+
" spec_enc, mol_enc = model.forward(dataset[i], stage='test')\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"fw = interactive_attention_visualization(spec_enc, mol_enc, peak_mzs, peak_intensities, peak_formulas, mol)\n",
|
| 673 |
+
"fw\n"
|
| 674 |
+
]
|
| 675 |
+
},
|
| 676 |
+
{
|
| 677 |
+
"cell_type": "code",
|
| 678 |
+
"execution_count": null,
|
| 679 |
+
"metadata": {},
|
| 680 |
+
"outputs": [],
|
| 681 |
+
"source": []
|
| 682 |
+
}
|
| 683 |
+
],
|
| 684 |
+
"metadata": {
|
| 685 |
+
"kernelspec": {
|
| 686 |
+
"display_name": "spec",
|
| 687 |
+
"language": "python",
|
| 688 |
+
"name": "python3"
|
| 689 |
+
},
|
| 690 |
+
"language_info": {
|
| 691 |
+
"codemirror_mode": {
|
| 692 |
+
"name": "ipython",
|
| 693 |
+
"version": 3
|
| 694 |
+
},
|
| 695 |
+
"file_extension": ".py",
|
| 696 |
+
"mimetype": "text/x-python",
|
| 697 |
+
"name": "python",
|
| 698 |
+
"nbconvert_exporter": "python",
|
| 699 |
+
"pygments_lexer": "ipython3",
|
| 700 |
+
"version": "3.11.7"
|
| 701 |
+
},
|
| 702 |
+
"orig_nbformat": 4
|
| 703 |
+
},
|
| 704 |
+
"nbformat": 4,
|
| 705 |
+
"nbformat_minor": 2
|
| 706 |
+
}
|
notebooks/hyperparameter_tuning_result.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/magma_script.ipynb
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 12,
|
| 6 |
+
"id": "1205f9e4",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pandas as pd\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"import numpy as np\n",
|
| 14 |
+
"from rdkit import Chem\n",
|
| 15 |
+
"from collections import defaultdict\n",
|
| 16 |
+
"import numpy as np\n",
|
| 17 |
+
"import sys\n",
|
| 18 |
+
"import json"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": 3,
|
| 24 |
+
"id": "c1267f1b",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"with open(\"/data/yzhouc01/FILIP-MS/data/magma/MassSpecGymID0191762.json\", 'r') as f:\n",
|
| 29 |
+
" data = json.load(f)"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 7,
|
| 35 |
+
"id": "db06e7e6",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [
|
| 38 |
+
{
|
| 39 |
+
"data": {
|
| 40 |
+
"text/plain": [
|
| 41 |
+
"dict_keys(['mz', 'intensities', 'subformulas', 'substructures'])"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
"execution_count": 7,
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"output_type": "execute_result"
|
| 47 |
+
}
|
| 48 |
+
],
|
| 49 |
+
"source": [
|
| 50 |
+
"data.keys()"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": 11,
|
| 56 |
+
"id": "582e23d9",
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"outputs": [
|
| 59 |
+
{
|
| 60 |
+
"data": {
|
| 61 |
+
"text/plain": [
|
| 62 |
+
"['C4H8O4']"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
"execution_count": 11,
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"output_type": "execute_result"
|
| 68 |
+
}
|
| 69 |
+
],
|
| 70 |
+
"source": [
|
| 71 |
+
"data['subformulas'][0]"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": 14,
|
| 77 |
+
"id": "abe22e21",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [],
|
| 80 |
+
"source": [
|
| 81 |
+
"np.random.seed(42)\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"formulas = []\n",
|
| 84 |
+
"mzs = []\n",
|
| 85 |
+
"intensities = []\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"for f, m, i in zip(data['subformulas'], data['mz'], data['intensities']):\n",
|
| 88 |
+
" if f:\n",
|
| 89 |
+
" formulas.append(np.random.choice(f))\n",
|
| 90 |
+
" mzs.append(m)\n",
|
| 91 |
+
" intensities.append(i)\n",
|
| 92 |
+
" "
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"execution_count": 15,
|
| 98 |
+
"id": "161f95f5",
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [
|
| 101 |
+
{
|
| 102 |
+
"data": {
|
| 103 |
+
"text/plain": [
|
| 104 |
+
"69"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
"execution_count": 15,
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"output_type": "execute_result"
|
| 110 |
+
}
|
| 111 |
+
],
|
| 112 |
+
"source": [
|
| 113 |
+
"len(data['mz'])"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "code",
|
| 118 |
+
"execution_count": null,
|
| 119 |
+
"id": "0c622b13",
|
| 120 |
+
"metadata": {},
|
| 121 |
+
"outputs": [],
|
| 122 |
+
"source": []
|
| 123 |
+
}
|
| 124 |
+
],
|
| 125 |
+
"metadata": {
|
| 126 |
+
"kernelspec": {
|
| 127 |
+
"display_name": "spec",
|
| 128 |
+
"language": "python",
|
| 129 |
+
"name": "python3"
|
| 130 |
+
},
|
| 131 |
+
"language_info": {
|
| 132 |
+
"codemirror_mode": {
|
| 133 |
+
"name": "ipython",
|
| 134 |
+
"version": 3
|
| 135 |
+
},
|
| 136 |
+
"file_extension": ".py",
|
| 137 |
+
"mimetype": "text/x-python",
|
| 138 |
+
"name": "python",
|
| 139 |
+
"nbconvert_exporter": "python",
|
| 140 |
+
"pygments_lexer": "ipython3",
|
| 141 |
+
"version": "3.11.7"
|
| 142 |
+
}
|
| 143 |
+
},
|
| 144 |
+
"nbformat": 4,
|
| 145 |
+
"nbformat_minor": 5
|
| 146 |
+
}
|
notebooks/peak_embedding_UMAP.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/peak_formula_analysis.ipynb
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "07d00685",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import torch\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"import plotly.graph_objects as go\n",
|
| 13 |
+
"from plotly.subplots import make_subplots\n",
|
| 14 |
+
"from rdkit import Chem\n",
|
| 15 |
+
"from rdkit.Chem import rdDepictor\n",
|
| 16 |
+
"from rdkit.Chem.Draw import rdMolDraw2D\n",
|
| 17 |
+
"import matplotlib.pyplot as plt\n",
|
| 18 |
+
"import json"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": 4,
|
| 24 |
+
"id": "cd9e10c7",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"import sys\n",
|
| 29 |
+
"sys.path.insert(0, \"/data/yzhouc01/MassSpecGym\")\n",
|
| 30 |
+
"sys.path.insert(0, \"/data/yzhouc01/FILIP-MS\")\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"from rdkit import RDLogger\n",
|
| 33 |
+
"import pytorch_lightning as pl\n",
|
| 34 |
+
"from pytorch_lightning import Trainer\n",
|
| 35 |
+
"from massspecgym.models.base import Stage\n",
|
| 36 |
+
"import os\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset,get_test_ms_dataset\n",
|
| 39 |
+
"from mvp.utils.models import get_model\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"from mvp.definitions import TEST_RESULTS_DIR\n",
|
| 42 |
+
"import yaml\n",
|
| 43 |
+
"from functools import partial\n",
|
| 44 |
+
"# Suppress RDKit warnings and errors\n",
|
| 45 |
+
"lg = RDLogger.logger()\n",
|
| 46 |
+
"lg.setLevel(RDLogger.CRITICAL)"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 5,
|
| 52 |
+
"id": "9ba93f86",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [
|
| 55 |
+
{
|
| 56 |
+
"name": "stdout",
|
| 57 |
+
"output_type": "stream",
|
| 58 |
+
"text": [
|
| 59 |
+
"Data path: /data/yzhouc01/MVP/data/sample/data.tsv\n",
|
| 60 |
+
"Processing formula spectra\n"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"name": "stderr",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"100%|██████████| 10/10 [00:00<00:00, 5861.24it/s]"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "stderr",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"\n"
|
| 75 |
+
]
|
| 76 |
+
}
|
| 77 |
+
],
|
| 78 |
+
"source": [
|
| 79 |
+
"# Load model and data\n",
|
| 80 |
+
"# param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250824_filipContrastive/lightning_logs/version_0/hparams.yaml'\n",
|
| 81 |
+
"param_pth = \"/data/yzhouc01/FILIP-MS/mvp/params_formSpec.yaml\"\n",
|
| 82 |
+
"with open(param_pth) as f:\n",
|
| 83 |
+
" params = yaml.load(f, Loader=yaml.FullLoader)\n",
|
| 84 |
+
"params['dataset_pth'] = \"/data/yzhouc01/MVP/data/sample/data.tsv\"\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"spec_featurizer = get_spec_featurizer(params['spectra_view'], params)\n",
|
| 87 |
+
"mol_featurizer = get_mol_featurizer(params['molecule_view'], params)\n",
|
| 88 |
+
"dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": 7,
|
| 94 |
+
"id": "bcb28630",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [
|
| 97 |
+
{
|
| 98 |
+
"data": {
|
| 99 |
+
"text/plain": [
|
| 100 |
+
"{'precursor_mz': 404.1241,\n",
|
| 101 |
+
" 'formulas': array(['C3HO2', 'C4H2N2', 'C6H4O', 'C4H2N2O', 'C7H4N', 'C7H4O', 'C8H5',\n",
|
| 102 |
+
" 'C4H2N2O2', 'C5H7O3', 'C7H4NO', 'C8H5O', 'C9H5O', 'C9H8O',\n",
|
| 103 |
+
" 'C10H8O', 'C9H5O2', 'C9H8O2', 'C10H5NO', 'C10H8O2', 'C10H6N2O',\n",
|
| 104 |
+
" 'C9H5O4', 'C11H6N2O', 'C10H6N2O2', 'C12H7N2O', 'C11H6N2O2',\n",
|
| 105 |
+
" 'C12H9O3', 'C11H8O4', 'C11H8NO3', 'C11H11O4', 'C11H6N3O2',\n",
|
| 106 |
+
" 'C12H11O4', 'C13H10N2O3', 'C14H10N2O4', 'C13H7N2O5', 'C15H13N2O5',\n",
|
| 107 |
+
" 'C19H11N3O3', 'C22H17N3O5'], dtype='<U10'),\n",
|
| 108 |
+
" 'precursor_formula': 'C22H18N3O5'}"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
"execution_count": 7,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"output_type": "execute_result"
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"source": [
|
| 117 |
+
"dataset.spectra[1].metadata"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": 4,
|
| 123 |
+
"id": "fbebdab3",
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [
|
| 126 |
+
{
|
| 127 |
+
"data": {
|
| 128 |
+
"text/plain": [
|
| 129 |
+
"{'precursor_mz': 226.0716,\n",
|
| 130 |
+
" 'formulas': array(['C5H5O2', 'C6H6O', 'C3H4NO3', 'C7H6O', 'C6H3O2', 'C3H5NO4',\n",
|
| 131 |
+
" 'C7H6O2', 'C7H6NO2', 'C8H4NO2', 'C7H6NO3', 'C8H9NO3', 'C7H6NO5',\n",
|
| 132 |
+
" 'C9H8NO4', 'C10H10NO4', 'C10H11NO5'], dtype='<U9'),\n",
|
| 133 |
+
" 'precursor_formula': 'C10H12NO5'}"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
"execution_count": 4,
|
| 137 |
+
"metadata": {},
|
| 138 |
+
"output_type": "execute_result"
|
| 139 |
+
}
|
| 140 |
+
],
|
| 141 |
+
"source": [
|
| 142 |
+
"dataset.spectra[0].metadata"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "code",
|
| 147 |
+
"execution_count": 9,
|
| 148 |
+
"id": "268c6470",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [
|
| 151 |
+
{
|
| 152 |
+
"data": {
|
| 153 |
+
"text/plain": [
|
| 154 |
+
"{'precursor_mz': 226.0716,\n",
|
| 155 |
+
" 'formulas': array(['C6H6O', 'C4H3O3', 'C5H5O2', 'C6H6O', 'C4H3O3', 'C5H5O2',\n",
|
| 156 |
+
" 'C3H4NO3', 'C7H6O', 'C7H6O', 'C6H3O2', 'C6H3O2', 'C3H5NO4',\n",
|
| 157 |
+
" 'C7H6O2', 'C7H6O2', 'C7H6O2', 'C7H6O2', 'C7H6NO2', 'C7H6NO2',\n",
|
| 158 |
+
" 'C7H6NO2', 'C8H4NO2', 'C8H4NO2', 'C7H6NO3', 'C8H9NO3', 'C8H9NO3',\n",
|
| 159 |
+
" 'C8H9NO3', 'C8H9NO3', 'C8H10NO4', 'C7H9NO5', 'C7H6NO5', 'C9H8NO4',\n",
|
| 160 |
+
" 'C10H10NO4', 'C9H8NO5', 'C10H11NO5'], dtype='<U9'),\n",
|
| 161 |
+
" 'precursor_formula': 'C10H12NO5'}"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
"execution_count": 9,
|
| 165 |
+
"metadata": {},
|
| 166 |
+
"output_type": "execute_result"
|
| 167 |
+
}
|
| 168 |
+
],
|
| 169 |
+
"source": [
|
| 170 |
+
"dataset.spectra[0].metadata"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "markdown",
|
| 175 |
+
"id": "4a9f0227",
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"source": [
|
| 178 |
+
"# SIRIUS subformulas"
|
| 179 |
+
]
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "code",
|
| 183 |
+
"execution_count": 4,
|
| 184 |
+
"id": "5f40603f",
|
| 185 |
+
"metadata": {},
|
| 186 |
+
"outputs": [
|
| 187 |
+
{
|
| 188 |
+
"data": {
|
| 189 |
+
"image/png": "",
|
| 190 |
+
"text/plain": [
|
| 191 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"output_type": "display_data"
|
| 196 |
+
}
|
| 197 |
+
],
|
| 198 |
+
"source": [
|
| 199 |
+
"import numpy as np\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"import matplotlib.pyplot as plt\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"# Collect number of formulas per spectrum\n",
|
| 204 |
+
"n_formulas = [len(d.metadata['formulas']) for d in dataset.spectra]\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"# Calculate mean and median\n",
|
| 207 |
+
"mean_n_formulas = np.mean(n_formulas)\n",
|
| 208 |
+
"median_n_formulas = np.median(n_formulas)\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"# Plot histogram\n",
|
| 211 |
+
"plt.hist(n_formulas, bins=30, alpha=0.7, color='skyblue')\n",
|
| 212 |
+
"plt.axvline(mean_n_formulas, color='red', linestyle='dashed', linewidth=2, label=f'Mean: {mean_n_formulas:.2f}')\n",
|
| 213 |
+
"plt.axvline(median_n_formulas, color='green', linestyle='dashed', linewidth=2, label=f'Median: {median_n_formulas:.2f}')\n",
|
| 214 |
+
"plt.xlabel('Number of formulas per spectrum')\n",
|
| 215 |
+
"plt.ylabel('Count')\n",
|
| 216 |
+
"plt.title('Distribution of Number of Formulas per Spectrum (MIST labels)')\n",
|
| 217 |
+
"plt.legend()\n",
|
| 218 |
+
"plt.show()"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"cell_type": "markdown",
|
| 223 |
+
"id": "170af068",
|
| 224 |
+
"metadata": {},
|
| 225 |
+
"source": [
|
| 226 |
+
"# MIST subformulas"
|
| 227 |
+
]
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"cell_type": "markdown",
|
| 231 |
+
"id": "4798447e",
|
| 232 |
+
"metadata": {},
|
| 233 |
+
"source": []
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "code",
|
| 237 |
+
"execution_count": 17,
|
| 238 |
+
"id": "5612e3dc",
|
| 239 |
+
"metadata": {},
|
| 240 |
+
"outputs": [
|
| 241 |
+
{
|
| 242 |
+
"data": {
|
| 243 |
+
"image/png": "",
|
| 244 |
+
"text/plain": [
|
| 245 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 246 |
+
]
|
| 247 |
+
},
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"output_type": "display_data"
|
| 250 |
+
}
|
| 251 |
+
],
|
| 252 |
+
"source": [
|
| 253 |
+
"import numpy as np\n",
|
| 254 |
+
"\n",
|
| 255 |
+
"import matplotlib.pyplot as plt\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"# Collect number of formulas per spectrum\n",
|
| 258 |
+
"n_formulas = [len(d.metadata['formulas']) for d in dataset.spectra]\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"# Calculate mean and median\n",
|
| 261 |
+
"mean_n_formulas = np.mean(n_formulas)\n",
|
| 262 |
+
"median_n_formulas = np.median(n_formulas)\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"# Plot histogram\n",
|
| 265 |
+
"plt.hist(n_formulas, bins=30, alpha=0.7, color='skyblue')\n",
|
| 266 |
+
"plt.axvline(mean_n_formulas, color='red', linestyle='dashed', linewidth=2, label=f'Mean: {mean_n_formulas:.2f}')\n",
|
| 267 |
+
"plt.axvline(median_n_formulas, color='green', linestyle='dashed', linewidth=2, label=f'Median: {median_n_formulas:.2f}')\n",
|
| 268 |
+
"plt.xlabel('Number of formulas per spectrum')\n",
|
| 269 |
+
"plt.ylabel('Count')\n",
|
| 270 |
+
"plt.title('Distribution of Number of Formulas per Spectrum (MIST labels)')\n",
|
| 271 |
+
"plt.legend()\n",
|
| 272 |
+
"plt.show()"
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "code",
|
| 277 |
+
"execution_count": 7,
|
| 278 |
+
"id": "b35bb5a4",
|
| 279 |
+
"metadata": {},
|
| 280 |
+
"outputs": [
|
| 281 |
+
{
|
| 282 |
+
"data": {
|
| 283 |
+
"text/plain": [
|
| 284 |
+
"identifier MassSpecGymID0000001\n",
|
| 285 |
+
"mzs 91.0542,125.0233,154.0499,155.0577,185.0961,20...\n",
|
| 286 |
+
"intensities 0.24524524524524524,1.0,0.08008008008008008,0....\n",
|
| 287 |
+
"smiles CC(=O)N[C@@H](CC1=CC=CC=C1)C2=CC(=CC(=O)O2)OC\n",
|
| 288 |
+
"inchikey VFMQMACUYWGDOJ\n",
|
| 289 |
+
"formula C16H17NO4\n",
|
| 290 |
+
"precursor_formula C16H18NO4\n",
|
| 291 |
+
"parent_mass 287.115224\n",
|
| 292 |
+
"precursor_mz 288.1225\n",
|
| 293 |
+
"adduct [M+H]+\n",
|
| 294 |
+
"instrument_type Orbitrap\n",
|
| 295 |
+
"collision_energy 30.0\n",
|
| 296 |
+
"fold train\n",
|
| 297 |
+
"simulation_challenge True\n",
|
| 298 |
+
"formulas [C16H17NO4]\n",
|
| 299 |
+
"formula_mzs [288.1225]\n",
|
| 300 |
+
"formula_intensities [1.0]\n",
|
| 301 |
+
"Name: 0, dtype: object"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
"execution_count": 7,
|
| 305 |
+
"metadata": {},
|
| 306 |
+
"output_type": "execute_result"
|
| 307 |
+
}
|
| 308 |
+
],
|
| 309 |
+
"source": [
|
| 310 |
+
"dataset.metadata.iloc[0]"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"cell_type": "code",
|
| 315 |
+
"execution_count": 9,
|
| 316 |
+
"id": "da07f08a",
|
| 317 |
+
"metadata": {},
|
| 318 |
+
"outputs": [
|
| 319 |
+
{
|
| 320 |
+
"ename": "PermissionError",
|
| 321 |
+
"evalue": "[Errno 13] Permission denied: '/r/hassounlab/msgym_sirius/MassSpecGymID0000140.json'",
|
| 322 |
+
"output_type": "error",
|
| 323 |
+
"traceback": [
|
| 324 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 325 |
+
"\u001b[0;31mPermissionError\u001b[0m Traceback (most recent call last)",
|
| 326 |
+
"Cell \u001b[0;32mIn[9], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m spec_id \u001b[38;5;241m=\u001b[39m dataset\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39miloc[\u001b[38;5;241m123\u001b[39m][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124midentifier\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 3\u001b[0m file \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(params[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msubformula_dir_pth\u001b[39m\u001b[38;5;124m'\u001b[39m], spec_id\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.json\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 5\u001b[0m data \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(f)\n",
|
| 327 |
+
"File \u001b[0;32m/data/yzc-conda/spec/lib/python3.11/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 322\u001b[0m )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 328 |
+
"\u001b[0;31mPermissionError\u001b[0m: [Errno 13] Permission denied: '/r/hassounlab/msgym_sirius/MassSpecGymID0000140.json'"
|
| 329 |
+
]
|
| 330 |
+
}
|
| 331 |
+
],
|
| 332 |
+
"source": [
|
| 333 |
+
"import json\n",
|
| 334 |
+
"spec_id = dataset.metadata.iloc[123]['identifier']\n",
|
| 335 |
+
"file = os.path.join(params['subformula_dir_pth'], spec_id+\".json\")\n",
|
| 336 |
+
"with open(file) as f:\n",
|
| 337 |
+
" data = json.load(f)"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "code",
|
| 342 |
+
"execution_count": null,
|
| 343 |
+
"id": "a1341478",
|
| 344 |
+
"metadata": {},
|
| 345 |
+
"outputs": [],
|
| 346 |
+
"source": []
|
| 347 |
+
}
|
| 348 |
+
],
|
| 349 |
+
"metadata": {
|
| 350 |
+
"kernelspec": {
|
| 351 |
+
"display_name": "spec",
|
| 352 |
+
"language": "python",
|
| 353 |
+
"name": "python3"
|
| 354 |
+
},
|
| 355 |
+
"language_info": {
|
| 356 |
+
"codemirror_mode": {
|
| 357 |
+
"name": "ipython",
|
| 358 |
+
"version": 3
|
| 359 |
+
},
|
| 360 |
+
"file_extension": ".py",
|
| 361 |
+
"mimetype": "text/x-python",
|
| 362 |
+
"name": "python",
|
| 363 |
+
"nbconvert_exporter": "python",
|
| 364 |
+
"pygments_lexer": "ipython3",
|
| 365 |
+
"version": "3.11.7"
|
| 366 |
+
}
|
| 367 |
+
},
|
| 368 |
+
"nbformat": 4,
|
| 369 |
+
"nbformat_minor": 5
|
| 370 |
+
}
|
notebooks/visualization.ipynb
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"data": {
|
| 10 |
+
"image/svg+xml": [
|
| 11 |
+
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:rdkit=\"http://www.rdkit.org/xml\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" version=\"1.1\" baseProfile=\"full\" xml:space=\"preserve\" width=\"400px\" height=\"400px\" viewBox=\"0 0 400 400\">\n",
|
| 12 |
+
"<!-- END OF HEADER -->\n",
|
| 13 |
+
"<rect style=\"opacity:1.0;fill:#FFFFFF;stroke:none\" width=\"400.0\" height=\"400.0\" x=\"0.0\" y=\"0.0\"> </rect>\n",
|
| 14 |
+
"<ellipse cx=\"38.2\" cy=\"220.1\" rx=\"18.2\" ry=\"18.2\" class=\"atom-0\" style=\"fill:#999999;fill-rule:evenodd;stroke:#999999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 15 |
+
"<ellipse cx=\"65.3\" cy=\"133.0\" rx=\"18.2\" ry=\"18.2\" class=\"atom-1\" style=\"fill:#999999;fill-rule:evenodd;stroke:#999999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 16 |
+
"<ellipse cx=\"154.2\" cy=\"113.0\" rx=\"18.2\" ry=\"18.2\" class=\"atom-2\" style=\"fill:#999999;fill-rule:evenodd;stroke:#999999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 17 |
+
"<ellipse cx=\"216.1\" cy=\"179.9\" rx=\"18.2\" ry=\"18.2\" class=\"atom-3\" style=\"fill:#999999;fill-rule:evenodd;stroke:#999999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 18 |
+
"<ellipse cx=\"189.0\" cy=\"267.0\" rx=\"18.2\" ry=\"18.2\" class=\"atom-4\" style=\"fill:#999999;fill-rule:evenodd;stroke:#999999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 19 |
+
"<ellipse cx=\"100.1\" cy=\"287.0\" rx=\"18.2\" ry=\"18.2\" class=\"atom-5\" style=\"fill:#999999;fill-rule:evenodd;stroke:#999999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 20 |
+
"<ellipse cx=\"305.0\" cy=\"159.9\" rx=\"18.2\" ry=\"18.2\" class=\"atom-6\" style=\"fill:#999999;fill-rule:evenodd;stroke:#999999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 21 |
+
"<ellipse cx=\"361.8\" cy=\"231.7\" rx=\"18.2\" ry=\"18.7\" class=\"atom-7\" style=\"fill:#F99999;fill-rule:evenodd;stroke:#F99999;stroke-width:1.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 22 |
+
"<path class=\"bond-0 atom-0 atom-1\" d=\"M 38.2,220.1 L 65.3,133.0\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 23 |
+
"<path class=\"bond-0 atom-0 atom-1\" d=\"M 53.6,216.6 L 76.0,144.6\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 24 |
+
"<path class=\"bond-1 atom-1 atom-2\" d=\"M 65.3,133.0 L 154.2,113.0\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 25 |
+
"<path class=\"bond-2 atom-2 atom-3\" d=\"M 154.2,113.0 L 216.1,179.9\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 26 |
+
"<path class=\"bond-2 atom-2 atom-3\" d=\"M 149.5,128.0 L 200.7,183.4\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 27 |
+
"<path class=\"bond-3 atom-3 atom-4\" d=\"M 216.1,179.9 L 189.0,267.0\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 28 |
+
"<path class=\"bond-4 atom-4 atom-5\" d=\"M 189.0,267.0 L 100.1,287.0\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 29 |
+
"<path class=\"bond-4 atom-4 atom-5\" d=\"M 178.3,255.4 L 104.7,272.0\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 30 |
+
"<path class=\"bond-5 atom-3 atom-6\" d=\"M 216.1,179.9 L 305.0,159.9\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 31 |
+
"<path class=\"bond-6 atom-6 atom-7\" d=\"M 305.0,159.9 L 328.5,185.4\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 32 |
+
"<path class=\"bond-6 atom-6 atom-7\" d=\"M 328.5,185.4 L 352.0,210.8\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 33 |
+
"<path class=\"bond-6 atom-6 atom-7\" d=\"M 300.3,175.0 L 318.4,194.6\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 34 |
+
"<path class=\"bond-6 atom-6 atom-7\" d=\"M 318.4,194.6 L 341.9,220.1\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 35 |
+
"<path class=\"bond-7 atom-5 atom-0\" d=\"M 100.1,287.0 L 38.2,220.1\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
|
| 36 |
+
"<path d=\"M 39.6,215.7 L 38.2,220.1 L 41.3,223.4\" style=\"fill:none;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;\"/>\n",
|
| 37 |
+
"<path d=\"M 64.0,137.4 L 65.3,133.0 L 69.8,132.0\" style=\"fill:none;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;\"/>\n",
|
| 38 |
+
"<path d=\"M 149.8,114.0 L 154.2,113.0 L 157.3,116.3\" style=\"fill:none;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;\"/>\n",
|
| 39 |
+
"<path d=\"M 190.3,262.6 L 189.0,267.0 L 184.5,268.0\" style=\"fill:none;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;\"/>\n",
|
| 40 |
+
"<path d=\"M 104.5,286.0 L 100.1,287.0 L 97.0,283.7\" style=\"fill:none;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;\"/>\n",
|
| 41 |
+
"<path d=\"M 300.5,160.9 L 305.0,159.9 L 306.1,161.2\" style=\"fill:none;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;\"/>\n",
|
| 42 |
+
"<path class=\"atom-7\" d=\"M 349.9 231.6 Q 349.9 225.4, 353.0 221.9 Q 356.0 218.5, 361.8 218.5 Q 367.5 218.5, 370.6 221.9 Q 373.6 225.4, 373.6 231.6 Q 373.6 237.9, 370.5 241.4 Q 367.4 245.0, 361.8 245.0 Q 356.1 245.0, 353.0 241.4 Q 349.9 237.9, 349.9 231.6 M 361.8 242.0 Q 365.7 242.0, 367.8 239.4 Q 370.0 236.8, 370.0 231.6 Q 370.0 226.5, 367.8 224.0 Q 365.7 221.4, 361.8 221.4 Q 357.8 221.4, 355.7 223.9 Q 353.6 226.5, 353.6 231.6 Q 353.6 236.8, 355.7 239.4 Q 357.8 242.0, 361.8 242.0 \" fill=\"#FF0000\"/>\n",
|
| 43 |
+
"</svg>"
|
| 44 |
+
],
|
| 45 |
+
"text/plain": [
|
| 46 |
+
"<IPython.core.display.SVG object>"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
"execution_count": 2,
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"output_type": "execute_result"
|
| 52 |
+
}
|
| 53 |
+
],
|
| 54 |
+
"source": [
|
| 55 |
+
"from rdkit import Chem\n",
|
| 56 |
+
"from rdkit.Chem import Draw\n",
|
| 57 |
+
"from rdkit.Chem.Draw import rdMolDraw2D\n",
|
| 58 |
+
"from IPython.display import SVG\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# Example molecule\n",
|
| 61 |
+
"mol = Chem.MolFromSmiles(\"C1=CC=C(C=C1)C=O\") \n",
|
| 62 |
+
"Chem.rdDepictor.Compute2DCoords(mol)\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"# Define colors for atom types\n",
|
| 65 |
+
"atom_colors = {\n",
|
| 66 |
+
" 6: (0.6, 0.6, 0.6), # Carbon = light gray\n",
|
| 67 |
+
" 8: (0.98, 0.6, 0.6), # Oxygen = soft red/pink\n",
|
| 68 |
+
" 7: (0.55, 0.63, 0.8), # Nitrogen = light blue\n",
|
| 69 |
+
" 16: (0.8, 0.8, 0.55), # Sulfur = soft yellow\n",
|
| 70 |
+
" 17: (0.65, 0.85, 0.65), # Chlorine = light green\n",
|
| 71 |
+
" 1: (0.9, 0.9, 0.9), # Hydrogen = very light gray\n",
|
| 72 |
+
"}\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"# Default = muted purple (for other atoms)\n",
|
| 76 |
+
"default_color = (0.8, 0.7, 0.9)\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"# Assign highlight colors\n",
|
| 79 |
+
"highlight_atoms = [atom.GetIdx() for atom in mol.GetAtoms()]\n",
|
| 80 |
+
"highlight_colors = {\n",
|
| 81 |
+
" atom.GetIdx(): atom_colors.get(atom.GetAtomicNum(), default_color)\n",
|
| 82 |
+
" for atom in mol.GetAtoms()\n",
|
| 83 |
+
"}\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# Draw with transparent background\n",
|
| 86 |
+
"drawer = rdMolDraw2D.MolDraw2DSVG(400, 400)\n",
|
| 87 |
+
"# drawer.drawOptions().clearBackground = False # 🔑 makes background transparent\n",
|
| 88 |
+
"rdMolDraw2D.PrepareAndDrawMolecule(\n",
|
| 89 |
+
" drawer,\n",
|
| 90 |
+
" mol,\n",
|
| 91 |
+
" highlightAtoms=highlight_atoms,\n",
|
| 92 |
+
" highlightAtomColors=highlight_colors\n",
|
| 93 |
+
")\n",
|
| 94 |
+
"drawer.FinishDrawing()\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"# Clean up RDKit's extra XML headers\n",
|
| 97 |
+
"svg = drawer.GetDrawingText().replace(\"svg:\", \"\")\n",
|
| 98 |
+
"SVG(svg)\n"
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "code",
|
| 103 |
+
"execution_count": 9,
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [
|
| 106 |
+
{
|
| 107 |
+
"data": {
|
| 108 |
+
"image/png": "",
|
| 109 |
+
"text/plain": [
|
| 110 |
+
"<Figure size 640x480 with 1 Axes>"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"output_type": "display_data"
|
| 115 |
+
}
|
| 116 |
+
],
|
| 117 |
+
"source": [
|
| 118 |
+
"import networkx as nx\n",
|
| 119 |
+
"import matplotlib.pyplot as plt\n",
|
| 120 |
+
"from rdkit import Chem\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"# Example molecule\n",
|
| 123 |
+
"mol = Chem.MolFromSmiles(\"C1CCN(C1)C(=O)N\") \n",
|
| 124 |
+
"Chem.rdDepictor.Compute2DCoords(mol)\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"# Define colors for atom types\n",
|
| 127 |
+
"atom_colors = {\n",
|
| 128 |
+
" 6: \"lightgray\", # Carbon\n",
|
| 129 |
+
" 8: \"lightcoral\", # Oxygen\n",
|
| 130 |
+
" 7: \"lightblue\", # Nitrogen\n",
|
| 131 |
+
" 16: \"khaki\", # Sulfur\n",
|
| 132 |
+
" 17: \"lightgreen\", # Chlorine\n",
|
| 133 |
+
" 1: \"whitesmoke\", # Hydrogen\n",
|
| 134 |
+
"}\n",
|
| 135 |
+
"default_color = \"plum\"\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"# Convert RDKit Mol → NetworkX graph\n",
|
| 138 |
+
"G = nx.Graph()\n",
|
| 139 |
+
"for atom in mol.GetAtoms():\n",
|
| 140 |
+
" idx = atom.GetIdx()\n",
|
| 141 |
+
" pos = mol.GetConformer().GetAtomPosition(idx)\n",
|
| 142 |
+
" G.add_node(\n",
|
| 143 |
+
" idx,\n",
|
| 144 |
+
" label=atom.GetSymbol(),\n",
|
| 145 |
+
" color=atom_colors.get(atom.GetAtomicNum(), default_color),\n",
|
| 146 |
+
" pos=(pos.x, pos.y) # store RDKit 2D coords\n",
|
| 147 |
+
" )\n",
|
| 148 |
+
"for bond in mol.GetBonds():\n",
|
| 149 |
+
" G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), order=bond.GetBondTypeAsDouble())\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"# Extract positions\n",
|
| 152 |
+
"pos = {n: (data[\"pos\"][0], data[\"pos\"][1]) for n, data in G.nodes(data=True)}\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"# Draw nodes\n",
|
| 155 |
+
"node_colors = [G.nodes[n][\"color\"] for n in G.nodes]\n",
|
| 156 |
+
"nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=800, edgecolors=\"k\")\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"# Draw edges with bond order as width\n",
|
| 159 |
+
"# edge_widths = [1.5 * G[u][v][\"order\"] for u, v in G.edges()]\n",
|
| 160 |
+
"nx.draw_networkx_edges(G, pos)\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"# Draw atom labels\n",
|
| 163 |
+
"labels = {n: G.nodes[n][\"label\"] for n in G.nodes}\n",
|
| 164 |
+
"nx.draw_networkx_labels(G, pos, labels, font_size=12, font_weight=\"bold\")\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"plt.axis(\"off\")\n",
|
| 167 |
+
"plt.gca().set_aspect(\"equal\", \"box\") # keep proportions\n",
|
| 168 |
+
"plt.show()\n"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": null,
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"outputs": [],
|
| 176 |
+
"source": []
|
| 177 |
+
}
|
| 178 |
+
],
|
| 179 |
+
"metadata": {
|
| 180 |
+
"kernelspec": {
|
| 181 |
+
"display_name": "spec",
|
| 182 |
+
"language": "python",
|
| 183 |
+
"name": "python3"
|
| 184 |
+
},
|
| 185 |
+
"language_info": {
|
| 186 |
+
"codemirror_mode": {
|
| 187 |
+
"name": "ipython",
|
| 188 |
+
"version": 3
|
| 189 |
+
},
|
| 190 |
+
"file_extension": ".py",
|
| 191 |
+
"mimetype": "text/x-python",
|
| 192 |
+
"name": "python",
|
| 193 |
+
"nbconvert_exporter": "python",
|
| 194 |
+
"pygments_lexer": "ipython3",
|
| 195 |
+
"version": "3.11.7"
|
| 196 |
+
},
|
| 197 |
+
"orig_nbformat": 4
|
| 198 |
+
},
|
| 199 |
+
"nbformat": 4,
|
| 200 |
+
"nbformat_minor": 2
|
| 201 |
+
}
|