Source code for deepsulci.deeptools.early_stopping

# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import absolute_import
import numpy as np
import torch


[docs]class EarlyStopping(object): """ Early stops the training if validation loss doesn't improve after a given patience. """ def __init__(self, patience=7, verbose=False): """ Args: patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score: self.counter += 1 print('EarlyStopping counter: %i out of %i' % (self.counter, self.patience)) if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0
[docs] def save_checkpoint(self, val_loss, model): '''Saves model when validation loss decrease.''' if self.verbose: print('Validation loss decreased (%.6f -> %.6f). Saving model...' % (self.val_loss_min, val_loss)) torch.save(model.state_dict(), 'checkpoint.pt') self.val_loss_min = val_loss