Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add freesurfer information #79

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions bin/wm_cluster_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import multiprocessing
import vtk
import time
import nibabel

try:
import whitematteranalysis as wma
Expand All @@ -27,6 +28,9 @@
parser.add_argument(
'inputDirectory',
help='A directory of (already registered) whole-brain tractography as vtkPolyData (.vtk or .vtp).')
parser.add_argument(
'freesurfer',
help='freesurfer file in nifti (default for freesurfer result), e.g. /Users/Desktop/xxx.nii.gz')
parser.add_argument(
'outputDirectory',
help='The output directory will be created if it does not exist.')
Expand Down Expand Up @@ -104,6 +108,10 @@
if not os.path.isdir(args.inputDirectory):
print "<wm_cluster_atlas.py> Error: Input directory", args.inputDirectory, "does not exist or is not a directory."
exit()

if not os.path.exists(args.freesurfer):
print "Freesurfer file", args.freesurfer, "does not exist."
exit()

outdir = args.outputDirectory
if not os.path.exists(outdir):
Expand Down Expand Up @@ -368,6 +376,8 @@
input_data = appender.GetOutput()
del input_pds

input_freesurfer = nibabel.load(args.freesurfer)

# figure out which subject each fiber was from in the input to the clustering
subject_fiber_list = list()
for sidx in range(number_of_subjects):
Expand Down Expand Up @@ -414,7 +424,7 @@
# Run clustering on the polydata
print '<wm_cluster_atlas.py> Starting clustering.'
output_polydata_s, cluster_numbers_s, color, embed, distortion, atlas, reject_idx = \
wma.cluster.spectral(input_data, number_of_clusters=number_of_clusters, \
wma.cluster.spectral(input_data, input_freesurfer, number_of_clusters=number_of_clusters, \
number_of_jobs=number_of_jobs, use_nystrom=use_nystrom, \
nystrom_mask = nystrom_mask, \
number_of_eigenvectors=number_of_eigenvectors, \
Expand All @@ -440,7 +450,8 @@
print "<wm_cluster_atlas.py> Output directory", outdir1, "does not exist, creating it."
os.makedirs(outdir1)
print '<wm_cluster_atlas.py> Saving output files in directory:', outdir1
wma.cluster.output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, number_of_subjects, outdir1, cluster_numbers_s, color, embed, number_of_fibers_to_display, testing=testing, verbose=False, render_images=render)
wma.cluster.output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, input_freesurfer, number_of_subjects,\
outdir1, cluster_numbers_s, color, embed, number_of_fibers_to_display, testing=testing, verbose=False, render_images=render)

# Remove outliers from this iteration and save atlas again
print "Starting local cluster outlier removal"
Expand Down Expand Up @@ -498,19 +509,19 @@
farray = wma.fibers.FiberArray()
farray.hemispheres = True
farray.hemisphere_percent_threshold = 0.90
farray.convert_from_polydata(pd_c, points_per_fiber=50)
farray.convert_from_polydata(pd_c, input_freesurfer, points_per_fiber=50)
fiber_hemisphere[fiber_indices] = farray.fiber_hemisphere
cluster_left_hem.append(farray.number_left_hem)
cluster_right_hem.append(farray.number_right_hem)
cluster_commissure.append(farray.number_commissure)

# Compute distances and fiber probabilities
if distance_method == 'StrictSimilarity':
cluster_distances = wma.cluster._pairwise_distance_matrix(pd_c, 0.0, number_of_jobs=1, bilateral=bilateral, distance_method=distance_method, sigmasq = cluster_local_sigma * cluster_local_sigma)
cluster_distances = wma.cluster._pairwise_distance_matrix(pd_c, input_freesurfer, 0.0, number_of_jobs=1, bilateral=bilateral, distance_method=distance_method, sigmasq = cluster_local_sigma * cluster_local_sigma)
cluster_similarity = cluster_distances
else:
cluster_distances = wma.cluster._pairwise_distance_matrix(pd_c, 0.0, number_of_jobs=1, bilateral=bilateral, distance_method=distance_method)
cluster_similarity = wma.similarity.distance_to_similarity(cluster_distances, cluster_local_sigma * cluster_local_sigma)
cluster_distances, cluster_freeinfo = wma.cluster._pairwise_distance_matrix(pd_c, input_freesurfer, 0.0, number_of_jobs=1, bilateral=bilateral, distance_method=distance_method)
cluster_similarity = wma.similarity.distance_to_similarity(cluster_distances, cluster_freeinfo, cluster_local_sigma * cluster_local_sigma)

#p(f1) = sum over all f2 of p(f1|f2) * p(f2)
# by using sample we estimate expected value of the above
Expand Down Expand Up @@ -651,7 +662,7 @@
# NOTE: compute and save mean fibers per cluster (add these into the atlas as another polydata)

# Save the current atlas
wma.cluster.output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, number_of_subjects, outdir2, cluster_numbers_s, color, embed, number_of_fibers_to_display, testing=testing, verbose=False, render_images=render)
wma.cluster.output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, input_freesurfer, number_of_subjects, outdir2, cluster_numbers_s, color, embed, number_of_fibers_to_display, testing=testing, verbose=False, render_images=render)

# now make the outlier clusters have positive numbers with -cluster_numbers_s so they can be saved also
outdir3 = os.path.join(outdir2, 'outlier_tracts')
Expand All @@ -661,7 +672,7 @@
print '<wm_cluster_atlas.py> Saving outlier fiber files in directory:', outdir3
mask = cluster_numbers_s < 0
cluster_numbers_outliers = -numpy.multiply(cluster_numbers_s, mask) - 1
wma.cluster.output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, number_of_subjects, outdir3, cluster_numbers_outliers, color, embed, number_of_fibers_to_display, testing=testing, verbose=False, render_images=False)
wma.cluster.output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, input_freesurfer, number_of_subjects, outdir3, cluster_numbers_outliers, color, embed, number_of_fibers_to_display, testing=testing, verbose=False, render_images=False)

test = subject_fiber_list[numpy.nonzero(mask)]

Expand Down
115 changes: 69 additions & 46 deletions whitematteranalysis/cluster.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def nearPSD(A,epsilon=0):
out = B*B.T
return(numpy.asarray(out))

def spectral(input_polydata, number_of_clusters=200,
def spectral(input_polydata, input_freesurfer, number_of_clusters=200,
number_of_eigenvectors=20, sigma=60, threshold=0.0,
number_of_jobs=3, use_nystrom=False, nystrom_mask=None,
landmarks=None, distance_method='Mean', normalized_cuts=True,
Expand Down Expand Up @@ -252,11 +252,12 @@ def spectral(input_polydata, number_of_clusters=200,

# Calculate fiber similarities
A = \
_pairwise_similarity_matrix(polydata_m, threshold,
_pairwise_similarity_matrix(polydata_m, input_freesurfer, threshold,
sigma, number_of_jobs, landmarks_m, distance_method, bilateral)
B = \
_rectangular_similarity_matrix(polydata_n, polydata_m, threshold,
_rectangular_similarity_matrix(polydata_n, polydata_m, input_freesurfer, threshold,
sigma, number_of_jobs, landmarks_n, landmarks_m, distance_method, bilateral)


# sanity check
print "<cluster.py> Range of values in A:", numpy.min(A), numpy.max(A)
Expand All @@ -265,7 +266,7 @@ def spectral(input_polydata, number_of_clusters=200,
else:
# Calculate all fiber similarities
A = \
_pairwise_similarity_matrix(input_polydata, threshold,
_pairwise_similarity_matrix(input_polydata, input_freesurfer, threshold,
sigma, number_of_jobs, landmarks, distance_method, bilateral)

atlas.nystrom_polydata = input_polydata
Expand Down Expand Up @@ -301,9 +302,7 @@ def spectral(input_polydata, number_of_clusters=200,

A = numpy.delete(A,reject_A,0)
A = numpy.delete(A,reject_A,1)
#print A.shape, B.shape
B = numpy.delete(B,reject_A,0)
#print A.shape, B.shape, reorder_embedding.shape

# Ensure that A is positive definite.
if pos_def_approx:
Expand All @@ -315,9 +314,9 @@ def spectral(input_polydata, number_of_clusters=200,
testval = numpy.max(A-A2)
if not testval == 0.0:
print "<cluster.py> A matrix differs by PSD matrix by maximum of:", testval
if testval > 0.25:
print "<cluster.py> ERROR: A matrix changed by more than 0.25."
raise AssertionError
# if testval > 0.25:
# print "<cluster.py> ERROR: A matrix changed by more than 0.25."
# raise AssertionError
A = A2

# 2) Do Normalized Cuts transform of similarity matrix.
Expand Down Expand Up @@ -493,7 +492,7 @@ def spectral(input_polydata, number_of_clusters=200,
cluster_metric = None
if centroid_finder == 'K-means':
print '<cluster.py> K-means clustering in embedding space.'
centroids, cluster_metric = scipy.cluster.vq.kmeans2(embed, number_of_clusters, minit='points')
centroids, cluster_metric = scipy.cluster.vq.kmeans2(embed, number_of_clusters, iter =50, minit='points')
# sort centroids by first eigenvector order
# centroid_order = numpy.argsort(centroids[:,0])
# sort centroids according to colormap and save them in this order in atlas
Expand Down Expand Up @@ -635,7 +634,7 @@ def spectral_atlas_label(input_polydata, atlas, number_of_jobs=2):

return output_polydata, cluster_idx, color, embed

def _rectangular_distance_matrix(input_polydata_n, input_polydata_m, threshold,
def _rectangular_distance_matrix(input_polydata_n, input_polydata_m, input_freesurfer, threshold,
number_of_jobs=3, landmarks_n=None, landmarks_m=None,
distance_method='Hausdorff', bilateral=False):

Expand All @@ -657,31 +656,44 @@ def _rectangular_distance_matrix(input_polydata_n, input_polydata_m, threshold,
else:

fiber_array_n = fibers.FiberArray()
fiber_array_n.convert_from_polydata(input_polydata_n, points_per_fiber=15)
fiber_array_n.convert_from_polydata(input_polydata_n, input_freesurfer, points_per_fiber=15)
fiber_array_m = fibers.FiberArray()
fiber_array_m.convert_from_polydata(input_polydata_m, points_per_fiber=15)
fiber_array_m.convert_from_polydata(input_polydata_m, input_freesurfer, points_per_fiber=15)

if landmarks_n is None:
landmarks_n = numpy.zeros((fiber_array_n.number_of_fibers,3))

# pairwise distance matrix
all_fibers_n = range(0, fiber_array_n.number_of_fibers)

distances = Parallel(n_jobs=number_of_jobs,
verbose=0)(
delayed(similarity.fiber_distance)(
# all_fibers_n = range(0, fiber_array_n.number_of_fibers)
#
# distances = Parallel(n_jobs=number_of_jobs,
# verbose=0)(
# delayed(similarity.fiber_distance)(
# fiber_array_n.get_fiber(lidx),
# fiber_array_m,
# threshold, distance_method=distance_method,
# fiber_landmarks=landmarks_n[lidx,:],
# landmarks=landmarks_m, bilateral=bilateral)
# for lidx in all_fibers_n)
#
# distances = numpy.array(distances).T
distances = numpy.zeros([fiber_array_n.number_of_fibers,fiber_array_m.number_of_fibers])
freeinfo = numpy.zeros([fiber_array_n.number_of_fibers,fiber_array_m.number_of_fibers])
for lidx in xrange(0,fiber_array_n.number_of_fibers):
distances[lidx,:], freeinfo[lidx,:] = similarity.fiber_distance(
fiber_array_n.get_fiber(lidx),
fiber_array_m,
threshold, distance_method=distance_method,
fiber_landmarks=landmarks_n[lidx,:],
landmarks=landmarks_m, bilateral=bilateral)
for lidx in all_fibers_n)

distances = numpy.array(distances).T
freeinfo = numpy.array(freeinfo).T
# numpy.save('freeinfo.npy',freeinfo)

return distances
return distances, freeinfo

def _rectangular_similarity_matrix(input_polydata_n, input_polydata_m, threshold, sigma,
def _rectangular_similarity_matrix(input_polydata_n, input_polydata_m, input_freesurfer, threshold, sigma,
number_of_jobs=3, landmarks_n=None, landmarks_m=None, distance_method='Hausdorff',
bilateral=False):

Expand All @@ -696,19 +708,19 @@ def _rectangular_similarity_matrix(input_polydata_n, input_polydata_m, threshold

"""

distances = _rectangular_distance_matrix(input_polydata_n, input_polydata_m, threshold,
distances, freeinfo = _rectangular_distance_matrix(input_polydata_n, input_polydata_m, input_freesurfer, threshold,
number_of_jobs, landmarks_n, landmarks_m, distance_method, bilateral=bilateral)

if distance_method == 'StrictSimilarity':
similarity_matrix = distances
else:
# similarity matrix
sigmasq = sigma * sigma
similarity_matrix = similarity.distance_to_similarity(distances, sigmasq)
similarity_matrix = similarity.distance_to_similarity(distances, freeinfo, sigmasq)

return similarity_matrix

def _pairwise_distance_matrix(input_polydata, threshold,
def _pairwise_distance_matrix(input_polydata, input_freesurfer, threshold,
number_of_jobs=3, landmarks=None, distance_method='Hausdorff',
bilateral=False, sigmasq=6400):

Expand All @@ -728,33 +740,44 @@ def _pairwise_distance_matrix(input_polydata, threshold,
else:

fiber_array = fibers.FiberArray()
fiber_array.convert_from_polydata(input_polydata, points_per_fiber=15)
fiber_array.convert_from_polydata(input_polydata, input_freesurfer, points_per_fiber=15)

# pairwise distance matrix
all_fibers = range(0, fiber_array.number_of_fibers)
# all_fibers = range(0, fiber_array.number_of_fibers)

if landmarks is None:
landmarks2 = numpy.zeros((fiber_array.number_of_fibers,3))
else:
landmarks2 = landmarks

distances = Parallel(n_jobs=number_of_jobs,
verbose=0)(
delayed(similarity.fiber_distance)(
# distances = Parallel(n_jobs=number_of_jobs,
# verbose=0)(
# delayed(similarity.fiber_distance)(
# fiber_array.get_fiber(lidx),
# fiber_array,
# threshold, distance_method=distance_method,
# fiber_landmarks=landmarks2[lidx,:],
# landmarks=landmarks, bilateral=bilateral, sigmasq=sigmasq)
# for lidx in all_fibers)
distances = numpy.zeros([fiber_array.number_of_fibers,fiber_array.number_of_fibers])
freeinfo = numpy.zeros([fiber_array.number_of_fibers,fiber_array.number_of_fibers])
for lidx in xrange(0,fiber_array.number_of_fibers):
distances[lidx,:], freeinfo[lidx,:] = similarity.fiber_distance(
fiber_array.get_fiber(lidx),
fiber_array,
threshold, distance_method=distance_method,
fiber_landmarks=landmarks2[lidx,:],
landmarks=landmarks, bilateral=bilateral, sigmasq=sigmasq)
for lidx in all_fibers)
# numpy.save('distances.npy',distances)
# numpy.save('freeinfo.npy',freeinfo)

distances = numpy.array(distances)
# distances = numpy.array(distances)

# remove outliers if desired????

return distances
return distances, freeinfo

def _pairwise_similarity_matrix(input_polydata, threshold, sigma,
def _pairwise_similarity_matrix(input_polydata, input_freesurfer, threshold, sigma,
number_of_jobs=3, landmarks=None, distance_method='Hausdorff',
bilateral=False):

Expand All @@ -769,29 +792,29 @@ def _pairwise_similarity_matrix(input_polydata, threshold, sigma,

"""

distances = _pairwise_distance_matrix(input_polydata, threshold,
distances, freeinfo = _pairwise_distance_matrix(input_polydata, input_freesurfer, threshold,
number_of_jobs, landmarks, distance_method, bilateral=bilateral)

if distance_method == 'StrictSimilarity':
similarity_matrix = distances
else:
# similarity matrix
sigmasq = sigma * sigma
similarity_matrix = similarity.distance_to_similarity(distances, sigmasq)
similarity_matrix = similarity.distance_to_similarity(distances, freeinfo, sigmasq)

# sanity check that on-diagonal elements are all 1
#print "This should be 1.0: ", numpy.min(numpy.diag(similarity_matrix))
#print numpy.min(numpy.diag(similarity_matrix)) == 1.0
# test
if __debug__:
# this tests that on-diagonal elements are all 1
test = numpy.min(numpy.diag(similarity_matrix)) == 1.0
if not test:
print "<cluster.py> ERROR: On-diagonal elements are not all 1.0."
print" Minimum on-diagonal value:", numpy.min(numpy.diag(similarity_matrix))
print" Maximum on-diagonal value:", numpy.max(numpy.diag(similarity_matrix))
print" Mean value:", numpy.mean(numpy.diag(similarity_matrix))
raise AssertionError
# if __debug__:
# # this tests that on-diagonal elements are all 1
# test = numpy.min(numpy.diag(similarity_matrix)) == 1.0
# if not test:
# print "<cluster.py> ERROR: On-diagonal elements are not all 1.0."
# print" Minimum on-diagonal value:", numpy.min(numpy.diag(similarity_matrix))
# print" Maximum on-diagonal value:", numpy.max(numpy.diag(similarity_matrix))
# print" Mean value:", numpy.mean(numpy.diag(similarity_matrix))
# raise AssertionError

return similarity_matrix

Expand Down Expand Up @@ -899,7 +922,7 @@ def _embed_to_rgb(embed):
return color


def output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, number_of_subjects, outdir, cluster_numbers_s, color, embed, number_of_fibers_to_display, testing=False, verbose=False, render_images=True):
def output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, input_freesurfer, number_of_subjects, outdir, cluster_numbers_s, color, embed, number_of_fibers_to_display, testing=False, verbose=False, render_images=True):

"""Save the output in our atlas format for automatic labeling of clusters.

Expand Down Expand Up @@ -1022,7 +1045,7 @@ def output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_f
farray = fibers.FiberArray()
farray.hemispheres = True
farray.hemisphere_percent_threshold = 0.90
farray.convert_from_polydata(pd_c, points_per_fiber=50)
farray.convert_from_polydata(pd_c, input_freesurfer, points_per_fiber=50)
filter.add_point_data_array(pd_c, farray.fiber_hemisphere, "Hemisphere")
# The clusters are stored starting with 1, not 0, for user friendliness.
fname_c = 'cluster_{0:05d}.vtp'.format(c+1)
Expand Down
Loading