import nibabel as nib 
import math
import numpy as np
import os.path
import subprocess
from scipy import ndimage, weave
import operator
#import nipype.interfaces.fsl as fsl
def Normalize(X, numS=-1):
    dimX = X.shape[1]
    numSamples = X.shape[0]
    Xm = np.mean(X, axis=0)[:,np.newaxis].T
    Xs = np.std(X, axis=0)[:,np.newaxis].T
    if numS > 0: 
        Xm = Xm * X.shape[0] / numS
        Xs = Xs * X.shape[0] / numS
    Xd = (X - np.tile(Xm, [numSamples,1])) / np.tile(Xs + 1e-8, [numSamples,1])
    return Xd

def PermutationTesting(Y, X, N):
    '''
    Y is the data with dimensions Samples x dimY
    X is the label with dimensions Samples x 1
    N is the number of permutations
    '''
    Xp = np.zeros([X.shape[0], N])
    for k in range(N): 
        Xp[:,k] = X[np.random.permutation(X.shape[0]), 0]
    cov_scores_p, vecs_p = MultivariateCovarianceAnalysis_SingleLabel(Y,Xp)
    cov_score, vec = MultivariateCovarianceAnalysis_SingleLabel(Y, X)
    pval = np.sum(cov_scores_p [0,:] >= cov_score[0,0]) / np.double(N)
    return pval, cov_scores_p, vecs_p

def MultivariateCovarianceAnalysis_SingleLabel(Y, X):
    '''
     Y is the data with dimensions Samples x dimY
     X is the label with dimensions Samples x numX
     returns cov_scores and vec of maximal correlations
    '''
    numSamples = X.shape[0]
    dimY = Y.shape[1]
    # demeaning Y
    Yd = Normalize(Y)
    # demeaning X
    Xd = Normalize(X)
    
    if Yd.max() == 0: 
        cov_score = np.zeros(X.shape[1])
        vec = np.ones(Y.shape[1],X.shape[1])
        return cov_score, vec
    mu_X = np.dot(Yd.T, Xd) / np.double(numSamples)    
    norm_v = np.sqrt(np.sum(mu_X**2,axis=0))[:,np.newaxis].T
    vec = mu_X / np.tile(norm_v + 1e-8, [Y.shape[1], 1])
    cov_score = norm_v
    return cov_score, vec

def MultivariateCovarianceAnalysis_SingleLabel_AllPoints(Y, X, numVar):
    '''
     Y is the data with dimensions Samples x dimY * numPoints
     X is the label with dimensions Samples x numX
     returns cov_scores and vec of maximal correlations
    '''
    numSamples = X.shape[0]
    dimY = Y.shape[1]
    numX = X.shape[1]
    # demeaning Y
    Yd = Normalize(Y)
    # demeaning X
    Xd = Normalize(X)
    
    cov_score = np.zeros([dimY / numVar, numX])
    vec = np.zeros([numVar, dimY / numVar, numX])
    mu_X = (np.dot(Yd.T, Xd) / np.double(numSamples))
    for n in range(numX):
        mu_Xn = mu_X[:,n].reshape([numVar,dimY / numVar])
        norm_v = np.sqrt(np.sum(mu_Xn**2, axis=0))[:,np.newaxis].T
        vec[:,:,n] = mu_Xn / np.tile(norm_v + 1e-8, [numVar, 1])
        cov_score[:,n] = norm_v
    return cov_score, vec

def MultivariateCovarianceAnalysis_SingleLabel_AllPoints(Y, X, numVar):
    '''
     Y is the data with dimensions Samples x dimY * numPoints
     X is the label with dimensions Samples x numX
     returns cov_scores and vec of maximal correlations
    '''
    numSamples = X.shape[0]
    dimY = Y.shape[1]
    numX = X.shape[1]
    # demeaning Y
    Yd = Normalize(Y)
    # demeaning X
    Xd = Normalize(X)
    
    cov_score = np.zeros([dimY / numVar, numX])
    vec = np.zeros([numVar, dimY / numVar, numX])
    mu_X = (np.dot(Yd.T, Xd) / np.double(numSamples))
    for n in range(numX):
        mu_Xn = mu_X[:,n].reshape([numVar,dimY / numVar])
        norm_v = np.sqrt(np.sum(mu_Xn**2, axis=0))[:,np.newaxis].T
        vec[:,:,n] = mu_Xn / np.tile(norm_v + 1e-8, [numVar, 1])
        cov_score[:,n] = norm_v
    return cov_score, vec

def MultivariateCovarianceAnalysis_RegressingOut_AllPoints(Y, X, Z, numVar):
    '''
         Y is the data with dimensions Samples x dimY * numPoints
         X is the label with dimensions Samples x numX
         Z is the label with dimensions Samples x numX
         returns cov_scores and vec of maximal correlations
         The procedure regresses out the effects of Z and computes the 
         residual effect of X. 
    '''
    numSamples = X.shape[0]
    dimY = Y.shape[1]
    numX = X.shape[1]  
    # demeaning Y
    Yd = Normalize(Y)
    # demeaning X
    Xd = Normalize(X)
    # demeaning Z
    Zd = Normalize(Z)
    cov_score_orth = np.zeros([dimY / numVar, numX])
    cov_score_par = np.zeros([dimY / numVar, numX])
    vec_orth = np.zeros([numVar, dimY / numVar, numX])
    mu_X = (np.dot(Yd.T, Xd) / np.double(numSamples))
    mu_Z = (np.dot(Yd.T, Zd) / np.double(numSamples))
    
    for n in range(numX):
        mu_Xn = mu_X[:,n].reshape([numVar,dimY / numVar])
        mu_Zn = mu_Z[:,n].reshape([numVar,dimY / numVar])    

        # finding the orthogonal component
        v = mu_Xn - np.tile(np.sum(mu_Xn*mu_Zn,axis=0) / (np.sum(mu_Zn*mu_Zn, axis=0)+1e-5), [numVar, 1]) * mu_Zn
        norm_v = np.sqrt(np.sum(v**2, axis=0))[:,np.newaxis].T
        vec_orth[:,:,n] = v / np.tile(norm_v + 1e-8, [numVar, 1])
        cov_score_orth[:,n] = np.sum(mu_Xn * vec_orth[:,:,n], axis=0)
        
        # finding the residual parallel component
        norm_w = np.sqrt(np.sum(mu_Zn**2, axis=0))[:,np.newaxis].T 
        w = mu_Zn / np.tile(norm_w + 1e-8, [numVar, 1])            
        Xdn = Xd[:,n][:,np.newaxis]
        Zdn = Zd[:,n][:,np.newaxis]        
        cov_score_par[:,n] = np.sum(w*mu_Xn, axis=0) - np.dot(np.dot(Xdn.T, Zdn), norm_w)[0,:]/np.double(numSamples)
        
    return cov_score_orth, vec_orth, cov_score_par
    
def MultivariateCovarianceAnalysis_RegressingOut_AllPoints_OLD(Y, X, Z, numVar):
    '''
         Y is the data with dimensions Samples x dimY * numPoints
         X is the label with dimensions Samples x numX
         Z is the label with dimensions Samples x numX
         returns cov_scores and vec of maximal correlations
         The procedure regresses out the effects of Z and computes the 
         residual effect of X. 
    '''
    numSamples = X.shape[0]
    dimY = Y.shape[1]
    numX = X.shape[1]  
    # demeaning Y
    Yd = Normalize(Y)
    # demeaning X
    Xd = Normalize(X)
    # demeaning Z
    Zd = Normalize(Z)
    cov_score = np.zeros([dimY / numVar, numX])
    vec = np.zeros([numVar, dimY / numVar, numX])
    mu_X = (np.dot(Yd.T, Xd) / np.double(numSamples))
    mu_Z = (np.dot(Yd.T, Zd) / np.double(numSamples))
    for n in range(numX):
        mu_Xn = mu_X[:,n].reshape([numVar,dimY / numVar])
        mu_Zn = mu_Z[:,n].reshape([numVar,dimY / numVar])
        print mu_Xn.shape
        print mu_Zn.shape
        v = mu_Xn - np.tile(np.sum(mu_Xn*mu_Zn,axis=0) / (np.sum(mu_Zn*mu_Zn, axis=0)+1e-5), [numVar, 1]) * mu_Zn
        norm_v = np.sqrt(np.sum(v**2, axis=0))[:,np.newaxis].T
        vec[:,:,n] = v / np.tile(norm_v + 1e-8, [numVar, 1])
        cov_score[:,n] = np.sum(mu_Xn * vec[:,:,n], axis=0)
    return cov_score, vec
    


def PermutationTesting_AllPoints(Y,X,numVar,N):
    '''
    Y is the data with dimensions Samples x dimY
    X is the label with dimensions Samples x 1
    N is the number of permutations
    '''
    cov_score, vec = MultivariateCovarianceAnalysis_SingleLabel_AllPoints(Y, X, numVar)
    cov_scores_p = np.zeros([cov_score.shape[0], N])
    permPerRun = 100
    Kparts = np.int(N / permPerRun)
    for k in range(Kparts):
        Xp = np.zeros([X.shape[0], permPerRun])
        for n in range(permPerRun): 
            Xp[:,n] = X[np.random.permutation(X.shape[0]), 0]
        cov_scores_p[:,k*permPerRun:(k+1)*permPerRun], vecs_p = MultivariateCovarianceAnalysis_SingleLabel_AllPoints(Y,Xp,numVar)        
    
    remaining = N - permPerRun * Kparts
    Xp = np.zeros([X.shape[0], remaining])
    for n in range(remaining):
        Xp[:,n] = X[np.random.permutation(X.shape[0]), 0]
    cov_scores_p[:,Kparts*permPerRun:], vecs_p = MultivariateCovarianceAnalysis_SingleLabel_AllPoints(Y,Xp,numVar)        
    
    pval = np.sum(cov_scores_p >= np.tile(cov_score,[1,N]),axis=1) / np.double(N)
    maximal_statistics = np.max(cov_scores_p, axis=0)    
    return pval, maximal_statistics, cov_score, cov_scores_p

def TFCECorrection_Basic(stats, rows, data_shape):
    dims = len(data_shape)
    if dims == 3: 
        stats_update = TFCE3D_Update(stats, rows, data_shape) 
    elif dims == 2: 
        stats_update = TFCE2D_Update(stats, rows, data_shape)
    else: 
        stats_update = TFCE1D_Update(stats, rows, data_shape)
    return stats_update

def TFCECorrection_PermutationResults(cov_score, cov_scores_p,rows,data_shape):
    '''
    Y is the data with dimensions Samples x dimY
    X is the label with dimensions Samples x 1
    N is the number of permutations
    '''
    N = cov_scores_p.shape[1]
    dims = len(data_shape)
    if dims == 3:
        cov_score_update = TFCE3D_Update(cov_score[:,0], rows, data_shape)    
    elif dims == 2:
        cov_score_update = TFCE2D_Update(cov_score[:,0], rows, data_shape)    
    else:
        cov_score_update = TFCE1D_Update(cov_score[:,0], rows, data_shape)    
        
    cov_scores_p_update = np.zeros(cov_scores_p.shape)
    for k in range(N):
        if dims == 3: 
            cov_scores_p_update[:,k] = TFCE3D_Update(cov_scores_p[:,k],rows,data_shape)
        elif dims == 2:
            cov_scores_p_update[:,k] = TFCE2D_Update(cov_scores_p[:,k],rows,data_shape)
        else: 
            cov_scores_p_update[:,k] = TFCE1D_Update(cov_scores_p[:,k],rows,data_shape)
            
        if np.mod(k, 250) == 0: 
            print '{0} / {1} done.'.format(k, N)
    pval = np.sum(cov_scores_p_update >= np.tile(cov_score_update[:,np.newaxis],[1,N]),axis=1) / np.double(N)
    maximal_statistics = np.max(cov_scores_p_update, axis=0)
    return pval, maximal_statistics


def PermutationTesting_RegressingOut_AllPoints(Y,X,Z,numVar,N):
    '''
    Y is the data with dimensions Samples x dimY
    X is the label with dimensions Samples x 1
    N is the number of permutations
    Z is the label with dimensions Samples x 1
    returns cov_scores and vec of maximal correlations
    The procedure regresses out the effects of Z and computes the 
    residual effect of X. 
    The permutation test is to permute the X variable and keep the Z variable fixed. 
    The null hypothesis is whathever the case variable is there is still enough difference once you
    regress out the age. 
    '''
    cov_score_orth, vec_orth, cov_score_par = MultivariateCovarianceAnalysis_RegressingOut_AllPoints(Y, X, Z, numVar)
    cov_scores_p_orth = np.zeros([cov_score_orth.shape[0],N])
    cov_scores_p_par = np.zeros([cov_score_par.shape[0],N])
    permPerRun = 100
    Kparts = np.int(N / permPerRun)
    for k in range(Kparts): 
        Xp = np.zeros([X.shape[0], permPerRun])
        Zp = np.tile(Z, [1,permPerRun])
        for n in range(permPerRun): 
            Xp[:,n] = X[np.random.permutation(X.shape[0]), 0]
        cov_scores_p_orth[:,k*permPerRun:(k+1)*permPerRun], vecs_p_orth,\
         cov_scores_p_par[:,k*permPerRun:(k+1)*permPerRun] = \
         MultivariateCovarianceAnalysis_RegressingOut_AllPoints(Y,Xp,Zp,numVar)
    
    remaining = N - permPerRun * Kparts
    Xp = np.zeros([X.shape[0], remaining])
    Zp = np.tile(Z, [1, remaining])
    for n in range(remaining):
        Xp[:,n] = X[np.random.permutation(X.shape[0]), 0]
    cov_scores_p_orth[:,Kparts*permPerRun:], vecs_p_orth, \
    cov_scores_p_par[:,Kparts*permPerRun:] = \
    MultivariateCovarianceAnalysis_RegressingOut_AllPoints(Y,Xp,Zp,numVar)

    pval_orth = np.sum(cov_scores_p_orth >= np.tile(cov_score_orth,[1,N]),axis=1) / np.double(N)
    maximal_statistics_orth = np.max(cov_scores_p_orth, axis=0)
    pval_par = np.sum(np.abs(cov_scores_p_par) >= np.tile(np.abs(cov_score_par),[1,N]),axis=1) / np.double(N)
    maximal_statistics_par = np.max(np.abs(cov_scores_p_par), axis=0)
    return pval_orth, maximal_statistics_orth,\
     pval_par, maximal_statistics_par, cov_scores_p_orth, cov_scores_p_par

def TFCECorrection_PermutationRegressingOut(cov_score_orth, cov_scores_p_orth, cov_score_par, cov_scores_p_par, rows, data_shape):
    '''
    cov_score_orth is always positive while cov_score_par can be negative as well. 
    '''
    N = cov_scores_p_orth.shape[1]
    dims = len(data_shape)
    if dims == 3:
        cov_score_par_update = TFCE3D_Update(np.abs(cov_score_par[:,0]), rows, data_shape)    
        cov_score_orth_update = TFCE3D_Update(cov_score_orth[:,0], rows, data_shape)
    elif dims == 2: 
        cov_score_par_update = TFCE2D_Update(np.abs(cov_score_par[:,0]), rows, data_shape)    
        cov_score_orth_update = TFCE2D_Update(cov_score_orth[:,0], rows, data_shape)
    else: # meaning 1 dimensional
        cov_score_par_update = TFCE1D_Update(np.abs(cov_score_par[:,0]), rows, data_shape)
        cov_score_orth_update = TFCE1D_Update(cov_score_orth[:,0], rows, data_shape)
        
    cov_scores_p_orth_update = np.zeros(cov_scores_p_orth.shape)
    cov_scores_p_par_update = np.zeros(cov_scores_p_par.shape)
    pval_orth = np.zeros(cov_score_orth.shape[0])
    pval_par = np.zeros(cov_score_par.shape[0])
    for k in range(N):
        if dims == 3: 
            cov_scores_p_orth_update[:,k] = TFCE3D_Update(cov_scores_p_orth[:,k],rows, data_shape)
            cov_scores_p_par_update[:,k] = TFCE3D_Update(np.abs(cov_scores_p_par[:,k]),rows, data_shape)
        elif dims == 2:
            cov_scores_p_orth_update[:,k] = TFCE2D_Update(cov_scores_p_orth[:,k],rows, data_shape)
            cov_scores_p_par_update[:,k] = TFCE2D_Update(np.abs(cov_scores_p_par[:,k]),rows, data_shape)
        else: 
            cov_scores_p_orth_update[:,k] = TFCE1D_Update(cov_scores_p_orth[:,k],rows, data_shape)
            cov_scores_p_par_update[:,k] = TFCE1D_Update(np.abs(cov_scores_p_par[:,k]),rows, data_shape)
            
        if np.mod(k, 250) == 0: 
            print '{0} / {1} done.'.format(k, N)
        pval_orth += (cov_scores_p_orth_update[:,k] >= cov_score_orth_update).astype(np.double) / np.double(N)
        pval_par += (cov_scores_p_par_update[:,k] >= cov_score_par_update).astype(np.double) / np.double(N)

    
    maximal_statistics_orth = np.max(cov_scores_p_orth_update, axis=0)
    maximal_statistics_par = np.max(np.abs(cov_scores_p_par_update), axis=0)
    
    return pval_orth, maximal_statistics_orth, pval_par, maximal_statistics_par
    
    
def PermutationTesting_RegressingOut_AllPoints_OLD(Y,X,Z,numVar,N):
    '''
    Y is the data with dimensions Samples x dimY
    X is the label with dimensions Samples x 1
    N is the number of permutations
    Z is the label with dimensions Samples x 1
    returns cov_scores and vec of maximal correlations
    The procedure regresses out the effects of Z and computes the 
    residual effect of X. 
    The permutation test is to permute the X variable and keep the Z variable fixed. 
    The null hypothesis is whathever the case variable is there is still enough difference once you
    regress out the age. 
    '''
    Xp = np.zeros([X.shape[0], N])
    Zp = np.tile(Z, [1,N])
    for k in range(N): 
        Xp[:,k] = X[np.random.permutation(X.shape[0]), 0]
    cov_scores_p, vecs_p = MultivariateCovarianceAnalysis_RegressingOut_AllPoints(Y,Xp,Zp,numVar)
    cov_score, vec = MultivariateCovarianceAnalysis_RegressingOut_AllPoints(Y, X, Z, numVar)
    pval = np.sum(cov_scores_p >= np.tile(cov_score,[1,N]),axis=1) / np.double(N)
    maximal_statistics = np.max(cov_scores_p, axis=0)
    return pval, maximal_statistics

    
def UnivariateCovarianceAnalysis_SingleLabel(Y,X):
    '''
    Y is the data with the dimensions Samples
    X is the label with the dimensions Samples x numX
    returns cov_scores
    '''
    cov_score, vec_temp = MultivariateCovarianceAnalysis_SingleLabel(Y,X)
    return cov_score
    

def ReshapeResults(cov_scores, vecs, pvals, rows, data_shape, pth=2.): 
    '''
    This function reshape the results of the statistical procedures. 
    The multivariate analysis is only done on a certain set of pixels, which are indicated in rows. 
    This function puts back the results in the data_shape form by setting everything else to 0. 
    '''
    # constructing the covariance image. 
    CImage = np.zeros(data_shape)
    CImage[rows] = cov_scores[:,0] 
    
    
    # pval maps for correlation strength
    #logPvalImage = np.zeros(data_shape)
    #row_zeros = cov_scores[:,0] == 0
    #log_p_values = -np.log(pvals) / np.log(10.)
    #log_p_values[row_zeros] = 0.
    #logPvalImage[rows] = log_p_values * np.sign(cov_scores[:,0])
    # thresholding the pvalue map
    #logPvalImage_th = (np.abs(logPvalImage) > pth).astype(float) * logPvalImage
    logPvalImage, logPvalImage_th = ReshapePValImage(cov_scores[:,0], pvals, rows, data_shape, pth=pth)

    # getting the vectors
    numChannels = vecs.shape[0]
    channels = np.zeros(np.append(numChannels, data_shape))
    channels_th = np.zeros(np.append(numChannels, data_shape))
    tmpchannels = np.zeros(data_shape)
    for c in range(numChannels):
        tmpchannels[rows] = vecs[c,:,0]
        channels[c,:] = tmpchannels.copy()
        channels_th[c,:] = channels[c,:] * (np.abs(logPvalImage) > pth).astype(float)
        tmpchannels = np.zeros(data_shape)
    
    return CImage, logPvalImage, logPvalImage_th, channels, channels_th

def ReshapePValImage(stats, pvals, rows, data_shape, pth=2.): 
    logPvalImage = np.zeros(data_shape)
    row_zeros = stats == 0
    log_p_values = -np.log(pvals) / np.log(10.)
    log_p_values[row_zeros] = 0.
    logPvalImage[rows] = log_p_values * np.sign(stats)
    # thresholding the pvalue map
    logPvalImage_th = (np.abs(logPvalImage) > pth).astype(float) * logPvalImage
    return logPvalImage, logPvalImage_th

def WriteResultsInTxt(cov_scores, vecs, pvals, pvalth, imageName, data_shape, rows):
    '''
    This function simply writes down the results 
    for the synthetic analysis.

	The synthetic analysis is 100 by 100 images. 
    
    cov_scores are the covariance values per point. 

    vecs are the vectors per point. 

    pvals are the permutation based pvalues 
    pvalth is the pval threshold

    imageName is the main acronym for the image naming
    e.g. CN_Aging
    This can also be given as a folder
    e.g. CN_Analysis/Aging
    '''
    # reshaping the results
    CImage, logPvalImage, logPvalImage_th, channels, channels_th = ReshapeResults(cov_scores, vecs, pvals, rows, data_shape, pth=pvalth)
    
    # writing the correlation strength
    name = imageName + 'covariance_strength.txt'
    np.savetxt(name, CImage.reshape(np.prod(data_shape)), fmt='%1.4f')

    # writing down the pval maps for the correlation strength
    name = imageName + 'pvals_covariance.txt'
    np.savetxt(name, logPvalImage.reshape(np.prod(data_shape)))
    name = imageName + 'pvals_covariance_th.txt'
    np.savetxt(name, logPvalImage_th.reshape(np.prod(data_shape)))

    # writing down the vectors
    for c in range(channels_th.shape[0]): 
        name = imageName + 'vectors_th_{0}.txt'.format(c)
        np.savetxt(name, channels_th[c,:].reshape(np.prod(data_shape)))
    
    # writing down the data_shape
    name = imageName + 'data_info.txt'
    np.savetxt(name, np.append(data_shape, channels_th.shape[0]))

def WritePValResultsInTxt(stats, pvals, pvalth, imageName, data_shape, rows):
    logPvalImage, logPvalImage_th = ReshapePValImage(stats, pvals, rows, data_shape, pth=pvalth)
    name = imageName + 'pvals.txt'
    np.savetxt(name, logPvalImage.reshape(np.prod(data_shape)))
    name = imageName + 'pvals_th.txt'
    np.savetxt(name, logPvalImage_th.reshape(np.prod(data_shape)))

def WriteResultsInBinary(cov_scores, vecs, pvals, pvalth, imageName, data_shape, rows): 
    # reshaping the results
    CImage, logPvalImage, logPvalImage_th, channels, channels_th = ReshapeResults(cov_scores, vecs, pvals, rows, data_shape, pth=pvalth)
    
    # writing the correlation strength
    fid = open(imageName + 'covariance_strength.dat', 'w')
    CImage.tofile(fid)
    fid.close()

    # writing down the pval maps for the correlation strength
    fid = open(imageName + 'pvals_covariance.dat', 'w')
    logPvalImage.tofile(fid)
    fid.close()
    fid = open(imageName + 'pvals_covariance_th.dat', 'w')
    logPvalImage_th.tofile(fid)
    fid.close()

    # writing down the vectors
    fid = open(imageName + 'vectors_th.dat', 'w')
    channels_th.tofile(fid)
    fid.close()
    
    # writing down the data_shape and number of channels
    name = imageName + 'data_info.txt'
    np.savetxt(name, np.append(data_shape, channels_th.shape[0]))
    
def WritePValResultsInBinary(stats, pvals, pvalth, imageName, data_shape, rows):
    logPvalImage, logPvalImage_th = ReshapePValImage(stats, pvals, rows, data_shape, pth=pvalth)
    fid = open(imageName + 'pvals.dat', 'w')
    logPvalImage.tofile(fid)
    fid.close()
    fid = open(imageName + 'pvals_th.dat', 'w')
    logPvalImage_th.tofile(fid)
    fid.close()
    

def WriteImage(template, rows, data, imageName, bg=0.0): 
    CImage = np.ones(template.get_data().shape)*bg
    CImage[rows] = data
    nib.save(nib.Nifti1Image(CImage,template.get_affine()),imageName)
    
def WritePValImage(template, rows, data, pvalth, imageName): 
    data[data>1.] = 1.
    logPvalImage = np.zeros(template.get_data().shape)
    log_p_values = -np.log(data) / np.log(10.0)
    logPvalImage[rows] = log_p_values
    logPvalImage_ = (logPvalImage > pvalth).astype(float) * logPvalImage
    nib.save(nib.Nifti1Image(logPvalImage_, template.get_affine()), imageName)
    return logPvalImage_

def WriteSyntheticImage(data, imageName, bg=0.0): 
    CImage = data
    np.savetxt(imageName, CImage)
    
def WriteSyntheticPValImage(data, pvalth, imageName): 
    data[data>1.] = 1.
    log_p_values = -np.log(data) / np.log(10.0)
    logPvalImage = log_p_values
    logPvalImage_ = (logPvalImage > pvalth).astype(float) * logPvalImage
    np.savetxt(imageName, logPvalImage_)

def TFCE3D_Update(statvals, rows, data_shape, folder='./tmp'):

    # writing down the correlation strength map
    CImage = np.zeros(data_shape)
    CImage[rows] = statvals
    bbox = np.where(rows)
    lx,ly,lz = min(bbox[0]), min(bbox[1]), min(bbox[2])
    ux,uy,uz = max(bbox[0]), max(bbox[1]), max(bbox[2])

    Res = np.zeros(CImage.shape)
    Res[lx:ux,ly:uy,lz:uz] = tfce3D(CImage[lx:ux,ly:uy,lz:uz])
    updated_stats = Res[rows]
    return updated_stats

def TFCE2D_Update(statvals, rows, data_shape, folder='./tmp'):

    # writing down the correlation strength map
    CImage = np.zeros(data_shape)
    CImage[rows] = statvals
    bbox = np.where(rows)
    lx,ly = min(bbox[0]), min(bbox[1])
    ux,uy = max(bbox[0]), max(bbox[1])
    
    Res = np.zeros(CImage.shape)
    Res[lx:ux,ly:uy] = tfce2D(CImage[lx:ux,ly:uy])
    updated_stats = Res[rows]
    return updated_stats

def TFCE1D_Update(statvals, rows, data_shape, folder='./tmp'):

    # writing down the correlation strength map
    CImage = np.zeros(data_shape)
    CImage[rows] = statvals
    bbox = np.where(rows)
    lx = min(bbox[0])
    ux = max(bbox[0])
    
    Res = np.zeros(CImage.shape)
    Res[lx:ux] = tfce1D(CImage[lx:ux])
    updated_stats = Res[rows]
    return updated_stats

    
# motivated from https://github.com/Mouse-Imaging-Centre/minc-stuffs/blob/master/python/TFCE
def tfce3D(invol, dh=0.1, E=0.5, H=2.0):
    struct = np.ones([3,3,3])
    outvol = np.zeros(invol.shape)
    for h in np.arange(0, invol.max(), dh): 
        thresh = (invol > h).astype(np.uint8)
        l = ndimage.label(thresh,structure = struct)
 #       print "number of Labels: ", l[1]
        sizes = ndimage.sum(thresh, l[0], range(l[1]+1))

        if l[1] > 0: 
            #print h, l[1]
            code = """
            for (int x = 0; x < nx; x++){
                for (int y = 0; y < ny; y++){
                    for (int z = 0; z < nz; z++){
                        if (labeled(x,y,z) > 0){
                            int e = labeled(x,y,z);
                            outvol(x,y,z) += (pow(h, H) * pow(sizes(e), E) * dh);
                        }
                    }
                }
            }
            """
            nx, ny, nz = outvol.shape
            labeled = l[0]
            weave.inline(code, ['outvol', 'nx', 'ny', 'nz', 'labeled', 'h', 'H',\
                'E', 'sizes', 'dh'], type_converters=weave.converters.blitz,\
                compiler='gcc')        
    return outvol
    
def tfce2D(invol, dh=0.1, E=0.5, H=2.0):
    struct = np.ones([3,3])
    outvol = np.zeros(invol.shape)
    for h in np.arange(0, invol.max(), dh): 
        thresh = (invol > h).astype(np.uint8)
        l = ndimage.label(thresh,structure = struct)
 #       print "number of Labels: ", l[1]
        sizes = ndimage.sum(thresh, l[0], range(l[1]+1))

        if l[1] > 0: 
            #print h, l[1]
            code = """
            for (int x = 0; x < nx; x++){
                for (int y = 0; y < ny; y++){            
                    if (labeled(x,y) > 0){
                        int e = labeled(x,y);
                        outvol(x,y) += (pow(h, H) * pow(sizes(e), E) * dh);
                    }
                }
            }
            """
            nx, ny = outvol.shape
            labeled = l[0]
            weave.inline(code, ['outvol', 'nx', 'ny', 'labeled', 'h', 'H',\
                'E', 'sizes', 'dh'], type_converters=weave.converters.blitz,\
                compiler='gcc')        
    return outvol

def tfce1D(invol, dh=0.1, E=0.5, H=2.0):
    struct = np.ones(3)
    outvol = np.zeros(invol.shape)
    for h in np.arange(0, invol.max(), dh): 
        thresh = (invol > h).astype(np.uint8)
        l = ndimage.label(thresh,structure = struct)
        sizes = ndimage.sum(thresh, l[0], range(l[1]+1))

        if l[1] > 0: 
            code = """
            for (int x = 0; x < nx; x++){
                if (labeled(x) > 0){
                    int e = labeled(x);
                    outvol(x) += (pow(h, H) * pow(sizes(e), E) * dh);
                }
            }
            """
            nx = outvol.size
            labeled = l[0]
            weave.inline(code, ['outvol', 'nx', 'labeled', 'h', 'H',\
                'E', 'sizes', 'dh'], type_converters=weave.converters.blitz,\
                compiler='gcc')        
    return outvol


