%
% SynthSeg(input, output, modelfile, outputvols)
%
% Contrast- and resolution-agnostic segmentation of brain MRI using domain
% randomization.
%
% Usage: SynthSeg(input, output, modelfile)
%
% input: can be a nifti file (.nii/.nii.gz) or a directory (in which case,
%        all nifti files in the directory will be processed).
%
% outputseg: like the input can be a nifti file or a directory. If it is a
%         file, it should not exist (code will throw an error if it does).
%
% modelfile: h5 file with the weights of the neural network. If the file
%            does not exist, the code will download it from MathWorks (so
%            next time you run the code it will already be there).
%
% outputvols: csv file with estimated volumes of the brain ROIs (optional)
%
%
% If you use this method in a publication, please cite:
% [1] Billot, Benjamin, Douglas N. Greve, Oula Puonti, Axel Thielscher, Koen Van Leemput,
% Bruce Fischl, Adrian V. Dalca, and Juan Eugenio Iglesias. "SynthSeg: Domain Randomisation
% for Segmentation of Brain Scans of Any Contrast and Resolution." ArXiv:2107.09559, 2021.
% http://arxiv.org/abs/2107.09559.
%
function SynthSeg(input, outputseg, modelfile, outputvols)

% Make sure files and directories exist (or not) as needed
if nargin~=3 && nargin~=4
    error('SynthSeg requires 3 or 4 arguments (see help).')
end
if nargin==3
    outputvols=[];
end

filemode = 0;
if strcmp(input(end-3:end), '.nii')>0 || strcmp(input(end-6:end), '.nii.gz')>0
    filemode = 1;
    disp('Input is a single nifti file.');
    if exist(input,'file')==0
        error('Input file does not exist.');
    end
    if exist(outputseg,'file')>0
        error('Output file already exists.');
    end
    [filepath,~,~] = fileparts(outputseg);
    if ~isempty(filepath) && exist(filepath,'dir')==0
        error(['Path of output file does not exist: ' filepath]);
    end
end

if filemode==0
    disp('Input is not nifti file; assuming that input is a directory.');
    if exist(input,'dir')==0
        error('Input directory does not exist.');
    end
    if exist(outputseg,'dir')==0
        disp(['Output directory does not exist; creating: ' outputseg]);
        mkdir(outputseg);
    end
end

if exist(modelfile,'file')>0
    disp('Model file found; no need to download.');
else
    disp('Model file not found; downloading (it''s 50MB, it may take a bit, depending on your connection...)');
    [filepath,~,~] = fileparts(modelfile);
    if ~isempty(filepath) && exist(filepath,'dir')==0
        error(['The directory where I am trying to save the model file does not exist: ' filepath]);
    end
    websave(modelfile,'https://www.mathworks.com/supportfiles/image/data/trainedBrainSynthSegNetwork.h5');
    disp('Download complete');
end

if exist(outputvols,'file')>0
    error('CSV file with output volumes already exists.');
end
if ~isempty(outputvols)
    [filepath,~,~] = fileparts(outputvols);
    if ~isempty(filepath) && exist(filepath,'dir')==0
        error(['The directory where I am trying to save the CSV file does not exist: ' filepath]);
    end
end

% Prepare list of files to process
inputs=[];
outputs=[];
if filemode==1
    inputs{1}=input;
    outputs{1}=outputseg;
else
    d=dir([input filesep '*.nii']);
    for i=1:length(d)
        inputs{end+1}=[input filesep d(i).name];
        outputs{end+1}=[outputseg filesep d(i).name(1:end-4) '.seg.nii'];
        if exist(outputs{end},'file')
            error(['Output file already exists: ' outputs{end}]);
        end
    end
    d=dir([input filesep '*.nii.gz']);
    for i=1:length(d)
        inputs{end+1}=[input filesep d(i).name];
        outputs{end+1}=[outputseg filesep d(i).name(1:end-7) '.seg.nii.gz'];
        if exist(outputs{end},'file')
            error(['Output file already exists: ' outputs{end}]);
        end
    end
end


% Go over images and segment
disp('Segmenting!')
currentSize = [0, 0, 0];
net=[];
[classNames,labelIDs] = getBrainCANDISegmentationLabels;
% create CSV file if needed (and write first line with column titles)
if ~isempty(outputvols)
    fid=fopen(outputvols,'w');
    fprintf(fid,'FileName');
    for i=2:length(classNames) % skip the background
        fprintf(fid,',%s',classNames{i});
    end
    fprintf(fid,'\n');
end

for i=1:length(inputs)

    disp(['   Working on image ' num2str(i) ' of ' num2str(length(inputs)) '.']);

    disp('      Reading image...');
    metaData = niftiinfo(inputs{i});
    X = niftiread(metaData);

    disp('      Preprocessing image...');
    [X1,aff] = preProcessBrainCANDIData(X,metaData);
    inputSize = size(X1);
    % read network weights (the first time, or if shape changes for some reason)
    if any(currentSize~=inputSize) || isempty(net)
        disp('      Reading in network weights...');
        warning('off','all')
        lgraph = importKerasLayers(modelfile,ImportWeights=true,ImageInputSize=inputSize);
        placeholderLayers = findPlaceholderLayers(lgraph);
        sf = softmaxLayer;
        lgraph = replaceLayer(lgraph,"unet_prediction",sf);
        net = dlnetwork(lgraph);
        warning('on','all')
        currentSize=inputSize;
    else
        disp('      No need to read network weights again.')
    end

    disp('      Prediction 1: original orientation...');
    X2 = dlarray(X1,"SSSCB");
    predictIm = predict(net,X2);

    disp('      Prediction 2: flipped orientation...');
    flippedData = fliplr(X1);
    flippedData = flip(flippedData,2);
    flippedData = flip(flippedData,1);
    flippedData = dlarray(flippedData,"SSSCB");
    flipPredictIm = predict(net,flippedData);

    disp('      Postprocessing...')
    [seg, vols] = postProcessBrainCANDIData(predictIm,flipPredictIm,labelIDs);

    % write volumes if needed
    if ~isempty(outputvols)
        fprintf(fid,'InputFile');
        for v=2:length(vols)  % skip the background
            fprintf(fid,',%f',vols(v));
        end
        fprintf(fid,'\n');
    end

    disp('      Writing to disk...');
    metaData.Datatype='uint8';
    metaData.BitsPerPixel=8;
    metaData.PixelDimensions=ones(1,3);
    metaData.Transform.T=aff';
    metaData.ImageSize=size(seg);
    metaData.raw=[];
    niftiwrite(uint8(seg), outputs{i}, metaData);

    disp('      Image done!');

end
% close CSV file if needed
if ~isempty(outputvols)
    fclose(fid);
end
disp('All done!');





