#!/usr/bin/env python
import warnings
warnings.filterwarnings('ignore', '.*negative int.*')
import os, sys, optparse
from sets import Set
import fsutils

# Original Version - Douglas Greve, MGH
# Rewrite - Krish Subramaniam, MGH
# $Id: asegstats2table,v 1.32 2010/03/30 16:30:51 greve Exp $

# map of measures and their appropriate columns in .stats file
seg2statcol_map = {'volume':3, 'mean':5, 'std':6}
# map of delimeter choices and string literals
delimiter2char = {'comma':',', 'tab':'\t', 'space':' ', 'semicolon':';'}

helptext = """
Converts a subcortical stats file created by recon-all and/or
mri_segstats (eg, aseg.stats) into a table in which each line is a
subject and each column is a segmentation ( there is an option to transpose that). 
The values are the volume of the segmentation in mm3 or the mean intensity over the structure.
The first row is a list of the segmentation names. The first column
is the subject name. If the measure is volume, then the estimated
intracranial volume (eTIV) is printed as the 2nd to last column 
(if present in the input), and BrainSegVol is the last column.

The subjects list can be specified in one of four ways:
  1. Specify each subject after -s 
  
          -s subject1 -s subject2 ..
  
  2. specify all subjects after --subjects.  
     --subjects does not have
     to be the last argument. Eg:
     
          --subjects subject1 subject2 ... 

  3. Specify each input file after -i 

          -i subject1/stats/aseg.stats -i subject2/stat/aseg.stats ..
  
  4. Specify all the input stat files after --inputs. --inputs does not have
     to be the last argument. Eg:
       
          --inputs subject1/stats/aseg.stats subject2/stats/aseg.stats ...

The first two methods assume the freesurfer directory structure. The
last two are general and can be used with any file produced by
mri_segstats regardless of whether it was created with recon-all or not,
however, the subject name is not printed in the file (just the row
number). Note that the first two and last two are mutually exclusive. i.e
don't specify --subjects when you are providing --inputs and vice versa.

By default, the volume (mm3) of each segmentation is reported. This
can be changed with '--meas measure', where measure can be 
volume or mean. If mean, it reports the mean intensity value from
the 6th column.

By default, all segmentations found in the input stats file are
reported. This can be changed by specifying the maximum segmentation
number with --maxsegno. This can be convenient for removing
segmentations that are always empty.

With methods 1 and 2 above uses stats/aseg.stats by default. 
This can be changed to subdir/statsfile with --subdir subdir
--statsfile statsfile.

The --common-segs flag outputs only the segmentations which are common to *all*
the statsfiles. This option is helpful if one or more statsfile contains
segmentations different from the segs of other files ( which results in the
script exiting which is the default behavior ). This option makes the
script to continue.

The --all-segs flag outputs segmentations which are the union of all
segmentations in all statsfiles. This option is helpful if one or more statsfile
contains segs different from the segs of other files ( which results in the
script exiting, the default behavior ). Subjects which don't have a certain
segmentation show a value of 0.

The --segno option outputs only the segmentations requested.
This is useful because if the number of segmentations is large,
the table becomes huge.

The --no-segno options doesn't output the segmentations. 
This can be convenient for removing segs that are always empty.

The --transpose flag writes the transpose of the table. 
This might be a useful way to see the table when the number of subjects is
relatively less than the number of segmentations.

The --delimiter option controls what character comes between the measures
in the table. Valid options are 'tab' ( default), 'space', 'comma' and  'semicolon'.

The --skip option skips if it can't find a .stats file. Default behavior is
exit the program.
"""

def options_parse():
    """
    Command Line Options Parser for aparcstats2table
    initiate the option parser and return the parsed object
    """
    parser = optparse.OptionParser(version='$Id: asegstats2table,v 1.32 2010/03/30 16:30:51 greve Exp $', usage=helptext)
    
    # help text
    h_sub = '(REQUIRED) subject1 <subject2 subject3..>'
    h_s = ' subjectname'
    h_subf = 'name of the file which has the list of subjects ( one subject per line)'
    h_inp = ' input1 <input2 input3..>'
    h_i = ' inputname'
    h_meas = 'measure: default is volume ( alt: mean, std)'
    h_max = ' maximum segmentation number to report'
    h_seg = 'segno1 <segno2 segno3..> : only include given segmentation numbers'
    h_noseg = 'segno1 <segno2 segno3..> : exclude given segmentation numbers'
    h_common = 'output only the common segmentations of all the statsfiles given'
    h_all = 'output all the segmentations of the statsfiles given'
    h_stats = 'use `fname` instead of "aseg.stats"'
    h_subdir = 'use `subdir` instead of "stats/"'
    h_tr = 'transpose the table ( default is subjects in rows and segmentations in cols)' 
    h_t = '(REQUIRED) the output tablefile'
    h_deli = 'delimiter between measures in the table. default is tab (alt comma, space, semicolon )' 
    h_skip = 'if a subject does not have stats file, skip it'
    h_v = 'increase verbosity'

    # Add options 
    parser.add_option('--subjects', dest='subjects' ,action='callback',
                      callback=fsutils.callback_var,  help=h_sub)
    parser.add_option('-s', dest='subjects' ,action='append',
                      help=h_s)
    parser.add_option('--subjectsfile', dest='subjectsfile', help=h_subf)
    parser.add_option('--inputs', dest='inputs' ,action='callback',
                      callback=fsutils.callback_var,  help=h_inp)
    parser.add_option('-i', dest='inputs' ,action='append',
                      help=h_i)
    parser.add_option('-t', '--tablefile', dest='outputfile',
                      help=h_t)
    parser.add_option('-m', '--meas', dest='meas',
                      choices=('volume','mean','std'),  default='volume', help=h_meas)
    parser.add_option('--maxsegno', dest='maxsegno',
                       help=h_inp)
    parser.add_option('--segno', dest='segnos' ,action='callback',
                      callback=fsutils.callback_var,  help=h_seg)
    parser.add_option('--no-segno', dest='no_segnos' ,action='callback',
                      callback=fsutils.callback_var,  help=h_noseg)
    parser.add_option('--common-segs', dest='common_flag' ,action='store_true',
                      default=False, help=h_common)
    parser.add_option('--all-segs', dest='all_flag' ,action='store_true',
                      default=False, help=h_all)
    parser.add_option('--stats', dest='statsfname',
                       help=h_stats)
    parser.add_option('--statsfile', dest='statsfname',
                       help=h_stats)
    parser.add_option('--subdir', dest='subdir',
                       help=h_subdir)
    parser.add_option('-d', '--delimiter', dest='delimiter',
                      choices=('comma','tab','space','semicolon'),
                      default='tab', help=h_deli)
    parser.add_option('', '--transpose', action='store_true', dest='transposeflag',
                      default=False, help=h_tr)
    parser.add_option('--skip', action='store_true', dest='skipflag',
                      default=False, help=h_skip)
    parser.add_option('-v', '--debug', action='store_true', dest='verboseflag',
                      default=False, help=h_v)

    (options, args) = parser.parse_args()
    
    # extensive error checks
    if options.subjects is not None:
        if len(options.subjects) < 1:
            print 'ERROR: subjects are not specified (use --subjects SUBJECTS)'
            sys.exit(1)
        else:
            options.dodirect = False
    
    if options.inputs is not None:
        if len(options.inputs) < 1:
            print 'ERROR: inputs are not specified'
            sys.exit(1)
        else:
            options.dodirect = True

    if options.subjectsfile is not None:
        options.dodirect = False

    if options.subjects is None and options.inputs is None and options.subjectsfile is None: 
        print 'ERROR: Specify one of --subjects, --inputs or --subjectsfile'
        print '       or run with --help for help.'
        sys.exit(1)
    
    if options.subjects is not None and options.inputs is not None:
        print 'ERROR: Both subjects and inputs are specified. Please specify just one '
        sys.exit(1)

    if options.subjects is not None and options.subjectsfile is not None:
        print 'ERROR: Both subjectsfile and subjects are specified. Please specify just one '
        sys.exit(1)
    
    if options.inputs is not None and options.subjectsfile is not None:
        print 'ERROR: Both subjectsfile and inputs are specified. Please specify just one '
        sys.exit(1)
    
    if not options.outputfile:
        print 'ERROR: output table name should be specified (use --tablefile FILE)'
        sys.exit(1)

    if options.maxsegno and int(options.maxsegno) < 1:
        print 'ERROR: maximum number of segs reported shouldnt be less than 1'
        sys.exit(1)
    
    if options.segnos is not None  and len(options.segnos) < 1 :
        print 'ERROR: segmentation numbers should be specified with that option'
        sys.exit(1)
    
    if options.no_segnos is not None and len(options.no_segnos) < 1:
        print 'ERROR: to be excluded segmentation numbers should be specified with that option'
        sys.exit(1)
    
    return options

def is_valid_segno(options, seg):
    """
    check whether the segmentation is what the user asked for
    """
    # invalid if doesn't belong to segnos ( only if segnos exists in the first
    # place)
    if options.segnos is not None:
        if seg not in options.segnos:
            return False
    # invalid if it belongs to no_segnos
    if options.no_segnos is not None:
        if seg in options.no_segnos:
            return False
    # invalid if it exceeds max segno 
    if options.maxsegno is not None:
        if int(seg) > int(options.maxsegno):
            return False
    # else
    return True
    
def sanity_check_segs(table, rows, stats, verbose=False):
    """
    primarily, if all the segs in all the subjects are the same (they should
    be! ), the union ( segs ) should be equal to intersection(segs)
    return False otherwise.. the program quits
    """
    _union = []
    intersection = table[rows[0]].keys()
    if verbose:
        print 'Displaying debugging information for each subject, statsfile'
    for count, row in enumerate(rows):
        rowkeys = table[row].keys()
        _union.append(rowkeys)
        intersection = fsutils.intersect_order(intersection, rowkeys)
        if verbose:
            pass
    temp_union = [item for sublist in _union for item in sublist]
    union = fsutils.unique_union(temp_union)
    if union == intersection:
        return (union, intersection, True)
    else:
        return (union, intersection, False)

def build_table(options):
    """
    This function builds the 2d table (spreadsheet )
    of subjects vs segs. The values are given by what the user chose in 
    measure option.
    """
    o = options
    # a devious trick in which all the subject files take the form
    # (subjectnumber, subjectname, subjectpath)
    # the table's first column will have subjectnumber+subjectname ( concat) 
    # since --subjects and --inputs are mut-ex, so are subjectnumber and
    # subjectname
    stats = []
    if options.dodirect:
        for count, inp in enumerate(o.inputs):
            stats.append( (str(count), '', inp) )
    else:
        # check the subjects dir
        subjdir = fsutils.check_subjdirs()
        print 'SUBJECTS_DIR : %s' %subjdir
        if o.subdir is None:
            o.subdir = 'stats'
        if o.statsfname is None:
            o.statsfname = 'aseg.stats'
        # in case the user gave --subjectsfile argument
        if o.subjectsfile is not None:
            o.subjects=[]
            try:
                sf = open(o.subjectsfile)
                [o.subjects.append(subfromfile.strip()) for subfromfile in sf]
            except IOError:
                print 'ERROR: the file %s doesnt exist'%o.subjectsfile
                sys.exit(1)
        for sub in o.subjects:
            stats.append( ('', sub,  os.path.join(subjdir, sub, o.subdir, o.statsfname)) )
    
    rows = []
    # init the table
    table = fsutils.Ddict(fsutils.StableDict)
    for statno, subname, statfile in stats:
        if os.path.exists( statfile):
            if o.verboseflag:
                print 'Processing file %s' %statfile
            rows.append(statno+subname)
            # if file size is less than 10, not a valid segstats file. so exit
            if os.path.getsize(statfile) < 10:
                print 'WARNING: File:' + statfile + ' is not a valid aseg statsfile'
                print 'Exclude it from the list of subjects/don\'t use common-segs option'
                print 'Rerun recon-all -segstats to regenerate a valid statsfile'
            fp = open(statfile, 'r')
            cols = []
            for line in fp:
                # valid line
                if line.rfind('#') == -1:
                    strlst = line.split()
                    seg = strlst[1]
                    if is_valid_segno(options, seg):
                        segid = strlst[4]
                        cols.append(segid)
                        table[statno+subname][segid] = float(strlst[seg2statcol_map[o.meas]])
            # unfortunately we do it again because we necessarily put these
            # volumes at the end of the table ( last 2 columns). Since we aren't
            # processing a lot of files, this wouldn't hamper the speed. Also
            # more volume measures can be added easily.
            fp.seek(0)
            for line in fp:
                # if measure is volume, also output ICV and BSV
                if not o.dodirect and o.meas == 'volume':
                    if line.startswith('# Measure lhCortex, lhCortexVol,') :
                        strlst = line.split(',')
                        cols.append('lhCortexVol')
                        table[statno+subname]['lhCortexVol'] = float(strlst[3])
                    if line.startswith('# Measure rhCortex, rhCortexVol,') :
                        strlst = line.split(',')
                        cols.append('rhCortexVol')
                        table[statno+subname]['rhCortexVol'] = float(strlst[3])
                    if line.startswith('# Measure Cortex, CortexVol,') :
                        strlst = line.split(',')
                        cols.append('CortexVol')
                        table[statno+subname]['CortexVol'] = float(strlst[3])
                    if line.startswith('# Measure lhCorticalWhiteMatter, lhCorticalWhiteMatterVol,') :
                        strlst = line.split(',')
                        cols.append('lhCorticalWhiteMatterVol')
                        table[statno+subname]['lhCorticalWhiteMatterVol'] = float(strlst[3])
                    if line.startswith('# Measure rhCorticalWhiteMatter, rhCorticalWhiteMatterVol,') :
                        strlst = line.split(',')
                        cols.append('rhCorticalWhiteMatterVol')
                        table[statno+subname]['rhCorticalWhiteMatterVol'] = float(strlst[3])
                    if line.startswith('# Measure SubCortGray, SubCortGrayVol,') :
                        strlst = line.split(',')
                        cols.append('SubCortGrayVol')
                        table[statno+subname]['SubCortGrayVol'] = float(strlst[3])
                    if line.startswith('# Measure TotalGray, TotalGrayVol,') :
                        strlst = line.split(',')
                        cols.append('TotalGrayVol')
                        table[statno+subname]['TotalGrayVol'] = float(strlst[3])
                    if line.startswith('# Measure SuperTentorial, SuperTentorialVol,') :
                        strlst = line.split(',')
                        cols.append('SuperTentorialVol')
                        table[statno+subname]['SuperTentorialVol'] = float(strlst[3])
                    if line.startswith('# Measure IntraCranialVol, ICV,') :
                        strlst = line.split(',')
                        cols.append('IntraCranialVol')
                        table[statno+subname]['IntraCranialVol'] = float(strlst[3])
                    if line.startswith('# Measure BrainSeg, BrainSegVol,') :
                        strlst = line.split(',')
                        cols.append('BrainSegVol')
                        table[statno+subname]['BrainSegVol'] = float(strlst[3])
        else:
            if not o.skipflag:
                print 'ERROR: cannot find %s' %statfile
                print 'Use --skip flag if you want to continue in such cases'
                sys.exit(1)
            else:
                print 'WARNING: cannot find %s. proceeding to next file' %statfile
    
    # check sanity
    (union, intersection, sanityflag) = sanity_check_segs(table, rows, stats,
                                                          verbose=o.verboseflag)
    if not sanityflag:
        if o.common_flag:
            # write only the common segmentations
            new_table = fsutils.Ddict(fsutils.StableDict)
            for row in rows:
                for seg in intersection:
                    new_table[row][seg]  = table[row][seg]

            table = new_table
            cols = intersection
        
        elif o.all_flag:
            # write all the segs
            new_table = fsutils.Ddict(fsutils.StableDict)
            for row in rows: 
                for seg in union:
                    if not seg in table[row]:
                        new_table[row][seg] = 0.0
                    else:
                        new_table[row][seg] = table[row][seg]
            table = new_table
            cols = union

        else:
            print 'ERROR: All stat files should have the same segmentations'
            print 'If one or more stats file have different segs from others,'
            print 'use --common-segs or --all-segs flag depending on the need.'
            print '(see help)'
            sys.exit(1)

    # if verbose
    if o.verboseflag:
        for row in rows:
            print table[row]
            print '-'*40
    return rows, cols, table

def write_table(options, rows, cols, table):
    """
    Write the table from memory to disk. Initialize the writer class.
    """
    tw = fsutils.TableWriter(rows, cols, table)
    r1c1 = 'Measure:%s' %( options.meas)
    tw.assign_attributes(filename=options.outputfile, row1col1=r1c1,
                         delimiter=delimiter2char[options.delimiter] )
    if options.transposeflag:
        tw.write_transpose()
    else:
        tw.write()
    

if __name__=="__main__":
    options = options_parse()
    if options.verboseflag:
        print options
    print 'Building the table..'
    (rows, cols, table) = build_table(options)
    print 'Writing the table to %s' %options.outputfile
    write_table(options, rows, cols, table)
    sys.exit(0)
