# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import absolute_import
from ...deeptools.dataset import extract_data
from ..analyse.stats import balanced_accuracy
from soma import aims
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import StratifiedKFold
from sklearn import preprocessing
from sklearn.svm import SVC
from pcl import registration as reg
import numpy as np
import json
import itertools
import pcl
import joblib
from six.moves import zip
[docs]class SVMPatternClassification(object):
'''
SVM classifier for pattern classification
'''
def __init__(self, pattern=None, names_filter=None,
C=1, gamma=0.01, trans=[0],
dict_bck=None, dict_bck_filtered=None,
dict_searched_pattern=None, dict_label=None):
self.pattern = pattern
self.nfilter = names_filter
self.C = C
self.gamma = gamma
self.transrot_init = []
for x in trans:
for y in trans:
for z in trans:
self.transrot_init.append([[1, 0, 0, x],
[0, 1, 0, y],
[0, 0, 1, z],
[0, 0, 0, 1]])
self.C_range = np.logspace(-4, -1, 4)
self.gamma_range = np.logspace(-1, 3, 5)
self.trans_range = [[0],
[-5, 0, 5],
[-10, 0, 10],
[-20, 0, 20],
[-5, -10, 0, 10, 5],
[-20, -10, 0, 10, -20]]
if dict_bck is None:
self.dict_bck = {}
else:
self.dict_bck = dict_bck
if dict_bck_filtered is None:
self.dict_bck_filtered = {}
else:
self.dict_bck_filtered = dict_bck_filtered
if dict_searched_pattern is None:
self.dict_searched_pattern = {}
else:
self.dict_searched_pattern = dict_searched_pattern
if dict_label is None:
self.dict_label = {}
else:
self.dict_label = dict_label
self.bck_filtered_list = []
self.label_list = []
self.searched_pattern_list = []
def learning(self, gfile_list):
self.bck_filtered_list, self.label_list, self.distmap_list = [], [], []
# Extract buckets and labels from the graphs
label = np.NaN if self.pattern is None else 0
for gfile in gfile_list:
if gfile not in self.dict_bck:
graph = aims.read(gfile)
side = gfile[gfile.rfind('/')+1:gfile.rfind('/')+2]
data = extract_data(graph, flip=True if side == 'R' else False)
label = 0
fn, fp = [], []
for name in data['names']:
if name.startswith(self.pattern):
label = 1
fn.append(sum([1 for n in self.names_filter if name.startswith(n)]))
fp.append(1 if name.startswith(self.pattern) else 0)
bck_filtered = np.asarray(data['bck'])[np.asarray(fn) == 1]
spattern = np.asarray(data['bck'])[np.asarray(fp) == 1]
# save data
self.dict_bck[gfile] = data['bck']
self.dict_label[gfile] = label
self.dict_bck_filtered[gfile] = bck_filtered
self.dict_searched_pattern[gfile] = spattern
self.label_list.append(self.dict_label[gfile])
if len(self.dict_searched_pattern[gfile]) != 0:
self.bck_filtered_list.append(self.dict_bck_filtered[gfile])
self.searched_pattern_list.append(
self.dict_searched_pattern[gfile])
# Compute distance matrix
X_train = []
for gfile in gfile_list:
X_train.append(self.compute_distmatrix(self.dict_bck[gfile]))
# Train SVM
self.clf = SVC(C=self.C, gamma=self.gamma, shrinking=True,
class_weight='balanced')
X_train = preprocessing.scale(X_train, axis=0)
self.scaler = preprocessing.StandardScaler().fit(X_train)
X_train = self.scaler.transform(X_train)
self.clf.fit(X_train, self.label_list)
def labeling(self, gfile_list):
y_pred = []
for gfile in gfile_list:
yp = self.subject_labeling(gfile)
y_pred.append(yp)
return y_pred
def subject_labeling(self, gfile):
print('Labeling %s' % gfile)
# Extract bucket
if gfile not in self.dict_bck:
graph = aims.read(gfile)
side = gfile[gfile.rfind('/')+1:gfile.rfind('/')+2]
data = extract_data(graph, flip=True if side == 'R' else False)
sbck = data['bck']
else:
sbck = self.dict_bck[gfile]
sbck = np.array(sbck)
# Compute distance matrix
X_test = [self.compute_distmatrix(sbck)]
# Compute classification
X_test = preprocessing.scale(X_test, axis=0)
X_test = self.scaler.transform(X_test)
ypred = self.clf.predict(X_test)
return ypred[0]
def find_hyperparameters(self, gfile_list, param_outfile):
gfile_list = np.asarray(gfile_list)
best_bacc = 0
best_C, best_gamma, best_trans = self.C, self.gamma, [0]
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=0)
for C, gamma, trans in itertools.product(
self.C_range, self.gamma_range, self.trans_range):
print('------ TEST PARAMETERS ------')
print('C: %r, gamma: %r, trans: %s' %
(C, gamma, str(trans)))
print()
self.C = C
self.gamma = gamma
self.transrot_init = []
for x in trans:
for y in trans:
for z in trans:
self.transrot_init.append([[1, 0, 0, x],
[0, 1, 0, y],
[0, 0, 1, z],
[0, 0, 0, 1]])
y_true, y_pred = [], []
y = [self.dict_label[gfile] for gfile in gfile_list]
y = np.asarray(y)
for train, test in skf.split(gfile_list, y):
print('--- LEARNING (%i samples)' % len(train))
self.learning(gfile_list[train])
print()
print('--- LABELING TEST SET (%i samples)' % len(test))
y_pred_test = self.labeling(gfile_list[test])
y_true.extend(y[test])
y_pred.extend(y_pred_test)
print()
bacc = balanced_accuracy(y_true, y_pred, [0, 1])
if bacc > best_bacc:
best_bacc = bacc
best_C = C
best_gamma = gamma
best_trans = trans
print('--- RESULT')
print('%0.2f for C=%r, gamma=%r, tr=%s' %
(bacc, C, gamma, str(trans)))
print()
print()
print('Best parameters set found on development set:')
print('C=%r, gamma=%r, tr=%s' % (best_C, best_gamma, str(best_trans)))
print()
self.C = best_C
self.gamma = best_gamma
self.transrot_init = []
for x in best_trans:
for y in best_trans:
for z in best_trans:
self.transrot_init.append([[1, 0, 0, x],
[0, 1, 0, y],
[0, 0, 1, z],
[0, 0, 0, 1]])
param = {'C': best_C,
'gamma': best_gamma,
'trans': best_trans,
'names_filter': self.nfilter,
'best_bacc': best_bacc,
'pattern': self.pattern}
with open(param_outfile, 'w') as f:
json.dump(param, f)
def compute_distmatrix(self, sbck):
X = []
for bck_filtered, searched_pattern in zip(self.bck_filtered_list,
self.searched_pattern_list):
pc1 = pcl.PointCloud(np.asarray(sbck, dtype=np.float32))
# Try different initialization and select the best score
dist_min = 100.
for trans in self.transrot_init:
pc2 = pcl.PointCloud(np.asarray(apply_trans(
trans, bck_filtered), dtype=np.float32))
bool, transrot, trans_pc2, d = reg.icp(pc2, pc1)
if d < dist_min:
dist_min = d
transrot_min = np.dot(transrot, trans)
trans_searched_pattern = apply_trans(
transrot_min, searched_pattern)
X.append(distance_data_to_model(trans_searched_pattern, sbck))
return X
def save(self, clf_file, scaler_file):
joblib.dump(self.clf, clf_file)
joblib.dump(self.scaler, scaler_file)
def load(self, clf_file, scaler_file):
self.clf = joblib.load(clf_file)
self.scaler = joblib(scaler_file)
def apply_trans(transrot, data):
data = np.asarray(data)
data_tmp = np.vstack((data.T, np.ones(data.shape[0])))
new_data_tmp = np.dot(transrot, data_tmp)
new_data = new_data_tmp[:3].T
return new_data
def distance_data_to_model(data, model):
nbrs = NearestNeighbors(n_neighbors=1, algorithm='auto').fit(model)
distances, indices = nbrs.kneighbors(data)
dist = (distances**2).sum()/len(distances)
if (dist < 0.000001):
dist = 0
return dist