Skip to content

Lowering precision #707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ test/models/__pycache__/*
test/network/__pycache__/*
test/analysis/__pycache__/*
*.pyc
**/*.pyc
dist/*
logs/*
.pytest_cache/*
Expand Down
Binary file removed bindsnet/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/__pycache__/utils.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/analysis/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/datasets/__pycache__/davis.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/encoding/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/learning/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/learning/__pycache__/reward.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/models/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/models/__pycache__/models.cpython-310.pyc
Binary file not shown.
57 changes: 44 additions & 13 deletions bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from torch.nn.modules.utils import _pair

from bindsnet.learning import PostPre
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
from bindsnet.network import Network
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
from bindsnet.network.topology import Connection, LocalConnection
from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection
from bindsnet.network.topology_features import Weight


class TwoLayerNetwork(Network):
Expand Down Expand Up @@ -94,6 +96,7 @@ class DiehlAndCook2015(Network):
def __init__(
self,
n_inpt: int,
device: str = "cpu",
n_neurons: int = 100,
exc: float = 22.5,
inh: float = 17.5,
Expand All @@ -102,6 +105,7 @@ def __init__(
reduction: Optional[callable] = None,
wmin: float = 0.0,
wmax: float = 1.0,
w_dtype: torch.dtype = torch.float32,
norm: float = 78.4,
theta_plus: float = 0.05,
tc_theta_decay: float = 1e7,
Expand All @@ -124,6 +128,7 @@ def __init__(
dimension.
:param wmin: Minimum allowed weight on input to excitatory synapses.
:param wmax: Maximum allowed weight on input to excitatory synapses.
:param w_dtype: Data type for :code:`w` tensor
:param norm: Input to excitatory layer connection weights normalization
constant.
:param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane
Expand Down Expand Up @@ -170,27 +175,53 @@ def __init__(

# Connections
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
input_exc_conn = Connection(
input_exc_conn = MulticompartmentConnection(
source=input_layer,
target=exc_layer,
w=w,
update_rule=PostPre,
nu=nu,
reduction=reduction,
wmin=wmin,
wmax=wmax,
norm=norm,
device=device,
pipeline=[
Weight(
'weight',
w,
value_dtype=w_dtype,
range=[wmin, wmax],
norm=norm,
reduction=reduction,
nu=nu,
learning_rule=MMCPostPre
)
]
)
w = self.exc * torch.diag(torch.ones(self.n_neurons))
exc_inh_conn = Connection(
source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
exc_inh_conn = MulticompartmentConnection(
source=exc_layer,
target=inh_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
value_dtype=w_dtype,
range=[0, self.exc]
)
]
)
w = -self.inh * (
torch.ones(self.n_neurons, self.n_neurons)
- torch.diag(torch.ones(self.n_neurons))
)
inh_exc_conn = Connection(
source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
inh_exc_conn = MulticompartmentConnection(
source=inh_layer,
target=exc_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
value_dtype=w_dtype,
range=[-self.inh, 0]
)
]
)

# Add to network
Expand Down
Binary file removed bindsnet/network/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file removed bindsnet/network/__pycache__/network.cpython-310.pyc
Binary file not shown.
Binary file removed bindsnet/network/__pycache__/nodes.cpython-310.pyc
Binary file not shown.
Binary file not shown.
46 changes: 44 additions & 2 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def reset_state_variables(self) -> None:
Contains resetting logic for the connection.
"""

@staticmethod
def cast_dtype_if_needed(w, w_dtype):
if w.dtype != w_dtype:
warnings.warn(f"Provided w has data type {w.dtype} but parameter w_dtype is {w_dtype}")
return w.to(dtype=w_dtype)
else:
return w


class AbstractMulticompartmentConnection(ABC, Module):
# language=rst
Expand Down Expand Up @@ -261,6 +269,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -275,6 +284,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor

Keyword arguments:

Expand All @@ -296,9 +306,11 @@ def __init__(
w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
else:
w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
w = w.to(dtype=w_dtype)
else:
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)

Expand Down Expand Up @@ -525,6 +537,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -543,6 +556,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor

Keyword arguments:

Expand Down Expand Up @@ -595,9 +609,11 @@ def __init__(
self.out_channels, self.in_channels, self.kernel_size
)
w += self.wmin
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
Expand Down Expand Up @@ -667,6 +683,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -685,6 +702,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor

Keyword arguments:

Expand Down Expand Up @@ -750,9 +768,11 @@ def __init__(
self.out_channels, self.in_channels, *self.kernel_size
)
w += self.wmin
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
Expand Down Expand Up @@ -824,6 +844,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -842,6 +863,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor

Keyword arguments:

Expand Down Expand Up @@ -926,9 +948,11 @@ def __init__(
self.out_channels, self.in_channels, *self.kernel_size
)
w += self.wmin
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
Expand Down Expand Up @@ -1276,6 +1300,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -1299,6 +1324,7 @@ def __init__(
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor

Keyword arguments:

Expand Down Expand Up @@ -1378,10 +1404,11 @@ def __init__(
w = torch.clamp(w, self.wmin, self.wmax)
else:
w = self.wmin + w * (self.wmax - self.wmin)

w = w.to(dtype=w_dtype)
else:
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
w = self.cast_dtype_if_needed(w, w_dtype)

self.w = Parameter(w, requires_grad=False)

Expand Down Expand Up @@ -1456,6 +1483,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
"""
Expand All @@ -1474,6 +1502,7 @@ def __init__(
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to some rule.
:param torch.Tensor w: Strengths of synapses.
Expand Down Expand Up @@ -1507,12 +1536,14 @@ def __init__(
w = torch.rand(
self.in_channels, self.n_filters * self.conv_size, self.kernel_size
)
w = w.to(dtype=w_dtype)
else:
assert w.shape == (
self.in_channels,
self.out_channels * self.conv_size,
self.kernel_size,
), error
w = self.cast_dtype_if_needed(w, w_dtype)

if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
Expand Down Expand Up @@ -1588,6 +1619,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
"""
Expand All @@ -1606,6 +1638,7 @@ def __init__(
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to some rule.
:param torch.Tensor w: Strengths of synapses.
Expand Down Expand Up @@ -1649,12 +1682,14 @@ def __init__(
w = torch.rand(
self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod
)
w = w.to(dtype=w_dtype)
else:
assert w.shape == (
self.in_channels,
self.out_channels * self.conv_prod,
self.kernel_prod,
), error
w = self.cast_dtype_if_needed(w, w_dtype)

if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
Expand Down Expand Up @@ -1731,6 +1766,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
"""
Expand All @@ -1749,6 +1785,7 @@ def __init__(
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to some rule.
:param torch.Tensor w: Strengths of synapses.
Expand Down Expand Up @@ -1794,12 +1831,14 @@ def __init__(
w = torch.rand(
self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod
)
w = w.to(dtype=w_dtype)
else:
assert w.shape == (
self.in_channels,
self.out_channels * self.conv_prod,
self.kernel_prod,
), error
w = self.cast_dtype_if_needed(w, w_dtype)

if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)
Expand Down Expand Up @@ -1875,6 +1914,7 @@ def __init__(
target: Nodes,
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
weight_decay: float = 0.0,
w_dtype: torch.dtype = torch.float32,
**kwargs,
) -> None:
# language=rst
Expand All @@ -1886,6 +1926,7 @@ def __init__(
accepts a pair of tensors to individualize learning rates of each neuron.
In this case, their shape should be the same size as the connection weights.
:param weight_decay: Constant multiple to decay weights by on each iteration.
:param w_dtype: Data type for :code:`w` tensor
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
Expand All @@ -1904,10 +1945,11 @@ def __init__(
w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax)
else:
w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin)
w = w.to(dtype=w_dtype)
else:
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
w = torch.clamp(w, self.wmin, self.wmax)

w = self.cast_dtype_if_needed(w, w_dtype)
self.w = Parameter(w, requires_grad=False)

def compute(self, s: torch.Tensor) -> torch.Tensor:
Expand Down
Loading
Loading