from __future__ import print_function
from __future__ import absolute_import
from ...deeptools.dataset import PatternDataset
from ...deeptools.early_stopping import EarlyStopping
from ...deeptools.models import resnet18
from ..analyse.stats import balanced_accuracy
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import time
import copy
import json
import os
from six.moves import range
from six.moves import zip
[docs]class ResnetPatternClassification(object):
'''
ResNet classifier for pattern classification
'''
def __init__(self, bounding_box, pattern=None, cuda=-1, names_filter=None,
lr=0.0001, momentum=0.9, batch_size=10, dict_bck=None,
dict_label=None):
if dict_bck is None:
dict_bck = {}
if dict_label is None:
dict_label = {}
self.bb = bounding_box
self.pattern = pattern
self.names_filter = names_filter
self.lr = lr
self.momentum = momentum
self.batch_size = batch_size
self.dict_bck = dict_bck
self.dict_label = dict_label
self.lr_range = [1e-2, 1e-3, 1e-4, 1e-5]
self.momentum_range = [0.8, 0.7, 0.6, 0.5]
self.patience = 5
self.division = 10
if cuda is -1:
self.device = torch.device("cpu")
else:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu", index=cuda)
print('Working on', self.device)
def learning(self, gfile_list_train, gfile_list_test, y_train, y_test):
print('TRAINING ON %i + %i samples' %
(len(gfile_list_train), len(gfile_list_test)))
print()
print('PARAMETERS')
print('----------')
print('batch_size:', self.batch_size)
print('learning rate:', self.lr, 'momentum', self.momentum)
print('patience:', self.patience, 'division:', self.division)
print()
# DATASET / DATALOADERS
print('Extract validation dataloader...')
valdataset = PatternDataset(
gfile_list_test, self.pattern, self.bb, train=False,
dict_bck=self.dict_bck, dict_label=self.dict_label,
labels=y_test)
valloader = torch.utils.data.DataLoader(
valdataset, batch_size=self.batch_size,
shuffle=False, num_workers=0)
print('Extract train dataloader...')
traindataset = PatternDataset(
gfile_list_train, self.pattern, self.bb, train=True,
dict_bck=self.dict_bck, dict_label=self.dict_label,
labels=y_train)
trainloader = torch.utils.data.DataLoader(
traindataset, batch_size=self.batch_size,
shuffle=False, num_workers=0)
# MODEL
print('Network initialization...')
model = resnet18()
model = model.to(self.device)
# OPTIMIZER
lr = self.lr
optimizer = optim.SGD(model.parameters(), lr=lr,
momentum=self.momentum, nesterov=True)
divide_lr = EarlyStopping(patience=self.patience)
es_stop = EarlyStopping(patience=self.patience*2)
# LOSS FUNCTION
class_sample_count = np.array(
[len(np.where(y_train == t)[0]) for t in np.unique(y_train)])
w = 1. / class_sample_count
w = torch.tensor([w[0], w[1]], dtype=torch.float)
self.criterion = nn.CrossEntropyLoss(weight=w.to(self.device))
# TRAINING
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
num_epochs = 200
print()
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
start_time = time.time()
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
dataloader = trainloader
else:
model.eval() # Set model to evaluate mode
dataloader = valloader
running_loss = 0.0
running_corrects = 0
# Iterate over data.
y_pred, y_true = [], []
for inputs, labels in dataloader:
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = self.criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
y_true.extend(labels.tolist())
y_pred.extend(preds.tolist())
epoch_loss = running_loss / len(dataloader.dataset)
epoch_acc = balanced_accuracy(y_true, y_pred, [0, 1])
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
# deep copy the model
if phase == 'val' and epoch_acc >= best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
# early_stopping
es_stop(epoch_loss, model)
divide_lr(epoch_loss, model)
if divide_lr.early_stop:
print('Divide learning rate by', self.division)
lr = lr/self.division
optimizer = optim.SGD(model.parameters(), lr=lr,
momentum=self.momentum)
divide_lr = EarlyStopping(patience=self.patience)
if es_stop.early_stop:
print("Early stopping")
break
time_elapsed = time.time() - start_time
print('Epoch took {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print()
time_elapsed = time.time() - since
print()
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
self.trained_model = model
def labeling(self, gfile_list, labels=None):
self.trained_model = self.trained_model.to(self.device)
self.trained_model.eval()
dataset = PatternDataset(
gfile_list, self.pattern, self.bb, train=False,
dict_bck=self.dict_bck, dict_label=self.dict_label,
labels=labels)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=self.batch_size,
shuffle=False, num_workers=0)
result = pd.DataFrame(index=[s for s in gfile_list])
start_time = time.time()
with torch.no_grad():
i = 0
for data in dataloader:
print('Labeling (%i/%i)' % (i, len(dataloader)))
start_time = time.time()
inputs, labels = data
inputs, labels = inputs.to(self.device), labels.to(self.device)
outputs = self.trained_model(inputs)
_, preds = torch.max(outputs.data, 1)
# statistics
slist = gfile_list[i*self.batch_size:(i+1)*self.batch_size]
result.loc[slist, 'y_pred'] = preds.tolist()
result.loc[slist, 'y_true'] = labels.tolist()
i += 1
print('Labeling took %i s.' % (time.time()-start_time))
return result
def find_hyperparameters(self, result_matrix, param_outfile, step=0):
# STEP 0
if step == 0:
best_bacc = 0
for lr, j in zip(self.lr_range, range(len(self.lr_range))):
# compute acc
y_true, y_pred = [], []
for i in range(3):
y_true.extend(list(result_matrix[i][j]['y_true']))
y_pred.extend(list(result_matrix[i][j]['y_pred']))
bacc = balanced_accuracy(y_true, y_pred, [0, 1])
print('lr: %f, acc: %f' % (lr, bacc))
# save best acc
if np.mean(bacc) > best_bacc:
best_bacc = bacc
best_lr = lr
if os.path.exists(param_outfile):
with open(param_outfile) as f:
param = json.load(f)
else:
param = {}
param['best_lr0'] = best_lr
param['best_bacc'] = best_bacc
param['bounding_box'] = [list(b) for b in self.bb]
with open(param_outfile, 'w') as f:
json.dump(param, f)
# STEP 1
elif step == 1:
with open(param_outfile) as f:
param = json.load(f)
best_lr0 = param['best_lr0']
best_lr = param['best_lr0']
best_bacc = param['best_bacc']
lr1_range = [best_lr0/4, best_lr0/2, best_lr0*2, best_lr0*4]
for lr, j in zip(lr1_range, range(len(lr1_range))):
# compute acc
y_true, y_pred = [], []
for i in range(3):
y_true.extend(list(result_matrix[i][j]['y_true']))
y_pred.extend(list(result_matrix[i][j]['y_pred']))
bacc = balanced_accuracy(y_true, y_pred, [0, 1])
print('lr: %f, bacc: %f' % (lr, bacc))
# save best acc
if bacc > best_bacc:
best_bacc = np.mean(bacc)
best_lr = lr
param['best_lr1'] = best_lr
param['best_bacc'] = best_bacc
with open(param_outfile, 'w') as f:
json.dump(param, f)
# STEP 2
elif step == 2:
with open(param_outfile) as f:
param = json.load(f)
best_bacc = param['best_bacc']
best_momentum = 0.9
for momentum, j in zip(self.momentum_range,
range(len(self.momentum_range))):
# compute acc
y_true, y_pred = [], []
for i in range(3):
y_true.extend(list(result_matrix[i][j]['y_true']))
y_pred.extend(list(result_matrix[i][j]['y_pred']))
bacc = balanced_accuracy(y_true, y_pred, [0, 1])
print('momentum: %f, bacc: %f' % (momentum, bacc))
# save best acc
if bacc > best_bacc:
best_bacc = bacc
best_momentum = momentum
param['best_momentum'] = best_momentum
with open(param_outfile, 'w') as f:
json.dump(param, f)
# train with best parameters
self.lr = param['best_lr1']
self.momentum = param['best_momentum']
print()
print('Best hyperparameters:',
'learning rate %f, momentum %f, acc %f' %
(self.lr, self.momentum, param['best_bacc']))
print()
def cv_inner(self, gfile_list_train, gfile_list_test, y_train, y_test,
param_outfile, step=0):
# STEP 0
if step == 0:
momentum = 0.9
result_list = []
for lr in self.lr_range:
print()
print('TEST learning rate', lr)
print('======================')
result = self.test_hyperparameters(
lr, momentum,
gfile_list_train, gfile_list_test, y_train, y_test)
result_list.append(result)
# STEP 1
elif step == 1:
with open(param_outfile) as f:
param = json.load(f)
momentum = 0.9
best_lr0 = param['best_lr0']
result_list = []
for lr in [best_lr0/4, best_lr0/2, best_lr0*2, best_lr0*4]:
print()
print('TEST learning rate', lr)
print('======================')
result = self.test_hyperparameters(
lr, momentum,
gfile_list_train, gfile_list_test, y_train, y_test)
result_list.append(result)
# STEP 2
elif step == 2:
with open(param_outfile) as f:
param = json.load(f)
best_lr1 = param['best_lr1']
result_list = []
for momentum in self.momentum_range:
print()
print('TEST momentum', momentum)
print('======================')
result = self.test_hyperparameters(
best_lr1, momentum,
gfile_list_train, gfile_list_test, y_train, y_test)
result_list.append(result)
return result_list
def test_hyperparameters(self, lr, momentum,
gfile_list_train, gfile_list_test,
y_train, y_test):
self.lr = lr
self.momentum = momentum
print()
s = 'TRAIN WITH lr '+str(lr)+' momentum '+str(momentum)
print(s)
print('='*len(s))
self.learning(gfile_list_train, gfile_list_test, y_train, y_test)
print()
print('TEST labeling')
print('=============')
result = self.labeling(gfile_list_test)
return result
def load(self, model_file):
self.trained_model = resnet18()
self.trained_model.load_state_dict(torch.load(
model_file, map_location='cpu'))
self.trained_model = self.trained_model.to(self.device)