Source code for soma.aimsalgo.sulci.graph_pointcloud

#!/usr/bin/env python
# -*- coding: utf-8

from __future__ import print_function
from __future__ import absolute_import
from soma import aims, aimsalgo
import pandas
import numpy as np
import subprocess
import six
from six.moves import range
from six.moves import zip


[docs]def get_vertex_points(v): ''' get array of voxels coordinates in buckets of the given graph node. (looks in ss, bottom, other buckets) ''' ss = v.get('aims_ss') bottom = v.get('aims_bottom') other = v.get('aims_other') ss_pts = np.zeros((4, 0), dtype=int) bottom_pts = np.zeros((4, 0), dtype=int) other_pts = np.zeros((4, 0), dtype=int) if ss is not None and len(ss[0]) != 0: ss_pts = np.array(list(ss[0].keys())) ss_pts = np.hstack((ss_pts, np.zeros((ss_pts.shape[0], 1), dtype=int))).T if bottom is not None and len(bottom[0]) != 0: bottom_pts = np.array(list(bottom[0].keys())) bottom_pts = np.hstack((bottom_pts, np.zeros((bottom_pts.shape[0], 1), dtype=int))).T if other is not None and len(other[0]) != 0: other_pts = np.array(list(other[0].keys())) other_pts = np.hstack((other_pts, np.zeros((other_pts.shape[0], 1), dtype=int))).T apts = np.hstack((ss_pts, bottom_pts, other_pts)) return apts
[docs]def maj_label(labels, pts, avol, get_counts=False): ''' Majority label in region pts of label volume avol. ''' counts = [len(np.where(avol[pts] == l)[0]) for l in labels] imax = np.argmax(counts) if get_counts: return (labels[imax], counts) return labels[imax]
[docs]def build_split_graph(graph, data, roots, skel=None): ''' Split graph vertices according to sub-vertex label classification (from Léonie Borne's automatic labelings). Parameters ---------- graph: aims.Graph cortical folds graph data: pandas.DataFrame voxel-wise data. Row should include "point_x", "point_y", "point_z" for voxels coordinates *in Talairach space* (they are transformed back into native space during the process), and "after_cutting" for labels. roots: aims.Volume_S16 folds skeleton roots regions, from Morphologist pipeline. skel: aims.Volume_S16 (optional) folds skeleton, from Morphologist pipeline. If provided, the skeleton will be modified with added junction points. Returns ------- graph: aims.Graph the input graph, modified. Note that the input graph is not copied but modified in-place, and the same instance is returned. summary: dict 'mislabeled': number of voxels which could keep their label (small parts which cannot make a vertex, or too complex geometry to allow a nice cut). 'cuts': number of cuts actually performed. 'failed_cuts': number of cut attemps which have failed. 'changed_vertices': number of vertices which needed rework (even if cuts failed and nothing was actually done) ''' JUNCTION = 80 min_size = 20 vs = graph['voxel_size'] tal_tr = aims.GraphManip.talairach(graph) tal_inv = tal_tr.inverse() askel = None if skel is not None: askel = np.asarray(skel) lvol = aims.Volume(roots) lvol.fill(0) # get native coords from data coords = data.loc[:, ['point_x', 'point_y', 'point_z']].values nat_coords = tal_inv.toMatrix().dot( np.vstack((coords.T, np.ones(coords.shape[0])))) int_coords = np.round(nat_coords / np.expand_dims(vs[:3] + [1.], 1)).astype(int) int_coords[3,:] = 0 # build labels volume labels = np.unique(data.after_cutting) labels_map = dict([(l, i + 1) for i, l in enumerate(sorted(labels))]) labels_rmap = dict([(i + 1, l) for i, l in enumerate(sorted(labels))]) labels_rmap[0] = 'unknown' # in case it is not already here labels_int = np.array([labels_map[l] for l in data.after_cutting]) avol = np.asarray(lvol) avol[tuple(int_coords)] = labels_int # build list of vertices to split todo = [] for v in graph.vertices(): apts = get_vertex_points(v) pts = tuple(apts) labels = np.unique(avol[pts]) if len(labels) in (0, 1): # unique label in vertex if len(labels) == 1: label = labels[0] v['label'] = six.ensure_text(labels_rmap.get(label, 'unknown')) continue todo.append(v) aroots = np.asarray(roots) roots_new = np.max(aroots) + 1 mislabeled = 0 cuts = 0 failed_cuts = 0 changed_vertices = len(todo) # split groups (vertices split from the same initial one) split_groups = {} # split each selected vertex while todo: v = todo.pop(0) apts = get_vertex_points(v) pts = tuple(apts) labels = np.unique(avol[pts]) winner, sizes = maj_label(labels, pts, avol, True) print('split vertex', v.get('index'), 'in:', labels, [labels_rmap.get(l, 'unknown') for l in labels], ', size:', apts.shape[1], tuple(sizes)) # grow a voronoi inside the roots region roots_val = aroots[tuple(apts[:,0])] roots_nval = list(range(roots_new, roots_new + len(labels))) # print skel new values into roots for i, l in enumerate(labels): lp = apts[:, np.where(avol[pts] == l)[0]] aroots[tuple(lp)] = roots_nval[i] fm = aims.FastMarching('6') # TODO: check if 6 connectivity is OK fm.doit(roots, [roots_val], roots_nval) voronoi = fm.voronoiVol() avor = np.asarray(voronoi) # reseet aroots for now, we'll do it later with actual split for l in reversed(roots_nval): aroots[avor==l] = roots_val convmask = ([-1, 0, 0, 0], [1, 0, 0, 0], [0, -1, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 1, 0]) # 6-connectivity split_bk = aims.BucketMap_VOID() split_bk.header()['voxel_size'] = vs sb0 = split_bk[0] for p in apts.T: p2 = np.vstack([p+m for m in convmask]).T vals = np.unique(avor[tuple(p2)]) if -1 in vals: # remove background value vals = [val for val in vals if val >= 0] if len(vals) > 1: if skel is not None: askel[tuple(p)] = JUNCTION sb0[p[:3]] = 1 # print('split points:', len(sb0)) v2 = aims.FoldArgOverSegment(graph).splitVertex(v, split_bk, min_size) if v2 is None: print('split failed.') v['label'] = six.ensure_text(labels_rmap[winner]) failed_cuts += 1 loc_mis = sum(sizes) - sizes[np.where(labels == winner)[0][0]] print('mislabeled:', loc_mis) mislabeled += loc_mis else: cuts += 1 # update split groups (vertices split from the same initial one) split_group = split_groups.setdefault(v, set()) split_groups[v2] = split_group split_group.add(v) split_group.add(v2) apts = get_vertex_points(v) apts2 = get_vertex_points(v2) pts = tuple(apts) labels = np.unique(avol[pts]) print('new vertices.', len(pts[0]), apts2.shape[1]) winner, sizes = maj_label(labels, pts, avol, True) if len(labels) != 1: print('split vertex doesn\'t have one label:', labels, sizes) loc_mis = sum(sizes) - sizes[np.where(labels == winner)[0][0]] if loc_mis > 50: todo.append(v) # do it again else: mislabeled += loc_mis print(' good enough.') v['label'] = six.ensure_text(labels_rmap[winner]) print('v1 label:', v['label']) # new voronoi in aroots aroots[pts] = roots_new + 1 apts = apts2 pts = tuple(apts) labels = np.unique(avol[pts]) winner, sizes = maj_label(labels, pts, avol, True) if len(labels) != 1: print('split vertex2 doesn\'t have one label:', labels, sizes) loc_mis = sum(sizes) - sizes[np.where(labels == winner)[0][0]] if loc_mis > 50: todo.append(v2) # do it again else: mislabeled += loc_mis print(' good enough.') v2['label'] = six.ensure_text(labels_rmap[maj_label( labels, pts, avol)]) print('v2 label:', v2['label']) aroots[pts] = roots_new # redo voronoi with new labels fm = aims.FastMarching('6') # TODO: check if 6 connectivity is OK fm.doit(roots, [roots_val], [roots_new, roots_new + 1]) voronoi = fm.voronoiVol() avor = np.asarray(voronoi) # change values, v gets initial roots_val, v2 gets roots_new aroots[avor==roots_new + 1] = roots_val roots_new += 1 # fusion pass: in each split group, merge vertices which share the same # label and are adjacent merged_vertices = set() for split_group in split_groups.values(): labels = {} for v in split_group: labels.setdefault(v['label'], []).append(v) if len(labels) == len(split_group): # all vertices have different labels: skip this step continue for label, vertices in labels.items(): # vertices that have been merged in a previous step should not be # considered (they are dangling pointers on the C++ level, and # calling v.edges() on them can trigger a segmentation fault, see # https://github.com/brainvisa/aims-free/issues/96) vertices = set(vertices) - merged_vertices while len(vertices) >= 2: v = next(iter(vertices)) # check junctions junctions = [j for j in v.edges() if j.getSyntax() == 'junction' and all(v2 in vertices for v2 in j.vertices())] if len(junctions) == 0: vertices.remove(v) else: # merge v and 1st connected other vertex v2 = [v3 for v3 in junctions[0].vertices() if v3 is not v][0] # v2 will disappear vertices.remove(v2) merged_vertices.add(v2) aims.FoldArgOverSegment(graph).mergeVertices(v, v2) del v2 # do v again next time since it may have other edges print('mislabeled points:', mislabeled) print() summary = { 'mislabeled': mislabeled, 'cuts': cuts, 'failed_cuts': failed_cuts, 'changed_vertices': changed_vertices, } return graph, summary
if __name__ == '__main__': subject = 'anubis' center = 'panabase' side = 'L' attribs = {'subject': subject, 'center': center, 'side': side} in_file_tpl = '/neurospin/lnao/Panabase/lborne/results/sulci_recognition/database_learnclean/learning_database_learnclean/dropnet_new/left/net/result_%(side)s%(subject)s.csv' graph_file_tpl = '/neurospin/lnao/PClean/database_learnclean/%(center)s/%(subject)s/t1mri/t1/default_analysis/folds/3.3/base2018_manual/%(side)s%(subject)s_base2018_manual.arg' skel_file_tpl = '/neurospin/lnao/PClean/database_learnclean/%(center)s/%(subject)s/t1mri/t1/default_analysis/segmentation/%(side)sskeleton_%(subject)s.nii.gz' roots_file_tpl = '/neurospin/lnao/PClean/database_learnclean/%(center)s/%(subject)s/t1mri/t1/default_analysis/segmentation/%(side)sroots_%(subject)s.nii.gz' output_graph_tpl = '/tmp/%(side)s%(subject)s.arg' subjects_map = { 'panabase': [ 'ammon', 'anubis', 'athena', 'atlas', 'beflo', 'caca', 'cronos', 'demeter', 'eros', 'hades', 'horus', 'hyperion', 'icbm100T', 'icbm125T', 'icbm200T', 'icbm201T', 'icbm300T', 'icbm310T', 'icbm320T', 'isis', 'jah2', 'jason', 'jupiter', 'moon', 'neptune', 'osiris', 'poseidon', 'ra', 'shiva', 'vayu', 'vishnu', 'zeus' ] } sides = ('L', ) #, 'R') elements = [{'center': center, 'subject': subject, 'side': side} for center, sdat in six.iteritems(subjects_map) for subject in sdat for side in sides] # FIXME DEBUG # elements = elements[:10] print('elements:') print(elements) mislabeled = 0 cuts = 0 failed_cuts = 0 changed_vertices = 0 summary = [] for item in elements: attribs.update(item) print(item) in_file = in_file_tpl % attribs graph_file = graph_file_tpl % attribs skel_file = skel_file_tpl % attribs roots_file = roots_file_tpl % attribs output_graph = output_graph_tpl % attribs data = pandas.read_csv(in_file) graph = aims.read(graph_file) # skel = aims.read(skel_file) skel = None roots = aims.read(roots_file) graph, loc_summary = build_split_graph(graph, data, roots, skel=skel) summary.append(loc_summary) mislabeled += loc_summary['mislabeled'] cuts += loc_summary['cuts'] failed_cuts += loc_summary['failed_cuts'] changed_vertices += loc_summary['changed_vertices'] graph['label_property'] = 'label' # auto labeling aims.write(graph, output_graph) print('\nReport by subject:') for item, report in zip(elements, summary): print(item) print(report) print('\nSummary:') print('mislabeled:', mislabeled) print('changed_vertices:', changed_vertices) print('cuts:', cuts) print('failed_cuts:', failed_cuts)