Attachment 'mri_easyatlas.py'

Download

   1 import os
   2 import argparse
   3 import numpy as np
   4 import voxelmorph as vxm
   5 import torch
   6 import surfa as sf
   7 import nibabel as nib
   8 import glob
   9 from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion, distance_transform_edt, binary_fill_holes
  10 from scipy.ndimage import label as scipy_label
  11 
  12 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  13 os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
  14 import tensorflow as tf
  15 import keras
  16 import keras.backend as K
  17 import keras.layers as KL
  18 
  19 
  20 
  21 # set tensorflow logging
  22 tf.get_logger().setLevel('ERROR')
  23 K.set_image_data_format('channels_last')
  24 
  25 
  26 def main():
  27 
  28     parser = argparse.ArgumentParser(description="EasyAtlas: fast atlas construction with EasyReg", epilog='\n')
  29 
  30     # input/outputs
  31     parser.add_argument("--i", help="Input directory with scans")
  32     parser.add_argument("--o", help="Output directory where atlas and other files will be written")
  33     parser.add_argument("--threads", type=int, default=-1, help="(optional) Number of cores to be used. You can use -1 to use all available cores. Default is -1.")
  34     parser.add_argument('--use_reliability_maps', action='store_true', help='Use reliability maps when averaging into atlas (recommended if data are not 1mm isotropic!')
  35 
  36     # parse commandline
  37     args = parser.parse_args()
  38 
  39     #############
  40 
  41     # Very first thing: we require FreeSurfer
  42     if not os.environ.get('FREESURFER_HOME'):
  43         sf.system.fatal('FREESURFER_HOME is not set. Please source freesurfer.')
  44     fs_home = os.environ.get('FREESURFER_HOME')
  45 
  46     if args.i is None:
  47         sf.system.fatal('Input directory must be provided')
  48     if args.o is None:
  49         sf.system.fatal('Output directory must be provided')
  50 
  51     # limit the number of threads to be used if running on CPU
  52     if args.threads<0:
  53         args.threads = os.cpu_count()
  54         print('using all available threads ( %s )' % args.threads)
  55     else:
  56         print('using %s threads' % args.threads)
  57     tf.config.threading.set_inter_op_parallelism_threads(args.threads)
  58     tf.config.threading.set_intra_op_parallelism_threads(args.threads)
  59     torch.set_num_threads(args.threads)
  60 
  61     # path models
  62     path_model_segmentation = fs_home + '/models/synthseg_2.0.h5'
  63     path_model_parcellation = fs_home + '/models/synthseg_parc_2.0.h5'
  64     path_model_registration_trained = fs_home + '/models/easyreg_v10_230103.h5'
  65 
  66     # path labels
  67     labels_segmentation = fs_home +  '/models/synthseg_segmentation_labels_2.0.npy'
  68     labels_parcellation = fs_home +  '/models/synthseg_parcellation_labels.npy'
  69     atlas_volsize = [160, 160, 192]
  70     atlas_aff = np.matrix([[-1, 0, 0, 79], [0, 0, 1, -104], [0, -1, 0, 79], [0, 0, 0, 1]])
  71 
  72     # get label lists
  73     labels_segmentation, _ = get_list_labels(label_list=labels_segmentation)
  74     labels_segmentation, unique_idx = np.unique(labels_segmentation, return_index=True)
  75     labels_parcellation, _ = np.unique(get_list_labels(labels_parcellation)[0], return_index=True)
  76 
  77     # Create output (and SynthSeg) directory if needed
  78     if os.path.exists(args.o) and os.path.isdir(args.o):
  79         print('Output directory already exists; no need to create it')
  80     else:
  81         os.mkdir(args.o)
  82     segdir = args.o + '/SynthSeg/'
  83     if os.path.exists(segdir) and os.path.isdir(segdir):
  84         print('SynthSeg directory already exists; no need to create it')
  85     else:
  86         os.mkdir(segdir)
  87     regdir = args.o + '/Registrations/'
  88     if os.path.exists(regdir) and os.path.isdir(regdir):
  89         print('Registration directory already exists; no need to create it')
  90     else:
  91         os.mkdir(regdir)
  92     tempdir = args.o + '/temp/'
  93     if os.path.exists(tempdir) and os.path.isdir(tempdir):
  94         print('Temporary directory already exists; no need to create it')
  95     else:
  96         os.mkdir(tempdir)
  97 
  98     # Build list of input, affine, segmentation files (supports nii, mgz, nii.gz)
  99     input_files = sorted(glob.glob(args.i + '/*.nii.gz') + glob.glob(args.i + '/*.nii') + glob.glob(args.i + '/*.mgz'))
 100     seg_files = []
 101     reg_files = []
 102     linear_files = []
 103     for file in input_files:
 104         _, tail = os.path.split(file)
 105         seg_files.append(segdir + '/' + tail)
 106         reg_files.append(regdir + '/' + tail)
 107         linear_files.append(tempdir + '/' + tail + '.npy')
 108 
 109     # Decide if we need to segment anything
 110     all_segs_ready = True
 111     for file in seg_files:
 112         if os.path.exists(file) is False:
 113             all_segs_ready = False
 114 
 115     # Run SynthSeg if needed
 116     if all_segs_ready:
 117         print('SynthSeg already there for all input files; no need to segment anything')
 118     else:
 119         print('Setting up segmentation net')
 120         segmentation_net = build_seg_model(model_file_segmentation=path_model_segmentation,
 121                                            model_file_parcellation=path_model_parcellation,
 122                                            labels_segmentation=labels_segmentation,
 123                                            labels_parcellation=labels_parcellation)
 124         for i in range(len(input_files)):
 125             if os.path.exists(seg_files[i]):
 126                 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmentation already there')
 127             else:
 128                 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmenting')
 129                 image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_image=input_files[i], crop=None,
 130                                                                              min_pad=128, path_resample=None)
 131                 post_patch_segmentation, post_patch_parcellation = segmentation_net.predict(image)
 132                 seg_buffer, _, _ = postprocess(post_patch_seg=post_patch_segmentation,
 133                                                    post_patch_parc=post_patch_parcellation,
 134                                                    shape=shape,
 135                                                    pad_idx=pad_idx,
 136                                                    crop_idx=crop_idx,
 137                                                    labels_segmentation=labels_segmentation,
 138                                                    labels_parcellation=labels_parcellation,
 139                                                    aff=aff,
 140                                                    im_res=im_res)
 141                 save_volume(seg_buffer, aff, h, seg_files[i], dtype='int32')
 142 
 143     # Now the linear registration part
 144     print('Linear registration with centroids of segmentations')
 145 
 146     # First, prepare a bunch of common variables
 147     labels = np.array([2,4,5,7,8,10,11,12,13,14,15,16,17,18,26,28,41,43,44,46,47,49,50,51,52,53,54,58,60,
 148                                     1001,1002,1003,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023,1024,1025,1026,1027,1028,1029,1030,1031,1032,1033,1034,1035,
 149                                     2001,2002,2003,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2020,2021,2022,2023,2024,2025,2026,2027,2028,2029,2030,2031,2032,2033,2034,2035])
 150     nlab = len(labels)
 151     atlasCOG = np.array([[-28.,-18.,-37.,-19.,-27.,-19.,-23.,-31.,-26.,-2.,-3.,-3.,-29.,-26.,-14.,-14.,24.,14.,31.,12.,18.,14.,19.,26.,21.,25.,22.,11.,8.,-52.,-6.,-36.,-7.,-24.,-37.,-39.,-52.,-9.,-27.,-26.,-14.,-8.,-59.,-28.,-7.,-49.,-43.,-47.,-12.,-46.,-6.,-43.,-10.,-7.,-33.,-11.,-23.,-55.,-50.,-10.,-29.,-46.,-38.,48.,4.,31.,3.,21.,33.,37.,47.,3.,24.,20.,8.,4.,54.,21.,5.,45.,38.,46.,8.,45.,3.,38.,6.,4.,29.,9.,19.,51.,49.,10.,24.,43.,33.],
 152                         [-30.,-17.,-13.,-36.,-40.,-22.,-3.,-5.,-9.,-14.,-31.,-21.,-15.,-1.,3.,-16.,-32.,-20.,-14.,-37.,-42.,-24.,-3.,-6.,-10.,-15.,-2.,3.,-17.,-44.,-5.,-15.,-71.,2.,-29.,-70.,-23.,-44.,-73.,22.,-57.,27.,-19.,-23.,-45.,4.,31.,20.,-68.,-38.,-33.,-26.,-60.,23.,22.,0.,-72.,-12.,-49.,49.,17.,-25.,-3.,-42.,-1.,-16.,-76.,0.,-34.,-69.,-16.,-44.,-73.,22.,-56.,28.,-18.,-25.,-45.,-3.,30.,14.,-69.,-37.,-32.,-30.,-60.,21.,21.,0.,-72.,-11.,-49.,48.,15.,-27.,-3.],
 153                         [12.,14.,-13.,-41.,-51.,1.,13.,3.,1.,0.,-40.,-28.,-15.,-10.,2.,-7.,11.,14.,-12.,-40.,-51.,2.,14.,4.,2.,-14.,-10.,4.,-7.,-8.,32.,40.,-14.,-21.,-28.,-4.,-28.,-3.,-35.,3.,-29.,4.,-17.,-21.,35.,18.,9.,20.,-24.,28.,25.,34.,7.,18.,35.,48.,16.,-5.,12.,22.,-18.,1.,4.,-12.,32.,43.,-11.,-21.,-29.,-3.,-27.,0.,-34.,3.,-25.,6.,-18.,-20.,36.,18.,11.,20.,-20.,26.,25.,34.,4.,24.,34.,47.,17.,-5.,10.,20.,-18.,0.,4.]])
 154 
 155     II, JJ, KK = np.meshgrid(np.arange(atlas_volsize[0]), np.arange(atlas_volsize[1]), np.arange(atlas_volsize[2]), indexing='ij')
 156     II = torch.tensor(II, device='cpu')
 157     JJ = torch.tensor(JJ, device='cpu')
 158     KK = torch.tensor(KK, device='cpu')
 159 
 160     # Loop over segmentations and get COGs of ROIs
 161     COGs = np.zeros([len(input_files), 4, nlab])
 162     OKs = np.zeros([len(input_files), nlab])
 163     for i in range(len(input_files)):
 164         print('Getting centroids of ROIs: case ' + str(i + 1) + ' of ' + str(len(input_files)))
 165         COG = np.zeros([4, nlab])
 166         ok = np.ones(nlab)
 167         seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
 168         label_to_idx = {lab: ii for ii, lab in enumerate(labels)}
 169         coords_per_label = [[] for _ in range(nlab)]
 170         nz = np.array(np.nonzero(seg_buffer)).T
 171         vals = seg_buffer[tuple(nz.T)]
 172         valid_mask = np.isin(vals, labels)
 173         nz = nz[valid_mask]
 174         vals = vals[valid_mask]
 175         idxs = np.searchsorted(labels, vals)
 176         for ii in range(nlab):
 177             coords_per_label[ii] = nz[idxs == ii]
 178         # Compute per-label median centroids
 179         for ii, vox in enumerate(coords_per_label):
 180             if vox.shape[0] > 50:
 181                 COG[:3, ii] = np.median(vox, axis=0)
 182                 COG[3, ii] = 1
 183             else:
 184                 ok[ii] = 0
 185         COGs[i] = np.matmul(seg_aff, COG)
 186         OKs[i] = ok.copy()
 187 
 188     # Linear registration matrices; first rigid, then affine
 189     NUM = np.zeros(atlasCOG.shape)
 190     DEN = np.zeros(atlasCOG.shape)
 191     for i in range(len(input_files)):
 192         M = getMrigid(COGs[i, :-1, OKs[i] > 0].T, atlasCOG[:, OKs[i] > 0])
 193         NUM[:, OKs[i] > 0] = NUM[:, OKs[i] > 0] + (M @ COGs[i, :, OKs[i] > 0].T)[:-1, :]
 194         DEN[:, OKs[i] > 0] = DEN[:, OKs[i] > 0] + 1
 195     rigidAtlasCOG = NUM / DEN
 196     Ms = np.zeros([len(input_files), 4, 4])
 197     for i in range(len(input_files)):
 198         Ms[i] = getM(rigidAtlasCOG[:, OKs[i] > 0], COGs[i, :, OKs[i] > 0].T)
 199 
 200     # OK now we can deform to linear space (and compute linear atlas, while at it)
 201     NUM = np.zeros(atlas_volsize)
 202     DEN = np.zeros(atlas_volsize)
 203     for i in range(len(input_files)):
 204         print('Deforming to linear space: case ' + str(i + 1) + ' of ' + str(len(input_files)))
 205         im_buffer, im_aff, im_hh = load_volume(input_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
 206         im_buffer = torch.tensor(im_buffer, device='cpu')
 207         voxdim = np.sqrt(np.sum(im_aff[:-1, :-1] ** 2, axis=0))
 208         affine = torch.tensor(np.matmul(np.linalg.inv(im_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
 209         II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
 210         JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
 211         KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
 212         im_lin = fast_3D_interp_torch(im_buffer, II2, JJ2, KK2, 'linear')
 213         if args.use_reliability_maps:
 214             lin_dists = torch.sqrt(((II2 - II2.round()) * voxdim[0]) ** 2 +
 215                                    ((JJ2 - JJ2.round()) * voxdim[1]) ** 2 +
 216                                    ((KK2 - KK2.round()) * voxdim[2]) ** 2)
 217             lin_rel = torch.exp(-1.0 * lin_dists)
 218         else:
 219             lin_rel = torch.ones(II2.shape)
 220 
 221         seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
 222         affine = torch.tensor(np.matmul(np.linalg.inv(seg_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
 223         II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
 224         JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
 225         KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
 226         seg_lin = fast_3D_interp_torch(torch.tensor(seg_buffer.copy(), device='cpu'), II2, JJ2, KK2, 'nearest')
 227         im_lin[seg_lin == 0] = 0
 228         im_lin /= torch.median(im_lin[torch.logical_or(seg_lin==2, seg_lin==41)])
 229         np.save(linear_files[i], torch.stack([im_lin, lin_rel]).detach().cpu().numpy())
 230         NUM += (im_lin * lin_rel).detach().cpu().numpy()
 231         DEN += lin_rel.detach().cpu().numpy()
 232 
 233     print('Computing and saving affine atlas')
 234     ATLAS = NUM / (1e-9 + DEN)
 235     save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.affine.nii.gz')
 236 
 237     print('Building nonlinear registration model')
 238     # Build model
 239     source = tf.keras.Input(shape=(*atlas_volsize, 1))
 240     target = tf.keras.Input(shape=(*atlas_volsize, 1))
 241 
 242     config = {'name': 'vxm_dense', 'fill_value': None, 'input_model': None, 'unet_half_res': True, 'trg_feats': 1,
 243               'src_feats': 1, 'use_probs': False, 'bidir': False, 'int_downsize': 2, 'int_steps': 10,
 244               'nb_unet_conv_per_level': 1, 'unet_feat_mult': 1, 'nb_unet_levels': None,
 245               'nb_unet_features': [[256, 256, 256, 256], [256, 256, 256, 256, 256, 256]], 'inshape': atlas_volsize}
 246     cnn = vxm.networks.VxmDense(**config)
 247     cnn.load_weights(path_model_registration_trained, by_name=True)
 248     svf1 = cnn([source, target])[1]
 249     svf2 = cnn([target, source])[1]
 250     pos_svf = KL.Lambda(lambda x: 0.5 * x[0] - 0.5 * x[1])([svf1, svf2])
 251     neg_svf = KL.Lambda(lambda x: -x)(pos_svf)
 252     pos_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(pos_svf)
 253     neg_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(neg_svf)
 254     pos_def = vxm.layers.RescaleTransform(2)(pos_def_small)
 255     neg_def = vxm.layers.RescaleTransform(2)(neg_def_small)
 256     model = tf.keras.Model(inputs=[source, target],
 257                            outputs=[pos_def, neg_def])
 258     model.load_weights(path_model_registration_trained)
 259 
 260     # Global atlas building iterations
 261     MAX_IT = 5
 262     for it in range(MAX_IT):
 263         # Initialize new atlas to zeros
 264         NUM = np.zeros_like(ATLAS)
 265         DEN = np.zeros_like(ATLAS)
 266         for i in range(len(input_files)):
 267             print('Iteration ' + str(1 + it) + ' of ' + str(MAX_IT) + ', image ' + str(i+1) + ' of ' + str(len(input_files)))
 268             lin = np.load(linear_files[i])
 269             pred = model.predict([lin[0:1, ..., np.newaxis] / np.max(lin[0]) ,
 270                                   ATLAS[np.newaxis, ..., np.newaxis]])
 271             field = torch.tensor(pred[0], device='cpu').squeeze()
 272             II2 = II + field[..., 0]
 273             JJ2 = JJ + field[..., 1]
 274             KK2 = KK + field[..., 2]
 275             deformed_im = fast_3D_interp_torch(torch.tensor(lin[0], device='cpu'), II2 , JJ2, KK2, 'linear')
 276             deformed_rel = fast_3D_interp_torch(torch.tensor(lin[1], device='cpu'), II2, JJ2, KK2, 'linear')
 277             NUM += (deformed_im * deformed_rel).detach().cpu().numpy()
 278             DEN += deformed_rel.detach().cpu().numpy()
 279             if it == (MAX_IT-1):
 280                 T = Ms[i] @ atlas_aff
 281                 RR = T[0, 0] * II2 + T[0, 1] * JJ2 + T[0, 2] * KK2 + T[0, 3]
 282                 AA = T[1, 0] * II2 + T[1, 1] * JJ2 + T[1, 2] * KK2 + T[1, 3]
 283                 SS = T[2, 0] * II2 + T[2, 1] * JJ2 + T[2, 2] * KK2 + T[2, 3]
 284                 save_volume(torch.stack([RR, AA, SS], dim=-1).detach().cpu().numpy(), atlas_aff, None, reg_files[i])
 285         ATLAS = NUM / (1e-9 + DEN)
 286         save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.iteration.' + str(it+1) + '.nii.gz')
 287 
 288     # Clean up
 289     print('Deleting temporary files')
 290     for i in range(len(linear_files)):
 291         os.remove(linear_files[i])
 292     os.rmdir(tempdir)
 293 
 294     print(' ')
 295     print('All done!')
 296     print(' ')
 297     print('If you use EasyReg in your analysis, please cite:')
 298     print('A ready-to-use machine learning tool for symmetric multi-modality registration of brain MRI.')
 299     print('JE Iglesias. Scientific Reports, 13, article number 6657 (2023).')
 300     print('https://www.nature.com/articles/s41598-023-33781-0')
 301     print(' ')
 302 
 303 
 304 #######################
 305 # Auxiliary functions #
 306 #######################
 307 
 308 
 309 def get_list_labels(label_list=None, save_label_list=None, FS_sort=False):
 310 
 311     # load label list if previously computed
 312     label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int'))
 313 
 314 
 315     # sort labels in neutral/left/right according to FS labels
 316     n_neutral_labels = 0
 317     if FS_sort:
 318         neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108,
 319                              109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
 320                              251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340,
 321                              502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530,
 322                              531, 532, 533, 534, 535, 536, 537]
 323         neutral = list()
 324         left = list()
 325         right = list()
 326         for la in label_list:
 327             if la in neutral_FS_labels:
 328                 if la not in neutral:
 329                     neutral.append(la)
 330             elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \
                    (la == 865) | (20100 < la < 20110):
 331                 if la not in left:
 332                     left.append(la)
 333             elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \
                    (la == 866):
 334                 if la not in right:
 335                     right.append(la)
 336             else:
 337                 raise Exception('label {} not in our current FS classification, '
 338                                 'please update get_list_labels in utils.py'.format(la))
 339         label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)])
 340         if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)):
 341             n_neutral_labels = len(neutral)
 342         else:
 343             n_neutral_labels = len(label_list)
 344 
 345     # save labels if specified
 346     if save_label_list is not None:
 347         np.save(save_label_list, np.int32(label_list))
 348 
 349     if FS_sort:
 350         return np.int32(label_list), n_neutral_labels
 351     else:
 352         return np.int32(label_list), None
 353 
 354 def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None):
 355     # convert to list
 356     if var is None:
 357         return None
 358     var = load_array_if_path(var, load_as_numpy=load_as_numpy)
 359     if isinstance(var, (int, float, np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64)):
 360         var = [var]
 361     elif isinstance(var, tuple):
 362         var = list(var)
 363     elif isinstance(var, np.ndarray):
 364         if var.shape == (1,):
 365             var = [var[0]]
 366         else:
 367             var = np.squeeze(var).tolist()
 368     elif isinstance(var, str):
 369         var = [var]
 370     elif isinstance(var, bool):
 371         var = [var]
 372     if isinstance(var, list):
 373         if length is not None:
 374             if len(var) == 1:
 375                 var = var * length
 376             elif len(var) != length:
 377                 raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, '
 378                                  'had {1}'.format(length, var))
 379     else:
 380         raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array')
 381 
 382     # convert items type
 383     if dtype is not None:
 384         if dtype == 'int':
 385             var = [int(v) for v in var]
 386         elif dtype == 'float':
 387             var = [float(v) for v in var]
 388         elif dtype == 'bool':
 389             var = [bool(v) for v in var]
 390         elif dtype == 'str':
 391             var = [str(v) for v in var]
 392         else:
 393             raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype))
 394     return var
 395 
 396 def load_array_if_path(var, load_as_numpy=True):
 397     if (isinstance(var, str)) & load_as_numpy:
 398         assert os.path.isfile(var), 'No such path: %s' % var
 399         var = np.load(var)
 400     return var
 401 
 402 
 403 def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None):
 404 
 405     assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume
 406 
 407     if path_volume.endswith(('.nii', '.nii.gz', '.mgz')):
 408         x = nib.load(path_volume)
 409         if squeeze:
 410             volume = np.squeeze(x.get_fdata())
 411         else:
 412             volume = x.get_fdata()
 413         aff = x.affine
 414         header = x.header
 415     else:  # npz
 416         volume = np.load(path_volume)['vol_data']
 417         if squeeze:
 418             volume = np.squeeze(volume)
 419         aff = np.eye(4)
 420         header = nib.Nifti1Header()
 421     if dtype is not None:
 422         if 'int' in dtype:
 423             volume = np.round(volume)
 424         volume = volume.astype(dtype=dtype)
 425 
 426     # align image to reference affine matrix
 427     if aff_ref is not None:
 428         n_dims, _ = get_dims(list(volume.shape), max_channels=10)
 429         volume, aff = align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims)
 430 
 431     if im_only:
 432         return volume
 433     else:
 434         return volume, aff, header
 435 
 436 
 437 
 438 
 439 def preprocess(path_image, n_levels=5, crop=None, min_pad=None, path_resample=None):
 440     # read image and corresponding info
 441     im, _, aff, n_dims, n_channels, h, im_res = get_volume_info(path_image, True)
 442     if n_dims < 3:
 443         sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
 444     elif n_dims == 4 and n_channels == 1:
 445         n_dims = 3
 446         im = im[..., 0]
 447     elif n_dims > 3:
 448         sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
 449     elif n_channels > 1:
 450         print('WARNING: detected more than 1 channel, only keeping the first channel.')
 451         im = im[..., 0]
 452 
 453     # resample image if necessary
 454     if np.any((im_res > 1.05) | (im_res < 0.95)):
 455         im_res = np.array([1.] * 3)
 456         im, aff = resample_volume(im, aff, im_res)
 457         if path_resample is not None:
 458             save_volume(im, aff, h, path_resample)
 459 
 460     # align image
 461     im = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False)
 462     shape = list(im.shape[:n_dims])
 463 
 464     # crop image if necessary
 465     if crop is not None:
 466         crop = reformat_to_list(crop, length=n_dims, dtype='int')
 467         crop_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop]
 468         im, crop_idx = crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True)
 469     else:
 470         crop_idx = None
 471 
 472     # normalise image
 473     im = rescale_volume(im, new_min=0, new_max=1, min_percentile=0.5, max_percentile=99.5)
 474 
 475     # pad image
 476     input_shape = im.shape[:n_dims]
 477     pad_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in input_shape]
 478     min_pad = reformat_to_list(min_pad, length=n_dims, dtype='int')
 479     min_pad = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in min_pad]
 480     pad_shape = np.maximum(pad_shape, min_pad)
 481     im, pad_idx = pad_volume(im, padding_shape=pad_shape, return_pad_idx=True)
 482 
 483     # add batch and channel axes
 484     im = add_axis(im, axis=[0, -1])
 485 
 486     return im, aff, h, im_res, shape, pad_idx, crop_idx
 487 
 488 
 489 def resample_volume(volume, aff, new_vox_size, interpolation='linear'):
 490     pixdim = np.sqrt(np.sum(aff * aff, axis=0))[:-1]
 491     new_vox_size = np.array(new_vox_size)
 492     factor = pixdim / new_vox_size
 493     sigmas = 0.25 / factor
 494     sigmas[factor > 1] = 0  # don't blur if upsampling
 495 
 496     volume_filt = gaussian_filter(volume, sigmas)
 497 
 498     # volume2 = zoom(volume_filt, factor, order=1, mode='reflect', prefilter=False)
 499     x = np.arange(0, volume_filt.shape[0])
 500     y = np.arange(0, volume_filt.shape[1])
 501     z = np.arange(0, volume_filt.shape[2])
 502 
 503     start = - (factor - 1) / (2 * factor)
 504     step = 1.0 / factor
 505     stop = start + step * np.ceil(volume_filt.shape * factor)
 506 
 507     xi = np.arange(start=start[0], stop=stop[0], step=step[0])
 508     yi = np.arange(start=start[1], stop=stop[1], step=step[1])
 509     zi = np.arange(start=start[2], stop=stop[2], step=step[2])
 510     xi[xi < 0] = 0
 511     yi[yi < 0] = 0
 512     zi[zi < 0] = 0
 513     xi[xi > (volume_filt.shape[0] - 1)] = volume_filt.shape[0] - 1
 514     yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1
 515     zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1
 516 
 517     xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=False)
 518     xig = torch.tensor(xig, device='cpu')
 519     yig = torch.tensor(yig, device='cpu')
 520     zig = torch.tensor(zig, device='cpu')
 521     volume2 = fast_3D_interp_torch(torch.tensor(volume_filt, device='cpu'), xig, yig, zig, 'linear')
 522 
 523     aff2 = aff.copy()
 524     for c in range(3):
 525         aff2[:-1, c] = aff2[:-1, c] / factor[c]
 526     aff2[:-1, -1] = aff2[:-1, -1] - np.matmul(aff2[:-1, :-1], 0.5 * (factor - 1))
 527 
 528     return volume2.numpy(), aff2
 529 
 530 def find_closest_number_divisible_by_m(n, m, answer_type='lower'):
 531     if n % m == 0:
 532         return n
 533     else:
 534         q = int(n / m)
 535         lower = q * m
 536         higher = (q + 1) * m
 537         if answer_type == 'lower':
 538             return lower
 539         elif answer_type == 'higher':
 540             return higher
 541         elif answer_type == 'closer':
 542             return lower if (n - lower) < (higher - n) else higher
 543         else:
 544             sf.system.fatal('answer_type should be lower, higher, or closer, had : %s' % answer_type)
 545 
 546 
 547 
 548 def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10):
 549 
 550     im, aff, header = load_volume(path_volume, im_only=False)
 551 
 552     # understand if image is multichannel
 553     im_shape = list(im.shape)
 554     n_dims, n_channels = get_dims(im_shape, max_channels=max_channels)
 555     im_shape = im_shape[:n_dims]
 556 
 557     # get labels res
 558     if '.nii' in path_volume:
 559         data_res = np.array(header['pixdim'][1:n_dims + 1])
 560     elif '.mgz' in path_volume:
 561         data_res = np.array(header['delta'])  # mgz image
 562     else:
 563         data_res = np.array([1.0] * n_dims)
 564 
 565     # align to given affine matrix
 566     if aff_ref is not None:
 567         ras_axes = get_ras_axes(aff, n_dims=n_dims)
 568         ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
 569         im = align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims)
 570         im_shape = np.array(im_shape)
 571         data_res = np.array(data_res)
 572         im_shape[ras_axes_ref] = im_shape[ras_axes]
 573         data_res[ras_axes_ref] = data_res[ras_axes]
 574         im_shape = im_shape.tolist()
 575 
 576     # return info
 577     if return_volume:
 578         return im, im_shape, aff, n_dims, n_channels, header, data_res
 579     else:
 580         return im_shape, aff, n_dims, n_channels, header, data_res
 581 
 582 def get_dims(shape, max_channels=10):
 583     if shape[-1] <= max_channels:
 584         n_dims = len(shape) - 1
 585         n_channels = shape[-1]
 586     else:
 587         n_dims = len(shape)
 588         n_channels = 1
 589     return n_dims, n_channels
 590 
 591 
 592 def get_ras_axes(aff, n_dims=3):
 593     aff_inverted = np.linalg.inv(aff)
 594     img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0)
 595     for i in range(n_dims):
 596         if i not in img_ras_axes:
 597             unique, counts = np.unique(img_ras_axes, return_counts=True)
 598             incorrect_value = unique[np.argmax(counts)]
 599             img_ras_axes[np.where(img_ras_axes == incorrect_value)[0][-1]] = i
 600 
 601     return img_ras_axes
 602 
 603 def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True):
 604 
 605     # work on copy
 606     new_volume = volume.copy() if return_copy else volume
 607     aff_flo = aff.copy()
 608 
 609     # default value for aff_ref
 610     if aff_ref is None:
 611         aff_ref = np.eye(4)
 612 
 613     # extract ras axes
 614     if n_dims is None:
 615         n_dims, _ = get_dims(new_volume.shape)
 616     ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
 617     ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims)
 618 
 619     # align axes
 620     aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo]
 621     for i in range(n_dims):
 622         if ras_axes_flo[i] != ras_axes_ref[i]:
 623             new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i])
 624             swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i])
 625             ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx]
 626 
 627     # align directions
 628     dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0)
 629     for i in range(n_dims):
 630         if dot_products[i] < 0:
 631             new_volume = np.flip(new_volume, axis=i)
 632             aff_flo[:, i] = - aff_flo[:, i]
 633             aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1)
 634 
 635     if return_aff:
 636         return new_volume, aff_flo
 637     else:
 638         return new_volume
 639 
 640 def build_seg_model(model_file_segmentation,
 641                 model_file_parcellation,
 642                 labels_segmentation,
 643                 labels_parcellation):
 644 
 645     if not os.path.isfile(model_file_segmentation):
 646         sf.system.fatal("The provided model path does not exist.")
 647 
 648     # get labels
 649     n_labels_seg = len(labels_segmentation)
 650 
 651     # build UNet
 652     net = unet(nb_features=24,
 653                input_shape=[None, None, None, 1],
 654                nb_levels=5,
 655                conv_size=3,
 656                nb_labels=n_labels_seg,
 657                feat_mult=2,
 658                activation='elu',
 659                nb_conv_per_level=2,
 660                batch_norm=-1,
 661                name='unet')
 662     net.load_weights(model_file_segmentation, by_name=True)
 663     input_image = net.inputs[0]
 664     name_segm_prediction_layer = 'unet_prediction'
 665 
 666     # smooth posteriors
 667     last_tensor = net.output
 668     last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
 669     last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
 670     net = keras.Model(inputs=net.inputs, outputs=last_tensor)
 671 
 672     # add aparc segmenter
 673     n_labels_parcellation = len(labels_parcellation)
 674 
 675     last_tensor = net.output
 676     last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), 'int32'))(last_tensor)
 677     last_tensor = ConvertLabels(np.arange(n_labels_seg), labels_segmentation)(last_tensor)
 678     parcellation_masking_values = np.array([1 if ((ll == 3) | (ll == 42)) else 0 for ll in labels_segmentation])
 679     last_tensor = ConvertLabels(labels_segmentation, parcellation_masking_values)(last_tensor)
 680     last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=2, axis=-1))(last_tensor)
 681     last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), 'float32'))([input_image, last_tensor])
 682     net = keras.Model(inputs=net.inputs, outputs=last_tensor)
 683 
 684     # build UNet
 685     net = unet(nb_features=24,
 686                input_shape=[None, None, None, 3],
 687                nb_levels=5,
 688                conv_size=3,
 689                nb_labels=n_labels_parcellation,
 690                feat_mult=2,
 691                activation='elu',
 692                nb_conv_per_level=2,
 693                batch_norm=-1,
 694                name='unet_parc',
 695                input_model=net)
 696     net.load_weights(model_file_parcellation, by_name=True)
 697 
 698     # smooth predictions
 699     last_tensor = net.output
 700     last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
 701     last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
 702     net = keras.Model(inputs=net.inputs, outputs=[net.get_layer(name_segm_prediction_layer).output, last_tensor])
 703 
 704     return net
 705 
 706 def unet(nb_features,
 707          input_shape,
 708          nb_levels,
 709          conv_size,
 710          nb_labels,
 711          name='unet',
 712          prefix=None,
 713          feat_mult=1,
 714          pool_size=2,
 715          padding='same',
 716          dilation_rate_mult=1,
 717          activation='elu',
 718          skip_n_concatenations=0,
 719          use_residuals=False,
 720          final_pred_activation='softmax',
 721          nb_conv_per_level=1,
 722          layer_nb_feats=None,
 723          conv_dropout=0,
 724          batch_norm=None,
 725          input_model=None):
 726 
 727     # naming
 728     model_name = name
 729     if prefix is None:
 730         prefix = model_name
 731 
 732     # volume size data
 733     ndims = len(input_shape) - 1
 734     if isinstance(pool_size, int):
 735         pool_size = (pool_size,) * ndims
 736 
 737     # get encoding model
 738     enc_model = conv_enc(nb_features,
 739                          input_shape,
 740                          nb_levels,
 741                          conv_size,
 742                          name=model_name,
 743                          prefix=prefix,
 744                          feat_mult=feat_mult,
 745                          pool_size=pool_size,
 746                          padding=padding,
 747                          dilation_rate_mult=dilation_rate_mult,
 748                          activation=activation,
 749                          use_residuals=use_residuals,
 750                          nb_conv_per_level=nb_conv_per_level,
 751                          layer_nb_feats=layer_nb_feats,
 752                          conv_dropout=conv_dropout,
 753                          batch_norm=batch_norm,
 754                          input_model=input_model)
 755 
 756     # get decoder
 757     # use_skip_connections=True makes it a u-net
 758     lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
 759     dec_model = conv_dec(nb_features,
 760                          None,
 761                          nb_levels,
 762                          conv_size,
 763                          nb_labels,
 764                          name=model_name,
 765                          prefix=prefix,
 766                          feat_mult=feat_mult,
 767                          pool_size=pool_size,
 768                          use_skip_connections=True,
 769                          skip_n_concatenations=skip_n_concatenations,
 770                          padding=padding,
 771                          dilation_rate_mult=dilation_rate_mult,
 772                          activation=activation,
 773                          use_residuals=use_residuals,
 774                          final_pred_activation=final_pred_activation,
 775                          nb_conv_per_level=nb_conv_per_level,
 776                          batch_norm=batch_norm,
 777                          layer_nb_feats=lnf,
 778                          conv_dropout=conv_dropout,
 779                          input_model=enc_model)
 780     final_model = dec_model
 781 
 782     return final_model
 783 
 784 def conv_enc(nb_features,
 785              input_shape,
 786              nb_levels,
 787              conv_size,
 788              name=None,
 789              prefix=None,
 790              feat_mult=1,
 791              pool_size=2,
 792              dilation_rate_mult=1,
 793              padding='same',
 794              activation='elu',
 795              layer_nb_feats=None,
 796              use_residuals=False,
 797              nb_conv_per_level=2,
 798              conv_dropout=0,
 799              batch_norm=None,
 800              input_model=None):
 801 
 802     # naming
 803     model_name = name
 804     if prefix is None:
 805         prefix = model_name
 806 
 807     # first layer: input
 808     name = '%s_input' % prefix
 809     if input_model is None:
 810         input_tensor = KL.Input(shape=input_shape, name=name)
 811         last_tensor = input_tensor
 812     else:
 813         input_tensor = input_model.inputs
 814         last_tensor = input_model.outputs
 815         if isinstance(last_tensor, list):
 816             last_tensor = last_tensor[0]
 817 
 818     # volume size data
 819     ndims = len(input_shape) - 1
 820     if isinstance(pool_size, int):
 821         pool_size = (pool_size,) * ndims
 822 
 823     # prepare layers
 824     convL = getattr(KL, 'Conv%dD' % ndims)
 825     conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
 826     maxpool = getattr(KL, 'MaxPooling%dD' % ndims)
 827 
 828     # down arm:
 829     # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
 830     lfidx = 0  # level feature index
 831     for level in range(nb_levels):
 832         lvl_first_tensor = last_tensor
 833         nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
 834         conv_kwargs['dilation_rate'] = dilation_rate_mult ** level
 835 
 836         for conv in range(nb_conv_per_level):  # does several conv per level, max pooling applied at the end
 837             if layer_nb_feats is not None:  # None or List of all the feature numbers
 838                 nb_lvl_feats = layer_nb_feats[lfidx]
 839                 lfidx += 1
 840 
 841             name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
 842             if conv < (nb_conv_per_level - 1) or (not use_residuals):
 843                 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
 844             else:  # no activation
 845                 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
 846 
 847             if conv_dropout > 0:
 848                 # conv dropout along feature space only
 849                 name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
 850                 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 851                 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
 852 
 853         if use_residuals:
 854             convarm_layer = last_tensor
 855 
 856             # the "add" layer is the original input
 857             # However, it may not have the right number of features to be added
 858             nb_feats_in = lvl_first_tensor.get_shape()[-1]
 859             nb_feats_out = convarm_layer.get_shape()[-1]
 860             add_layer = lvl_first_tensor
 861             if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
 862                 name = '%s_expand_down_merge_%d' % (prefix, level)
 863                 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
 864                 add_layer = last_tensor
 865 
 866                 if conv_dropout > 0:
 867                     name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv)
 868                     noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 869 
 870             name = '%s_res_down_merge_%d' % (prefix, level)
 871             last_tensor = KL.add([add_layer, convarm_layer], name=name)
 872 
 873             name = '%s_res_down_merge_act_%d' % (prefix, level)
 874             last_tensor = KL.Activation(activation, name=name)(last_tensor)
 875 
 876         if batch_norm is not None:
 877             name = '%s_bn_down_%d' % (prefix, level)
 878             last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
 879 
 880         # max pool if we're not at the last level
 881         if level < (nb_levels - 1):
 882             name = '%s_maxpool_%d' % (prefix, level)
 883             last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)
 884 
 885     # create the model and return
 886     model = keras.Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
 887     return model
 888 
 889 
 890 def conv_dec(nb_features,
 891              input_shape,
 892              nb_levels,
 893              conv_size,
 894              nb_labels,
 895              name=None,
 896              prefix=None,
 897              feat_mult=1,
 898              pool_size=2,
 899              use_skip_connections=False,
 900              skip_n_concatenations=0,
 901              padding='same',
 902              dilation_rate_mult=1,
 903              activation='elu',
 904              use_residuals=False,
 905              final_pred_activation='softmax',
 906              nb_conv_per_level=2,
 907              layer_nb_feats=None,
 908              batch_norm=None,
 909              conv_dropout=0,
 910              input_model=None):
 911 
 912     # naming
 913     model_name = name
 914     if prefix is None:
 915         prefix = model_name
 916 
 917     # if using skip connections, make sure need to use them.
 918     if use_skip_connections:
 919         assert input_model is not None, "is using skip connections, tensors dictionary is required"
 920 
 921     # first layer: input
 922     input_name = '%s_input' % prefix
 923     if input_model is None:
 924         input_tensor = KL.Input(shape=input_shape, name=input_name)
 925         last_tensor = input_tensor
 926     else:
 927         input_tensor = input_model.input
 928         last_tensor = input_model.output
 929         input_shape = last_tensor.shape.as_list()[1:]
 930 
 931     # vol size info
 932     ndims = len(input_shape) - 1
 933     if isinstance(pool_size, int):
 934         if ndims > 1:
 935             pool_size = (pool_size,) * ndims
 936 
 937     # prepare layers
 938     convL = getattr(KL, 'Conv%dD' % ndims)
 939     conv_kwargs = {'padding': padding, 'activation': activation}
 940     upsample = getattr(KL, 'UpSampling%dD' % ndims)
 941 
 942     # up arm:
 943     # nb_levels - 1 layers of Deconvolution3D
 944     #    (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
 945     lfidx = 0
 946     for level in range(nb_levels - 1):
 947         nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
 948         conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)
 949 
 950         # upsample matching the max pooling layers size
 951         name = '%s_up_%d' % (prefix, nb_levels + level)
 952         last_tensor = upsample(size=pool_size, name=name)(last_tensor)
 953         up_tensor = last_tensor
 954 
 955         # merge layers combining previous layer
 956         if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
 957             conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
 958             cat_tensor = input_model.get_layer(conv_name).output
 959             name = '%s_merge_%d' % (prefix, nb_levels + level)
 960             last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)
 961 
 962         # convolution layers
 963         for conv in range(nb_conv_per_level):
 964             if layer_nb_feats is not None:
 965                 nb_lvl_feats = layer_nb_feats[lfidx]
 966                 lfidx += 1
 967 
 968             name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
 969             if conv < (nb_conv_per_level - 1) or (not use_residuals):
 970                 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
 971             else:
 972                 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
 973 
 974             if conv_dropout > 0:
 975                 name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
 976                 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 977                 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
 978 
 979         # residual block
 980         if use_residuals:
 981 
 982             # the "add" layer is the original input
 983             # However, it may not have the right number of features to be added
 984             add_layer = up_tensor
 985             nb_feats_in = add_layer.get_shape()[-1]
 986             nb_feats_out = last_tensor.get_shape()[-1]
 987             if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
 988                 name = '%s_expand_up_merge_%d' % (prefix, level)
 989                 add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)
 990 
 991                 if conv_dropout > 0:
 992                     name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv)
 993                     noise_shape = [None, *[1] * ndims, nb_lvl_feats]
 994                     last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
 995 
 996             name = '%s_res_up_merge_%d' % (prefix, level)
 997             last_tensor = KL.add([last_tensor, add_layer], name=name)
 998 
 999             name = '%s_res_up_merge_act_%d' % (prefix, level)
1000             last_tensor = KL.Activation(activation, name=name)(last_tensor)
1001 
1002         if batch_norm is not None:
1003             name = '%s_bn_up_%d' % (prefix, level)
1004             last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
1005 
1006     # Compute likelyhood prediction (no activation yet)
1007     name = '%s_likelihood' % prefix
1008     last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
1009     like_tensor = last_tensor
1010 
1011     # output prediction layer
1012     # we use a softmax to compute P(L_x|I) where x is each location
1013     if final_pred_activation == 'softmax':
1014         name = '%s_prediction' % prefix
1015         softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
1016         pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor)
1017 
1018     # otherwise create a layer that does nothing.
1019     else:
1020         name = '%s_prediction' % prefix
1021         pred_tensor = KL.Activation('linear', name=name)(like_tensor)
1022 
1023     # create the model and retun
1024     model = keras.Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
1025     return model
1026 
1027 def postprocess(post_patch_seg, post_patch_parc, shape, pad_idx, crop_idx,
1028                 labels_segmentation, labels_parcellation, aff, im_res):
1029 
1030     # get posteriors
1031     post_patch_seg = np.squeeze(post_patch_seg)
1032     post_patch_seg = crop_volume_with_idx(post_patch_seg, pad_idx, n_dims=3, return_copy=False)
1033 
1034     # keep biggest connected component
1035     tmp_post_patch_seg = post_patch_seg[..., 1:]
1036     post_patch_seg_mask = np.sum(tmp_post_patch_seg, axis=-1) > 0.25
1037     post_patch_seg_mask = get_largest_connected_component(post_patch_seg_mask)
1038     post_patch_seg_mask = np.stack([post_patch_seg_mask]*tmp_post_patch_seg.shape[-1], axis=-1)
1039     tmp_post_patch_seg = mask_volume(tmp_post_patch_seg, mask=post_patch_seg_mask, return_copy=False)
1040     post_patch_seg[..., 1:] = tmp_post_patch_seg
1041 
1042     # reset posteriors to zero outside the largest connected component of each topological class
1043     post_patch_seg_mask = post_patch_seg > 0.2
1044     post_patch_seg[..., 1:] *= post_patch_seg_mask[..., 1:]
1045 
1046     # get hard segmentation
1047     post_patch_seg /= np.sum(post_patch_seg, axis=-1)[..., np.newaxis]
1048     seg_patch = labels_segmentation[post_patch_seg.argmax(-1).astype('int32')].astype('int32')
1049 
1050     # postprocess parcellation
1051     post_patch_parc = np.squeeze(post_patch_parc)
1052     post_patch_parc = crop_volume_with_idx(post_patch_parc, pad_idx, n_dims=3, return_copy=False)
1053     mask = (seg_patch == 3) | (seg_patch == 42)
1054     post_patch_parc[..., 0] = np.ones_like(post_patch_parc[..., 0])
1055     post_patch_parc[..., 0] = mask_volume(post_patch_parc[..., 0], mask=mask < 0.1, return_copy=False)
1056     post_patch_parc /= np.sum(post_patch_parc, axis=-1)[..., np.newaxis]
1057     parc_patch = labels_parcellation[post_patch_parc.argmax(-1).astype('int32')].astype('int32')
1058     seg_patch[mask] = parc_patch[mask]
1059 
1060     # paste patches back to matrix of original image size
1061     if crop_idx is not None:
1062         # we need to go through this because of the posteriors of the background, otherwise pad_volume would work
1063         seg = np.zeros(shape=shape, dtype='int32')
1064         posteriors = np.zeros(shape=[*shape, labels_segmentation.shape[0]])
1065         posteriors[..., 0] = np.ones(shape)  # place background around patch
1066         seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch
1067         posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch_seg
1068     else:
1069         seg = seg_patch
1070         posteriors = post_patch_seg
1071 
1072     # align prediction back to first orientation
1073     seg = align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1074     posteriors = align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1075 
1076     # compute volumes
1077     volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1078     volumes = np.concatenate([np.array([np.sum(volumes)]), volumes])
1079     if post_patch_parc is not None:
1080         volumes_parc = np.sum(post_patch_parc[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1081         total_volume_cortex = np.sum(volumes[np.where((labels_segmentation == 3) | (labels_segmentation == 42))[0] - 1])
1082         volumes_parc = volumes_parc / np.sum(volumes_parc) * total_volume_cortex
1083         volumes = np.concatenate([volumes, volumes_parc])
1084     volumes = np.around(volumes * np.prod(im_res), 3)
1085 
1086     return seg, posteriors, volumes
1087 
1088 def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
1089     mkdir(os.path.dirname(path))
1090     if '.npz' in path:
1091         np.savez_compressed(path, vol_data=volume)
1092     else:
1093         if header is None:
1094             header = nib.Nifti1Header()
1095         if isinstance(aff, str):
1096             if aff == 'FS':
1097                 aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
1098         elif aff is None:
1099             aff = np.eye(4)
1100         nifty = nib.Nifti1Image(volume, aff, header)
1101         if dtype is not None:
1102             if 'int' in dtype:
1103                 volume = np.round(volume)
1104             volume = volume.astype(dtype=dtype)
1105             nifty.set_data_dtype(dtype)
1106         if res is not None:
1107             if n_dims is None:
1108                 n_dims, _ = get_dims(volume.shape)
1109             res = reformat_to_list(res, length=n_dims, dtype=None)
1110             nifty.header.set_zooms(res)
1111         nib.save(nifty, path)
1112 
1113 
1114 
1115 def mkdir(path_dir):
1116 
1117     if len(path_dir)>0:
1118         if path_dir[-1] == '/':
1119             path_dir = path_dir[:-1]
1120         if not os.path.isdir(path_dir):
1121             list_dir_to_create = [path_dir]
1122             while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])):
1123                 list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1]))
1124             for dir_to_create in reversed(list_dir_to_create):
1125                 os.mkdir(dir_to_create)
1126 
1127 
1128 def getM(ref, mov):
1129     zmat = np.zeros(ref.shape[::-1])
1130     zcol = np.zeros([ref.shape[1], 1])
1131     ocol = np.ones([ref.shape[1], 1])
1132     zero = np.zeros(zmat.shape)
1133     A = np.concatenate([
1134         np.concatenate([np.transpose(ref), zero, zero, ocol, zcol, zcol], axis=1),
1135         np.concatenate([zero, np.transpose(ref), zero, zcol, ocol, zcol], axis=1),
1136         np.concatenate([zero, zero, np.transpose(ref), zcol, zcol, ocol], axis=1)], axis=0)
1137     b = np.concatenate([np.transpose(mov[0, :]), np.transpose(mov[1, :]), np.transpose(mov[2, :])], axis=0)
1138     x = np.matmul(np.linalg.inv(np.matmul(np.transpose(A), A)), np.matmul(np.transpose(A), b))
1139     M = np.stack([
1140         [x[0], x[1], x[2], x[9]],
1141         [x[3], x[4], x[5], x[10]],
1142         [x[6], x[7], x[8], x[11]],
1143         [0, 0, 0, 1]])
1144     return M
1145 
1146 def getMrigid(A, B):
1147     centroid_A = np.mean(A, axis=1, keepdims=True)
1148     centroid_B = np.mean(B, axis=1, keepdims=True)
1149     A_centered = A - centroid_A
1150     B_centered = B - centroid_B
1151     H = A_centered @ B_centered.T
1152     U, S, Vt = np.linalg.svd(H)
1153     R = Vt.T @ U.T
1154     if np.linalg.det(R) < 0:
1155         Vt[2, :] *= -1
1156         R = Vt.T @ U.T
1157     t = centroid_B - R @ centroid_A
1158     T = np.eye(4)
1159     T[:3, :3] = R
1160     T[:3, 3] = t.flatten()
1161     return T
1162 
1163 
1164 def fast_3D_interp_torch(X, II, JJ, KK, mode):
1165     if mode=='nearest':
1166         IIr = torch.round(II).long()
1167         JJr = torch.round(JJ).long()
1168         KKr = torch.round(KK).long()
1169         IIr[IIr < 0] = 0
1170         JJr[JJr < 0] = 0
1171         KKr[KKr < 0] = 0
1172         IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1)
1173         JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1)
1174         KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1)
1175         Y = X[IIr, JJr, KKr]
1176     elif mode=='linear':
1177         ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1)
1178         IIv = II[ok]
1179         JJv = JJ[ok]
1180         KKv = KK[ok]
1181 
1182         fx = torch.floor(IIv).long()
1183         cx = fx + 1
1184         cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1185         wcx = IIv - fx
1186         wfx = 1 - wcx
1187 
1188         fy = torch.floor(JJv).long()
1189         cy = fy + 1
1190         cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1191         wcy = JJv - fy
1192         wfy = 1 - wcy
1193 
1194         fz = torch.floor(KKv).long()
1195         cz = fz + 1
1196         cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1197         wcz = KKv - fz
1198         wfz = 1 - wcz
1199 
1200         c000 = X[fx, fy, fz]
1201         c100 = X[cx, fy, fz]
1202         c010 = X[fx, cy, fz]
1203         c110 = X[cx, cy, fz]
1204         c001 = X[fx, fy, cz]
1205         c101 = X[cx, fy, cz]
1206         c011 = X[fx, cy, cz]
1207         c111 = X[cx, cy, cz]
1208 
1209         c00 = c000 * wfx + c100 * wcx
1210         c01 = c001 * wfx + c101 * wcx
1211         c10 = c010 * wfx + c110 * wcx
1212         c11 = c011 * wfx + c111 * wcx
1213 
1214         c0 = c00 * wfy + c10 * wcy
1215         c1 = c01 * wfy + c11 * wcy
1216 
1217         c = c0 * wfz + c1 * wcz
1218 
1219         Y = torch.zeros(II.shape, device='cpu')
1220         Y[ok] = c.float()
1221 
1222     else:
1223         sf.system.fatal('mode must be linear or nearest')
1224 
1225     return Y
1226 
1227 
1228 
1229 def fast_3D_interp_field_torch(X, II, JJ, KK):
1230 
1231     ok = (II > 0) & (JJ > 0) & (KK > 0) & (II <= X.shape[0] - 1) & (JJ <= X.shape[1] - 1) & (KK <= X.shape[2] - 1)
1232     IIv = II[ok]
1233     JJv = JJ[ok]
1234     KKv = KK[ok]
1235 
1236     fx = torch.floor(IIv).long()
1237     cx = fx + 1
1238     cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1239     wcx = IIv - fx
1240     wfx = 1 - wcx
1241 
1242     fy = torch.floor(JJv).long()
1243     cy = fy + 1
1244     cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1245     wcy = JJv - fy
1246     wfy = 1 - wcy
1247 
1248     fz = torch.floor(KKv).long()
1249     cz = fz + 1
1250     cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1251     wcz = KKv - fz
1252     wfz = 1 - wcz
1253 
1254     Y = torch.zeros([*II.shape, 3], device='cpu')
1255     for channel in range(3):
1256 
1257         Xc = X[:, :, :, channel]
1258 
1259         c000 = Xc[fx, fy, fz]
1260         c100 = Xc[cx, fy, fz]
1261         c010 = Xc[fx, cy, fz]
1262         c110 = Xc[cx, cy, fz]
1263         c001 = Xc[fx, fy, cz]
1264         c101 = Xc[cx, fy, cz]
1265         c011 = Xc[fx, cy, cz]
1266         c111 = Xc[cx, cy, cz]
1267 
1268         c00 = c000 * wfx + c100 * wcx
1269         c01 = c001 * wfx + c101 * wcx
1270         c10 = c010 * wfx + c110 * wcx
1271         c11 = c011 * wfx + c111 * wcx
1272 
1273         c0 = c00 * wfy + c10 * wcy
1274         c1 = c01 * wfy + c11 * wcy
1275 
1276         c = c0 * wfz + c1 * wcz
1277 
1278         Yc = torch.zeros(II.shape, device='cpu')
1279         Yc[ok] = c.float()
1280 
1281         Y[:, :, :, channel] = Yc
1282 
1283     return Y
1284 
1285 
1286 def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=True):
1287 
1288     # get info
1289     new_volume = volume.copy() if return_copy else volume
1290     n_dims = int(np.array(crop_idx).shape[0] / 2) if n_dims is None else n_dims
1291 
1292     # crop image
1293     if n_dims == 2:
1294         new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...]
1295     elif n_dims == 3:
1296         new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...]
1297     else:
1298         sf.system.fatal('cannot crop volumes with more than 3 dimensions')
1299 
1300     if aff is not None:
1301         aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3]
1302         return new_volume, aff
1303     else:
1304         return new_volume
1305 
1306 
1307 def get_largest_connected_component(mask, structure=None):
1308     components, n_components = scipy_label(mask, structure)
1309     return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy()
1310 
1311 
1312 def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, masking_value=0,
1313                 return_mask=False, return_copy=True):
1314 
1315     # get info
1316     new_volume = volume.copy() if return_copy else volume
1317     vol_shape = list(new_volume.shape)
1318     n_dims, n_channels = get_dims(vol_shape)
1319 
1320     # get mask and erode/dilate it
1321     if mask is None:
1322         mask = new_volume >= threshold
1323     else:
1324         assert list(mask.shape[:n_dims]) == vol_shape[:n_dims], 'mask should have shape {0}, or {1}, had {2}'.format(
1325             vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape))
1326         mask = mask > 0
1327     if dilate > 0:
1328         dilate_struct = build_binary_structure(dilate, n_dims)
1329         mask_to_apply = binary_dilation(mask, dilate_struct)
1330     else:
1331         mask_to_apply = mask
1332     if erode > 0:
1333         erode_struct = build_binary_structure(erode, n_dims)
1334         mask_to_apply = binary_erosion(mask_to_apply, erode_struct)
1335     if fill_holes:
1336         mask_to_apply = binary_fill_holes(mask_to_apply)
1337 
1338     # replace values outside of mask by padding_char
1339     if mask_to_apply.shape == new_volume.shape:
1340         new_volume[np.logical_not(mask_to_apply)] = masking_value
1341     else:
1342         new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = masking_value
1343 
1344     if return_mask:
1345         return new_volume, mask_to_apply
1346     else:
1347         return new_volume
1348 
1349 
1350 def build_binary_structure(connectivity, n_dims, shape=None):
1351     if shape is None:
1352         shape = [connectivity * 2 + 1] * n_dims
1353     else:
1354         shape = reformat_to_list(shape, length=n_dims)
1355     dist = np.ones(shape)
1356     center = tuple([tuple([int(s / 2)]) for s in shape])
1357     dist[center] = 0
1358     dist = distance_transform_edt(dist)
1359     struct = (dist <= connectivity) * 1
1360     return struct
1361 
1362 
1363 
1364 def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'):
1365 
1366     assert (cropping_margin is not None) | (cropping_shape is not None), \
        'cropping_margin or cropping_shape should be provided'
1367     assert not ((cropping_margin is not None) & (cropping_shape is not None)), \
        'only one of cropping_margin or cropping_shape should be provided'
1368 
1369     # get info
1370     new_volume = volume.copy()
1371     vol_shape = new_volume.shape
1372     n_dims, _ = get_dims(vol_shape)
1373 
1374     # find cropping indices
1375     if cropping_margin is not None:
1376         cropping_margin = reformat_to_list(cropping_margin, length=n_dims)
1377         do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin)
1378         min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)]
1379         max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)]
1380     else:
1381         cropping_shape = reformat_to_list(cropping_shape, length=n_dims)
1382         if mode == 'center':
1383             min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0)
1384             max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)],
1385                                       np.array(vol_shape)[:n_dims])
1386         elif mode == 'random':
1387             crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0)
1388             min_crop_idx = np.random.randint(0, high=crop_max_val + 1)
1389             max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims])
1390         else:
1391             raise ValueError('mode should be either "center" or "random", had %s' % mode)
1392     crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)])
1393 
1394     # crop volume
1395     if n_dims == 2:
1396         new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...]
1397     elif n_dims == 3:
1398         new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...]
1399 
1400     # sort outputs
1401     output = [new_volume]
1402     if aff is not None:
1403         aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ np.array(min_crop_idx)
1404         output.append(aff)
1405     if return_crop_idx:
1406         output.append(crop_idx)
1407     return output[0] if len(output) == 1 else tuple(output)
1408 
1409 
1410 
1411 def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2., max_percentile=98., use_positive_only=False):
1412 
1413     # select only positive intensities
1414     new_volume = volume.copy()
1415     intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten()
1416 
1417     # define min and max intensities in original image for normalisation
1418     robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile)
1419     robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile)
1420 
1421     # trim values outside range
1422     new_volume = np.clip(new_volume, robust_min, robust_max)
1423 
1424     # rescale image
1425     if robust_min != robust_max:
1426         return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min)
1427     else:  # avoid dividing by zero
1428         return np.zeros_like(new_volume)
1429 
1430 
1431 
1432 
1433 def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=False):
1434     # get info
1435     new_volume = volume.copy()
1436     vol_shape = new_volume.shape
1437     n_dims, n_channels = get_dims(vol_shape)
1438     padding_shape = reformat_to_list(padding_shape, length=n_dims, dtype='int')
1439 
1440     # check if need to pad
1441     if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')):
1442 
1443         # get padding margins
1444         min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1445         max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1446         pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])])
1447         pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)])
1448         if n_channels > 1:
1449             pad_margins = tuple(list(pad_margins) + [(0, 0)])
1450 
1451         # pad volume
1452         new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value)
1453 
1454         if aff is not None:
1455             if n_dims == 2:
1456                 min_margins = np.append(min_margins, 0)
1457             aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_margins
1458 
1459     else:
1460         pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape[:n_dims])])
1461 
1462     # sort outputs
1463     output = [new_volume]
1464     if aff is not None:
1465         output.append(aff)
1466     if return_pad_idx:
1467         output.append(pad_idx)
1468     return output[0] if len(output) == 1 else tuple(output)
1469 
1470 
1471 
1472 def add_axis(x, axis=0):
1473     axis = reformat_to_list(axis)
1474     for ax in axis:
1475         x = np.expand_dims(x, axis=ax)
1476     return x
1477 
1478 
1479 def volshape_to_meshgrid(volshape, **kwargs):
1480     """
1481     compute Tensor meshgrid from a volume size
1482     """
1483 
1484     isint = [float(d).is_integer() for d in volshape]
1485     if not all(isint):
1486         raise ValueError("volshape needs to be a list of integers")
1487 
1488     linvec = [tf.range(0, d) for d in volshape]
1489     return meshgrid(*linvec, **kwargs)
1490 
1491 
1492 def meshgrid(*args, **kwargs):
1493 
1494     indexing = kwargs.pop("indexing", "xy")
1495     if kwargs:
1496         key = list(kwargs.keys())[0]
1497         raise TypeError("'{}' is an invalid keyword argument "
1498                         "for this function".format(key))
1499 
1500     if indexing not in ("xy", "ij"):
1501         raise ValueError("indexing parameter must be either 'xy' or 'ij'")
1502 
1503     # with ops.name_scope(name, "meshgrid", args) as name:
1504     ndim = len(args)
1505     s0 = (1,) * ndim
1506 
1507     # Prepare reshape by inserting dimensions with size 1 where needed
1508     output = []
1509     for i, x in enumerate(args):
1510         output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
1511     # Create parameters for broadcasting each tensor to the full size
1512     shapes = [tf.size(x) for x in args]
1513     sz = [x.get_shape().as_list()[0] for x in args]
1514 
1515     # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype
1516     if indexing == "xy" and ndim > 1:
1517         output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2))
1518         output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
1519         shapes[0], shapes[1] = shapes[1], shapes[0]
1520         sz[0], sz[1] = sz[1], sz[0]
1521 
1522     # This is the part of the implementation from tf that is slow.
1523     # We replace it below to get a ~6x speedup (essentially using tile instead of * tf.ones())
1524     # mult_fact = tf.ones(shapes, output_dtype)
1525     # return [x * mult_fact for x in output]
1526     for i in range(len(output)):
1527         stack_sz = [*sz[:i], 1, *sz[(i + 1):]]
1528         if indexing == 'xy' and ndim > 1 and i < 2:
1529             stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0]
1530         output[i] = tf.tile(output[i], tf.stack(stack_sz))
1531     return output
1532 
1533 
1534 def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
1535 
1536     # convert sigma into a tensor
1537     if not tf.is_tensor(sigma):
1538         sigma_tens = tf.convert_to_tensor(reformat_to_list(sigma), dtype='float32')
1539     else:
1540         assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor'
1541         sigma_tens = sigma
1542     shape = sigma_tens.get_shape().as_list()
1543 
1544     # get n_dims and batchsize
1545     if shape[0] is not None:
1546         n_dims = shape[0]
1547         batchsize = None
1548     else:
1549         n_dims = shape[1]
1550         batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0]
1551 
1552     # reformat max_sigma
1553     if max_sigma is not None:  # dynamic blurring
1554         max_sigma = np.array(reformat_to_list(max_sigma, length=n_dims))
1555     else:  # sigma is fixed
1556         max_sigma = np.array(reformat_to_list(sigma, length=n_dims))
1557 
1558     # randomise the burring std dev and/or split it between dimensions
1559     if blur_range is not None:
1560         if blur_range != 1:
1561             sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range)
1562 
1563     # get size of blurring kernels
1564     windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1
1565 
1566     if separable:
1567 
1568         split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1)
1569 
1570         kernels = list()
1571         comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
1572         for (i, wsize) in enumerate(windowsize):
1573 
1574             if wsize > 1:
1575 
1576                 # build meshgrid and replicate it along batch dim if dynamic blurring
1577                 locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2
1578                 if batchsize is not None:
1579                     locations = tf.tile(tf.expand_dims(locations, axis=0),
1580                                         tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')],
1581                                                   axis=0))
1582                     comb[i] += 1
1583 
1584                 # compute gaussians
1585                 exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2)
1586                 g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i]))
1587                 g = g / tf.reduce_sum(g)
1588 
1589                 for axis in comb[i]:
1590                     g = tf.expand_dims(g, axis=axis)
1591                 kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1))
1592 
1593             else:
1594                 kernels.append(None)
1595 
1596     else:
1597 
1598         # build meshgrid
1599         mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
1600         diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
1601 
1602         # replicate meshgrid to batch size and reshape sigma_tens
1603         if batchsize is not None:
1604             diff = tf.tile(tf.expand_dims(diff, axis=0),
1605                            tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0))
1606             for i in range(n_dims):
1607                 sigma_tens = tf.expand_dims(sigma_tens, axis=1)
1608         else:
1609             for i in range(n_dims):
1610                 sigma_tens = tf.expand_dims(sigma_tens, axis=0)
1611 
1612         # compute gaussians
1613         sigma_is_0 = tf.equal(sigma_tens, 0)
1614         exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2)
1615         norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens))
1616         kernels = K.sum(norms, -1)
1617         kernels = tf.exp(kernels)
1618         kernels /= tf.reduce_sum(kernels)
1619         kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1)
1620 
1621     return kernels
1622 
1623 
1624 def get_mapping_lut(source, dest=None):
1625     """This functions returns the look-up table to map a list of N values (source) to another list (dest).
1626     If the second list is not given, we assume it is equal to [0, ..., N-1]."""
1627 
1628     # initialise
1629     source = np.array(reformat_to_list(source), dtype='int32')
1630     n_labels = source.shape[0]
1631 
1632     # build new label list if neccessary
1633     if dest is None:
1634         dest = np.arange(n_labels, dtype='int32')
1635     else:
1636         assert len(source) == len(dest), 'label_list and new_label_list should have the same length'
1637         dest = np.array(reformat_to_list(dest, dtype='int'))
1638 
1639     # build look-up table
1640     lut = np.zeros(np.max(source) + 1, dtype='int32')
1641     for source, dest in zip(source, dest):
1642         lut[source] = dest
1643 
1644     return lut
1645 
1646 
1647 class GaussianBlur(KL.Layer):
1648     """Applies gaussian blur to an input image."""
1649 
1650     def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs):
1651         self.sigma = reformat_to_list(sigma)
1652         assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0'
1653         self.use_mask = use_mask
1654 
1655         self.n_dims = None
1656         self.n_channels = None
1657         self.blur_range = random_blur_range
1658         self.stride = None
1659         self.separable = None
1660         self.kernels = None
1661         self.convnd = None
1662         super(GaussianBlur, self).__init__(**kwargs)
1663 
1664     def get_config(self):
1665         config = super().get_config()
1666         config["sigma"] = self.sigma
1667         config["random_blur_range"] = self.blur_range
1668         config["use_mask"] = self.use_mask
1669         return config
1670 
1671     def build(self, input_shape):
1672 
1673         # get shapes
1674         if self.use_mask:
1675             assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True'
1676             self.n_dims = len(input_shape[0]) - 2
1677             self.n_channels = input_shape[0][-1]
1678         else:
1679             self.n_dims = len(input_shape) - 2
1680             self.n_channels = input_shape[-1]
1681 
1682         # prepare blurring kernel
1683         self.stride = [1]*(self.n_dims+2)
1684         self.sigma = reformat_to_list(self.sigma, length=self.n_dims)
1685         self.separable = np.linalg.norm(np.array(self.sigma)) > 5
1686         if self.blur_range is None:  # fixed kernels
1687             self.kernels = gaussian_kernel(self.sigma, separable=self.separable)
1688         else:
1689             self.kernels = None
1690 
1691         # prepare convolution
1692         self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
1693 
1694         self.built = True
1695         super(GaussianBlur, self).build(input_shape)
1696 
1697     def call(self, inputs, **kwargs):
1698 
1699         if self.use_mask:
1700             image = inputs[0]
1701             mask = tf.cast(inputs[1], 'bool')
1702         else:
1703             image = inputs
1704             mask = None
1705 
1706         # redefine the kernels at each new step when blur_range is activated
1707         if self.blur_range is not None:
1708             self.kernels = gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable)
1709 
1710         if self.separable:
1711             for k in self.kernels:
1712                 if k is not None:
1713                     image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME')
1714                                        for n in range(self.n_channels)], -1)
1715                     if self.use_mask:
1716                         maskb = tf.cast(mask, 'float32')
1717                         maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME')
1718                                            for n in range(self.n_channels)], -1)
1719                         image = image / (maskb + keras.backend.epsilon())
1720                         image = tf.where(mask, image, tf.zeros_like(image))
1721         else:
1722             if any(self.sigma):
1723                 image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME')
1724                                    for n in range(self.n_channels)], -1)
1725                 if self.use_mask:
1726                     maskb = tf.cast(mask, 'float32')
1727                     maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME')
1728                                        for n in range(self.n_channels)], -1)
1729                     image = image / (maskb + keras.backend.epsilon())
1730                     image = tf.where(mask, image, tf.zeros_like(image))
1731 
1732         return image
1733 
1734 
1735 class ConvertLabels(KL.Layer):
1736 
1737     def __init__(self, source_values, dest_values=None, **kwargs):
1738         self.source_values = source_values
1739         self.dest_values = dest_values
1740         self.lut = None
1741         super(ConvertLabels, self).__init__(**kwargs)
1742 
1743     def get_config(self):
1744         config = super().get_config()
1745         config["source_values"] = self.source_values
1746         config["dest_values"] = self.dest_values
1747         return config
1748 
1749     def build(self, input_shape):
1750         self.lut = tf.convert_to_tensor(get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32')
1751         self.built = True
1752         super(ConvertLabels, self).build(input_shape)
1753 
1754     def call(self, inputs, **kwargs):
1755         return tf.gather(self.lut, tf.cast(inputs, dtype='int32'))
1756 
1757 
1758 
1759 
1760 # execute script
1761 if __name__ == '__main__':
1762     main()

Attached Files

To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.

You are not allowed to attach a file to this page.