from __future__ import print_function
from __future__ import absolute_import
from soma import aims, aimsalgo
import numpy as np
import six
from six.moves import zip
inspect = () # (70989, 70988)
[docs]def mesh_skeleton(mesh, texture, curv_func=None, dist_tex=None,
do_timesteps=False, min_cc_size=20, min_branch_size=20,
debug_inspect=()):
'''
Process a skeleton of an object given as a binary texture.
The current algorithm is rather simple, it erodes vertices iteratively in a
given order until a vertex is "blocked" based on a curvature-like criterion
funcion. The mesh vertices position and curvature are not used directly in
the algorithm.
Parameters
----------
mesh: aims.AimsSurfaceTriangle
triangular mesh to buils the skeleton on
texture: aims.TimeTexture (int values)
input object definition: binary object in a texture, all non-zero
values are considered in the object
curv_func: function
"curvature" function which returns a value for each point, deciding if
the point can be removed from the object (eroded), normally based on
curvature or "sharpness". A negative value means that the point cannot
be removed. The default function allows to remove a point unless it is
a "sharp edge", connected to the object only by one triangle edge.
dist_tex: aims.TimeTexture (float values)
distance-like map inside the object, deciding the priority of eroded
points. Points with the lowest values will be processed first.
Typically if we want to build the skeleton of a thresholded curvature
or a depth potential function (DPF), texture will be this binarized
texture, and dist_tex will be the curvature or DPF texture itself.
do_timesteps: bool
if True, the output texture will have one timestep per front
propagation iteration
min_cc_size: int
small connected components can be removed afterwards. Such trimming
only happens if do_timesteps is False.
min_branch_size: int
small branches can be pruned afterwards. Such trimming only happens if
do_timesteps is False.
debug_inspect: sequence (preferably set) of ints
list of vertices for which debug information will be printed on the
standard output. Useful to understand what happens there.
Returns
-------
skel_tex: aims.TimeTexture_S16
output skeleton texture. Value 0 is the background, 1 is the skeleton.
If do_timesteps is True, then one timestep per propagation step will be
found in the texture, and value 2 will be used for the object interior
(not belonging to the propagation front in the current step)
'''
if curv_func is None:
curv_func = sharp_curve_func
vert = mesh.vertex()
nvert = np.asarray(vert)
ntex = np.array(texture[0], copy=True)
ntex[ntex!=0] = 1 # binarize ntex
neigh = aims.SurfaceManip.surfaceNeighbours(mesh)
out_tex = aims.TimeTexture(np.int16)
out_tex[0].assign([0] * len(ntex))
fg = np.where(ntex!=0)[0]
#bg = np.where(ntex==0)[0]
front = fg[[np.any(ntex[list(neigh[i])] == 0) for i in fg]]
frozen = False
p = 0
while not frozen:
frozen = True
print('iter %d: front: %d' % (p, len(front)))
# order front point by potential priority
front = sort_potential(front, dist_tex)
# iterate front
new_front = []
if do_timesteps:
out_tex[p].assign(np.zeros(ntex.shape, dtype=np.int16))
np.asarray(out_tex[p])[ntex!=0] = 2
np.asarray(out_tex[p])[front] = 1
ntex[front] = 2 # front points are part of the object
cfront = [curv_func(nvert, ntex, neigh, v, texture) for v in front]
for v, c in zip(front, cfront):
#c = curv_func(nvert, ntex, neigh, v, texture)
v2 = can_move(mesh, ntex, v, neigh, c, debug_inspect)
if v in debug_inspect:
print('v:', v, ', c:', c, ', v2:', v2)
if v2 is None:
new_front.append(v) # v still in front line
else:
# remove point v
ntex[v] = 0
v2 = [x for x in v2 if x not in front and x not in new_front]
new_front += v2
ntex[v2] = 2
frozen = False
ntex[front] = 0
front = new_front
if do_timesteps:
p += 1
print('skeleton:', len(front))
out_tex[p].assign(np.zeros(ntex.shape, dtype=np.int16))
np.asarray(out_tex[p])[ntex!=0] = 2
np.asarray(out_tex[p])[front] = 1
if not do_timesteps and (min_cc_size != 0 or min_branch_size != 0):
out_tex = trim_skeleton(mesh, out_tex, min_cc_size=min_cc_size,
min_branch_size=min_branch_size)
return out_tex
[docs]def sharp_curve_func(nvert, ntex, neigh, v, dist_tex):
'''
Default "curvature" function used as "curv_func" in :func:`mesh_skeleton`.
Returns a negative value if the given vertex should not be removed
(eroded). The current implementation freezes a vertex if it has only one
neighbor in the object.
'''
n_v = np.array(list(neigh[v]))
nval = ntex[n_v]
active = n_v[nval!=0]
if len(active) < 2:
return 0
return 1
def _same_cc(active, neigh):
s = set([active[0]])
todo = list(active[1:])
while todo:
#print(' samecc todo:', todo)
todo_next = []
n = len(todo)
while todo:
p = todo.pop(0)
#print('test:', p)
for nb in neigh[p]:
if nb in s:
s.add(p)
#print('add', p)
break
else:
#print('keep', p)
todo_next.append(p) # back in todo list
if len(todo_next) == n:
# no change this pass
#print('no change')
return False
todo = todo_next
return True
def can_move(mesh, ntex, v, neigh, c, debug_inspect):
if c <= 0:
return None
n_v = np.array(list(neigh[v]))
nval = ntex[n_v]
active = n_v[nval!=0]
n = len(active)
if v in debug_inspect:
print(' can_move:', v, ':', n, active, n_v)
if n == 0:
return None
elif n == 1:
return active
else:
if _same_cc(active, neigh) == 1:
return active
return None
[docs]def sort_potential(front, texture):
'''
Sort front points list according to texture value
'''
return sorted(front, key=lambda v: texture[0][v])
[docs]def trim_skeleton(mesh, skeleton, min_cc_size=20, min_branch_size=20):
'''
Trim a skeleton texture by removing small connected components and small
branches.
'''
#print('trim_skeleton')
trimmed = aims.TimeTexture(skeleton)
cc = aimsalgo.AimsMeshLabelConnectedComponent(mesh, skeleton,
0, 0)
ncc = np.asarray(cc[0])
nskel = np.asarray(trimmed[0])
values = np.unique(ncc)[1:]
sizes = [len(np.where(ncc==i)[0]) for i in values]
for i, v in enumerate(values):
if sizes[i] < min_cc_size:
print('remove:', v, ':', sizes[i], i)
nskel[ncc==v] = 0
pruned = prune_branches(mesh, trimmed, min_branch_size=min_branch_size)
return pruned
[docs]def topo_mark(mesh, texture, neigh=None):
'''
Mark skeleton vertices according to their topological type:
* 1: end point
* 2: line point
* 3: bifurcation
'''
if neigh is None:
neigh = aims.SurfaceManip.surfaceNeighbours(mesh)
ntex = np.asarray(texture[0])
otex = aims.TimeTexture(texture)
notex = np.asarray(otex[0])
todo = np.where(ntex != 0)[0]
for v in todo:
n_v = np.array(list(neigh[v]))
nval = ntex[n_v]
active = n_v[nval!=0]
n = len(active)
if n <= 1:
notex[v] = 2 # end point
elif n > 2:
notex[v] = 3 # bifurcation
return otex
[docs]def prune_branches(mesh, texture, min_branch_size=20, neigh=None):
'''
Prune the smallest branches in a skeleton texture
'''
if neigh is None:
neigh = aims.SurfaceManip.surfaceNeighbours(mesh)
topo = topo_mark(mesh, texture, neigh)
ntopo = np.asarray(topo[0])
br = aims.TimeTexture(topo)
nbr = np.asarray(br[0])
nbr[nbr!=1] = 0
cc = aimsalgo.AimsMeshLabelConnectedComponent(mesh, br, 0, 0)
ncc = np.asarray(cc[0])
br = aims.TimeTexture(topo)
nbr = np.asarray(br[0])
nbr[nbr!=3] = 0
cc_bif = aimsalgo.AimsMeshLabelConnectedComponent(mesh, br, 0, 0)
ncc_bif = np.asarray(cc_bif[0])
ntex = np.asarray(texture[0])
otex = aims.TimeTexture(texture)
notex = np.asarray(otex[0])
end_points = np.where(np.asarray(topo[0]) == 2)[0]
branches_per_cc = {}
for v in end_points:
n_v = np.array(list(neigh[v]))
nval = ntex[n_v]
obj = n_v[nval != 0][0]
ccobj = ncc[obj] # (may be -1 for bifurcation)
if ccobj == -1:
branches_per_cc.setdefault(ncc_bif[obj], []).append((v, ccobj))
else:
# find bifurcation attached to this cc
for v2 in np.where(ncc==ccobj)[0]:
n_v2 = np.array(list(neigh[v2]))
nval2 = ntopo[n_v2]
obj2 = n_v2[nval2 == 3]
if len(obj2) != 0: # can be 0 for a line without bifurcation
branches_per_cc.setdefault(ncc_bif[obj2[0]], []).append(
(v, ccobj))
break
for bif, branches in six.iteritems(branches_per_cc):
if len(branches) >= 3:
br_cc = [x[1] for x in branches]
br_sz = []
for b in br_cc:
if b == -1: # no cc, size 0
br_sz.append(0)
else:
br_sz.append(len(np.where(ncc == b)[0]))
ranks = np.argsort(br_sz) # smallest to biggest
for br in ranks[:-2]: # leave at least the last 2 (biggest)
if br_sz[br] < min_branch_size:
branch = branches[br]
#print('remove branch size', br_sz[br], 'at vertex', branch[0], bif, branches)
notex[branch[0]] = 0
if branch[1] != -1:
notex[ncc==branch[1]] = 0
return otex
if __name__ == '__main__':
from soma import aims
from soma.aimsalgo import mesh_skeleton
import numpy as np
mesh = aims.read('/volatile/riviere/basetests-3.1.0/subjects/ratio_t1_dp/t1mri/default_acquisition/default_analysis/segmentation/mesh/ratio_t1_dp_Lwhite.gii')
neigh = aims.SurfaceManip.surfaceNeighbours(mesh)
tex = aims.read('/volatile/riviere/basetests-3.1.0/subjects/ratio_t1_dp/t1mri/default_acquisition/default_analysis/segmentation/mesh/surface_analysis/ratio_t1_dp_Lwhite_DPF.gii')
texture = aims.TimeTexture((np.asarray(tex[0]) >= 0).astype(np.int32))
ntex = np.asarray(texture[0])
stex = mesh_skeleton.mesh_skeleton(mesh, texture, dist_tex=tex)
aims.write(stex, '/tmp/stex.gii')
bg = np.where(ntex==0)[0]
front = np.where([np.any(ntex[list(neigh[i])]) for i in bg])[0]