import typing as T import collections from enum import Enum from abc import ABC, abstractmethod from pathlib import Path import torch import pytorch_lightning as pl from torchmetrics import Metric, SumMetric from massspecgym.utils import ReturnScalarBootStrapper class Stage(Enum): TRAIN = 'train' VAL = 'val' TEST = 'test' NONE = 'none' def to_pref(self) -> str: return f"{self.value}_" if self != Stage.NONE else "" class MassSpecGymModel(pl.LightningModule, ABC): def __init__( self, lr: float = 1e-4, weight_decay: float = 0.0, log_only_loss_at_stages: T.Sequence[Stage | str] = (), bootstrap_metrics: bool = True, df_test_path: T.Optional[str | Path] = None, *args, **kwargs ): super().__init__() self.save_hyperparameters() # Setup metring logging self.log_only_loss_at_stages = [ Stage(s) if isinstance(s, str) else s for s in log_only_loss_at_stages ] self.bootstrap_metrics = bootstrap_metrics # Init dictionary to store dataframe columns where rows correspond to samples # (for constructing test dataframe with predictions and metrics for each sample) self.df_test_path = Path(df_test_path) if df_test_path is not None else None self.df_test = collections.defaultdict(list) @abstractmethod def step( self, batch: dict, stage: Stage = Stage.NONE ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError( "Method `step` must be implemented in the model-specific child class." ) def training_step( self, batch: dict, batch_idx: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: return self.step(batch, stage=Stage.TRAIN) def validation_step( self, batch: dict, batch_idx: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: return self.step(batch, stage=Stage.VAL) def test_step( self, batch: dict, batch_idx: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: return self.step(batch, stage=Stage.TEST) @abstractmethod def on_batch_end( self, outputs: T.Any, batch: dict, batch_idx: int, stage: Stage ) -> None: """ Method to be called at the end of each batch. This method should be implemented by a child, task-dedicated class and contain the evaluation necessary for the task. """ raise NotImplementedError( "Method `on_batch_end` must be implemented in the task-specific child class." ) def on_train_batch_end(self, *args, **kwargs): return self.on_batch_end(*args, **kwargs, stage=Stage.TRAIN) def on_validation_batch_end(self, *args, **kwargs): return self.on_batch_end(*args, **kwargs, stage=Stage.VAL) def on_test_batch_end(self, *args, **kwargs): return self.on_batch_end(*args, **kwargs, stage=Stage.TEST) def configure_optimizers(self): return torch.optim.Adam( self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay ) def get_checkpoint_monitors(self) -> list[dict]: monitors = [ {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": True} ] return monitors def _update_metric( self, name: str, metric_class: type[Metric], update_args: T.Any, batch_size: T.Optional[int] = None, prog_bar: bool = False, metric_kwargs: T.Optional[dict] = None, log: bool = True, log_n_samples: bool = False, bootstrap: bool = False, num_bootstraps: int = 100 ) -> None: """ This method enables updating and logging metrics without instantiating them in advance in the __init__ method. The metrics are aggreated over batches and logged at the end of the epoch. If the metric does not exist yet, it is instantiated and added as an attribute to the model. """ # Process arguments bootstrap = bootstrap and self.bootstrap_metrics # Log total number of samples (useful for debugging) if log_n_samples: self._update_metric( name=name + "_n_samples", metric_class=SumMetric, update_args=(len(update_args[0]),), batch_size=1, ) # Init metric if does not exits yet if hasattr(self, name): metric = getattr(self, name) else: if metric_kwargs is None: metric_kwargs = dict() metric = metric_class(**metric_kwargs) metric = metric.to(self.device) setattr(self, name, metric) # Update metric(*update_args) # Log if log: self.log( name, metric, prog_bar=prog_bar, batch_size=batch_size, on_step=False, on_epoch=True, add_dataloader_idx=False, metric_attribute=name # Suggested by a torchmetrics error ) # Bootstrap if bootstrap: def _bootsrapped_metric_class(**metric_kwargs): metric = metric_class(**metric_kwargs) return ReturnScalarBootStrapper(metric, std=True, num_bootstraps=num_bootstraps) self._update_metric( name=name + "_std", metric_class=_bootsrapped_metric_class, update_args=update_args, batch_size=batch_size, metric_kwargs=metric_kwargs, ) def _update_df_test(self, dct: dict) -> None: for col, vals in dct.items(): if isinstance(vals, torch.Tensor): vals = vals.tolist() self.df_test[col].extend(vals)