Source code for capsul.pipeline.custom_nodes.cv_node
# -*- coding: utf-8 -*-
'''
:class:`CrossValidationFoldNode`
--------------------------------
'''
from __future__ import absolute_import
from capsul.pipeline.pipeline_nodes import Node
from soma.controller import Controller
import traits.api as traits
import sys
[docs]class CrossValidationFoldNode(Node):
'''
This "inert" node filters a list to separate it into (typically) learn and
test sublists.
The "outputs" are "train" and "test" output traits.
'''
_doc_path = 'api/pipeline.html#crossvalidationfoldnode'
def __init__(self, pipeline, name, input_type=None):
in_traitsl = ['inputs', 'fold', 'nfolds']
out_traitsl = ['train', 'test']
in_traits = []
out_traits = []
for tr in in_traitsl:
in_traits.append({'name': tr, 'optional': True})
for tr in out_traitsl:
out_traits.append({'name': tr, 'optional': True})
super(CrossValidationFoldNode, self).__init__(
pipeline, name, in_traits, out_traits)
if input_type:
ptype = input_type
else:
ptype = traits.Any(traits.Undefined)
self.add_trait('inputs', traits.List(ptype, output=False))
self.add_trait('fold', traits.Int())
self.add_trait('nfolds', traits.Int(10))
is_output = True # not a choice for now.
self.add_trait('train', traits.List(ptype, output=is_output))
self.add_trait('test', traits.List(ptype, output=is_output))
self.set_callbacks()
def set_callbacks(self, update_callback=None):
inputs = ['inputs', 'fold', 'nfolds']
if update_callback is None:
update_callback = self.filter_callback
for name in inputs:
self.on_trait_change(update_callback, name)
def filter_callback(self):
n = len(self.inputs) // self.nfolds
ninc = len(self.inputs) % self.nfolds
begin = self.fold * n + min((ninc, self.fold))
end = min((self.fold + 1) * n + min((ninc, self.fold + 1)),
len(self.inputs))
self.train = self.inputs[:begin] + self.inputs[end:]
self.test = self.inputs[begin:end]
def configured_controller(self):
c = self.configure_controller()
c.param_type = self.trait('inputs').inner_traits[0].trait_type.__class__.__name__
return c
@classmethod
def configure_controller(cls):
c = Controller()
c.add_trait('param_type', traits.Str('Str'))
return c
@classmethod
def build_node(cls, pipeline, name, conf_controller):
t = None
if conf_controller.param_type == 'Str':
t = traits.Str(traits.Undefined)
elif conf_controller.param_type == 'File':
t = traits.File(traits.Undefined)
elif conf_controller.param_type not in (None, traits.Undefined):
t = getattr(traits, conf_controller.param_type)()
node = CrossValidationFoldNode(pipeline, name, input_type=t)
return node
def params_to_command(self):
return ['custom_job']
def build_job(self, name=None, referenced_input_files=[],
referenced_output_files=[], param_dict=None):
from soma_workflow.custom_jobs \
import CrossValidationFoldJob
if param_dict is None:
param_dict = {}
else:
param_dict = dict(param_dict)
param_dict['inputs'] = self.inputs
param_dict['train'] = self.train
param_dict['test'] = self.test
param_dict['nfolds'] = self.nfolds
param_dict['fold'] = self.fold
job = CrossValidationFoldJob(
name=name,
referenced_input_files=referenced_input_files,
referenced_output_files=referenced_output_files,
param_dict=param_dict)
return job