| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from sklearn.metrics import accuracy_score | |
| class EarlyStopping: | |
| def __init__(self, patience=5, verbose=False, delta=0): | |
| self.patience = patience | |
| self.verbose = verbose | |
| self.counter = 0 | |
| self.best_score = None | |
| self.early_stop = False | |
| self.delta = delta | |
| def __call__(self, val_loss, model): | |
| score = -val_loss | |
| if self.best_score is None: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model) | |
| elif score < self.best_score + self.delta: | |
| self.counter += 1 | |
| if self.verbose: | |
| print(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| else: | |
| self.best_score = score | |
| self.save_checkpoint(val_loss, model) | |
| self.counter = 0 | |
| def save_checkpoint(self, val_loss, model): | |
| if self.verbose: | |
| print(f'Validation loss decreased ({self.best_score:.6f} --> {val_loss:.6f}). Saving model ...') | |
| torch.save(model.state_dict(), 'checkpoint.pt') | |