Skip to content

Commit

Permalink
Updated
Browse files Browse the repository at this point in the history
  • Loading branch information
Hjorthmedh committed Oct 6, 2023
1 parent 0c3562b commit 806bbef
Showing 1 changed file with 73 additions and 12 deletions.
85 changes: 73 additions & 12 deletions snudda/place/region_mesh_redux.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import numexpr
import numpy as np
from scipy.spatial import cKDTree
import open3d as o3d
Expand All @@ -16,16 +16,20 @@ def __init__(self, mesh_path):
scale_factor = 1e-6
self.mesh.scale(scale_factor, center=(0, 0, 0))

self.mesh.remove_non_manifold_edges()

self.min_coord = self.mesh.get_min_bound()
self.max_coord = self.mesh.get_max_bound()

self.scene = o3d.t.geometry.RaycastingScene()
legacy_mesh = o3d.t.geometry.TriangleMesh.from_legacy(self.mesh)
self.scene.add_triangles(legacy_mesh)

# SHALL WE PICk NUMBER OF PUTATIVE POINTS BASED ON VOLUME???

self.scene.add_triangles(legacy_mesh)

def check_inside(self, points):
""" Check if points are inside, returns bool array."""

# http://www.open3d.org/docs/latest/tutorial/geometry/distance_queries.html
query_point = o3d.core.Tensor(points, dtype=o3d.core.Dtype.Float32)
Expand Down Expand Up @@ -67,16 +71,23 @@ def __init__(self, mesh_path: str, d_min: float, seed=None, rng=None, n_putative
putative_points = self.remove_outside(putative_points)

self.putative_points = putative_points
self.allocated_points = None
self.allocated_points = np.zeros(shape=(putative_points.shape[0],), dtype=bool)

def plot_putative_points(self):

self.plot_points(points=self.putative_points)

def plot_points(self, points, colour=None):
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
x, y, z = self.putative_points.T
ax.scatter(x, y, z, marker='.')
x, y, z = points.T
ax.scatter(x, y, z, marker='.', color=colour)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.ion()
plt.show()

def get_point_cloud(self, n):
Expand Down Expand Up @@ -135,15 +146,44 @@ def remove_outside(self, points):

return points[keep_flag, :]

def set_neuron_density(self, neuron_type, neuron_density):
def get_neuron_positions(self, n_positions, neuron_density=None):

pass
""" neuron_density either None (even), or a str representing a function f(x,y,z)
where x,y,z are the coordinates in meters """

def place_neurons(self, neuron_type, number_of_neurons):
# We have the putative_points, pick positions from them, based on neuron density
# then update the allocated points

# använd np.choice ?
# 1. Calculate the distance to the closest free (non-allocated) neighbour for all points
# This is so that if we allocated neurons, we can correct for that, and get flat gradients

free_positions = self.putative_points[~self.allocated_points]

# k=2, since we don't want distance to point itself, but closest neighbour
# closest_distance, _ = cKDTree(data=free_positions).query(x=free_positions, k=2)
closest_distance, _ = cKDTree(data=free_positions).query(x=free_positions, k=2)

# Volume is proportional to distance**3, so scale probabilities to pick position by that
free_volume = np.power(np.mean(closest_distance[:, 1:2], axis=1), 3)
x, y, z = free_positions.T

if neuron_density:
# TODO: Temp disabled volume... still does not seem to work
P_neuron = np.multiply(numexpr.evaluate(neuron_density), free_volume)
# P_neuron = numexpr.evaluate(neuron_density)
else:
P_neuron = free_volume

P_neuron /= np.sum(P_neuron)

idx = self.rng.choice(len(free_positions), n_positions, p=P_neuron, replace=False)

neuron_positions = free_positions[idx, :]
used_idx = np.where(~self.allocated_points)[0][idx]
self.allocated_points[used_idx] = True

return neuron_positions

pass


class NeuronBender:
Expand All @@ -169,8 +209,29 @@ def __init__(self):

mesh_path="/home/hjorth/HBP/Snudda/snudda/data/mesh/Striatum-dorsal-right-hemisphere.obj"

nep = NeuronPlacer(mesh_path=mesh_path, d_min=10e-6)
nep.plot_putative_points()
nep = NeuronPlacer(mesh_path=mesh_path, d_min=10e-6, n_putative_points=10000000)
# nep.plot_putative_points()

points_flat = nep.get_neuron_positions(5000)
nep.plot_points(points_flat, colour="black")

points = nep.get_neuron_positions(200000, neuron_density="exp((y-0.0025)*2000)")
nep.plot_points(points, colour="red")

points_flat2 = nep.get_neuron_positions(5000)
nep.plot_points(points_flat, colour="blue")

import matplotlib.pyplot as plt
plt.figure()
plt.hist(points_flat[:,1], color="black")
plt.hist(points[:,1], color="red", alpha=0.5)
plt.show()

plt.figure()
plt.hist(points_flat[:,1], color="black")
plt.hist(points_flat2[:,1], color="blue", alpha=0.5)
plt.show()


import pdb
pdb.set_trace()

0 comments on commit 806bbef

Please sign in to comment.