diff --git a/.gitignore b/.gitignore index 197e3bb3..71aadef8 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ test/models/__pycache__/* test/network/__pycache__/* test/analysis/__pycache__/* *.pyc +**/*.pyc dist/* logs/* .pytest_cache/* diff --git a/bindsnet/__pycache__/__init__.cpython-310.pyc b/bindsnet/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index b69fecbe..00000000 Binary files a/bindsnet/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/__pycache__/utils.cpython-310.pyc b/bindsnet/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 29d33ba0..00000000 Binary files a/bindsnet/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/analysis/__pycache__/__init__.cpython-310.pyc b/bindsnet/analysis/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 0a28eb47..00000000 Binary files a/bindsnet/analysis/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/analysis/__pycache__/pipeline_analysis.cpython-310.pyc b/bindsnet/analysis/__pycache__/pipeline_analysis.cpython-310.pyc deleted file mode 100644 index 8482a745..00000000 Binary files a/bindsnet/analysis/__pycache__/pipeline_analysis.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/analysis/__pycache__/plotting.cpython-310.pyc b/bindsnet/analysis/__pycache__/plotting.cpython-310.pyc deleted file mode 100644 index 916d80c4..00000000 Binary files a/bindsnet/analysis/__pycache__/plotting.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/analysis/__pycache__/visualization.cpython-310.pyc b/bindsnet/analysis/__pycache__/visualization.cpython-310.pyc deleted file mode 100644 index 2f817c76..00000000 Binary files a/bindsnet/analysis/__pycache__/visualization.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/conversion/__pycache__/__init__.cpython-310.pyc b/bindsnet/conversion/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index feacf5b7..00000000 Binary files a/bindsnet/conversion/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/conversion/__pycache__/conversion.cpython-310.pyc b/bindsnet/conversion/__pycache__/conversion.cpython-310.pyc deleted file mode 100644 index aa962d24..00000000 Binary files a/bindsnet/conversion/__pycache__/conversion.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/conversion/__pycache__/nodes.cpython-310.pyc b/bindsnet/conversion/__pycache__/nodes.cpython-310.pyc deleted file mode 100644 index e9ce3a49..00000000 Binary files a/bindsnet/conversion/__pycache__/nodes.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/conversion/__pycache__/topology.cpython-310.pyc b/bindsnet/conversion/__pycache__/topology.cpython-310.pyc deleted file mode 100644 index 179ac233..00000000 Binary files a/bindsnet/conversion/__pycache__/topology.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/__init__.cpython-310.pyc b/bindsnet/datasets/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index b51f6946..00000000 Binary files a/bindsnet/datasets/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/alov300.cpython-310.pyc b/bindsnet/datasets/__pycache__/alov300.cpython-310.pyc deleted file mode 100644 index 0b984934..00000000 Binary files a/bindsnet/datasets/__pycache__/alov300.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/collate.cpython-310.pyc b/bindsnet/datasets/__pycache__/collate.cpython-310.pyc deleted file mode 100644 index 856bb942..00000000 Binary files a/bindsnet/datasets/__pycache__/collate.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/dataloader.cpython-310.pyc b/bindsnet/datasets/__pycache__/dataloader.cpython-310.pyc deleted file mode 100644 index bcad70c5..00000000 Binary files a/bindsnet/datasets/__pycache__/dataloader.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/davis.cpython-310.pyc b/bindsnet/datasets/__pycache__/davis.cpython-310.pyc deleted file mode 100644 index 60c0ee11..00000000 Binary files a/bindsnet/datasets/__pycache__/davis.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/preprocess.cpython-310.pyc b/bindsnet/datasets/__pycache__/preprocess.cpython-310.pyc deleted file mode 100644 index 432de202..00000000 Binary files a/bindsnet/datasets/__pycache__/preprocess.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/spoken_mnist.cpython-310.pyc b/bindsnet/datasets/__pycache__/spoken_mnist.cpython-310.pyc deleted file mode 100644 index 274a50d7..00000000 Binary files a/bindsnet/datasets/__pycache__/spoken_mnist.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-310.pyc b/bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-310.pyc deleted file mode 100644 index 4d03c625..00000000 Binary files a/bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/encoding/__pycache__/__init__.cpython-310.pyc b/bindsnet/encoding/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 27ae69a3..00000000 Binary files a/bindsnet/encoding/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/encoding/__pycache__/encoders.cpython-310.pyc b/bindsnet/encoding/__pycache__/encoders.cpython-310.pyc deleted file mode 100644 index 0b74cf42..00000000 Binary files a/bindsnet/encoding/__pycache__/encoders.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/encoding/__pycache__/encodings.cpython-310.pyc b/bindsnet/encoding/__pycache__/encodings.cpython-310.pyc deleted file mode 100644 index f7082471..00000000 Binary files a/bindsnet/encoding/__pycache__/encodings.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/encoding/__pycache__/loaders.cpython-310.pyc b/bindsnet/encoding/__pycache__/loaders.cpython-310.pyc deleted file mode 100644 index 5915414a..00000000 Binary files a/bindsnet/encoding/__pycache__/loaders.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/environment/__pycache__/__init__.cpython-310.pyc b/bindsnet/environment/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 88a73cf6..00000000 Binary files a/bindsnet/environment/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/environment/__pycache__/environment.cpython-310.pyc b/bindsnet/environment/__pycache__/environment.cpython-310.pyc deleted file mode 100644 index da3dc8b8..00000000 Binary files a/bindsnet/environment/__pycache__/environment.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/evaluation/__pycache__/__init__.cpython-310.pyc b/bindsnet/evaluation/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index f5cb7f23..00000000 Binary files a/bindsnet/evaluation/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/evaluation/__pycache__/evaluation.cpython-310.pyc b/bindsnet/evaluation/__pycache__/evaluation.cpython-310.pyc deleted file mode 100644 index 2ccc3832..00000000 Binary files a/bindsnet/evaluation/__pycache__/evaluation.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/learning/__pycache__/__init__.cpython-310.pyc b/bindsnet/learning/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 6289cc4c..00000000 Binary files a/bindsnet/learning/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/learning/__pycache__/learning.cpython-310.pyc b/bindsnet/learning/__pycache__/learning.cpython-310.pyc deleted file mode 100644 index 099de4fe..00000000 Binary files a/bindsnet/learning/__pycache__/learning.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/learning/__pycache__/reward.cpython-310.pyc b/bindsnet/learning/__pycache__/reward.cpython-310.pyc deleted file mode 100644 index 06358553..00000000 Binary files a/bindsnet/learning/__pycache__/reward.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/models/__pycache__/__init__.cpython-310.pyc b/bindsnet/models/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 4b00cfdf..00000000 Binary files a/bindsnet/models/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/models/__pycache__/models.cpython-310.pyc b/bindsnet/models/__pycache__/models.cpython-310.pyc deleted file mode 100644 index a5c96c22..00000000 Binary files a/bindsnet/models/__pycache__/models.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/models/models.py b/bindsnet/models/models.py index 8ae3f136..96f13bba 100644 --- a/bindsnet/models/models.py +++ b/bindsnet/models/models.py @@ -6,9 +6,11 @@ from torch.nn.modules.utils import _pair from bindsnet.learning import PostPre +from bindsnet.learning.MCC_learning import PostPre as MMCPostPre from bindsnet.network import Network from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes -from bindsnet.network.topology import Connection, LocalConnection +from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection +from bindsnet.network.topology_features import Weight class TwoLayerNetwork(Network): @@ -94,6 +96,7 @@ class DiehlAndCook2015(Network): def __init__( self, n_inpt: int, + device: str = "cpu", n_neurons: int = 100, exc: float = 22.5, inh: float = 17.5, @@ -102,6 +105,7 @@ def __init__( reduction: Optional[callable] = None, wmin: float = 0.0, wmax: float = 1.0, + w_dtype: torch.dtype = torch.float32, norm: float = 78.4, theta_plus: float = 0.05, tc_theta_decay: float = 1e7, @@ -124,6 +128,7 @@ def __init__( dimension. :param wmin: Minimum allowed weight on input to excitatory synapses. :param wmax: Maximum allowed weight on input to excitatory synapses. + :param w_dtype: Data type for :code:`w` tensor :param norm: Input to excitatory layer connection weights normalization constant. :param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane @@ -170,27 +175,53 @@ def __init__( # Connections w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) - input_exc_conn = Connection( + input_exc_conn = MulticompartmentConnection( source=input_layer, target=exc_layer, - w=w, - update_rule=PostPre, - nu=nu, - reduction=reduction, - wmin=wmin, - wmax=wmax, - norm=norm, + device=device, + pipeline=[ + Weight( + 'weight', + w, + value_dtype=w_dtype, + range=[wmin, wmax], + norm=norm, + reduction=reduction, + nu=nu, + learning_rule=MMCPostPre + ) + ] ) w = self.exc * torch.diag(torch.ones(self.n_neurons)) - exc_inh_conn = Connection( - source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc + exc_inh_conn = MulticompartmentConnection( + source=exc_layer, + target=inh_layer, + device=device, + pipeline=[ + Weight( + 'weight', + w, + value_dtype=w_dtype, + range=[0, self.exc] + ) + ] ) w = -self.inh * ( torch.ones(self.n_neurons, self.n_neurons) - torch.diag(torch.ones(self.n_neurons)) ) - inh_exc_conn = Connection( - source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0 + inh_exc_conn = MulticompartmentConnection( + source=inh_layer, + target=exc_layer, + device=device, + pipeline=[ + Weight( + 'weight', + w, + value_dtype=w_dtype, + range=[-self.inh, 0] + ) + ] ) # Add to network diff --git a/bindsnet/network/__pycache__/__init__.cpython-310.pyc b/bindsnet/network/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index d24bf0a9..00000000 Binary files a/bindsnet/network/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/network/__pycache__/monitors.cpython-310.pyc b/bindsnet/network/__pycache__/monitors.cpython-310.pyc deleted file mode 100644 index d0bec1c2..00000000 Binary files a/bindsnet/network/__pycache__/monitors.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/network/__pycache__/network.cpython-310.pyc b/bindsnet/network/__pycache__/network.cpython-310.pyc deleted file mode 100644 index 8cd220a6..00000000 Binary files a/bindsnet/network/__pycache__/network.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/network/__pycache__/nodes.cpython-310.pyc b/bindsnet/network/__pycache__/nodes.cpython-310.pyc deleted file mode 100644 index fa8f1b06..00000000 Binary files a/bindsnet/network/__pycache__/nodes.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/network/__pycache__/topology.cpython-310.pyc b/bindsnet/network/__pycache__/topology.cpython-310.pyc deleted file mode 100644 index 1d55aea8..00000000 Binary files a/bindsnet/network/__pycache__/topology.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index cb5fafa1..5c6deedb 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -141,6 +141,14 @@ def reset_state_variables(self) -> None: Contains resetting logic for the connection. """ + @staticmethod + def cast_dtype_if_needed(w, w_dtype): + if w.dtype != w_dtype: + warnings.warn(f"Provided w has data type {w.dtype} but parameter w_dtype is {w_dtype}") + return w.to(dtype=w_dtype) + else: + return w + class AbstractMulticompartmentConnection(ABC, Module): # language=rst @@ -261,6 +269,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: # language=rst @@ -275,6 +284,7 @@ def __init__( :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: @@ -296,9 +306,11 @@ def __init__( w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) + w = w.to(dtype=w_dtype) else: if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any(): w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax) + w = self.cast_dtype_if_needed(w, w_dtype) self.w = Parameter(w, requires_grad=False) @@ -525,6 +537,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: # language=rst @@ -543,6 +556,7 @@ def __init__( :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: @@ -595,9 +609,11 @@ def __init__( self.out_channels, self.in_channels, self.kernel_size ) w += self.wmin + w = w.to(dtype=w_dtype) else: if (self.wmin == -inf).any() or (self.wmax == inf).any(): w = torch.clamp(w, self.wmin, self.wmax) + w = self.cast_dtype_if_needed(w, w_dtype) self.w = Parameter(w, requires_grad=False) self.b = Parameter( @@ -667,6 +683,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: # language=rst @@ -685,6 +702,7 @@ def __init__( :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: @@ -750,9 +768,11 @@ def __init__( self.out_channels, self.in_channels, *self.kernel_size ) w += self.wmin + w = w.to(dtype=w_dtype) else: if (self.wmin == -inf).any() or (self.wmax == inf).any(): w = torch.clamp(w, self.wmin, self.wmax) + w = self.cast_dtype_if_needed(w, w_dtype) self.w = Parameter(w, requires_grad=False) self.b = Parameter( @@ -824,6 +844,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: # language=rst @@ -842,6 +863,7 @@ def __init__( :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: @@ -926,9 +948,11 @@ def __init__( self.out_channels, self.in_channels, *self.kernel_size ) w += self.wmin + w = w.to(dtype=w_dtype) else: if (self.wmin == -inf).any() or (self.wmax == inf).any(): w = torch.clamp(w, self.wmin, self.wmax) + w = self.cast_dtype_if_needed(w, w_dtype) self.w = Parameter(w, requires_grad=False) self.b = Parameter( @@ -1276,6 +1300,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: # language=rst @@ -1299,6 +1324,7 @@ def __init__( :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: @@ -1378,10 +1404,11 @@ def __init__( w = torch.clamp(w, self.wmin, self.wmax) else: w = self.wmin + w * (self.wmax - self.wmin) - + w = w.to(dtype=w_dtype) else: if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any(): w = torch.clamp(w, self.wmin, self.wmax) + w = self.cast_dtype_if_needed(w, w_dtype) self.w = Parameter(w, requires_grad=False) @@ -1456,6 +1483,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: """ @@ -1474,6 +1502,7 @@ def __init__( In this case, their shape should be the same size as the connection weights. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. @@ -1507,12 +1536,14 @@ def __init__( w = torch.rand( self.in_channels, self.n_filters * self.conv_size, self.kernel_size ) + w = w.to(dtype=w_dtype) else: assert w.shape == ( self.in_channels, self.out_channels * self.conv_size, self.kernel_size, ), error + w = self.cast_dtype_if_needed(w, w_dtype) if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -1588,6 +1619,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: """ @@ -1606,6 +1638,7 @@ def __init__( In this case, their shape should be the same size as the connection weights. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. @@ -1649,12 +1682,14 @@ def __init__( w = torch.rand( self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod ) + w = w.to(dtype=w_dtype) else: assert w.shape == ( self.in_channels, self.out_channels * self.conv_prod, self.kernel_prod, ), error + w = self.cast_dtype_if_needed(w, w_dtype) if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -1731,6 +1766,7 @@ def __init__( nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: """ @@ -1749,6 +1785,7 @@ def __init__( In this case, their shape should be the same size as the connection weights. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. @@ -1794,12 +1831,14 @@ def __init__( w = torch.rand( self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod ) + w = w.to(dtype=w_dtype) else: assert w.shape == ( self.in_channels, self.out_channels * self.conv_prod, self.kernel_prod, ), error + w = self.cast_dtype_if_needed(w, w_dtype) if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -1875,6 +1914,7 @@ def __init__( target: Nodes, nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, weight_decay: float = 0.0, + w_dtype: torch.dtype = torch.float32, **kwargs, ) -> None: # language=rst @@ -1886,6 +1926,7 @@ def __init__( accepts a pair of tensors to individualize learning rates of each neuron. In this case, their shape should be the same size as the connection weights. :param weight_decay: Constant multiple to decay weights by on each iteration. + :param w_dtype: Data type for :code:`w` tensor Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. @@ -1904,10 +1945,11 @@ def __init__( w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax) else: w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin) + w = w.to(dtype=w_dtype) else: if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): w = torch.clamp(w, self.wmin, self.wmax) - + w = self.cast_dtype_if_needed(w, w_dtype) self.w = Parameter(w, requires_grad=False) def compute(self, s: torch.Tensor) -> torch.Tensor: diff --git a/bindsnet/network/topology_features.py b/bindsnet/network/topology_features.py index f99cf39f..c7484a0e 100644 --- a/bindsnet/network/topology_features.py +++ b/bindsnet/network/topology_features.py @@ -4,6 +4,7 @@ import numpy as np import torch +import warnings from torch import device from torch.nn import Parameter import torch.nn.functional as F @@ -22,6 +23,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, range: Optional[Union[list, tuple]] = None, clamp_frequency: Optional[int] = 1, norm: Optional[Union[torch.Tensor, float, int]] = None, @@ -38,6 +40,7 @@ def __init__( Instantiates a :code:`Feature` object. Will assign all incoming arguments as class variables :param name: Name of the feature :param value: Core numeric object for the feature. This parameters function will vary depending on the feature + :param value_dtype: Data type for :code:`value` tensor :param range: Range of acceptable values for the :code:`value` parameter :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample and after the value has been updated by the learning rule (if there is one) @@ -119,6 +122,15 @@ def __init__( self.assert_valid_range() if value is not None: self.assert_feature_in_range() + self.value = self.cast_dtype_if_needed(self.value, value_dtype) + + @staticmethod + def cast_dtype_if_needed(value, value_dtype): + if value.dtype != value_dtype: + warnings.warn(f"Provided value has data type {value.dtype} but parameter w_dtype is {value_dtype}") + return value.to(dtype=value_dtype) + else: + return value @abstractmethod def reset_state_variables(self) -> None: @@ -312,6 +324,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, range: Optional[Sequence[float]] = None, norm: Optional[Union[torch.Tensor, float, int]] = None, learning_rule: Optional[bindsnet.learning.LearningRule] = None, @@ -327,6 +340,7 @@ def __init__( :param value: Number(s) in [0, 1] which represent the probability of a signal traversing a synapse. Tensor values assume that probabilities will be matched to adjacent synapses in the connection. Scalars will be applied to all synapses. + :param value_dtype: Data type for :code:`value` tensor :param range: Range of acceptable values for the :code:`value` parameter. Should be in [0, 1] :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample and after the value has been updated by the learning rule (if there is one) @@ -342,6 +356,7 @@ def __init__( super().__init__( name=name, value=value, + value_dtype=value_dtype, range=[0, 1] if range is None else range, norm=norm, learning_rule=learning_rule, @@ -419,6 +434,7 @@ def __init__( super().__init__( name=name, value=value, + value_dtype=torch.bool ) self.name = name @@ -497,6 +513,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, range: Optional[Sequence[float]] = None, norm: Optional[Union[torch.Tensor, float, int]] = None, norm_frequency: Optional[str] = "sample", @@ -511,6 +528,7 @@ def __init__( Multiplies signals by scalars :param name: Name of the feature :param value: Values to scale signals by + :param value_dtype: Data type for :code:`value` tensor :param range: Range of acceptable values for the :code:`value` parameter :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample and after the value has been updated by the learning rule (if there is one) @@ -530,6 +548,7 @@ def __init__( super().__init__( name=name, value=value, + value_dtype=value_dtype, range=[-torch.inf, +torch.inf] if range is None else range, norm=norm, learning_rule=learning_rule, @@ -587,6 +606,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, range: Optional[Sequence[float]] = None, norm: Optional[Union[torch.Tensor, float, int]] = None, ) -> None: @@ -595,6 +615,7 @@ def __init__( Adds scalars to signals :param name: Name of the feature :param value: Values to add to the signals + :param value_dtype: Data type for :code:`value` tensor :param range: Range of acceptable values for the :code:`value` parameter :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample and after the value has been updated by the learning rule (if there is one) @@ -603,6 +624,7 @@ def __init__( super().__init__( name=name, value=value, + value_dtype=value_dtype, range=[-torch.inf, +torch.inf] if range is None else range, norm=norm, ) @@ -628,6 +650,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, range: Optional[Sequence[float]] = None, ) -> None: # language=rst @@ -635,9 +658,10 @@ def __init__( Adds scalars to signals :param name: Name of the feature :param value: Values to scale signals by + :param value_dtype: Data type for :code:`value` tensor """ - super().__init__(name=name, value=value, range=range) + super().__init__(name=name, value=value, value_dtype=value_dtype, range=range) def reset_state_variables(self) -> None: pass @@ -664,6 +688,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, degrade_function: callable = None, parent_feature: Optional[AbstractFeature] = None, ) -> None: @@ -673,13 +698,14 @@ def __init__( Note: If :code:`parent_feature` is provided, it will override :code:`value`. :param name: Name of the feature :param value: Value used to degrade feature + :param value_dtype: Data type for :code:`value` tensor :param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or constant to be *subtracted* from the propagating spikes. :param parent_feature: Parent feature with desired :code:`value` to inherit """ # Note: parent_feature will override value. See abstract constructor - super().__init__(name=name, value=value, parent_feature=parent_feature) + super().__init__(name=name, value=value, value_dtype=value_dtype, parent_feature=parent_feature) self.degrade_function = degrade_function @@ -695,6 +721,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.001, @@ -710,6 +737,9 @@ def __init__( :param const_decay: The spontaneous activation of the synapses. """ + self.value_dtype = value_dtype + value = value.to(self.value_dtype) + # Define the ANN class ANN(nn.Module): def __init__(self, input_size, hidden_size, output_size): @@ -743,7 +773,7 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value) + super().__init__(name=name, value=value, value_dtype=self.value_dtype) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: @@ -758,7 +788,7 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # Update the masks if self.counter % self.spike_buffer.shape[1] == 0: with torch.no_grad(): - ann_decision = self.ann(self.spike_buffer.to(torch.float32)) + ann_decision = self.ann(self.spike_buffer.to(self.value_dtype)) self.mask += ( ann_decision.view(self.mask.shape) * self.const_update_rate ) # update mask with learning rate fraction @@ -766,7 +796,7 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: self.mask = torch.clamp(self.mask, -1, 1) # cap the mask # self.mask = torch.clamp(self.mask, -1, 1) - self.value = (self.mask > 0).float() + self.value = (self.mask > 0).to(self.value_dtype) return conn_spikes * self.value @@ -785,6 +815,7 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + value_dtype: torch.dtype = torch.float32, ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.01, @@ -796,9 +827,12 @@ def __init__( :param name: Name of the feature :param ann_values: Values to be use to build an ANN that will adapt the connectivity of the layer. :param value: Values to be use to build an initial mask for the synapses. + :param value_dtype: Data type for :code:`value` tensor :param const_update_rate: The mask upatate rate of the ANN decision. :param const_decay: The spontaneous activation of the synapses. """ + self.value_dtype = value_dtype + value = value.to(self.value_dtype) # Define the ANN class ANN(nn.Module): @@ -833,7 +867,7 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value) + super().__init__(name=name, value=value, value_dtype=self.value_dtype) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: @@ -848,7 +882,7 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # Update the masks if self.counter % self.spike_buffer.shape[1] == 0: with torch.no_grad(): - ann_decision = self.ann(self.spike_buffer.to(torch.float32)) + ann_decision = self.ann(self.spike_buffer.to(self.value_dtype)) self.mask += ( ann_decision.view(self.mask.shape) * self.const_update_rate ) # update mask with learning rate fraction @@ -856,7 +890,7 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: self.mask = torch.clamp(self.mask, -1, 1) # cap the mask # self.mask = torch.clamp(self.mask, -1, 1) - self.value = (self.mask > 0).float() + self.value = (self.mask > 0).to(self.value_dtype) return conn_spikes * self.value diff --git a/bindsnet/pipeline/__pycache__/__init__.cpython-310.pyc b/bindsnet/pipeline/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 1316877f..00000000 Binary files a/bindsnet/pipeline/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/pipeline/__pycache__/action.cpython-310.pyc b/bindsnet/pipeline/__pycache__/action.cpython-310.pyc deleted file mode 100644 index bc066cb0..00000000 Binary files a/bindsnet/pipeline/__pycache__/action.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/pipeline/__pycache__/base_pipeline.cpython-310.pyc b/bindsnet/pipeline/__pycache__/base_pipeline.cpython-310.pyc deleted file mode 100644 index 4d1c2c36..00000000 Binary files a/bindsnet/pipeline/__pycache__/base_pipeline.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-310.pyc b/bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-310.pyc deleted file mode 100644 index 29a9039a..00000000 Binary files a/bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/pipeline/__pycache__/environment_pipeline.cpython-310.pyc b/bindsnet/pipeline/__pycache__/environment_pipeline.cpython-310.pyc deleted file mode 100644 index 88e638b0..00000000 Binary files a/bindsnet/pipeline/__pycache__/environment_pipeline.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/preprocessing/__pycache__/__init__.cpython-310.pyc b/bindsnet/preprocessing/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index dd1e526d..00000000 Binary files a/bindsnet/preprocessing/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/bindsnet/preprocessing/__pycache__/preprocessing.cpython-310.pyc b/bindsnet/preprocessing/__pycache__/preprocessing.cpython-310.pyc deleted file mode 100644 index 86c739a8..00000000 Binary files a/bindsnet/preprocessing/__pycache__/preprocessing.cpython-310.pyc and /dev/null differ diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 8338af19..8e762339 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -40,6 +40,12 @@ parser.add_argument("--dt", type=int, default=1.0) parser.add_argument("--intensity", type=float, default=128) parser.add_argument("--progress_interval", type=int, default=10) +parser.add_argument( + "--w_dtype", + type=str, + default='float32', + help='Datatype to use for weights. Examples: float32, float16, bfloat16 etc' +) parser.add_argument("--train", dest="train", action="store_true") parser.add_argument("--test", dest="train", action="store_false") parser.add_argument("--plot", dest="plot", action="store_true") @@ -102,6 +108,8 @@ nu=(1e-4, 1e-2), theta_plus=theta_plus, inpt_shape=(1, 28, 28), + device=device, + w_dtype=getattr(torch, args.w_dtype) ) # Directs network to GPU @@ -271,7 +279,7 @@ image = batch["image"][:, 0].view(28, 28) inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28) lable = batch["label"][0] - input_exc_weights = network.connections[("X", "Ae")].w + input_exc_weights = network.connections[("X", "Ae")].feature_index['weight'].value square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28 ) diff --git a/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-7.4.4.pyc b/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index 7bed2960..00000000 Binary files a/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.1.1.pyc b/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index 7bed2960..00000000 Binary files a/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.2.2.pyc b/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index b1a7536e..00000000 Binary files a/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/conversion/__pycache__/test_conversion.cpython-310-pytest-7.4.4.pyc b/test/conversion/__pycache__/test_conversion.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index 608608f0..00000000 Binary files a/test/conversion/__pycache__/test_conversion.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.1.1.pyc b/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index 608608f0..00000000 Binary files a/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.2.2.pyc b/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index e12e1822..00000000 Binary files a/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/encoding/__pycache__/test_encoding.cpython-310-pytest-7.4.4.pyc b/test/encoding/__pycache__/test_encoding.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index 024727cb..00000000 Binary files a/test/encoding/__pycache__/test_encoding.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.1.1.pyc b/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index 024727cb..00000000 Binary files a/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.2.2.pyc b/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index f03e5b6a..00000000 Binary files a/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/import/__pycache__/test_import.cpython-310-pytest-7.4.4.pyc b/test/import/__pycache__/test_import.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index 067df56d..00000000 Binary files a/test/import/__pycache__/test_import.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/import/__pycache__/test_import.cpython-310-pytest-8.1.1.pyc b/test/import/__pycache__/test_import.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index 067df56d..00000000 Binary files a/test/import/__pycache__/test_import.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/import/__pycache__/test_import.cpython-310-pytest-8.2.2.pyc b/test/import/__pycache__/test_import.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index 067df56d..00000000 Binary files a/test/import/__pycache__/test_import.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/models/__pycache__/test_models.cpython-310-pytest-7.4.4.pyc b/test/models/__pycache__/test_models.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index 5ecf1176..00000000 Binary files a/test/models/__pycache__/test_models.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/models/__pycache__/test_models.cpython-310-pytest-8.1.1.pyc b/test/models/__pycache__/test_models.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index 5ecf1176..00000000 Binary files a/test/models/__pycache__/test_models.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/models/__pycache__/test_models.cpython-310-pytest-8.2.2.pyc b/test/models/__pycache__/test_models.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index 8b9c2002..00000000 Binary files a/test/models/__pycache__/test_models.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_connections.cpython-310-pytest-7.4.4.pyc b/test/network/__pycache__/test_connections.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index 2ed35151..00000000 Binary files a/test/network/__pycache__/test_connections.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_connections.cpython-310-pytest-8.1.1.pyc b/test/network/__pycache__/test_connections.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index 2ed35151..00000000 Binary files a/test/network/__pycache__/test_connections.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_connections.cpython-310-pytest-8.2.2.pyc b/test/network/__pycache__/test_connections.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index dbdda889..00000000 Binary files a/test/network/__pycache__/test_connections.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_learning.cpython-310-pytest-7.4.4.pyc b/test/network/__pycache__/test_learning.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index c7c83f1d..00000000 Binary files a/test/network/__pycache__/test_learning.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_learning.cpython-310-pytest-8.1.1.pyc b/test/network/__pycache__/test_learning.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index c7c83f1d..00000000 Binary files a/test/network/__pycache__/test_learning.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_learning.cpython-310-pytest-8.2.2.pyc b/test/network/__pycache__/test_learning.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index acf60e50..00000000 Binary files a/test/network/__pycache__/test_learning.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_monitors.cpython-310-pytest-7.4.4.pyc b/test/network/__pycache__/test_monitors.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index c7d1c3fc..00000000 Binary files a/test/network/__pycache__/test_monitors.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_monitors.cpython-310-pytest-8.1.1.pyc b/test/network/__pycache__/test_monitors.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index c7d1c3fc..00000000 Binary files a/test/network/__pycache__/test_monitors.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_monitors.cpython-310-pytest-8.2.2.pyc b/test/network/__pycache__/test_monitors.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index 27f2e7ab..00000000 Binary files a/test/network/__pycache__/test_monitors.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_network.cpython-310-pytest-7.4.4.pyc b/test/network/__pycache__/test_network.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index e1e54af9..00000000 Binary files a/test/network/__pycache__/test_network.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_network.cpython-310-pytest-8.1.1.pyc b/test/network/__pycache__/test_network.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index e1e54af9..00000000 Binary files a/test/network/__pycache__/test_network.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_network.cpython-310-pytest-8.2.2.pyc b/test/network/__pycache__/test_network.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index e0082300..00000000 Binary files a/test/network/__pycache__/test_network.cpython-310-pytest-8.2.2.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_nodes.cpython-310-pytest-7.4.4.pyc b/test/network/__pycache__/test_nodes.cpython-310-pytest-7.4.4.pyc deleted file mode 100644 index 41cf3244..00000000 Binary files a/test/network/__pycache__/test_nodes.cpython-310-pytest-7.4.4.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_nodes.cpython-310-pytest-8.1.1.pyc b/test/network/__pycache__/test_nodes.cpython-310-pytest-8.1.1.pyc deleted file mode 100644 index 41cf3244..00000000 Binary files a/test/network/__pycache__/test_nodes.cpython-310-pytest-8.1.1.pyc and /dev/null differ diff --git a/test/network/__pycache__/test_nodes.cpython-310-pytest-8.2.2.pyc b/test/network/__pycache__/test_nodes.cpython-310-pytest-8.2.2.pyc deleted file mode 100644 index 47d90be7..00000000 Binary files a/test/network/__pycache__/test_nodes.cpython-310-pytest-8.2.2.pyc and /dev/null differ