Spaces:
Runtime error
Runtime error
| """ | |
| Custom Lightning callbacks | |
| """ | |
| import lightning as L | |
| from lightning.pytorch.callbacks import Callback | |
| import torch | |
| from typing import Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class CustomCallback(Callback): | |
| """Custom callback for monitoring training progress""" | |
| def __init__(self): | |
| super().__init__() | |
| self.start_time = None | |
| def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: | |
| """Called when training starts""" | |
| import time | |
| self.start_time = time.time() | |
| logger.info("Training started") | |
| # Log model info | |
| total_params = sum(p.numel() for p in pl_module.parameters()) | |
| trainable_params = sum(p.numel() for p in pl_module.parameters() if p.requires_grad) | |
| logger.info(f"Total parameters: {total_params:,}") | |
| logger.info(f"Trainable parameters: {trainable_params:,}") | |
| logger.info(f"Trainable ratio: {trainable_params/total_params:.2%}") | |
| def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: | |
| """Called when training ends""" | |
| if self.start_time: | |
| import time | |
| duration = time.time() - self.start_time | |
| logger.info(f"Training completed in {duration:.2f} seconds") | |
| def on_train_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: | |
| """Called at the start of each training epoch""" | |
| logger.info(f"Starting epoch {trainer.current_epoch + 1}/{trainer.max_epochs}") | |
| def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: | |
| """Called at the end of validation epoch""" | |
| if trainer.logged_metrics: | |
| val_loss = trainer.logged_metrics.get("val/loss", None) | |
| if val_loss is not None: | |
| logger.info(f"Validation loss: {val_loss:.4f}") | |
| class MemoryMonitorCallback(Callback): | |
| """Monitor GPU memory usage during training""" | |
| def __init__(self, log_every_n_steps: int = 100): | |
| super().__init__() | |
| self.log_every_n_steps = log_every_n_steps | |
| def on_train_batch_end( | |
| self, | |
| trainer: L.Trainer, | |
| pl_module: L.LightningModule, | |
| outputs: Any, | |
| batch: Any, | |
| batch_idx: int | |
| ) -> None: | |
| """Log memory usage""" | |
| if batch_idx % self.log_every_n_steps == 0 and torch.cuda.is_available(): | |
| memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB | |
| memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB | |
| pl_module.log("train/memory_allocated_gb", memory_allocated, on_step=True) | |
| pl_module.log("train/memory_reserved_gb", memory_reserved, on_step=True) | |