Spaces:
Sleeping
Sleeping
| import typing as T | |
| from abc import ABC | |
| import pandas as pd | |
| import torch | |
| from torchmetrics import CosineSimilarity, MeanMetric | |
| from torchmetrics.functional.retrieval import retrieval_hit_rate | |
| from torch_geometric.utils import unbatch | |
| from massspecgym.models.base import MassSpecGymModel, Stage | |
| import massspecgym.utils as utils | |
| class RetrievalMassSpecGymModel(MassSpecGymModel, ABC): | |
| def __init__( | |
| self, | |
| at_ks: T.Iterable[int] = (1, 5, 20), | |
| myopic_mces_kwargs: T.Optional[T.Mapping] = None, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.at_ks = at_ks | |
| self.myopic_mces = utils.MyopicMCES(**(myopic_mces_kwargs or {})) | |
| def on_batch_end( | |
| self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage | |
| ) -> None: | |
| """ | |
| Compute evaluation metrics for the retrieval model based on the batch and corresponding | |
| predictions. | |
| """ | |
| self.log( | |
| f"{stage.to_pref()}loss", | |
| outputs['loss'], | |
| batch_size=batch['spec'].size(0), | |
| sync_dist=True, | |
| prog_bar=True, | |
| ) | |
| if stage in self.log_only_loss_at_stages: | |
| return | |
| metric_vals = {} | |
| metric_vals |= self.evaluate_retrieval_step( | |
| outputs["scores"], | |
| batch["labels"], | |
| batch["batch_ptr"], | |
| stage=stage, | |
| ) | |
| metric_vals |= self.evaluate_mces_at_1( | |
| outputs["scores"], | |
| batch["labels"], | |
| batch["smiles"], | |
| batch["candidates_smiles"], | |
| batch["batch_ptr"], | |
| stage=stage, | |
| ) | |
| if stage == Stage.TEST and self.df_test_path is not None: | |
| self._update_df_test(metric_vals) | |
| def evaluate_retrieval_step( | |
| self, | |
| scores: torch.Tensor, | |
| labels: torch.Tensor, | |
| batch_ptr: torch.Tensor, | |
| stage: Stage, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Main evaluation method for the retrieval models. The retrieval step is evaluated by | |
| computing the hit rate at different top-k values. | |
| Args: | |
| scores (torch.Tensor): Concatenated scores for all candidates for all samples in the | |
| batch | |
| labels (torch.Tensor): Concatenated True/False labels for all candidates for all samples | |
| in the batch | |
| batch_ptr (torch.Tensor): Number of each sample's candidates in the concatenated tensors | |
| """ | |
| # Initialize return dictionary to store metric values per sample | |
| metric_vals = {} | |
| # Evaluate hitrate at different top-k values | |
| indexes = utils.batch_ptr_to_batch_idx(batch_ptr) | |
| scores = unbatch(scores, indexes) | |
| labels = unbatch(labels, indexes) | |
| for at_k in self.at_ks: | |
| hit_rates = [] | |
| for scores_sample, labels_sample in zip(scores, labels): | |
| hit_rates.append(retrieval_hit_rate(scores_sample, labels_sample, top_k=at_k)) | |
| hit_rates = torch.tensor(hit_rates, device=batch_ptr.device) | |
| metric_name = f"{stage.to_pref()}hit_rate@{at_k}" | |
| self._update_metric( | |
| metric_name, | |
| MeanMetric, | |
| (hit_rates,), | |
| batch_size=batch_ptr.size(0), | |
| bootstrap=stage == Stage.TEST | |
| ) | |
| metric_vals[metric_name] = hit_rates | |
| return metric_vals | |
| def evaluate_mces_at_1( | |
| self, | |
| scores: torch.Tensor, | |
| labels: torch.Tensor, | |
| smiles: list[str], | |
| candidates_smiles: list[str], | |
| batch_ptr: torch.Tensor, | |
| stage: Stage, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| TODO | |
| """ | |
| if labels.sum() != len(smiles): | |
| raise ValueError("MCES@1 evaluation currently supports exactly 1 positive candidate per sample.") | |
| # Initialize return dictionary to store metric values per sample | |
| metric_vals = {} | |
| # Get top-1 predicted molecules for each ground-truth sample | |
| smiles_pred_top_1 = [] | |
| batch_ptr = torch.cumsum(batch_ptr, dim=0) | |
| for i, j in zip(torch.cat([torch.tensor([0], device=batch_ptr.device), batch_ptr]), batch_ptr): | |
| scores_sample = scores[i:j] | |
| top_1_idx = i + torch.argmax(scores_sample) | |
| smiles_pred_top_1.append(candidates_smiles[top_1_idx]) | |
| # Calculate MCES distance between top-1 predicted molecules and ground truth | |
| mces_dists = [ | |
| self.myopic_mces(sm, sm_pred) | |
| for sm, sm_pred in zip(smiles, smiles_pred_top_1) | |
| ] | |
| mces_dists = torch.tensor(mces_dists, device=scores.device) | |
| # Log | |
| metric_name = f"{stage.to_pref()}mces@1" | |
| self._update_metric( | |
| metric_name, | |
| MeanMetric, | |
| (mces_dists,), | |
| batch_size=len(mces_dists), | |
| bootstrap=stage == Stage.TEST | |
| ) | |
| metric_vals[metric_name] = mces_dists | |
| return metric_vals | |
| def evaluate_fingerprint_step( | |
| self, | |
| y_true: torch.Tensor, | |
| y_pred: torch.Tensor, | |
| stage: Stage, | |
| ) -> None: | |
| """ | |
| Utility evaluation method to assess the quality of predicted fingerprints. This method is | |
| not a part of the necessary evaluation logic (not called in the `on_batch_end` method) | |
| since retrieval models are not bound to predict fingerprints. | |
| Args: | |
| y_true (torch.Tensor): [batch_size, fingerprint_size] tensor of true fingerprints | |
| y_pred (torch.Tensor): [batch_size, fingerprint_size] tensor of predicted fingerprints | |
| """ | |
| # Cosine similarity between predicted and true fingerprints | |
| self._update_metric( | |
| f"{stage.to_pref()}fingerprint_cos_sim", | |
| CosineSimilarity, | |
| (y_pred, y_true), | |
| batch_size=y_true.size(0), | |
| metric_kwargs=dict(reduction="mean") | |
| ) | |
| def test_step( | |
| self, | |
| batch: dict, | |
| batch_idx: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| outputs = super().test_step(batch, batch_idx) | |
| # Get sorted candidate SMILES based on the predicted scores for each sample | |
| if self.df_test_path is not None: | |
| indexes = utils.batch_ptr_to_batch_idx(batch['batch_ptr']) | |
| scores = unbatch(outputs['scores'], indexes) | |
| candidates_smiles = utils.unbatch_list(batch['candidates_smiles'], indexes) | |
| sorted_candidate_smiles = [] | |
| for scores_sample, candidates_smiles_sample in zip(scores, candidates_smiles): | |
| candidates_smiles_sample = [ | |
| x for _, x in sorted(zip(scores_sample, candidates_smiles_sample), reverse=True) | |
| ] | |
| sorted_candidate_smiles.append(candidates_smiles_sample) | |
| self._update_df_test({ | |
| 'identifier': batch['identifier'], | |
| 'sorted_candidate_smiles': sorted_candidate_smiles | |
| }) | |
| return outputs | |
| def on_test_epoch_end(self): | |
| # Save test data frame to disk | |
| if self.df_test_path is not None: | |
| df_test = pd.DataFrame(self.df_test) | |
| self.df_test_path.parent.mkdir(parents=True, exist_ok=True) | |
| df_test.to_pickle(self.df_test_path) | |