MVP / mvp /models /contrastive.py
yzhouchen001's picture
model code
d9df210
raw
history blame
15.7 kB
import typing as T
import torch
import torch.nn as nn
import pandas as pd
from collections import defaultdict
import numpy as np
import os
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
from massspecgym.models.base import Stage
from massspecgym import utils
from mvp.utils.loss import contrastive_loss, cand_spec_sim_loss, fp_loss, cons_spec_loss
import mvp.utils.models as model_utils
import torch.nn.functional as F
class ContrastiveModel(RetrievalMassSpecGymModel):
def __init__(
self,
external_test = False,
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters()
self.external_test = external_test
if 'use_fp' not in self.hparams:
self.hparams.use_fp = False
if 'loss_strategy' not in self.hparams:
self.hparams.loss_strategy = 'static'
self.hparams.contr_wt = 1.0
self.hparams.use_contr = True
self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
if self.hparams.pred_fp:
self.fp_loss = fp_loss(self.hparams.fp_loss_type)
self.fp_pred_model = model_utils.get_fp_pred_model(self.hparams)
if self.hparams.use_cons_spec:
self.cons_spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
self.cons_loss = cons_spec_loss(self.hparams.cons_loss_type)
self.spec_view = self.hparams.spectra_view
# result storage for testing results
self.result_dct = defaultdict(lambda: defaultdict(list))
def forward(self, batch, stage):
g = batch['cand'] if stage == Stage.TEST else batch['mol']
if self.hparams.use_cons_spec and stage != Stage.TEST:
spec = batch['cons_spec']
n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
spec_enc = self.cons_spec_enc_model(spec, n_peaks)
else:
spec = batch[self.spec_view]
n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
spec_enc = self.spec_enc_model(spec, n_peaks)
fp = batch['fp'] if self.hparams.use_fp else None
mol_enc = self.mol_enc_model(g, fp=fp)
return spec_enc, mol_enc
def compute_loss(self, batch: dict, spec_enc, mol_enc, output):
loss = 0
losses = {}
contr_loss, cong_loss, noncong_loss = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp)
contr_loss = self.loss_wts['contr_wt'] *contr_loss
losses['contr_loss'] = contr_loss.detach().item()
losses['cong_loss'] = cong_loss.detach().item()
losses['noncong_loss'] = noncong_loss.detach().item()
loss+=contr_loss
if self.hparams.pred_fp:
fp_loss_val = self.loss_wts['fp_wt'] *self.fp_loss(output['fp'], batch['fp'])
loss+= fp_loss_val
losses['fp_loss'] = fp_loss_val.detach().item()
if 'aug_cand_enc' in output:
aug_cand_loss = self.loss_wts['aug_cand_wt'] * cand_spec_sim_loss(spec_enc, output['aug_cand_enc'])
loss+= aug_cand_loss
losses['aug_cand_loss'] = aug_cand_loss.detach().item()
if 'ind_spec' in output:
spec_loss = self.loss_wts['cons_spec_wt'] * self.cons_loss(spec_enc, output['ind_spec'])
loss+=spec_loss
losses['cons_spec_loss'] = spec_loss.detach().item()
losses['loss'] = loss
return losses
def step(
self, batch: dict, stage= Stage.NONE):
# Compute spectra and mol encoding
spec_enc, mol_enc = self.forward(batch, stage)
if stage == Stage.TEST:
return dict(spec_enc=spec_enc, mol_enc=mol_enc)
# Aux tasks
output = {}
if self.hparams.pred_fp:
output['fp'] = self.fp_pred_model(mol_enc)
if self.hparams.use_cons_spec:
spec = batch[self.spec_view]
n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
output['ind_spec'] = self.spec_enc_model(spec, n_peaks)
# Calculate loss
losses = self.compute_loss(batch, spec_enc, mol_enc, output)
return losses
def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
# total loss
self.log(
f'{stage.to_pref()}loss',
outputs['loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=True,
on_epoch=True,
# on_step=True
)
# contr loss
if self.hparams.use_contr:
self.log(
f'{stage.to_pref()}contr_loss',
outputs['contr_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=False,
on_epoch=True,
# on_step=True
)
# noncongruent pairs
self.log(
f'{stage.to_pref()}noncong_loss',
outputs['noncong_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=False,
on_epoch=True,
# on_step=True
)
# congruent pairs
self.log(
f'{stage.to_pref()}cong_loss',
outputs['cong_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=False,
on_epoch=True,
# on_step=True
)
if self.hparams.pred_fp:
self.log(
f'{stage.to_pref()}_fp_loss',
outputs['fp_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=False,
on_epoch=True,
)
if self.hparams.use_cons_spec:
self.log(
f'{stage.to_pref()}cons_loss',
outputs['cons_spec_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=False,
on_epoch=True,
)
def test_step(self, batch, batch_idx):
# Unpack inputs
identifiers = batch['identifier']
cand_smiles = batch['cand_smiles']
id_to_ct = defaultdict(int)
for i in identifiers: id_to_ct[i]+=1
batch_ptr = torch.tensor(list(id_to_ct.values()))
outputs = self.step(batch, stage=Stage.TEST)
spec_enc = outputs['spec_enc']
mol_enc = outputs['mol_enc']
# Calculate scores
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
scores = nn.functional.cosine_similarity(spec_enc, mol_enc)
scores = torch.split(scores, list(id_to_ct.values()))
cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
labels = utils.unbatch_list(batch['label'], indexes)
return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
# save scores
for i, cands, scores, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores'], outputs['labels']):
self.result_dct[i]['candidates'].extend(cands)
self.result_dct[i]['scores'].extend(scores.cpu().tolist())
self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
def _compute_rank(self, scores, labels):
if not any(labels):
return -1
scores = np.array(scores)
target_score = scores[labels][0]
rank = np.count_nonzero(scores >=target_score)
return rank
def on_test_epoch_end(self) -> None:
self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
# Compute rank
self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
if not self.df_test_path:
self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
# self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
self.df_test.to_pickle(self.df_test_path)
def get_checkpoint_monitors(self) -> T.List[dict]:
monitors = [
{"monitor": f"{Stage.TRAIN.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor train loss
]
return monitors
def _update_loss_weights(self)-> None:
if self.hparams.loss_strategy == 'linear':
for loss in self.loss_wts:
self.loss_wts[loss] += self.loss_updates[loss]
elif self.hparams.loss_strategy == 'manual':
for loss in self.loss_wts:
if self.current_epoch in self.loss_updates[loss]:
self.loss_wts[loss] = self.loss_updates[loss][self.current_epoch]
def on_train_epoch_end(self) -> None:
self._update_loss_weights()
class MultiViewContrastive(ContrastiveModel):
def __init__(self,
**kwargs):
super().__init__(**kwargs)
# build fingerprint encoder model
if self.hparams.use_fp:
self.fp_enc_model = model_utils.get_fp_enc_model(self.hparams)
# build NL encoder model
# if self.hparams.use_NL_spec:
# self.NL_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
def forward(self, batch, stage):
g = batch['cand'] if stage == Stage.TEST else batch['mol']
spec = batch[self.spec_view]
n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
spec_enc = self.spec_enc_model(spec, n_peaks)
mol_enc = self.mol_enc_model(g)
views = {'spec_enc': spec_enc, 'mol_enc': mol_enc}
if self.hparams.use_fp:
fp_enc = self.fp_enc_model(batch['fp'])
views['fp_enc'] = fp_enc
if self.hparams.use_cons_spec:
spec = batch['cons_spec']
n_peaks = batch['cons_n_peaks'] if 'cons_n_peaks' in batch else None
spec_enc = self.cons_spec_enc_model(spec, n_peaks)
views['cons_spec_enc'] = spec_enc
if self.hparams.use_NL_spec:
spec = batch['NL_spec']
n_peaks = batch['NL_n_peaks'] if 'NL_n_peaks' in batch else None
spec_enc = self.NL_enc_model(spec, n_peaks)
views['NL_spec_enc'] = spec_enc
return views
def step(
self, batch: dict, stage= Stage.NONE):
# Compute spectra and mol encoding
views = self.forward(batch, stage)
if stage == Stage.TEST:
return views
# Calculate loss
losses = self.compute_loss(batch, views)
return losses
def compute_loss(self, batch: dict, views: dict):
loss = 0
losses = {}
for v1, v2 in self.hparams.contr_views:
contr_loss, cong_loss, noncong_loss = contrastive_loss(views[v1], views[v2], self.hparams.contr_temp)
loss+=contr_loss
losses[f'{v1[:-4]}-{v2[:-4]}_contr_loss'] = contr_loss.detach().item()
losses[f'{v1[:-4]}-{v2[:-4]}_cong_loss'] = cong_loss.detach().item()
losses[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'] = noncong_loss.detach().item()
losses['loss'] = loss
return losses
def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
# total loss
self.log(
f'{stage.to_pref()}loss',
outputs['loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=True,
on_epoch=True,
# on_step=True
)
for v1, v2 in self.hparams.contr_views:
self.log(
f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_contr_loss',
outputs[f'{v1[:-4]}-{v2[:-4]}_contr_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
on_epoch=True,
)
self.log(
f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_cong_loss',
outputs[f'{v1[:-4]}-{v2[:-4]}_cong_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
on_epoch=True,
)
self.log(
f'{stage.to_pref()}{v1[:-4]}-{v2[:-4]}_noncong_loss',
outputs[f'{v1[:-4]}-{v2[:-4]}_noncong_loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
on_epoch=True,
)
def test_step(self, batch, batch_idx):
# Unpack inputs
identifiers = batch['identifier']
cand_smiles = batch['cand_smiles']
id_to_ct = defaultdict(int)
for i in identifiers: id_to_ct[i]+=1
batch_ptr = torch.tensor(list(id_to_ct.values()))
outputs = self.step(batch, stage=Stage.TEST)
scores = {}
for v1, v2 in self.hparams.contr_views:
# if 'cons_spec_enc' in (v1, v2):
# continue
v1_enc = outputs[v1]
v2_enc = outputs[v2]
s = nn.functional.cosine_similarity(v1_enc, v2_enc)
scores[f'{v1[:-4]}-{v2[:-4]}_scores'] = torch.split(s, list(id_to_ct.values()))
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
labels = utils.unbatch_list(batch['label'], indexes)
return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
# save scores
for i, cands, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels']):
self.result_dct[i]['candidates'].extend(cands)
self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
for v1, v2 in self.hparams.contr_views:
for i, scores in zip(outputs['identifiers'], outputs['scores'][f'{v1[:-4]}-{v2[:-4]}_scores']):
self.result_dct[i][f'{v1[:-4]}-{v2[:-4]}_scores'].extend(scores.cpu().tolist())
def _get_top_cand(self, scores, candidates):
return candidates[np.argmax(np.array(scores))]
def on_test_epoch_end(self) -> None:
self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
# Compute rank
if not self.external_test:
for v1, v2 in self.hparams.contr_views:
self.df_test[f'{v1[:-4]}-{v2[:-4]}_rank'] = self.df_test.apply(lambda row: self._compute_rank(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['labels']), axis=1)
if self.external_test:
self.df_test.drop('labels', axis=1, inplace=True)
for v1, v2 in self.hparams.contr_views:
self.df_test[f'top_{v1[:-4]}-{v2[:-4]}_cand'] = self.df_test.apply(lambda row: self._get_top_cand(row[f'{v1[:-4]}-{v2[:-4]}_scores'], row['candidates']), axis=1)
self.df_test.to_pickle(self.df_test_path)