Source code for soma.aimsalgo.texture_cleaning

#!/usr/bin/env python

from soma import aims, aimsalgo
from soma.aims.labelstools import read_labels
import numpy as np
import json
import sys
import queue
from soma import mpfork


# constants (or default values)*

default_ero_dist = {'unknown': 7.,
                    'GapMap': 1.5,
                    'other': 0.5}
default_dil = 1.5
default_min_cc_size = 100.  # mm2


def _clean_one_tex(tex, mesh, otex, lvalue, label, ero_dist, dilation):
    print('region:', label, lvalue)
    if label.endswith('(GapMap)'):
        ero_label = 'GapMap'
    else:
        ero_label = 'other'
    if label in ero_dist:
        ero_label = label
    ero = ero_dist[ero_label]
    ntex = aims.TimeTexture_S16()
    ntex[0].resize(len(tex[0]))
    ntex[0].np[tex[0].np != lvalue] = 0
    ntex[0].np[tex[0].np == lvalue] = 12
    # print('dirty:', len(np.where(ntex[0].np != 0)[0]), ero)
    # 1. erode to eliminate small crap
    dtex = aims.meshdistance.MeshErosion(mesh, ntex, 0, 1, ero, True)
    # 2. dilate to grow back, and over to connect disconnected parts
    dtex = aims.meshdistance.MeshDilation(mesh, dtex, 0, 1, ero + dilation,
                                          True)
    # 3. re-erode back to original size
    dtex = aims.meshdistance.MeshErosion(mesh, dtex, 0, 1, dilation, True)
    # print('clean1:', len(np.where(dtex[0].np != 0)[0]))
    #aims.write(dtex, '/tmp/dil_%s_%d.gii' % (label, lvalue))
    #raise RuntimeError('DEBUG STOP')

    #with lock:
    #otex[0].np[otex[0].np == lvalue] = -1
    #otex[0].np[dtex[0].np != 0] = lvalue
    #return otex
    otex = np.zeros(tex[0].np.shape, dtype=np.int16) - 32000
    otex[tex[0].np == lvalue] = -1
    otex[dtex[0].np != 0] = lvalue
    #print(label, 'done.')
    return otex


def _filter_cc(mesh, otex, lvalue, min_cc_size):
    ftex = aimsalgo.AimsMeshFilterConnectedComponent(
        mesh, otex, lvalue, -1, 0, 0, min_cc_size)
    # check that all components have not been erased
    if not np.any(ftex[0].np == lvalue):
        # try with smaller limit
        ftex = aimsalgo.AimsMeshFilterConnectedComponent(
            mesh, otex, lvalue, -1, 0, 0, min_cc_size / 5)
        if not np.any(ftex[0].np == lvalue):
            # last trial: keep only the biggest component
            ftex = aimsalgo.AimsMeshFilterConnectedComponent(
                mesh, otex, lvalue, -1, 1, 0, 0)
    # has anything changed ?
    ootex = None
    changed = False
    if len(np.where(ftex[0].np == lvalue)[0]) \
            != len(np.where(otex[0].np == lvalue)[0]):
        # replace dtex with filtered texture
        ootex = ftex[0].np
        changed = True

    return changed, ootex, lvalue


[docs]def clean_texture(mesh, tex, labels, ero_dist=default_ero_dist, dilation=default_dil, min_cc_size=default_min_cc_size, max_threads=0): ''' Clean labels texture: - for each label in the nomenclature: - erode a certain amount depending on the label to eliminate small crap - dilate the same amount + ``dilation`` mm, to grow back and to connect disconnected parts - re-erode back to original size - Voronoi for all regions to fill gaps - filter out small disconnected parts (< ``min_cc_size`` mm2) - set a labels and colors table in the texture Parameters ---------- mesh: Aims mesh tex: Aims texture labels: dict labels map, normally obtained using :func:`read_labels` ero_dist: dict dilation: float min_cc_size: float max_threads: int 0: all CPU cores 1: mono-core 2+: that number of worker threads -n: all but n cores Returns ------- otex: Aims texture cleaned output texture ''' if isinstance(mesh, str): mesh = aims.read(mesh) if isinstance(tex, str): tex = aims.read(tex) if isinstance(labels, str): labels = read_labels(labels) otex = aims.TimeTexture_S16() #otex[0].assign([0] * len(tex[0])) otex[0].assign(np.round(tex[0].np).astype(np.int16)) sides = ['l', 'r'] side = 0 # aims.write(otex, '/tmp/otex_%s_init.gii' % sides[side]) used = np.unique(tex[0].np) # make unkonwn the 1st region so that others grow over it, not the # contrary lvalues = [l for l, lv in labels.items() if lv['Label'] == 'unknown'] \ + [l for l, lv in labels.items() if lv['Label'] != 'unknown'] # handle parallel processing q = queue.Queue() workers = mpfork.allocate_workers(q, max_threads) print('n workers:', len(workers)) res = [None] * len(lvalues) i = 0 for lvalue in lvalues: label_def = labels[lvalue] label = label_def['Label'] # color = label_def.get('RGB') #if color is None: #color = [float(i) / (len(labels) - 1) for x in range(3)] + [1.] #print('modify color for', label, ':', color) #label_def['RGB'] = color if lvalue not in used: continue job = (i, _clean_one_tex, (tex, mesh, otex, lvalue, label, ero_dist, dilation), {}, res) q.put(job) i += 1 for i in range(len(workers)): q.put(None) # wait for every job to complete q.join() # terminate all threads for w in workers: w.join() for onp in res: if onp is None: continue otex[0].np[onp!=-32000] = onp[onp!=-32000] # 4. Voronoi for all regions otex = aims.meshdistance.MeshVoronoi(mesh, otex, -1, -2, 10000, True, True) # 5. filter out small disconnected parts if min_cc_size > 0: workers = mpfork.allocate_workers(q, max_threads) res = [None] * len(labels) i = 0 for lvalue, label_def in labels.items(): label = label_def['Label'] if lvalue not in used: continue job = (i, _filter_cc, (mesh, otex, lvalue, min_cc_size), {}, res) q.put(job) i += 1 njobs = i for i in range(len(workers)): q.put(None) # wait for every job to complete q.join() # terminate all threads for w in workers: w.join() changed = False for resval in res[:njobs]: c, onp, lvalue = resval if not c: continue otex[0].np[otex[0].np == lvalue] = -1 otex[0].np[onp==lvalue] = lvalue changed = True if changed: # perform another voronoi to fill the new holes otex = aims.meshdistance.MeshVoronoi(mesh, otex, -1, -2, 10000, True, True) # write result otex.header()['GIFTI_labels_table'] = labels otex.header()['texture_properties'] = [{'interpolation': 'rgb'}] #otex.header()['palette'] = {'palette': 'parcellation720'} # check size changes for lvalue, label_def in labels.items(): label = label_def['Label'] s1 = len(np.where(tex[0].np == lvalue)[0]) if s1 != 0: s2 = len(np.where(otex[0].np == lvalue)[0]) if s2 < 0.5 * s1: col = '0;31' # red elif s2 < 0.75 * s1: col = '0;33' # orange elif s2 > 2 * s1: col = '0;35' # purple elif s2 > 1.5 * s1: col = '1;35' # light purple else: col = '0;32' # green print(lvalue, ':', label, ', size:\033[%sm' % col, s1, '->', s2, '\033[0m') return otex
def main(argv=sys.argv): import argparse import textwrap ## v 2.4 #tex_names = [ #'lh.JulichBrain_MPMAtlas_l_N10_nlin2Stdicbm152asym2009c_publicDOI_83fb39b2811305777db0eb80a0fc8b53.BV_MNI152_orig_to_hcp32k.gii', #'rh.JulichBrain_MPMAtlas_r_N10_nlin2Stdicbm152asym2009c_publicDOI_172e93a5bec140c111ac862268f0d046.BV_MNI152_orig_to_hcp32k.gii' #] #mesh_names = ['lh.r.white.gii', 'rh.r.white.gii'] #otex_names = ['/tmp/otex_l.gii', '/tmp/otex_r.gii'] #nomenclature = 'julichbrain_nomenclature.txt' class MultilineFormatter(argparse.HelpFormatter): def _fill_text(self, text, width, indent): text = text.replace('\n\n* ', '|n |n * ') text = text.replace('\n* ', '|n * ') text = text.replace('\n\n', '|n |n ') text = self._whitespace_matcher.sub(' ', text).strip() paragraphs = text.split('|n ') multiline_text = '' for paragraph in paragraphs: formatted_paragraph = textwrap.fill(paragraph, width, initial_indent=indent, subsequent_indent=indent) + '\n\n' multiline_text = multiline_text + formatted_paragraph return multiline_text labels_doc_lines = read_labels.__doc__.split('\n') labels_doc_lines2 = [] for l in labels_doc_lines[2:]: if l.strip() == 'Returns': break labels_doc_lines2.append(l) labels_doc = '\n'.join(labels_doc_lines2) parser = argparse.ArgumentParser( description='Clean parcellation textures (developed initially for ' 'JulichBrain projections). Should normally be used after volume to ' 'texture projection using AimsVol2Tex.\n\n' 'The nomenclature is basically an int/string labels conversion table, ' 'but may include colors in addition to build a consistent colormap. ' 'The nomenclature file may be given in different formats: JSON, CSV, ' 'or HIE.\n\n' + labels_doc, formatter_class=MultilineFormatter) parser.add_argument('-t', '--texture', nargs='*', help='input textures') parser.add_argument('-m', '--mesh', nargs='*', help='input meshes') parser.add_argument('-o', '--output', nargs='*', help='output textures') parser.add_argument('-n', '--nomenclature', help='nomenclature int->string table file ' '(JSON, or CSV, or HIE)') parser.add_argument('-e', '--erosion', default=default_ero_dist, type=json.loads, help='erosion distances dict to remove small parts. Default: %s' % repr(default_ero_dist)) parser.add_argument('-d', '--dilation', default=default_dil, type=float, help='dilation distance (mm) to connect disconnected ' 'regions. Default: %f' % default_dil) parser.add_argument('-s', '--minsize', default=default_min_cc_size, type=float, help='minimum disconnected component size (in mm2) ' 'under which remaining small parts are removed. ' 'Default: %f' % default_min_cc_size) parser.add_argument('-p', '--proc', default=0, type=int, help='Use parallel computing using this number of ' 'processors (cores). 0=all in the machine, positive ' 'number=this number of cores, negative number=all but ' 'this number. Default: 0 (all)') options = parser.parse_args(argv[1:]) tex_names = options.texture mesh_names = options.mesh otex_names = options.output nomenclature = options.nomenclature ero_dist = options.erosion dil = options.dilation min_size = options.minsize nproc = options.proc if not mesh_names or not tex_names or not otex_names or len(mesh_names) != len(tex_names) or len(mesh_names) != len(otex_names): raise ValueError( 'texture, mesh, and output parameters must be specified, with the ' 'same number of values') for mname, tname, otex_name in zip(mesh_names, tex_names, otex_names): otex = clean_texture(mname, tname, nomenclature, ero_dist=ero_dist, dilation=dil, min_cc_size=min_size, max_threads=nproc) aims.write(otex, otex_name) if __name__ == '__main__': main()