Source code for deepsulci.deeptools.models

# -*- coding: utf-8 -*-
from __future__ import absolute_import
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from functools import partial
from six.moves import range
from six.moves import zip

[docs]class UNet3D(nn.Module): """ 3DUnet model from `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" <https://arxiv.org/pdf/1606.06650.pdf>`_ Args: in_channels (int): number of input channels out_channels (int): number of output segmentation masks; Note that that the of out_channels might correspond to either different semantic classes or to different binary segmentation mask. It's up to the user of the class to interpret the out_channels and use the proper loss criterion during training (i.e. NLLLoss (multi-class) or BCELoss (two-class) respectively) interpolate (bool): if True use F.interpolate for upsampling otherwise use ConvTranspose3d final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. conv_layer_order (string): determines the order of layers in `DoubleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. See `DoubleConv` for more info. """ def __init__(self, in_channels, out_channels, final_sigmoid, interpolate=True, conv_layer_order='crg', init_channel_number=64): super(UNet3D, self).__init__() # number of groups for the GroupNorm num_groups = min(init_channel_number // 2, 32) # encoder path consist of 4 subsequent Encoder modules # the number of features maps is the same as in the paper self.encoders = nn.ModuleList([ Encoder(in_channels, init_channel_number, is_max_pool=False, conv_layer_order=conv_layer_order, ind=0, num_groups=num_groups), Encoder(init_channel_number, 2 * init_channel_number, conv_layer_order=conv_layer_order, ind=1, num_groups=num_groups), Encoder(2 * init_channel_number, 4 * init_channel_number, conv_layer_order=conv_layer_order, ind=2, num_groups=num_groups), Encoder(4 * init_channel_number, 8 * init_channel_number, conv_layer_order=conv_layer_order, ind=3, num_groups=num_groups), ]) self.decoders = nn.ModuleList([ Decoder(4 * init_channel_number + 8 * init_channel_number, 4 * init_channel_number, interpolate, ind=4, conv_layer_order=conv_layer_order, num_groups=num_groups), Decoder(2 * init_channel_number + 4 * init_channel_number, 2 * init_channel_number, interpolate, ind=5, conv_layer_order=conv_layer_order, num_groups=num_groups), Decoder(init_channel_number + 2 * init_channel_number, init_channel_number, interpolate, ind=6, conv_layer_order=conv_layer_order, num_groups=num_groups) ]) # in the last layer a 1×1×1 convolution reduces the number of output # channels to the number of labels self.final_conv = nn.Conv3d(init_channel_number, out_channels, 1) if final_sigmoid: self.final_activation = nn.Sigmoid() else: self.final_activation = nn.LogSoftmax(dim=1) #nn.LogSoftmax(dim=1)
[docs] def forward(self, x): # inputs = torch.tensor(x) # encoder part encoders_features = [] for encoder in self.encoders: x = encoder(x) # reverse the encoder outputs to be aligned with the decoder encoders_features.insert(0, x) # remove the last encoder's output from the list # !!remember: it's the 1st in the list encoders_features = encoders_features[1:] # decoder part for decoder, encoder_features in zip(self.decoders, encoders_features): # pass the output from the corresponding encoder and the output # of the previous decoder x = decoder(encoder_features, x) x = self.final_conv(x) # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs # logits and it's up to the user to normalize it before visualising with tensorboard or computing accuracy if not self.training: x = self.final_activation(x) return x
[docs]class DoubleConv(nn.Sequential): """ A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d) with the number of output channels 'out_channels // 2' and 'out_channels' respectively. We use (Conv3d+ReLU+GroupNorm3d) by default. This can be change however by providing the 'order' argument, e.g. in order to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. Use padded convolutions to make sure that the output (H_out, W_out) is the same as (H_in, W_in), so that you don't have to crop in the decoder path. Args: in_channels (int): number of input channels out_channels (int): number of output channels kernel_size (int): size of the convolving kernel order (string): determines the order of layers, e.g. 'cr' -> conv + ReLU 'crg' -> conv + ReLU + groupnorm num_groups (int): number of groups for the GroupNorm """ def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=32): super(DoubleConv, self).__init__() if in_channels < out_channels: # if in_channels < out_channels we're in the encoder path conv1_in_channels, conv1_out_channels = in_channels, out_channels // 2 conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels else: # otherwise we're in the decoder path conv1_in_channels, conv1_out_channels = in_channels, out_channels conv2_in_channels, conv2_out_channels = out_channels, out_channels # conv1 self._add_conv(1, conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups) # conv2 self._add_conv(2, conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups) def _add_conv(self, pos, in_channels, out_channels, kernel_size, order, num_groups): """Add the conv layer with non-linearity and optional batchnorm Args: pos (int): the order (position) of the layer. MUST be 1 or 2 in_channels (int): number of input channels out_channels (int): number of output channels order (string): order of things, e.g. 'cr' -> conv + ReLU 'crg' -> conv + ReLU + groupnorm num_groups (int): number of groups for the GroupNorm """ assert pos in [1, 2], 'pos MUST be either 1 or 2' assert 'c' in order, "'c' (conv layer) MUST be present" assert 'r' in order, "'r' (ReLU layer) MUST be present" assert order[ 0] is not 'r', 'ReLU cannot be the first operation in the layer' for i, char in enumerate(order): if char == 'r': self.add_module('relu%i' % pos, nn.ReLU(inplace=True)) elif char == 'c': self.add_module('conv%i' % pos, nn.Conv3d( in_channels, out_channels, kernel_size, padding=1)) elif char == 'g': is_before_conv = i < order.index('c') assert not is_before_conv, 'GroupNorm3d MUST go after the Conv3d' self.add_module('norm%i' % pos, GroupNorm3d(out_channels, num_groups=num_groups)) elif char == 'b': is_before_conv = i < order.index('c') if is_before_conv: self.add_module('norm%i' % pos, nn.BatchNorm3d(in_channels)) else: self.add_module('norm%i' % pos, nn.BatchNorm3d(out_channels)) else: raise ValueError( "Unsupported layer type '%s'. MUST be one of 'b', 'r', 'c'" %char)
[docs]class Encoder(nn.Module): """ A single module from the encoder path consisting of the optional max pooling layer (one may specify the MaxPool kernel_size to be different than the standard (2,2,2), e.g. if the volumetric data is anisotropic (make sure to use complementary scale_factor in the decoder path) followed by a DoubleConv module. Args: in_channels (int): number of input channels out_channels (int): number of output channels conv_kernel_size (int): size of the convolving kernel is_max_pool (bool): if True use MaxPool3d before DoubleConv max_pool_kernel_size (tuple): the size of the window to take a max over conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm """ def __init__(self, in_channels, out_channels, conv_kernel_size=3, is_max_pool=True, max_pool_kernel_size=(2, 2, 2), conv_layer_order='crg', num_groups=32, ind=0): super(Encoder, self).__init__() self.ind = ind self.max_pool = nn.MaxPool3d(kernel_size=max_pool_kernel_size, padding=1) if is_max_pool else None self.double_conv = DoubleConv(in_channels, out_channels, kernel_size=conv_kernel_size, order=conv_layer_order, num_groups=num_groups)
[docs] def forward(self, x): if self.max_pool is not None: x = self.max_pool(x) x = self.double_conv(x) return x
[docs]class Decoder(nn.Module): """ A single module for decoder path consisting of the upsample layer (either learned ConvTranspose3d or interpolation) followed by a DoubleConv module. Args: in_channels (int): number of input channels out_channels (int): number of output channels interpolate (bool): if True use nn.Upsample for upsampling, otherwise learn ConvTranspose3d if you have enough GPU memory and ain't afraid of overfitting kernel_size (int): size of the convolving kernel scale_factor (tuple): used as the multiplier for the image H/W/D in case of nn.Upsample or as stride in case of ConvTranspose3d conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm """ def __init__(self, in_channels, out_channels, interpolate, kernel_size=3, scale_factor=(2, 2, 2), conv_layer_order='crg', num_groups=32, ind=0): super(Decoder, self).__init__() self.ind = ind if interpolate: self.upsample = None else: # make sure that the output size reverses the MaxPool3d # D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0] self.upsample = nn.ConvTranspose3d(2 * out_channels, 2 * out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1, output_padding=1) self.double_conv = DoubleConv(in_channels, out_channels, kernel_size=kernel_size, order=conv_layer_order, num_groups=num_groups)
[docs] def forward(self, encoder_features, x): if self.upsample is None: output_size = encoder_features.size()[2:] x = F.interpolate(x, size=output_size, mode='nearest') else: x = self.upsample(x) # concatenate encoder_features (encoder path) with the upsampled input across channel dimension x = torch.cat((encoder_features, x), dim=1) x = self.double_conv(x) return x
class _GroupNorm(nn.Module): dim_to_params_shape = { 3: (1, 1, 1, 1, 1), 2: (1, 1, 1, 1), 1: (1, 1, 1) } def __init__(self, num_features, dim, num_groups=32, eps=1e-5): super(_GroupNorm, self).__init__() assert dim in [1, 2, 3], 'Unsupported dimensionality: %i' % dim params_shape = list(self.dim_to_params_shape[dim]) params_shape[1] = num_features self.weight = nn.Parameter(torch.ones(params_shape)) self.bias = nn.Parameter(torch.zeros(params_shape)) self.num_groups = num_groups self.eps = eps def forward(self, x): self._check_input_dim(x) # save original shape shape = x.size() N = shape[0] C = shape[1] G = self.num_groups assert C % G == 0, 'Channel dim must be multiply of number of groups' x = x.view(N, G, -1) mean = x.mean(-1, keepdim=True) var = x.var(-1, keepdim=True) x = (x - mean) / (var + self.eps).sqrt() # restore original shape x = x.view(shape) return x * self.weight + self.bias def _check_input_dim(self, x): raise NotImplementedError
[docs]class GroupNorm3d(_GroupNorm): def __init__(self, num_features, num_groups=32, eps=1e-5): super(GroupNorm3d, self).__init__(num_features, 3, num_groups, eps) def _check_input_dim(self, x): if x.dim() != 5: raise ValueError('Expected 5D input (got %iD input)' % x.dim())
[docs]class GroupNorm2d(_GroupNorm): def __init__(self, num_features, num_groups=32, eps=1e-5): super(GroupNorm2d, self).__init__(num_features, 2, num_groups, eps) def _check_input_dim(self, x): if x.dim() != 4: raise ValueError('Expected 4D input (got %iD input)' % x.dim())
[docs]class GroupNorm1d(_GroupNorm): def __init__(self, num_features, num_groups=32, eps=1e-5): super(GroupNorm1d, self).__init__(num_features, 1, num_groups, eps) def _check_input_dim(self, x): if x.dim() != 3: raise ValueError('Expected 3D input (got %iD input)' % x.dim())
############################################################################### ############################################################################### ############################################################################### def conv3x3x3(in_planes, out_planes, stride=1): # 3x3x3 convolution with padding return nn.Conv3d( in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def downsample_basic_block(x, planes, stride): out = F.avg_pool3d(x, kernel_size=1, stride=stride) zero_pads = torch.Tensor( out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)).zero_() if isinstance(out.data, torch.cuda.FloatTensor): zero_pads = zero_pads.cuda() out = Variable(torch.cat([out.data, zero_pads], dim=1)) return out
[docs]class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm3d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3x3(planes, planes) self.bn2 = nn.BatchNorm3d(planes) self.downsample = downsample self.stride = stride
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
[docs]class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm3d(planes) self.conv2 = nn.Conv3d( planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm3d(planes) self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm3d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
[docs]class ResNet(nn.Module): def __init__(self, block, layers, shortcut_type='B', num_classes=2): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv3d( 1, 64, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False) self.bn1 = nn.BatchNorm3d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) self.layer2 = self._make_layer( block, 128, layers[1], shortcut_type, stride=2) self.layer3 = self._make_layer( block, 256, layers[2], shortcut_type, stride=2) self.layer4 = self._make_layer( block, 512, layers[3], shortcut_type, stride=2) self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.dropout = nn.Dropout() self.fc = nn.Linear(512 * block.expansion, num_classes) self.softmax = torch.nn.LogSoftmax() for m in self.modules(): if isinstance(m, nn.Conv3d): m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: if shortcut_type == 'A': downsample = partial( downsample_basic_block, planes=planes * block.expansion, stride=stride) else: downsample = nn.Sequential( nn.Conv3d( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm3d(planes * block.expansion)) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers)
[docs] def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x
[docs]def resnet18(**kwargs): """Constructs a ResNet-18 model. """ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) return model
[docs]def resnet34(**kwargs): """Constructs a ResNet-34 model. """ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) return model
[docs]def resnet50(**kwargs): """Constructs a ResNet-50 model. """ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) return model
[docs]def resnet101(**kwargs): """Constructs a ResNet-101 model. """ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) return model
[docs]def resnet152(**kwargs): """Constructs a ResNet-152 model. """ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) return model