Attachment 'mri_easyatlas.py'
Download 1 import os
2 import argparse
3 import numpy as np
4 import voxelmorph as vxm
5 import torch
6 import surfa as sf
7 import nibabel as nib
8 import glob
9 from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion, distance_transform_edt, binary_fill_holes
10 from scipy.ndimage import label as scipy_label
11
12 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13 os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
14 import tensorflow as tf
15 import keras
16 import keras.backend as K
17 import keras.layers as KL
18
19
20
21 # set tensorflow logging
22 tf.get_logger().setLevel('ERROR')
23 K.set_image_data_format('channels_last')
24
25
26 def main():
27
28 parser = argparse.ArgumentParser(description="EasyAtlas: fast atlas construction with EasyReg", epilog='\n')
29
30 # input/outputs
31 parser.add_argument("--i", help="Input directory with scans")
32 parser.add_argument("--o", help="Output directory where atlas and other files will be written")
33 parser.add_argument("--threads", type=int, default=-1, help="(optional) Number of cores to be used. You can use -1 to use all available cores. Default is -1.")
34 parser.add_argument('--use_reliability_maps', action='store_true', help='Use reliability maps when averaging into atlas (recommended if data are not 1mm isotropic!')
35
36 # parse commandline
37 args = parser.parse_args()
38
39 #############
40
41 # Very first thing: we require FreeSurfer
42 if not os.environ.get('FREESURFER_HOME'):
43 sf.system.fatal('FREESURFER_HOME is not set. Please source freesurfer.')
44 fs_home = os.environ.get('FREESURFER_HOME')
45
46 if args.i is None:
47 sf.system.fatal('Input directory must be provided')
48 if args.o is None:
49 sf.system.fatal('Output directory must be provided')
50
51 # limit the number of threads to be used if running on CPU
52 if args.threads<0:
53 args.threads = os.cpu_count()
54 print('using all available threads ( %s )' % args.threads)
55 else:
56 print('using %s threads' % args.threads)
57 tf.config.threading.set_inter_op_parallelism_threads(args.threads)
58 tf.config.threading.set_intra_op_parallelism_threads(args.threads)
59 torch.set_num_threads(args.threads)
60
61 # path models
62 path_model_segmentation = fs_home + '/models/synthseg_2.0.h5'
63 path_model_parcellation = fs_home + '/models/synthseg_parc_2.0.h5'
64 path_model_registration_trained = fs_home + '/models/easyreg_v10_230103.h5'
65
66 # path labels
67 labels_segmentation = fs_home + '/models/synthseg_segmentation_labels_2.0.npy'
68 labels_parcellation = fs_home + '/models/synthseg_parcellation_labels.npy'
69 atlas_volsize = [160, 160, 192]
70 atlas_aff = np.matrix([[-1, 0, 0, 79], [0, 0, 1, -104], [0, -1, 0, 79], [0, 0, 0, 1]])
71
72 # get label lists
73 labels_segmentation, _ = get_list_labels(label_list=labels_segmentation)
74 labels_segmentation, unique_idx = np.unique(labels_segmentation, return_index=True)
75 labels_parcellation, _ = np.unique(get_list_labels(labels_parcellation)[0], return_index=True)
76
77 # Create output (and SynthSeg) directory if needed
78 if os.path.exists(args.o) and os.path.isdir(args.o):
79 print('Output directory already exists; no need to create it')
80 else:
81 os.mkdir(args.o)
82 segdir = args.o + '/SynthSeg/'
83 if os.path.exists(segdir) and os.path.isdir(segdir):
84 print('SynthSeg directory already exists; no need to create it')
85 else:
86 os.mkdir(segdir)
87 regdir = args.o + '/Registrations/'
88 if os.path.exists(regdir) and os.path.isdir(regdir):
89 print('Registration directory already exists; no need to create it')
90 else:
91 os.mkdir(regdir)
92 tempdir = args.o + '/temp/'
93 if os.path.exists(tempdir) and os.path.isdir(tempdir):
94 print('Temporary directory already exists; no need to create it')
95 else:
96 os.mkdir(tempdir)
97
98 # Build list of input, affine, segmentation files (supports nii, mgz, nii.gz)
99 input_files = sorted(glob.glob(args.i + '/*.nii.gz') + glob.glob(args.i + '/*.nii') + glob.glob(args.i + '/*.mgz'))
100 seg_files = []
101 reg_files = []
102 linear_files = []
103 for file in input_files:
104 _, tail = os.path.split(file)
105 seg_files.append(segdir + '/' + tail)
106 reg_files.append(regdir + '/' + tail)
107 linear_files.append(tempdir + '/' + tail + '.npy')
108
109 # Decide if we need to segment anything
110 all_segs_ready = True
111 for file in seg_files:
112 if os.path.exists(file) is False:
113 all_segs_ready = False
114
115 # Run SynthSeg if needed
116 if all_segs_ready:
117 print('SynthSeg already there for all input files; no need to segment anything')
118 else:
119 print('Setting up segmentation net')
120 segmentation_net = build_seg_model(model_file_segmentation=path_model_segmentation,
121 model_file_parcellation=path_model_parcellation,
122 labels_segmentation=labels_segmentation,
123 labels_parcellation=labels_parcellation)
124 for i in range(len(input_files)):
125 if os.path.exists(seg_files[i]):
126 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmentation already there')
127 else:
128 print('Image ' + str(i + 1) + ' of ' + str(len(input_files)) + ': segmenting')
129 image, aff, h, im_res, shape, pad_idx, crop_idx = preprocess(path_image=input_files[i], crop=None,
130 min_pad=128, path_resample=None)
131 post_patch_segmentation, post_patch_parcellation = segmentation_net.predict(image)
132 seg_buffer, _, _ = postprocess(post_patch_seg=post_patch_segmentation,
133 post_patch_parc=post_patch_parcellation,
134 shape=shape,
135 pad_idx=pad_idx,
136 crop_idx=crop_idx,
137 labels_segmentation=labels_segmentation,
138 labels_parcellation=labels_parcellation,
139 aff=aff,
140 im_res=im_res)
141 save_volume(seg_buffer, aff, h, seg_files[i], dtype='int32')
142
143 # Now the linear registration part
144 print('Linear registration with centroids of segmentations')
145
146 # First, prepare a bunch of common variables
147 labels = np.array([2,4,5,7,8,10,11,12,13,14,15,16,17,18,26,28,41,43,44,46,47,49,50,51,52,53,54,58,60,
148 1001,1002,1003,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023,1024,1025,1026,1027,1028,1029,1030,1031,1032,1033,1034,1035,
149 2001,2002,2003,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2020,2021,2022,2023,2024,2025,2026,2027,2028,2029,2030,2031,2032,2033,2034,2035])
150 nlab = len(labels)
151 atlasCOG = np.array([[-28.,-18.,-37.,-19.,-27.,-19.,-23.,-31.,-26.,-2.,-3.,-3.,-29.,-26.,-14.,-14.,24.,14.,31.,12.,18.,14.,19.,26.,21.,25.,22.,11.,8.,-52.,-6.,-36.,-7.,-24.,-37.,-39.,-52.,-9.,-27.,-26.,-14.,-8.,-59.,-28.,-7.,-49.,-43.,-47.,-12.,-46.,-6.,-43.,-10.,-7.,-33.,-11.,-23.,-55.,-50.,-10.,-29.,-46.,-38.,48.,4.,31.,3.,21.,33.,37.,47.,3.,24.,20.,8.,4.,54.,21.,5.,45.,38.,46.,8.,45.,3.,38.,6.,4.,29.,9.,19.,51.,49.,10.,24.,43.,33.],
152 [-30.,-17.,-13.,-36.,-40.,-22.,-3.,-5.,-9.,-14.,-31.,-21.,-15.,-1.,3.,-16.,-32.,-20.,-14.,-37.,-42.,-24.,-3.,-6.,-10.,-15.,-2.,3.,-17.,-44.,-5.,-15.,-71.,2.,-29.,-70.,-23.,-44.,-73.,22.,-57.,27.,-19.,-23.,-45.,4.,31.,20.,-68.,-38.,-33.,-26.,-60.,23.,22.,0.,-72.,-12.,-49.,49.,17.,-25.,-3.,-42.,-1.,-16.,-76.,0.,-34.,-69.,-16.,-44.,-73.,22.,-56.,28.,-18.,-25.,-45.,-3.,30.,14.,-69.,-37.,-32.,-30.,-60.,21.,21.,0.,-72.,-11.,-49.,48.,15.,-27.,-3.],
153 [12.,14.,-13.,-41.,-51.,1.,13.,3.,1.,0.,-40.,-28.,-15.,-10.,2.,-7.,11.,14.,-12.,-40.,-51.,2.,14.,4.,2.,-14.,-10.,4.,-7.,-8.,32.,40.,-14.,-21.,-28.,-4.,-28.,-3.,-35.,3.,-29.,4.,-17.,-21.,35.,18.,9.,20.,-24.,28.,25.,34.,7.,18.,35.,48.,16.,-5.,12.,22.,-18.,1.,4.,-12.,32.,43.,-11.,-21.,-29.,-3.,-27.,0.,-34.,3.,-25.,6.,-18.,-20.,36.,18.,11.,20.,-20.,26.,25.,34.,4.,24.,34.,47.,17.,-5.,10.,20.,-18.,0.,4.]])
154
155 II, JJ, KK = np.meshgrid(np.arange(atlas_volsize[0]), np.arange(atlas_volsize[1]), np.arange(atlas_volsize[2]), indexing='ij')
156 II = torch.tensor(II, device='cpu')
157 JJ = torch.tensor(JJ, device='cpu')
158 KK = torch.tensor(KK, device='cpu')
159
160 # Loop over segmentations and get COGs of ROIs
161 COGs = np.zeros([len(input_files), 4, nlab])
162 OKs = np.zeros([len(input_files), nlab])
163 for i in range(len(input_files)):
164 print('Getting centroids of ROIs: case ' + str(i + 1) + ' of ' + str(len(input_files)))
165 COG = np.zeros([4, nlab])
166 ok = np.ones(nlab)
167 seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
168 label_to_idx = {lab: ii for ii, lab in enumerate(labels)}
169 coords_per_label = [[] for _ in range(nlab)]
170 nz = np.array(np.nonzero(seg_buffer)).T
171 vals = seg_buffer[tuple(nz.T)]
172 valid_mask = np.isin(vals, labels)
173 nz = nz[valid_mask]
174 vals = vals[valid_mask]
175 idxs = np.searchsorted(labels, vals)
176 for ii in range(nlab):
177 coords_per_label[ii] = nz[idxs == ii]
178 # Compute per-label median centroids
179 for ii, vox in enumerate(coords_per_label):
180 if vox.shape[0] > 50:
181 COG[:3, ii] = np.median(vox, axis=0)
182 COG[3, ii] = 1
183 else:
184 ok[ii] = 0
185 COGs[i] = np.matmul(seg_aff, COG)
186 OKs[i] = ok.copy()
187
188 # Linear registration matrices; first rigid, then affine
189 NUM = np.zeros(atlasCOG.shape)
190 DEN = np.zeros(atlasCOG.shape)
191 for i in range(len(input_files)):
192 M = getMrigid(COGs[i, :-1, OKs[i] > 0].T, atlasCOG[:, OKs[i] > 0])
193 NUM[:, OKs[i] > 0] = NUM[:, OKs[i] > 0] + (M @ COGs[i, :, OKs[i] > 0].T)[:-1, :]
194 DEN[:, OKs[i] > 0] = DEN[:, OKs[i] > 0] + 1
195 rigidAtlasCOG = NUM / DEN
196 Ms = np.zeros([len(input_files), 4, 4])
197 for i in range(len(input_files)):
198 Ms[i] = getM(rigidAtlasCOG[:, OKs[i] > 0], COGs[i, :, OKs[i] > 0].T)
199
200 # OK now we can deform to linear space (and compute linear atlas, while at it)
201 NUM = np.zeros(atlas_volsize)
202 DEN = np.zeros(atlas_volsize)
203 for i in range(len(input_files)):
204 print('Deforming to linear space: case ' + str(i + 1) + ' of ' + str(len(input_files)))
205 im_buffer, im_aff, im_hh = load_volume(input_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
206 im_buffer = torch.tensor(im_buffer, device='cpu')
207 voxdim = np.sqrt(np.sum(im_aff[:-1, :-1] ** 2, axis=0))
208 affine = torch.tensor(np.matmul(np.linalg.inv(im_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
209 II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
210 JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
211 KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
212 im_lin = fast_3D_interp_torch(im_buffer, II2, JJ2, KK2, 'linear')
213 if args.use_reliability_maps:
214 lin_dists = torch.sqrt(((II2 - II2.round()) * voxdim[0]) ** 2 +
215 ((JJ2 - JJ2.round()) * voxdim[1]) ** 2 +
216 ((KK2 - KK2.round()) * voxdim[2]) ** 2)
217 lin_rel = torch.exp(-1.0 * lin_dists)
218 else:
219 lin_rel = torch.ones(II2.shape)
220
221 seg_buffer, seg_aff, seg_h = load_volume(seg_files[i], im_only=False, squeeze=True, dtype=None, aff_ref=None)
222 affine = torch.tensor(np.matmul(np.linalg.inv(seg_aff), np.matmul(Ms[i], atlas_aff)), device='cpu')
223 II2 = affine[0, 0] * II + affine[0, 1] * JJ + affine[0, 2] * KK + affine[0, 3]
224 JJ2 = affine[1, 0] * II + affine[1, 1] * JJ + affine[1, 2] * KK + affine[1, 3]
225 KK2 = affine[2, 0] * II + affine[2, 1] * JJ + affine[2, 2] * KK + affine[2, 3]
226 seg_lin = fast_3D_interp_torch(torch.tensor(seg_buffer.copy(), device='cpu'), II2, JJ2, KK2, 'nearest')
227 im_lin[seg_lin == 0] = 0
228 im_lin /= torch.median(im_lin[torch.logical_or(seg_lin==2, seg_lin==41)])
229 np.save(linear_files[i], torch.stack([im_lin, lin_rel]).detach().cpu().numpy())
230 NUM += (im_lin * lin_rel).detach().cpu().numpy()
231 DEN += lin_rel.detach().cpu().numpy()
232
233 print('Computing and saving affine atlas')
234 ATLAS = NUM / (1e-9 + DEN)
235 save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.affine.nii.gz')
236
237 print('Building nonlinear registration model')
238 # Build model
239 source = tf.keras.Input(shape=(*atlas_volsize, 1))
240 target = tf.keras.Input(shape=(*atlas_volsize, 1))
241
242 config = {'name': 'vxm_dense', 'fill_value': None, 'input_model': None, 'unet_half_res': True, 'trg_feats': 1,
243 'src_feats': 1, 'use_probs': False, 'bidir': False, 'int_downsize': 2, 'int_steps': 10,
244 'nb_unet_conv_per_level': 1, 'unet_feat_mult': 1, 'nb_unet_levels': None,
245 'nb_unet_features': [[256, 256, 256, 256], [256, 256, 256, 256, 256, 256]], 'inshape': atlas_volsize}
246 cnn = vxm.networks.VxmDense(**config)
247 cnn.load_weights(path_model_registration_trained, by_name=True)
248 svf1 = cnn([source, target])[1]
249 svf2 = cnn([target, source])[1]
250 pos_svf = KL.Lambda(lambda x: 0.5 * x[0] - 0.5 * x[1])([svf1, svf2])
251 neg_svf = KL.Lambda(lambda x: -x)(pos_svf)
252 pos_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(pos_svf)
253 neg_def_small = vxm.layers.VecInt(method='ss', int_steps=10)(neg_svf)
254 pos_def = vxm.layers.RescaleTransform(2)(pos_def_small)
255 neg_def = vxm.layers.RescaleTransform(2)(neg_def_small)
256 model = tf.keras.Model(inputs=[source, target],
257 outputs=[pos_def, neg_def])
258 model.load_weights(path_model_registration_trained)
259
260 # Global atlas building iterations
261 MAX_IT = 5
262 for it in range(MAX_IT):
263 # Initialize new atlas to zeros
264 NUM = np.zeros_like(ATLAS)
265 DEN = np.zeros_like(ATLAS)
266 for i in range(len(input_files)):
267 print('Iteration ' + str(1 + it) + ' of ' + str(MAX_IT) + ', image ' + str(i+1) + ' of ' + str(len(input_files)))
268 lin = np.load(linear_files[i])
269 pred = model.predict([lin[0:1, ..., np.newaxis] / np.max(lin[0]) ,
270 ATLAS[np.newaxis, ..., np.newaxis]])
271 field = torch.tensor(pred[0], device='cpu').squeeze()
272 II2 = II + field[..., 0]
273 JJ2 = JJ + field[..., 1]
274 KK2 = KK + field[..., 2]
275 deformed_im = fast_3D_interp_torch(torch.tensor(lin[0], device='cpu'), II2 , JJ2, KK2, 'linear')
276 deformed_rel = fast_3D_interp_torch(torch.tensor(lin[1], device='cpu'), II2, JJ2, KK2, 'linear')
277 NUM += (deformed_im * deformed_rel).detach().cpu().numpy()
278 DEN += deformed_rel.detach().cpu().numpy()
279 if it == (MAX_IT-1):
280 T = Ms[i] @ atlas_aff
281 RR = T[0, 0] * II2 + T[0, 1] * JJ2 + T[0, 2] * KK2 + T[0, 3]
282 AA = T[1, 0] * II2 + T[1, 1] * JJ2 + T[1, 2] * KK2 + T[1, 3]
283 SS = T[2, 0] * II2 + T[2, 1] * JJ2 + T[2, 2] * KK2 + T[2, 3]
284 save_volume(torch.stack([RR, AA, SS], dim=-1).detach().cpu().numpy(), atlas_aff, None, reg_files[i])
285 ATLAS = NUM / (1e-9 + DEN)
286 save_volume(ATLAS, atlas_aff, None, args.o + '/atlas.iteration.' + str(it+1) + '.nii.gz')
287
288 # Clean up
289 print('Deleting temporary files')
290 for i in range(len(linear_files)):
291 os.remove(linear_files[i])
292 os.rmdir(tempdir)
293
294 print(' ')
295 print('All done!')
296 print(' ')
297 print('If you use EasyReg in your analysis, please cite:')
298 print('A ready-to-use machine learning tool for symmetric multi-modality registration of brain MRI.')
299 print('JE Iglesias. Scientific Reports, 13, article number 6657 (2023).')
300 print('https://www.nature.com/articles/s41598-023-33781-0')
301 print(' ')
302
303
304 #######################
305 # Auxiliary functions #
306 #######################
307
308
309 def get_list_labels(label_list=None, save_label_list=None, FS_sort=False):
310
311 # load label list if previously computed
312 label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int'))
313
314
315 # sort labels in neutral/left/right according to FS labels
316 n_neutral_labels = 0
317 if FS_sort:
318 neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108,
319 109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
320 251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340,
321 502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530,
322 531, 532, 533, 534, 535, 536, 537]
323 neutral = list()
324 left = list()
325 right = list()
326 for la in label_list:
327 if la in neutral_FS_labels:
328 if la not in neutral:
329 neutral.append(la)
330 elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \
(la == 865) | (20100 < la < 20110):
331 if la not in left:
332 left.append(la)
333 elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \
(la == 866):
334 if la not in right:
335 right.append(la)
336 else:
337 raise Exception('label {} not in our current FS classification, '
338 'please update get_list_labels in utils.py'.format(la))
339 label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)])
340 if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)):
341 n_neutral_labels = len(neutral)
342 else:
343 n_neutral_labels = len(label_list)
344
345 # save labels if specified
346 if save_label_list is not None:
347 np.save(save_label_list, np.int32(label_list))
348
349 if FS_sort:
350 return np.int32(label_list), n_neutral_labels
351 else:
352 return np.int32(label_list), None
353
354 def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None):
355 # convert to list
356 if var is None:
357 return None
358 var = load_array_if_path(var, load_as_numpy=load_as_numpy)
359 if isinstance(var, (int, float, np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64)):
360 var = [var]
361 elif isinstance(var, tuple):
362 var = list(var)
363 elif isinstance(var, np.ndarray):
364 if var.shape == (1,):
365 var = [var[0]]
366 else:
367 var = np.squeeze(var).tolist()
368 elif isinstance(var, str):
369 var = [var]
370 elif isinstance(var, bool):
371 var = [var]
372 if isinstance(var, list):
373 if length is not None:
374 if len(var) == 1:
375 var = var * length
376 elif len(var) != length:
377 raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, '
378 'had {1}'.format(length, var))
379 else:
380 raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array')
381
382 # convert items type
383 if dtype is not None:
384 if dtype == 'int':
385 var = [int(v) for v in var]
386 elif dtype == 'float':
387 var = [float(v) for v in var]
388 elif dtype == 'bool':
389 var = [bool(v) for v in var]
390 elif dtype == 'str':
391 var = [str(v) for v in var]
392 else:
393 raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype))
394 return var
395
396 def load_array_if_path(var, load_as_numpy=True):
397 if (isinstance(var, str)) & load_as_numpy:
398 assert os.path.isfile(var), 'No such path: %s' % var
399 var = np.load(var)
400 return var
401
402
403 def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None):
404
405 assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume
406
407 if path_volume.endswith(('.nii', '.nii.gz', '.mgz')):
408 x = nib.load(path_volume)
409 if squeeze:
410 volume = np.squeeze(x.get_fdata())
411 else:
412 volume = x.get_fdata()
413 aff = x.affine
414 header = x.header
415 else: # npz
416 volume = np.load(path_volume)['vol_data']
417 if squeeze:
418 volume = np.squeeze(volume)
419 aff = np.eye(4)
420 header = nib.Nifti1Header()
421 if dtype is not None:
422 if 'int' in dtype:
423 volume = np.round(volume)
424 volume = volume.astype(dtype=dtype)
425
426 # align image to reference affine matrix
427 if aff_ref is not None:
428 n_dims, _ = get_dims(list(volume.shape), max_channels=10)
429 volume, aff = align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims)
430
431 if im_only:
432 return volume
433 else:
434 return volume, aff, header
435
436
437
438
439 def preprocess(path_image, n_levels=5, crop=None, min_pad=None, path_resample=None):
440 # read image and corresponding info
441 im, _, aff, n_dims, n_channels, h, im_res = get_volume_info(path_image, True)
442 if n_dims < 3:
443 sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
444 elif n_dims == 4 and n_channels == 1:
445 n_dims = 3
446 im = im[..., 0]
447 elif n_dims > 3:
448 sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
449 elif n_channels > 1:
450 print('WARNING: detected more than 1 channel, only keeping the first channel.')
451 im = im[..., 0]
452
453 # resample image if necessary
454 if np.any((im_res > 1.05) | (im_res < 0.95)):
455 im_res = np.array([1.] * 3)
456 im, aff = resample_volume(im, aff, im_res)
457 if path_resample is not None:
458 save_volume(im, aff, h, path_resample)
459
460 # align image
461 im = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False)
462 shape = list(im.shape[:n_dims])
463
464 # crop image if necessary
465 if crop is not None:
466 crop = reformat_to_list(crop, length=n_dims, dtype='int')
467 crop_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in crop]
468 im, crop_idx = crop_volume(im, cropping_shape=crop_shape, return_crop_idx=True)
469 else:
470 crop_idx = None
471
472 # normalise image
473 im = rescale_volume(im, new_min=0, new_max=1, min_percentile=0.5, max_percentile=99.5)
474
475 # pad image
476 input_shape = im.shape[:n_dims]
477 pad_shape = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in input_shape]
478 min_pad = reformat_to_list(min_pad, length=n_dims, dtype='int')
479 min_pad = [find_closest_number_divisible_by_m(s, 2 ** n_levels, 'higher') for s in min_pad]
480 pad_shape = np.maximum(pad_shape, min_pad)
481 im, pad_idx = pad_volume(im, padding_shape=pad_shape, return_pad_idx=True)
482
483 # add batch and channel axes
484 im = add_axis(im, axis=[0, -1])
485
486 return im, aff, h, im_res, shape, pad_idx, crop_idx
487
488
489 def resample_volume(volume, aff, new_vox_size, interpolation='linear'):
490 pixdim = np.sqrt(np.sum(aff * aff, axis=0))[:-1]
491 new_vox_size = np.array(new_vox_size)
492 factor = pixdim / new_vox_size
493 sigmas = 0.25 / factor
494 sigmas[factor > 1] = 0 # don't blur if upsampling
495
496 volume_filt = gaussian_filter(volume, sigmas)
497
498 # volume2 = zoom(volume_filt, factor, order=1, mode='reflect', prefilter=False)
499 x = np.arange(0, volume_filt.shape[0])
500 y = np.arange(0, volume_filt.shape[1])
501 z = np.arange(0, volume_filt.shape[2])
502
503 start = - (factor - 1) / (2 * factor)
504 step = 1.0 / factor
505 stop = start + step * np.ceil(volume_filt.shape * factor)
506
507 xi = np.arange(start=start[0], stop=stop[0], step=step[0])
508 yi = np.arange(start=start[1], stop=stop[1], step=step[1])
509 zi = np.arange(start=start[2], stop=stop[2], step=step[2])
510 xi[xi < 0] = 0
511 yi[yi < 0] = 0
512 zi[zi < 0] = 0
513 xi[xi > (volume_filt.shape[0] - 1)] = volume_filt.shape[0] - 1
514 yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1
515 zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1
516
517 xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=False)
518 xig = torch.tensor(xig, device='cpu')
519 yig = torch.tensor(yig, device='cpu')
520 zig = torch.tensor(zig, device='cpu')
521 volume2 = fast_3D_interp_torch(torch.tensor(volume_filt, device='cpu'), xig, yig, zig, 'linear')
522
523 aff2 = aff.copy()
524 for c in range(3):
525 aff2[:-1, c] = aff2[:-1, c] / factor[c]
526 aff2[:-1, -1] = aff2[:-1, -1] - np.matmul(aff2[:-1, :-1], 0.5 * (factor - 1))
527
528 return volume2.numpy(), aff2
529
530 def find_closest_number_divisible_by_m(n, m, answer_type='lower'):
531 if n % m == 0:
532 return n
533 else:
534 q = int(n / m)
535 lower = q * m
536 higher = (q + 1) * m
537 if answer_type == 'lower':
538 return lower
539 elif answer_type == 'higher':
540 return higher
541 elif answer_type == 'closer':
542 return lower if (n - lower) < (higher - n) else higher
543 else:
544 sf.system.fatal('answer_type should be lower, higher, or closer, had : %s' % answer_type)
545
546
547
548 def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10):
549
550 im, aff, header = load_volume(path_volume, im_only=False)
551
552 # understand if image is multichannel
553 im_shape = list(im.shape)
554 n_dims, n_channels = get_dims(im_shape, max_channels=max_channels)
555 im_shape = im_shape[:n_dims]
556
557 # get labels res
558 if '.nii' in path_volume:
559 data_res = np.array(header['pixdim'][1:n_dims + 1])
560 elif '.mgz' in path_volume:
561 data_res = np.array(header['delta']) # mgz image
562 else:
563 data_res = np.array([1.0] * n_dims)
564
565 # align to given affine matrix
566 if aff_ref is not None:
567 ras_axes = get_ras_axes(aff, n_dims=n_dims)
568 ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
569 im = align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims)
570 im_shape = np.array(im_shape)
571 data_res = np.array(data_res)
572 im_shape[ras_axes_ref] = im_shape[ras_axes]
573 data_res[ras_axes_ref] = data_res[ras_axes]
574 im_shape = im_shape.tolist()
575
576 # return info
577 if return_volume:
578 return im, im_shape, aff, n_dims, n_channels, header, data_res
579 else:
580 return im_shape, aff, n_dims, n_channels, header, data_res
581
582 def get_dims(shape, max_channels=10):
583 if shape[-1] <= max_channels:
584 n_dims = len(shape) - 1
585 n_channels = shape[-1]
586 else:
587 n_dims = len(shape)
588 n_channels = 1
589 return n_dims, n_channels
590
591
592 def get_ras_axes(aff, n_dims=3):
593 aff_inverted = np.linalg.inv(aff)
594 img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0)
595 for i in range(n_dims):
596 if i not in img_ras_axes:
597 unique, counts = np.unique(img_ras_axes, return_counts=True)
598 incorrect_value = unique[np.argmax(counts)]
599 img_ras_axes[np.where(img_ras_axes == incorrect_value)[0][-1]] = i
600
601 return img_ras_axes
602
603 def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True):
604
605 # work on copy
606 new_volume = volume.copy() if return_copy else volume
607 aff_flo = aff.copy()
608
609 # default value for aff_ref
610 if aff_ref is None:
611 aff_ref = np.eye(4)
612
613 # extract ras axes
614 if n_dims is None:
615 n_dims, _ = get_dims(new_volume.shape)
616 ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
617 ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims)
618
619 # align axes
620 aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo]
621 for i in range(n_dims):
622 if ras_axes_flo[i] != ras_axes_ref[i]:
623 new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i])
624 swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i])
625 ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx]
626
627 # align directions
628 dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0)
629 for i in range(n_dims):
630 if dot_products[i] < 0:
631 new_volume = np.flip(new_volume, axis=i)
632 aff_flo[:, i] = - aff_flo[:, i]
633 aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1)
634
635 if return_aff:
636 return new_volume, aff_flo
637 else:
638 return new_volume
639
640 def build_seg_model(model_file_segmentation,
641 model_file_parcellation,
642 labels_segmentation,
643 labels_parcellation):
644
645 if not os.path.isfile(model_file_segmentation):
646 sf.system.fatal("The provided model path does not exist.")
647
648 # get labels
649 n_labels_seg = len(labels_segmentation)
650
651 # build UNet
652 net = unet(nb_features=24,
653 input_shape=[None, None, None, 1],
654 nb_levels=5,
655 conv_size=3,
656 nb_labels=n_labels_seg,
657 feat_mult=2,
658 activation='elu',
659 nb_conv_per_level=2,
660 batch_norm=-1,
661 name='unet')
662 net.load_weights(model_file_segmentation, by_name=True)
663 input_image = net.inputs[0]
664 name_segm_prediction_layer = 'unet_prediction'
665
666 # smooth posteriors
667 last_tensor = net.output
668 last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
669 last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
670 net = keras.Model(inputs=net.inputs, outputs=last_tensor)
671
672 # add aparc segmenter
673 n_labels_parcellation = len(labels_parcellation)
674
675 last_tensor = net.output
676 last_tensor = KL.Lambda(lambda x: tf.cast(tf.argmax(x, axis=-1), 'int32'))(last_tensor)
677 last_tensor = ConvertLabels(np.arange(n_labels_seg), labels_segmentation)(last_tensor)
678 parcellation_masking_values = np.array([1 if ((ll == 3) | (ll == 42)) else 0 for ll in labels_segmentation])
679 last_tensor = ConvertLabels(labels_segmentation, parcellation_masking_values)(last_tensor)
680 last_tensor = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, 'int32'), depth=2, axis=-1))(last_tensor)
681 last_tensor = KL.Lambda(lambda x: tf.cast(tf.concat(x, axis=-1), 'float32'))([input_image, last_tensor])
682 net = keras.Model(inputs=net.inputs, outputs=last_tensor)
683
684 # build UNet
685 net = unet(nb_features=24,
686 input_shape=[None, None, None, 3],
687 nb_levels=5,
688 conv_size=3,
689 nb_labels=n_labels_parcellation,
690 feat_mult=2,
691 activation='elu',
692 nb_conv_per_level=2,
693 batch_norm=-1,
694 name='unet_parc',
695 input_model=net)
696 net.load_weights(model_file_parcellation, by_name=True)
697
698 # smooth predictions
699 last_tensor = net.output
700 last_tensor._keras_shape = tuple(last_tensor.get_shape().as_list())
701 last_tensor = GaussianBlur(sigma=0.5)(last_tensor)
702 net = keras.Model(inputs=net.inputs, outputs=[net.get_layer(name_segm_prediction_layer).output, last_tensor])
703
704 return net
705
706 def unet(nb_features,
707 input_shape,
708 nb_levels,
709 conv_size,
710 nb_labels,
711 name='unet',
712 prefix=None,
713 feat_mult=1,
714 pool_size=2,
715 padding='same',
716 dilation_rate_mult=1,
717 activation='elu',
718 skip_n_concatenations=0,
719 use_residuals=False,
720 final_pred_activation='softmax',
721 nb_conv_per_level=1,
722 layer_nb_feats=None,
723 conv_dropout=0,
724 batch_norm=None,
725 input_model=None):
726
727 # naming
728 model_name = name
729 if prefix is None:
730 prefix = model_name
731
732 # volume size data
733 ndims = len(input_shape) - 1
734 if isinstance(pool_size, int):
735 pool_size = (pool_size,) * ndims
736
737 # get encoding model
738 enc_model = conv_enc(nb_features,
739 input_shape,
740 nb_levels,
741 conv_size,
742 name=model_name,
743 prefix=prefix,
744 feat_mult=feat_mult,
745 pool_size=pool_size,
746 padding=padding,
747 dilation_rate_mult=dilation_rate_mult,
748 activation=activation,
749 use_residuals=use_residuals,
750 nb_conv_per_level=nb_conv_per_level,
751 layer_nb_feats=layer_nb_feats,
752 conv_dropout=conv_dropout,
753 batch_norm=batch_norm,
754 input_model=input_model)
755
756 # get decoder
757 # use_skip_connections=True makes it a u-net
758 lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
759 dec_model = conv_dec(nb_features,
760 None,
761 nb_levels,
762 conv_size,
763 nb_labels,
764 name=model_name,
765 prefix=prefix,
766 feat_mult=feat_mult,
767 pool_size=pool_size,
768 use_skip_connections=True,
769 skip_n_concatenations=skip_n_concatenations,
770 padding=padding,
771 dilation_rate_mult=dilation_rate_mult,
772 activation=activation,
773 use_residuals=use_residuals,
774 final_pred_activation=final_pred_activation,
775 nb_conv_per_level=nb_conv_per_level,
776 batch_norm=batch_norm,
777 layer_nb_feats=lnf,
778 conv_dropout=conv_dropout,
779 input_model=enc_model)
780 final_model = dec_model
781
782 return final_model
783
784 def conv_enc(nb_features,
785 input_shape,
786 nb_levels,
787 conv_size,
788 name=None,
789 prefix=None,
790 feat_mult=1,
791 pool_size=2,
792 dilation_rate_mult=1,
793 padding='same',
794 activation='elu',
795 layer_nb_feats=None,
796 use_residuals=False,
797 nb_conv_per_level=2,
798 conv_dropout=0,
799 batch_norm=None,
800 input_model=None):
801
802 # naming
803 model_name = name
804 if prefix is None:
805 prefix = model_name
806
807 # first layer: input
808 name = '%s_input' % prefix
809 if input_model is None:
810 input_tensor = KL.Input(shape=input_shape, name=name)
811 last_tensor = input_tensor
812 else:
813 input_tensor = input_model.inputs
814 last_tensor = input_model.outputs
815 if isinstance(last_tensor, list):
816 last_tensor = last_tensor[0]
817
818 # volume size data
819 ndims = len(input_shape) - 1
820 if isinstance(pool_size, int):
821 pool_size = (pool_size,) * ndims
822
823 # prepare layers
824 convL = getattr(KL, 'Conv%dD' % ndims)
825 conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
826 maxpool = getattr(KL, 'MaxPooling%dD' % ndims)
827
828 # down arm:
829 # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
830 lfidx = 0 # level feature index
831 for level in range(nb_levels):
832 lvl_first_tensor = last_tensor
833 nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
834 conv_kwargs['dilation_rate'] = dilation_rate_mult ** level
835
836 for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end
837 if layer_nb_feats is not None: # None or List of all the feature numbers
838 nb_lvl_feats = layer_nb_feats[lfidx]
839 lfidx += 1
840
841 name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
842 if conv < (nb_conv_per_level - 1) or (not use_residuals):
843 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
844 else: # no activation
845 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
846
847 if conv_dropout > 0:
848 # conv dropout along feature space only
849 name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
850 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
851 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
852
853 if use_residuals:
854 convarm_layer = last_tensor
855
856 # the "add" layer is the original input
857 # However, it may not have the right number of features to be added
858 nb_feats_in = lvl_first_tensor.get_shape()[-1]
859 nb_feats_out = convarm_layer.get_shape()[-1]
860 add_layer = lvl_first_tensor
861 if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
862 name = '%s_expand_down_merge_%d' % (prefix, level)
863 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
864 add_layer = last_tensor
865
866 if conv_dropout > 0:
867 name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv)
868 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
869
870 name = '%s_res_down_merge_%d' % (prefix, level)
871 last_tensor = KL.add([add_layer, convarm_layer], name=name)
872
873 name = '%s_res_down_merge_act_%d' % (prefix, level)
874 last_tensor = KL.Activation(activation, name=name)(last_tensor)
875
876 if batch_norm is not None:
877 name = '%s_bn_down_%d' % (prefix, level)
878 last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
879
880 # max pool if we're not at the last level
881 if level < (nb_levels - 1):
882 name = '%s_maxpool_%d' % (prefix, level)
883 last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)
884
885 # create the model and return
886 model = keras.Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
887 return model
888
889
890 def conv_dec(nb_features,
891 input_shape,
892 nb_levels,
893 conv_size,
894 nb_labels,
895 name=None,
896 prefix=None,
897 feat_mult=1,
898 pool_size=2,
899 use_skip_connections=False,
900 skip_n_concatenations=0,
901 padding='same',
902 dilation_rate_mult=1,
903 activation='elu',
904 use_residuals=False,
905 final_pred_activation='softmax',
906 nb_conv_per_level=2,
907 layer_nb_feats=None,
908 batch_norm=None,
909 conv_dropout=0,
910 input_model=None):
911
912 # naming
913 model_name = name
914 if prefix is None:
915 prefix = model_name
916
917 # if using skip connections, make sure need to use them.
918 if use_skip_connections:
919 assert input_model is not None, "is using skip connections, tensors dictionary is required"
920
921 # first layer: input
922 input_name = '%s_input' % prefix
923 if input_model is None:
924 input_tensor = KL.Input(shape=input_shape, name=input_name)
925 last_tensor = input_tensor
926 else:
927 input_tensor = input_model.input
928 last_tensor = input_model.output
929 input_shape = last_tensor.shape.as_list()[1:]
930
931 # vol size info
932 ndims = len(input_shape) - 1
933 if isinstance(pool_size, int):
934 if ndims > 1:
935 pool_size = (pool_size,) * ndims
936
937 # prepare layers
938 convL = getattr(KL, 'Conv%dD' % ndims)
939 conv_kwargs = {'padding': padding, 'activation': activation}
940 upsample = getattr(KL, 'UpSampling%dD' % ndims)
941
942 # up arm:
943 # nb_levels - 1 layers of Deconvolution3D
944 # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
945 lfidx = 0
946 for level in range(nb_levels - 1):
947 nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
948 conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)
949
950 # upsample matching the max pooling layers size
951 name = '%s_up_%d' % (prefix, nb_levels + level)
952 last_tensor = upsample(size=pool_size, name=name)(last_tensor)
953 up_tensor = last_tensor
954
955 # merge layers combining previous layer
956 if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
957 conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
958 cat_tensor = input_model.get_layer(conv_name).output
959 name = '%s_merge_%d' % (prefix, nb_levels + level)
960 last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)
961
962 # convolution layers
963 for conv in range(nb_conv_per_level):
964 if layer_nb_feats is not None:
965 nb_lvl_feats = layer_nb_feats[lfidx]
966 lfidx += 1
967
968 name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
969 if conv < (nb_conv_per_level - 1) or (not use_residuals):
970 last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
971 else:
972 last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
973
974 if conv_dropout > 0:
975 name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
976 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
977 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
978
979 # residual block
980 if use_residuals:
981
982 # the "add" layer is the original input
983 # However, it may not have the right number of features to be added
984 add_layer = up_tensor
985 nb_feats_in = add_layer.get_shape()[-1]
986 nb_feats_out = last_tensor.get_shape()[-1]
987 if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
988 name = '%s_expand_up_merge_%d' % (prefix, level)
989 add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)
990
991 if conv_dropout > 0:
992 name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv)
993 noise_shape = [None, *[1] * ndims, nb_lvl_feats]
994 last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
995
996 name = '%s_res_up_merge_%d' % (prefix, level)
997 last_tensor = KL.add([last_tensor, add_layer], name=name)
998
999 name = '%s_res_up_merge_act_%d' % (prefix, level)
1000 last_tensor = KL.Activation(activation, name=name)(last_tensor)
1001
1002 if batch_norm is not None:
1003 name = '%s_bn_up_%d' % (prefix, level)
1004 last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
1005
1006 # Compute likelyhood prediction (no activation yet)
1007 name = '%s_likelihood' % prefix
1008 last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
1009 like_tensor = last_tensor
1010
1011 # output prediction layer
1012 # we use a softmax to compute P(L_x|I) where x is each location
1013 if final_pred_activation == 'softmax':
1014 name = '%s_prediction' % prefix
1015 softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
1016 pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor)
1017
1018 # otherwise create a layer that does nothing.
1019 else:
1020 name = '%s_prediction' % prefix
1021 pred_tensor = KL.Activation('linear', name=name)(like_tensor)
1022
1023 # create the model and retun
1024 model = keras.Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
1025 return model
1026
1027 def postprocess(post_patch_seg, post_patch_parc, shape, pad_idx, crop_idx,
1028 labels_segmentation, labels_parcellation, aff, im_res):
1029
1030 # get posteriors
1031 post_patch_seg = np.squeeze(post_patch_seg)
1032 post_patch_seg = crop_volume_with_idx(post_patch_seg, pad_idx, n_dims=3, return_copy=False)
1033
1034 # keep biggest connected component
1035 tmp_post_patch_seg = post_patch_seg[..., 1:]
1036 post_patch_seg_mask = np.sum(tmp_post_patch_seg, axis=-1) > 0.25
1037 post_patch_seg_mask = get_largest_connected_component(post_patch_seg_mask)
1038 post_patch_seg_mask = np.stack([post_patch_seg_mask]*tmp_post_patch_seg.shape[-1], axis=-1)
1039 tmp_post_patch_seg = mask_volume(tmp_post_patch_seg, mask=post_patch_seg_mask, return_copy=False)
1040 post_patch_seg[..., 1:] = tmp_post_patch_seg
1041
1042 # reset posteriors to zero outside the largest connected component of each topological class
1043 post_patch_seg_mask = post_patch_seg > 0.2
1044 post_patch_seg[..., 1:] *= post_patch_seg_mask[..., 1:]
1045
1046 # get hard segmentation
1047 post_patch_seg /= np.sum(post_patch_seg, axis=-1)[..., np.newaxis]
1048 seg_patch = labels_segmentation[post_patch_seg.argmax(-1).astype('int32')].astype('int32')
1049
1050 # postprocess parcellation
1051 post_patch_parc = np.squeeze(post_patch_parc)
1052 post_patch_parc = crop_volume_with_idx(post_patch_parc, pad_idx, n_dims=3, return_copy=False)
1053 mask = (seg_patch == 3) | (seg_patch == 42)
1054 post_patch_parc[..., 0] = np.ones_like(post_patch_parc[..., 0])
1055 post_patch_parc[..., 0] = mask_volume(post_patch_parc[..., 0], mask=mask < 0.1, return_copy=False)
1056 post_patch_parc /= np.sum(post_patch_parc, axis=-1)[..., np.newaxis]
1057 parc_patch = labels_parcellation[post_patch_parc.argmax(-1).astype('int32')].astype('int32')
1058 seg_patch[mask] = parc_patch[mask]
1059
1060 # paste patches back to matrix of original image size
1061 if crop_idx is not None:
1062 # we need to go through this because of the posteriors of the background, otherwise pad_volume would work
1063 seg = np.zeros(shape=shape, dtype='int32')
1064 posteriors = np.zeros(shape=[*shape, labels_segmentation.shape[0]])
1065 posteriors[..., 0] = np.ones(shape) # place background around patch
1066 seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch
1067 posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch_seg
1068 else:
1069 seg = seg_patch
1070 posteriors = post_patch_seg
1071
1072 # align prediction back to first orientation
1073 seg = align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1074 posteriors = align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
1075
1076 # compute volumes
1077 volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1078 volumes = np.concatenate([np.array([np.sum(volumes)]), volumes])
1079 if post_patch_parc is not None:
1080 volumes_parc = np.sum(post_patch_parc[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
1081 total_volume_cortex = np.sum(volumes[np.where((labels_segmentation == 3) | (labels_segmentation == 42))[0] - 1])
1082 volumes_parc = volumes_parc / np.sum(volumes_parc) * total_volume_cortex
1083 volumes = np.concatenate([volumes, volumes_parc])
1084 volumes = np.around(volumes * np.prod(im_res), 3)
1085
1086 return seg, posteriors, volumes
1087
1088 def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
1089 mkdir(os.path.dirname(path))
1090 if '.npz' in path:
1091 np.savez_compressed(path, vol_data=volume)
1092 else:
1093 if header is None:
1094 header = nib.Nifti1Header()
1095 if isinstance(aff, str):
1096 if aff == 'FS':
1097 aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
1098 elif aff is None:
1099 aff = np.eye(4)
1100 nifty = nib.Nifti1Image(volume, aff, header)
1101 if dtype is not None:
1102 if 'int' in dtype:
1103 volume = np.round(volume)
1104 volume = volume.astype(dtype=dtype)
1105 nifty.set_data_dtype(dtype)
1106 if res is not None:
1107 if n_dims is None:
1108 n_dims, _ = get_dims(volume.shape)
1109 res = reformat_to_list(res, length=n_dims, dtype=None)
1110 nifty.header.set_zooms(res)
1111 nib.save(nifty, path)
1112
1113
1114
1115 def mkdir(path_dir):
1116
1117 if len(path_dir)>0:
1118 if path_dir[-1] == '/':
1119 path_dir = path_dir[:-1]
1120 if not os.path.isdir(path_dir):
1121 list_dir_to_create = [path_dir]
1122 while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])):
1123 list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1]))
1124 for dir_to_create in reversed(list_dir_to_create):
1125 os.mkdir(dir_to_create)
1126
1127
1128 def getM(ref, mov):
1129 zmat = np.zeros(ref.shape[::-1])
1130 zcol = np.zeros([ref.shape[1], 1])
1131 ocol = np.ones([ref.shape[1], 1])
1132 zero = np.zeros(zmat.shape)
1133 A = np.concatenate([
1134 np.concatenate([np.transpose(ref), zero, zero, ocol, zcol, zcol], axis=1),
1135 np.concatenate([zero, np.transpose(ref), zero, zcol, ocol, zcol], axis=1),
1136 np.concatenate([zero, zero, np.transpose(ref), zcol, zcol, ocol], axis=1)], axis=0)
1137 b = np.concatenate([np.transpose(mov[0, :]), np.transpose(mov[1, :]), np.transpose(mov[2, :])], axis=0)
1138 x = np.matmul(np.linalg.inv(np.matmul(np.transpose(A), A)), np.matmul(np.transpose(A), b))
1139 M = np.stack([
1140 [x[0], x[1], x[2], x[9]],
1141 [x[3], x[4], x[5], x[10]],
1142 [x[6], x[7], x[8], x[11]],
1143 [0, 0, 0, 1]])
1144 return M
1145
1146 def getMrigid(A, B):
1147 centroid_A = np.mean(A, axis=1, keepdims=True)
1148 centroid_B = np.mean(B, axis=1, keepdims=True)
1149 A_centered = A - centroid_A
1150 B_centered = B - centroid_B
1151 H = A_centered @ B_centered.T
1152 U, S, Vt = np.linalg.svd(H)
1153 R = Vt.T @ U.T
1154 if np.linalg.det(R) < 0:
1155 Vt[2, :] *= -1
1156 R = Vt.T @ U.T
1157 t = centroid_B - R @ centroid_A
1158 T = np.eye(4)
1159 T[:3, :3] = R
1160 T[:3, 3] = t.flatten()
1161 return T
1162
1163
1164 def fast_3D_interp_torch(X, II, JJ, KK, mode):
1165 if mode=='nearest':
1166 IIr = torch.round(II).long()
1167 JJr = torch.round(JJ).long()
1168 KKr = torch.round(KK).long()
1169 IIr[IIr < 0] = 0
1170 JJr[JJr < 0] = 0
1171 KKr[KKr < 0] = 0
1172 IIr[IIr > (X.shape[0] - 1)] = (X.shape[0] - 1)
1173 JJr[JJr > (X.shape[1] - 1)] = (X.shape[1] - 1)
1174 KKr[KKr > (X.shape[2] - 1)] = (X.shape[2] - 1)
1175 Y = X[IIr, JJr, KKr]
1176 elif mode=='linear':
1177 ok = (II>0) & (JJ>0) & (KK>0) & (II<=X.shape[0]-1) & (JJ<=X.shape[1]-1) & (KK<=X.shape[2]-1)
1178 IIv = II[ok]
1179 JJv = JJ[ok]
1180 KKv = KK[ok]
1181
1182 fx = torch.floor(IIv).long()
1183 cx = fx + 1
1184 cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1185 wcx = IIv - fx
1186 wfx = 1 - wcx
1187
1188 fy = torch.floor(JJv).long()
1189 cy = fy + 1
1190 cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1191 wcy = JJv - fy
1192 wfy = 1 - wcy
1193
1194 fz = torch.floor(KKv).long()
1195 cz = fz + 1
1196 cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1197 wcz = KKv - fz
1198 wfz = 1 - wcz
1199
1200 c000 = X[fx, fy, fz]
1201 c100 = X[cx, fy, fz]
1202 c010 = X[fx, cy, fz]
1203 c110 = X[cx, cy, fz]
1204 c001 = X[fx, fy, cz]
1205 c101 = X[cx, fy, cz]
1206 c011 = X[fx, cy, cz]
1207 c111 = X[cx, cy, cz]
1208
1209 c00 = c000 * wfx + c100 * wcx
1210 c01 = c001 * wfx + c101 * wcx
1211 c10 = c010 * wfx + c110 * wcx
1212 c11 = c011 * wfx + c111 * wcx
1213
1214 c0 = c00 * wfy + c10 * wcy
1215 c1 = c01 * wfy + c11 * wcy
1216
1217 c = c0 * wfz + c1 * wcz
1218
1219 Y = torch.zeros(II.shape, device='cpu')
1220 Y[ok] = c.float()
1221
1222 else:
1223 sf.system.fatal('mode must be linear or nearest')
1224
1225 return Y
1226
1227
1228
1229 def fast_3D_interp_field_torch(X, II, JJ, KK):
1230
1231 ok = (II > 0) & (JJ > 0) & (KK > 0) & (II <= X.shape[0] - 1) & (JJ <= X.shape[1] - 1) & (KK <= X.shape[2] - 1)
1232 IIv = II[ok]
1233 JJv = JJ[ok]
1234 KKv = KK[ok]
1235
1236 fx = torch.floor(IIv).long()
1237 cx = fx + 1
1238 cx[cx > (X.shape[0] - 1)] = (X.shape[0] - 1)
1239 wcx = IIv - fx
1240 wfx = 1 - wcx
1241
1242 fy = torch.floor(JJv).long()
1243 cy = fy + 1
1244 cy[cy > (X.shape[1] - 1)] = (X.shape[1] - 1)
1245 wcy = JJv - fy
1246 wfy = 1 - wcy
1247
1248 fz = torch.floor(KKv).long()
1249 cz = fz + 1
1250 cz[cz > (X.shape[2] - 1)] = (X.shape[2] - 1)
1251 wcz = KKv - fz
1252 wfz = 1 - wcz
1253
1254 Y = torch.zeros([*II.shape, 3], device='cpu')
1255 for channel in range(3):
1256
1257 Xc = X[:, :, :, channel]
1258
1259 c000 = Xc[fx, fy, fz]
1260 c100 = Xc[cx, fy, fz]
1261 c010 = Xc[fx, cy, fz]
1262 c110 = Xc[cx, cy, fz]
1263 c001 = Xc[fx, fy, cz]
1264 c101 = Xc[cx, fy, cz]
1265 c011 = Xc[fx, cy, cz]
1266 c111 = Xc[cx, cy, cz]
1267
1268 c00 = c000 * wfx + c100 * wcx
1269 c01 = c001 * wfx + c101 * wcx
1270 c10 = c010 * wfx + c110 * wcx
1271 c11 = c011 * wfx + c111 * wcx
1272
1273 c0 = c00 * wfy + c10 * wcy
1274 c1 = c01 * wfy + c11 * wcy
1275
1276 c = c0 * wfz + c1 * wcz
1277
1278 Yc = torch.zeros(II.shape, device='cpu')
1279 Yc[ok] = c.float()
1280
1281 Y[:, :, :, channel] = Yc
1282
1283 return Y
1284
1285
1286 def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=True):
1287
1288 # get info
1289 new_volume = volume.copy() if return_copy else volume
1290 n_dims = int(np.array(crop_idx).shape[0] / 2) if n_dims is None else n_dims
1291
1292 # crop image
1293 if n_dims == 2:
1294 new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...]
1295 elif n_dims == 3:
1296 new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...]
1297 else:
1298 sf.system.fatal('cannot crop volumes with more than 3 dimensions')
1299
1300 if aff is not None:
1301 aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3]
1302 return new_volume, aff
1303 else:
1304 return new_volume
1305
1306
1307 def get_largest_connected_component(mask, structure=None):
1308 components, n_components = scipy_label(mask, structure)
1309 return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy()
1310
1311
1312 def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, masking_value=0,
1313 return_mask=False, return_copy=True):
1314
1315 # get info
1316 new_volume = volume.copy() if return_copy else volume
1317 vol_shape = list(new_volume.shape)
1318 n_dims, n_channels = get_dims(vol_shape)
1319
1320 # get mask and erode/dilate it
1321 if mask is None:
1322 mask = new_volume >= threshold
1323 else:
1324 assert list(mask.shape[:n_dims]) == vol_shape[:n_dims], 'mask should have shape {0}, or {1}, had {2}'.format(
1325 vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape))
1326 mask = mask > 0
1327 if dilate > 0:
1328 dilate_struct = build_binary_structure(dilate, n_dims)
1329 mask_to_apply = binary_dilation(mask, dilate_struct)
1330 else:
1331 mask_to_apply = mask
1332 if erode > 0:
1333 erode_struct = build_binary_structure(erode, n_dims)
1334 mask_to_apply = binary_erosion(mask_to_apply, erode_struct)
1335 if fill_holes:
1336 mask_to_apply = binary_fill_holes(mask_to_apply)
1337
1338 # replace values outside of mask by padding_char
1339 if mask_to_apply.shape == new_volume.shape:
1340 new_volume[np.logical_not(mask_to_apply)] = masking_value
1341 else:
1342 new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = masking_value
1343
1344 if return_mask:
1345 return new_volume, mask_to_apply
1346 else:
1347 return new_volume
1348
1349
1350 def build_binary_structure(connectivity, n_dims, shape=None):
1351 if shape is None:
1352 shape = [connectivity * 2 + 1] * n_dims
1353 else:
1354 shape = reformat_to_list(shape, length=n_dims)
1355 dist = np.ones(shape)
1356 center = tuple([tuple([int(s / 2)]) for s in shape])
1357 dist[center] = 0
1358 dist = distance_transform_edt(dist)
1359 struct = (dist <= connectivity) * 1
1360 return struct
1361
1362
1363
1364 def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'):
1365
1366 assert (cropping_margin is not None) | (cropping_shape is not None), \
'cropping_margin or cropping_shape should be provided'
1367 assert not ((cropping_margin is not None) & (cropping_shape is not None)), \
'only one of cropping_margin or cropping_shape should be provided'
1368
1369 # get info
1370 new_volume = volume.copy()
1371 vol_shape = new_volume.shape
1372 n_dims, _ = get_dims(vol_shape)
1373
1374 # find cropping indices
1375 if cropping_margin is not None:
1376 cropping_margin = reformat_to_list(cropping_margin, length=n_dims)
1377 do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin)
1378 min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)]
1379 max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)]
1380 else:
1381 cropping_shape = reformat_to_list(cropping_shape, length=n_dims)
1382 if mode == 'center':
1383 min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0)
1384 max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)],
1385 np.array(vol_shape)[:n_dims])
1386 elif mode == 'random':
1387 crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0)
1388 min_crop_idx = np.random.randint(0, high=crop_max_val + 1)
1389 max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims])
1390 else:
1391 raise ValueError('mode should be either "center" or "random", had %s' % mode)
1392 crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)])
1393
1394 # crop volume
1395 if n_dims == 2:
1396 new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...]
1397 elif n_dims == 3:
1398 new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...]
1399
1400 # sort outputs
1401 output = [new_volume]
1402 if aff is not None:
1403 aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ np.array(min_crop_idx)
1404 output.append(aff)
1405 if return_crop_idx:
1406 output.append(crop_idx)
1407 return output[0] if len(output) == 1 else tuple(output)
1408
1409
1410
1411 def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2., max_percentile=98., use_positive_only=False):
1412
1413 # select only positive intensities
1414 new_volume = volume.copy()
1415 intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten()
1416
1417 # define min and max intensities in original image for normalisation
1418 robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile)
1419 robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile)
1420
1421 # trim values outside range
1422 new_volume = np.clip(new_volume, robust_min, robust_max)
1423
1424 # rescale image
1425 if robust_min != robust_max:
1426 return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min)
1427 else: # avoid dividing by zero
1428 return np.zeros_like(new_volume)
1429
1430
1431
1432
1433 def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=False):
1434 # get info
1435 new_volume = volume.copy()
1436 vol_shape = new_volume.shape
1437 n_dims, n_channels = get_dims(vol_shape)
1438 padding_shape = reformat_to_list(padding_shape, length=n_dims, dtype='int')
1439
1440 # check if need to pad
1441 if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')):
1442
1443 # get padding margins
1444 min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1445 max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
1446 pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])])
1447 pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)])
1448 if n_channels > 1:
1449 pad_margins = tuple(list(pad_margins) + [(0, 0)])
1450
1451 # pad volume
1452 new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value)
1453
1454 if aff is not None:
1455 if n_dims == 2:
1456 min_margins = np.append(min_margins, 0)
1457 aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_margins
1458
1459 else:
1460 pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape[:n_dims])])
1461
1462 # sort outputs
1463 output = [new_volume]
1464 if aff is not None:
1465 output.append(aff)
1466 if return_pad_idx:
1467 output.append(pad_idx)
1468 return output[0] if len(output) == 1 else tuple(output)
1469
1470
1471
1472 def add_axis(x, axis=0):
1473 axis = reformat_to_list(axis)
1474 for ax in axis:
1475 x = np.expand_dims(x, axis=ax)
1476 return x
1477
1478
1479 def volshape_to_meshgrid(volshape, **kwargs):
1480 """
1481 compute Tensor meshgrid from a volume size
1482 """
1483
1484 isint = [float(d).is_integer() for d in volshape]
1485 if not all(isint):
1486 raise ValueError("volshape needs to be a list of integers")
1487
1488 linvec = [tf.range(0, d) for d in volshape]
1489 return meshgrid(*linvec, **kwargs)
1490
1491
1492 def meshgrid(*args, **kwargs):
1493
1494 indexing = kwargs.pop("indexing", "xy")
1495 if kwargs:
1496 key = list(kwargs.keys())[0]
1497 raise TypeError("'{}' is an invalid keyword argument "
1498 "for this function".format(key))
1499
1500 if indexing not in ("xy", "ij"):
1501 raise ValueError("indexing parameter must be either 'xy' or 'ij'")
1502
1503 # with ops.name_scope(name, "meshgrid", args) as name:
1504 ndim = len(args)
1505 s0 = (1,) * ndim
1506
1507 # Prepare reshape by inserting dimensions with size 1 where needed
1508 output = []
1509 for i, x in enumerate(args):
1510 output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
1511 # Create parameters for broadcasting each tensor to the full size
1512 shapes = [tf.size(x) for x in args]
1513 sz = [x.get_shape().as_list()[0] for x in args]
1514
1515 # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype
1516 if indexing == "xy" and ndim > 1:
1517 output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2))
1518 output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
1519 shapes[0], shapes[1] = shapes[1], shapes[0]
1520 sz[0], sz[1] = sz[1], sz[0]
1521
1522 # This is the part of the implementation from tf that is slow.
1523 # We replace it below to get a ~6x speedup (essentially using tile instead of * tf.ones())
1524 # mult_fact = tf.ones(shapes, output_dtype)
1525 # return [x * mult_fact for x in output]
1526 for i in range(len(output)):
1527 stack_sz = [*sz[:i], 1, *sz[(i + 1):]]
1528 if indexing == 'xy' and ndim > 1 and i < 2:
1529 stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0]
1530 output[i] = tf.tile(output[i], tf.stack(stack_sz))
1531 return output
1532
1533
1534 def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
1535
1536 # convert sigma into a tensor
1537 if not tf.is_tensor(sigma):
1538 sigma_tens = tf.convert_to_tensor(reformat_to_list(sigma), dtype='float32')
1539 else:
1540 assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor'
1541 sigma_tens = sigma
1542 shape = sigma_tens.get_shape().as_list()
1543
1544 # get n_dims and batchsize
1545 if shape[0] is not None:
1546 n_dims = shape[0]
1547 batchsize = None
1548 else:
1549 n_dims = shape[1]
1550 batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0]
1551
1552 # reformat max_sigma
1553 if max_sigma is not None: # dynamic blurring
1554 max_sigma = np.array(reformat_to_list(max_sigma, length=n_dims))
1555 else: # sigma is fixed
1556 max_sigma = np.array(reformat_to_list(sigma, length=n_dims))
1557
1558 # randomise the burring std dev and/or split it between dimensions
1559 if blur_range is not None:
1560 if blur_range != 1:
1561 sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range)
1562
1563 # get size of blurring kernels
1564 windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1
1565
1566 if separable:
1567
1568 split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1)
1569
1570 kernels = list()
1571 comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
1572 for (i, wsize) in enumerate(windowsize):
1573
1574 if wsize > 1:
1575
1576 # build meshgrid and replicate it along batch dim if dynamic blurring
1577 locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2
1578 if batchsize is not None:
1579 locations = tf.tile(tf.expand_dims(locations, axis=0),
1580 tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')],
1581 axis=0))
1582 comb[i] += 1
1583
1584 # compute gaussians
1585 exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2)
1586 g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i]))
1587 g = g / tf.reduce_sum(g)
1588
1589 for axis in comb[i]:
1590 g = tf.expand_dims(g, axis=axis)
1591 kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1))
1592
1593 else:
1594 kernels.append(None)
1595
1596 else:
1597
1598 # build meshgrid
1599 mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
1600 diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
1601
1602 # replicate meshgrid to batch size and reshape sigma_tens
1603 if batchsize is not None:
1604 diff = tf.tile(tf.expand_dims(diff, axis=0),
1605 tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0))
1606 for i in range(n_dims):
1607 sigma_tens = tf.expand_dims(sigma_tens, axis=1)
1608 else:
1609 for i in range(n_dims):
1610 sigma_tens = tf.expand_dims(sigma_tens, axis=0)
1611
1612 # compute gaussians
1613 sigma_is_0 = tf.equal(sigma_tens, 0)
1614 exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2)
1615 norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens))
1616 kernels = K.sum(norms, -1)
1617 kernels = tf.exp(kernels)
1618 kernels /= tf.reduce_sum(kernels)
1619 kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1)
1620
1621 return kernels
1622
1623
1624 def get_mapping_lut(source, dest=None):
1625 """This functions returns the look-up table to map a list of N values (source) to another list (dest).
1626 If the second list is not given, we assume it is equal to [0, ..., N-1]."""
1627
1628 # initialise
1629 source = np.array(reformat_to_list(source), dtype='int32')
1630 n_labels = source.shape[0]
1631
1632 # build new label list if neccessary
1633 if dest is None:
1634 dest = np.arange(n_labels, dtype='int32')
1635 else:
1636 assert len(source) == len(dest), 'label_list and new_label_list should have the same length'
1637 dest = np.array(reformat_to_list(dest, dtype='int'))
1638
1639 # build look-up table
1640 lut = np.zeros(np.max(source) + 1, dtype='int32')
1641 for source, dest in zip(source, dest):
1642 lut[source] = dest
1643
1644 return lut
1645
1646
1647 class GaussianBlur(KL.Layer):
1648 """Applies gaussian blur to an input image."""
1649
1650 def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs):
1651 self.sigma = reformat_to_list(sigma)
1652 assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0'
1653 self.use_mask = use_mask
1654
1655 self.n_dims = None
1656 self.n_channels = None
1657 self.blur_range = random_blur_range
1658 self.stride = None
1659 self.separable = None
1660 self.kernels = None
1661 self.convnd = None
1662 super(GaussianBlur, self).__init__(**kwargs)
1663
1664 def get_config(self):
1665 config = super().get_config()
1666 config["sigma"] = self.sigma
1667 config["random_blur_range"] = self.blur_range
1668 config["use_mask"] = self.use_mask
1669 return config
1670
1671 def build(self, input_shape):
1672
1673 # get shapes
1674 if self.use_mask:
1675 assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True'
1676 self.n_dims = len(input_shape[0]) - 2
1677 self.n_channels = input_shape[0][-1]
1678 else:
1679 self.n_dims = len(input_shape) - 2
1680 self.n_channels = input_shape[-1]
1681
1682 # prepare blurring kernel
1683 self.stride = [1]*(self.n_dims+2)
1684 self.sigma = reformat_to_list(self.sigma, length=self.n_dims)
1685 self.separable = np.linalg.norm(np.array(self.sigma)) > 5
1686 if self.blur_range is None: # fixed kernels
1687 self.kernels = gaussian_kernel(self.sigma, separable=self.separable)
1688 else:
1689 self.kernels = None
1690
1691 # prepare convolution
1692 self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims)
1693
1694 self.built = True
1695 super(GaussianBlur, self).build(input_shape)
1696
1697 def call(self, inputs, **kwargs):
1698
1699 if self.use_mask:
1700 image = inputs[0]
1701 mask = tf.cast(inputs[1], 'bool')
1702 else:
1703 image = inputs
1704 mask = None
1705
1706 # redefine the kernels at each new step when blur_range is activated
1707 if self.blur_range is not None:
1708 self.kernels = gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable)
1709
1710 if self.separable:
1711 for k in self.kernels:
1712 if k is not None:
1713 image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME')
1714 for n in range(self.n_channels)], -1)
1715 if self.use_mask:
1716 maskb = tf.cast(mask, 'float32')
1717 maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME')
1718 for n in range(self.n_channels)], -1)
1719 image = image / (maskb + keras.backend.epsilon())
1720 image = tf.where(mask, image, tf.zeros_like(image))
1721 else:
1722 if any(self.sigma):
1723 image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME')
1724 for n in range(self.n_channels)], -1)
1725 if self.use_mask:
1726 maskb = tf.cast(mask, 'float32')
1727 maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME')
1728 for n in range(self.n_channels)], -1)
1729 image = image / (maskb + keras.backend.epsilon())
1730 image = tf.where(mask, image, tf.zeros_like(image))
1731
1732 return image
1733
1734
1735 class ConvertLabels(KL.Layer):
1736
1737 def __init__(self, source_values, dest_values=None, **kwargs):
1738 self.source_values = source_values
1739 self.dest_values = dest_values
1740 self.lut = None
1741 super(ConvertLabels, self).__init__(**kwargs)
1742
1743 def get_config(self):
1744 config = super().get_config()
1745 config["source_values"] = self.source_values
1746 config["dest_values"] = self.dest_values
1747 return config
1748
1749 def build(self, input_shape):
1750 self.lut = tf.convert_to_tensor(get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32')
1751 self.built = True
1752 super(ConvertLabels, self).build(input_shape)
1753
1754 def call(self, inputs, **kwargs):
1755 return tf.gather(self.lut, tf.cast(inputs, dtype='int32'))
1756
1757
1758
1759
1760 # execute script
1761 if __name__ == '__main__':
1762 main()
Attached Files
To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.You are not allowed to attach a file to this page.
