diff --git a/elephant/asset/asset.py b/elephant/asset/asset.py index afff77eaf..e4d7a6fed 100644 --- a/elephant/asset/asset.py +++ b/elephant/asset/asset.py @@ -159,7 +159,6 @@ size = 1 rank = 0 - __all__ = [ "ASSET", "synchronous_events_intersection", @@ -529,7 +528,7 @@ def calculate_stretch_mat(theta_mat, D_mat): chunk_size = D_chunk.shape[0] assert (chunk_size == estimated_chunk or - chunk_size == last_chunk) # Safety check + chunk_size == last_chunk) # Safety check dX = x_array[:, start: start + chunk_size].T - x_array dY = y_array[:, start: start + chunk_size].T - y_array @@ -606,6 +605,7 @@ class _GPUBackend: Python objects, PyOpenCL and PyCUDA clean up and free allocated memory automatically when garbage collection is executed. """ + def __init__(self, max_chunk_size=None): self.max_chunk_size = max_chunk_size @@ -926,7 +926,7 @@ def pycuda(self, log_du): device = pycuda.autoinit.device max_l_block = device.MAX_SHARED_MEMORY_PER_BLOCK // ( - self.dtype.itemsize * (self.d + 2)) + self.dtype.itemsize * (self.d + 2)) n_threads = min(self.cuda_threads, max_l_block, device.MAX_THREADS_PER_BLOCK) if n_threads > device.WARP_SIZE: @@ -1965,6 +1965,13 @@ class ASSET(object): If None, the attribute `t_stop` of the spike trains is used (if the same for all spike trains). Default: None + bin_tolerance : float or 'default' or None, optional + Defines the tolerance value for rounding errors when binning the + spike trains. If 'default', the value is the default as defined in + :class:`~.conversion.BinnedSpikeTrain`. If None, no correction for + binning errors is performed. If a number, the binning will consider + this value. + Default: 'default' verbose : bool, optional If True, print messages and show progress bar. Default: True @@ -1978,11 +1985,15 @@ class ASSET(object): fully disjoint. + See Also + -------- + :class:`elephant.conversion.BinnedSpikeTrain` + """ def __init__(self, spiketrains_i, spiketrains_j=None, bin_size=3 * pq.ms, t_start_i=None, t_start_j=None, t_stop_i=None, t_stop_j=None, - verbose=True): + bin_tolerance='default', verbose=True): self.spiketrains_i = spiketrains_i if spiketrains_j is None: spiketrains_j = spiketrains_i @@ -2014,13 +2025,21 @@ def __init__(self, spiketrains_i, spiketrains_j=None, bin_size=3 * pq.ms, or (self.t_start_i < self.t_stop_j < self.t_stop_i): raise ValueError(msg) + # Define the tolerance parameter for binning. + # If `bin_tolerance` is 'default', `conv.BinnedSpikeTrain will be + # called without passing the parameter, and it will take what is + # defined by the behavior of that class. Otherwise, set to the value + # specified by `bin_tolerance` + tolerance_param = {'tolerance': bin_tolerance} if \ + bin_tolerance != 'default' else {} + # Compute the binned spike train matrices, along both time axes self.spiketrains_binned_i = conv.BinnedSpikeTrain( self.spiketrains_i, bin_size=self.bin_size, - t_start=self.t_start_i, t_stop=self.t_stop_i) + t_start=self.t_start_i, t_stop=self.t_stop_i, **tolerance_param) self.spiketrains_binned_j = conv.BinnedSpikeTrain( self.spiketrains_j, bin_size=self.bin_size, - t_start=self.t_start_j, t_stop=self.t_stop_j) + t_start=self.t_start_j, t_stop=self.t_stop_j, **tolerance_param) @property def x_edges(self): @@ -2680,8 +2699,8 @@ def cluster_matrix_entries(mask_matrix, max_distance, min_neighbors, file_dir = file_path.parent file_name = file_path.stem mapped_array_file = tempfile.NamedTemporaryFile( - prefix=file_name, dir=file_dir, - delete=not keep_file) + prefix=file_name, dir=file_dir, + delete=not keep_file) # Compute the matrix D[i, j] of euclidean distances between pixels i # and j diff --git a/elephant/test/test_asset.py b/elephant/test/test_asset.py index 432d166b0..22a9ab0db 100644 --- a/elephant/test/test_asset.py +++ b/elephant/test/test_asset.py @@ -46,6 +46,77 @@ HAVE_CUDA = False +class AssetBinningTestCase(unittest.TestCase): + + def setUp(self): + spiketrain_1 = neo.SpikeTrain( + [1.3, 2.1, 3.9999999999, 4.9999], units=pq.ms, t_stop=6*pq.ms) + + spiketrain_2 = neo.SpikeTrain( + [0.9999999999, 1.9999, 4, 5], units=pq.ms, t_stop=6*pq.ms) + + self.spiketrains_i = [spiketrain_1, spiketrain_2] + self.spiketrains_j = [spiketrain_2, spiketrain_1] + + def test_bin_tolerance_default(self): + asset_obj = asset.ASSET(spiketrains_i=self.spiketrains_i, + spiketrains_j=self.spiketrains_j, + bin_size=1*pq.ms) + bins_i = asset_obj.spiketrains_binned_i.to_array() + bins_j = asset_obj.spiketrains_binned_j.to_array() + + # Should shift spikes closer than 1e-8 to the right bin edge. + # This is the current default tolerance for `BinnedSpikeTrain`. + expected_bins_i = np.array( + [[0, 1, 1, 0, 2, 0], + [0, 2, 0, 0, 1, 1]]) + expected_bins_j = np.array( + [[0, 2, 0, 0, 1, 1], + [0, 1, 1, 0, 2, 0]]) + + self.assertTrue(np.array_equal(bins_i, expected_bins_i)) + self.assertTrue(np.array_equal(bins_j, expected_bins_j)) + + def test_bin_tolerance_none(self): + asset_obj = asset.ASSET(spiketrains_i=self.spiketrains_i, + spiketrains_j=self.spiketrains_j, + bin_size=1*pq.ms, + bin_tolerance=None) + bins_i = asset_obj.spiketrains_binned_i.to_array() + bins_j = asset_obj.spiketrains_binned_j.to_array() + + # Should not shift any spikes. Bin should be the same as the integer + # part of the time. + expected_bins_i = np.array( + [[0, 1, 1, 1, 1, 0], + [1, 1, 0, 0, 1, 1]]) + expected_bins_j = np.array( + [[1, 1, 0, 0, 1, 1], + [0, 1, 1, 1, 1, 0]]) + + self.assertTrue(np.array_equal(bins_i, expected_bins_i)) + self.assertTrue(np.array_equal(bins_j, expected_bins_j)) + + def test_bin_tolerance_float(self): + asset_obj = asset.ASSET(spiketrains_i=self.spiketrains_i, + spiketrains_j=self.spiketrains_j, + bin_size=1*pq.ms, + bin_tolerance=1e-3) + bins_i = asset_obj.spiketrains_binned_i.to_array() + bins_j = asset_obj.spiketrains_binned_j.to_array() + + # Should shift spikes closer than 1e-3 to the right bin edge. + expected_bins_i = np.array( + [[0, 1, 1, 0, 1, 1], + [0, 1, 1, 0, 1, 1]]) + expected_bins_j = np.array( + [[0, 1, 1, 0, 1, 1], + [0, 1, 1, 0, 1, 1]]) + + self.assertTrue(np.array_equal(bins_i, expected_bins_i)) + self.assertTrue(np.array_equal(bins_j, expected_bins_j)) + + @unittest.skipUnless(HAVE_SKLEARN, 'requires sklearn') class AssetTestCase(unittest.TestCase):