# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import absolute_import
from torch.utils.data import Dataset
from soma import aims
import numpy as np
import torch
import random
import math
import sigraph
from six.moves import range
from six.moves import zip
[docs]class SulciDataset(Dataset):
'''
Sulci dataset class
'''
def __init__(self, gfile_list, dict_sulci,
train=True, translation_file=None,
dict_bck2={}, dict_names={}):
self.gfile_list = gfile_list
self.dict_sulci = dict_sulci
if 'background' not in self.dict_sulci:
self.dict_sulci['background'] = -1
self.train = train
self.rot_angle = math.pi/16
self.translation_file = translation_file
self.dict_bck2 = dict_bck2
self.dict_names = dict_names
def transform(self, bck):
# rotation
if self.rot_angle is not None:
center = (np.max(bck, axis=0) - np.min(bck, axis=0))/2
transrot = random_rotation(center, self.rot_angle)
bck = rotation_bck(bck, transrot)
# translation
tr = np.min(bck, axis=0)
bck -= tr
return bck
def __getitem__(self, index):
gfile = self.gfile_list[index]
if gfile in self.dict_bck2:
bck2 = np.asarray(self.dict_bck2[gfile])
names = np.asarray(self.dict_names[gfile])
else:
graph = aims.read(gfile)
if self.translation_file is not None:
flt = sigraph.FoldLabelsTranslator()
flt.readLabels(self.translation_file)
flt.translate(graph)
# extract bucket/names
data = extract_data(graph)
bck2 = np.asarray(data['bck2'])
names = np.asarray(data['names'])
self.dict_bck2[gfile] = bck2
self.dict_names[gfile] = names
tr = np.min(bck2, axis=0)
bck2 = bck2 - tr
# data augmentation
if self.train:
bck2 = self.transform(bck2)
# input/label
bck2 = np.array(bck2, dtype=int)
bck_T = bck2.transpose()
img_size = np.max(bck2, axis=0) + 1
input = torch.zeros(
1, img_size[0], img_size[1], img_size[2], dtype=torch.float)
input[0][bck_T[0], bck_T[1], bck_T[2]] = 1
labels = torch.zeros(
img_size[0], img_size[1], img_size[2], dtype=torch.long)
labels += self.dict_sulci['background']
labels[bck_T[0], bck_T[1], bck_T[2]] = torch.tensor(
[self.dict_sulci[n] for n in names], dtype=torch.long)
return input, labels
def __len__(self):
return len(self.gfile_list)
[docs]class PatternDataset(Dataset):
'''
Pattern dataset (for pattern classification)
'''
def __init__(self, gfile_list, pattern, bb, train=True,
dict_bck={}, dict_label={}, labels=None):
self.gfile_list = gfile_list
self.labels = labels
self.pattern = pattern
self.bb = np.array(bb)
self.size = self.bb[:, 1] - self.bb[:, 0] + 1
self.tr = self.bb[:, 0]
self.rot_angle = math.pi/40
self.tr_sigma = 2
self.train = train
self.dict_bck = dict_bck
self.dict_label = dict_label
def transform(self, bck):
# rotation
if self.rot_angle is not None:
center = (np.max(bck, axis=0) - np.min(bck, axis=0))/2
transrot = random_rotation(center, self.rot_angle)
bck = rotation_bck(bck, transrot)
# translation
tr = [int(round(np.random.normal(0, self.tr_sigma))),
int(round(np.random.normal(0, self.tr_sigma))),
int(round(np.random.normal(0, self.tr_sigma)))]
bck += tr
return bck
def __getitem__(self, index):
gfile = self.gfile_list[index]
# Bucket extraction
if gfile in self.dict_bck:
bck = self.dict_bck[gfile]
label = self.dict_label[gfile]
else:
side = gfile[gfile.rfind('/')+1:gfile.rfind('/')+2]
flip = True if side == 'R' else False
graph = aims.read(gfile)
data = extract_data(graph, flip=flip)
bck = data['bck2']
if self.labels is not None:
label = self.labels[index]
elif self.pattern is None:
label = np.NaN
else:
label = 0
for n in data['names']:
if n.startswith(self.pattern):
label = 1
break
self.dict_bck[gfile] = bck
self.dict_label[gfile] = label
# Data augmentation
if self.train:
bck = self.transform(bck)
# Pytorch output
bck, _ = apply_bounding_box(bck, self.bb)
bck -= self.tr
values = np.ones(len(bck))
bck_T = bck.transpose()
input = torch.zeros(
1, self.size[0], self.size[1], self.size[2], dtype=torch.float)
input[0][bck_T[0], bck_T[1], bck_T[2]] = torch.tensor(
values, dtype=torch.float)
return input, label
def __len__(self):
return len(self.gfile_list)
[docs]def apply_bounding_box(points, bb):
'''
Crop points with a bounding box
'''
bb = np.asarray(bb)
points = np.asarray(points)
inidx = np.all(np.logical_and(bb[:, 0] <= points, points <= bb[:, 1]),
axis=1)
inbox = points[inidx]
return inbox, np.asarray(range(len(points)))[inidx]
[docs]def unit_vector(data, axis=None, out=None):
"""Return ndarray normalized by length, i.e. Euclidean norm, along axis.
>>> v0 = numpy.random.random(3)
>>> v1 = unit_vector(v0)
>>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
True
>>> v0 = numpy.random.rand(5, 4, 3)
>>> v1 = unit_vector(v0, axis=-1)
>>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
>>> numpy.allclose(v1, v2)
True
>>> v1 = unit_vector(v0, axis=1)
>>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
>>> numpy.allclose(v1, v2)
True
>>> v1 = numpy.empty((5, 4, 3))
>>> unit_vector(v0, axis=1, out=v1)
>>> numpy.allclose(v1, v2)
True
>>> list(unit_vector([]))
[]
>>> list(unit_vector([1]))
[1.0]
"""
if out is None:
data = np.array(data, dtype=np.float64, copy=True)
if data.ndim == 1:
data /= math.sqrt(np.dot(data, data))
return data
else:
if out is not data:
out[:] = np.array(data, copy=False)
data = out
length = np.atleast_1d(np.sum(data*data, axis))
np.sqrt(length, length)
if axis is not None:
length = np.expand_dims(length, axis)
data /= length
if out is None:
return data
[docs]def rotation_matrix(angle, direction, point=None):
"""Return matrix to rotate about axis defined by point and direction.
>>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0])
>>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1])
True
>>> angle = (random.random() - 0.5) * (2*math.pi)
>>> direc = numpy.random.random(3) - 0.5
>>> point = numpy.random.random(3) - 0.5
>>> R0 = rotation_matrix(angle, direc, point)
>>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
>>> is_same_transform(R0, R1)
True
>>> R0 = rotation_matrix(angle, direc, point)
>>> R1 = rotation_matrix(-angle, -direc, point)
>>> is_same_transform(R0, R1)
True
>>> I = numpy.identity(4, numpy.float64)
>>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
True
>>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2,
... direc, point)))
True
"""
sina = math.sin(angle)
cosa = math.cos(angle)
direction = unit_vector(direction[:3])
# rotation matrix around unit vector
R = np.diag([cosa, cosa, cosa])
R += np.outer(direction, direction) * (1.0 - cosa)
direction *= sina
R += np.array([[0.0, -direction[2], direction[1]],
[direction[2], 0.0, -direction[0]],
[-direction[1], direction[0], 0.0]])
M = np.identity(4)
M[:3, :3] = R
if point is not None:
# rotation not around origin
point = np.array(point[:3], dtype=np.float64, copy=False)
M[:3, 3] = point - np.dot(R, point)
return M
[docs]def random_rotation(center, rot_angle):
'''
Apply a random rotation (random axis and angle) around a given center
'''
th = random.uniform(0, 2*math.pi)
z = random.uniform(-1, 1)
direction = [np.sqrt(1-z**2)*np.cos(th), np.sqrt(1-z**2)*np.sin(th), z]
transrot = rotation_matrix(
np.random.normal(0, rot_angle),
direction,
np.array(center))
return transrot
[docs]def rotation_bck(bck, transrot):
'''
Apply a rotation to a bucket (set of voxels)
'''
bck_4d = np.vstack([np.transpose(bck), np.ones(len(bck))])
bck_4d_trans = np.dot(transrot, bck_4d)
bck = np.transpose(bck_4d_trans[:-1])
bck = np.array(bck, dtype=int)
return bck