Skip to content

Commit

Permalink
Changed the delay calculation to only include the first spike
Browse files Browse the repository at this point in the history
This speeds up the generation of pre curves and also avoids generating useless curves
  • Loading branch information
pherbers committed Mar 8, 2016
1 parent 4408af2 commit ae85b33
Showing 1 changed file with 28 additions and 34 deletions.
62 changes: 28 additions & 34 deletions pam/pam_anim/pam_anim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SPIKE_GROUP_NAME = "SPIKES"

SPIKE_OBJECTS = {}
SPIKE_OBJECTS_PRE = {}
CURVES = {}
CURVES_PRE = {}
TIMING_COLORS = []
Expand Down Expand Up @@ -58,14 +59,18 @@ def visualize(self):
try:
self.curveObject = pam_vis.visualizeOneConnectionPost(self.connectionID, self.sourceNeuronID, self.targetNeuronID,
bpy.context.scene.pam_visualize.smoothing)

if self.curvePre.curveObject is None:
self.curvePre.visualize(self.curveObject)

self.curveObject.name = "curve.%d_%d_%d" % (self.connectionID, self.sourceNeuronID, self.targetNeuronID)
bpy.data.groups[PATHS_GROUP_NAME].objects.link(self.curveObject)

self.curveObject.data.resolution_u = bpy.context.scene.pam_anim_mesh.path_bevel_resolution
frameLength = timeToFrames(self.timeLength)
frameLengthPost = timeToFrames(self.curvePre.timeLength - self.curvePre.post_delay)

# setAnimationSpeed(self.curveObject.data, frameLength)
# self.curveObject.data["timeLength"] = frameLength
setAnimationSpeed(self.curveObject.data, frameLengthPost)
self.curveObject.data["timeLength"] = frameLengthPost
except Exception as e:
logger.info(traceback.format_exc())
logger.error("Failed to visualize post connection " + str((self.connectionID, self.sourceNeuronID, self.targetNeuronID)))
Expand All @@ -78,22 +83,9 @@ def __init__(self, connectionID, sourceNeuronID, timeLength):
self.curves_post = []
self.timeLength = timeLength

def visualize(self, max_connections = 0):
def visualize(self, post_curve):

max_connections = max_connections or len(self.curves_post)
max_connections = min(max_connections, len(self.curves_post))

dist_mean = 0.0
count = 0
for curve_post in self.curves_post[:max_connections]:
curve_post.visualize()
if curve_post.curveObject:
count += 1
dist_mean += pam_vis.calculatePathLength(curve_post.curveObject)

if count == 0:
return
dist_mean /= count
dist_post = pam_vis.calculatePathLength(post_curve)

try:
self.curveObject = pam_vis.visualizeOneConnectionPre(self.connectionID, self.sourceNeuronID,
Expand All @@ -110,18 +102,18 @@ def visualize(self, max_connections = 0):

dist_pre = pam_vis.calculatePathLength(self.curveObject)

fork_percent = dist_pre / (dist_pre + dist_mean)
fork_percent = dist_pre / (dist_pre + dist_post)
fork_time = fork_percent * self.timeLength

frameLengthPre = timeToFrames(fork_time)
frameLengthPost = timeToFrames(self.timeLength - fork_time)

self.post_delay = fork_time

for curve_post in self.curves_post:
if curve_post.curveObject:
setAnimationSpeed(curve_post.curveObject.data, frameLengthPost)
curve_post.curveObject.data["timeLength"] = frameLengthPost
# for curve_post in self.curves_post:
# if curve_post.curveObject:
# setAnimationSpeed(curve_post.curveObject.data, frameLengthPost)
# curve_post.curveObject.data["timeLength"] = frameLengthPost

self.curveObject.data["timeLength"] = frameLengthPre
setAnimationSpeed(self.curveObject.data, frameLengthPre)
Expand Down Expand Up @@ -154,14 +146,14 @@ def __init__(self, connectionID, sourceNeuronID, targetNeuronID, targetNeuronInd
self.targetNeuronIndex = targetNeuronIndex
self.timingID = timingID

def visualizeCurve(self, max_connections = 0):
if self.curve_pre.curveObject is None:
self.curve_pre.visualize(max_connections)
def visualizeCurve(self):
if self.curve.curveObject is None:
self.curve.visualize()

def getDelay(self):
return self.curve_pre.post_delay

def visualize(self, meshObject, orientationOptions = {'orientationType': 'NONE'}, max_connections = 0):
def visualize(self, meshObject, orientationOptions = {'orientationType': 'NONE'}):
"""Generates an object for this spike
This function generates a curve object for it's connection if there is none.
Expand All @@ -173,7 +165,7 @@ def visualize(self, meshObject, orientationOptions = {'orientationType': 'NONE'}
orientationObject: bpy.types.Object, Only used for orientationType OBJECT
"""

self.visualizeCurve(max_connections)
self.visualizeCurve()

if self.curve.curveObject is None:
logger.error("No curve object to attatch to for spike " + self.__repr__() + "!")
Expand Down Expand Up @@ -234,9 +226,8 @@ def __init__(self, connectionID, sourceNeuronID, timingID, curve, startTime):
self.sourceNeuronID = sourceNeuronID
self.timingID = timingID

def visualizeCurve(self, max_connections = 0):
if self.curve.curveObject is None:
self.curve.visualize(max_connections)
def visualizeCurve(self):
pass

def getDelay(self):
return 0
Expand Down Expand Up @@ -296,7 +287,7 @@ def simulateTiming(timingID):

if at_least_one:
# distance = data.DELAYS[connectionID][neuronID][targetNeuronIndex]
SPIKE_OBJECTS[(curve_key, timingID)] = SpikeObjectPre(connectionID[0], neuronID, timingID, curve_pre, fireTime)
SPIKE_OBJECTS_PRE[(curve_key, timingID)] = SpikeObjectPre(connectionID[0], neuronID, timingID, curve_pre, fireTime)


def simulateConnection(connectionID, sourceNeuronID, targetNeuronIndex, timingID):
Expand Down Expand Up @@ -333,6 +324,7 @@ def simulateConnection(connectionID, sourceNeuronID, targetNeuronIndex, timingID
CURVES_PRE[curve_key_pre] = curve

curve_pre.curves_post.append(curve)
curve.curvePre = curve_pre

fireTime = data.TIMINGS[timingID][2]
SPIKE_OBJECTS[(curveKey, timingID)] = SpikeObject(connectionID, sourceNeuronID, targetNeuronID, targetNeuronIndex, timingID, curve, curve_pre, fireTime)
Expand Down Expand Up @@ -465,7 +457,7 @@ def simulateColors(labelController = None):
heapq.heappush(neuronUpdateQueue, (updateTime, connectionID[2], i, layerValuesDecay))

if at_least_one:
obj = SPIKE_OBJECTS[((connectionID[0], neuronID), timingID)]
obj = SPIKE_OBJECTS_PRE[((connectionID[0], neuronID), timingID)]
if obj.object:
obj.object.color = color
obj.object['spiking_labels'] = str(layerValuesDecay)
Expand Down Expand Up @@ -542,7 +534,9 @@ def generateAllTimings(frameStart = 0, frameEnd = 250, maxConns = 0, showPercent
continue

logger.info("Generating spike " + str(i) + "/" + str(total) + ": " + str(spike))
spike.visualize(bpy.data.objects[bpy.context.scene.pam_anim_mesh.mesh], bpy.context.scene.pam_anim_mesh, max_connections = maxConns)
spike.visualize(bpy.data.objects[bpy.context.scene.pam_anim_mesh.mesh], bpy.context.scene.pam_anim_mesh)
if SPIKE_OBJECTS_PRE[(key[0][0], key[0][1]), key[1]].object is None:
SPIKE_OBJECTS_PRE[(key[0][0], key[0][1]), key[1]].visualize(bpy.data.objects[bpy.context.scene.pam_anim_mesh.mesh], bpy.context.scene.pam_anim_mesh)


wm.progress_end()
Expand Down

0 comments on commit ae85b33

Please sign in to comment.